Skip to content

sk-agents

sk_agents
sk_agents.a2a
sk_agents.a2a.redis_task_store

DEPRECATION NOTICE: A2A (Agent-to-Agent) functionality is being deprecated as part of the framework migration evaluation. This module is maintained for backward compatibility only. New development should avoid using A2A functionality.

Redis implementation of the TaskStore interface. This implementation uses Redis as the persistent store for Task objects.

sk_agents.a2a.redis_task_store.RedisTaskStore

Bases: TaskStore

Redis implementation of the TaskStore interface.

This class provides Redis-based persistence for Task objects.

Source code in src/sk_agents/a2a/redis_task_store.py
class RedisTaskStore(TaskStore):
    """Redis implementation of the TaskStore interface.

    This class provides Redis-based persistence for Task objects.
    """

    def __init__(self, redis_client: Redis, ttl: int | None = None, key_prefix: str = "task:"):
        """Initialize the RedisTaskStore with a Redis client.

        Args:
            redis_client: An instance of Redis client
            key_prefix: Prefix used for Redis keys (default: "task:")
        """
        self._redis = redis_client
        self._ttl = ttl
        self._key_prefix = key_prefix

    def _get_key(self, task_id: str) -> str:
        """Generate a Redis key for a given task ID.

        Args:
            task_id: The ID of the task

        Returns:
            A Redis key string
        """
        return f"{self._key_prefix}{task_id}"

    async def save(self, task: Task):
        """Saves or updates a task in the Redis store.

        Args:
            task: The Task object to save
        """
        # Convert the Task object to a serializable dictionary
        task_dict = task.model_dump(mode="json")

        # Serialize the task dictionary to JSON
        task_json = json.dumps(task_dict)

        # Store the serialized task in Redis using the task ID as the key
        await self._redis.set(self._get_key(task.id), task_json, ex=self._ttl)

    async def get(self, task_id: str) -> Task | None:
        """Retrieves a task from the Redis store by ID.

        Args:
            task_id: The ID of the task to retrieve

        Returns:
            The Task object if found, None otherwise
        """
        # Get the serialized task from Redis
        task_json = await self._redis.get(self._get_key(task_id))

        if task_json is None:
            return None

        # Deserialize the JSON string to a dictionary
        task_dict = json.loads(task_json)

        # Create and return a Task object from the dictionary
        return Task.model_validate(task_dict)

    async def delete(self, task_id: str):
        """Deletes a task from the Redis store by ID.

        Args:
            task_id: The ID of the task to delete
        """
        # Delete the task from Redis
        await self._redis.delete(self._get_key(task_id))
sk_agents.a2a.redis_task_store.RedisTaskStore.__init__
__init__(
    redis_client: Redis,
    ttl: int | None = None,
    key_prefix: str = "task:",
)

Initialize the RedisTaskStore with a Redis client.

Parameters:

Name Type Description Default
redis_client Redis

An instance of Redis client

required
key_prefix str

Prefix used for Redis keys (default: "task:")

'task:'
Source code in src/sk_agents/a2a/redis_task_store.py
def __init__(self, redis_client: Redis, ttl: int | None = None, key_prefix: str = "task:"):
    """Initialize the RedisTaskStore with a Redis client.

    Args:
        redis_client: An instance of Redis client
        key_prefix: Prefix used for Redis keys (default: "task:")
    """
    self._redis = redis_client
    self._ttl = ttl
    self._key_prefix = key_prefix
sk_agents.a2a.redis_task_store.RedisTaskStore.save async
save(task: Task)

Saves or updates a task in the Redis store.

Parameters:

Name Type Description Default
task Task

The Task object to save

required
Source code in src/sk_agents/a2a/redis_task_store.py
async def save(self, task: Task):
    """Saves or updates a task in the Redis store.

    Args:
        task: The Task object to save
    """
    # Convert the Task object to a serializable dictionary
    task_dict = task.model_dump(mode="json")

    # Serialize the task dictionary to JSON
    task_json = json.dumps(task_dict)

    # Store the serialized task in Redis using the task ID as the key
    await self._redis.set(self._get_key(task.id), task_json, ex=self._ttl)
sk_agents.a2a.redis_task_store.RedisTaskStore.get async
get(task_id: str) -> Task | None

Retrieves a task from the Redis store by ID.

Parameters:

Name Type Description Default
task_id str

The ID of the task to retrieve

required

Returns:

Type Description
Task | None

The Task object if found, None otherwise

Source code in src/sk_agents/a2a/redis_task_store.py
async def get(self, task_id: str) -> Task | None:
    """Retrieves a task from the Redis store by ID.

    Args:
        task_id: The ID of the task to retrieve

    Returns:
        The Task object if found, None otherwise
    """
    # Get the serialized task from Redis
    task_json = await self._redis.get(self._get_key(task_id))

    if task_json is None:
        return None

    # Deserialize the JSON string to a dictionary
    task_dict = json.loads(task_json)

    # Create and return a Task object from the dictionary
    return Task.model_validate(task_dict)
sk_agents.a2a.redis_task_store.RedisTaskStore.delete async
delete(task_id: str)

Deletes a task from the Redis store by ID.

Parameters:

Name Type Description Default
task_id str

The ID of the task to delete

required
Source code in src/sk_agents/a2a/redis_task_store.py
async def delete(self, task_id: str):
    """Deletes a task from the Redis store by ID.

    Args:
        task_id: The ID of the task to delete
    """
    # Delete the task from Redis
    await self._redis.delete(self._get_key(task_id))
sk_agents.a2a.response_classifier
sk_agents.a2a.response_classifier.A2AResponseClassifier

A class to classify responses from the A2A agent.

Source code in src/sk_agents/a2a/response_classifier.py
class A2AResponseClassifier:
    """
    A class to classify responses from the A2A agent.
    """

    NAME = "a2a-response-classifier"
    SYSTEM_PROMPT = (
        "## System Prompt: Agent Output Classifier\n"
        "\n"
        "**You are an AI agent tasked with analyzing the output of another AI agent "
        '(referred to as the "Primary Agent") and classifying its status. Your output MUST '
        "be a JSON object.**\n"
        "\n"
        "Your goal is to determine which of the following categories best describes "
        "the Primary Agent's output and structure your response accordingly.\n"
        "\n"
        "**Possible Classification Statuses & JSON Output Structures:**\n"
        "\n"
        "1.  **Status: `completed`**\n"
        "    * The Primary Agent has successfully completed the assigned task or answered the "
        "user's query.\n"
        "    * **JSON Output Structure:**\n"
        "        ```json\n"
        "        {\n"
        '          "status": "completed"\n'
        "        }\n"
        "        ```\n"
        '    * Keywords/phrases to look for: "done," "completed," "finished," "success," '
        '"here is the result," "I have finished," "the task is complete," direct answers '
        "to questions, generated content that fulfills the request.\n"
        "    * Context: The output clearly indicates finality and achievement of the original "
        "goal.\n"
        "\n"
        "2.  **Status: `failed`**\n"
        "    * The Primary Agent has failed to complete the assigned task or answered the "
        "user's query.\n"
        "    * **JSON Output Structure:**\n"
        "        ```json\n"
        "        {\n"
        '          "status": "failed"\n'
        "        }\n"
        "        ```\n"
        '    * Keywords/phrases to look for: "failed," "unable to," "cannot complete," '
        '"error," "encountered a problem," "not possible," "I\'m sorry, I can\'t," '
        '"task aborted."\n'
        "    * Context: The output indicates an inability to proceed or a definitive negative "
        "outcome regarding the task. This includes technical errors, lack of capability, "
        "or hitting a dead end.\n"
        "\n"
        "3.  **Status: `input-required`**\n"
        "    * The Primary Agent requires additional information, clarification, or a decision "
        "from the user to continue or complete the task.\n"
        "    * **JSON Output Structure:**\n"
        "        ```json\n"
        "        {\n"
        '          "status": "input-required",\n'
        '          "message": "A description of what info is needed from the user and why."\n'
        "        }\n"
        "        ```\n"
        '    * Keywords/phrases to look for: "what do you mean by," "could you please '
        'specify," "which option do you prefer," "do you want to proceed," "please '
        'provide," "I need more information," questions directed at the user.\n'
        "    * Context: The output is a direct or indirect request for user interaction to "
        "resolve ambiguity, make a choice, or provide necessary data. The `message` field "
        "should summarize this request.\n"
        "\n"
        "4.  **Status: `auth-required`**\n"
        "    * The Primary Agent has indicated that it needs to perform some form of "
        "authentication (e.g., login, API key verification, permission grant) before it "
        "can proceed with the task.\n"
        "    * **JSON Output Structure:**\n"
        "        ```json\n"
        "        {\n"
        '          "status": "auth-required",\n'
        '          "message": "A description of what authentication is needed and why.",\n'
        '          "auth_details": {} // Likely a JSON structure extracted from the Primary '
        "Agent's output containing technical details about the auth request. Can be an "
        "empty object if no specific structure is found.\n"
        "        }\n"
        "        ```\n"
        '    * Keywords/phrases to look for: "please log in," "authentication required," '
        '"access denied," "invalid credentials," "API key needed," "sign in to continue," '
        '"verify your identity," "permissions needed."\n'
        "    * Context: The output explicitly states or strongly implies that a security or "
        "access barrier is preventing task progression. The `message` field should explain "
        "this. The `auth_details` field should attempt to capture any structured information "
        "(e.g., OAuth URLs, scopes needed, realm info) provided by the Primary Agent "
        "regarding the authentication. If the Primary Agent provides a JSON blob related to "
        "auth, try to pass that through in `auth_details`.\n"
        "\n"
        "**Your Analysis Process:**\n"
        "\n"
        "1.  **Carefully review the entire output from the Primary Agent.** Understand the "
        "context and the overall message.\n"
        "2.  **Look for explicit keywords and phrases** associated with each category.\n"
        "3.  **Consider the intent** behind the Primary Agent's message.\n"
        "4.  **Prioritize:**\n"
        "    * If authentication is mentioned as a blocker, classify as `auth-required`. "
        "Extract relevant details for the `message` and `auth_details` fields.\n"
        "    * If the agent is clearly asking the user a question to proceed (and it's not "
        "primarily an authentication request), classify as `input-required`. Formulate the "
        "`message` field.\n"
        "    * If the agent explicitly states success, classify as `completed`.\n"
        "    * If the agent explicitly states failure or an insurmountable error (not related "
        "to needing input or auth), classify as `failed`.\n"
        "5.  **Extract Information for `message` and `auth_details`:**\n"
        "    * For `input-required` and `auth-required`, the `message` should be a concise "
        "explanation derived from the Primary Agent's output.\n"
        "    * For `auth-required`, if the Primary Agent's output includes a structured "
        "(e.g., JSON) segment detailing the authentication requirements, attempt to extract "
        "and place this into the `auth_details` field. If no specific structure is found, "
        "`auth_details` can be an empty object `{}`. Do not invent details; only extract "
        "what is provided.\n"
        "6.  **If the output is ambiguous, try to infer the most likely category.** If truly "
        'unclear, you may need a default or "UNCLEAR" category (though this prompt focuses '
        "on the four defined). In such a case, defaulting to `failed` with an appropriate "
        "message might be a safe fallback if no other category fits.\n"
        "\n"
        "**Output Format:**\n"
        "\n"
        "Your output **MUST** be a single JSON object corresponding to one of the structures "
        "defined above.\n"
        "\n"
        "**Example Scenarios:**\n"
        "\n"
        "* **Primary Agent Output:** \"I've finished generating the report you asked for. "
        "It's attached below.\"\n"
        "    * **Your JSON Output:**\n"
        "        ```json\n"
        "        {\n"
        '          "status": "completed"\n'
        "        }\n"
        "        ```\n"
        "* **Primary Agent Output:** \"I'm sorry, I encountered an unexpected error and cannot "
        'process your request at this time. Error code: 503. Please try again later."\n'
        "    * **Your JSON Output:**\n"
        "        ```json\n"
        "        {\n"
        '          "status": "failed"\n'
        "        }\n"
        "        ```\n"
        '* **Primary Agent Output:** "To help you with that, could you please tell me which '
        'specific date range you are interested in for the sales data?"\n'
        "    * **Your JSON Output:**\n"
        "        ```json\n"
        "        {\n"
        '          "status": "input-required",\n'
        '          "message": "The agent needs to know the specific date range for the sales '
        'data to proceed."\n'
        "        }\n"
        "        ```\n"
        '* **Primary Agent Output:** "Access to this API endpoint requires authentication. '
        "Please provide a valid Bearer token. Details: {'type': 'Bearer', 'realm': "
        "'[api.example.com/auth](https://api.example.com/auth)'}}\"\n"
        "    * **Your JSON Output:**\n"
        "        ```json\n"
        "        {\n"
        '          "status": "auth-required",\n'
        '          "message": "Access to the API endpoint requires a valid Bearer token.",\n'
        '          "auth_details": {\n'
        '            "type": "Bearer",\n'
        '            "realm": "[api.example.com/auth](https://api.example.com/auth)"\n'
        "          }\n"
        "        }\n"
        "        ```\n"
        '* **Primary Agent Output:** "You need to sign in to your account to access your '
        'profile. Click here to login."\n'
        "    * **Your JSON Output:**\n"
        "        ```json\n"
        "        {\n"
        '          "status": "auth-required",\n'
        '          "message": "User needs to sign in to their account to access their profile.",\n'
        '          "auth_details": {}\n'
        "        }\n"
        "        ```\n"
        "\n"
        "**Critical Considerations:**\n"
        "\n"
        "* Ensure your output is always valid JSON.\n"
        "* Be precise in your classification and in the information extracted for the `message` "
        "and `auth_details` fields.\n"
        "* Focus solely on the provided output from the Primary Agent.\n"
        "* Adhere to the prioritization logic.\n"
    )

    def __init__(self, app_config: AppConfig, chat_completion_builder: ChatCompletionBuilder):
        model_name = app_config.get(TA_A2A_OUTPUT_CLASSIFIER_MODEL.env_name)
        chat_completion = chat_completion_builder.get_chat_completion_for_model(
            service_id=self.NAME, model_name=model_name
        )
        kernel = Kernel()
        kernel.add_service(chat_completion)
        settings = kernel.get_prompt_execution_settings_from_service_id(self.NAME)
        settings.response_format = A2AResponseClassification
        self.agent = ChatCompletionAgent(
            kernel=kernel,
            name=self.NAME,
            instructions=self.SYSTEM_PROMPT,
            arguments=KernelArguments(settings=settings),
        )

    async def classify_response(self, response: str) -> A2AResponseClassification:
        """
        Classify the response from the A2A agent.

        Args:
            response (str): The response from the A2A agent.

        Returns:
            str: The classification of the response.
        """
        chat_history = ChatHistory()
        chat_history.add_user_message(f"Please classify the following response:\n\n{response}")
        async for content in self.agent.invoke(messages=chat_history):
            data = json.loads(str(content.content))
            return A2AResponseClassification(**data)
        return A2AResponseClassification(
            status=A2AResponseStatus.failed,
            message="No response received from response classifier.",
        )
sk_agents.a2a.response_classifier.A2AResponseClassifier.classify_response async
classify_response(
    response: str,
) -> A2AResponseClassification

Classify the response from the A2A agent.

Parameters:

Name Type Description Default
response str

The response from the A2A agent.

required

Returns:

Name Type Description
str A2AResponseClassification

The classification of the response.

Source code in src/sk_agents/a2a/response_classifier.py
async def classify_response(self, response: str) -> A2AResponseClassification:
    """
    Classify the response from the A2A agent.

    Args:
        response (str): The response from the A2A agent.

    Returns:
        str: The classification of the response.
    """
    chat_history = ChatHistory()
    chat_history.add_user_message(f"Please classify the following response:\n\n{response}")
    async for content in self.agent.invoke(messages=chat_history):
        data = json.loads(str(content.content))
        return A2AResponseClassification(**data)
    return A2AResponseClassification(
        status=A2AResponseStatus.failed,
        message="No response received from response classifier.",
    )
sk_agents.appv3
class AppV3

@staticmethod def run(name, version, app_config, config, app): pass

sk_agents.auth

MCP OAuth 2.1 Authentication Components

This module provides OAuth 2.1 compliant authentication for MCP (Model Context Protocol) servers. All components follow the MCP specification (2025-06-18) for authorization: https://modelcontextprotocol.io/specification/2025-06-18/basic/authorization

Key Features: - PKCE (Proof Key for Code Exchange) for authorization code flow - Resource parameter binding for token audience validation - State parameter for CSRF protection - Token refresh with rotation - Server metadata discovery (RFC8414, RFC9728) - Dynamic client registration (RFC7591)

Architecture: - This module is isolated from platform authentication (RequestAuthorizer) - Platform auth: Validates user to platform, returns user_id - Service auth (MCP): Manages OAuth tokens for external services per user

Components: - oauth_client: Main OAuth 2.1 client for authorization flows - oauth_pkce: PKCE generation and validation - oauth_models: Request/response models for OAuth flows - oauth_state_manager: State and PKCE verifier storage for OAuth flows - server_metadata: Authorization server metadata discovery (RFC8414, RFC9728) - client_registration: Dynamic client registration (RFC7591)

sk_agents.auth.client_registration

Dynamic Client Registration

Implements dynamic client registration per RFC7591. Allows automatic OAuth client registration with authorization servers.

References: - RFC 7591: OAuth 2.0 Dynamic Client Registration Protocol

sk_agents.auth.client_registration.ClientRegistrationRequest

Bases: BaseModel

OAuth 2.0 Dynamic Client Registration Request (RFC7591)

Source code in src/sk_agents/auth/client_registration.py
class ClientRegistrationRequest(BaseModel):
    """
    OAuth 2.0 Dynamic Client Registration Request (RFC7591)
    """

    client_name: str = "teal-agents"
    redirect_uris: list[HttpUrl]
    grant_types: list[Literal["authorization_code", "refresh_token"]] = [
        "authorization_code",
        "refresh_token",
    ]
    token_endpoint_auth_method: Literal["none", "client_secret_basic", "client_secret_post"] = (
        "none"
    )
    response_types: list[str] = ["code"]
    scope: str | None = None
sk_agents.auth.client_registration.ClientRegistrationResponse

Bases: BaseModel

OAuth 2.0 Dynamic Client Registration Response (RFC7591)

Source code in src/sk_agents/auth/client_registration.py
class ClientRegistrationResponse(BaseModel):
    """
    OAuth 2.0 Dynamic Client Registration Response (RFC7591)
    """

    client_id: str
    client_secret: str | None = None
    client_id_issued_at: int | None = None
    client_secret_expires_at: int | None = None
    registration_access_token: str | None = None
    registration_client_uri: HttpUrl | None = None
sk_agents.auth.client_registration.DynamicClientRegistration

Dynamic Client Registration Client.

Handles automatic OAuth client registration. Phase 3 implementation (optional).

Source code in src/sk_agents/auth/client_registration.py
class DynamicClientRegistration:
    """
    Dynamic Client Registration Client.

    Handles automatic OAuth client registration.
    Phase 3 implementation (optional).
    """

    def __init__(self, timeout: float = 30.0):
        self.timeout = timeout

    async def register_client(
        self,
        registration_endpoint: str,
        request: ClientRegistrationRequest,
    ) -> ClientRegistrationResponse:
        """
        Register OAuth client with authorization server.

        Args:
            registration_endpoint: Registration endpoint URL from server metadata
            request: Client registration request

        Returns:
            ClientRegistrationResponse: Registered client credentials

        Raises:
            httpx.HTTPError: If registration fails
        """
        # Implementation in Phase 3
        raise NotImplementedError("Dynamic client registration not yet implemented (Phase 3)")

    async def update_client(
        self,
        registration_client_uri: str,
        registration_access_token: str,
        request: ClientRegistrationRequest,
    ) -> ClientRegistrationResponse:
        """
        Update registered OAuth client configuration.

        Args:
            registration_client_uri: Client configuration URI
            registration_access_token: Access token for client management
            request: Updated client configuration

        Returns:
            ClientRegistrationResponse: Updated client credentials
        """
        # Implementation in Phase 3
        raise NotImplementedError("Dynamic client update not yet implemented (Phase 3)")
sk_agents.auth.client_registration.DynamicClientRegistration.register_client async
register_client(
    registration_endpoint: str,
    request: ClientRegistrationRequest,
) -> ClientRegistrationResponse

Register OAuth client with authorization server.

Parameters:

Name Type Description Default
registration_endpoint str

Registration endpoint URL from server metadata

required
request ClientRegistrationRequest

Client registration request

required

Returns:

Name Type Description
ClientRegistrationResponse ClientRegistrationResponse

Registered client credentials

Raises:

Type Description
HTTPError

If registration fails

Source code in src/sk_agents/auth/client_registration.py
async def register_client(
    self,
    registration_endpoint: str,
    request: ClientRegistrationRequest,
) -> ClientRegistrationResponse:
    """
    Register OAuth client with authorization server.

    Args:
        registration_endpoint: Registration endpoint URL from server metadata
        request: Client registration request

    Returns:
        ClientRegistrationResponse: Registered client credentials

    Raises:
        httpx.HTTPError: If registration fails
    """
    # Implementation in Phase 3
    raise NotImplementedError("Dynamic client registration not yet implemented (Phase 3)")
sk_agents.auth.client_registration.DynamicClientRegistration.update_client async
update_client(
    registration_client_uri: str,
    registration_access_token: str,
    request: ClientRegistrationRequest,
) -> ClientRegistrationResponse

Update registered OAuth client configuration.

Parameters:

Name Type Description Default
registration_client_uri str

Client configuration URI

required
registration_access_token str

Access token for client management

required
request ClientRegistrationRequest

Updated client configuration

required

Returns:

Name Type Description
ClientRegistrationResponse ClientRegistrationResponse

Updated client credentials

Source code in src/sk_agents/auth/client_registration.py
async def update_client(
    self,
    registration_client_uri: str,
    registration_access_token: str,
    request: ClientRegistrationRequest,
) -> ClientRegistrationResponse:
    """
    Update registered OAuth client configuration.

    Args:
        registration_client_uri: Client configuration URI
        registration_access_token: Access token for client management
        request: Updated client configuration

    Returns:
        ClientRegistrationResponse: Updated client credentials
    """
    # Implementation in Phase 3
    raise NotImplementedError("Dynamic client update not yet implemented (Phase 3)")
sk_agents.auth.oauth_client

OAuth 2.1 Client Implementation

Main OAuth client for handling authorization code flow with PKCE. Implements MCP specification requirements for OAuth authorization.

Key Features: - Authorization URL generation with PKCE + resource parameter - Authorization code exchange for access token - Token refresh with rotation - Resource-bound token acquisition

References: - MCP Specification 2025-06-18 - OAuth 2.1 Draft - RFC 8707 (Resource Indicators)

sk_agents.auth.oauth_client.OAuthClient

OAuth 2.1 Client for MCP Server Authentication.

Handles complete OAuth authorization code flow with PKCE and resource binding.

Source code in src/sk_agents/auth/oauth_client.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
class OAuthClient:
    """
    OAuth 2.1 Client for MCP Server Authentication.

    Handles complete OAuth authorization code flow with PKCE and resource binding.
    """

    def __init__(self, timeout: float = 30.0):
        """
        Initialize OAuth client.

        Args:
            timeout: HTTP request timeout in seconds
        """
        self.timeout = timeout
        self.pkce_manager = PKCEManager()
        self.state_manager = OAuthStateManager()
        self.metadata_cache = ServerMetadataCache(timeout=timeout)
        self.auth_storage_factory = AuthStorageFactory(AppConfig())
        self.auth_storage = self.auth_storage_factory.get_auth_storage_manager()

    @staticmethod
    def should_include_resource_param(
        protocol_version: str | None = None, has_prm: bool = False
    ) -> bool:
        """
        Determine if resource parameter should be included in OAuth requests.

        Per MCP specification 2025-06-18:
        - resource parameter MUST be included if protocol version >= 2025-06-18
        - resource parameter MUST be included if Protected Resource Metadata discovered
        - Otherwise, resource parameter SHOULD be omitted for backward compatibility

        Args:
            protocol_version: MCP protocol version (e.g., "2025-06-18")
            has_prm: Whether Protected Resource Metadata has been discovered

        Returns:
            bool: True if resource parameter should be included
        """
        # If we have Protected Resource Metadata, always include resource param
        if has_prm:
            return True

        # If no protocol version provided, don't include resource param (backward compat)
        if not protocol_version:
            return False

        # Check if protocol version is 2025-11-25 or later
        # Simple string comparison works for ISO date format (YYYY-MM-DD)
        try:
            return protocol_version >= "2025-11-25"
        except Exception:
            # If comparison fails, be conservative and include resource param
            logger.warning(f"Failed to compare protocol version: {protocol_version}")
            return True

    @staticmethod
    def validate_token_scopes(
        requested_scopes: list[str] | None, token_response: "TokenResponse"
    ) -> None:
        """
        Validate that returned scopes don't exceed requested scopes (prevents escalation attacks).

        Per OAuth 2.1 Section 3.3:
        - If scopes were requested, returned scopes MUST be a subset of requested scopes
        - Servers MUST NOT grant scopes not requested by the client
        - This prevents scope escalation attacks

        Args:
            requested_scopes: Scopes requested in authorization request
            token_response: Token response from authorization server

        Raises:
            ValueError: If scope escalation detected (returned > requested)
        """

        # If no scopes were requested, any returned scopes are acceptable
        if not requested_scopes:
            return

        # If server didn't return scope field, assume it granted all requested scopes
        # Per OAuth 2.1: "If omitted, authorization server defaults to all requested scopes"
        if not token_response.scope:
            logger.debug(
                "Token response contains no scope field - assuming all requested scopes granted"
            )
            return

        # Parse returned scopes (space-separated string)
        requested = set(requested_scopes)
        returned = set(token_response.scope.split())

        # Check for scope escalation: returned scopes must be subset of requested
        unauthorized_scopes = returned - requested

        if unauthorized_scopes:
            logger.error(
                f"Scope escalation attack detected! "
                f"Requested: {requested}, Returned: {returned}, Unauthorized: {unauthorized_scopes}"
            )
            raise ValueError(
                f"Server granted unauthorized scopes: {unauthorized_scopes}. "
                f"This is a scope escalation attack. Requested: {requested}, Returned: {returned}"
            )

        # Log scope reduction (informational - not an error)
        missing_scopes = requested - returned
        if missing_scopes:
            logger.warning(
                f"Server granted fewer scopes than requested. "
                f"Requested: {requested}, Granted: {returned}, Missing: {missing_scopes}"
            )
        else:
            logger.debug(f"Scope validation passed. Granted scopes: {returned}")

    def build_authorization_url(self, request: AuthorizationRequest) -> str:
        """
        Build complete OAuth authorization URL.

        Constructs URL with all required parameters:
        - response_type=code
        - client_id, redirect_uri
        - resource (canonical MCP server URI) - only if protocol version >= 2025-06-18
        - code_challenge, code_challenge_method=S256
        - scope, state

        Args:
            request: Authorization request parameters

        Returns:
            str: Complete authorization URL for user redirect
        """
        params = {
            "response_type": request.response_type,
            "client_id": request.client_id,
            "redirect_uri": str(request.redirect_uri),
            "scope": " ".join(request.scopes),
            "state": request.state,
            "code_challenge": request.code_challenge,
            "code_challenge_method": request.code_challenge_method,
        }

        # Conditionally include resource parameter per MCP spec 2025-06-18
        if request.resource:
            params["resource"] = request.resource

        # Build URL - use discovered authorization_endpoint if available
        if request.authorization_endpoint:
            base_url = str(request.authorization_endpoint)
            logger.debug(f"Using discovered authorization endpoint: {base_url}")
        else:
            # Fallback: construct from auth_server
            base_url = str(request.auth_server).rstrip("/")
            if not base_url.endswith("/authorize"):
                base_url = f"{base_url}/authorize"
            logger.debug(f"Using fallback authorization endpoint: {base_url}")

        auth_url = f"{base_url}?{urlencode(params)}"
        logger.debug(f"Built authorization URL for resource={request.resource}")
        return auth_url

    async def exchange_code_for_tokens(self, token_request: TokenRequest) -> TokenResponse:
        """
        Exchange authorization code for access token.

        Makes POST request to token endpoint with:
        - grant_type=authorization_code
        - code, redirect_uri
        - code_verifier (PKCE)
        - resource (canonical URI) - only if protocol version >= 2025-06-18
        - client_id (+ client_secret if confidential)

        Args:
            token_request: Token request parameters

        Returns:
            TokenResponse: Access token and metadata

        Raises:
            httpx.HTTPError: If token request fails
            ValueError: If response is invalid
        """
        # Build request body
        body = {
            "grant_type": token_request.grant_type,
            "client_id": token_request.client_id,
        }

        # Conditionally include resource parameter per MCP spec 2025-06-18
        if token_request.resource:
            body["resource"] = token_request.resource

        # Add grant-specific parameters
        if token_request.grant_type == "authorization_code":
            if (
                not token_request.code
                or not token_request.redirect_uri
                or not token_request.code_verifier
            ):
                raise ValueError("Missing required parameters for authorization_code grant")
            body["code"] = token_request.code
            body["redirect_uri"] = str(token_request.redirect_uri)
            body["code_verifier"] = token_request.code_verifier
        elif token_request.grant_type == "refresh_token":
            if not token_request.refresh_token:
                raise ValueError("Missing refresh_token for refresh_token grant")
            body["refresh_token"] = token_request.refresh_token

        # Add client secret if provided (confidential client)
        if token_request.client_secret:
            body["client_secret"] = token_request.client_secret

        logger.debug(
            f"Exchanging code for tokens: endpoint={token_request.token_endpoint}, "
            f"grant_type={token_request.grant_type}"
        )

        # Make token request
        async with httpx.AsyncClient(timeout=self.timeout) as client:
            response = await client.post(
                str(token_request.token_endpoint),
                data=body,
                headers={"Content-Type": "application/x-www-form-urlencoded"},
            )

            if response.status_code != 200:
                error_data = (
                    response.json()
                    if response.headers.get("content-type") == "application/json"
                    else {}
                )
                logger.error(
                    f"Token request failed: status={response.status_code}, error={error_data}"
                )
                raise httpx.HTTPError(
                    f"Token request failed: {error_data.get('error', 'unknown_error')}"
                )

            # Parse response
            token_data = response.json()
            token_response = TokenResponse(**token_data)

            # Validate scopes to prevent escalation attacks
            self.validate_token_scopes(token_request.requested_scopes, token_response)

            logger.info(f"Successfully obtained access token for resource={token_request.resource}")
            return token_response

    async def refresh_access_token(self, refresh_request: RefreshTokenRequest) -> TokenResponse:
        """
        Refresh expired access token.

        Makes POST request to token endpoint with:
        - grant_type=refresh_token
        - refresh_token
        - resource (must match original)
        - client_id

        Implements token rotation per OAuth 2.1.

        Args:
            refresh_request: Refresh token request parameters

        Returns:
            TokenResponse: New access token (and possibly new refresh token)

        Raises:
            httpx.HTTPError: If refresh fails
        """
        token_request = TokenRequest(
            token_endpoint=refresh_request.token_endpoint,
            grant_type="refresh_token",
            refresh_token=refresh_request.refresh_token,
            resource=refresh_request.resource,
            client_id=refresh_request.client_id,
            client_secret=refresh_request.client_secret,
            requested_scopes=refresh_request.requested_scopes,
        )

        logger.debug(f"Refreshing access token for resource={refresh_request.resource}")
        return await self.exchange_code_for_tokens(token_request)

    async def revoke_token(
        self,
        token: str,
        revocation_endpoint: str,
        client_id: str,
        client_secret: str | None = None,
        token_type_hint: str = "access_token",
    ) -> None:
        """
        Revoke an access or refresh token per RFC 7009.

        This allows clients to notify the authorization server that a token
        is no longer needed, enabling immediate invalidation.

        Args:
            token: The token to revoke (access or refresh token)
            revocation_endpoint: Token revocation endpoint URL
            client_id: OAuth client ID
            client_secret: OAuth client secret (for confidential clients)
            token_type_hint: Hint about token type ("access_token" or "refresh_token")

        Raises:
            httpx.HTTPError: If revocation request fails
        """
        # Build request body per RFC 7009
        body = {
            "token": token,
            "token_type_hint": token_type_hint,
            "client_id": client_id,
        }

        # Add client secret if provided (confidential clients)
        if client_secret:
            body["client_secret"] = client_secret

        logger.debug(f"Revoking token: endpoint={revocation_endpoint}, type_hint={token_type_hint}")

        try:
            async with httpx.AsyncClient(timeout=self.timeout) as client:
                response = await client.post(
                    revocation_endpoint,
                    data=body,
                    headers={"Content-Type": "application/x-www-form-urlencoded"},
                )

                # Per RFC 7009: Server responds with 200 regardless of token validity
                # This prevents token scanning attacks
                if response.status_code == 200:
                    logger.info(f"Successfully revoked token (type_hint={token_type_hint})")
                else:
                    logger.warning(
                        f"Token revocation returned unexpected status {response.status_code}"
                    )
                    response.raise_for_status()

        except httpx.HTTPError as e:
            logger.error(f"Failed to revoke token: {e}")
            raise

    async def initiate_authorization_flow(
        self,
        server_config: "McpServerConfig",
        user_id: str,
    ) -> str:
        """
        Initiate OAuth authorization flow for MCP server.

        Generates PKCE pair, state, stores flow state, and returns authorization URL.

        Args:
            server_config: MCP server configuration
            user_id: User ID initiating the flow

        Returns:
            str: Authorization URL for user redirect

        Raises:
            ValueError: If server configuration is invalid
        """
        from ska_utils import AppConfig

        from sk_agents.configs import TA_OAUTH_CLIENT_NAME

        # Discover Protected Resource Metadata (RFC 9728) if HTTP MCP server
        has_prm = False
        if server_config.url:  # Only for HTTP MCP servers
            try:
                prm = await self.metadata_cache.fetch_protected_resource_metadata(server_config.url)
                has_prm = prm is not None
                if prm:
                    logger.info(
                        f"Discovered PRM for {server_config.name}: "
                        f"auth_servers={prm.authorization_servers}"
                    )
            except Exception as e:
                logger.debug(f"PRM discovery failed (optional): {e}")
                has_prm = False

        # Determine if resource parameter should be included (per MCP spec 2025-06-18)
        include_resource = self.should_include_resource_param(
            protocol_version=server_config.protocol_version, has_prm=has_prm
        )

        # Get canonical resource URI if needed
        resource = None
        if include_resource:
            try:
                resource = server_config.effective_canonical_uri
            except ValueError as e:
                logger.warning(
                    f"Cannot determine canonical URI for {server_config.name}: {e}. "
                    "Proceeding without resource parameter."
                )
                resource = None

        # Generate PKCE pair
        verifier, challenge = self.pkce_manager.generate_pkce_pair()

        # Generate state
        state = self.state_manager.generate_state()

        # Store flow state (always store resource for validation, even if not sent in auth request)
        self.state_manager.store_flow_state(
            state=state,
            verifier=verifier,
            user_id=user_id,
            server_name=server_config.name,
            resource=resource or server_config.url or "",  # Store for future reference
            scopes=server_config.scopes,
        )

        # Get client configuration
        app_config = AppConfig()
        client_name = app_config.get(TA_OAUTH_CLIENT_NAME.env_name)

        # Discover authorization server metadata (RFC 8414)
        authorization_endpoint = None
        metadata = None
        try:
            metadata = await self.metadata_cache.fetch_auth_server_metadata(
                server_config.auth_server
            )
            authorization_endpoint = str(metadata.authorization_endpoint)
            logger.info(f"Discovered authorization endpoint: {authorization_endpoint}")
        except Exception as e:
            logger.warning(
                f"Failed to discover authorization server metadata: {e}. Using fallback."
            )
            authorization_endpoint = None

        # Try dynamic client registration if no client_id configured (RFC 7591)
        client_id = server_config.oauth_client_id or client_name

        if not server_config.oauth_client_id and server_config.enable_dynamic_registration:
            try:
                # Check if metadata includes registration_endpoint
                if metadata and metadata.registration_endpoint:
                    logger.info(
                        f"No client_id configured for {server_config.name}. "
                        f"Attempting dynamic registration..."
                    )

                    from sk_agents.auth.client_registration import DynamicClientRegistration

                    registration_client = DynamicClientRegistration(timeout=self.timeout)
                    registration_response = await registration_client.register_client(
                        registration_endpoint=str(metadata.registration_endpoint),
                        redirect_uris=[str(server_config.oauth_redirect_uri)],
                        client_name=client_name,
                        scopes=server_config.scopes,
                    )

                    # Use registered credentials
                    client_id = registration_response.client_id
                    # Note: client_secret available in registration_response if needed

                    logger.info(
                        f"Successfully registered client for {server_config.name}: "
                        f"client_id={client_id}"
                    )

                    # TODO: Optionally persist client_id/secret for reuse

                else:
                    logger.warning(
                        f"Dynamic registration enabled but no registration_endpoint "
                        f"discovered for {server_config.name}"
                    )
            except Exception as e:
                logger.warning(
                    f"Dynamic client registration failed for {server_config.name}: {e}. "
                    f"Falling back to default client_id."
                )
                # Continue with default client_name

        # Build authorization request
        auth_request = AuthorizationRequest(
            auth_server=server_config.auth_server,
            authorization_endpoint=authorization_endpoint,
            client_id=client_id,
            redirect_uri=server_config.oauth_redirect_uri,
            resource=resource,  # None if protocol version < 2025-06-18
            scopes=server_config.scopes,
            state=state,
            code_challenge=challenge,
            code_challenge_method="S256",
        )

        # Build and return authorization URL
        auth_url = self.build_authorization_url(auth_request)
        logger.info(
            f"Initiated OAuth flow for {server_config.name}: "
            f"user={user_id}, resource={resource}, state={state}"
        )
        return auth_url

    async def handle_callback(
        self,
        code: str,
        state: str,
        user_id: str,
        server_config: "McpServerConfig",
    ) -> OAuth2AuthData:
        """
        Handle OAuth callback after user authorization.

        Validates state, exchanges code for tokens, and stores in AuthStorage.

        Args:
            code: Authorization code from callback
            state: State parameter from callback
            user_id: User ID to validate against
            server_config: MCP server configuration

        Returns:
            OAuth2AuthData: Stored token data

        Raises:
            ValueError: If state invalid or user mismatch
            httpx.HTTPError: If token exchange fails
        """
        from datetime import datetime, timedelta

        from ska_utils import AppConfig

        from sk_agents.configs import TA_OAUTH_CLIENT_NAME
        from sk_agents.mcp_client import build_auth_storage_key

        # Retrieve and validate flow state
        flow_state = self.state_manager.retrieve_flow_state(state, user_id)

        # Get token endpoint (from server metadata or construct from auth_server)
        token_endpoint = f"{server_config.auth_server.rstrip('/')}/token"

        # Get client configuration
        app_config = AppConfig()
        client_name = app_config.get(TA_OAUTH_CLIENT_NAME.env_name)

        # Discover Protected Resource Metadata (RFC 9728) if HTTP MCP server
        has_prm = False
        if server_config.url:  # Only for HTTP MCP servers
            try:
                prm = await self.metadata_cache.fetch_protected_resource_metadata(server_config.url)
                has_prm = prm is not None
                if prm:
                    logger.info(
                        f"Discovered PRM for {server_config.name}: "
                        f"auth_servers={prm.authorization_servers}"
                    )
            except Exception as e:
                logger.debug(f"PRM discovery failed (optional): {e}")
                has_prm = False

        # Determine if resource parameter should be included (per MCP spec 2025-06-18)
        include_resource = self.should_include_resource_param(
            protocol_version=server_config.protocol_version, has_prm=has_prm
        )

        # Build token request
        token_request = TokenRequest(
            token_endpoint=token_endpoint,
            grant_type="authorization_code",
            code=code,
            redirect_uri=server_config.oauth_redirect_uri,
            code_verifier=flow_state.verifier,
            resource=flow_state.resource
            if include_resource
            else None,  # Conditional per protocol version
            client_id=server_config.oauth_client_id or client_name,
            client_secret=server_config.oauth_client_secret,
            requested_scopes=flow_state.scopes,  # For scope validation
        )

        # Exchange code for tokens
        token_response = await self.exchange_code_for_tokens(token_request)

        # Create OAuth2AuthData
        oauth_data = OAuth2AuthData(
            access_token=token_response.access_token,
            refresh_token=token_response.refresh_token,
            expires_at=datetime.now(UTC) + timedelta(seconds=token_response.expires_in),
            scopes=token_response.scope.split() if token_response.scope else flow_state.scopes,
            audience=token_response.aud,
            resource=flow_state.resource,
            token_type=token_response.token_type,
            issued_at=datetime.now(UTC),
        )

        # Store in AuthStorage
        composite_key = build_auth_storage_key(server_config.auth_server, oauth_data.scopes)
        self.auth_storage.store(user_id, composite_key, oauth_data)

        logger.info(
            f"OAuth callback successful for {flow_state.server_name}: "
            f"user={user_id}, resource={flow_state.resource}"
        )

        # Clean up flow state
        self.state_manager.delete_flow_state(state, user_id)

        return oauth_data
sk_agents.auth.oauth_client.OAuthClient.__init__
__init__(timeout: float = 30.0)

Initialize OAuth client.

Parameters:

Name Type Description Default
timeout float

HTTP request timeout in seconds

30.0
Source code in src/sk_agents/auth/oauth_client.py
def __init__(self, timeout: float = 30.0):
    """
    Initialize OAuth client.

    Args:
        timeout: HTTP request timeout in seconds
    """
    self.timeout = timeout
    self.pkce_manager = PKCEManager()
    self.state_manager = OAuthStateManager()
    self.metadata_cache = ServerMetadataCache(timeout=timeout)
    self.auth_storage_factory = AuthStorageFactory(AppConfig())
    self.auth_storage = self.auth_storage_factory.get_auth_storage_manager()
sk_agents.auth.oauth_client.OAuthClient.should_include_resource_param staticmethod
should_include_resource_param(
    protocol_version: str | None = None,
    has_prm: bool = False,
) -> bool

Determine if resource parameter should be included in OAuth requests.

Per MCP specification 2025-06-18: - resource parameter MUST be included if protocol version >= 2025-06-18 - resource parameter MUST be included if Protected Resource Metadata discovered - Otherwise, resource parameter SHOULD be omitted for backward compatibility

Parameters:

Name Type Description Default
protocol_version str | None

MCP protocol version (e.g., "2025-06-18")

None
has_prm bool

Whether Protected Resource Metadata has been discovered

False

Returns:

Name Type Description
bool bool

True if resource parameter should be included

Source code in src/sk_agents/auth/oauth_client.py
@staticmethod
def should_include_resource_param(
    protocol_version: str | None = None, has_prm: bool = False
) -> bool:
    """
    Determine if resource parameter should be included in OAuth requests.

    Per MCP specification 2025-06-18:
    - resource parameter MUST be included if protocol version >= 2025-06-18
    - resource parameter MUST be included if Protected Resource Metadata discovered
    - Otherwise, resource parameter SHOULD be omitted for backward compatibility

    Args:
        protocol_version: MCP protocol version (e.g., "2025-06-18")
        has_prm: Whether Protected Resource Metadata has been discovered

    Returns:
        bool: True if resource parameter should be included
    """
    # If we have Protected Resource Metadata, always include resource param
    if has_prm:
        return True

    # If no protocol version provided, don't include resource param (backward compat)
    if not protocol_version:
        return False

    # Check if protocol version is 2025-11-25 or later
    # Simple string comparison works for ISO date format (YYYY-MM-DD)
    try:
        return protocol_version >= "2025-11-25"
    except Exception:
        # If comparison fails, be conservative and include resource param
        logger.warning(f"Failed to compare protocol version: {protocol_version}")
        return True
sk_agents.auth.oauth_client.OAuthClient.validate_token_scopes staticmethod
validate_token_scopes(
    requested_scopes: list[str] | None,
    token_response: TokenResponse,
) -> None

Validate that returned scopes don't exceed requested scopes (prevents escalation attacks).

Per OAuth 2.1 Section 3.3: - If scopes were requested, returned scopes MUST be a subset of requested scopes - Servers MUST NOT grant scopes not requested by the client - This prevents scope escalation attacks

Parameters:

Name Type Description Default
requested_scopes list[str] | None

Scopes requested in authorization request

required
token_response TokenResponse

Token response from authorization server

required

Raises:

Type Description
ValueError

If scope escalation detected (returned > requested)

Source code in src/sk_agents/auth/oauth_client.py
@staticmethod
def validate_token_scopes(
    requested_scopes: list[str] | None, token_response: "TokenResponse"
) -> None:
    """
    Validate that returned scopes don't exceed requested scopes (prevents escalation attacks).

    Per OAuth 2.1 Section 3.3:
    - If scopes were requested, returned scopes MUST be a subset of requested scopes
    - Servers MUST NOT grant scopes not requested by the client
    - This prevents scope escalation attacks

    Args:
        requested_scopes: Scopes requested in authorization request
        token_response: Token response from authorization server

    Raises:
        ValueError: If scope escalation detected (returned > requested)
    """

    # If no scopes were requested, any returned scopes are acceptable
    if not requested_scopes:
        return

    # If server didn't return scope field, assume it granted all requested scopes
    # Per OAuth 2.1: "If omitted, authorization server defaults to all requested scopes"
    if not token_response.scope:
        logger.debug(
            "Token response contains no scope field - assuming all requested scopes granted"
        )
        return

    # Parse returned scopes (space-separated string)
    requested = set(requested_scopes)
    returned = set(token_response.scope.split())

    # Check for scope escalation: returned scopes must be subset of requested
    unauthorized_scopes = returned - requested

    if unauthorized_scopes:
        logger.error(
            f"Scope escalation attack detected! "
            f"Requested: {requested}, Returned: {returned}, Unauthorized: {unauthorized_scopes}"
        )
        raise ValueError(
            f"Server granted unauthorized scopes: {unauthorized_scopes}. "
            f"This is a scope escalation attack. Requested: {requested}, Returned: {returned}"
        )

    # Log scope reduction (informational - not an error)
    missing_scopes = requested - returned
    if missing_scopes:
        logger.warning(
            f"Server granted fewer scopes than requested. "
            f"Requested: {requested}, Granted: {returned}, Missing: {missing_scopes}"
        )
    else:
        logger.debug(f"Scope validation passed. Granted scopes: {returned}")
sk_agents.auth.oauth_client.OAuthClient.build_authorization_url
build_authorization_url(
    request: AuthorizationRequest,
) -> str

Build complete OAuth authorization URL.

Constructs URL with all required parameters: - response_type=code - client_id, redirect_uri - resource (canonical MCP server URI) - only if protocol version >= 2025-06-18 - code_challenge, code_challenge_method=S256 - scope, state

Parameters:

Name Type Description Default
request AuthorizationRequest

Authorization request parameters

required

Returns:

Name Type Description
str str

Complete authorization URL for user redirect

Source code in src/sk_agents/auth/oauth_client.py
def build_authorization_url(self, request: AuthorizationRequest) -> str:
    """
    Build complete OAuth authorization URL.

    Constructs URL with all required parameters:
    - response_type=code
    - client_id, redirect_uri
    - resource (canonical MCP server URI) - only if protocol version >= 2025-06-18
    - code_challenge, code_challenge_method=S256
    - scope, state

    Args:
        request: Authorization request parameters

    Returns:
        str: Complete authorization URL for user redirect
    """
    params = {
        "response_type": request.response_type,
        "client_id": request.client_id,
        "redirect_uri": str(request.redirect_uri),
        "scope": " ".join(request.scopes),
        "state": request.state,
        "code_challenge": request.code_challenge,
        "code_challenge_method": request.code_challenge_method,
    }

    # Conditionally include resource parameter per MCP spec 2025-06-18
    if request.resource:
        params["resource"] = request.resource

    # Build URL - use discovered authorization_endpoint if available
    if request.authorization_endpoint:
        base_url = str(request.authorization_endpoint)
        logger.debug(f"Using discovered authorization endpoint: {base_url}")
    else:
        # Fallback: construct from auth_server
        base_url = str(request.auth_server).rstrip("/")
        if not base_url.endswith("/authorize"):
            base_url = f"{base_url}/authorize"
        logger.debug(f"Using fallback authorization endpoint: {base_url}")

    auth_url = f"{base_url}?{urlencode(params)}"
    logger.debug(f"Built authorization URL for resource={request.resource}")
    return auth_url
sk_agents.auth.oauth_client.OAuthClient.exchange_code_for_tokens async
exchange_code_for_tokens(
    token_request: TokenRequest,
) -> TokenResponse

Exchange authorization code for access token.

Makes POST request to token endpoint with: - grant_type=authorization_code - code, redirect_uri - code_verifier (PKCE) - resource (canonical URI) - only if protocol version >= 2025-06-18 - client_id (+ client_secret if confidential)

Parameters:

Name Type Description Default
token_request TokenRequest

Token request parameters

required

Returns:

Name Type Description
TokenResponse TokenResponse

Access token and metadata

Raises:

Type Description
HTTPError

If token request fails

ValueError

If response is invalid

Source code in src/sk_agents/auth/oauth_client.py
async def exchange_code_for_tokens(self, token_request: TokenRequest) -> TokenResponse:
    """
    Exchange authorization code for access token.

    Makes POST request to token endpoint with:
    - grant_type=authorization_code
    - code, redirect_uri
    - code_verifier (PKCE)
    - resource (canonical URI) - only if protocol version >= 2025-06-18
    - client_id (+ client_secret if confidential)

    Args:
        token_request: Token request parameters

    Returns:
        TokenResponse: Access token and metadata

    Raises:
        httpx.HTTPError: If token request fails
        ValueError: If response is invalid
    """
    # Build request body
    body = {
        "grant_type": token_request.grant_type,
        "client_id": token_request.client_id,
    }

    # Conditionally include resource parameter per MCP spec 2025-06-18
    if token_request.resource:
        body["resource"] = token_request.resource

    # Add grant-specific parameters
    if token_request.grant_type == "authorization_code":
        if (
            not token_request.code
            or not token_request.redirect_uri
            or not token_request.code_verifier
        ):
            raise ValueError("Missing required parameters for authorization_code grant")
        body["code"] = token_request.code
        body["redirect_uri"] = str(token_request.redirect_uri)
        body["code_verifier"] = token_request.code_verifier
    elif token_request.grant_type == "refresh_token":
        if not token_request.refresh_token:
            raise ValueError("Missing refresh_token for refresh_token grant")
        body["refresh_token"] = token_request.refresh_token

    # Add client secret if provided (confidential client)
    if token_request.client_secret:
        body["client_secret"] = token_request.client_secret

    logger.debug(
        f"Exchanging code for tokens: endpoint={token_request.token_endpoint}, "
        f"grant_type={token_request.grant_type}"
    )

    # Make token request
    async with httpx.AsyncClient(timeout=self.timeout) as client:
        response = await client.post(
            str(token_request.token_endpoint),
            data=body,
            headers={"Content-Type": "application/x-www-form-urlencoded"},
        )

        if response.status_code != 200:
            error_data = (
                response.json()
                if response.headers.get("content-type") == "application/json"
                else {}
            )
            logger.error(
                f"Token request failed: status={response.status_code}, error={error_data}"
            )
            raise httpx.HTTPError(
                f"Token request failed: {error_data.get('error', 'unknown_error')}"
            )

        # Parse response
        token_data = response.json()
        token_response = TokenResponse(**token_data)

        # Validate scopes to prevent escalation attacks
        self.validate_token_scopes(token_request.requested_scopes, token_response)

        logger.info(f"Successfully obtained access token for resource={token_request.resource}")
        return token_response
sk_agents.auth.oauth_client.OAuthClient.refresh_access_token async
refresh_access_token(
    refresh_request: RefreshTokenRequest,
) -> TokenResponse

Refresh expired access token.

Makes POST request to token endpoint with: - grant_type=refresh_token - refresh_token - resource (must match original) - client_id

Implements token rotation per OAuth 2.1.

Parameters:

Name Type Description Default
refresh_request RefreshTokenRequest

Refresh token request parameters

required

Returns:

Name Type Description
TokenResponse TokenResponse

New access token (and possibly new refresh token)

Raises:

Type Description
HTTPError

If refresh fails

Source code in src/sk_agents/auth/oauth_client.py
async def refresh_access_token(self, refresh_request: RefreshTokenRequest) -> TokenResponse:
    """
    Refresh expired access token.

    Makes POST request to token endpoint with:
    - grant_type=refresh_token
    - refresh_token
    - resource (must match original)
    - client_id

    Implements token rotation per OAuth 2.1.

    Args:
        refresh_request: Refresh token request parameters

    Returns:
        TokenResponse: New access token (and possibly new refresh token)

    Raises:
        httpx.HTTPError: If refresh fails
    """
    token_request = TokenRequest(
        token_endpoint=refresh_request.token_endpoint,
        grant_type="refresh_token",
        refresh_token=refresh_request.refresh_token,
        resource=refresh_request.resource,
        client_id=refresh_request.client_id,
        client_secret=refresh_request.client_secret,
        requested_scopes=refresh_request.requested_scopes,
    )

    logger.debug(f"Refreshing access token for resource={refresh_request.resource}")
    return await self.exchange_code_for_tokens(token_request)
sk_agents.auth.oauth_client.OAuthClient.revoke_token async
revoke_token(
    token: str,
    revocation_endpoint: str,
    client_id: str,
    client_secret: str | None = None,
    token_type_hint: str = "access_token",
) -> None

Revoke an access or refresh token per RFC 7009.

This allows clients to notify the authorization server that a token is no longer needed, enabling immediate invalidation.

Parameters:

Name Type Description Default
token str

The token to revoke (access or refresh token)

required
revocation_endpoint str

Token revocation endpoint URL

required
client_id str

OAuth client ID

required
client_secret str | None

OAuth client secret (for confidential clients)

None
token_type_hint str

Hint about token type ("access_token" or "refresh_token")

'access_token'

Raises:

Type Description
HTTPError

If revocation request fails

Source code in src/sk_agents/auth/oauth_client.py
async def revoke_token(
    self,
    token: str,
    revocation_endpoint: str,
    client_id: str,
    client_secret: str | None = None,
    token_type_hint: str = "access_token",
) -> None:
    """
    Revoke an access or refresh token per RFC 7009.

    This allows clients to notify the authorization server that a token
    is no longer needed, enabling immediate invalidation.

    Args:
        token: The token to revoke (access or refresh token)
        revocation_endpoint: Token revocation endpoint URL
        client_id: OAuth client ID
        client_secret: OAuth client secret (for confidential clients)
        token_type_hint: Hint about token type ("access_token" or "refresh_token")

    Raises:
        httpx.HTTPError: If revocation request fails
    """
    # Build request body per RFC 7009
    body = {
        "token": token,
        "token_type_hint": token_type_hint,
        "client_id": client_id,
    }

    # Add client secret if provided (confidential clients)
    if client_secret:
        body["client_secret"] = client_secret

    logger.debug(f"Revoking token: endpoint={revocation_endpoint}, type_hint={token_type_hint}")

    try:
        async with httpx.AsyncClient(timeout=self.timeout) as client:
            response = await client.post(
                revocation_endpoint,
                data=body,
                headers={"Content-Type": "application/x-www-form-urlencoded"},
            )

            # Per RFC 7009: Server responds with 200 regardless of token validity
            # This prevents token scanning attacks
            if response.status_code == 200:
                logger.info(f"Successfully revoked token (type_hint={token_type_hint})")
            else:
                logger.warning(
                    f"Token revocation returned unexpected status {response.status_code}"
                )
                response.raise_for_status()

    except httpx.HTTPError as e:
        logger.error(f"Failed to revoke token: {e}")
        raise
sk_agents.auth.oauth_client.OAuthClient.initiate_authorization_flow async
initiate_authorization_flow(
    server_config: McpServerConfig, user_id: str
) -> str

Initiate OAuth authorization flow for MCP server.

Generates PKCE pair, state, stores flow state, and returns authorization URL.

Parameters:

Name Type Description Default
server_config McpServerConfig

MCP server configuration

required
user_id str

User ID initiating the flow

required

Returns:

Name Type Description
str str

Authorization URL for user redirect

Raises:

Type Description
ValueError

If server configuration is invalid

Source code in src/sk_agents/auth/oauth_client.py
async def initiate_authorization_flow(
    self,
    server_config: "McpServerConfig",
    user_id: str,
) -> str:
    """
    Initiate OAuth authorization flow for MCP server.

    Generates PKCE pair, state, stores flow state, and returns authorization URL.

    Args:
        server_config: MCP server configuration
        user_id: User ID initiating the flow

    Returns:
        str: Authorization URL for user redirect

    Raises:
        ValueError: If server configuration is invalid
    """
    from ska_utils import AppConfig

    from sk_agents.configs import TA_OAUTH_CLIENT_NAME

    # Discover Protected Resource Metadata (RFC 9728) if HTTP MCP server
    has_prm = False
    if server_config.url:  # Only for HTTP MCP servers
        try:
            prm = await self.metadata_cache.fetch_protected_resource_metadata(server_config.url)
            has_prm = prm is not None
            if prm:
                logger.info(
                    f"Discovered PRM for {server_config.name}: "
                    f"auth_servers={prm.authorization_servers}"
                )
        except Exception as e:
            logger.debug(f"PRM discovery failed (optional): {e}")
            has_prm = False

    # Determine if resource parameter should be included (per MCP spec 2025-06-18)
    include_resource = self.should_include_resource_param(
        protocol_version=server_config.protocol_version, has_prm=has_prm
    )

    # Get canonical resource URI if needed
    resource = None
    if include_resource:
        try:
            resource = server_config.effective_canonical_uri
        except ValueError as e:
            logger.warning(
                f"Cannot determine canonical URI for {server_config.name}: {e}. "
                "Proceeding without resource parameter."
            )
            resource = None

    # Generate PKCE pair
    verifier, challenge = self.pkce_manager.generate_pkce_pair()

    # Generate state
    state = self.state_manager.generate_state()

    # Store flow state (always store resource for validation, even if not sent in auth request)
    self.state_manager.store_flow_state(
        state=state,
        verifier=verifier,
        user_id=user_id,
        server_name=server_config.name,
        resource=resource or server_config.url or "",  # Store for future reference
        scopes=server_config.scopes,
    )

    # Get client configuration
    app_config = AppConfig()
    client_name = app_config.get(TA_OAUTH_CLIENT_NAME.env_name)

    # Discover authorization server metadata (RFC 8414)
    authorization_endpoint = None
    metadata = None
    try:
        metadata = await self.metadata_cache.fetch_auth_server_metadata(
            server_config.auth_server
        )
        authorization_endpoint = str(metadata.authorization_endpoint)
        logger.info(f"Discovered authorization endpoint: {authorization_endpoint}")
    except Exception as e:
        logger.warning(
            f"Failed to discover authorization server metadata: {e}. Using fallback."
        )
        authorization_endpoint = None

    # Try dynamic client registration if no client_id configured (RFC 7591)
    client_id = server_config.oauth_client_id or client_name

    if not server_config.oauth_client_id and server_config.enable_dynamic_registration:
        try:
            # Check if metadata includes registration_endpoint
            if metadata and metadata.registration_endpoint:
                logger.info(
                    f"No client_id configured for {server_config.name}. "
                    f"Attempting dynamic registration..."
                )

                from sk_agents.auth.client_registration import DynamicClientRegistration

                registration_client = DynamicClientRegistration(timeout=self.timeout)
                registration_response = await registration_client.register_client(
                    registration_endpoint=str(metadata.registration_endpoint),
                    redirect_uris=[str(server_config.oauth_redirect_uri)],
                    client_name=client_name,
                    scopes=server_config.scopes,
                )

                # Use registered credentials
                client_id = registration_response.client_id
                # Note: client_secret available in registration_response if needed

                logger.info(
                    f"Successfully registered client for {server_config.name}: "
                    f"client_id={client_id}"
                )

                # TODO: Optionally persist client_id/secret for reuse

            else:
                logger.warning(
                    f"Dynamic registration enabled but no registration_endpoint "
                    f"discovered for {server_config.name}"
                )
        except Exception as e:
            logger.warning(
                f"Dynamic client registration failed for {server_config.name}: {e}. "
                f"Falling back to default client_id."
            )
            # Continue with default client_name

    # Build authorization request
    auth_request = AuthorizationRequest(
        auth_server=server_config.auth_server,
        authorization_endpoint=authorization_endpoint,
        client_id=client_id,
        redirect_uri=server_config.oauth_redirect_uri,
        resource=resource,  # None if protocol version < 2025-06-18
        scopes=server_config.scopes,
        state=state,
        code_challenge=challenge,
        code_challenge_method="S256",
    )

    # Build and return authorization URL
    auth_url = self.build_authorization_url(auth_request)
    logger.info(
        f"Initiated OAuth flow for {server_config.name}: "
        f"user={user_id}, resource={resource}, state={state}"
    )
    return auth_url
sk_agents.auth.oauth_client.OAuthClient.handle_callback async
handle_callback(
    code: str,
    state: str,
    user_id: str,
    server_config: McpServerConfig,
) -> OAuth2AuthData

Handle OAuth callback after user authorization.

Validates state, exchanges code for tokens, and stores in AuthStorage.

Parameters:

Name Type Description Default
code str

Authorization code from callback

required
state str

State parameter from callback

required
user_id str

User ID to validate against

required
server_config McpServerConfig

MCP server configuration

required

Returns:

Name Type Description
OAuth2AuthData OAuth2AuthData

Stored token data

Raises:

Type Description
ValueError

If state invalid or user mismatch

HTTPError

If token exchange fails

Source code in src/sk_agents/auth/oauth_client.py
async def handle_callback(
    self,
    code: str,
    state: str,
    user_id: str,
    server_config: "McpServerConfig",
) -> OAuth2AuthData:
    """
    Handle OAuth callback after user authorization.

    Validates state, exchanges code for tokens, and stores in AuthStorage.

    Args:
        code: Authorization code from callback
        state: State parameter from callback
        user_id: User ID to validate against
        server_config: MCP server configuration

    Returns:
        OAuth2AuthData: Stored token data

    Raises:
        ValueError: If state invalid or user mismatch
        httpx.HTTPError: If token exchange fails
    """
    from datetime import datetime, timedelta

    from ska_utils import AppConfig

    from sk_agents.configs import TA_OAUTH_CLIENT_NAME
    from sk_agents.mcp_client import build_auth_storage_key

    # Retrieve and validate flow state
    flow_state = self.state_manager.retrieve_flow_state(state, user_id)

    # Get token endpoint (from server metadata or construct from auth_server)
    token_endpoint = f"{server_config.auth_server.rstrip('/')}/token"

    # Get client configuration
    app_config = AppConfig()
    client_name = app_config.get(TA_OAUTH_CLIENT_NAME.env_name)

    # Discover Protected Resource Metadata (RFC 9728) if HTTP MCP server
    has_prm = False
    if server_config.url:  # Only for HTTP MCP servers
        try:
            prm = await self.metadata_cache.fetch_protected_resource_metadata(server_config.url)
            has_prm = prm is not None
            if prm:
                logger.info(
                    f"Discovered PRM for {server_config.name}: "
                    f"auth_servers={prm.authorization_servers}"
                )
        except Exception as e:
            logger.debug(f"PRM discovery failed (optional): {e}")
            has_prm = False

    # Determine if resource parameter should be included (per MCP spec 2025-06-18)
    include_resource = self.should_include_resource_param(
        protocol_version=server_config.protocol_version, has_prm=has_prm
    )

    # Build token request
    token_request = TokenRequest(
        token_endpoint=token_endpoint,
        grant_type="authorization_code",
        code=code,
        redirect_uri=server_config.oauth_redirect_uri,
        code_verifier=flow_state.verifier,
        resource=flow_state.resource
        if include_resource
        else None,  # Conditional per protocol version
        client_id=server_config.oauth_client_id or client_name,
        client_secret=server_config.oauth_client_secret,
        requested_scopes=flow_state.scopes,  # For scope validation
    )

    # Exchange code for tokens
    token_response = await self.exchange_code_for_tokens(token_request)

    # Create OAuth2AuthData
    oauth_data = OAuth2AuthData(
        access_token=token_response.access_token,
        refresh_token=token_response.refresh_token,
        expires_at=datetime.now(UTC) + timedelta(seconds=token_response.expires_in),
        scopes=token_response.scope.split() if token_response.scope else flow_state.scopes,
        audience=token_response.aud,
        resource=flow_state.resource,
        token_type=token_response.token_type,
        issued_at=datetime.now(UTC),
    )

    # Store in AuthStorage
    composite_key = build_auth_storage_key(server_config.auth_server, oauth_data.scopes)
    self.auth_storage.store(user_id, composite_key, oauth_data)

    logger.info(
        f"OAuth callback successful for {flow_state.server_name}: "
        f"user={user_id}, resource={flow_state.resource}"
    )

    # Clean up flow state
    self.state_manager.delete_flow_state(state, user_id)

    return oauth_data
sk_agents.auth.oauth_error_handler

OAuth Error Handler

Handles OAuth error responses and WWW-Authenticate header parsing per: - RFC 6750: Bearer Token Usage - RFC 9728: Protected Resource Metadata - MCP Specification 2025-06-18

Key functionality: - Parse WWW-Authenticate headers from 401 responses - Extract error codes: invalid_token, insufficient_scope, etc. - Extract scope requirements for re-authorization - Extract resource_metadata URL for RFC 9728 discovery

sk_agents.auth.oauth_error_handler.WWWAuthenticateChallenge

Parsed WWW-Authenticate challenge from 401 response.

Per RFC 6750 Section 3: WWW-Authenticate: Bearer realm="example", error="invalid_token", error_description="The access token expired", scope="read write", resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"

Source code in src/sk_agents/auth/oauth_error_handler.py
class WWWAuthenticateChallenge:
    """
    Parsed WWW-Authenticate challenge from 401 response.

    Per RFC 6750 Section 3:
    WWW-Authenticate: Bearer realm="example",
                      error="invalid_token",
                      error_description="The access token expired",
                      scope="read write",
                      resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"
    """

    def __init__(
        self,
        realm: str | None = None,
        error: str | None = None,
        error_description: str | None = None,
        error_uri: str | None = None,
        scope: str | None = None,
        resource_metadata: str | None = None,
    ):
        self.realm = realm
        self.error = error
        self.error_description = error_description
        self.error_uri = error_uri
        self.scope = scope  # Space-separated scope string
        self.resource_metadata = resource_metadata

    @property
    def scopes(self) -> list[str]:
        """Get scopes as list."""
        return self.scope.split() if self.scope else []

    def requires_reauth(self) -> bool:
        """Check if error requires re-authorization."""
        return self.error in ("invalid_token", "insufficient_scope")

    def is_token_expired(self) -> bool:
        """Check if error indicates token expiry."""
        return self.error == "invalid_token"

    def is_insufficient_scope(self) -> bool:
        """Check if error indicates insufficient scopes."""
        return self.error == "insufficient_scope"

    def __repr__(self) -> str:
        return (
            f"WWWAuthenticateChallenge(error={self.error}, scope={self.scope}, realm={self.realm})"
        )
sk_agents.auth.oauth_error_handler.WWWAuthenticateChallenge.scopes property
scopes: list[str]

Get scopes as list.

sk_agents.auth.oauth_error_handler.WWWAuthenticateChallenge.requires_reauth
requires_reauth() -> bool

Check if error requires re-authorization.

Source code in src/sk_agents/auth/oauth_error_handler.py
def requires_reauth(self) -> bool:
    """Check if error requires re-authorization."""
    return self.error in ("invalid_token", "insufficient_scope")
sk_agents.auth.oauth_error_handler.WWWAuthenticateChallenge.is_token_expired
is_token_expired() -> bool

Check if error indicates token expiry.

Source code in src/sk_agents/auth/oauth_error_handler.py
def is_token_expired(self) -> bool:
    """Check if error indicates token expiry."""
    return self.error == "invalid_token"
sk_agents.auth.oauth_error_handler.WWWAuthenticateChallenge.is_insufficient_scope
is_insufficient_scope() -> bool

Check if error indicates insufficient scopes.

Source code in src/sk_agents/auth/oauth_error_handler.py
def is_insufficient_scope(self) -> bool:
    """Check if error indicates insufficient scopes."""
    return self.error == "insufficient_scope"
sk_agents.auth.oauth_error_handler.OAuthErrorHandler

Handler for OAuth error responses.

Provides structured error handling for: - 401 Unauthorized (invalid_token, insufficient_scope) - 403 Forbidden (insufficient permissions) - 400 Bad Request (malformed request)

Source code in src/sk_agents/auth/oauth_error_handler.py
class OAuthErrorHandler:
    """
    Handler for OAuth error responses.

    Provides structured error handling for:
    - 401 Unauthorized (invalid_token, insufficient_scope)
    - 403 Forbidden (insufficient permissions)
    - 400 Bad Request (malformed request)
    """

    @staticmethod
    def handle_401_response(response_headers: dict[str, str]) -> WWWAuthenticateChallenge | None:
        """
        Handle 401 Unauthorized response.

        Extracts WWW-Authenticate challenge for further processing.

        Args:
            response_headers: HTTP response headers

        Returns:
            Parsed WWW-Authenticate challenge, or None if header missing
        """
        www_auth = response_headers.get("WWW-Authenticate") or response_headers.get(
            "www-authenticate"
        )
        if not www_auth:
            logger.warning("401 response missing WWW-Authenticate header")
            return None

        return parse_www_authenticate_header(www_auth)

    @staticmethod
    def should_refresh_token(challenge: WWWAuthenticateChallenge | None) -> bool:
        """
        Determine if token should be refreshed based on error.

        Args:
            challenge: Parsed WWW-Authenticate challenge

        Returns:
            True if token refresh should be attempted
        """
        if not challenge:
            return False

        # Refresh on invalid_token error
        return challenge.is_token_expired()

    @staticmethod
    def should_reauthorize(challenge: WWWAuthenticateChallenge | None) -> bool:
        """
        Determine if re-authorization is required.

        Args:
            challenge: Parsed WWW-Authenticate challenge

        Returns:
            True if re-authorization flow should be initiated
        """
        if not challenge:
            return False

        # Re-authorize on insufficient_scope or other auth errors
        return challenge.is_insufficient_scope() or (
            challenge.error and challenge.error not in ("invalid_token",)
        )

    @staticmethod
    def get_required_scopes(challenge: WWWAuthenticateChallenge | None) -> list[str]:
        """
        Extract required scopes from challenge.

        For insufficient_scope errors, this returns the scopes needed.

        Args:
            challenge: Parsed WWW-Authenticate challenge

        Returns:
            List of required scopes, empty if none specified
        """
        if not challenge:
            return []

        return challenge.scopes
sk_agents.auth.oauth_error_handler.OAuthErrorHandler.handle_401_response staticmethod
handle_401_response(
    response_headers: dict[str, str],
) -> WWWAuthenticateChallenge | None

Handle 401 Unauthorized response.

Extracts WWW-Authenticate challenge for further processing.

Parameters:

Name Type Description Default
response_headers dict[str, str]

HTTP response headers

required

Returns:

Type Description
WWWAuthenticateChallenge | None

Parsed WWW-Authenticate challenge, or None if header missing

Source code in src/sk_agents/auth/oauth_error_handler.py
@staticmethod
def handle_401_response(response_headers: dict[str, str]) -> WWWAuthenticateChallenge | None:
    """
    Handle 401 Unauthorized response.

    Extracts WWW-Authenticate challenge for further processing.

    Args:
        response_headers: HTTP response headers

    Returns:
        Parsed WWW-Authenticate challenge, or None if header missing
    """
    www_auth = response_headers.get("WWW-Authenticate") or response_headers.get(
        "www-authenticate"
    )
    if not www_auth:
        logger.warning("401 response missing WWW-Authenticate header")
        return None

    return parse_www_authenticate_header(www_auth)
sk_agents.auth.oauth_error_handler.OAuthErrorHandler.should_refresh_token staticmethod
should_refresh_token(
    challenge: WWWAuthenticateChallenge | None,
) -> bool

Determine if token should be refreshed based on error.

Parameters:

Name Type Description Default
challenge WWWAuthenticateChallenge | None

Parsed WWW-Authenticate challenge

required

Returns:

Type Description
bool

True if token refresh should be attempted

Source code in src/sk_agents/auth/oauth_error_handler.py
@staticmethod
def should_refresh_token(challenge: WWWAuthenticateChallenge | None) -> bool:
    """
    Determine if token should be refreshed based on error.

    Args:
        challenge: Parsed WWW-Authenticate challenge

    Returns:
        True if token refresh should be attempted
    """
    if not challenge:
        return False

    # Refresh on invalid_token error
    return challenge.is_token_expired()
sk_agents.auth.oauth_error_handler.OAuthErrorHandler.should_reauthorize staticmethod
should_reauthorize(
    challenge: WWWAuthenticateChallenge | None,
) -> bool

Determine if re-authorization is required.

Parameters:

Name Type Description Default
challenge WWWAuthenticateChallenge | None

Parsed WWW-Authenticate challenge

required

Returns:

Type Description
bool

True if re-authorization flow should be initiated

Source code in src/sk_agents/auth/oauth_error_handler.py
@staticmethod
def should_reauthorize(challenge: WWWAuthenticateChallenge | None) -> bool:
    """
    Determine if re-authorization is required.

    Args:
        challenge: Parsed WWW-Authenticate challenge

    Returns:
        True if re-authorization flow should be initiated
    """
    if not challenge:
        return False

    # Re-authorize on insufficient_scope or other auth errors
    return challenge.is_insufficient_scope() or (
        challenge.error and challenge.error not in ("invalid_token",)
    )
sk_agents.auth.oauth_error_handler.OAuthErrorHandler.get_required_scopes staticmethod
get_required_scopes(
    challenge: WWWAuthenticateChallenge | None,
) -> list[str]

Extract required scopes from challenge.

For insufficient_scope errors, this returns the scopes needed.

Parameters:

Name Type Description Default
challenge WWWAuthenticateChallenge | None

Parsed WWW-Authenticate challenge

required

Returns:

Type Description
list[str]

List of required scopes, empty if none specified

Source code in src/sk_agents/auth/oauth_error_handler.py
@staticmethod
def get_required_scopes(challenge: WWWAuthenticateChallenge | None) -> list[str]:
    """
    Extract required scopes from challenge.

    For insufficient_scope errors, this returns the scopes needed.

    Args:
        challenge: Parsed WWW-Authenticate challenge

    Returns:
        List of required scopes, empty if none specified
    """
    if not challenge:
        return []

    return challenge.scopes
sk_agents.auth.oauth_error_handler.parse_www_authenticate_header
parse_www_authenticate_header(
    header_value: str,
) -> WWWAuthenticateChallenge | None

Parse WWW-Authenticate header per RFC 6750 + RFC 9728.

Format

WWW-Authenticate: Bearer realm="example", error="invalid_token", error_description="The access token expired", scope="read write", resource_metadata="https://..."

Parameters:

Name Type Description Default
header_value str

Value of WWW-Authenticate header

required

Returns:

Name Type Description
WWWAuthenticateChallenge WWWAuthenticateChallenge | None

Parsed challenge, or None if not a Bearer challenge

Raises:

Type Description
ValueError

If header is malformed

Source code in src/sk_agents/auth/oauth_error_handler.py
def parse_www_authenticate_header(header_value: str) -> WWWAuthenticateChallenge | None:
    """
    Parse WWW-Authenticate header per RFC 6750 + RFC 9728.

    Format:
        WWW-Authenticate: Bearer realm="example",
                          error="invalid_token",
                          error_description="The access token expired",
                          scope="read write",
                          resource_metadata="https://..."

    Args:
        header_value: Value of WWW-Authenticate header

    Returns:
        WWWAuthenticateChallenge: Parsed challenge, or None if not a Bearer challenge

    Raises:
        ValueError: If header is malformed
    """
    if not header_value:
        return None

    # Check if it's a Bearer challenge
    if not header_value.strip().lower().startswith("bearer"):
        logger.debug(f"WWW-Authenticate header is not Bearer type: {header_value}")
        return None

    # Remove "Bearer " prefix
    params_str = header_value[6:].strip()

    # Parse parameters using regex
    # Matches: param="value" or param=value (unquoted)
    pattern = r'(\w+)=(?:"([^"]*)"|([^\s,]+))'
    matches = re.findall(pattern, params_str)

    params: dict[str, str] = {}
    for match in matches:
        param_name = match[0]
        # Use quoted value if present, otherwise unquoted
        param_value = match[1] if match[1] else match[2]
        params[param_name] = param_value

    # Build challenge object
    challenge = WWWAuthenticateChallenge(
        realm=params.get("realm"),
        error=params.get("error"),
        error_description=params.get("error_description"),
        error_uri=params.get("error_uri"),
        scope=params.get("scope"),
        resource_metadata=params.get("resource_metadata"),
    )

    logger.debug(f"Parsed WWW-Authenticate challenge: {challenge}")
    return challenge
sk_agents.auth.oauth_error_handler.extract_field_from_www_authenticate
extract_field_from_www_authenticate(
    header_value: str, field_name: str
) -> str | None

Extract a specific field from WWW-Authenticate header.

Convenience function for extracting single fields.

Parameters:

Name Type Description Default
header_value str

Value of WWW-Authenticate header

required
field_name str

Field to extract (e.g., "error", "scope", "resource_metadata")

required

Returns:

Type Description
str | None

Field value, or None if not present

Source code in src/sk_agents/auth/oauth_error_handler.py
def extract_field_from_www_authenticate(header_value: str, field_name: str) -> str | None:
    """
    Extract a specific field from WWW-Authenticate header.

    Convenience function for extracting single fields.

    Args:
        header_value: Value of WWW-Authenticate header
        field_name: Field to extract (e.g., "error", "scope", "resource_metadata")

    Returns:
        Field value, or None if not present
    """
    challenge = parse_www_authenticate_header(header_value)
    if not challenge:
        return None

    return getattr(challenge, field_name, None)
sk_agents.auth.oauth_error_handler.build_www_authenticate_header
build_www_authenticate_header(
    error: str,
    error_description: str | None = None,
    scope: str | None = None,
    realm: str | None = None,
    resource_metadata: str | None = None,
) -> str

Build WWW-Authenticate header per RFC 6750 + RFC 9728.

For use when implementing MCP servers that need to challenge clients.

Parameters:

Name Type Description Default
error str

OAuth error code (e.g., "invalid_token", "insufficient_scope")

required
error_description str | None

Human-readable error description

None
scope str | None

Required scope(s) (space-separated)

None
realm str | None

Protection realm

None
resource_metadata str | None

URL for Protected Resource Metadata (RFC 9728)

None

Returns:

Name Type Description
str str

Formatted WWW-Authenticate header value

Example

build_www_authenticate_header( ... error="insufficient_scope", ... error_description="Token lacks required scopes", ... scope="read write", ... resource_metadata="https://api.example.com/.well-known/oauth-protected-resource" ... ) # doctest: +SKIP 'Bearer error="insufficient_scope", ...'

Source code in src/sk_agents/auth/oauth_error_handler.py
def build_www_authenticate_header(
    error: str,
    error_description: str | None = None,
    scope: str | None = None,
    realm: str | None = None,
    resource_metadata: str | None = None,
) -> str:
    """
    Build WWW-Authenticate header per RFC 6750 + RFC 9728.

    For use when implementing MCP servers that need to challenge clients.

    Args:
        error: OAuth error code (e.g., "invalid_token", "insufficient_scope")
        error_description: Human-readable error description
        scope: Required scope(s) (space-separated)
        realm: Protection realm
        resource_metadata: URL for Protected Resource Metadata (RFC 9728)

    Returns:
        str: Formatted WWW-Authenticate header value

    Example:
        >>> build_www_authenticate_header(
        ...     error="insufficient_scope",
        ...     error_description="Token lacks required scopes",
        ...     scope="read write",
        ...     resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"
        ... )  # doctest: +SKIP
        'Bearer error="insufficient_scope", ...'
    """
    parts = ["Bearer"]

    if realm:
        parts.append(f'realm="{realm}"')

    parts.append(f'error="{error}"')

    if error_description:
        parts.append(f'error_description="{error_description}"')

    if scope:
        parts.append(f'scope="{scope}"')

    if resource_metadata:
        parts.append(f'resource_metadata="{resource_metadata}"')

    return ", ".join(parts)
sk_agents.auth.oauth_models

OAuth 2.1 Request and Response Models

Models for OAuth authorization flows following MCP specification.

sk_agents.auth.oauth_models.AuthorizationRequest

Bases: BaseModel

OAuth 2.1 Authorization Request

Used to construct authorization URL with all required parameters. Follows MCP spec requirement for PKCE and resource parameter.

Note: resource parameter is optional and should only be included if: - MCP protocol version >= 2025-06-18, OR - Protected Resource Metadata has been discovered

Source code in src/sk_agents/auth/oauth_models.py
class AuthorizationRequest(BaseModel):
    """
    OAuth 2.1 Authorization Request

    Used to construct authorization URL with all required parameters.
    Follows MCP spec requirement for PKCE and resource parameter.

    Note: resource parameter is optional and should only be included if:
    - MCP protocol version >= 2025-06-18, OR
    - Protected Resource Metadata has been discovered
    """

    auth_server: HttpUrl = Field(..., description="Authorization server base URL")
    authorization_endpoint: HttpUrl | None = Field(
        None, description="Discovered authorization endpoint (RFC 8414)"
    )
    client_id: str = Field(..., description="OAuth client ID")
    redirect_uri: HttpUrl = Field(..., description="OAuth callback URL")
    resource: str | None = Field(
        None,
        description="Canonical MCP server URI (resource binding) per protocol version",
    )
    scopes: list[str] = Field(..., description="Requested OAuth scopes")
    state: str = Field(..., description="CSRF protection state parameter")
    code_challenge: str = Field(..., description="PKCE code challenge (S256)")
    code_challenge_method: Literal["S256"] = Field(
        default="S256", description="PKCE challenge method (must be S256)"
    )
    response_type: Literal["code"] = Field(
        default="code", description="OAuth response type (authorization code flow)"
    )
sk_agents.auth.oauth_models.TokenRequest

Bases: BaseModel

OAuth 2.1 Token Request

Used to exchange authorization code for access token. Includes PKCE verifier and resource parameter.

Note: resource parameter is optional and should only be included if: - MCP protocol version >= 2025-06-18, OR - Protected Resource Metadata has been discovered

Source code in src/sk_agents/auth/oauth_models.py
class TokenRequest(BaseModel):
    """
    OAuth 2.1 Token Request

    Used to exchange authorization code for access token.
    Includes PKCE verifier and resource parameter.

    Note: resource parameter is optional and should only be included if:
    - MCP protocol version >= 2025-06-18, OR
    - Protected Resource Metadata has been discovered
    """

    token_endpoint: HttpUrl = Field(..., description="Token endpoint URL")
    grant_type: Literal["authorization_code", "refresh_token"] = Field(
        ..., description="OAuth grant type"
    )
    code: str | None = Field(None, description="Authorization code (for authorization_code grant)")
    refresh_token: str | None = Field(None, description="Refresh token (for refresh_token grant)")
    redirect_uri: HttpUrl | None = Field(None, description="OAuth callback URL (must match)")
    code_verifier: str | None = Field(None, description="PKCE code verifier")
    resource: str | None = Field(
        None,
        description="Canonical MCP server URI (resource binding) per protocol version",
    )
    client_id: str = Field(..., description="OAuth client ID")
    client_secret: str | None = Field(
        None, description="OAuth client secret (confidential clients only)"
    )
    requested_scopes: list[str] | None = Field(
        None, description="Requested scopes for validation (prevents escalation attacks)"
    )
sk_agents.auth.oauth_models.TokenResponse

Bases: BaseModel

OAuth 2.1 Token Response

Token endpoint response with access token and metadata.

Source code in src/sk_agents/auth/oauth_models.py
class TokenResponse(BaseModel):
    """
    OAuth 2.1 Token Response

    Token endpoint response with access token and metadata.
    """

    access_token: str = Field(..., description="OAuth access token")
    token_type: str = Field(..., description="Token type (usually 'Bearer')")
    expires_in: int = Field(..., description="Token lifetime in seconds")
    refresh_token: str | None = Field(None, description="Refresh token (optional)")
    scope: str | None = Field(None, description="Granted scopes (space-separated)")
    aud: str | None = Field(None, description="Token audience (for validation)")
sk_agents.auth.oauth_models.RefreshTokenRequest

Bases: BaseModel

OAuth 2.1 Refresh Token Request

Request to refresh an expired access token.

Note: resource parameter is optional and should only be included if: - MCP protocol version >= 2025-06-18, OR - Protected Resource Metadata has been discovered - Must match the original authorization request resource if included

Source code in src/sk_agents/auth/oauth_models.py
class RefreshTokenRequest(BaseModel):
    """
    OAuth 2.1 Refresh Token Request

    Request to refresh an expired access token.

    Note: resource parameter is optional and should only be included if:
    - MCP protocol version >= 2025-06-18, OR
    - Protected Resource Metadata has been discovered
    - Must match the original authorization request resource if included
    """

    token_endpoint: HttpUrl = Field(..., description="Token endpoint URL")
    refresh_token: str = Field(..., description="Refresh token")
    resource: str | None = Field(
        None,
        description="Canonical MCP server URI (must match original) per protocol version",
    )
    client_id: str = Field(..., description="OAuth client ID")
    client_secret: str | None = Field(
        None, description="OAuth client secret (confidential clients only)"
    )
    grant_type: Literal["refresh_token"] = Field(
        default="refresh_token", description="OAuth grant type"
    )
    requested_scopes: list[str] | None = Field(
        None, description="Original requested scopes for validation (prevents escalation)"
    )
sk_agents.auth.oauth_models.OAuthError

Bases: BaseModel

OAuth Error Response

Parsed from WWW-Authenticate header or token endpoint error response.

Source code in src/sk_agents/auth/oauth_models.py
class OAuthError(BaseModel):
    """
    OAuth Error Response

    Parsed from WWW-Authenticate header or token endpoint error response.
    """

    error: str = Field(..., description="Error code (e.g., 'invalid_token', 'insufficient_scope')")
    error_description: str | None = Field(None, description="Human-readable error description")
    error_uri: str | None = Field(None, description="URL with error information")
    oauth_server_metadata_url: str | None = Field(
        None, description="Authorization server metadata URL (from WWW-Authenticate)"
    )
sk_agents.auth.oauth_models.MCP401Response

Bases: BaseModel

MCP-compliant 401 Unauthorized response.

Per MCP spec, servers should return WWW-Authenticate header with: - error: Error code - error_description: Human-readable description - scope: Required scopes (for insufficient_scope) - resource_metadata: URL for RFC 9728 discovery (optional)

Source code in src/sk_agents/auth/oauth_models.py
class MCP401Response(BaseModel):
    """
    MCP-compliant 401 Unauthorized response.

    Per MCP spec, servers should return WWW-Authenticate header with:
    - error: Error code
    - error_description: Human-readable description
    - scope: Required scopes (for insufficient_scope)
    - resource_metadata: URL for RFC 9728 discovery (optional)
    """

    www_authenticate: str = Field(..., description="WWW-Authenticate header value")
    error_code: int = Field(401, description="HTTP status code")
    error_message: str = Field(
        "Authentication required", description="Human-readable error message"
    )
sk_agents.auth.oauth_models.MCP403Response

Bases: BaseModel

MCP-compliant 403 Forbidden response.

Source code in src/sk_agents/auth/oauth_models.py
class MCP403Response(BaseModel):
    """MCP-compliant 403 Forbidden response."""

    error_code: int = Field(403, description="HTTP status code")
    error_message: str = Field(
        "Insufficient permissions", description="Human-readable error message"
    )
    required_scopes: list[str] | None = Field(
        None, description="Scopes required for this operation"
    )
sk_agents.auth.oauth_pkce

PKCE (Proof Key for Code Exchange) Implementation

Implements PKCE as required by OAuth 2.1 and MCP specification. PKCE prevents authorization code interception attacks.

References: - OAuth 2.1 Section 7.5.2 - RFC 7636: Proof Key for Code Exchange

sk_agents.auth.oauth_pkce.PKCEManager

Manager for PKCE generation and validation.

Provides high-level interface for PKCE operations in OAuth flows.

Source code in src/sk_agents/auth/oauth_pkce.py
class PKCEManager:
    """
    Manager for PKCE generation and validation.

    Provides high-level interface for PKCE operations in OAuth flows.
    """

    @staticmethod
    def generate_pkce_pair() -> tuple[str, str]:
        """
        Generate PKCE verifier and challenge pair.

        Returns:
            tuple: (verifier, challenge)
        """
        verifier = generate_code_verifier()
        challenge = generate_code_challenge(verifier)
        return verifier, challenge

    @staticmethod
    def validate_verifier(verifier: str) -> bool:
        """
        Validate code verifier meets requirements.

        Args:
            verifier: Code verifier to validate

        Returns:
            bool: True if valid
        """
        return validate_code_verifier(verifier)

    @staticmethod
    def verify_challenge(verifier: str, challenge: str) -> bool:
        """
        Verify that challenge matches verifier.

        Used by authorization server (not typically by client).

        Args:
            verifier: Code verifier
            challenge: Code challenge to verify

        Returns:
            bool: True if challenge matches verifier
        """
        expected_challenge = generate_code_challenge(verifier)
        return expected_challenge == challenge
sk_agents.auth.oauth_pkce.PKCEManager.generate_pkce_pair staticmethod
generate_pkce_pair() -> tuple[str, str]

Generate PKCE verifier and challenge pair.

Returns:

Name Type Description
tuple tuple[str, str]

(verifier, challenge)

Source code in src/sk_agents/auth/oauth_pkce.py
@staticmethod
def generate_pkce_pair() -> tuple[str, str]:
    """
    Generate PKCE verifier and challenge pair.

    Returns:
        tuple: (verifier, challenge)
    """
    verifier = generate_code_verifier()
    challenge = generate_code_challenge(verifier)
    return verifier, challenge
sk_agents.auth.oauth_pkce.PKCEManager.validate_verifier staticmethod
validate_verifier(verifier: str) -> bool

Validate code verifier meets requirements.

Parameters:

Name Type Description Default
verifier str

Code verifier to validate

required

Returns:

Name Type Description
bool bool

True if valid

Source code in src/sk_agents/auth/oauth_pkce.py
@staticmethod
def validate_verifier(verifier: str) -> bool:
    """
    Validate code verifier meets requirements.

    Args:
        verifier: Code verifier to validate

    Returns:
        bool: True if valid
    """
    return validate_code_verifier(verifier)
sk_agents.auth.oauth_pkce.PKCEManager.verify_challenge staticmethod
verify_challenge(verifier: str, challenge: str) -> bool

Verify that challenge matches verifier.

Used by authorization server (not typically by client).

Parameters:

Name Type Description Default
verifier str

Code verifier

required
challenge str

Code challenge to verify

required

Returns:

Name Type Description
bool bool

True if challenge matches verifier

Source code in src/sk_agents/auth/oauth_pkce.py
@staticmethod
def verify_challenge(verifier: str, challenge: str) -> bool:
    """
    Verify that challenge matches verifier.

    Used by authorization server (not typically by client).

    Args:
        verifier: Code verifier
        challenge: Code challenge to verify

    Returns:
        bool: True if challenge matches verifier
    """
    expected_challenge = generate_code_challenge(verifier)
    return expected_challenge == challenge
sk_agents.auth.oauth_pkce.generate_code_verifier
generate_code_verifier() -> str

Generate cryptographically random code verifier.

Per OAuth 2.1 spec, code verifier must be: - 43-128 characters long - Use characters [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~"

Returns:

Name Type Description
str str

Base64url-encoded random verifier (43-128 chars)

Source code in src/sk_agents/auth/oauth_pkce.py
def generate_code_verifier() -> str:
    """
    Generate cryptographically random code verifier.

    Per OAuth 2.1 spec, code verifier must be:
    - 43-128 characters long
    - Use characters [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~"

    Returns:
        str: Base64url-encoded random verifier (43-128 chars)
    """
    # Generate 32 random bytes (provides 43 base64url characters)
    random_bytes = secrets.token_bytes(32)
    # Base64url encode (URL-safe, no padding)
    verifier = base64.urlsafe_b64encode(random_bytes).decode("utf-8").rstrip("=")
    return verifier
sk_agents.auth.oauth_pkce.generate_code_challenge
generate_code_challenge(verifier: str) -> str

Generate PKCE code challenge from verifier using S256 method.

Per OAuth 2.1 spec: - challenge = BASE64URL(SHA256(verifier))

Parameters:

Name Type Description Default
verifier str

Code verifier from generate_code_verifier()

required

Returns:

Name Type Description
str str

Base64url-encoded SHA256 hash of verifier

Source code in src/sk_agents/auth/oauth_pkce.py
def generate_code_challenge(verifier: str) -> str:
    """
    Generate PKCE code challenge from verifier using S256 method.

    Per OAuth 2.1 spec:
    - challenge = BASE64URL(SHA256(verifier))

    Args:
        verifier: Code verifier from generate_code_verifier()

    Returns:
        str: Base64url-encoded SHA256 hash of verifier
    """
    # SHA256 hash
    sha256_hash = hashlib.sha256(verifier.encode("utf-8")).digest()
    # Base64url encode (URL-safe, no padding)
    challenge = base64.urlsafe_b64encode(sha256_hash).decode("utf-8").rstrip("=")
    return challenge
sk_agents.auth.oauth_pkce.validate_code_verifier
validate_code_verifier(verifier: str) -> bool

Validate code verifier meets OAuth 2.1 requirements.

Parameters:

Name Type Description Default
verifier str

Code verifier to validate

required

Returns:

Name Type Description
bool bool

True if valid, False otherwise

Source code in src/sk_agents/auth/oauth_pkce.py
def validate_code_verifier(verifier: str) -> bool:
    """
    Validate code verifier meets OAuth 2.1 requirements.

    Args:
        verifier: Code verifier to validate

    Returns:
        bool: True if valid, False otherwise
    """
    # Check length (43-128 characters)
    if not (43 <= len(verifier) <= 128):
        return False

    # Check allowed characters
    allowed_chars = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~")
    if not all(c in allowed_chars for c in verifier):
        return False

    return True
sk_agents.auth.oauth_state_manager

OAuth State Manager

Manages OAuth flow state for CSRF protection. Stores state parameter + PKCE verifier temporarily during OAuth flow.

Implementation uses AuthStorage with temporary keys and TTL.

sk_agents.auth.oauth_state_manager.OAuthFlowState

Represents temporary OAuth flow state.

Stored during authorization request, retrieved during callback.

Source code in src/sk_agents/auth/oauth_state_manager.py
class OAuthFlowState:
    """
    Represents temporary OAuth flow state.

    Stored during authorization request, retrieved during callback.
    """

    def __init__(
        self,
        state: str,
        verifier: str,
        user_id: str,
        server_name: str,
        resource: str,
        scopes: list[str],
        created_at: datetime,
    ):
        self.state = state
        self.verifier = verifier
        self.user_id = user_id
        self.server_name = server_name
        self.resource = resource
        self.scopes = scopes
        self.created_at = created_at

    def to_dict(self) -> dict[str, Any]:
        """Serialize to dict for storage"""
        return {
            "state": self.state,
            "verifier": self.verifier,
            "user_id": self.user_id,
            "server_name": self.server_name,
            "resource": self.resource,
            "scopes": self.scopes,
            "created_at": self.created_at.isoformat(),
        }

    @classmethod
    def from_dict(cls, data: dict[str, Any]) -> "OAuthFlowState":
        """Deserialize from storage dict"""
        return cls(
            state=data["state"],
            verifier=data["verifier"],
            user_id=data["user_id"],
            server_name=data["server_name"],
            resource=data["resource"],
            scopes=data["scopes"],
            created_at=datetime.fromisoformat(data["created_at"]),
        )

    def is_expired(self, ttl_seconds: int = 300) -> bool:
        """Check if flow state has expired (default 5 minutes)"""
        expires_at = self.created_at + timedelta(seconds=ttl_seconds)
        return datetime.now(UTC) > expires_at
sk_agents.auth.oauth_state_manager.OAuthFlowState.to_dict
to_dict() -> dict[str, Any]

Serialize to dict for storage

Source code in src/sk_agents/auth/oauth_state_manager.py
def to_dict(self) -> dict[str, Any]:
    """Serialize to dict for storage"""
    return {
        "state": self.state,
        "verifier": self.verifier,
        "user_id": self.user_id,
        "server_name": self.server_name,
        "resource": self.resource,
        "scopes": self.scopes,
        "created_at": self.created_at.isoformat(),
    }
sk_agents.auth.oauth_state_manager.OAuthFlowState.from_dict classmethod
from_dict(data: dict[str, Any]) -> OAuthFlowState

Deserialize from storage dict

Source code in src/sk_agents/auth/oauth_state_manager.py
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "OAuthFlowState":
    """Deserialize from storage dict"""
    return cls(
        state=data["state"],
        verifier=data["verifier"],
        user_id=data["user_id"],
        server_name=data["server_name"],
        resource=data["resource"],
        scopes=data["scopes"],
        created_at=datetime.fromisoformat(data["created_at"]),
    )
sk_agents.auth.oauth_state_manager.OAuthFlowState.is_expired
is_expired(ttl_seconds: int = 300) -> bool

Check if flow state has expired (default 5 minutes)

Source code in src/sk_agents/auth/oauth_state_manager.py
def is_expired(self, ttl_seconds: int = 300) -> bool:
    """Check if flow state has expired (default 5 minutes)"""
    expires_at = self.created_at + timedelta(seconds=ttl_seconds)
    return datetime.now(UTC) > expires_at
sk_agents.auth.oauth_state_manager.OAuthStateManager

Manager for OAuth flow state and CSRF protection.

Uses AuthStorage with temporary keys to store state during OAuth flow.

Source code in src/sk_agents/auth/oauth_state_manager.py
class OAuthStateManager:
    """
    Manager for OAuth flow state and CSRF protection.

    Uses AuthStorage with temporary keys to store state during OAuth flow.
    """

    # Use a special user_id for temporary OAuth flow state
    TEMP_USER_PREFIX = "oauth_flow_temp"

    def __init__(self, ttl_seconds: int = 300):
        """
        Initialize state manager.

        Args:
            ttl_seconds: Time-to-live for state (default 5 minutes)
        """
        self.ttl_seconds = ttl_seconds
        self.auth_storage_factory = AuthStorageFactory(AppConfig())
        self.auth_storage = self.auth_storage_factory.get_auth_storage_manager()

    @staticmethod
    def generate_state() -> str:
        """
        Generate cryptographically random state parameter.

        Returns:
            str: Random state string (URL-safe, 32 bytes)
        """
        return secrets.token_urlsafe(32)

    def store_flow_state(
        self,
        state: str,
        verifier: str,
        user_id: str,
        server_name: str,
        resource: str,
        scopes: list[str],
    ) -> None:
        """
        Store OAuth flow state temporarily.

        Stores in two locations:
        1. User-specific key for validation: oauth_flow_temp:{user_id}
        2. State-only key for callback retrieval: oauth_flow_temp:by_state

        Args:
            state: CSRF state parameter
            verifier: PKCE code verifier
            user_id: User ID for this flow
            server_name: MCP server name
            resource: Canonical server URI
            scopes: Requested scopes
        """
        flow_state = OAuthFlowState(
            state=state,
            verifier=verifier,
            user_id=user_id,
            server_name=server_name,
            resource=resource,
            scopes=scopes,
            created_at=datetime.now(UTC),
        )

        # Store with temporary key
        temp_key = f"oauth_state:{state}"

        # Note: Current AuthStorage doesn't support TTL natively
        # We'll implement expiry check on retrieval
        # For production, consider Redis or other storage with native TTL
        try:
            # Store with user-specific key (for retrieve_flow_state with user_id)
            temp_user = f"{self.TEMP_USER_PREFIX}:{user_id}"
            self.auth_storage.store(temp_user, temp_key, flow_state.to_dict())

            # Also store with state-only key (for OAuth callback without user_id)
            state_only_user = f"{self.TEMP_USER_PREFIX}:by_state"
            self.auth_storage.store(state_only_user, temp_key, flow_state.to_dict())

            logger.debug(f"Stored OAuth flow state for state={state}, user={user_id}")
        except Exception as e:
            logger.error(f"Failed to store OAuth flow state: {e}")
            raise

    def retrieve_flow_state(self, state: str, user_id: str) -> OAuthFlowState:
        """
        Retrieve and validate OAuth flow state.

        Args:
            state: CSRF state parameter from callback
            user_id: User ID to validate against

        Returns:
            OAuthFlowState: Retrieved flow state

        Raises:
            ValueError: If state not found, expired, or user_id mismatch
        """
        temp_key = f"oauth_state:{state}"
        temp_user = f"{self.TEMP_USER_PREFIX}:{user_id}"

        try:
            # Retrieve from storage
            data = self.auth_storage.retrieve(temp_user, temp_key)

            if not data:
                logger.warning(f"OAuth flow state not found for state={state}")
                raise ValueError("Invalid or expired OAuth state")

            # Handle both dict and object storage
            if not isinstance(data, dict):
                # If AuthStorage returns an object, try to convert
                if hasattr(data, "to_dict"):
                    data = data.to_dict()
                elif hasattr(data, "__dict__"):
                    data = data.__dict__
                else:
                    logger.error(f"Unexpected flow state data type: {type(data)}")
                    raise ValueError("Invalid OAuth flow state data")

            flow_state = OAuthFlowState.from_dict(data)

            # Validate expiry
            if flow_state.is_expired(self.ttl_seconds):
                logger.warning(f"OAuth flow state expired for state={state}")
                # Clean up expired state
                self.delete_flow_state(state, user_id)
                raise ValueError("OAuth state expired")

            # Validate user_id (CSRF protection)
            if flow_state.user_id != user_id:
                logger.error(
                    f"OAuth flow user_id mismatch: expected={flow_state.user_id}, got={user_id}"
                )
                raise ValueError("OAuth state user mismatch (CSRF attempt?)")

            logger.debug(f"Retrieved valid OAuth flow state for state={state}, user={user_id}")
            return flow_state

        except Exception as e:
            logger.error(f"Failed to retrieve OAuth flow state: {e}")
            raise

    def retrieve_flow_state_by_state_only(self, state: str) -> OAuthFlowState:
        """
        Retrieve OAuth flow state using only the state parameter.

        This is used in OAuth callbacks where we don't have user_id upfront.
        The flow state contains user_id which we extract after retrieval.

        Note: This method attempts retrieval by trying common patterns.
        For production, consider using a state→user_id mapping or encoding
        user_id in the state parameter itself.

        Args:
            state: CSRF state parameter from callback

        Returns:
            OAuthFlowState: Retrieved flow state with embedded user_id

        Raises:
            ValueError: If state not found or expired
        """
        temp_key = f"oauth_state:{state}"

        try:
            # First, try to retrieve with a wildcard pattern
            # Since AuthStorage is user-scoped, we need to iterate
            # This is inefficient but works for now
            # TODO: Implement better storage pattern (e.g., state→user_id mapping)

            # For now, we'll use a simplified approach:
            # Store flow state with a well-known temporary user that doesn't include user_id
            # We'll modify store_flow_state to support this

            # Attempt to retrieve with state-only key
            state_only_user = f"{self.TEMP_USER_PREFIX}:by_state"
            data = self.auth_storage.retrieve(state_only_user, temp_key)

            if not data:
                logger.warning(f"OAuth flow state not found for state={state}")
                raise ValueError("Invalid or expired OAuth state")

            # Handle both dict and object storage
            if not isinstance(data, dict):
                if hasattr(data, "to_dict"):
                    data = data.to_dict()
                elif hasattr(data, "__dict__"):
                    data = data.__dict__
                else:
                    logger.error(f"Unexpected flow state data type: {type(data)}")
                    raise ValueError("Invalid OAuth flow state data")

            flow_state = OAuthFlowState.from_dict(data)

            # Validate expiry
            if flow_state.is_expired(self.ttl_seconds):
                logger.warning(f"OAuth flow state expired for state={state}")
                raise ValueError("OAuth state expired")

            logger.debug(f"Retrieved OAuth flow state for state={state}, user={flow_state.user_id}")
            return flow_state

        except Exception as e:
            logger.error(f"Failed to retrieve OAuth flow state by state only: {e}")
            raise

    def delete_flow_state(self, state: str, user_id: str) -> None:
        """
        Delete OAuth flow state after use or expiry.

        Args:
            state: CSRF state parameter
            user_id: User ID
        """
        temp_key = f"oauth_state:{state}"
        temp_user = f"{self.TEMP_USER_PREFIX}:{user_id}"
        state_only_user = f"{self.TEMP_USER_PREFIX}:by_state"

        try:
            # Delete from user-specific storage
            self.auth_storage.delete(temp_user, temp_key)
            logger.debug(f"Deleted OAuth flow state for state={state}, user={user_id}")

            # Also delete from state-only storage
            try:
                self.auth_storage.delete(state_only_user, temp_key)
                logger.debug(f"Deleted state-only OAuth flow state for state={state}")
            except Exception as e:
                logger.debug(f"Failed to delete state-only flow state (non-critical): {e}")

        except Exception as e:
            logger.warning(f"Failed to delete OAuth flow state: {e}")
sk_agents.auth.oauth_state_manager.OAuthStateManager.__init__
__init__(ttl_seconds: int = 300)

Initialize state manager.

Parameters:

Name Type Description Default
ttl_seconds int

Time-to-live for state (default 5 minutes)

300
Source code in src/sk_agents/auth/oauth_state_manager.py
def __init__(self, ttl_seconds: int = 300):
    """
    Initialize state manager.

    Args:
        ttl_seconds: Time-to-live for state (default 5 minutes)
    """
    self.ttl_seconds = ttl_seconds
    self.auth_storage_factory = AuthStorageFactory(AppConfig())
    self.auth_storage = self.auth_storage_factory.get_auth_storage_manager()
sk_agents.auth.oauth_state_manager.OAuthStateManager.generate_state staticmethod
generate_state() -> str

Generate cryptographically random state parameter.

Returns:

Name Type Description
str str

Random state string (URL-safe, 32 bytes)

Source code in src/sk_agents/auth/oauth_state_manager.py
@staticmethod
def generate_state() -> str:
    """
    Generate cryptographically random state parameter.

    Returns:
        str: Random state string (URL-safe, 32 bytes)
    """
    return secrets.token_urlsafe(32)
sk_agents.auth.oauth_state_manager.OAuthStateManager.store_flow_state
store_flow_state(
    state: str,
    verifier: str,
    user_id: str,
    server_name: str,
    resource: str,
    scopes: list[str],
) -> None

Store OAuth flow state temporarily.

Stores in two locations: 1. User-specific key for validation: oauth_flow_temp:{user_id} 2. State-only key for callback retrieval: oauth_flow_temp:by_state

Parameters:

Name Type Description Default
state str

CSRF state parameter

required
verifier str

PKCE code verifier

required
user_id str

User ID for this flow

required
server_name str

MCP server name

required
resource str

Canonical server URI

required
scopes list[str]

Requested scopes

required
Source code in src/sk_agents/auth/oauth_state_manager.py
def store_flow_state(
    self,
    state: str,
    verifier: str,
    user_id: str,
    server_name: str,
    resource: str,
    scopes: list[str],
) -> None:
    """
    Store OAuth flow state temporarily.

    Stores in two locations:
    1. User-specific key for validation: oauth_flow_temp:{user_id}
    2. State-only key for callback retrieval: oauth_flow_temp:by_state

    Args:
        state: CSRF state parameter
        verifier: PKCE code verifier
        user_id: User ID for this flow
        server_name: MCP server name
        resource: Canonical server URI
        scopes: Requested scopes
    """
    flow_state = OAuthFlowState(
        state=state,
        verifier=verifier,
        user_id=user_id,
        server_name=server_name,
        resource=resource,
        scopes=scopes,
        created_at=datetime.now(UTC),
    )

    # Store with temporary key
    temp_key = f"oauth_state:{state}"

    # Note: Current AuthStorage doesn't support TTL natively
    # We'll implement expiry check on retrieval
    # For production, consider Redis or other storage with native TTL
    try:
        # Store with user-specific key (for retrieve_flow_state with user_id)
        temp_user = f"{self.TEMP_USER_PREFIX}:{user_id}"
        self.auth_storage.store(temp_user, temp_key, flow_state.to_dict())

        # Also store with state-only key (for OAuth callback without user_id)
        state_only_user = f"{self.TEMP_USER_PREFIX}:by_state"
        self.auth_storage.store(state_only_user, temp_key, flow_state.to_dict())

        logger.debug(f"Stored OAuth flow state for state={state}, user={user_id}")
    except Exception as e:
        logger.error(f"Failed to store OAuth flow state: {e}")
        raise
sk_agents.auth.oauth_state_manager.OAuthStateManager.retrieve_flow_state
retrieve_flow_state(
    state: str, user_id: str
) -> OAuthFlowState

Retrieve and validate OAuth flow state.

Parameters:

Name Type Description Default
state str

CSRF state parameter from callback

required
user_id str

User ID to validate against

required

Returns:

Name Type Description
OAuthFlowState OAuthFlowState

Retrieved flow state

Raises:

Type Description
ValueError

If state not found, expired, or user_id mismatch

Source code in src/sk_agents/auth/oauth_state_manager.py
def retrieve_flow_state(self, state: str, user_id: str) -> OAuthFlowState:
    """
    Retrieve and validate OAuth flow state.

    Args:
        state: CSRF state parameter from callback
        user_id: User ID to validate against

    Returns:
        OAuthFlowState: Retrieved flow state

    Raises:
        ValueError: If state not found, expired, or user_id mismatch
    """
    temp_key = f"oauth_state:{state}"
    temp_user = f"{self.TEMP_USER_PREFIX}:{user_id}"

    try:
        # Retrieve from storage
        data = self.auth_storage.retrieve(temp_user, temp_key)

        if not data:
            logger.warning(f"OAuth flow state not found for state={state}")
            raise ValueError("Invalid or expired OAuth state")

        # Handle both dict and object storage
        if not isinstance(data, dict):
            # If AuthStorage returns an object, try to convert
            if hasattr(data, "to_dict"):
                data = data.to_dict()
            elif hasattr(data, "__dict__"):
                data = data.__dict__
            else:
                logger.error(f"Unexpected flow state data type: {type(data)}")
                raise ValueError("Invalid OAuth flow state data")

        flow_state = OAuthFlowState.from_dict(data)

        # Validate expiry
        if flow_state.is_expired(self.ttl_seconds):
            logger.warning(f"OAuth flow state expired for state={state}")
            # Clean up expired state
            self.delete_flow_state(state, user_id)
            raise ValueError("OAuth state expired")

        # Validate user_id (CSRF protection)
        if flow_state.user_id != user_id:
            logger.error(
                f"OAuth flow user_id mismatch: expected={flow_state.user_id}, got={user_id}"
            )
            raise ValueError("OAuth state user mismatch (CSRF attempt?)")

        logger.debug(f"Retrieved valid OAuth flow state for state={state}, user={user_id}")
        return flow_state

    except Exception as e:
        logger.error(f"Failed to retrieve OAuth flow state: {e}")
        raise
sk_agents.auth.oauth_state_manager.OAuthStateManager.retrieve_flow_state_by_state_only
retrieve_flow_state_by_state_only(
    state: str,
) -> OAuthFlowState

Retrieve OAuth flow state using only the state parameter.

This is used in OAuth callbacks where we don't have user_id upfront. The flow state contains user_id which we extract after retrieval.

Note: This method attempts retrieval by trying common patterns. For production, consider using a state→user_id mapping or encoding user_id in the state parameter itself.

Parameters:

Name Type Description Default
state str

CSRF state parameter from callback

required

Returns:

Name Type Description
OAuthFlowState OAuthFlowState

Retrieved flow state with embedded user_id

Raises:

Type Description
ValueError

If state not found or expired

Source code in src/sk_agents/auth/oauth_state_manager.py
def retrieve_flow_state_by_state_only(self, state: str) -> OAuthFlowState:
    """
    Retrieve OAuth flow state using only the state parameter.

    This is used in OAuth callbacks where we don't have user_id upfront.
    The flow state contains user_id which we extract after retrieval.

    Note: This method attempts retrieval by trying common patterns.
    For production, consider using a state→user_id mapping or encoding
    user_id in the state parameter itself.

    Args:
        state: CSRF state parameter from callback

    Returns:
        OAuthFlowState: Retrieved flow state with embedded user_id

    Raises:
        ValueError: If state not found or expired
    """
    temp_key = f"oauth_state:{state}"

    try:
        # First, try to retrieve with a wildcard pattern
        # Since AuthStorage is user-scoped, we need to iterate
        # This is inefficient but works for now
        # TODO: Implement better storage pattern (e.g., state→user_id mapping)

        # For now, we'll use a simplified approach:
        # Store flow state with a well-known temporary user that doesn't include user_id
        # We'll modify store_flow_state to support this

        # Attempt to retrieve with state-only key
        state_only_user = f"{self.TEMP_USER_PREFIX}:by_state"
        data = self.auth_storage.retrieve(state_only_user, temp_key)

        if not data:
            logger.warning(f"OAuth flow state not found for state={state}")
            raise ValueError("Invalid or expired OAuth state")

        # Handle both dict and object storage
        if not isinstance(data, dict):
            if hasattr(data, "to_dict"):
                data = data.to_dict()
            elif hasattr(data, "__dict__"):
                data = data.__dict__
            else:
                logger.error(f"Unexpected flow state data type: {type(data)}")
                raise ValueError("Invalid OAuth flow state data")

        flow_state = OAuthFlowState.from_dict(data)

        # Validate expiry
        if flow_state.is_expired(self.ttl_seconds):
            logger.warning(f"OAuth flow state expired for state={state}")
            raise ValueError("OAuth state expired")

        logger.debug(f"Retrieved OAuth flow state for state={state}, user={flow_state.user_id}")
        return flow_state

    except Exception as e:
        logger.error(f"Failed to retrieve OAuth flow state by state only: {e}")
        raise
sk_agents.auth.oauth_state_manager.OAuthStateManager.delete_flow_state
delete_flow_state(state: str, user_id: str) -> None

Delete OAuth flow state after use or expiry.

Parameters:

Name Type Description Default
state str

CSRF state parameter

required
user_id str

User ID

required
Source code in src/sk_agents/auth/oauth_state_manager.py
def delete_flow_state(self, state: str, user_id: str) -> None:
    """
    Delete OAuth flow state after use or expiry.

    Args:
        state: CSRF state parameter
        user_id: User ID
    """
    temp_key = f"oauth_state:{state}"
    temp_user = f"{self.TEMP_USER_PREFIX}:{user_id}"
    state_only_user = f"{self.TEMP_USER_PREFIX}:by_state"

    try:
        # Delete from user-specific storage
        self.auth_storage.delete(temp_user, temp_key)
        logger.debug(f"Deleted OAuth flow state for state={state}, user={user_id}")

        # Also delete from state-only storage
        try:
            self.auth_storage.delete(state_only_user, temp_key)
            logger.debug(f"Deleted state-only OAuth flow state for state={state}")
        except Exception as e:
            logger.debug(f"Failed to delete state-only flow state (non-critical): {e}")

    except Exception as e:
        logger.warning(f"Failed to delete OAuth flow state: {e}")
sk_agents.auth.server_metadata

Authorization Server Metadata Discovery

Implements server metadata discovery per RFC8414 and RFC9728. Used for dynamic discovery of OAuth endpoints and capabilities.

References: - RFC 8414: OAuth 2.0 Authorization Server Metadata - RFC 9728: OAuth 2.0 Protected Resource Metadata

sk_agents.auth.server_metadata.AuthServerMetadata

Bases: BaseModel

OAuth 2.0 Authorization Server Metadata (RFC8414)

Discovered from {auth_server}/.well-known/oauth-authorization-server

Source code in src/sk_agents/auth/server_metadata.py
class AuthServerMetadata(BaseModel):
    """
    OAuth 2.0 Authorization Server Metadata (RFC8414)

    Discovered from {auth_server}/.well-known/oauth-authorization-server
    """

    issuer: HttpUrl
    authorization_endpoint: HttpUrl
    token_endpoint: HttpUrl
    revocation_endpoint: HttpUrl | None = None
    registration_endpoint: HttpUrl | None = None
    response_types_supported: list[str]
    grant_types_supported: list[str] | None = None
    code_challenge_methods_supported: list[str] | None = None
    scopes_supported: list[str] | None = None
sk_agents.auth.server_metadata.ProtectedResourceMetadata

Bases: BaseModel

OAuth 2.0 Protected Resource Metadata (RFC9728)

Discovered from {mcp_server}/.well-known/oauth-protected-resource

Source code in src/sk_agents/auth/server_metadata.py
class ProtectedResourceMetadata(BaseModel):
    """
    OAuth 2.0 Protected Resource Metadata (RFC9728)

    Discovered from {mcp_server}/.well-known/oauth-protected-resource
    """

    resource: HttpUrl
    authorization_servers: list[HttpUrl]
    scopes_supported: list[str] | None = None
    bearer_methods_supported: list[str] | None = None
sk_agents.auth.server_metadata.ServerMetadataCache

Cache for server metadata to avoid repeated discovery requests.

Implements RFC 8414 and RFC 9728 discovery with TTL-based caching.

Source code in src/sk_agents/auth/server_metadata.py
class ServerMetadataCache:
    """
    Cache for server metadata to avoid repeated discovery requests.

    Implements RFC 8414 and RFC 9728 discovery with TTL-based caching.
    """

    def __init__(self, timeout: float = 30.0, ttl: int = 3600):
        """
        Initialize metadata cache.

        Args:
            timeout: HTTP request timeout in seconds (default: 30)
            ttl: Cache TTL in seconds (default: 3600 = 1 hour)
        """
        self.timeout = timeout
        self.ttl = ttl
        self._cache: dict[str, tuple[Any, datetime]] = {}
        self._lock = asyncio.Lock()

    async def fetch_auth_server_metadata(self, auth_server: str) -> AuthServerMetadata:
        """
        Fetch authorization server metadata from well-known endpoint.

        Per RFC 8414, discovers OAuth endpoints from:
        {auth_server}/.well-known/oauth-authorization-server

        Args:
            auth_server: Authorization server base URL

        Returns:
            AuthServerMetadata: Parsed metadata

        Raises:
            httpx.HTTPError: If discovery fails
            ValueError: If metadata is invalid
        """
        # Check cache first
        async with self._lock:
            if auth_server in self._cache:
                metadata, cached_at = self._cache[auth_server]
                if datetime.now(UTC) - cached_at < timedelta(seconds=self.ttl):
                    logger.debug(f"Cache hit for auth server metadata: {auth_server}")
                    return metadata

        # Fetch from well-known endpoint
        well_known_url = f"{auth_server.rstrip('/')}/.well-known/oauth-authorization-server"
        logger.info(f"Discovering authorization server metadata from {well_known_url}")

        data = None  # Initialize to avoid scope issues
        try:
            async with httpx.AsyncClient(timeout=self.timeout) as client:
                response = await client.get(well_known_url)
                response.raise_for_status()
                data = response.json()

            # Parse and validate
            metadata = AuthServerMetadata(**data)

            # Validate PKCE support (MCP requirement)
            if metadata.code_challenge_methods_supported:
                if "S256" not in metadata.code_challenge_methods_supported:
                    logger.warning(
                        f"Auth server {auth_server} does not advertise S256 PKCE support. "
                        f"Supported methods: {metadata.code_challenge_methods_supported}"
                    )
            else:
                logger.warning(
                    f"Auth server {auth_server} does not advertise code_challenge_methods_supported"
                )

            logger.info(
                f"Successfully discovered metadata for {auth_server}: "
                f"authorization_endpoint={metadata.authorization_endpoint}, "
                f"token_endpoint={metadata.token_endpoint}"
            )

            # Cache result
            async with self._lock:
                self._cache[auth_server] = (metadata, datetime.now(UTC))

            return metadata

        except httpx.HTTPStatusError as e:
            logger.error(
                f"Failed to fetch authorization server metadata from {well_known_url}: "
                f"HTTP {e.response.status_code}"
            )
            raise
        except httpx.HTTPError as e:
            logger.error(f"Network error fetching authorization server metadata: {e}")
            raise
        except Exception as e:
            logger.error(f"Failed to parse authorization server metadata: {e}")
            raise ValueError(f"Invalid authorization server metadata: {e}") from e

    async def fetch_protected_resource_metadata(
        self, mcp_server: str
    ) -> ProtectedResourceMetadata | None:
        """
        Fetch protected resource metadata from MCP server.

        Per RFC 9728, discovers resource metadata from:
        {mcp_server}/.well-known/oauth-protected-resource

        Note: This metadata is OPTIONAL per RFC 9728. Returns None if not available.

        Args:
            mcp_server: MCP server base URL

        Returns:
            ProtectedResourceMetadata: Parsed metadata, or None if not available

        Raises:
            ValueError: If metadata exists but is invalid
        """
        # Check cache first
        cache_key = f"prm:{mcp_server}"
        async with self._lock:
            if cache_key in self._cache:
                metadata, cached_at = self._cache[cache_key]
                if datetime.now(UTC) - cached_at < timedelta(seconds=self.ttl):
                    logger.debug(f"Cache hit for protected resource metadata: {mcp_server}")
                    return metadata

        # Fetch from well-known endpoint
        well_known_url = f"{mcp_server.rstrip('/')}/.well-known/oauth-protected-resource"
        logger.info(f"Discovering protected resource metadata from {well_known_url}")

        data = None  # Initialize to avoid scope issues
        try:
            async with httpx.AsyncClient(timeout=self.timeout) as client:
                response = await client.get(well_known_url)

                # 404 is acceptable - PRM is optional
                if response.status_code == 404:
                    logger.debug(
                        f"Protected resource metadata not available for {mcp_server} (404). "
                        f"This is optional per RFC 9728."
                    )
                    # Cache the None result to avoid repeated requests
                    async with self._lock:
                        self._cache[cache_key] = (None, datetime.now(UTC))
                    return None

                response.raise_for_status()
                data = response.json()

            # Parse and validate
            metadata = ProtectedResourceMetadata(**data)

            # Validate authorization_servers is non-empty
            if not metadata.authorization_servers:
                raise ValueError("Protected resource metadata must include authorization_servers")

            logger.info(
                f"Successfully discovered protected resource metadata for {mcp_server}: "
                f"authorization_servers={metadata.authorization_servers}, "
                f"scopes_supported={metadata.scopes_supported}"
            )

            # Cache result
            async with self._lock:
                self._cache[cache_key] = (metadata, datetime.now(UTC))

            return metadata

        except httpx.HTTPStatusError as e:
            if e.response.status_code == 404:
                # Already handled above, but just in case
                async with self._lock:
                    self._cache[cache_key] = (None, datetime.now(UTC))
                return None
            logger.error(
                f"Failed to fetch protected resource metadata from {well_known_url}: "
                f"HTTP {e.response.status_code}"
            )
            raise
        except httpx.HTTPError as e:
            logger.error(f"Network error fetching protected resource metadata: {e}")
            raise
        except Exception as e:
            logger.error(f"Failed to parse protected resource metadata: {e}")
            raise ValueError(f"Invalid protected resource metadata: {e}") from e
sk_agents.auth.server_metadata.ServerMetadataCache.__init__
__init__(timeout: float = 30.0, ttl: int = 3600)

Initialize metadata cache.

Parameters:

Name Type Description Default
timeout float

HTTP request timeout in seconds (default: 30)

30.0
ttl int

Cache TTL in seconds (default: 3600 = 1 hour)

3600
Source code in src/sk_agents/auth/server_metadata.py
def __init__(self, timeout: float = 30.0, ttl: int = 3600):
    """
    Initialize metadata cache.

    Args:
        timeout: HTTP request timeout in seconds (default: 30)
        ttl: Cache TTL in seconds (default: 3600 = 1 hour)
    """
    self.timeout = timeout
    self.ttl = ttl
    self._cache: dict[str, tuple[Any, datetime]] = {}
    self._lock = asyncio.Lock()
sk_agents.auth.server_metadata.ServerMetadataCache.fetch_auth_server_metadata async
fetch_auth_server_metadata(
    auth_server: str,
) -> AuthServerMetadata

Fetch authorization server metadata from well-known endpoint.

Per RFC 8414, discovers OAuth endpoints from: {auth_server}/.well-known/oauth-authorization-server

Parameters:

Name Type Description Default
auth_server str

Authorization server base URL

required

Returns:

Name Type Description
AuthServerMetadata AuthServerMetadata

Parsed metadata

Raises:

Type Description
HTTPError

If discovery fails

ValueError

If metadata is invalid

Source code in src/sk_agents/auth/server_metadata.py
async def fetch_auth_server_metadata(self, auth_server: str) -> AuthServerMetadata:
    """
    Fetch authorization server metadata from well-known endpoint.

    Per RFC 8414, discovers OAuth endpoints from:
    {auth_server}/.well-known/oauth-authorization-server

    Args:
        auth_server: Authorization server base URL

    Returns:
        AuthServerMetadata: Parsed metadata

    Raises:
        httpx.HTTPError: If discovery fails
        ValueError: If metadata is invalid
    """
    # Check cache first
    async with self._lock:
        if auth_server in self._cache:
            metadata, cached_at = self._cache[auth_server]
            if datetime.now(UTC) - cached_at < timedelta(seconds=self.ttl):
                logger.debug(f"Cache hit for auth server metadata: {auth_server}")
                return metadata

    # Fetch from well-known endpoint
    well_known_url = f"{auth_server.rstrip('/')}/.well-known/oauth-authorization-server"
    logger.info(f"Discovering authorization server metadata from {well_known_url}")

    data = None  # Initialize to avoid scope issues
    try:
        async with httpx.AsyncClient(timeout=self.timeout) as client:
            response = await client.get(well_known_url)
            response.raise_for_status()
            data = response.json()

        # Parse and validate
        metadata = AuthServerMetadata(**data)

        # Validate PKCE support (MCP requirement)
        if metadata.code_challenge_methods_supported:
            if "S256" not in metadata.code_challenge_methods_supported:
                logger.warning(
                    f"Auth server {auth_server} does not advertise S256 PKCE support. "
                    f"Supported methods: {metadata.code_challenge_methods_supported}"
                )
        else:
            logger.warning(
                f"Auth server {auth_server} does not advertise code_challenge_methods_supported"
            )

        logger.info(
            f"Successfully discovered metadata for {auth_server}: "
            f"authorization_endpoint={metadata.authorization_endpoint}, "
            f"token_endpoint={metadata.token_endpoint}"
        )

        # Cache result
        async with self._lock:
            self._cache[auth_server] = (metadata, datetime.now(UTC))

        return metadata

    except httpx.HTTPStatusError as e:
        logger.error(
            f"Failed to fetch authorization server metadata from {well_known_url}: "
            f"HTTP {e.response.status_code}"
        )
        raise
    except httpx.HTTPError as e:
        logger.error(f"Network error fetching authorization server metadata: {e}")
        raise
    except Exception as e:
        logger.error(f"Failed to parse authorization server metadata: {e}")
        raise ValueError(f"Invalid authorization server metadata: {e}") from e
sk_agents.auth.server_metadata.ServerMetadataCache.fetch_protected_resource_metadata async
fetch_protected_resource_metadata(
    mcp_server: str,
) -> ProtectedResourceMetadata | None

Fetch protected resource metadata from MCP server.

Per RFC 9728, discovers resource metadata from: {mcp_server}/.well-known/oauth-protected-resource

Note: This metadata is OPTIONAL per RFC 9728. Returns None if not available.

Parameters:

Name Type Description Default
mcp_server str

MCP server base URL

required

Returns:

Name Type Description
ProtectedResourceMetadata ProtectedResourceMetadata | None

Parsed metadata, or None if not available

Raises:

Type Description
ValueError

If metadata exists but is invalid

Source code in src/sk_agents/auth/server_metadata.py
async def fetch_protected_resource_metadata(
    self, mcp_server: str
) -> ProtectedResourceMetadata | None:
    """
    Fetch protected resource metadata from MCP server.

    Per RFC 9728, discovers resource metadata from:
    {mcp_server}/.well-known/oauth-protected-resource

    Note: This metadata is OPTIONAL per RFC 9728. Returns None if not available.

    Args:
        mcp_server: MCP server base URL

    Returns:
        ProtectedResourceMetadata: Parsed metadata, or None if not available

    Raises:
        ValueError: If metadata exists but is invalid
    """
    # Check cache first
    cache_key = f"prm:{mcp_server}"
    async with self._lock:
        if cache_key in self._cache:
            metadata, cached_at = self._cache[cache_key]
            if datetime.now(UTC) - cached_at < timedelta(seconds=self.ttl):
                logger.debug(f"Cache hit for protected resource metadata: {mcp_server}")
                return metadata

    # Fetch from well-known endpoint
    well_known_url = f"{mcp_server.rstrip('/')}/.well-known/oauth-protected-resource"
    logger.info(f"Discovering protected resource metadata from {well_known_url}")

    data = None  # Initialize to avoid scope issues
    try:
        async with httpx.AsyncClient(timeout=self.timeout) as client:
            response = await client.get(well_known_url)

            # 404 is acceptable - PRM is optional
            if response.status_code == 404:
                logger.debug(
                    f"Protected resource metadata not available for {mcp_server} (404). "
                    f"This is optional per RFC 9728."
                )
                # Cache the None result to avoid repeated requests
                async with self._lock:
                    self._cache[cache_key] = (None, datetime.now(UTC))
                return None

            response.raise_for_status()
            data = response.json()

        # Parse and validate
        metadata = ProtectedResourceMetadata(**data)

        # Validate authorization_servers is non-empty
        if not metadata.authorization_servers:
            raise ValueError("Protected resource metadata must include authorization_servers")

        logger.info(
            f"Successfully discovered protected resource metadata for {mcp_server}: "
            f"authorization_servers={metadata.authorization_servers}, "
            f"scopes_supported={metadata.scopes_supported}"
        )

        # Cache result
        async with self._lock:
            self._cache[cache_key] = (metadata, datetime.now(UTC))

        return metadata

    except httpx.HTTPStatusError as e:
        if e.response.status_code == 404:
            # Already handled above, but just in case
            async with self._lock:
                self._cache[cache_key] = (None, datetime.now(UTC))
            return None
        logger.error(
            f"Failed to fetch protected resource metadata from {well_known_url}: "
            f"HTTP {e.response.status_code}"
        )
        raise
    except httpx.HTTPError as e:
        logger.error(f"Network error fetching protected resource metadata: {e}")
        raise
    except Exception as e:
        logger.error(f"Failed to parse protected resource metadata: {e}")
        raise ValueError(f"Invalid protected resource metadata: {e}") from e
sk_agents.auth_storage
sk_agents.auth_storage.auth_storage_factory
sk_agents.auth_storage.auth_storage_factory.AuthStorageFactory
Source code in src/sk_agents/auth_storage/auth_storage_factory.py
class AuthStorageFactory(metaclass=Singleton):
    def __init__(self, app_config: AppConfig):
        self.app_config = app_config

        # Try to load custom module, fallback to default if not configured
        module_name, class_name = self._get_custom_auth_storage_config()
        if module_name and class_name:
            try:
                self.module = ModuleLoader.load_module(module_name)
            except Exception as e:
                raise ImportError(f"Failed to load module '{module_name}': {e}") from e

            self.class_name = class_name
            self._validate_custom_class()
        else:
            self.module = None
            self.class_name = None

    def get_auth_storage_manager(self) -> SecureAuthStorageManager:
        if self.module and self.class_name:
            # Use custom implementation
            custom_class = getattr(self.module, self.class_name)
            try:
                return custom_class(app_config=self.app_config)
            except TypeError:
                # Fallback if app_config not accepted
                return custom_class()
        else:
            # Use default implementation
            return InMemorySecureAuthStorageManager()

    def _get_custom_auth_storage_config(self) -> tuple[str | None, str | None]:
        """Get custom auth storage configuration, returning None values if not configured."""
        try:
            module_name = self.app_config.get(TA_AUTH_STORAGE_MANAGER_MODULE.env_name)
        except KeyError:
            return None, None

        try:
            class_name = self.app_config.get(TA_AUTH_STORAGE_MANAGER_CLASS.env_name)
        except KeyError:
            if module_name:
                raise ValueError("Custom Auth Storage Manager class name not provided") from None
            return None, None

        return module_name, class_name

    def _validate_custom_class(self):
        """Validate that the custom class is a proper SecureAuthStorageManager subclass."""
        if not hasattr(self.module, self.class_name):
            module_name = getattr(self.module, "__name__", "unknown module")
            raise ValueError(
                f"Custom Auth Storage Manager class: {self.class_name} "
                f"Not found in module: {module_name}"
            )

        custom_class = getattr(self.module, self.class_name)
        if not issubclass(custom_class, SecureAuthStorageManager):
            raise TypeError(
                f"Class '{self.class_name}' is not a subclass of SecureAuthStorageManager."
            )
sk_agents.auth_storage.custom
sk_agents.auth_storage.custom.example_redis_auth_storage

Complete Redis Authentication Storage Implementation

This example demonstrates a full-featured, production-ready Redis-based authentication storage implementation. It serves as a complete alternative to the default in-memory storage.

To use this implementation, set the following environment variables:

TA_AUTH_STORAGE_MANAGER_MODULE=src/sk_agents/auth_storage/custom/example_redis_auth_storage.py TA_AUTH_STORAGE_MANAGER_CLASS=RedisSecureAuthStorageManager

Required Redis configuration environment variables: - TA_REDIS_HOST (default: localhost) - TA_REDIS_PORT (default: 6379) - TA_REDIS_DB (default: 0) - TA_REDIS_TTL (default: 3600 seconds) - TA_REDIS_PWD (optional) - TA_REDIS_SSL (default: false)

sk_agents.auth_storage.custom.example_redis_auth_storage.RedisSecureAuthStorageManager

Bases: SecureAuthStorageManager

Source code in src/sk_agents/auth_storage/custom/example_redis_auth_storage.py
class RedisSecureAuthStorageManager(SecureAuthStorageManager):
    def __init__(self, app_config: AppConfig = None):
        """
        Initialize the Redis-based auth storage manager.

        Args:
            app_config: Application configuration object. If None, creates a new one.
        """
        if app_config is None:
            app_config = AppConfig()

        self.app_config = app_config
        self._lock = threading.Lock()

        # Get Redis configuration
        redis_host = self.app_config.get(TA_REDIS_HOST.env_name) or "localhost"
        redis_port = int(self.app_config.get(TA_REDIS_PORT.env_name) or 6379)
        redis_db = int(self.app_config.get(TA_REDIS_DB.env_name) or 0)
        redis_password = self.app_config.get(TA_REDIS_PWD.env_name)
        redis_ssl = self.app_config.get(TA_REDIS_SSL.env_name) == "false"
        self.ttl = int(self.app_config.get(TA_REDIS_TTL.env_name) or 3600)  # Default 1 hour

        # Initialize Redis client
        self.redis_client = redis.Redis(
            host=redis_host,
            port=redis_port,
            db=redis_db,
            password=redis_password,
            ssl=redis_ssl,
            decode_responses=True,  # Automatically decode responses to strings
            socket_connect_timeout=5,
            socket_timeout=5,
            retry_on_timeout=True,
        )

        # Test connection
        try:
            self.redis_client.ping()
        except redis.ConnectionError as e:
            raise ConnectionError(f"Failed to connect to Redis: {e}") from e

    def _get_redis_key(self, user_id: str, key: str) -> str:
        """Generate a Redis key for the given user_id and key."""
        return f"auth_storage:{user_id}:{key}"

    def _serialize_auth_data(self, data: AuthData) -> str:
        """Serialize AuthData to JSON string."""
        return data.model_dump_json()

    def _deserialize_auth_data(self, data_str: str) -> AuthData:
        """Deserialize JSON string to AuthData."""
        data_dict = json.loads(data_str)
        # Import here to avoid circular imports
        from sk_agents.auth_storage.models import AuthData

        return AuthData.model_validate(data_dict)

    def store(self, user_id: str, key: str, data: AuthData) -> None:
        """Store authorization data for a given user and key with TTL."""
        with self._lock:
            try:
                redis_key = self._get_redis_key(user_id, key)
                serialized_data = self._serialize_auth_data(data)

                # Store with TTL
                self.redis_client.setex(redis_key, self.ttl, serialized_data)

            except redis.RedisError as e:
                raise RuntimeError(f"Failed to store auth data in Redis: {e}") from e

    def retrieve(self, user_id: str, key: str) -> AuthData | None:
        """Retrieve authorization data for a given user and key."""
        with self._lock:
            try:
                redis_key = self._get_redis_key(user_id, key)
                data_str = self.redis_client.get(redis_key)

                if data_str is None:
                    return None

                return self._deserialize_auth_data(data_str)

            except redis.RedisError as e:
                raise RuntimeError(f"Failed to retrieve auth data from Redis: {e}") from e
            except (json.JSONDecodeError, ValueError) as e:
                # If we can't deserialize the data, it's corrupted, so delete it
                try:
                    redis_key = self._get_redis_key(user_id, key)
                    self.redis_client.delete(redis_key)
                except redis.RedisError:
                    pass  # Ignore deletion errors
                raise ValueError(
                    f"Corrupted auth data found for user {user_id}, key {key}: {e}"
                ) from e

    def delete(self, user_id: str, key: str) -> None:
        """Delete authorization data for a given user and key."""
        with self._lock:
            try:
                redis_key = self._get_redis_key(user_id, key)
                self.redis_client.delete(redis_key)

            except redis.RedisError as e:
                raise RuntimeError(f"Failed to delete auth data from Redis: {e}") from e

    def clear_user_data(self, user_id: str) -> int:
        """
        Clear all authorization data for a given user.

        Returns:
            Number of keys deleted.
        """
        with self._lock:
            try:
                pattern = self._get_redis_key(user_id, "*")
                keys = self.redis_client.keys(pattern)

                if not keys:
                    return 0

                return self.redis_client.delete(*keys)

            except redis.RedisError as e:
                raise RuntimeError(f"Failed to clear user data from Redis: {e}") from e

    def health_check(self) -> bool:
        """Check if Redis connection is healthy."""
        try:
            self.redis_client.ping()
            return True
        except redis.RedisError:
            return False
sk_agents.auth_storage.custom.example_redis_auth_storage.RedisSecureAuthStorageManager.__init__
__init__(app_config: AppConfig = None)

Initialize the Redis-based auth storage manager.

Parameters:

Name Type Description Default
app_config AppConfig

Application configuration object. If None, creates a new one.

None
Source code in src/sk_agents/auth_storage/custom/example_redis_auth_storage.py
def __init__(self, app_config: AppConfig = None):
    """
    Initialize the Redis-based auth storage manager.

    Args:
        app_config: Application configuration object. If None, creates a new one.
    """
    if app_config is None:
        app_config = AppConfig()

    self.app_config = app_config
    self._lock = threading.Lock()

    # Get Redis configuration
    redis_host = self.app_config.get(TA_REDIS_HOST.env_name) or "localhost"
    redis_port = int(self.app_config.get(TA_REDIS_PORT.env_name) or 6379)
    redis_db = int(self.app_config.get(TA_REDIS_DB.env_name) or 0)
    redis_password = self.app_config.get(TA_REDIS_PWD.env_name)
    redis_ssl = self.app_config.get(TA_REDIS_SSL.env_name) == "false"
    self.ttl = int(self.app_config.get(TA_REDIS_TTL.env_name) or 3600)  # Default 1 hour

    # Initialize Redis client
    self.redis_client = redis.Redis(
        host=redis_host,
        port=redis_port,
        db=redis_db,
        password=redis_password,
        ssl=redis_ssl,
        decode_responses=True,  # Automatically decode responses to strings
        socket_connect_timeout=5,
        socket_timeout=5,
        retry_on_timeout=True,
    )

    # Test connection
    try:
        self.redis_client.ping()
    except redis.ConnectionError as e:
        raise ConnectionError(f"Failed to connect to Redis: {e}") from e
sk_agents.auth_storage.custom.example_redis_auth_storage.RedisSecureAuthStorageManager.store
store(user_id: str, key: str, data: AuthData) -> None

Store authorization data for a given user and key with TTL.

Source code in src/sk_agents/auth_storage/custom/example_redis_auth_storage.py
def store(self, user_id: str, key: str, data: AuthData) -> None:
    """Store authorization data for a given user and key with TTL."""
    with self._lock:
        try:
            redis_key = self._get_redis_key(user_id, key)
            serialized_data = self._serialize_auth_data(data)

            # Store with TTL
            self.redis_client.setex(redis_key, self.ttl, serialized_data)

        except redis.RedisError as e:
            raise RuntimeError(f"Failed to store auth data in Redis: {e}") from e
sk_agents.auth_storage.custom.example_redis_auth_storage.RedisSecureAuthStorageManager.retrieve
retrieve(user_id: str, key: str) -> AuthData | None

Retrieve authorization data for a given user and key.

Source code in src/sk_agents/auth_storage/custom/example_redis_auth_storage.py
def retrieve(self, user_id: str, key: str) -> AuthData | None:
    """Retrieve authorization data for a given user and key."""
    with self._lock:
        try:
            redis_key = self._get_redis_key(user_id, key)
            data_str = self.redis_client.get(redis_key)

            if data_str is None:
                return None

            return self._deserialize_auth_data(data_str)

        except redis.RedisError as e:
            raise RuntimeError(f"Failed to retrieve auth data from Redis: {e}") from e
        except (json.JSONDecodeError, ValueError) as e:
            # If we can't deserialize the data, it's corrupted, so delete it
            try:
                redis_key = self._get_redis_key(user_id, key)
                self.redis_client.delete(redis_key)
            except redis.RedisError:
                pass  # Ignore deletion errors
            raise ValueError(
                f"Corrupted auth data found for user {user_id}, key {key}: {e}"
            ) from e
sk_agents.auth_storage.custom.example_redis_auth_storage.RedisSecureAuthStorageManager.delete
delete(user_id: str, key: str) -> None

Delete authorization data for a given user and key.

Source code in src/sk_agents/auth_storage/custom/example_redis_auth_storage.py
def delete(self, user_id: str, key: str) -> None:
    """Delete authorization data for a given user and key."""
    with self._lock:
        try:
            redis_key = self._get_redis_key(user_id, key)
            self.redis_client.delete(redis_key)

        except redis.RedisError as e:
            raise RuntimeError(f"Failed to delete auth data from Redis: {e}") from e
sk_agents.auth_storage.custom.example_redis_auth_storage.RedisSecureAuthStorageManager.clear_user_data
clear_user_data(user_id: str) -> int

Clear all authorization data for a given user.

Returns:

Type Description
int

Number of keys deleted.

Source code in src/sk_agents/auth_storage/custom/example_redis_auth_storage.py
def clear_user_data(self, user_id: str) -> int:
    """
    Clear all authorization data for a given user.

    Returns:
        Number of keys deleted.
    """
    with self._lock:
        try:
            pattern = self._get_redis_key(user_id, "*")
            keys = self.redis_client.keys(pattern)

            if not keys:
                return 0

            return self.redis_client.delete(*keys)

        except redis.RedisError as e:
            raise RuntimeError(f"Failed to clear user data from Redis: {e}") from e
sk_agents.auth_storage.custom.example_redis_auth_storage.RedisSecureAuthStorageManager.health_check
health_check() -> bool

Check if Redis connection is healthy.

Source code in src/sk_agents/auth_storage/custom/example_redis_auth_storage.py
def health_check(self) -> bool:
    """Check if Redis connection is healthy."""
    try:
        self.redis_client.ping()
        return True
    except redis.RedisError:
        return False
sk_agents.auth_storage.in_memory_secure_auth_storage_manager
sk_agents.auth_storage.in_memory_secure_auth_storage_manager.InMemorySecureAuthStorageManager

Bases: SecureAuthStorageManager

A thread-safe, in-memory implementation of the SecureAuthStorageManager.

Source code in src/sk_agents/auth_storage/in_memory_secure_auth_storage_manager.py
class InMemorySecureAuthStorageManager(SecureAuthStorageManager):
    """A thread-safe, in-memory implementation of the SecureAuthStorageManager."""

    def __init__(self):
        self._storage: dict[str, dict[str, AuthData]] = {}
        self._lock = threading.Lock()

    def store(self, user_id: str, key: str, data: AuthData) -> None:
        with self._lock:
            if user_id not in self._storage:
                self._storage[user_id] = {}
            self._storage[user_id][key] = data

    def retrieve(self, user_id: str, key: str) -> AuthData | None:
        with self._lock:
            return self._storage.get(user_id, {}).get(key)

    def delete(self, user_id: str, key: str) -> None:
        with self._lock:
            if user_id in self._storage and key in self._storage[user_id]:
                del self._storage[user_id][key]
sk_agents.auth_storage.models
sk_agents.auth_storage.models.OAuth2AuthData

Bases: BaseAuthData

Source code in src/sk_agents/auth_storage/models.py
class OAuth2AuthData(BaseAuthData):
    auth_type: Literal["oauth2"] = "oauth2"
    access_token: str
    refresh_token: str | None = None
    expires_at: datetime
    # The scopes this token is valid for.
    scopes: list[str] = []

    # MCP OAuth 2.1 Compliance Fields
    audience: str | None = None  # Token audience (aud) for validation
    resource: str | None = None  # Resource binding (canonical MCP server URI)
    token_type: str = "Bearer"  # Token type (usually "Bearer")
    issued_at: datetime | None = None  # Token issue timestamp

    def is_valid_for_resource(self, resource_uri: str) -> bool:
        """
        Validate token is valid for specific resource.

        Checks:
        1. Token not expired
        2. Resource matches (if resource binding present)
        3. Audience matches (if audience present)

        Args:
            resource_uri: Canonical MCP server URI to validate against

        Returns:
            bool: True if token is valid for this resource
        """
        from datetime import datetime

        # Check expiry
        if self.expires_at <= datetime.now(UTC):
            return False

        # Check resource binding (MCP-specific)
        if self.resource and self.resource != resource_uri:
            return False

        # Check audience (OAuth 2.1 token audience validation)
        if self.audience and self.audience != resource_uri:
            return False

        return True
sk_agents.auth_storage.models.OAuth2AuthData.is_valid_for_resource
is_valid_for_resource(resource_uri: str) -> bool

Validate token is valid for specific resource.

Checks: 1. Token not expired 2. Resource matches (if resource binding present) 3. Audience matches (if audience present)

Parameters:

Name Type Description Default
resource_uri str

Canonical MCP server URI to validate against

required

Returns:

Name Type Description
bool bool

True if token is valid for this resource

Source code in src/sk_agents/auth_storage/models.py
def is_valid_for_resource(self, resource_uri: str) -> bool:
    """
    Validate token is valid for specific resource.

    Checks:
    1. Token not expired
    2. Resource matches (if resource binding present)
    3. Audience matches (if audience present)

    Args:
        resource_uri: Canonical MCP server URI to validate against

    Returns:
        bool: True if token is valid for this resource
    """
    from datetime import datetime

    # Check expiry
    if self.expires_at <= datetime.now(UTC):
        return False

    # Check resource binding (MCP-specific)
    if self.resource and self.resource != resource_uri:
        return False

    # Check audience (OAuth 2.1 token audience validation)
    if self.audience and self.audience != resource_uri:
        return False

    return True
sk_agents.auth_storage.secure_auth_storage_manager
sk_agents.auth_storage.secure_auth_storage_manager.SecureAuthStorageManager

Bases: ABC

Source code in src/sk_agents/auth_storage/secure_auth_storage_manager.py
class SecureAuthStorageManager(ABC):
    @abstractmethod
    def store(self, user_id: str, key: str, data: AuthData) -> None:
        """Stores authorization data for a given user and key."""
        pass

    @abstractmethod
    def retrieve(self, user_id: str, key: str) -> AuthData | None:
        """Retrieves authorization data for a given user and key."""
        pass

    @abstractmethod
    def delete(self, user_id: str, key: str) -> None:
        """Deletes authorization data for a given user and key."""
        pass
sk_agents.auth_storage.secure_auth_storage_manager.SecureAuthStorageManager.store abstractmethod
store(user_id: str, key: str, data: AuthData) -> None

Stores authorization data for a given user and key.

Source code in src/sk_agents/auth_storage/secure_auth_storage_manager.py
@abstractmethod
def store(self, user_id: str, key: str, data: AuthData) -> None:
    """Stores authorization data for a given user and key."""
    pass
sk_agents.auth_storage.secure_auth_storage_manager.SecureAuthStorageManager.retrieve abstractmethod
retrieve(user_id: str, key: str) -> AuthData | None

Retrieves authorization data for a given user and key.

Source code in src/sk_agents/auth_storage/secure_auth_storage_manager.py
@abstractmethod
def retrieve(self, user_id: str, key: str) -> AuthData | None:
    """Retrieves authorization data for a given user and key."""
    pass
sk_agents.auth_storage.secure_auth_storage_manager.SecureAuthStorageManager.delete abstractmethod
delete(user_id: str, key: str) -> None

Deletes authorization data for a given user and key.

Source code in src/sk_agents/auth_storage/secure_auth_storage_manager.py
@abstractmethod
def delete(self, user_id: str, key: str) -> None:
    """Deletes authorization data for a given user and key."""
    pass
sk_agents.authorization
sk_agents.authorization.request_authorizer
sk_agents.authorization.request_authorizer.RequestAuthorizer

Bases: ABC

Source code in src/sk_agents/authorization/request_authorizer.py
class RequestAuthorizer(ABC):
    @abstractmethod
    async def authorize_request(self, auth_header: str) -> str:
        """
        Validates the given authorization header and returns a unique identifier
        for the authenticated user.

        Parameters:
            auth_header (str): The value of the 'Authorization' HTTP header.
                Typically, this is in the format 'Bearer <token>' or some other
                scheme depending on the implementation.

        Returns:
            str: A unique string that identifies the authenticated user.
                This could be a user ID, username, email, or any other unique
                identifier suitable for tracking and authorization.
            Examples:
                "user_12345"
                "alice@example.com"

        Raises:
            ValueError: If the authorization header is missing, malformed, or invalid.
            AuthenticationError (optional): If used in your implementation, it may
                be raised to signal an authentication failure.
        """
        pass
sk_agents.authorization.request_authorizer.RequestAuthorizer.authorize_request abstractmethod async
authorize_request(auth_header: str) -> str

Validates the given authorization header and returns a unique identifier for the authenticated user.

Parameters:

Name Type Description Default
auth_header str

The value of the 'Authorization' HTTP header. Typically, this is in the format 'Bearer ' or some other scheme depending on the implementation.

required

Returns:

Name Type Description
str str

A unique string that identifies the authenticated user. This could be a user ID, username, email, or any other unique identifier suitable for tracking and authorization.

Examples str

"user_12345" "alice@example.com"

Raises:

Type Description
ValueError

If the authorization header is missing, malformed, or invalid.

AuthenticationError(optional)

If used in your implementation, it may be raised to signal an authentication failure.

Source code in src/sk_agents/authorization/request_authorizer.py
@abstractmethod
async def authorize_request(self, auth_header: str) -> str:
    """
    Validates the given authorization header and returns a unique identifier
    for the authenticated user.

    Parameters:
        auth_header (str): The value of the 'Authorization' HTTP header.
            Typically, this is in the format 'Bearer <token>' or some other
            scheme depending on the implementation.

    Returns:
        str: A unique string that identifies the authenticated user.
            This could be a user ID, username, email, or any other unique
            identifier suitable for tracking and authorization.
        Examples:
            "user_12345"
            "alice@example.com"

    Raises:
        ValueError: If the authorization header is missing, malformed, or invalid.
        AuthenticationError (optional): If used in your implementation, it may
            be raised to signal an authentication failure.
    """
    pass
sk_agents.exceptions
sk_agents.exceptions.AgentsException

Bases: Exception

Base class for all exception in SKagents

Source code in src/sk_agents/exceptions.py
class AgentsException(Exception):
    """Base class for all exception in SKagents"""
sk_agents.exceptions.InvalidConfigException

Bases: AgentsException

Exception raised when the provided configuration is invalid

Source code in src/sk_agents/exceptions.py
class InvalidConfigException(AgentsException):
    """Exception raised when the provided configuration is invalid"""

    message: str

    def __init__(self, message: str):
        self.message = message
sk_agents.exceptions.InvalidInputException

Bases: AgentsException

Exception raised when the provided input type is invalid

Source code in src/sk_agents/exceptions.py
class InvalidInputException(AgentsException):
    """Exception raised when the provided input type is invalid"""

    message: str

    def __init__(self, message: str):
        self.message = message
sk_agents.exceptions.AgentInvokeException

Bases: AgentsException

Exception raised when invoking an Agent failed

Source code in src/sk_agents/exceptions.py
class AgentInvokeException(AgentsException):
    """Exception raised when invoking an Agent failed"""

    message: str

    def __init__(self, message: str):
        self.message = message
sk_agents.exceptions.PersistenceCreateError

Bases: AgentsException

Exception raised for errors during task creation.

Source code in src/sk_agents/exceptions.py
class PersistenceCreateError(AgentsException):
    """Exception raised for errors during task creation."""

    message: str

    def __init__(self, message: str):
        self.message = message
sk_agents.exceptions.PersistenceLoadError

Bases: AgentsException

Exception raised for errors during task loading.

Source code in src/sk_agents/exceptions.py
class PersistenceLoadError(AgentsException):
    """Exception raised for errors during task loading."""

    message: str

    def __init__(self, message: str):
        self.message = message
sk_agents.exceptions.PersistenceUpdateError

Bases: AgentsException

Exception raised for errors during task update.

Source code in src/sk_agents/exceptions.py
class PersistenceUpdateError(AgentsException):
    """Exception raised for errors during task update."""

    message: str

    def __init__(self, message: str):
        self.message = message
sk_agents.exceptions.PersistenceDeleteError

Bases: AgentsException

Exception raised for errors during task deletion.

Source code in src/sk_agents/exceptions.py
class PersistenceDeleteError(AgentsException):
    """Exception raised for errors during task deletion."""

    message: str

    def __init__(self, message: str):
        self.message = message
sk_agents.exceptions.AuthenticationException

Bases: AgentsException

Exception raised errors when authenticating users

Source code in src/sk_agents/exceptions.py
class AuthenticationException(AgentsException):
    """Exception raised errors when authenticating users"""

    message: str

    def __init__(self, message: str):
        self.message = message
sk_agents.exceptions.PluginCatalogDefinitionException

Bases: AgentsException

Exception raised when the parsed json does not match the PluginCatalogDefinition Model

Source code in src/sk_agents/exceptions.py
class PluginCatalogDefinitionException(AgentsException):
    """Exception raised when the parsed json does not match the PluginCatalogDefinition Model"""

    message: str

    def __init__(self, message: str):
        self.message = message
sk_agents.exceptions.PluginFileReadException

Bases: AgentsException

Raise this exception when the plugin file fails to be read

Source code in src/sk_agents/exceptions.py
class PluginFileReadException(AgentsException):
    """Raise this exception when the plugin file fails to be read"""

    message: str

    def __init__(self, message: str):
        self.message = message
sk_agents.hitl
sk_agents.hitl.hitl_manager
sk_agents.hitl.hitl_manager.HitlInterventionRequired

Bases: Exception

Exception raised when a tool call requires human-in-the-loop intervention.

Source code in src/sk_agents/hitl/hitl_manager.py
class HitlInterventionRequired(Exception):
    """
    Exception raised when a tool call
    requires human-in-the-loop intervention.
    """

    def __init__(self, function_calls: list[FunctionCallContent]):
        self.function_calls = function_calls
        if function_calls:
            self.plugin_name = function_calls[0].plugin_name
            self.function_name = function_calls[0].function_name
            message = f"HITL intervention required for {self.plugin_name}.{self.function_name}"

        else:
            message = "HITL intervention required but no function calls provided (internal error)"
        super().__init__(message)
sk_agents.hitl.hitl_manager.check_for_intervention
check_for_intervention(
    tool_call: FunctionCallContent,
) -> bool

Checks the plugin catalog to determine if a tool call requires Human-in-the-Loop intervention.

Source code in src/sk_agents/hitl/hitl_manager.py
def check_for_intervention(tool_call: FunctionCallContent) -> bool:
    """
    Checks the plugin catalog to determine if a tool call requires
    Human-in-the-Loop intervention.
    """
    plugin_factory = PluginCatalogFactory()
    catalog = plugin_factory.get_catalog()
    if not catalog:
        # Fallback if catalog is not configured
        return False

    tool_id = f"{tool_call.plugin_name}-{tool_call.function_name}"
    tool = catalog.get_tool(tool_id)

    if tool:
        logger.debug(
            f"HITL Check: Intercepted call to {tool_id}. "
            f"Requires HITL: {tool.governance.requires_hitl}"
        )
        return tool.governance.requires_hitl
    # Default to no intervention if tool is not in the catalog
    return False
sk_agents.mcp_client

MCP Client for Teal Agents Platform - Clean Implementation

This module provides an MCP (Model Context Protocol) client that supports only the transports that are actually available in the MCP Python SDK.

ONLY SUPPORTED TRANSPORTS: - stdio: Local subprocess communication - http: HTTP with Server-Sent Events for remote servers

WebSocket support will be added when it becomes available in the MCP SDK.

sk_agents.mcp_client.AuthRequiredError

Bases: Exception

Exception raised when MCP server authentication is required but missing.

This exception is raised during discovery when a server requires authentication (has auth_server + scopes configured) but the user has no valid token in AuthStorage.

Source code in src/sk_agents/mcp_client.py
class AuthRequiredError(Exception):
    """
    Exception raised when MCP server authentication is required but missing.

    This exception is raised during discovery when a server requires authentication
    (has auth_server + scopes configured) but the user has no valid token in AuthStorage.
    """

    def __init__(self, server_name: str, auth_server: str, scopes: list[str], message: str = None):
        self.server_name = server_name
        self.auth_server = auth_server
        self.scopes = scopes
        self.message = message or f"Authentication required for MCP server '{server_name}'"
        super().__init__(self.message)
sk_agents.mcp_client.McpConnectionManager

Request-scoped connection manager for MCP servers.

Manages MCP connections within a single agent invoke() request scope: - Lazy connection establishment (connect on first tool call per server) - Connection reuse within the request (all tools on same server share connection) - Automatic cleanup at request end - Session ID persistence via state manager for cross-request continuity

Lifecycle
  1. Created at start of invoke() request
  2. Connections created lazily when first tool from server is called
  3. Connections reused for all subsequent tool calls in same request
  4. Cleanup at end of invoke() - close connections, persist session IDs
Usage

async with McpConnectionManager(servers, user_id, ...) as conn_mgr: session = await conn_mgr.get_or_create_session(server_name) result = await session.call_tool(tool_name, args)

Source code in src/sk_agents/mcp_client.py
class McpConnectionManager:
    """
    Request-scoped connection manager for MCP servers.

    Manages MCP connections within a single agent invoke() request scope:
    - Lazy connection establishment (connect on first tool call per server)
    - Connection reuse within the request (all tools on same server share connection)
    - Automatic cleanup at request end
    - Session ID persistence via state manager for cross-request continuity

    Lifecycle:
        1. Created at start of invoke() request
        2. Connections created lazily when first tool from server is called
        3. Connections reused for all subsequent tool calls in same request
        4. Cleanup at end of invoke() - close connections, persist session IDs

    Usage:
        async with McpConnectionManager(servers, user_id, ...) as conn_mgr:
            session = await conn_mgr.get_or_create_session(server_name)
            result = await session.call_tool(tool_name, args)
    """

    def __init__(
        self,
        server_configs: dict[str, McpServerConfig],
        user_id: str,
        session_id: str,
        state_manager=None,  # McpStateManager for session ID persistence
        app_config: AppConfig = None,
    ):
        self._server_configs = server_configs
        self._user_id = user_id
        self._session_id = session_id
        self._state_manager = state_manager
        self._app_config = app_config

        # Active connections (created lazily)
        self._sessions: dict[str, ClientSession] = {}
        self._get_session_id_callbacks: dict[str, Callable[[], str | None]] = {}
        self._connection_stack: AsyncExitStack | None = None

        # Stored session IDs from previous requests
        self._stored_session_ids: dict[str, str] = {}

    async def __aenter__(self) -> "McpConnectionManager":
        """Enter context - initialize and load stored session IDs."""
        self._connection_stack = AsyncExitStack()
        await self._connection_stack.__aenter__()

        # Pre-load stored session IDs for all configured servers
        if self._state_manager:
            for server_name in self._server_configs:
                try:
                    stored_id = await self._state_manager.get_mcp_session(
                        self._user_id, self._session_id, server_name
                    )
                    if stored_id:
                        self._stored_session_ids[server_name] = stored_id
                        logger.debug(f"Loaded stored MCP session for {server_name}")
                except Exception as e:
                    logger.debug(f"Could not load stored session for {server_name}: {e}")

        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        """Exit context - persist session IDs and cleanup connections."""
        try:
            # Persist session IDs for servers that were connected
            if self._state_manager:
                for server_name, get_session_id in self._get_session_id_callbacks.items():
                    try:
                        current_id = get_session_id() if get_session_id else None
                        if current_id:
                            await self._state_manager.store_mcp_session(
                                self._user_id, self._session_id, server_name, current_id
                            )
                            logger.debug(f"Persisted MCP session for {server_name}")
                    except Exception as e:
                        logger.warning(f"Failed to persist MCP session for {server_name}: {e}")
        finally:
            # Close all connections
            if self._connection_stack:
                try:
                    await self._connection_stack.__aexit__(exc_type, exc_val, exc_tb)
                except RuntimeError as e:
                    # Handle anyio task affinity errors gracefully.
                    # This can happen when the connection manager is used across
                    # recursive handler calls that change the async task context.
                    # The MCP SDK's streamablehttp_client uses anyio.create_task_group()
                    # which requires entering/exiting in the same async task.
                    err_str = str(e)
                    if "cancel scope" in err_str and "different task" in err_str:
                        logger.warning(
                            f"MCP cleanup encountered task affinity issue (non-fatal): {e}"
                        )
                    else:
                        raise
            self._sessions.clear()
            self._get_session_id_callbacks.clear()

    async def get_or_create_session(self, server_name: str) -> ClientSession:
        """
        Get existing session or create new one for server (lazy connection).

        Args:
            server_name: Name of the MCP server

        Returns:
            Active MCP ClientSession for the server
        """
        if server_name in self._sessions:
            logger.debug(f"Reusing MCP session for {server_name}")
            return self._sessions[server_name]

        server_config = self._server_configs.get(server_name)
        if not server_config:
            raise ValueError(f"Unknown MCP server: {server_name}")

        if not self._connection_stack:
            raise RuntimeError("McpConnectionManager must be used as async context manager")

        stored_session_id = self._stored_session_ids.get(server_name)

        session, get_session_id = await create_mcp_session_with_retry(
            server_config,
            self._connection_stack,
            self._user_id,
            mcp_session_id=stored_session_id,
            on_stale_session=self._create_stale_handler(server_name),
            app_config=self._app_config,
        )

        self._sessions[server_name] = session
        self._get_session_id_callbacks[server_name] = get_session_id
        logger.info(f"Created MCP session for {server_name}")
        return session

    def _create_stale_handler(self, server_name: str) -> Callable[[str], Awaitable[None]]:
        """Create callback to handle stale session ID."""

        async def handler(stale_id: str):
            logger.info(f"Clearing stale MCP session for {server_name}")
            if self._state_manager:
                try:
                    await self._state_manager.clear_mcp_session(
                        self._user_id, self._session_id, server_name, expected_session_id=stale_id
                    )
                except Exception as e:
                    logger.debug(f"Failed to clear stale session: {e}")
            self._stored_session_ids.pop(server_name, None)

        return handler

    def has_active_session(self, server_name: str) -> bool:
        """Check if server has an active session in this request."""
        return server_name in self._sessions

    def get_active_servers(self) -> list[str]:
        """Get list of servers with active sessions."""
        return list(self._sessions.keys())
sk_agents.mcp_client.McpConnectionManager.__aenter__ async
__aenter__() -> McpConnectionManager

Enter context - initialize and load stored session IDs.

Source code in src/sk_agents/mcp_client.py
async def __aenter__(self) -> "McpConnectionManager":
    """Enter context - initialize and load stored session IDs."""
    self._connection_stack = AsyncExitStack()
    await self._connection_stack.__aenter__()

    # Pre-load stored session IDs for all configured servers
    if self._state_manager:
        for server_name in self._server_configs:
            try:
                stored_id = await self._state_manager.get_mcp_session(
                    self._user_id, self._session_id, server_name
                )
                if stored_id:
                    self._stored_session_ids[server_name] = stored_id
                    logger.debug(f"Loaded stored MCP session for {server_name}")
            except Exception as e:
                logger.debug(f"Could not load stored session for {server_name}: {e}")

    return self
sk_agents.mcp_client.McpConnectionManager.__aexit__ async
__aexit__(exc_type, exc_val, exc_tb)

Exit context - persist session IDs and cleanup connections.

Source code in src/sk_agents/mcp_client.py
async def __aexit__(self, exc_type, exc_val, exc_tb):
    """Exit context - persist session IDs and cleanup connections."""
    try:
        # Persist session IDs for servers that were connected
        if self._state_manager:
            for server_name, get_session_id in self._get_session_id_callbacks.items():
                try:
                    current_id = get_session_id() if get_session_id else None
                    if current_id:
                        await self._state_manager.store_mcp_session(
                            self._user_id, self._session_id, server_name, current_id
                        )
                        logger.debug(f"Persisted MCP session for {server_name}")
                except Exception as e:
                    logger.warning(f"Failed to persist MCP session for {server_name}: {e}")
    finally:
        # Close all connections
        if self._connection_stack:
            try:
                await self._connection_stack.__aexit__(exc_type, exc_val, exc_tb)
            except RuntimeError as e:
                # Handle anyio task affinity errors gracefully.
                # This can happen when the connection manager is used across
                # recursive handler calls that change the async task context.
                # The MCP SDK's streamablehttp_client uses anyio.create_task_group()
                # which requires entering/exiting in the same async task.
                err_str = str(e)
                if "cancel scope" in err_str and "different task" in err_str:
                    logger.warning(
                        f"MCP cleanup encountered task affinity issue (non-fatal): {e}"
                    )
                else:
                    raise
        self._sessions.clear()
        self._get_session_id_callbacks.clear()
sk_agents.mcp_client.McpConnectionManager.get_or_create_session async
get_or_create_session(server_name: str) -> ClientSession

Get existing session or create new one for server (lazy connection).

Parameters:

Name Type Description Default
server_name str

Name of the MCP server

required

Returns:

Type Description
ClientSession

Active MCP ClientSession for the server

Source code in src/sk_agents/mcp_client.py
async def get_or_create_session(self, server_name: str) -> ClientSession:
    """
    Get existing session or create new one for server (lazy connection).

    Args:
        server_name: Name of the MCP server

    Returns:
        Active MCP ClientSession for the server
    """
    if server_name in self._sessions:
        logger.debug(f"Reusing MCP session for {server_name}")
        return self._sessions[server_name]

    server_config = self._server_configs.get(server_name)
    if not server_config:
        raise ValueError(f"Unknown MCP server: {server_name}")

    if not self._connection_stack:
        raise RuntimeError("McpConnectionManager must be used as async context manager")

    stored_session_id = self._stored_session_ids.get(server_name)

    session, get_session_id = await create_mcp_session_with_retry(
        server_config,
        self._connection_stack,
        self._user_id,
        mcp_session_id=stored_session_id,
        on_stale_session=self._create_stale_handler(server_name),
        app_config=self._app_config,
    )

    self._sessions[server_name] = session
    self._get_session_id_callbacks[server_name] = get_session_id
    logger.info(f"Created MCP session for {server_name}")
    return session
sk_agents.mcp_client.McpConnectionManager.has_active_session
has_active_session(server_name: str) -> bool

Check if server has an active session in this request.

Source code in src/sk_agents/mcp_client.py
def has_active_session(self, server_name: str) -> bool:
    """Check if server has an active session in this request."""
    return server_name in self._sessions
sk_agents.mcp_client.McpConnectionManager.get_active_servers
get_active_servers() -> list[str]

Get list of servers with active sessions.

Source code in src/sk_agents/mcp_client.py
def get_active_servers(self) -> list[str]:
    """Get list of servers with active sessions."""
    return list(self._sessions.keys())
sk_agents.mcp_client.McpTool

Stateless wrapper for MCP tools to make them compatible with Semantic Kernel.

This class stores the server configuration and tool metadata, but does NOT store active connections. Each invocation creates a temporary connection.

Source code in src/sk_agents/mcp_client.py
class McpTool:
    """
    Stateless wrapper for MCP tools to make them compatible with Semantic Kernel.

    This class stores the server configuration and tool metadata, but does NOT
    store active connections. Each invocation creates a temporary connection.
    """

    def __init__(
        self,
        tool_name: str,
        description: str,
        input_schema: dict[str, Any],
        output_schema: dict[str, Any] | None,
        server_config: "McpServerConfig",
        server_name: str,
    ):
        """
        Initialize stateless MCP tool.

        Args:
            tool_name: Name of the MCP tool
            description: Tool description
            input_schema: JSON schema for tool inputs
            output_schema: JSON schema for tool outputs (optional)
            server_config: MCP server configuration (for reconnection)
            server_name: Name of the MCP server
        """
        self.tool_name = tool_name
        self.description = description
        self.input_schema = input_schema
        self.output_schema = output_schema
        self.server_config = server_config
        self.server_name = server_name
        self.app_config: AppConfig | None = None  # Set via McpPlugin at instantiation time

    async def invoke(
        self,
        connection_manager: "McpConnectionManager",
        **kwargs,
    ) -> str:
        """
        Invoke the MCP tool using a request-scoped connection manager.

        Args:
            connection_manager: Request-scoped connection manager for connection reuse
            **kwargs: Tool arguments

        Returns:
            Tool execution result as string

        Raises:
            ValueError: If connection_manager is not provided
            RuntimeError: If tool execution fails
        """
        if not connection_manager:
            raise ValueError(
                f"connection_manager is required for MCP tool invocation. "
                f"Tool '{self.tool_name}' cannot be invoked without a connection manager."
            )

        try:
            if self.input_schema:
                self._validate_inputs(kwargs)

            logger.debug(f"Executing MCP tool: {self.server_name}.{self.tool_name}")
            session = await connection_manager.get_or_create_session(self.server_name)
            result = await session.call_tool(self.tool_name, kwargs)
            parsed = self._parse_result(result)
            logger.debug(f"MCP tool {self.tool_name} completed successfully")
            return parsed

        except ValueError:
            # Re-raise validation errors as-is
            raise
        except Exception as e:
            logger.error(f"Error invoking MCP tool {self.tool_name}: {e}")

            error_msg = str(e).lower()
            if "timeout" in error_msg:
                raise RuntimeError(
                    f"MCP tool '{self.tool_name}' timed out. Check server responsiveness."
                ) from e
            elif "connection" in error_msg:
                raise RuntimeError(
                    f"MCP tool '{self.tool_name}' connection failed. Check server availability."
                ) from e
            else:
                raise RuntimeError(f"MCP tool '{self.tool_name}' failed: {e}") from e

    def _parse_result(self, result: Any) -> str:
        """Parse MCP result into string format."""
        if hasattr(result, "content"):
            if isinstance(result.content, list) and len(result.content) > 0:
                return (
                    str(result.content[0].text)
                    if hasattr(result.content[0], "text")
                    else str(result.content[0])
                )
            return str(result.content)
        elif hasattr(result, "text"):
            return result.text
        else:
            return str(result)

    def _validate_inputs(self, kwargs: dict[str, Any]) -> None:
        """Basic input validation against the tool's JSON schema."""
        if not isinstance(self.input_schema, dict):
            return

        properties = self.input_schema.get("properties", {})
        required = self.input_schema.get("required", [])

        # Check required parameters
        for req_param in required:
            if req_param not in kwargs:
                raise ValueError(
                    f"Missing required parameter '{req_param}' for tool '{self.tool_name}'"
                )

        # Warn about unexpected parameters
        for param in kwargs:
            if param not in properties:
                logger.warning(f"Unexpected parameter '{param}' for tool '{self.tool_name}'")
sk_agents.mcp_client.McpTool.__init__
__init__(
    tool_name: str,
    description: str,
    input_schema: dict[str, Any],
    output_schema: dict[str, Any] | None,
    server_config: McpServerConfig,
    server_name: str,
)

Initialize stateless MCP tool.

Parameters:

Name Type Description Default
tool_name str

Name of the MCP tool

required
description str

Tool description

required
input_schema dict[str, Any]

JSON schema for tool inputs

required
output_schema dict[str, Any] | None

JSON schema for tool outputs (optional)

required
server_config McpServerConfig

MCP server configuration (for reconnection)

required
server_name str

Name of the MCP server

required
Source code in src/sk_agents/mcp_client.py
def __init__(
    self,
    tool_name: str,
    description: str,
    input_schema: dict[str, Any],
    output_schema: dict[str, Any] | None,
    server_config: "McpServerConfig",
    server_name: str,
):
    """
    Initialize stateless MCP tool.

    Args:
        tool_name: Name of the MCP tool
        description: Tool description
        input_schema: JSON schema for tool inputs
        output_schema: JSON schema for tool outputs (optional)
        server_config: MCP server configuration (for reconnection)
        server_name: Name of the MCP server
    """
    self.tool_name = tool_name
    self.description = description
    self.input_schema = input_schema
    self.output_schema = output_schema
    self.server_config = server_config
    self.server_name = server_name
    self.app_config: AppConfig | None = None  # Set via McpPlugin at instantiation time
sk_agents.mcp_client.McpTool.invoke async
invoke(
    connection_manager: McpConnectionManager, **kwargs
) -> str

Invoke the MCP tool using a request-scoped connection manager.

Parameters:

Name Type Description Default
connection_manager McpConnectionManager

Request-scoped connection manager for connection reuse

required
**kwargs

Tool arguments

{}

Returns:

Type Description
str

Tool execution result as string

Raises:

Type Description
ValueError

If connection_manager is not provided

RuntimeError

If tool execution fails

Source code in src/sk_agents/mcp_client.py
async def invoke(
    self,
    connection_manager: "McpConnectionManager",
    **kwargs,
) -> str:
    """
    Invoke the MCP tool using a request-scoped connection manager.

    Args:
        connection_manager: Request-scoped connection manager for connection reuse
        **kwargs: Tool arguments

    Returns:
        Tool execution result as string

    Raises:
        ValueError: If connection_manager is not provided
        RuntimeError: If tool execution fails
    """
    if not connection_manager:
        raise ValueError(
            f"connection_manager is required for MCP tool invocation. "
            f"Tool '{self.tool_name}' cannot be invoked without a connection manager."
        )

    try:
        if self.input_schema:
            self._validate_inputs(kwargs)

        logger.debug(f"Executing MCP tool: {self.server_name}.{self.tool_name}")
        session = await connection_manager.get_or_create_session(self.server_name)
        result = await session.call_tool(self.tool_name, kwargs)
        parsed = self._parse_result(result)
        logger.debug(f"MCP tool {self.tool_name} completed successfully")
        return parsed

    except ValueError:
        # Re-raise validation errors as-is
        raise
    except Exception as e:
        logger.error(f"Error invoking MCP tool {self.tool_name}: {e}")

        error_msg = str(e).lower()
        if "timeout" in error_msg:
            raise RuntimeError(
                f"MCP tool '{self.tool_name}' timed out. Check server responsiveness."
            ) from e
        elif "connection" in error_msg:
            raise RuntimeError(
                f"MCP tool '{self.tool_name}' connection failed. Check server availability."
            ) from e
        else:
            raise RuntimeError(f"MCP tool '{self.tool_name}' failed: {e}") from e
sk_agents.mcp_client.McpPlugin

Bases: BasePlugin

Plugin wrapper that holds MCP tools for Semantic Kernel integration.

This plugin creates kernel functions with proper type annotations from MCP JSON schemas, allowing Semantic Kernel to expose full parameter information to the LLM.

MCP-Specific Design Note:

MCP plugins require both user_id and connection_manager:

  1. Per-User Authentication: MCP tools connect to external services that require OAuth2 authentication. Tokens are stored per-user in AuthStorage.

  2. Connection Reuse: All tool calls within a request share connections via the connection_manager, reducing overhead from per-tool-call to per-request per-server.

Parameters:

Name Type Description Default
tools list[McpTool]

List of MCP tools discovered from the server

required
server_name str

Name of the MCP server (used for logging and namespacing)

required
user_id str

User ID for OAuth2 token resolution (REQUIRED)

required
connection_manager McpConnectionManager

Request-scoped connection manager (REQUIRED)

required
authorization str | None

Optional standard authorization header (rarely used with MCP)

None
extra_data_collector

Optional collector for extra response data

None

Raises:

Type Description
ValueError

If user_id or connection_manager is not provided

Example

async with McpConnectionManager(configs, user_id, session_id) as conn_mgr: ... plugin_instance = plugin_class( ... user_id="user123", ... connection_manager=conn_mgr, ... extra_data_collector=collector ... ) ... kernel.add_plugin(plugin_instance, "mcp_github")

Source code in src/sk_agents/mcp_client.py
class McpPlugin(BasePlugin):
    """
    Plugin wrapper that holds MCP tools for Semantic Kernel integration.

    This plugin creates kernel functions with proper type annotations from MCP JSON schemas,
    allowing Semantic Kernel to expose full parameter information to the LLM.

    MCP-Specific Design Note:
    -------------------------
    MCP plugins require both user_id and connection_manager:

    1. **Per-User Authentication**: MCP tools connect to external services that require OAuth2
       authentication. Tokens are stored per-user in AuthStorage.

    2. **Connection Reuse**: All tool calls within a request share connections via the
       connection_manager, reducing overhead from per-tool-call to per-request per-server.

    Args:
        tools: List of MCP tools discovered from the server
        server_name: Name of the MCP server (used for logging and namespacing)
        user_id: User ID for OAuth2 token resolution (REQUIRED)
        connection_manager: Request-scoped connection manager (REQUIRED)
        authorization: Optional standard authorization header (rarely used with MCP)
        extra_data_collector: Optional collector for extra response data

    Raises:
        ValueError: If user_id or connection_manager is not provided

    Example:
        >>> async with McpConnectionManager(configs, user_id, session_id) as conn_mgr:
        ...     plugin_instance = plugin_class(
        ...         user_id="user123",
        ...         connection_manager=conn_mgr,
        ...         extra_data_collector=collector
        ...     )
        ...     kernel.add_plugin(plugin_instance, "mcp_github")
    """

    def __init__(
        self,
        tools: list[McpTool],
        server_name: str,
        user_id: str,
        connection_manager: "McpConnectionManager",
        authorization: str | None = None,
        extra_data_collector=None,
    ):
        if not user_id:
            raise ValueError(
                "MCP plugins require a user_id for per-request OAuth2 token resolution."
            )
        if not connection_manager:
            raise ValueError(
                "MCP plugins require a connection_manager for request-scoped connection reuse. "
                "Create one using McpConnectionManager and pass it to the plugin."
            )

        super().__init__(authorization, extra_data_collector)
        self.tools = tools
        self.server_name = server_name
        self.user_id = user_id
        self.connection_manager = connection_manager

        # Dynamically add kernel functions for each tool
        for tool in tools:
            self._add_tool_function(tool)

    def _add_tool_function(self, tool: McpTool):
        """
        Add a tool as a kernel function with proper type annotations.

        Converts MCP JSON schema to Python type hints so SK can expose
        full parameter information to the LLM.
        """

        # Create a closure that captures the specific tool instance
        def create_tool_function(captured_tool: McpTool):
            # Create unique tool name to avoid collisions
            function_name = f"{self.server_name}_{captured_tool.tool_name}"

            @kernel_function(
                name=function_name,
                description=f"[{self.server_name}] {captured_tool.description}",
            )
            async def tool_function(**kwargs):
                return await captured_tool.invoke(
                    connection_manager=self.connection_manager,
                    **kwargs,
                )

            # CRITICAL FIX: Override __kernel_function_parameters__ after decoration
            # This is the CORRECT way to set function parameters in Semantic Kernel
            # The decorator has already read inspect.signature() (which only sees **kwargs),
            # but we can override the parameters it uses to build the LLM schema
            tool_function.__kernel_function_parameters__ = self._build_sk_parameters(
                captured_tool.input_schema
            )

            return tool_function

        # Create the function and set as attribute
        tool_function = create_tool_function(tool)

        # Sanitize tool name for Python attribute
        attr_name = self._sanitize_name(f"{self.server_name}_{tool.tool_name}")

        setattr(self, attr_name, tool_function)

    def _build_sk_parameters(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
        """
        Build Semantic Kernel parameter dictionaries from MCP JSON schema.

        This creates the parameter metadata in the format expected by
        KernelParameterMetadata, which Semantic Kernel uses to build
        the schema sent to the LLM.

        This is the CORRECT way to override function parameters in SK -
        by setting __kernel_function_parameters__ after decoration.

        Args:
            input_schema: MCP tool's JSON schema for inputs

        Returns:
            List of parameter dictionaries for SK
        """
        if not input_schema or not isinstance(input_schema, dict):
            return []

        properties = input_schema.get("properties", {})
        required = input_schema.get("required", [])
        params = []

        for param_name, param_schema in properties.items():
            if not isinstance(param_schema, dict):
                continue

            # Build parameter dict in SK format
            param_dict = {
                "name": param_name,
                "description": param_schema.get("description", ""),
                "is_required": param_name in required,
                "type_": param_schema.get("type", "string"),  # JSON type string
                "default_value": param_schema.get("default", None),
                "schema_data": param_schema,  # Full JSON schema sent to LLM
            }

            # Add Python type object for better type handling
            json_type = param_schema.get("type", "string")
            param_dict["type_object"] = self._json_type_to_python(json_type)

            params.append(param_dict)

        return params

    @staticmethod
    def _json_type_to_python(json_type: str) -> type:
        """
        Map JSON schema types to Python types.

        Args:
            json_type: JSON schema type string

        Returns:
            Corresponding Python type
        """
        type_map = {
            "string": str,
            "number": float,
            "integer": int,
            "boolean": bool,
            "array": list,
            "object": dict,
        }
        return type_map.get(json_type, str)

    @staticmethod
    def _sanitize_name(name: str) -> str:
        """Sanitize name for Python attribute."""
        sanitized = "".join(c if c.isalnum() or c == "_" else "_" for c in name)
        if not sanitized[0].isalpha() and sanitized[0] != "_":
            sanitized = f"tool_{sanitized}"
        return sanitized
sk_agents.mcp_client.build_auth_storage_key
build_auth_storage_key(
    auth_server: str, scopes: list[str]
) -> str

Create deterministic key for storing OAuth tokens in AuthStorage.

Source code in src/sk_agents/mcp_client.py
def build_auth_storage_key(auth_server: str, scopes: list[str]) -> str:
    """Create deterministic key for storing OAuth tokens in AuthStorage."""
    normalized_scopes = "|".join(sorted(scopes)) if scopes else ""
    return f"{auth_server}|{normalized_scopes}" if normalized_scopes else auth_server
sk_agents.mcp_client.normalize_canonical_uri
normalize_canonical_uri(uri: str) -> str

Normalize URI to canonical format for MCP resource parameter.

Per MCP specification, canonical URI must be: - Absolute URI with scheme - Lowercase scheme and host - Optional port (only if non-standard) - Optional path

Examples:

"HTTPS://API.Example.COM/mcp" -> "https://api.example.com/mcp" "https://example.com:443/mcp" -> "https://example.com/mcp" "https://example.com:8443/mcp" -> "https://example.com:8443/mcp"

Parameters:

Name Type Description Default
uri str

URI to normalize

required

Returns:

Name Type Description
str str

Normalized canonical URI

Raises:

Type Description
ValueError

If URI is invalid or not absolute

Source code in src/sk_agents/mcp_client.py
def normalize_canonical_uri(uri: str) -> str:
    """
    Normalize URI to canonical format for MCP resource parameter.

    Per MCP specification, canonical URI must be:
    - Absolute URI with scheme
    - Lowercase scheme and host
    - Optional port (only if non-standard)
    - Optional path

    Examples:
        "HTTPS://API.Example.COM/mcp" -> "https://api.example.com/mcp"
        "https://example.com:443/mcp" -> "https://example.com/mcp"
        "https://example.com:8443/mcp" -> "https://example.com:8443/mcp"

    Args:
        uri: URI to normalize

    Returns:
        str: Normalized canonical URI

    Raises:
        ValueError: If URI is invalid or not absolute
    """
    from urllib.parse import urlparse, urlunparse

    if not uri:
        raise ValueError("URI cannot be empty")

    # Parse URI
    try:
        parsed = urlparse(uri)
    except Exception as e:
        raise ValueError(f"Invalid URI format: {e}") from e

    # Require absolute URI with scheme
    if not parsed.scheme:
        raise ValueError(f"URI must be absolute with scheme (got: {uri})")

    # Require host
    if not parsed.netloc:
        raise ValueError(f"URI must have a host component (got: {uri})")

    # Normalize scheme and host to lowercase
    scheme = parsed.scheme.lower()
    netloc = parsed.netloc.lower()

    # Remove default ports (80 for http, 443 for https)
    if ":" in netloc:
        host, port = netloc.rsplit(":", 1)
        try:
            port_num = int(port)
            # Remove default ports
            if (scheme == "http" and port_num == 80) or (scheme == "https" and port_num == 443):
                netloc = host
        except ValueError:
            # Not a valid port number, keep as is
            pass

    # Reconstruct canonical URI
    canonical = urlunparse(
        (
            scheme,
            netloc,
            parsed.path or "",  # Include path if present
            "",  # No params
            "",  # No query
            "",  # No fragment
        )
    )

    logger.debug(f"Normalized canonical URI: {uri} -> {canonical}")
    return canonical
sk_agents.mcp_client.validate_https_url
validate_https_url(
    url: str, allow_localhost: bool = True
) -> bool

Validate that URL uses HTTPS (or localhost for development).

Per MCP spec and OAuth 2.1, all endpoints must use HTTPS except localhost.

Parameters:

Name Type Description Default
url str

URL to validate

required
allow_localhost bool

Allow http://localhost or http://127.0.0.1

True

Returns:

Name Type Description
bool bool

True if valid, False otherwise

Source code in src/sk_agents/mcp_client.py
def validate_https_url(url: str, allow_localhost: bool = True) -> bool:
    """
    Validate that URL uses HTTPS (or localhost for development).

    Per MCP spec and OAuth 2.1, all endpoints must use HTTPS except localhost.

    Args:
        url: URL to validate
        allow_localhost: Allow http://localhost or http://127.0.0.1

    Returns:
        bool: True if valid, False otherwise
    """
    from urllib.parse import urlparse

    try:
        parsed = urlparse(url)
        scheme = parsed.scheme.lower()
        hostname = parsed.hostname

        # HTTPS is always valid
        if scheme == "https":
            return True

        # HTTP is only valid for localhost/127.0.0.1/::1 if allowed
        if scheme == "http" and allow_localhost:
            if hostname in ("localhost", "127.0.0.1", "::1"):
                return True

        return False
    except Exception:
        return False
sk_agents.mcp_client.get_package_version
get_package_version() -> str

Get package version for MCP client identification.

Source code in src/sk_agents/mcp_client.py
def get_package_version() -> str:
    """Get package version for MCP client identification."""
    try:
        from importlib.metadata import version

        return version("sk-agents")
    except Exception:
        return "1.0.0"  # Fallback version
sk_agents.mcp_client.validate_mcp_sdk_version
validate_mcp_sdk_version() -> None

Validate MCP SDK version compatibility.

Logs warnings if the installed MCP SDK version is too old to support all features.

Source code in src/sk_agents/mcp_client.py
def validate_mcp_sdk_version() -> None:
    """
    Validate MCP SDK version compatibility.

    Logs warnings if the installed MCP SDK version is too old to support all features.
    """
    try:
        import mcp

        version_str = getattr(mcp, "__version__", "0.0.0")

        # Parse version components
        try:
            from packaging import version as pkg_version

            installed_version = pkg_version.parse(version_str)
            required_version = pkg_version.parse("1.23.0")

            if installed_version < required_version:
                logger.warning(
                    f"MCP SDK version {version_str} detected. "
                    f"Required: >= 1.23.0 for MCP spec 2025-11-25. "
                    f"Please upgrade the MCP SDK."
                )
            else:
                logger.debug(f"MCP SDK version {version_str} is compatible")
        except ImportError:
            # packaging not available, do basic string comparison
            logger.debug(f"MCP SDK version {version_str} (could not validate compatibility)")
    except Exception as e:
        logger.warning(f"Could not validate MCP SDK version: {e}")
sk_agents.mcp_client.initialize_mcp_session async
initialize_mcp_session(
    session: ClientSession,
    server_name: str,
    server_info_obj: Any = None,
    protocol_version: str = "2025-11-25",
) -> Any

Initialize MCP session with proper protocol handshake.

This function handles the complete MCP initialization sequence: 1. Send initialize request with protocol version and capabilities 2. Receive initialization result from server 3. Send initialized notification (required by MCP spec)

Parameters:

Name Type Description Default
session ClientSession

The MCP ClientSession to initialize

required
server_name str

Name of the server for logging purposes

required
server_info_obj Any

Optional server info object for logging

None

Returns:

Type Description
Any

The initialization result from the server

Raises:

Type Description
ConnectionError

If initialization fails

Source code in src/sk_agents/mcp_client.py
async def initialize_mcp_session(
    session: ClientSession,
    server_name: str,
    server_info_obj: Any = None,
    protocol_version: str = "2025-11-25",
) -> Any:
    """
    Initialize MCP session with proper protocol handshake.

    This function handles the complete MCP initialization sequence:
    1. Send initialize request with protocol version and capabilities
    2. Receive initialization result from server
    3. Send initialized notification (required by MCP spec)

    Args:
        session: The MCP ClientSession to initialize
        server_name: Name of the server for logging purposes
        server_info_obj: Optional server info object for logging

    Returns:
        The initialization result from the server

    Raises:
        ConnectionError: If initialization fails
    """
    try:
        # Step 1: Send initialize request (prefers spec path, falls back if SDK lacks args)
        try:
            init_result = await session.initialize(
                protocol_version=protocol_version,
                client_info={"name": "teal-agents", "version": get_package_version()},
                capabilities={
                    # Per MCP spec 2025-11-25: advertise root change notifications if supported
                    "roots": {"listChanged": True},
                    "sampling": {},
                    "experimental": {},
                },
            )
        except TypeError as e:
            # Older SDKs (<=1.22) don't accept keyword args; degrade gracefully.
            if "unexpected keyword argument 'protocol_version'" in str(e):
                logger.warning(
                    f"MCP SDK initialize() does not accept protocol_version/capabilities; "
                    f"falling back to legacy initialize() for '{server_name}'. "
                    f"Upgrade SDK for full 2025-11-25 compliance."
                )
                init_result = await session.initialize()
            else:
                raise

        logger.info(
            f"MCP session initialized for '{server_name}': "
            f"server={getattr(init_result, 'server_info', 'unknown')}, "
            f"protocol={getattr(init_result, 'protocol_version', 'unknown')}"
        )

        # Step 2: Send initialized notification (MCP protocol requirement)
        # Per MCP spec: "After successful initialization, the client MUST send
        # an initialized notification to indicate it is ready to begin normal operations."
        # The spec requires an initialized notification; if SDK lacks it, warn and continue.
        if hasattr(session, "send_initialized"):
            await session.send_initialized()
            logger.debug(f"Sent initialized notification to '{server_name}'")
        elif hasattr(session, "initialized"):
            await session.initialized()
            logger.debug(f"Sent initialized notification to '{server_name}'")
        else:
            logger.warning(
                f"MCP SDK missing initialized notification method for '{server_name}'. "
                f"Upgrade SDK for full spec compliance."
            )

        return init_result

    except Exception as e:
        logger.error(f"Failed to initialize MCP session for '{server_name}': {e}")
        raise ConnectionError(f"MCP session initialization failed for '{server_name}': {e}") from e
sk_agents.mcp_client.graceful_shutdown_session async
graceful_shutdown_session(
    session: ClientSession, server_name: str
) -> None

Attempt graceful MCP session shutdown.

Per MCP spec, clients should attempt to notify servers before disconnecting. This is a best-effort operation and failures are logged but not raised.

Parameters:

Name Type Description Default
session ClientSession

The MCP ClientSession to shutdown

required
server_name str

Name of the server for logging purposes

required
Source code in src/sk_agents/mcp_client.py
async def graceful_shutdown_session(session: ClientSession, server_name: str) -> None:
    """
    Attempt graceful MCP session shutdown.

    Per MCP spec, clients should attempt to notify servers before disconnecting.
    This is a best-effort operation and failures are logged but not raised.

    Args:
        session: The MCP ClientSession to shutdown
        server_name: Name of the server for logging purposes
    """
    try:
        if hasattr(session, "send_shutdown"):
            await session.send_shutdown()
            logger.debug(f"Sent graceful shutdown to MCP server: {server_name}")
        elif hasattr(session, "shutdown"):
            await session.shutdown()
            logger.debug(f"Sent graceful shutdown to MCP server: {server_name}")
        else:
            logger.warning(
                f"MCP SDK missing shutdown method for '{server_name}'. "
                f"Upgrade SDK for full spec compliance."
            )
    except Exception as e:
        logger.debug(f"Graceful shutdown failed for {server_name}: {e}")
sk_agents.mcp_client.map_mcp_annotations_to_governance
map_mcp_annotations_to_governance(
    annotations: dict[str, Any], tool_description: str = ""
) -> Governance

Map MCP tool annotations to Teal Agents governance policies using secure-by-default approach.

Parameters:

Name Type Description Default
annotations dict[str, Any]

MCP tool annotations

required
tool_description str

Tool description for risk analysis

''

Returns:

Name Type Description
Governance Governance

Governance settings for the tool

Source code in src/sk_agents/mcp_client.py
def map_mcp_annotations_to_governance(
    annotations: dict[str, Any], tool_description: str = ""
) -> Governance:
    """
    Map MCP tool annotations to Teal Agents governance policies using secure-by-default approach.

    Args:
        annotations: MCP tool annotations
        tool_description: Tool description for risk analysis

    Returns:
        Governance: Governance settings for the tool
    """
    # SECURE-BY-DEFAULT: Start with HITL required for unknown tools
    requires_hitl = True
    cost = "high"
    data_sensitivity = "sensitive"

    # Only relax restrictions with explicit safe annotations
    read_only_hint = annotations.get("readOnlyHint", False)
    if read_only_hint:
        requires_hitl = False
        cost = "low"
        data_sensitivity = "public"

    # Destructive tools require HITL (already secure)
    destructive_hint = annotations.get("destructiveHint", False)
    if destructive_hint:
        requires_hitl = True
        cost = "high"
        data_sensitivity = "sensitive"

    # Enhanced risk analysis based on tool description
    if tool_description:
        description_lower = tool_description.lower()

        # Network/external access indicators
        if any(
            keyword in description_lower
            for keyword in [
                "http",
                "https",
                "api",
                "network",
                "request",
                "fetch",
                "download",
                "upload",
                "url",
                "web",
                "internet",
                "remote",
                "curl",
                "wget",
            ]
        ):
            requires_hitl = True
            cost = "high"
            data_sensitivity = "sensitive"

        # File system access indicators
        elif any(
            keyword in description_lower
            for keyword in [
                "file",
                "directory",
                "write",
                "delete",
                "create",
                "modify",
                "save",
                "remove",
                "mkdir",
                "rmdir",
                "chmod",
                "move",
                "copy",
            ]
        ):
            requires_hitl = True
            cost = "medium" if not destructive_hint else "high"
            data_sensitivity = "proprietary"

        # Code execution indicators
        elif any(
            keyword in description_lower
            for keyword in ["execute", "run", "command", "shell", "bash", "script", "eval", "exec"]
        ):
            requires_hitl = True
            cost = "high"
            data_sensitivity = "sensitive"

        # Database/storage access
        elif any(
            keyword in description_lower
            for keyword in ["database", "sql", "query", "insert", "update", "delete", "drop"]
        ):
            requires_hitl = True
            cost = "high"
            data_sensitivity = "sensitive"

    return Governance(requires_hitl=requires_hitl, cost=cost, data_sensitivity=data_sensitivity)
sk_agents.mcp_client.apply_trust_level_governance
apply_trust_level_governance(
    base_governance: Governance,
    trust_level: str,
    tool_description: str = "",
) -> Governance

Apply server trust level controls to governance settings.

Trust levels provide defense-in-depth by applying additional security controls based on the server's trust relationship with the platform: - untrusted: Maximum restrictions, force HITL for all operations - sandboxed: Enhanced restrictions, HITL required unless explicitly safe - trusted: Base governance applies, but still enforce safety on detected risks

Parameters:

Name Type Description Default
base_governance Governance

Base governance settings from MCP annotations

required
trust_level str

Server trust level ("trusted", "sandboxed", "untrusted")

required
tool_description str

Tool description for additional risk analysis

''

Returns:

Name Type Description
Governance Governance

Governance with trust level controls applied

Source code in src/sk_agents/mcp_client.py
def apply_trust_level_governance(
    base_governance: Governance, trust_level: str, tool_description: str = ""
) -> Governance:
    """
    Apply server trust level controls to governance settings.

    Trust levels provide defense-in-depth by applying additional security controls
    based on the server's trust relationship with the platform:
    - untrusted: Maximum restrictions, force HITL for all operations
    - sandboxed: Enhanced restrictions, HITL required unless explicitly safe
    - trusted: Base governance applies, but still enforce safety on detected risks

    Args:
        base_governance: Base governance settings from MCP annotations
        trust_level: Server trust level ("trusted", "sandboxed", "untrusted")
        tool_description: Tool description for additional risk analysis

    Returns:
        Governance: Governance with trust level controls applied
    """
    if trust_level == "untrusted":
        # Force HITL for all tools from untrusted servers
        logger.debug("Applying untrusted server governance: forcing HITL")
        return Governance(requires_hitl=True, cost="high", data_sensitivity="sensitive")
    elif trust_level == "sandboxed":
        # Require HITL unless explicitly marked as safe
        # Sandboxed servers get elevated restrictions
        logger.debug("Applying sandboxed server governance: elevated restrictions")
        return Governance(
            requires_hitl=True,  # Force HITL for sandboxed servers
            cost=base_governance.cost
            if base_governance.cost != "low"
            else "medium",  # Elevate cost
            data_sensitivity=base_governance.data_sensitivity,
        )
    else:  # trusted
        # For trusted servers, use base governance but still enforce safety on high-risk operations
        # This provides defense-in-depth even for trusted sources

        # Check if tool description indicates high-risk operations
        # Even for trusted servers, certain operations should require HITL
        description_lower = tool_description.lower()
        high_risk_operations = [
            "delete",
            "remove",
            "drop",
            "truncate",
            "destroy",
            "kill",
            "execute",
            "exec",
            "eval",
            "run command",
            "shell",
            "system",
            "sudo",
            "admin",
            "root",
        ]

        has_high_risk = any(keyword in description_lower for keyword in high_risk_operations)

        if has_high_risk and not base_governance.requires_hitl:
            # Override for high-risk operations even on trusted servers
            logger.debug(
                "Trusted server tool has high-risk indicators in description, "
                "enforcing HITL despite trust level"
            )
            return Governance(
                requires_hitl=True,  # Override to require HITL
                cost="high" if base_governance.cost != "high" else base_governance.cost,
                data_sensitivity=base_governance.data_sensitivity,
            )

        # For non-high-risk operations on trusted servers, use base governance
        logger.debug("Applying trusted server governance: using base governance")
        return base_governance
sk_agents.mcp_client.apply_governance_overrides
apply_governance_overrides(
    base_governance: Governance,
    tool_name: str,
    overrides: dict[str, GovernanceOverride] | None,
) -> Governance

Apply tool-specific governance overrides to base governance settings.

Parameters:

Name Type Description Default
base_governance Governance

Auto-inferred governance from MCP annotations

required
tool_name str

Name of the MCP tool

required
overrides dict[str, GovernanceOverride] | None

Optional governance overrides from server config

required

Returns:

Name Type Description
Governance Governance

Final governance with overrides applied

Source code in src/sk_agents/mcp_client.py
def apply_governance_overrides(
    base_governance: Governance, tool_name: str, overrides: dict[str, GovernanceOverride] | None
) -> Governance:
    """
    Apply tool-specific governance overrides to base governance settings.

    Args:
        base_governance: Auto-inferred governance from MCP annotations
        tool_name: Name of the MCP tool
        overrides: Optional governance overrides from server config

    Returns:
        Governance: Final governance with overrides applied
    """
    if not overrides or tool_name not in overrides:
        return base_governance

    override = overrides[tool_name]

    # Apply selective overrides - only override specified fields
    return Governance(
        requires_hitl=override.requires_hitl
        if override.requires_hitl is not None
        else base_governance.requires_hitl,
        cost=override.cost if override.cost is not None else base_governance.cost,
        data_sensitivity=override.data_sensitivity
        if override.data_sensitivity is not None
        else base_governance.data_sensitivity,
    )
sk_agents.mcp_client.resolve_server_auth_headers async
resolve_server_auth_headers(
    server_config: McpServerConfig,
    user_id: str = "default",
    app_config: AppConfig | None = None,
) -> dict[str, str]

Resolve authentication headers for MCP server connection.

Now supports automatic token refresh with OAuth 2.1 compliance: - Validates token audience matches resource - Automatically refreshes expired tokens - Implements token rotation per OAuth 2.1

Parameters:

Name Type Description Default
server_config McpServerConfig

MCP server configuration

required
user_id str

User ID for auth lookup

'default'

Returns:

Type Description
dict[str, str]

Dict[str, str]: Headers to use for server connection

Raises:

Type Description
AuthRequiredError

If no valid token and refresh fails

Source code in src/sk_agents/mcp_client.py
async def resolve_server_auth_headers(
    server_config: McpServerConfig,
    user_id: str = "default",
    app_config: AppConfig | None = None,
) -> dict[str, str]:
    """
    Resolve authentication headers for MCP server connection.

    Now supports automatic token refresh with OAuth 2.1 compliance:
    - Validates token audience matches resource
    - Automatically refreshes expired tokens
    - Implements token rotation per OAuth 2.1

    Args:
        server_config: MCP server configuration
        user_id: User ID for auth lookup

    Returns:
        Dict[str, str]: Headers to use for server connection

    Raises:
        AuthRequiredError: If no valid token and refresh fails
    """
    headers = {}

    # Optional per-server user header injection (opt-in via config)
    if server_config.user_id_header:
        header_name = server_config.user_id_header
        source = server_config.user_id_source
        if source == "auth" and user_id and user_id != "default":
            headers[header_name] = user_id
            logger.info(f"Set {header_name} from auth user_id for {server_config.name}")
        elif source == "env":
            env_var = server_config.user_id_env_var or header_name.upper()
            env_val = os.getenv(env_var)
            if env_val:
                headers[header_name] = env_val
                logger.info(f"Set {header_name} from env {env_var} for {server_config.name}")
            else:
                logger.warning(
                    f"user_id_source=env configured for {server_config.name} "
                    f"but env var {env_var} is not set"
                )

    # Start with any manually configured headers
    if server_config.headers:
        # If OAuth is configured, filter out Authorization headers (OAuth takes precedence)
        # If OAuth is NOT configured, keep all headers including Authorization
        for header_key, header_value in server_config.headers.items():
            if header_key.lower() == "authorization" and (
                server_config.auth_server and server_config.scopes
            ):
                logger.warning(
                    "Ignoring static Authorization header for MCP server %s (OAuth configured). "
                    "OAuth token will be used instead.",
                    server_config.name,
                )
                continue
            headers[header_key] = header_value

    # Override Arcade-User-Id with runtime user_id; fallback to env when user_id is default/absent
    fallback_arcade_user = os.getenv("ARCADE_USER_ID")
    if user_id and user_id != "default":
        headers["Arcade-User-Id"] = user_id
        logger.info(f"Overriding Arcade-User-Id header with runtime user: {user_id}")
    elif fallback_arcade_user:
        headers["Arcade-User-Id"] = fallback_arcade_user
        logger.info(f"Using fallback Arcade-User-Id from env: {fallback_arcade_user}")

    # Precompute canonical resource URI for HTTP servers (enforce presence for spec compliance)
    resource_uri: str | None = None
    if server_config.transport == "http":
        try:
            resource_uri = server_config.effective_canonical_uri
        except Exception as e:
            logger.error(f"Unable to determine canonical URI for {server_config.name}: {e}")
            raise AuthRequiredError(
                server_name=server_config.name,
                auth_server=server_config.auth_server or "unknown",
                scopes=server_config.scopes or [],
                message=f"Missing or invalid canonical URI for HTTP MCP server "
                f"'{server_config.name}'",
            ) from e

    # If server has OAuth configuration, resolve tokens using OAuth flow
    if server_config.auth_server and server_config.scopes:
        try:
            # Use AuthStorageFactory directly - no wrapper needed
            from datetime import datetime, timedelta

            from sk_agents.auth.oauth_client import OAuthClient
            from sk_agents.auth.oauth_models import RefreshTokenRequest
            from sk_agents.configs import (
                TA_MCP_OAUTH_ENABLE_AUDIENCE_VALIDATION,
                TA_MCP_OAUTH_ENABLE_TOKEN_REFRESH,
            )

            if app_config is None:
                from ska_utils import AppConfig as SkaAppConfig

                app_config = SkaAppConfig()
            auth_storage_factory = AuthStorageFactory(app_config)
            auth_storage = auth_storage_factory.get_auth_storage_manager()

            # Check feature flags
            enable_refresh = (
                app_config.get(TA_MCP_OAUTH_ENABLE_TOKEN_REFRESH.env_name).lower() == "true"
            )
            # Enforce audience/resource validation for HTTP servers regardless of flag
            if server_config.transport == "http":
                enable_audience = True
            else:
                enable_audience = (
                    app_config.get(TA_MCP_OAUTH_ENABLE_AUDIENCE_VALIDATION.env_name).lower()
                    == "true"
                )

            # Generate composite key for OAuth2 token lookup
            composite_key = build_auth_storage_key(server_config.auth_server, server_config.scopes)

            # Retrieve stored auth data
            auth_data = auth_storage.retrieve(user_id, composite_key)

            if not auth_data or not isinstance(auth_data, OAuth2AuthData):
                logger.warning(f"No valid auth token found for MCP server: {server_config.name}")
                raise AuthRequiredError(
                    server_name=server_config.name,
                    auth_server=server_config.auth_server,
                    scopes=server_config.scopes,
                )

            # Validate token for this resource (expiry + audience + resource binding)
            if enable_audience and resource_uri:
                is_valid = auth_data.is_valid_for_resource(resource_uri)
            else:
                # Legacy behavior: only check expiry
                is_valid = auth_data.expires_at > datetime.now(UTC)

            # Token expired or invalid - try refresh
            if not is_valid:
                if enable_refresh and auth_data.refresh_token and resource_uri:
                    logger.info(
                        f"Token expired/invalid for {server_config.name}, attempting refresh"
                    )

                    try:
                        # Initialize OAuth client
                        oauth_client = OAuthClient()

                        # Discover Protected Resource Metadata (RFC 9728) for HTTP MCP
                        has_prm = False
                        if server_config.url:  # Only for HTTP MCP servers
                            try:
                                cache = oauth_client.metadata_cache
                                prm = await cache.fetch_protected_resource_metadata(
                                    server_config.url
                                )
                                has_prm = prm is not None
                                if prm:
                                    logger.debug(
                                        f"Discovered PRM for {server_config.name} "
                                        "during token refresh"
                                    )
                            except Exception as e:
                                logger.debug(f"PRM discovery failed (optional): {e}")
                                has_prm = False

                        # Determine if resource param should be included (MCP spec 2025-06-18)
                        include_resource = oauth_client.should_include_resource_param(
                            protocol_version=server_config.protocol_version, has_prm=has_prm
                        )

                        # Discover token endpoint from authorization server metadata (RFC 8414)
                        token_endpoint = None
                        try:
                            metadata = await oauth_client.metadata_cache.fetch_auth_server_metadata(
                                server_config.auth_server
                            )
                            token_endpoint = str(metadata.token_endpoint)
                            logger.debug(f"Discovered token endpoint for refresh: {token_endpoint}")
                        except Exception as e:
                            logger.debug(f"Failed to discover token endpoint: {e}. Using fallback.")
                            token_endpoint = f"{server_config.auth_server.rstrip('/')}/token"

                        # Build refresh request
                        refresh_request = RefreshTokenRequest(
                            token_endpoint=token_endpoint,
                            refresh_token=auth_data.refresh_token,
                            resource=resource_uri
                            if include_resource
                            else None,  # Conditional per protocol version
                            client_id=server_config.oauth_client_id
                            or app_config.get("TA_OAUTH_CLIENT_NAME"),
                            client_secret=server_config.oauth_client_secret,
                            requested_scopes=auth_data.scopes,  # For scope validation
                        )

                        # Refresh token
                        token_response = await oauth_client.refresh_access_token(refresh_request)

                        # Update auth data with new tokens
                        auth_data.access_token = token_response.access_token
                        auth_data.expires_at = datetime.now(UTC) + timedelta(
                            seconds=token_response.expires_in
                        )
                        auth_data.issued_at = datetime.now(UTC)

                        # Handle refresh token rotation (OAuth 2.1)
                        if token_response.refresh_token:
                            auth_data.refresh_token = token_response.refresh_token
                            logger.debug(f"Refresh token rotated for {server_config.name}")

                        # Update audience if provided
                        if token_response.aud:
                            auth_data.audience = token_response.aud

                        # Store updated auth data
                        auth_storage.store(user_id, composite_key, auth_data)

                        logger.info(f"Successfully refreshed token for {server_config.name}")

                    except httpx.HTTPStatusError as http_error:
                        # Handle 401 WWW-Authenticate challenges
                        if http_error.response.status_code == 401:
                            challenge = OAuthErrorHandler.handle_401_response(
                                dict(http_error.response.headers)
                            )

                            if challenge and OAuthErrorHandler.should_reauthorize(challenge):
                                logger.info(
                                    f"Received 401 with WWW-Authenticate challenge "
                                    f"during token refresh for {server_config.name}. "
                                    f"Error: {challenge.error}, "
                                    f"Description: {challenge.error_description}"
                                )
                                # Extract required scopes from challenge or use configured
                                required_scopes = (
                                    challenge.scopes if challenge.scopes else server_config.scopes
                                )
                                err_msg = challenge.error_description or challenge.error
                                raise AuthRequiredError(
                                    server_name=server_config.name,
                                    auth_server=server_config.auth_server,
                                    scopes=required_scopes,
                                    message=f"Token rejected by server: {err_msg}",
                                ) from http_error

                        # Re-raise other HTTP errors
                        logger.error(
                            f"HTTP error during token refresh for "
                            f"{server_config.name}: {http_error}"
                        )
                        raise AuthRequiredError(
                            server_name=server_config.name,
                            auth_server=server_config.auth_server,
                            scopes=server_config.scopes,
                            message=f"Token refresh HTTP error: {http_error}",
                        ) from http_error

                    except Exception as refresh_error:
                        logger.error(
                            f"Token refresh failed for {server_config.name}: {refresh_error}"
                        )
                        # Refresh failed - require re-authentication
                        raise AuthRequiredError(
                            server_name=server_config.name,
                            auth_server=server_config.auth_server,
                            scopes=server_config.scopes,
                            message=f"Token refresh failed for '{server_config.name}'. "
                            "Re-authentication required.",
                        ) from refresh_error
                else:
                    # Refresh not enabled or no refresh token
                    logger.warning(
                        f"Token expired for {server_config.name} and refresh not available"
                    )
                    raise AuthRequiredError(
                        server_name=server_config.name,
                        auth_server=server_config.auth_server,
                        scopes=server_config.scopes,
                        message=f"Token expired for '{server_config.name}'",
                    )

            # Token is valid (or was successfully refreshed)
            headers["Authorization"] = f"{auth_data.token_type} {auth_data.access_token}"
            logger.info(f"Resolved auth headers for MCP server: {server_config.name}")

        except AuthRequiredError:
            # Re-raise auth errors
            raise
        except Exception as e:
            logger.error(f"Failed to resolve auth for MCP server {server_config.name}: {e}")
            raise AuthRequiredError(
                server_name=server_config.name,
                auth_server=server_config.auth_server if server_config.auth_server else "unknown",
                scopes=server_config.scopes if server_config.scopes else [],
                message=f"Auth resolution failed: {e}",
            ) from e

    # Debug logging: show what headers we're about to send
    safe_headers = {}
    for k, v in headers.items():
        if k.lower() == "authorization":
            # Redact token but show format
            if v.startswith("Bearer "):
                safe_headers[k] = "Bearer [REDACTED]"
            elif v.startswith("ghp_"):
                safe_headers[k] = "ghp_[REDACTED]"
            else:
                safe_headers[k] = "[REDACTED]"
        else:
            safe_headers[k] = v
    logger.info(f"Resolved headers for {server_config.name}: {safe_headers}")

    return headers
sk_agents.mcp_client.revoke_mcp_server_tokens async
revoke_mcp_server_tokens(
    server_config: McpServerConfig, user_id: str = "default"
) -> None

Revoke all tokens for an MCP server.

Useful when: - User logs out - Security incident detected - Server access no longer needed

Parameters:

Name Type Description Default
server_config McpServerConfig

MCP server configuration

required
user_id str

User ID for token lookup

'default'

Raises:

Type Description
Exception

If revocation fails

Source code in src/sk_agents/mcp_client.py
async def revoke_mcp_server_tokens(
    server_config: McpServerConfig, user_id: str = "default"
) -> None:
    """
    Revoke all tokens for an MCP server.

    Useful when:
    - User logs out
    - Security incident detected
    - Server access no longer needed

    Args:
        server_config: MCP server configuration
        user_id: User ID for token lookup

    Raises:
        Exception: If revocation fails
    """
    from ska_utils import AppConfig

    from sk_agents.auth.oauth_client import OAuthClient

    if not server_config.auth_server or not server_config.scopes:
        logger.debug(f"Server {server_config.name} has no OAuth config, skipping revocation")
        return

    app_config = AppConfig()
    auth_storage_factory = AuthStorageFactory(app_config)
    auth_storage = auth_storage_factory.get_auth_storage_manager()
    oauth_client = OAuthClient()

    # Retrieve stored tokens
    composite_key = build_auth_storage_key(server_config.auth_server, server_config.scopes)
    auth_data = auth_storage.retrieve(user_id, composite_key)

    if not auth_data or not isinstance(auth_data, OAuth2AuthData):
        logger.debug(f"No tokens found for {server_config.name}, skipping revocation")
        return

    try:
        # Discover revocation endpoint
        metadata = await oauth_client.metadata_cache.fetch_auth_server_metadata(
            server_config.auth_server
        )

        if not metadata.revocation_endpoint:
            logger.warning(
                f"No revocation_endpoint discovered for {server_config.auth_server}. "
                f"Cannot revoke tokens."
            )
            return

        # Revoke access token
        await oauth_client.revoke_token(
            token=auth_data.access_token,
            revocation_endpoint=str(metadata.revocation_endpoint),
            client_id=server_config.oauth_client_id or app_config.get("TA_OAUTH_CLIENT_NAME"),
            client_secret=server_config.oauth_client_secret,
            token_type_hint="access_token",
        )

        # Revoke refresh token if present
        if auth_data.refresh_token:
            await oauth_client.revoke_token(
                token=auth_data.refresh_token,
                revocation_endpoint=str(metadata.revocation_endpoint),
                client_id=server_config.oauth_client_id or app_config.get("TA_OAUTH_CLIENT_NAME"),
                client_secret=server_config.oauth_client_secret,
                token_type_hint="refresh_token",
            )

        # Remove from storage
        auth_storage.delete(user_id, composite_key)

        logger.info(f"Successfully revoked and removed tokens for {server_config.name}")

    except Exception as e:
        logger.error(f"Failed to revoke tokens for {server_config.name}: {e}")
        raise
sk_agents.mcp_client.create_mcp_session_with_retry async
create_mcp_session_with_retry(
    server_config: McpServerConfig,
    connection_stack: AsyncExitStack,
    user_id: str = "default",
    max_retries: int = 3,
    mcp_session_id: str | None = None,
    on_stale_session: Callable[[str], Awaitable[None]]
    | None = None,
    app_config: AppConfig | None = None,
) -> tuple[ClientSession, Callable[[], str | None]]

Create MCP session with retry logic for transient failures.

This function wraps create_mcp_session with exponential backoff retry logic to handle transient network issues and temporary server unavailability.

Parameters:

Name Type Description Default
server_config McpServerConfig

MCP server configuration

required
connection_stack AsyncExitStack

AsyncExitStack for resource management

required
user_id str

User ID for authentication

'default'
max_retries int

Maximum number of retry attempts (default: 3)

3

Returns:

Name Type Description
ClientSession tuple[ClientSession, Callable[[], str | None]]

Initialized MCP session

Raises:

Type Description
ConnectionError

If all retry attempts fail

ValueError

If server configuration is invalid

Source code in src/sk_agents/mcp_client.py
async def create_mcp_session_with_retry(
    server_config: McpServerConfig,
    connection_stack: AsyncExitStack,
    user_id: str = "default",
    max_retries: int = 3,
    mcp_session_id: str | None = None,
    on_stale_session: Callable[[str], Awaitable[None]] | None = None,
    app_config: AppConfig | None = None,
) -> tuple[ClientSession, Callable[[], str | None]]:
    """
    Create MCP session with retry logic for transient failures.

    This function wraps create_mcp_session with exponential backoff retry logic
    to handle transient network issues and temporary server unavailability.

    Args:
        server_config: MCP server configuration
        connection_stack: AsyncExitStack for resource management
        user_id: User ID for authentication
        max_retries: Maximum number of retry attempts (default: 3)

    Returns:
        ClientSession: Initialized MCP session

    Raises:
        ConnectionError: If all retry attempts fail
        ValueError: If server configuration is invalid
    """
    last_error = None

    for attempt in range(max_retries):
        try:
            session, get_session_id = await create_mcp_session(
                server_config,
                connection_stack,
                user_id,
                mcp_session_id=mcp_session_id,
                app_config=app_config,
            )

            # If we succeed after retries, log it
            if attempt > 0:
                logger.info(
                    f"Successfully connected to MCP server '{server_config.name}' "
                    f"after {attempt + 1} attempt(s)"
                )

            return session, get_session_id

        except (ConnectionError, TimeoutError, OSError) as e:
            last_error = e

            # If the first attempt with a stored session id fails, clear and retry fresh once
            if mcp_session_id and on_stale_session:
                try:
                    await on_stale_session(mcp_session_id)
                except Exception:
                    logger.debug("Failed to clear stale MCP session id during retry path")
                mcp_session_id = None

            # Don't retry on the last attempt
            if attempt < max_retries - 1:
                backoff_seconds = 2**attempt  # 1s, 2s, 4s
                logger.warning(
                    f"MCP connection attempt {attempt + 1}/{max_retries} failed for "
                    f"'{server_config.name}': {e}. Retrying in {backoff_seconds}s..."
                )
                await asyncio.sleep(backoff_seconds)
            else:
                # Final attempt failed
                logger.error(
                    f"Failed to connect to MCP server '{server_config.name}' "
                    f"after {max_retries} attempts"
                )

        except Exception as e:
            # If failure might be due to stale session id, clear once then re-raise
            if mcp_session_id and on_stale_session:
                try:
                    await on_stale_session(mcp_session_id)
                except Exception:
                    logger.debug("Failed to clear stale MCP session id during retry path")
            logger.error(
                f"Non-retryable error connecting to MCP server '{server_config.name}': {e}"
            )
            raise

    # All retries exhausted
    raise ConnectionError(
        f"Failed to connect to MCP server '{server_config.name}' after {max_retries} attempts. "
        f"Last error: {last_error}"
    ) from last_error
sk_agents.mcp_client.create_mcp_session async
create_mcp_session(
    server_config: McpServerConfig,
    connection_stack: AsyncExitStack,
    user_id: str = "default",
    mcp_session_id: str | None = None,
    app_config: AppConfig | None = None,
) -> tuple[ClientSession, Callable[[], str | None]]

Create MCP session using SDK transport factories.

Source code in src/sk_agents/mcp_client.py
async def create_mcp_session(
    server_config: McpServerConfig,
    connection_stack: AsyncExitStack,
    user_id: str = "default",
    mcp_session_id: str | None = None,
    app_config: AppConfig | None = None,
) -> tuple[ClientSession, Callable[[], str | None]]:
    """Create MCP session using SDK transport factories."""
    transport_type = server_config.transport

    if transport_type == "stdio":
        from mcp.client.stdio import stdio_client

        server_params = StdioServerParameters(
            command=server_config.command, args=server_config.args, env=server_config.env or {}
        )

        read, write = await connection_stack.enter_async_context(stdio_client(server_params))
        session = await connection_stack.enter_async_context(ClientSession(read, write))

        await initialize_mcp_session(
            session,
            server_config.name,
            protocol_version=server_config.protocol_version or "2025-11-25",
        )
        return session, (lambda: None)

    elif transport_type == "http":
        # Resolve auth headers for HTTP transport
        resolved_headers = await resolve_server_auth_headers(
            server_config, user_id, app_config=app_config
        )

        # Try streamable HTTP first (preferred), fall back to SSE
        try:
            from mcp.client.streamable_http import streamablehttp_client

            # Create custom httpx client factory if SSL verification is disabled
            httpx_client_factory = None
            if getattr(server_config, "verify_ssl", True) is False:
                logger.warning(
                    f"SSL verification disabled for MCP server '{server_config.name}'. "
                    f"Creating custom httpx client factory with verify=False"
                )

                def create_insecure_http_client(
                    headers: dict[str, str] | None = None,
                    timeout: httpx.Timeout | None = None,
                    auth: httpx.Auth | None = None,
                ) -> httpx.AsyncClient:
                    """Create httpx client with SSL verification disabled."""
                    logger.debug(
                        f"Creating insecure httpx client for {server_config.name} with verify=False"
                    )
                    kwargs: dict[str, Any] = {
                        "follow_redirects": True,
                        "verify": False,  # Disable SSL verification
                    }
                    if timeout is None:
                        kwargs["timeout"] = httpx.Timeout(30.0)
                    else:
                        kwargs["timeout"] = timeout
                    if headers is not None:
                        kwargs["headers"] = headers
                    if auth is not None:
                        kwargs["auth"] = auth

                    logger.debug(f"httpx.AsyncClient kwargs: {kwargs}")
                    return httpx.AsyncClient(**kwargs)

                httpx_client_factory = create_insecure_http_client

            # Build kwargs for streamablehttp_client
            headers_with_session = resolved_headers.copy()
            if mcp_session_id:
                headers_with_session["Mcp-Session-Id"] = mcp_session_id

            client_kwargs = {
                "url": server_config.url,
                "headers": headers_with_session,
                "timeout": server_config.timeout or 30.0,
                "sse_read_timeout": server_config.sse_read_timeout or 300.0,
            }
            if httpx_client_factory is not None:
                client_kwargs["httpx_client_factory"] = httpx_client_factory
                logger.info(
                    f"Passing custom httpx_client_factory to streamablehttp_client "
                    f"for {server_config.name}"
                )
            else:
                logger.debug(
                    f"No custom httpx_client_factory for {server_config.name}, "
                    "using default SSL verification"
                )

            # Use streamable HTTP transport
            read, write, get_session_id = await connection_stack.enter_async_context(
                streamablehttp_client(**client_kwargs)
            )
            session = await connection_stack.enter_async_context(ClientSession(read, write))

            await initialize_mcp_session(
                session,
                server_config.name,
                protocol_version=server_config.protocol_version or "2025-11-25",
            )
            return session, get_session_id

        except ImportError as err:
            raise NotImplementedError(
                "HTTP transport is not available. Please install the MCP SDK with HTTP support"
            ) from err
            # # Fall back to SSE transport if streamable HTTP not available
            # try:
            #     from mcp.client.sse import sse_client

            #     read, write = await connection_stack.enter_async_context(
            #         sse_client(
            #             url=server_config.url,
            #             headers=resolved_headers,
            #             timeout=server_config.timeout or 30.0,
            #             sse_read_timeout=server_config.sse_read_timeout or 300.0
            #         )
            #     )
            #     session = await connection_stack.enter_async_context(
            #         ClientSession(read, write)
            #     )

            #     return session

            # except ImportError:
            #     raise NotImplementedError(
            #         "HTTP transport is not available. "
            #         "Please install the MCP SDK with HTTP support: "
            #         "pip install 'mcp[http]' or 'mcp[sse]'"
            #     )
    else:
        raise ValueError(f"Unsupported transport type: {transport_type}")
sk_agents.mcp_client.get_transport_info
get_transport_info(server_config: McpServerConfig) -> str

Get transport info for logging.

Source code in src/sk_agents/mcp_client.py
def get_transport_info(server_config: McpServerConfig) -> str:
    """Get transport info for logging."""
    if server_config.transport == "stdio":
        # Sanitize sensitive arguments
        safe_args = []
        for arg in server_config.args:
            if any(
                keyword in arg.lower() for keyword in ["token", "key", "secret", "password", "auth"]
            ):
                safe_args.append("[REDACTED]")
            else:
                safe_args.append(arg)
        return f"stdio:{server_config.command} {' '.join(safe_args)}"
    elif server_config.transport == "http":
        # Sanitize URL for logging
        url = server_config.url or ""
        if "?" in url:
            url = url.split("?")[0]
        return f"http:{url}"
    else:
        return f"{server_config.transport}:unknown"
sk_agents.mcp_discovery

MCP State Management Module.

sk_agents.mcp_discovery.DiscoveryManagerFactory

Factory for MCP state manager with dependency injection.

Uses singleton pattern to ensure only one factory instance exists. Dynamically loads state manager implementation based on environment variables.

Configuration

TA_MCP_DISCOVERY_MODULE: Python module containing manager class TA_MCP_DISCOVERY_CLASS: Manager class name

Defaults to InMemoryStateManager for development.

Source code in src/sk_agents/mcp_discovery/discovery_manager_factory.py
class DiscoveryManagerFactory(metaclass=Singleton):
    """
    Factory for MCP state manager with dependency injection.

    Uses singleton pattern to ensure only one factory instance exists.
    Dynamically loads state manager implementation based on
    environment variables.

    Configuration:
        TA_MCP_DISCOVERY_MODULE: Python module containing manager class
        TA_MCP_DISCOVERY_CLASS: Manager class name

    Defaults to InMemoryStateManager for development.
    """

    def __init__(self, app_config: AppConfig):
        """
        Initialize factory with app configuration.

        Args:
            app_config: Application configuration for env vars
        """
        self.app_config = app_config
        self._manager: McpStateManager | None = None  # noqa: F821

    def get_discovery_manager(self) -> "McpStateManager":  # noqa: F821
        """
        Get state manager instance (cached singleton).

        Loads manager implementation on first call based on configuration,
        then caches for subsequent calls.

        Returns:
            McpStateManager instance

        Raises:
            Exception: If manager class cannot be loaded (falls back to in-memory)
        """
        if self._manager is None:
            # Import here to avoid circular dependency
            from sk_agents.configs import TA_MCP_DISCOVERY_CLASS, TA_MCP_DISCOVERY_MODULE

            module_name = self.app_config.get(TA_MCP_DISCOVERY_MODULE.env_name)
            class_name = self.app_config.get(TA_MCP_DISCOVERY_CLASS.env_name)

            try:
                # Dynamic module loading
                module = __import__(module_name, fromlist=[class_name])
                manager_class = getattr(module, class_name)
                self._manager = manager_class(self.app_config)
                logger.info(f"Initialized MCP state manager: {class_name}")

            except Exception as e:
                logger.error(
                    f"Failed to load state manager {class_name} from {module_name}: {e}. "
                    f"Falling back to InMemoryStateManager"
                )

                # Fallback to in-memory implementation
                try:
                    from sk_agents.mcp_discovery.in_memory_discovery_manager import (
                        InMemoryStateManager,
                    )

                    self._manager = InMemoryStateManager(self.app_config)
                    logger.info("Fallback to InMemoryStateManager successful")

                except Exception as fallback_error:
                    logger.critical(
                        f"Failed to load fallback InMemoryStateManager: {fallback_error}"
                    )
                    raise

        return self._manager
sk_agents.mcp_discovery.DiscoveryManagerFactory.__init__
__init__(app_config: AppConfig)

Initialize factory with app configuration.

Parameters:

Name Type Description Default
app_config AppConfig

Application configuration for env vars

required
Source code in src/sk_agents/mcp_discovery/discovery_manager_factory.py
def __init__(self, app_config: AppConfig):
    """
    Initialize factory with app configuration.

    Args:
        app_config: Application configuration for env vars
    """
    self.app_config = app_config
    self._manager: McpStateManager | None = None  # noqa: F821
sk_agents.mcp_discovery.DiscoveryManagerFactory.get_discovery_manager
get_discovery_manager() -> McpStateManager

Get state manager instance (cached singleton).

Loads manager implementation on first call based on configuration, then caches for subsequent calls.

Returns:

Type Description
McpStateManager

McpStateManager instance

Raises:

Type Description
Exception

If manager class cannot be loaded (falls back to in-memory)

Source code in src/sk_agents/mcp_discovery/discovery_manager_factory.py
def get_discovery_manager(self) -> "McpStateManager":  # noqa: F821
    """
    Get state manager instance (cached singleton).

    Loads manager implementation on first call based on configuration,
    then caches for subsequent calls.

    Returns:
        McpStateManager instance

    Raises:
        Exception: If manager class cannot be loaded (falls back to in-memory)
    """
    if self._manager is None:
        # Import here to avoid circular dependency
        from sk_agents.configs import TA_MCP_DISCOVERY_CLASS, TA_MCP_DISCOVERY_MODULE

        module_name = self.app_config.get(TA_MCP_DISCOVERY_MODULE.env_name)
        class_name = self.app_config.get(TA_MCP_DISCOVERY_CLASS.env_name)

        try:
            # Dynamic module loading
            module = __import__(module_name, fromlist=[class_name])
            manager_class = getattr(module, class_name)
            self._manager = manager_class(self.app_config)
            logger.info(f"Initialized MCP state manager: {class_name}")

        except Exception as e:
            logger.error(
                f"Failed to load state manager {class_name} from {module_name}: {e}. "
                f"Falling back to InMemoryStateManager"
            )

            # Fallback to in-memory implementation
            try:
                from sk_agents.mcp_discovery.in_memory_discovery_manager import (
                    InMemoryStateManager,
                )

                self._manager = InMemoryStateManager(self.app_config)
                logger.info("Fallback to InMemoryStateManager successful")

            except Exception as fallback_error:
                logger.critical(
                    f"Failed to load fallback InMemoryStateManager: {fallback_error}"
                )
                raise

    return self._manager
sk_agents.mcp_discovery.InMemoryStateManager

Bases: McpStateManager

In-memory implementation of MCP state manager.

Stores MCP state in memory with thread-safe access. Suitable for: - Development and testing - Single-instance deployments - Scenarios where persistence is not required

Note: State is lost on server restart.

Source code in src/sk_agents/mcp_discovery/in_memory_discovery_manager.py
class InMemoryStateManager(McpStateManager):
    """
    In-memory implementation of MCP state manager.

    Stores MCP state in memory with thread-safe access.
    Suitable for:
    - Development and testing
    - Single-instance deployments
    - Scenarios where persistence is not required

    Note: State is lost on server restart.
    """

    def __init__(self, app_config):
        """
        Initialize in-memory state manager.

        Args:
            app_config: Application configuration (for consistency with other managers)
        """
        self.app_config = app_config
        # Storage: {(user_id, session_id): McpState}
        self._storage: dict[tuple[str, str], McpState] = {}
        self._lock = asyncio.Lock()

    def _make_key(self, user_id: str, session_id: str) -> tuple[str, str]:
        """
        Create composite key for storage.

        Args:
            user_id: User ID
            session_id: Session ID

        Returns:
            Tuple of (user_id, session_id)
        """
        return (user_id, session_id)

    async def create_discovery(self, state: McpState) -> None:
        """
        Create initial MCP state.

        Args:
            state: MCP state to create

        Raises:
            DiscoveryCreateError: If state already exists
        """
        async with self._lock:
            key = self._make_key(state.user_id, state.session_id)
            if key in self._storage:
                raise DiscoveryCreateError(
                    f"MCP state already exists for user={state.user_id}, session={state.session_id}"
                )
            self._storage[key] = state
            logger.debug(f"Created MCP state for user={state.user_id}, session={state.session_id}")

    async def load_discovery(self, user_id: str, session_id: str) -> McpState | None:
        """
        Load MCP state.

        Args:
            user_id: User ID
            session_id: Session ID

        Returns:
            Deep copy of MCP state if exists, None otherwise.
            Returns a copy to prevent external mutations.
        """
        async with self._lock:
            key = self._make_key(user_id, session_id)
            state = self._storage.get(key)
            if state is None:
                return None
            # Return deep copy to prevent external mutations bypassing update_discovery
            return copy.deepcopy(state)

    async def update_discovery(self, state: McpState) -> None:
        """
        Update existing MCP state.

        Args:
            state: Updated MCP state

        Raises:
            DiscoveryUpdateError: If state does not exist
        """
        async with self._lock:
            key = self._make_key(state.user_id, state.session_id)
            if key not in self._storage:
                raise DiscoveryUpdateError(
                    f"MCP state not found for user={state.user_id}, session={state.session_id}"
                )
            self._storage[key] = state
            logger.debug(f"Updated MCP state for user={state.user_id}, session={state.session_id}")

    async def delete_discovery(self, user_id: str, session_id: str) -> None:
        """
        Delete MCP state.

        Args:
            user_id: User ID
            session_id: Session ID
        """
        async with self._lock:
            key = self._make_key(user_id, session_id)
            if key in self._storage:
                del self._storage[key]
                logger.debug(f"Deleted MCP state for user={user_id}, session={session_id}")

    async def mark_completed(self, user_id: str, session_id: str) -> None:
        """
        Mark discovery as completed.

        If state doesn't exist, auto-creates it with discovery_completed=True
        and empty discovered_servers dict. A warning is logged when auto-creating.

        Args:
            user_id: User ID
            session_id: Session ID
        """
        async with self._lock:
            key = self._make_key(user_id, session_id)
            if key in self._storage:
                self._storage[key].discovery_completed = True
                logger.debug(f"Marked discovery completed for user={user_id}, session={session_id}")
            else:
                # Auto-create state if it doesn't exist
                logger.warning(
                    f"MCP state not found for user={user_id}, session={session_id}. "
                    f"Auto-creating with discovery_completed=True."
                )
                state = McpState(
                    user_id=user_id,
                    session_id=session_id,
                    discovered_servers={},
                    discovery_completed=True,
                    created_at=datetime.now(UTC),
                )
                self._storage[key] = state

    async def is_completed(self, user_id: str, session_id: str) -> bool:
        """
        Check if discovery is completed.

        Args:
            user_id: User ID
            session_id: Session ID

        Returns:
            True if discovery completed, False otherwise
        """
        async with self._lock:
            key = self._make_key(user_id, session_id)
            state = self._storage.get(key)
            return state.discovery_completed if state else False

    async def store_mcp_session(
        self, user_id: str, session_id: str, server_name: str, mcp_session_id: str
    ) -> None:
        """
        Store MCP session ID for a server.

        Args:
            user_id: User ID
            session_id: Teal agent session ID
            server_name: Name of the MCP server
            mcp_session_id: MCP session ID from server
        """
        async with self._lock:
            key = self._make_key(user_id, session_id)
            state = self._storage.get(key)

            # Auto-create state if doesn't exist
            if not state:
                logger.warning(
                    f"MCP state not found for user={user_id}, session={session_id}. "
                    f"Auto-creating to store session for {server_name}."
                )
                state = McpState(
                    user_id=user_id,
                    session_id=session_id,
                    discovered_servers={},
                    discovery_completed=False,
                    created_at=datetime.now(UTC),
                )
                self._storage[key] = state

            # Ensure server entry exists and preserve plugin_data if present
            existing_entry = state.discovered_servers.get(server_name, {})
            plugin_data = existing_entry.get("plugin_data")
            state.discovered_servers[server_name] = {
                "plugin_data": plugin_data,
                **(
                    {"session": existing_entry.get("session")}
                    if existing_entry.get("session")
                    else {}
                ),
            }

            # Store session data
            session_bucket = state.discovered_servers[server_name].get("session", {})
            now_iso = datetime.now(UTC).isoformat()
            session_bucket.update(
                {
                    "mcp_session_id": mcp_session_id,
                    "created_at": session_bucket.get("created_at", now_iso),
                    "last_used_at": now_iso,
                }
            )
            state.discovered_servers[server_name]["session"] = session_bucket

            logger.debug(
                f"Stored MCP session {mcp_session_id} for server={server_name}, "
                f"user={user_id}, session={session_id}"
            )

    async def get_mcp_session(self, user_id: str, session_id: str, server_name: str) -> str | None:
        """
        Get MCP session ID for a server.

        Args:
            user_id: User ID
            session_id: Teal agent session ID
            server_name: Name of the MCP server

        Returns:
            MCP session ID if exists, None otherwise
        """
        async with self._lock:
            key = self._make_key(user_id, session_id)
            state = self._storage.get(key)

            if not state:
                return None

            server_data = state.discovered_servers.get(server_name)
            if not server_data:
                return None

            session_bucket = server_data.get("session")
            if not session_bucket:
                return None
            return session_bucket.get("mcp_session_id")

    async def update_session_last_used(
        self, user_id: str, session_id: str, server_name: str
    ) -> None:
        """
        Update last_used timestamp for an MCP session.

        Args:
            user_id: User ID
            session_id: Teal agent session ID
            server_name: Name of the MCP server

        Raises:
            DiscoveryUpdateError: If state or server doesn't exist
        """
        async with self._lock:
            key = self._make_key(user_id, session_id)
            state = self._storage.get(key)

            if not state:
                raise DiscoveryUpdateError(
                    f"MCP state not found for user={user_id}, session={session_id}"
                )

            if server_name not in state.discovered_servers:
                raise DiscoveryUpdateError(
                    f"Server {server_name} not found in state for "
                    f"user={user_id}, session={session_id}"
                )

            session_bucket = state.discovered_servers[server_name].get("session")
            if not session_bucket:
                session_bucket = {}
            session_bucket["last_used_at"] = datetime.now(UTC).isoformat()
            state.discovered_servers[server_name]["session"] = session_bucket

            logger.debug(
                f"Updated last_used for server={server_name}, user={user_id}, session={session_id}"
            )

    async def clear_mcp_session(
        self,
        user_id: str,
        session_id: str,
        server_name: str,
        expected_session_id: str | None = None,
    ) -> None:
        """Remove stored MCP session info for a server if present."""
        async with self._lock:
            key = self._make_key(user_id, session_id)
            state = self._storage.get(key)
            if not state:
                return
            entry = state.discovered_servers.get(server_name)
            if not entry:
                return
            if "session" in entry:
                if expected_session_id:
                    current = entry.get("session", {}).get("mcp_session_id")
                    if current and current != expected_session_id:
                        # Another session already replaced it; do not clear
                        return
                entry.pop("session", None)
                state.discovered_servers[server_name] = entry
                logger.debug(
                    f"Cleared MCP session for server={server_name}, "
                    f"user={user_id}, session={session_id}"
                )
sk_agents.mcp_discovery.InMemoryStateManager.__init__
__init__(app_config)

Initialize in-memory state manager.

Parameters:

Name Type Description Default
app_config

Application configuration (for consistency with other managers)

required
Source code in src/sk_agents/mcp_discovery/in_memory_discovery_manager.py
def __init__(self, app_config):
    """
    Initialize in-memory state manager.

    Args:
        app_config: Application configuration (for consistency with other managers)
    """
    self.app_config = app_config
    # Storage: {(user_id, session_id): McpState}
    self._storage: dict[tuple[str, str], McpState] = {}
    self._lock = asyncio.Lock()
sk_agents.mcp_discovery.InMemoryStateManager.create_discovery async
create_discovery(state: McpState) -> None

Create initial MCP state.

Parameters:

Name Type Description Default
state McpState

MCP state to create

required

Raises:

Type Description
DiscoveryCreateError

If state already exists

Source code in src/sk_agents/mcp_discovery/in_memory_discovery_manager.py
async def create_discovery(self, state: McpState) -> None:
    """
    Create initial MCP state.

    Args:
        state: MCP state to create

    Raises:
        DiscoveryCreateError: If state already exists
    """
    async with self._lock:
        key = self._make_key(state.user_id, state.session_id)
        if key in self._storage:
            raise DiscoveryCreateError(
                f"MCP state already exists for user={state.user_id}, session={state.session_id}"
            )
        self._storage[key] = state
        logger.debug(f"Created MCP state for user={state.user_id}, session={state.session_id}")
sk_agents.mcp_discovery.InMemoryStateManager.load_discovery async
load_discovery(
    user_id: str, session_id: str
) -> McpState | None

Load MCP state.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Session ID

required

Returns:

Type Description
McpState | None

Deep copy of MCP state if exists, None otherwise.

McpState | None

Returns a copy to prevent external mutations.

Source code in src/sk_agents/mcp_discovery/in_memory_discovery_manager.py
async def load_discovery(self, user_id: str, session_id: str) -> McpState | None:
    """
    Load MCP state.

    Args:
        user_id: User ID
        session_id: Session ID

    Returns:
        Deep copy of MCP state if exists, None otherwise.
        Returns a copy to prevent external mutations.
    """
    async with self._lock:
        key = self._make_key(user_id, session_id)
        state = self._storage.get(key)
        if state is None:
            return None
        # Return deep copy to prevent external mutations bypassing update_discovery
        return copy.deepcopy(state)
sk_agents.mcp_discovery.InMemoryStateManager.update_discovery async
update_discovery(state: McpState) -> None

Update existing MCP state.

Parameters:

Name Type Description Default
state McpState

Updated MCP state

required

Raises:

Type Description
DiscoveryUpdateError

If state does not exist

Source code in src/sk_agents/mcp_discovery/in_memory_discovery_manager.py
async def update_discovery(self, state: McpState) -> None:
    """
    Update existing MCP state.

    Args:
        state: Updated MCP state

    Raises:
        DiscoveryUpdateError: If state does not exist
    """
    async with self._lock:
        key = self._make_key(state.user_id, state.session_id)
        if key not in self._storage:
            raise DiscoveryUpdateError(
                f"MCP state not found for user={state.user_id}, session={state.session_id}"
            )
        self._storage[key] = state
        logger.debug(f"Updated MCP state for user={state.user_id}, session={state.session_id}")
sk_agents.mcp_discovery.InMemoryStateManager.delete_discovery async
delete_discovery(user_id: str, session_id: str) -> None

Delete MCP state.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Session ID

required
Source code in src/sk_agents/mcp_discovery/in_memory_discovery_manager.py
async def delete_discovery(self, user_id: str, session_id: str) -> None:
    """
    Delete MCP state.

    Args:
        user_id: User ID
        session_id: Session ID
    """
    async with self._lock:
        key = self._make_key(user_id, session_id)
        if key in self._storage:
            del self._storage[key]
            logger.debug(f"Deleted MCP state for user={user_id}, session={session_id}")
sk_agents.mcp_discovery.InMemoryStateManager.mark_completed async
mark_completed(user_id: str, session_id: str) -> None

Mark discovery as completed.

If state doesn't exist, auto-creates it with discovery_completed=True and empty discovered_servers dict. A warning is logged when auto-creating.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Session ID

required
Source code in src/sk_agents/mcp_discovery/in_memory_discovery_manager.py
async def mark_completed(self, user_id: str, session_id: str) -> None:
    """
    Mark discovery as completed.

    If state doesn't exist, auto-creates it with discovery_completed=True
    and empty discovered_servers dict. A warning is logged when auto-creating.

    Args:
        user_id: User ID
        session_id: Session ID
    """
    async with self._lock:
        key = self._make_key(user_id, session_id)
        if key in self._storage:
            self._storage[key].discovery_completed = True
            logger.debug(f"Marked discovery completed for user={user_id}, session={session_id}")
        else:
            # Auto-create state if it doesn't exist
            logger.warning(
                f"MCP state not found for user={user_id}, session={session_id}. "
                f"Auto-creating with discovery_completed=True."
            )
            state = McpState(
                user_id=user_id,
                session_id=session_id,
                discovered_servers={},
                discovery_completed=True,
                created_at=datetime.now(UTC),
            )
            self._storage[key] = state
sk_agents.mcp_discovery.InMemoryStateManager.is_completed async
is_completed(user_id: str, session_id: str) -> bool

Check if discovery is completed.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Session ID

required

Returns:

Type Description
bool

True if discovery completed, False otherwise

Source code in src/sk_agents/mcp_discovery/in_memory_discovery_manager.py
async def is_completed(self, user_id: str, session_id: str) -> bool:
    """
    Check if discovery is completed.

    Args:
        user_id: User ID
        session_id: Session ID

    Returns:
        True if discovery completed, False otherwise
    """
    async with self._lock:
        key = self._make_key(user_id, session_id)
        state = self._storage.get(key)
        return state.discovery_completed if state else False
sk_agents.mcp_discovery.InMemoryStateManager.store_mcp_session async
store_mcp_session(
    user_id: str,
    session_id: str,
    server_name: str,
    mcp_session_id: str,
) -> None

Store MCP session ID for a server.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Teal agent session ID

required
server_name str

Name of the MCP server

required
mcp_session_id str

MCP session ID from server

required
Source code in src/sk_agents/mcp_discovery/in_memory_discovery_manager.py
async def store_mcp_session(
    self, user_id: str, session_id: str, server_name: str, mcp_session_id: str
) -> None:
    """
    Store MCP session ID for a server.

    Args:
        user_id: User ID
        session_id: Teal agent session ID
        server_name: Name of the MCP server
        mcp_session_id: MCP session ID from server
    """
    async with self._lock:
        key = self._make_key(user_id, session_id)
        state = self._storage.get(key)

        # Auto-create state if doesn't exist
        if not state:
            logger.warning(
                f"MCP state not found for user={user_id}, session={session_id}. "
                f"Auto-creating to store session for {server_name}."
            )
            state = McpState(
                user_id=user_id,
                session_id=session_id,
                discovered_servers={},
                discovery_completed=False,
                created_at=datetime.now(UTC),
            )
            self._storage[key] = state

        # Ensure server entry exists and preserve plugin_data if present
        existing_entry = state.discovered_servers.get(server_name, {})
        plugin_data = existing_entry.get("plugin_data")
        state.discovered_servers[server_name] = {
            "plugin_data": plugin_data,
            **(
                {"session": existing_entry.get("session")}
                if existing_entry.get("session")
                else {}
            ),
        }

        # Store session data
        session_bucket = state.discovered_servers[server_name].get("session", {})
        now_iso = datetime.now(UTC).isoformat()
        session_bucket.update(
            {
                "mcp_session_id": mcp_session_id,
                "created_at": session_bucket.get("created_at", now_iso),
                "last_used_at": now_iso,
            }
        )
        state.discovered_servers[server_name]["session"] = session_bucket

        logger.debug(
            f"Stored MCP session {mcp_session_id} for server={server_name}, "
            f"user={user_id}, session={session_id}"
        )
sk_agents.mcp_discovery.InMemoryStateManager.get_mcp_session async
get_mcp_session(
    user_id: str, session_id: str, server_name: str
) -> str | None

Get MCP session ID for a server.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Teal agent session ID

required
server_name str

Name of the MCP server

required

Returns:

Type Description
str | None

MCP session ID if exists, None otherwise

Source code in src/sk_agents/mcp_discovery/in_memory_discovery_manager.py
async def get_mcp_session(self, user_id: str, session_id: str, server_name: str) -> str | None:
    """
    Get MCP session ID for a server.

    Args:
        user_id: User ID
        session_id: Teal agent session ID
        server_name: Name of the MCP server

    Returns:
        MCP session ID if exists, None otherwise
    """
    async with self._lock:
        key = self._make_key(user_id, session_id)
        state = self._storage.get(key)

        if not state:
            return None

        server_data = state.discovered_servers.get(server_name)
        if not server_data:
            return None

        session_bucket = server_data.get("session")
        if not session_bucket:
            return None
        return session_bucket.get("mcp_session_id")
sk_agents.mcp_discovery.InMemoryStateManager.update_session_last_used async
update_session_last_used(
    user_id: str, session_id: str, server_name: str
) -> None

Update last_used timestamp for an MCP session.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Teal agent session ID

required
server_name str

Name of the MCP server

required

Raises:

Type Description
DiscoveryUpdateError

If state or server doesn't exist

Source code in src/sk_agents/mcp_discovery/in_memory_discovery_manager.py
async def update_session_last_used(
    self, user_id: str, session_id: str, server_name: str
) -> None:
    """
    Update last_used timestamp for an MCP session.

    Args:
        user_id: User ID
        session_id: Teal agent session ID
        server_name: Name of the MCP server

    Raises:
        DiscoveryUpdateError: If state or server doesn't exist
    """
    async with self._lock:
        key = self._make_key(user_id, session_id)
        state = self._storage.get(key)

        if not state:
            raise DiscoveryUpdateError(
                f"MCP state not found for user={user_id}, session={session_id}"
            )

        if server_name not in state.discovered_servers:
            raise DiscoveryUpdateError(
                f"Server {server_name} not found in state for "
                f"user={user_id}, session={session_id}"
            )

        session_bucket = state.discovered_servers[server_name].get("session")
        if not session_bucket:
            session_bucket = {}
        session_bucket["last_used_at"] = datetime.now(UTC).isoformat()
        state.discovered_servers[server_name]["session"] = session_bucket

        logger.debug(
            f"Updated last_used for server={server_name}, user={user_id}, session={session_id}"
        )
sk_agents.mcp_discovery.InMemoryStateManager.clear_mcp_session async
clear_mcp_session(
    user_id: str,
    session_id: str,
    server_name: str,
    expected_session_id: str | None = None,
) -> None

Remove stored MCP session info for a server if present.

Source code in src/sk_agents/mcp_discovery/in_memory_discovery_manager.py
async def clear_mcp_session(
    self,
    user_id: str,
    session_id: str,
    server_name: str,
    expected_session_id: str | None = None,
) -> None:
    """Remove stored MCP session info for a server if present."""
    async with self._lock:
        key = self._make_key(user_id, session_id)
        state = self._storage.get(key)
        if not state:
            return
        entry = state.discovered_servers.get(server_name)
        if not entry:
            return
        if "session" in entry:
            if expected_session_id:
                current = entry.get("session", {}).get("mcp_session_id")
                if current and current != expected_session_id:
                    # Another session already replaced it; do not clear
                    return
            entry.pop("session", None)
            state.discovered_servers[server_name] = entry
            logger.debug(
                f"Cleared MCP session for server={server_name}, "
                f"user={user_id}, session={session_id}"
            )
sk_agents.mcp_discovery.McpState

MCP state for a specific user session.

Stores the results of MCP server discovery and session management including: - Which servers have been discovered - Serialized plugin data for each server - MCP session IDs for stateful servers - Completion status

Scoped to (user_id, session_id) for session-level isolation.

Structure of discovered_servers: { "server_name": { "tools": [...], # Plugin metadata "mcp_session_id": "session-abc123", # Optional, for stateful servers "last_used_at": "2025-01-15T10:30:00Z", # Optional, session activity timestamp "created_at": "2025-01-15T10:00:00Z" # Optional, session creation timestamp } }

Source code in src/sk_agents/mcp_discovery/mcp_discovery_manager.py
class McpState:
    """
    MCP state for a specific user session.

    Stores the results of MCP server discovery and session management including:
    - Which servers have been discovered
    - Serialized plugin data for each server
    - MCP session IDs for stateful servers
    - Completion status

    Scoped to (user_id, session_id) for session-level isolation.

    Structure of discovered_servers:
    {
        "server_name": {
            "tools": [...],  # Plugin metadata
            "mcp_session_id": "session-abc123",  # Optional, for stateful servers
            "last_used_at": "2025-01-15T10:30:00Z",  # Optional, session activity timestamp
            "created_at": "2025-01-15T10:00:00Z"  # Optional, session creation timestamp
        }
    }
    """

    def __init__(
        self,
        user_id: str,
        session_id: str,
        discovered_servers: dict[str, dict],
        discovery_completed: bool,
        created_at: datetime | None = None,
        failed_servers: dict[str, str] | None = None,
    ):
        """
        Initialize MCP state.

        Args:
            user_id: User ID for authentication and scoping
            session_id: Session ID for conversation grouping
            discovered_servers: Mapping of server_name to plugin data and session info
            discovery_completed: Whether discovery has finished successfully
            created_at: Timestamp of state creation (defaults to now)
            failed_servers: Dictionary of failed servers and their error messages
        """
        self.user_id = user_id
        self.session_id = session_id
        self.discovered_servers = discovered_servers
        self.discovery_completed = discovery_completed
        self.created_at = created_at or datetime.now(UTC)
        self.failed_servers = failed_servers or {}
sk_agents.mcp_discovery.McpState.__init__
__init__(
    user_id: str,
    session_id: str,
    discovered_servers: dict[str, dict],
    discovery_completed: bool,
    created_at: datetime | None = None,
    failed_servers: dict[str, str] | None = None,
)

Initialize MCP state.

Parameters:

Name Type Description Default
user_id str

User ID for authentication and scoping

required
session_id str

Session ID for conversation grouping

required
discovered_servers dict[str, dict]

Mapping of server_name to plugin data and session info

required
discovery_completed bool

Whether discovery has finished successfully

required
created_at datetime | None

Timestamp of state creation (defaults to now)

None
failed_servers dict[str, str] | None

Dictionary of failed servers and their error messages

None
Source code in src/sk_agents/mcp_discovery/mcp_discovery_manager.py
def __init__(
    self,
    user_id: str,
    session_id: str,
    discovered_servers: dict[str, dict],
    discovery_completed: bool,
    created_at: datetime | None = None,
    failed_servers: dict[str, str] | None = None,
):
    """
    Initialize MCP state.

    Args:
        user_id: User ID for authentication and scoping
        session_id: Session ID for conversation grouping
        discovered_servers: Mapping of server_name to plugin data and session info
        discovery_completed: Whether discovery has finished successfully
        created_at: Timestamp of state creation (defaults to now)
        failed_servers: Dictionary of failed servers and their error messages
    """
    self.user_id = user_id
    self.session_id = session_id
    self.discovered_servers = discovered_servers
    self.discovery_completed = discovery_completed
    self.created_at = created_at or datetime.now(UTC)
    self.failed_servers = failed_servers or {}
sk_agents.mcp_discovery.McpStateManager

Bases: ABC

Abstract interface for MCP state management (discovery + sessions).

Implementations must provide storage for MCP state scoped to (user_id, session_id) combinations. This enables: - Session-level tool isolation - Shared discovery across tasks in the same session - MCP session persistence for stateful servers - External state storage (Redis, in-memory, etc.)

Pattern matches: - TaskPersistenceManager (for task state) - SecureAuthStorageManager (for OAuth tokens)

Source code in src/sk_agents/mcp_discovery/mcp_discovery_manager.py
class McpStateManager(ABC):
    """
    Abstract interface for MCP state management (discovery + sessions).

    Implementations must provide storage for MCP state scoped to
    (user_id, session_id) combinations. This enables:
    - Session-level tool isolation
    - Shared discovery across tasks in the same session
    - MCP session persistence for stateful servers
    - External state storage (Redis, in-memory, etc.)

    Pattern matches:
    - TaskPersistenceManager (for task state)
    - SecureAuthStorageManager (for OAuth tokens)
    """

    @abstractmethod
    async def create_discovery(self, state: McpState) -> None:
        """
        Create initial state for (user_id, session_id).

        Args:
            state: MCP state to create

        Raises:
            DiscoveryCreateError: If state already exists for this (user_id, session_id)
        """
        pass

    @abstractmethod
    async def load_discovery(self, user_id: str, session_id: str) -> McpState | None:
        """
        Load MCP state for (user_id, session_id).

        Args:
            user_id: User ID
            session_id: Session ID

        Returns:
            MCP state if exists, None otherwise
        """
        pass

    @abstractmethod
    async def update_discovery(self, state: McpState) -> None:
        """
        Update existing MCP state.

        Args:
            state: Updated MCP state

        Raises:
            DiscoveryUpdateError: If state does not exist
        """
        pass

    @abstractmethod
    async def delete_discovery(self, user_id: str, session_id: str) -> None:
        """
        Delete MCP state for (user_id, session_id).

        Args:
            user_id: User ID
            session_id: Session ID
        """
        pass

    @abstractmethod
    async def mark_completed(self, user_id: str, session_id: str) -> None:
        """
        Mark discovery as completed for (user_id, session_id).

        If the state does not exist, it will be created automatically
        with an empty discovered_servers dict and discovery_completed=True.
        A warning will be logged when auto-creating.

        This operation is idempotent - calling it multiple times has the same
        effect as calling it once.

        Args:
            user_id: User ID
            session_id: Session ID
        """
        pass

    @abstractmethod
    async def is_completed(self, user_id: str, session_id: str) -> bool:
        """
        Check if discovery is completed for (user_id, session_id).

        Args:
            user_id: User ID
            session_id: Session ID

        Returns:
            True if discovery completed, False otherwise
        """
        pass

    @abstractmethod
    async def store_mcp_session(
        self, user_id: str, session_id: str, server_name: str, mcp_session_id: str
    ) -> None:
        """
        Store MCP session ID for a server.

        If state doesn't exist, it will be created. If server doesn't exist
        in discovered_servers, it will be added.

        Args:
            user_id: User ID
            session_id: Teal agent session ID
            server_name: Name of the MCP server
            mcp_session_id: MCP session ID from server

        Raises:
            DiscoveryUpdateError: If state update fails
        """
        pass

    @abstractmethod
    async def get_mcp_session(self, user_id: str, session_id: str, server_name: str) -> str | None:
        """
        Get MCP session ID for a server.

        Args:
            user_id: User ID
            session_id: Teal agent session ID
            server_name: Name of the MCP server

        Returns:
            MCP session ID if exists, None otherwise
        """
        pass

    @abstractmethod
    async def update_session_last_used(
        self, user_id: str, session_id: str, server_name: str
    ) -> None:
        """
        Update last_used timestamp for an MCP session.

        Args:
            user_id: User ID
            session_id: Teal agent session ID
            server_name: Name of the MCP server

        Raises:
            DiscoveryUpdateError: If state or server doesn't exist
        """
        pass

    @abstractmethod
    async def clear_mcp_session(
        self,
        user_id: str,
        session_id: str,
        server_name: str,
        expected_session_id: str | None = None,
    ) -> None:
        """
        Clear the stored MCP session for a given server (if present).

        Args:
            user_id: User ID
            session_id: Teal agent session ID
            server_name: Name of the MCP server
            expected_session_id: Optional session id to match before clearing
        """
        pass
sk_agents.mcp_discovery.McpStateManager.create_discovery abstractmethod async
create_discovery(state: McpState) -> None

Create initial state for (user_id, session_id).

Parameters:

Name Type Description Default
state McpState

MCP state to create

required

Raises:

Type Description
DiscoveryCreateError

If state already exists for this (user_id, session_id)

Source code in src/sk_agents/mcp_discovery/mcp_discovery_manager.py
@abstractmethod
async def create_discovery(self, state: McpState) -> None:
    """
    Create initial state for (user_id, session_id).

    Args:
        state: MCP state to create

    Raises:
        DiscoveryCreateError: If state already exists for this (user_id, session_id)
    """
    pass
sk_agents.mcp_discovery.McpStateManager.load_discovery abstractmethod async
load_discovery(
    user_id: str, session_id: str
) -> McpState | None

Load MCP state for (user_id, session_id).

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Session ID

required

Returns:

Type Description
McpState | None

MCP state if exists, None otherwise

Source code in src/sk_agents/mcp_discovery/mcp_discovery_manager.py
@abstractmethod
async def load_discovery(self, user_id: str, session_id: str) -> McpState | None:
    """
    Load MCP state for (user_id, session_id).

    Args:
        user_id: User ID
        session_id: Session ID

    Returns:
        MCP state if exists, None otherwise
    """
    pass
sk_agents.mcp_discovery.McpStateManager.update_discovery abstractmethod async
update_discovery(state: McpState) -> None

Update existing MCP state.

Parameters:

Name Type Description Default
state McpState

Updated MCP state

required

Raises:

Type Description
DiscoveryUpdateError

If state does not exist

Source code in src/sk_agents/mcp_discovery/mcp_discovery_manager.py
@abstractmethod
async def update_discovery(self, state: McpState) -> None:
    """
    Update existing MCP state.

    Args:
        state: Updated MCP state

    Raises:
        DiscoveryUpdateError: If state does not exist
    """
    pass
sk_agents.mcp_discovery.McpStateManager.delete_discovery abstractmethod async
delete_discovery(user_id: str, session_id: str) -> None

Delete MCP state for (user_id, session_id).

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Session ID

required
Source code in src/sk_agents/mcp_discovery/mcp_discovery_manager.py
@abstractmethod
async def delete_discovery(self, user_id: str, session_id: str) -> None:
    """
    Delete MCP state for (user_id, session_id).

    Args:
        user_id: User ID
        session_id: Session ID
    """
    pass
sk_agents.mcp_discovery.McpStateManager.mark_completed abstractmethod async
mark_completed(user_id: str, session_id: str) -> None

Mark discovery as completed for (user_id, session_id).

If the state does not exist, it will be created automatically with an empty discovered_servers dict and discovery_completed=True. A warning will be logged when auto-creating.

This operation is idempotent - calling it multiple times has the same effect as calling it once.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Session ID

required
Source code in src/sk_agents/mcp_discovery/mcp_discovery_manager.py
@abstractmethod
async def mark_completed(self, user_id: str, session_id: str) -> None:
    """
    Mark discovery as completed for (user_id, session_id).

    If the state does not exist, it will be created automatically
    with an empty discovered_servers dict and discovery_completed=True.
    A warning will be logged when auto-creating.

    This operation is idempotent - calling it multiple times has the same
    effect as calling it once.

    Args:
        user_id: User ID
        session_id: Session ID
    """
    pass
sk_agents.mcp_discovery.McpStateManager.is_completed abstractmethod async
is_completed(user_id: str, session_id: str) -> bool

Check if discovery is completed for (user_id, session_id).

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Session ID

required

Returns:

Type Description
bool

True if discovery completed, False otherwise

Source code in src/sk_agents/mcp_discovery/mcp_discovery_manager.py
@abstractmethod
async def is_completed(self, user_id: str, session_id: str) -> bool:
    """
    Check if discovery is completed for (user_id, session_id).

    Args:
        user_id: User ID
        session_id: Session ID

    Returns:
        True if discovery completed, False otherwise
    """
    pass
sk_agents.mcp_discovery.McpStateManager.store_mcp_session abstractmethod async
store_mcp_session(
    user_id: str,
    session_id: str,
    server_name: str,
    mcp_session_id: str,
) -> None

Store MCP session ID for a server.

If state doesn't exist, it will be created. If server doesn't exist in discovered_servers, it will be added.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Teal agent session ID

required
server_name str

Name of the MCP server

required
mcp_session_id str

MCP session ID from server

required

Raises:

Type Description
DiscoveryUpdateError

If state update fails

Source code in src/sk_agents/mcp_discovery/mcp_discovery_manager.py
@abstractmethod
async def store_mcp_session(
    self, user_id: str, session_id: str, server_name: str, mcp_session_id: str
) -> None:
    """
    Store MCP session ID for a server.

    If state doesn't exist, it will be created. If server doesn't exist
    in discovered_servers, it will be added.

    Args:
        user_id: User ID
        session_id: Teal agent session ID
        server_name: Name of the MCP server
        mcp_session_id: MCP session ID from server

    Raises:
        DiscoveryUpdateError: If state update fails
    """
    pass
sk_agents.mcp_discovery.McpStateManager.get_mcp_session abstractmethod async
get_mcp_session(
    user_id: str, session_id: str, server_name: str
) -> str | None

Get MCP session ID for a server.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Teal agent session ID

required
server_name str

Name of the MCP server

required

Returns:

Type Description
str | None

MCP session ID if exists, None otherwise

Source code in src/sk_agents/mcp_discovery/mcp_discovery_manager.py
@abstractmethod
async def get_mcp_session(self, user_id: str, session_id: str, server_name: str) -> str | None:
    """
    Get MCP session ID for a server.

    Args:
        user_id: User ID
        session_id: Teal agent session ID
        server_name: Name of the MCP server

    Returns:
        MCP session ID if exists, None otherwise
    """
    pass
sk_agents.mcp_discovery.McpStateManager.update_session_last_used abstractmethod async
update_session_last_used(
    user_id: str, session_id: str, server_name: str
) -> None

Update last_used timestamp for an MCP session.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Teal agent session ID

required
server_name str

Name of the MCP server

required

Raises:

Type Description
DiscoveryUpdateError

If state or server doesn't exist

Source code in src/sk_agents/mcp_discovery/mcp_discovery_manager.py
@abstractmethod
async def update_session_last_used(
    self, user_id: str, session_id: str, server_name: str
) -> None:
    """
    Update last_used timestamp for an MCP session.

    Args:
        user_id: User ID
        session_id: Teal agent session ID
        server_name: Name of the MCP server

    Raises:
        DiscoveryUpdateError: If state or server doesn't exist
    """
    pass
sk_agents.mcp_discovery.McpStateManager.clear_mcp_session abstractmethod async
clear_mcp_session(
    user_id: str,
    session_id: str,
    server_name: str,
    expected_session_id: str | None = None,
) -> None

Clear the stored MCP session for a given server (if present).

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Teal agent session ID

required
server_name str

Name of the MCP server

required
expected_session_id str | None

Optional session id to match before clearing

None
Source code in src/sk_agents/mcp_discovery/mcp_discovery_manager.py
@abstractmethod
async def clear_mcp_session(
    self,
    user_id: str,
    session_id: str,
    server_name: str,
    expected_session_id: str | None = None,
) -> None:
    """
    Clear the stored MCP session for a given server (if present).

    Args:
        user_id: User ID
        session_id: Teal agent session ID
        server_name: Name of the MCP server
        expected_session_id: Optional session id to match before clearing
    """
    pass
sk_agents.mcp_discovery.RedisStateManager

Bases: McpStateManager

Redis-backed implementation of MCP state manager.

Stores MCP state in Redis for: - Production deployments - Multi-instance horizontal scaling - Persistence across server restarts - Shared state across distributed systems

Uses the same Redis configuration as other components (TA_REDIS_*).

Source code in src/sk_agents/mcp_discovery/redis_discovery_manager.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
class RedisStateManager(McpStateManager):
    """
    Redis-backed implementation of MCP state manager.

    Stores MCP state in Redis for:
    - Production deployments
    - Multi-instance horizontal scaling
    - Persistence across server restarts
    - Shared state across distributed systems

    Uses the same Redis configuration as other components (TA_REDIS_*).
    """

    def __init__(self, app_config: AppConfig, redis_client: Redis | None = None):
        """
        Initialize Redis state manager.

        Args:
            app_config: Application configuration for Redis connection
            redis_client: Optional pre-configured Redis client (for testing)
        """
        self.app_config = app_config
        self.redis = redis_client or self._create_redis_client()
        self.key_prefix = "mcp_state"

        # TTL support: Default to 24 hours (86400 seconds)
        from sk_agents.configs import TA_REDIS_TTL

        ttl_str = self.app_config.get(TA_REDIS_TTL.env_name)
        if ttl_str:
            self.ttl = int(ttl_str)
        else:
            # Default to 24 hours for discovery state
            self.ttl = 86400

        logger.debug(f"Redis state manager initialized with TTL={self.ttl}s")

    async def close(self) -> None:
        """Close Redis connection and cleanup resources."""
        if self.redis:
            await self.redis.close()
            logger.debug("Redis state manager connection closed")

    async def __aenter__(self):
        """Async context manager entry."""
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        """Async context manager exit."""
        await self.close()

    def _create_redis_client(self) -> Redis:
        """
        Create Redis client from app configuration.

        Reuses existing TA_REDIS_* environment variables for consistency
        with other persistence components.

        Returns:
            Configured Redis client

        Raises:
            ValueError: If required Redis config is missing
        """
        from sk_agents.configs import (
            TA_REDIS_DB,
            TA_REDIS_HOST,
            TA_REDIS_PORT,
            TA_REDIS_PWD,
            TA_REDIS_SSL,
        )

        host = self.app_config.get(TA_REDIS_HOST.env_name)
        port_str = self.app_config.get(TA_REDIS_PORT.env_name)
        db_str = self.app_config.get(TA_REDIS_DB.env_name, default="0")
        ssl_str = self.app_config.get(TA_REDIS_SSL.env_name, default="false")
        pwd = self.app_config.get(TA_REDIS_PWD.env_name, default=None)

        if not host:
            raise ValueError("TA_REDIS_HOST must be configured for Redis discovery manager")
        if not port_str:
            raise ValueError("TA_REDIS_PORT must be configured for Redis discovery manager")

        port = int(port_str)
        db = int(db_str)
        ssl = strtobool(ssl_str)

        logger.info(
            f"Creating Redis discovery client: host={host}, port={port}, db={db}, ssl={ssl}"
        )

        return Redis(host=host, port=port, db=db, ssl=ssl, password=pwd)

    def _make_key(self, user_id: str, session_id: str) -> str:
        """
        Create Redis key for storage.

        Format: mcp_state:{user_id}:{session_id}

        Args:
            user_id: User ID
            session_id: Session ID

        Returns:
            Redis key string
        """
        return f"{self.key_prefix}:{user_id}:{session_id}"

    async def create_discovery(self, state: McpState) -> None:
        """
        Create initial MCP state in Redis.

        Args:
            state: MCP state to create

        Raises:
            DiscoveryCreateError: If state already exists
        """
        key = self._make_key(state.user_id, state.session_id)
        exists = await self.redis.exists(key)
        if exists:
            raise DiscoveryCreateError(
                f"MCP state already exists for user={state.user_id}, session={state.session_id}"
            )

        data = self._serialize(state)
        # Set with TTL
        await self.redis.set(key, data, ex=self.ttl)
        logger.debug(
            f"Created Redis MCP state: user={state.user_id}, session={state.session_id}, "
            f"TTL={self.ttl}s"
        )

    async def load_discovery(self, user_id: str, session_id: str) -> McpState | None:
        """
        Load MCP state from Redis.

        Args:
            user_id: User ID
            session_id: Session ID

        Returns:
            MCP state if exists, None otherwise
        """
        key = self._make_key(user_id, session_id)
        data = await self.redis.get(key)
        if not data:
            return None
        return self._deserialize(data, user_id, session_id)

    async def update_discovery(self, state: McpState) -> None:
        """
        Update existing MCP state in Redis.

        Args:
            state: Updated MCP state

        Raises:
            DiscoveryUpdateError: If state does not exist
        """
        key = self._make_key(state.user_id, state.session_id)
        # Check existence before updating
        exists = await self.redis.exists(key)
        if not exists:
            raise DiscoveryUpdateError(
                f"MCP state not found for user={state.user_id}, session={state.session_id}"
            )

        data = self._serialize(state)
        # Update with TTL to extend expiration
        await self.redis.set(key, data, ex=self.ttl)
        logger.debug(f"Updated Redis MCP state: user={state.user_id}, session={state.session_id}")

    async def delete_discovery(self, user_id: str, session_id: str) -> None:
        """
        Delete discovery state from Redis.

        Args:
            user_id: User ID
            session_id: Session ID
        """
        key = self._make_key(user_id, session_id)
        await self.redis.delete(key)
        logger.debug(f"Deleted Redis discovery state: user={user_id}, session={session_id}")

    async def mark_completed(self, user_id: str, session_id: str) -> None:
        """
        Mark discovery as completed in Redis using atomic operation.

        If state doesn't exist, auto-creates it with discovery_completed=True
        and empty discovered_servers dict. A warning is logged when auto-creating.

        Uses Lua script for atomic read-modify-write to prevent race conditions
        in multi-worker deployments.

        Args:
            user_id: User ID
            session_id: Session ID
        """
        key = self._make_key(user_id, session_id)

        # Lua script for atomic mark_completed operation
        lua_script = """
        local key = KEYS[1]
        local ttl = tonumber(ARGV[1])
        local data = redis.call('GET', key)

        if data then
            -- State exists, update discovery_completed field
            local obj = cjson.decode(data)
            obj.discovery_completed = true
            local updated_data = cjson.encode(obj)
            redis.call('SET', key, updated_data, 'EX', ttl)
            return 1
        else
            -- State doesn't exist, return 0 to signal auto-create
            return 0
        end
        """

        result = await self.redis.eval(lua_script, 1, key, self.ttl)

        if result == 1:
            logger.debug(f"Marked discovery completed: user={user_id}, session={session_id}")
        else:
            # Auto-create state if it doesn't exist
            logger.warning(
                f"MCP state not found for user={user_id}, session={session_id}. "
                f"Auto-creating with discovery_completed=True."
            )
            state = McpState(
                user_id=user_id,
                session_id=session_id,
                discovered_servers={},
                discovery_completed=True,
                created_at=datetime.now(UTC),
            )
            data = self._serialize(state)
            await self.redis.set(key, data, ex=self.ttl)
            logger.debug(f"Auto-created discovery state: user={user_id}, session={session_id}")

    async def is_completed(self, user_id: str, session_id: str) -> bool:
        """
        Check if discovery is completed in Redis.

        Args:
            user_id: User ID
            session_id: Session ID

        Returns:
            True if discovery completed, False otherwise
        """
        state = await self.load_discovery(user_id, session_id)
        return state.discovery_completed if state else False

    async def store_mcp_session(
        self, user_id: str, session_id: str, server_name: str, mcp_session_id: str
    ) -> None:
        """
        Store MCP session ID for a server using atomic Lua script.

        Args:
            user_id: User ID
            session_id: Teal agent session ID
            server_name: Name of the MCP server
            mcp_session_id: MCP session ID from server
        """
        key = self._make_key(user_id, session_id)

        # Lua script for atomic store operation
        lua_script = """
        local key = KEYS[1]
        local ttl = tonumber(ARGV[1])
        local server_name = ARGV[2]
        local mcp_session_id = ARGV[3]
        local timestamp = ARGV[4]

        local data = redis.call('GET', key)
        local obj

        if data then
            -- State exists, update it
            obj = cjson.decode(data)
        else
            -- State doesn't exist, create minimal state
            obj = {
                user_id = ARGV[5],
                session_id = ARGV[6],
                discovered_servers = {},
                discovery_completed = false,
                created_at = timestamp
            }
        end

        -- Ensure server entry exists
        if not obj.discovered_servers[server_name] then
            obj.discovered_servers[server_name] = {}
        end

        -- Store session data
        if not obj.discovered_servers[server_name].session then
            obj.discovered_servers[server_name].session = {}
        end

        obj.discovered_servers[server_name].session.mcp_session_id = mcp_session_id
        local sess = obj.discovered_servers[server_name].session
        sess.created_at = sess.created_at or timestamp
        sess.last_used_at = timestamp

        local updated_data = cjson.encode(obj)
        redis.call('SET', key, updated_data, 'EX', ttl)
        return 1
        """

        timestamp = datetime.now(UTC).isoformat()
        await self.redis.eval(
            lua_script,
            1,
            key,
            self.ttl,
            server_name,
            mcp_session_id,
            timestamp,
            user_id,
            session_id,
        )

        logger.debug(
            f"Stored MCP session {mcp_session_id} for server={server_name}, "
            f"user={user_id}, session={session_id}"
        )

    async def get_mcp_session(self, user_id: str, session_id: str, server_name: str) -> str | None:
        """
        Get MCP session ID for a server.

        Args:
            user_id: User ID
            session_id: Teal agent session ID
            server_name: Name of the MCP server

        Returns:
            MCP session ID if exists, None otherwise
        """
        state = await self.load_discovery(user_id, session_id)

        if not state:
            return None

        server_data = state.discovered_servers.get(server_name)
        if not server_data:
            return None

        session_bucket = server_data.get("session")
        if not session_bucket:
            return None

        return session_bucket.get("mcp_session_id")

    async def update_session_last_used(
        self, user_id: str, session_id: str, server_name: str
    ) -> None:
        """
        Update last_used timestamp using atomic Lua script.

        Args:
            user_id: User ID
            session_id: Teal agent session ID
            server_name: Name of the MCP server

        Raises:
            DiscoveryUpdateError: If state or server doesn't exist
        """
        key = self._make_key(user_id, session_id)

        # Lua script for atomic update
        lua_script = """
        local key = KEYS[1]
        local ttl = tonumber(ARGV[1])
        local server_name = ARGV[2]
        local timestamp = ARGV[3]

        local data = redis.call('GET', key)
        if not data then
            return 0  -- State not found
        end

        local obj = cjson.decode(data)

        if not obj.discovered_servers[server_name] then
            return -1  -- Server not found
        end

        if not obj.discovered_servers[server_name].session then
            obj.discovered_servers[server_name].session = {}
        end
        obj.discovered_servers[server_name].session.last_used_at = timestamp

        local updated_data = cjson.encode(obj)
        redis.call('SET', key, updated_data, 'EX', ttl)
        return 1
        """

        timestamp = datetime.now(UTC).isoformat()
        result = await self.redis.eval(lua_script, 1, key, self.ttl, server_name, timestamp)

        if result == 0:
            raise DiscoveryUpdateError(
                f"MCP state not found for user={user_id}, session={session_id}"
            )
        elif result == -1:
            raise DiscoveryUpdateError(
                f"Server {server_name} not found in state for user={user_id}, session={session_id}"
            )

        logger.debug(
            f"Updated last_used for server={server_name}, user={user_id}, session={session_id}"
        )

    async def clear_mcp_session(
        self,
        user_id: str,
        session_id: str,
        server_name: str,
        expected_session_id: str | None = None,
    ) -> None:
        """Remove stored MCP session info for a server if present."""
        key = self._make_key(user_id, session_id)

        lua_script = """
        local key = KEYS[1]
        local server_name = ARGV[1]
        local ttl = tonumber(ARGV[2])
        local expected_session_id = ARGV[3]

        local data = redis.call('GET', key)
        if not data then
            return 0 -- state missing
        end

        local obj = cjson.decode(data)
        if not obj.discovered_servers[server_name] then
            return -1 -- server missing
        end

        -- Only clear if expected matches or no expectation provided
        if obj.discovered_servers[server_name].session then
            local current = obj.discovered_servers[server_name].session.mcp_session_id
            if expected_session_id ~= nil and expected_session_id ~= '' then
                if current ~= expected_session_id then
                    return -2  -- session changed, skip clear
                end
            end
        end

        obj.discovered_servers[server_name].session = nil

        local updated_data = cjson.encode(obj)
        redis.call('SET', key, updated_data, 'EX', ttl)
        return 1
        """

        expected_arg = expected_session_id or ""
        result = await self.redis.eval(lua_script, 1, key, server_name, self.ttl, expected_arg)
        if result == 0:
            logger.debug(
                f"clear_mcp_session: state missing for user={user_id}, session={session_id}"
            )
        elif result == -1:
            logger.debug(
                f"clear_mcp_session: server missing for user={user_id}, "
                f"session={session_id}, server={server_name}"
            )
        elif result == -2:
            logger.debug(
                f"clear_mcp_session: session changed for user={user_id}, "
                f"session={session_id}, server={server_name}"
            )
        else:
            logger.debug(
                f"Cleared MCP session for server={server_name}, "
                f"user={user_id}, session={session_id}"
            )

    def _serialize(self, state: McpState) -> str:
        """
        Serialize MCP state to JSON.

        Args:
            state: MCP state to serialize

        Returns:
            JSON string representation
        """
        return json.dumps(
            {
                "user_id": state.user_id,
                "session_id": state.session_id,
                "discovered_servers": state.discovered_servers,
                "discovery_completed": state.discovery_completed,
                "created_at": state.created_at.isoformat(),
                "failed_servers": state.failed_servers,
            }
        )

    def _deserialize(self, data: str | bytes, user_id: str, session_id: str) -> McpState:
        """
        Deserialize JSON to MCP state object.

        Args:
            data: JSON string or bytes from Redis
            user_id: User ID (for validation)
            session_id: Session ID (for validation)

        Returns:
            McpState object

        Raises:
            ValueError: If deserialized user_id/session_id don't match parameters
        """
        # Handle bytes from Redis
        if isinstance(data, bytes):
            data = data.decode("utf-8")

        obj = json.loads(data)

        # Validate that serialized data matches the key parameters
        if obj["user_id"] != user_id:
            raise ValueError(
                f"Deserialized user_id '{obj['user_id']}' does not match "
                f"expected user_id '{user_id}'"
            )
        if obj["session_id"] != session_id:
            raise ValueError(
                f"Deserialized session_id '{obj['session_id']}' does not match "
                f"expected session_id '{session_id}'"
            )

        return McpState(
            user_id=user_id,
            session_id=session_id,
            discovered_servers=obj["discovered_servers"],
            discovery_completed=obj["discovery_completed"],
            created_at=datetime.fromisoformat(obj["created_at"]),
            failed_servers=obj.get("failed_servers", {}),
        )
sk_agents.mcp_discovery.RedisStateManager.__init__
__init__(
    app_config: AppConfig, redis_client: Redis | None = None
)

Initialize Redis state manager.

Parameters:

Name Type Description Default
app_config AppConfig

Application configuration for Redis connection

required
redis_client Redis | None

Optional pre-configured Redis client (for testing)

None
Source code in src/sk_agents/mcp_discovery/redis_discovery_manager.py
def __init__(self, app_config: AppConfig, redis_client: Redis | None = None):
    """
    Initialize Redis state manager.

    Args:
        app_config: Application configuration for Redis connection
        redis_client: Optional pre-configured Redis client (for testing)
    """
    self.app_config = app_config
    self.redis = redis_client or self._create_redis_client()
    self.key_prefix = "mcp_state"

    # TTL support: Default to 24 hours (86400 seconds)
    from sk_agents.configs import TA_REDIS_TTL

    ttl_str = self.app_config.get(TA_REDIS_TTL.env_name)
    if ttl_str:
        self.ttl = int(ttl_str)
    else:
        # Default to 24 hours for discovery state
        self.ttl = 86400

    logger.debug(f"Redis state manager initialized with TTL={self.ttl}s")
sk_agents.mcp_discovery.RedisStateManager.close async
close() -> None

Close Redis connection and cleanup resources.

Source code in src/sk_agents/mcp_discovery/redis_discovery_manager.py
async def close(self) -> None:
    """Close Redis connection and cleanup resources."""
    if self.redis:
        await self.redis.close()
        logger.debug("Redis state manager connection closed")
sk_agents.mcp_discovery.RedisStateManager.__aenter__ async
__aenter__()

Async context manager entry.

Source code in src/sk_agents/mcp_discovery/redis_discovery_manager.py
async def __aenter__(self):
    """Async context manager entry."""
    return self
sk_agents.mcp_discovery.RedisStateManager.__aexit__ async
__aexit__(exc_type, exc_val, exc_tb)

Async context manager exit.

Source code in src/sk_agents/mcp_discovery/redis_discovery_manager.py
async def __aexit__(self, exc_type, exc_val, exc_tb):
    """Async context manager exit."""
    await self.close()
sk_agents.mcp_discovery.RedisStateManager.create_discovery async
create_discovery(state: McpState) -> None

Create initial MCP state in Redis.

Parameters:

Name Type Description Default
state McpState

MCP state to create

required

Raises:

Type Description
DiscoveryCreateError

If state already exists

Source code in src/sk_agents/mcp_discovery/redis_discovery_manager.py
async def create_discovery(self, state: McpState) -> None:
    """
    Create initial MCP state in Redis.

    Args:
        state: MCP state to create

    Raises:
        DiscoveryCreateError: If state already exists
    """
    key = self._make_key(state.user_id, state.session_id)
    exists = await self.redis.exists(key)
    if exists:
        raise DiscoveryCreateError(
            f"MCP state already exists for user={state.user_id}, session={state.session_id}"
        )

    data = self._serialize(state)
    # Set with TTL
    await self.redis.set(key, data, ex=self.ttl)
    logger.debug(
        f"Created Redis MCP state: user={state.user_id}, session={state.session_id}, "
        f"TTL={self.ttl}s"
    )
sk_agents.mcp_discovery.RedisStateManager.load_discovery async
load_discovery(
    user_id: str, session_id: str
) -> McpState | None

Load MCP state from Redis.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Session ID

required

Returns:

Type Description
McpState | None

MCP state if exists, None otherwise

Source code in src/sk_agents/mcp_discovery/redis_discovery_manager.py
async def load_discovery(self, user_id: str, session_id: str) -> McpState | None:
    """
    Load MCP state from Redis.

    Args:
        user_id: User ID
        session_id: Session ID

    Returns:
        MCP state if exists, None otherwise
    """
    key = self._make_key(user_id, session_id)
    data = await self.redis.get(key)
    if not data:
        return None
    return self._deserialize(data, user_id, session_id)
sk_agents.mcp_discovery.RedisStateManager.update_discovery async
update_discovery(state: McpState) -> None

Update existing MCP state in Redis.

Parameters:

Name Type Description Default
state McpState

Updated MCP state

required

Raises:

Type Description
DiscoveryUpdateError

If state does not exist

Source code in src/sk_agents/mcp_discovery/redis_discovery_manager.py
async def update_discovery(self, state: McpState) -> None:
    """
    Update existing MCP state in Redis.

    Args:
        state: Updated MCP state

    Raises:
        DiscoveryUpdateError: If state does not exist
    """
    key = self._make_key(state.user_id, state.session_id)
    # Check existence before updating
    exists = await self.redis.exists(key)
    if not exists:
        raise DiscoveryUpdateError(
            f"MCP state not found for user={state.user_id}, session={state.session_id}"
        )

    data = self._serialize(state)
    # Update with TTL to extend expiration
    await self.redis.set(key, data, ex=self.ttl)
    logger.debug(f"Updated Redis MCP state: user={state.user_id}, session={state.session_id}")
sk_agents.mcp_discovery.RedisStateManager.delete_discovery async
delete_discovery(user_id: str, session_id: str) -> None

Delete discovery state from Redis.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Session ID

required
Source code in src/sk_agents/mcp_discovery/redis_discovery_manager.py
async def delete_discovery(self, user_id: str, session_id: str) -> None:
    """
    Delete discovery state from Redis.

    Args:
        user_id: User ID
        session_id: Session ID
    """
    key = self._make_key(user_id, session_id)
    await self.redis.delete(key)
    logger.debug(f"Deleted Redis discovery state: user={user_id}, session={session_id}")
sk_agents.mcp_discovery.RedisStateManager.mark_completed async
mark_completed(user_id: str, session_id: str) -> None

Mark discovery as completed in Redis using atomic operation.

If state doesn't exist, auto-creates it with discovery_completed=True and empty discovered_servers dict. A warning is logged when auto-creating.

Uses Lua script for atomic read-modify-write to prevent race conditions in multi-worker deployments.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Session ID

required
Source code in src/sk_agents/mcp_discovery/redis_discovery_manager.py
async def mark_completed(self, user_id: str, session_id: str) -> None:
    """
    Mark discovery as completed in Redis using atomic operation.

    If state doesn't exist, auto-creates it with discovery_completed=True
    and empty discovered_servers dict. A warning is logged when auto-creating.

    Uses Lua script for atomic read-modify-write to prevent race conditions
    in multi-worker deployments.

    Args:
        user_id: User ID
        session_id: Session ID
    """
    key = self._make_key(user_id, session_id)

    # Lua script for atomic mark_completed operation
    lua_script = """
    local key = KEYS[1]
    local ttl = tonumber(ARGV[1])
    local data = redis.call('GET', key)

    if data then
        -- State exists, update discovery_completed field
        local obj = cjson.decode(data)
        obj.discovery_completed = true
        local updated_data = cjson.encode(obj)
        redis.call('SET', key, updated_data, 'EX', ttl)
        return 1
    else
        -- State doesn't exist, return 0 to signal auto-create
        return 0
    end
    """

    result = await self.redis.eval(lua_script, 1, key, self.ttl)

    if result == 1:
        logger.debug(f"Marked discovery completed: user={user_id}, session={session_id}")
    else:
        # Auto-create state if it doesn't exist
        logger.warning(
            f"MCP state not found for user={user_id}, session={session_id}. "
            f"Auto-creating with discovery_completed=True."
        )
        state = McpState(
            user_id=user_id,
            session_id=session_id,
            discovered_servers={},
            discovery_completed=True,
            created_at=datetime.now(UTC),
        )
        data = self._serialize(state)
        await self.redis.set(key, data, ex=self.ttl)
        logger.debug(f"Auto-created discovery state: user={user_id}, session={session_id}")
sk_agents.mcp_discovery.RedisStateManager.is_completed async
is_completed(user_id: str, session_id: str) -> bool

Check if discovery is completed in Redis.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Session ID

required

Returns:

Type Description
bool

True if discovery completed, False otherwise

Source code in src/sk_agents/mcp_discovery/redis_discovery_manager.py
async def is_completed(self, user_id: str, session_id: str) -> bool:
    """
    Check if discovery is completed in Redis.

    Args:
        user_id: User ID
        session_id: Session ID

    Returns:
        True if discovery completed, False otherwise
    """
    state = await self.load_discovery(user_id, session_id)
    return state.discovery_completed if state else False
sk_agents.mcp_discovery.RedisStateManager.store_mcp_session async
store_mcp_session(
    user_id: str,
    session_id: str,
    server_name: str,
    mcp_session_id: str,
) -> None

Store MCP session ID for a server using atomic Lua script.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Teal agent session ID

required
server_name str

Name of the MCP server

required
mcp_session_id str

MCP session ID from server

required
Source code in src/sk_agents/mcp_discovery/redis_discovery_manager.py
async def store_mcp_session(
    self, user_id: str, session_id: str, server_name: str, mcp_session_id: str
) -> None:
    """
    Store MCP session ID for a server using atomic Lua script.

    Args:
        user_id: User ID
        session_id: Teal agent session ID
        server_name: Name of the MCP server
        mcp_session_id: MCP session ID from server
    """
    key = self._make_key(user_id, session_id)

    # Lua script for atomic store operation
    lua_script = """
    local key = KEYS[1]
    local ttl = tonumber(ARGV[1])
    local server_name = ARGV[2]
    local mcp_session_id = ARGV[3]
    local timestamp = ARGV[4]

    local data = redis.call('GET', key)
    local obj

    if data then
        -- State exists, update it
        obj = cjson.decode(data)
    else
        -- State doesn't exist, create minimal state
        obj = {
            user_id = ARGV[5],
            session_id = ARGV[6],
            discovered_servers = {},
            discovery_completed = false,
            created_at = timestamp
        }
    end

    -- Ensure server entry exists
    if not obj.discovered_servers[server_name] then
        obj.discovered_servers[server_name] = {}
    end

    -- Store session data
    if not obj.discovered_servers[server_name].session then
        obj.discovered_servers[server_name].session = {}
    end

    obj.discovered_servers[server_name].session.mcp_session_id = mcp_session_id
    local sess = obj.discovered_servers[server_name].session
    sess.created_at = sess.created_at or timestamp
    sess.last_used_at = timestamp

    local updated_data = cjson.encode(obj)
    redis.call('SET', key, updated_data, 'EX', ttl)
    return 1
    """

    timestamp = datetime.now(UTC).isoformat()
    await self.redis.eval(
        lua_script,
        1,
        key,
        self.ttl,
        server_name,
        mcp_session_id,
        timestamp,
        user_id,
        session_id,
    )

    logger.debug(
        f"Stored MCP session {mcp_session_id} for server={server_name}, "
        f"user={user_id}, session={session_id}"
    )
sk_agents.mcp_discovery.RedisStateManager.get_mcp_session async
get_mcp_session(
    user_id: str, session_id: str, server_name: str
) -> str | None

Get MCP session ID for a server.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Teal agent session ID

required
server_name str

Name of the MCP server

required

Returns:

Type Description
str | None

MCP session ID if exists, None otherwise

Source code in src/sk_agents/mcp_discovery/redis_discovery_manager.py
async def get_mcp_session(self, user_id: str, session_id: str, server_name: str) -> str | None:
    """
    Get MCP session ID for a server.

    Args:
        user_id: User ID
        session_id: Teal agent session ID
        server_name: Name of the MCP server

    Returns:
        MCP session ID if exists, None otherwise
    """
    state = await self.load_discovery(user_id, session_id)

    if not state:
        return None

    server_data = state.discovered_servers.get(server_name)
    if not server_data:
        return None

    session_bucket = server_data.get("session")
    if not session_bucket:
        return None

    return session_bucket.get("mcp_session_id")
sk_agents.mcp_discovery.RedisStateManager.update_session_last_used async
update_session_last_used(
    user_id: str, session_id: str, server_name: str
) -> None

Update last_used timestamp using atomic Lua script.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Teal agent session ID

required
server_name str

Name of the MCP server

required

Raises:

Type Description
DiscoveryUpdateError

If state or server doesn't exist

Source code in src/sk_agents/mcp_discovery/redis_discovery_manager.py
async def update_session_last_used(
    self, user_id: str, session_id: str, server_name: str
) -> None:
    """
    Update last_used timestamp using atomic Lua script.

    Args:
        user_id: User ID
        session_id: Teal agent session ID
        server_name: Name of the MCP server

    Raises:
        DiscoveryUpdateError: If state or server doesn't exist
    """
    key = self._make_key(user_id, session_id)

    # Lua script for atomic update
    lua_script = """
    local key = KEYS[1]
    local ttl = tonumber(ARGV[1])
    local server_name = ARGV[2]
    local timestamp = ARGV[3]

    local data = redis.call('GET', key)
    if not data then
        return 0  -- State not found
    end

    local obj = cjson.decode(data)

    if not obj.discovered_servers[server_name] then
        return -1  -- Server not found
    end

    if not obj.discovered_servers[server_name].session then
        obj.discovered_servers[server_name].session = {}
    end
    obj.discovered_servers[server_name].session.last_used_at = timestamp

    local updated_data = cjson.encode(obj)
    redis.call('SET', key, updated_data, 'EX', ttl)
    return 1
    """

    timestamp = datetime.now(UTC).isoformat()
    result = await self.redis.eval(lua_script, 1, key, self.ttl, server_name, timestamp)

    if result == 0:
        raise DiscoveryUpdateError(
            f"MCP state not found for user={user_id}, session={session_id}"
        )
    elif result == -1:
        raise DiscoveryUpdateError(
            f"Server {server_name} not found in state for user={user_id}, session={session_id}"
        )

    logger.debug(
        f"Updated last_used for server={server_name}, user={user_id}, session={session_id}"
    )
sk_agents.mcp_discovery.RedisStateManager.clear_mcp_session async
clear_mcp_session(
    user_id: str,
    session_id: str,
    server_name: str,
    expected_session_id: str | None = None,
) -> None

Remove stored MCP session info for a server if present.

Source code in src/sk_agents/mcp_discovery/redis_discovery_manager.py
async def clear_mcp_session(
    self,
    user_id: str,
    session_id: str,
    server_name: str,
    expected_session_id: str | None = None,
) -> None:
    """Remove stored MCP session info for a server if present."""
    key = self._make_key(user_id, session_id)

    lua_script = """
    local key = KEYS[1]
    local server_name = ARGV[1]
    local ttl = tonumber(ARGV[2])
    local expected_session_id = ARGV[3]

    local data = redis.call('GET', key)
    if not data then
        return 0 -- state missing
    end

    local obj = cjson.decode(data)
    if not obj.discovered_servers[server_name] then
        return -1 -- server missing
    end

    -- Only clear if expected matches or no expectation provided
    if obj.discovered_servers[server_name].session then
        local current = obj.discovered_servers[server_name].session.mcp_session_id
        if expected_session_id ~= nil and expected_session_id ~= '' then
            if current ~= expected_session_id then
                return -2  -- session changed, skip clear
            end
        end
    end

    obj.discovered_servers[server_name].session = nil

    local updated_data = cjson.encode(obj)
    redis.call('SET', key, updated_data, 'EX', ttl)
    return 1
    """

    expected_arg = expected_session_id or ""
    result = await self.redis.eval(lua_script, 1, key, server_name, self.ttl, expected_arg)
    if result == 0:
        logger.debug(
            f"clear_mcp_session: state missing for user={user_id}, session={session_id}"
        )
    elif result == -1:
        logger.debug(
            f"clear_mcp_session: server missing for user={user_id}, "
            f"session={session_id}, server={server_name}"
        )
    elif result == -2:
        logger.debug(
            f"clear_mcp_session: session changed for user={user_id}, "
            f"session={session_id}, server={server_name}"
        )
    else:
        logger.debug(
            f"Cleared MCP session for server={server_name}, "
            f"user={user_id}, session={session_id}"
        )
sk_agents.mcp_discovery.discovery_manager_factory

MCP State Manager Factory

Provides singleton factory for creating MCP state manager instances with dynamic module loading and dependency injection.

Follows the same pattern as PersistenceFactory and AuthStorageFactory.

sk_agents.mcp_discovery.discovery_manager_factory.DiscoveryManagerFactory

Factory for MCP state manager with dependency injection.

Uses singleton pattern to ensure only one factory instance exists. Dynamically loads state manager implementation based on environment variables.

Configuration

TA_MCP_DISCOVERY_MODULE: Python module containing manager class TA_MCP_DISCOVERY_CLASS: Manager class name

Defaults to InMemoryStateManager for development.

Source code in src/sk_agents/mcp_discovery/discovery_manager_factory.py
class DiscoveryManagerFactory(metaclass=Singleton):
    """
    Factory for MCP state manager with dependency injection.

    Uses singleton pattern to ensure only one factory instance exists.
    Dynamically loads state manager implementation based on
    environment variables.

    Configuration:
        TA_MCP_DISCOVERY_MODULE: Python module containing manager class
        TA_MCP_DISCOVERY_CLASS: Manager class name

    Defaults to InMemoryStateManager for development.
    """

    def __init__(self, app_config: AppConfig):
        """
        Initialize factory with app configuration.

        Args:
            app_config: Application configuration for env vars
        """
        self.app_config = app_config
        self._manager: McpStateManager | None = None  # noqa: F821

    def get_discovery_manager(self) -> "McpStateManager":  # noqa: F821
        """
        Get state manager instance (cached singleton).

        Loads manager implementation on first call based on configuration,
        then caches for subsequent calls.

        Returns:
            McpStateManager instance

        Raises:
            Exception: If manager class cannot be loaded (falls back to in-memory)
        """
        if self._manager is None:
            # Import here to avoid circular dependency
            from sk_agents.configs import TA_MCP_DISCOVERY_CLASS, TA_MCP_DISCOVERY_MODULE

            module_name = self.app_config.get(TA_MCP_DISCOVERY_MODULE.env_name)
            class_name = self.app_config.get(TA_MCP_DISCOVERY_CLASS.env_name)

            try:
                # Dynamic module loading
                module = __import__(module_name, fromlist=[class_name])
                manager_class = getattr(module, class_name)
                self._manager = manager_class(self.app_config)
                logger.info(f"Initialized MCP state manager: {class_name}")

            except Exception as e:
                logger.error(
                    f"Failed to load state manager {class_name} from {module_name}: {e}. "
                    f"Falling back to InMemoryStateManager"
                )

                # Fallback to in-memory implementation
                try:
                    from sk_agents.mcp_discovery.in_memory_discovery_manager import (
                        InMemoryStateManager,
                    )

                    self._manager = InMemoryStateManager(self.app_config)
                    logger.info("Fallback to InMemoryStateManager successful")

                except Exception as fallback_error:
                    logger.critical(
                        f"Failed to load fallback InMemoryStateManager: {fallback_error}"
                    )
                    raise

        return self._manager
sk_agents.mcp_discovery.discovery_manager_factory.DiscoveryManagerFactory.__init__
__init__(app_config: AppConfig)

Initialize factory with app configuration.

Parameters:

Name Type Description Default
app_config AppConfig

Application configuration for env vars

required
Source code in src/sk_agents/mcp_discovery/discovery_manager_factory.py
def __init__(self, app_config: AppConfig):
    """
    Initialize factory with app configuration.

    Args:
        app_config: Application configuration for env vars
    """
    self.app_config = app_config
    self._manager: McpStateManager | None = None  # noqa: F821
sk_agents.mcp_discovery.discovery_manager_factory.DiscoveryManagerFactory.get_discovery_manager
get_discovery_manager() -> McpStateManager

Get state manager instance (cached singleton).

Loads manager implementation on first call based on configuration, then caches for subsequent calls.

Returns:

Type Description
McpStateManager

McpStateManager instance

Raises:

Type Description
Exception

If manager class cannot be loaded (falls back to in-memory)

Source code in src/sk_agents/mcp_discovery/discovery_manager_factory.py
def get_discovery_manager(self) -> "McpStateManager":  # noqa: F821
    """
    Get state manager instance (cached singleton).

    Loads manager implementation on first call based on configuration,
    then caches for subsequent calls.

    Returns:
        McpStateManager instance

    Raises:
        Exception: If manager class cannot be loaded (falls back to in-memory)
    """
    if self._manager is None:
        # Import here to avoid circular dependency
        from sk_agents.configs import TA_MCP_DISCOVERY_CLASS, TA_MCP_DISCOVERY_MODULE

        module_name = self.app_config.get(TA_MCP_DISCOVERY_MODULE.env_name)
        class_name = self.app_config.get(TA_MCP_DISCOVERY_CLASS.env_name)

        try:
            # Dynamic module loading
            module = __import__(module_name, fromlist=[class_name])
            manager_class = getattr(module, class_name)
            self._manager = manager_class(self.app_config)
            logger.info(f"Initialized MCP state manager: {class_name}")

        except Exception as e:
            logger.error(
                f"Failed to load state manager {class_name} from {module_name}: {e}. "
                f"Falling back to InMemoryStateManager"
            )

            # Fallback to in-memory implementation
            try:
                from sk_agents.mcp_discovery.in_memory_discovery_manager import (
                    InMemoryStateManager,
                )

                self._manager = InMemoryStateManager(self.app_config)
                logger.info("Fallback to InMemoryStateManager successful")

            except Exception as fallback_error:
                logger.critical(
                    f"Failed to load fallback InMemoryStateManager: {fallback_error}"
                )
                raise

    return self._manager
sk_agents.mcp_discovery.in_memory_discovery_manager

In-Memory MCP State Manager

Provides in-memory implementation for development and testing. Follows the same pattern as InMemoryPersistenceManager.

sk_agents.mcp_discovery.in_memory_discovery_manager.InMemoryStateManager

Bases: McpStateManager

In-memory implementation of MCP state manager.

Stores MCP state in memory with thread-safe access. Suitable for: - Development and testing - Single-instance deployments - Scenarios where persistence is not required

Note: State is lost on server restart.

Source code in src/sk_agents/mcp_discovery/in_memory_discovery_manager.py
class InMemoryStateManager(McpStateManager):
    """
    In-memory implementation of MCP state manager.

    Stores MCP state in memory with thread-safe access.
    Suitable for:
    - Development and testing
    - Single-instance deployments
    - Scenarios where persistence is not required

    Note: State is lost on server restart.
    """

    def __init__(self, app_config):
        """
        Initialize in-memory state manager.

        Args:
            app_config: Application configuration (for consistency with other managers)
        """
        self.app_config = app_config
        # Storage: {(user_id, session_id): McpState}
        self._storage: dict[tuple[str, str], McpState] = {}
        self._lock = asyncio.Lock()

    def _make_key(self, user_id: str, session_id: str) -> tuple[str, str]:
        """
        Create composite key for storage.

        Args:
            user_id: User ID
            session_id: Session ID

        Returns:
            Tuple of (user_id, session_id)
        """
        return (user_id, session_id)

    async def create_discovery(self, state: McpState) -> None:
        """
        Create initial MCP state.

        Args:
            state: MCP state to create

        Raises:
            DiscoveryCreateError: If state already exists
        """
        async with self._lock:
            key = self._make_key(state.user_id, state.session_id)
            if key in self._storage:
                raise DiscoveryCreateError(
                    f"MCP state already exists for user={state.user_id}, session={state.session_id}"
                )
            self._storage[key] = state
            logger.debug(f"Created MCP state for user={state.user_id}, session={state.session_id}")

    async def load_discovery(self, user_id: str, session_id: str) -> McpState | None:
        """
        Load MCP state.

        Args:
            user_id: User ID
            session_id: Session ID

        Returns:
            Deep copy of MCP state if exists, None otherwise.
            Returns a copy to prevent external mutations.
        """
        async with self._lock:
            key = self._make_key(user_id, session_id)
            state = self._storage.get(key)
            if state is None:
                return None
            # Return deep copy to prevent external mutations bypassing update_discovery
            return copy.deepcopy(state)

    async def update_discovery(self, state: McpState) -> None:
        """
        Update existing MCP state.

        Args:
            state: Updated MCP state

        Raises:
            DiscoveryUpdateError: If state does not exist
        """
        async with self._lock:
            key = self._make_key(state.user_id, state.session_id)
            if key not in self._storage:
                raise DiscoveryUpdateError(
                    f"MCP state not found for user={state.user_id}, session={state.session_id}"
                )
            self._storage[key] = state
            logger.debug(f"Updated MCP state for user={state.user_id}, session={state.session_id}")

    async def delete_discovery(self, user_id: str, session_id: str) -> None:
        """
        Delete MCP state.

        Args:
            user_id: User ID
            session_id: Session ID
        """
        async with self._lock:
            key = self._make_key(user_id, session_id)
            if key in self._storage:
                del self._storage[key]
                logger.debug(f"Deleted MCP state for user={user_id}, session={session_id}")

    async def mark_completed(self, user_id: str, session_id: str) -> None:
        """
        Mark discovery as completed.

        If state doesn't exist, auto-creates it with discovery_completed=True
        and empty discovered_servers dict. A warning is logged when auto-creating.

        Args:
            user_id: User ID
            session_id: Session ID
        """
        async with self._lock:
            key = self._make_key(user_id, session_id)
            if key in self._storage:
                self._storage[key].discovery_completed = True
                logger.debug(f"Marked discovery completed for user={user_id}, session={session_id}")
            else:
                # Auto-create state if it doesn't exist
                logger.warning(
                    f"MCP state not found for user={user_id}, session={session_id}. "
                    f"Auto-creating with discovery_completed=True."
                )
                state = McpState(
                    user_id=user_id,
                    session_id=session_id,
                    discovered_servers={},
                    discovery_completed=True,
                    created_at=datetime.now(UTC),
                )
                self._storage[key] = state

    async def is_completed(self, user_id: str, session_id: str) -> bool:
        """
        Check if discovery is completed.

        Args:
            user_id: User ID
            session_id: Session ID

        Returns:
            True if discovery completed, False otherwise
        """
        async with self._lock:
            key = self._make_key(user_id, session_id)
            state = self._storage.get(key)
            return state.discovery_completed if state else False

    async def store_mcp_session(
        self, user_id: str, session_id: str, server_name: str, mcp_session_id: str
    ) -> None:
        """
        Store MCP session ID for a server.

        Args:
            user_id: User ID
            session_id: Teal agent session ID
            server_name: Name of the MCP server
            mcp_session_id: MCP session ID from server
        """
        async with self._lock:
            key = self._make_key(user_id, session_id)
            state = self._storage.get(key)

            # Auto-create state if doesn't exist
            if not state:
                logger.warning(
                    f"MCP state not found for user={user_id}, session={session_id}. "
                    f"Auto-creating to store session for {server_name}."
                )
                state = McpState(
                    user_id=user_id,
                    session_id=session_id,
                    discovered_servers={},
                    discovery_completed=False,
                    created_at=datetime.now(UTC),
                )
                self._storage[key] = state

            # Ensure server entry exists and preserve plugin_data if present
            existing_entry = state.discovered_servers.get(server_name, {})
            plugin_data = existing_entry.get("plugin_data")
            state.discovered_servers[server_name] = {
                "plugin_data": plugin_data,
                **(
                    {"session": existing_entry.get("session")}
                    if existing_entry.get("session")
                    else {}
                ),
            }

            # Store session data
            session_bucket = state.discovered_servers[server_name].get("session", {})
            now_iso = datetime.now(UTC).isoformat()
            session_bucket.update(
                {
                    "mcp_session_id": mcp_session_id,
                    "created_at": session_bucket.get("created_at", now_iso),
                    "last_used_at": now_iso,
                }
            )
            state.discovered_servers[server_name]["session"] = session_bucket

            logger.debug(
                f"Stored MCP session {mcp_session_id} for server={server_name}, "
                f"user={user_id}, session={session_id}"
            )

    async def get_mcp_session(self, user_id: str, session_id: str, server_name: str) -> str | None:
        """
        Get MCP session ID for a server.

        Args:
            user_id: User ID
            session_id: Teal agent session ID
            server_name: Name of the MCP server

        Returns:
            MCP session ID if exists, None otherwise
        """
        async with self._lock:
            key = self._make_key(user_id, session_id)
            state = self._storage.get(key)

            if not state:
                return None

            server_data = state.discovered_servers.get(server_name)
            if not server_data:
                return None

            session_bucket = server_data.get("session")
            if not session_bucket:
                return None
            return session_bucket.get("mcp_session_id")

    async def update_session_last_used(
        self, user_id: str, session_id: str, server_name: str
    ) -> None:
        """
        Update last_used timestamp for an MCP session.

        Args:
            user_id: User ID
            session_id: Teal agent session ID
            server_name: Name of the MCP server

        Raises:
            DiscoveryUpdateError: If state or server doesn't exist
        """
        async with self._lock:
            key = self._make_key(user_id, session_id)
            state = self._storage.get(key)

            if not state:
                raise DiscoveryUpdateError(
                    f"MCP state not found for user={user_id}, session={session_id}"
                )

            if server_name not in state.discovered_servers:
                raise DiscoveryUpdateError(
                    f"Server {server_name} not found in state for "
                    f"user={user_id}, session={session_id}"
                )

            session_bucket = state.discovered_servers[server_name].get("session")
            if not session_bucket:
                session_bucket = {}
            session_bucket["last_used_at"] = datetime.now(UTC).isoformat()
            state.discovered_servers[server_name]["session"] = session_bucket

            logger.debug(
                f"Updated last_used for server={server_name}, user={user_id}, session={session_id}"
            )

    async def clear_mcp_session(
        self,
        user_id: str,
        session_id: str,
        server_name: str,
        expected_session_id: str | None = None,
    ) -> None:
        """Remove stored MCP session info for a server if present."""
        async with self._lock:
            key = self._make_key(user_id, session_id)
            state = self._storage.get(key)
            if not state:
                return
            entry = state.discovered_servers.get(server_name)
            if not entry:
                return
            if "session" in entry:
                if expected_session_id:
                    current = entry.get("session", {}).get("mcp_session_id")
                    if current and current != expected_session_id:
                        # Another session already replaced it; do not clear
                        return
                entry.pop("session", None)
                state.discovered_servers[server_name] = entry
                logger.debug(
                    f"Cleared MCP session for server={server_name}, "
                    f"user={user_id}, session={session_id}"
                )
sk_agents.mcp_discovery.in_memory_discovery_manager.InMemoryStateManager.__init__
__init__(app_config)

Initialize in-memory state manager.

Parameters:

Name Type Description Default
app_config

Application configuration (for consistency with other managers)

required
Source code in src/sk_agents/mcp_discovery/in_memory_discovery_manager.py
def __init__(self, app_config):
    """
    Initialize in-memory state manager.

    Args:
        app_config: Application configuration (for consistency with other managers)
    """
    self.app_config = app_config
    # Storage: {(user_id, session_id): McpState}
    self._storage: dict[tuple[str, str], McpState] = {}
    self._lock = asyncio.Lock()
sk_agents.mcp_discovery.in_memory_discovery_manager.InMemoryStateManager.create_discovery async
create_discovery(state: McpState) -> None

Create initial MCP state.

Parameters:

Name Type Description Default
state McpState

MCP state to create

required

Raises:

Type Description
DiscoveryCreateError

If state already exists

Source code in src/sk_agents/mcp_discovery/in_memory_discovery_manager.py
async def create_discovery(self, state: McpState) -> None:
    """
    Create initial MCP state.

    Args:
        state: MCP state to create

    Raises:
        DiscoveryCreateError: If state already exists
    """
    async with self._lock:
        key = self._make_key(state.user_id, state.session_id)
        if key in self._storage:
            raise DiscoveryCreateError(
                f"MCP state already exists for user={state.user_id}, session={state.session_id}"
            )
        self._storage[key] = state
        logger.debug(f"Created MCP state for user={state.user_id}, session={state.session_id}")
sk_agents.mcp_discovery.in_memory_discovery_manager.InMemoryStateManager.load_discovery async
load_discovery(
    user_id: str, session_id: str
) -> McpState | None

Load MCP state.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Session ID

required

Returns:

Type Description
McpState | None

Deep copy of MCP state if exists, None otherwise.

McpState | None

Returns a copy to prevent external mutations.

Source code in src/sk_agents/mcp_discovery/in_memory_discovery_manager.py
async def load_discovery(self, user_id: str, session_id: str) -> McpState | None:
    """
    Load MCP state.

    Args:
        user_id: User ID
        session_id: Session ID

    Returns:
        Deep copy of MCP state if exists, None otherwise.
        Returns a copy to prevent external mutations.
    """
    async with self._lock:
        key = self._make_key(user_id, session_id)
        state = self._storage.get(key)
        if state is None:
            return None
        # Return deep copy to prevent external mutations bypassing update_discovery
        return copy.deepcopy(state)
sk_agents.mcp_discovery.in_memory_discovery_manager.InMemoryStateManager.update_discovery async
update_discovery(state: McpState) -> None

Update existing MCP state.

Parameters:

Name Type Description Default
state McpState

Updated MCP state

required

Raises:

Type Description
DiscoveryUpdateError

If state does not exist

Source code in src/sk_agents/mcp_discovery/in_memory_discovery_manager.py
async def update_discovery(self, state: McpState) -> None:
    """
    Update existing MCP state.

    Args:
        state: Updated MCP state

    Raises:
        DiscoveryUpdateError: If state does not exist
    """
    async with self._lock:
        key = self._make_key(state.user_id, state.session_id)
        if key not in self._storage:
            raise DiscoveryUpdateError(
                f"MCP state not found for user={state.user_id}, session={state.session_id}"
            )
        self._storage[key] = state
        logger.debug(f"Updated MCP state for user={state.user_id}, session={state.session_id}")
sk_agents.mcp_discovery.in_memory_discovery_manager.InMemoryStateManager.delete_discovery async
delete_discovery(user_id: str, session_id: str) -> None

Delete MCP state.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Session ID

required
Source code in src/sk_agents/mcp_discovery/in_memory_discovery_manager.py
async def delete_discovery(self, user_id: str, session_id: str) -> None:
    """
    Delete MCP state.

    Args:
        user_id: User ID
        session_id: Session ID
    """
    async with self._lock:
        key = self._make_key(user_id, session_id)
        if key in self._storage:
            del self._storage[key]
            logger.debug(f"Deleted MCP state for user={user_id}, session={session_id}")
sk_agents.mcp_discovery.in_memory_discovery_manager.InMemoryStateManager.mark_completed async
mark_completed(user_id: str, session_id: str) -> None

Mark discovery as completed.

If state doesn't exist, auto-creates it with discovery_completed=True and empty discovered_servers dict. A warning is logged when auto-creating.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Session ID

required
Source code in src/sk_agents/mcp_discovery/in_memory_discovery_manager.py
async def mark_completed(self, user_id: str, session_id: str) -> None:
    """
    Mark discovery as completed.

    If state doesn't exist, auto-creates it with discovery_completed=True
    and empty discovered_servers dict. A warning is logged when auto-creating.

    Args:
        user_id: User ID
        session_id: Session ID
    """
    async with self._lock:
        key = self._make_key(user_id, session_id)
        if key in self._storage:
            self._storage[key].discovery_completed = True
            logger.debug(f"Marked discovery completed for user={user_id}, session={session_id}")
        else:
            # Auto-create state if it doesn't exist
            logger.warning(
                f"MCP state not found for user={user_id}, session={session_id}. "
                f"Auto-creating with discovery_completed=True."
            )
            state = McpState(
                user_id=user_id,
                session_id=session_id,
                discovered_servers={},
                discovery_completed=True,
                created_at=datetime.now(UTC),
            )
            self._storage[key] = state
sk_agents.mcp_discovery.in_memory_discovery_manager.InMemoryStateManager.is_completed async
is_completed(user_id: str, session_id: str) -> bool

Check if discovery is completed.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Session ID

required

Returns:

Type Description
bool

True if discovery completed, False otherwise

Source code in src/sk_agents/mcp_discovery/in_memory_discovery_manager.py
async def is_completed(self, user_id: str, session_id: str) -> bool:
    """
    Check if discovery is completed.

    Args:
        user_id: User ID
        session_id: Session ID

    Returns:
        True if discovery completed, False otherwise
    """
    async with self._lock:
        key = self._make_key(user_id, session_id)
        state = self._storage.get(key)
        return state.discovery_completed if state else False
sk_agents.mcp_discovery.in_memory_discovery_manager.InMemoryStateManager.store_mcp_session async
store_mcp_session(
    user_id: str,
    session_id: str,
    server_name: str,
    mcp_session_id: str,
) -> None

Store MCP session ID for a server.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Teal agent session ID

required
server_name str

Name of the MCP server

required
mcp_session_id str

MCP session ID from server

required
Source code in src/sk_agents/mcp_discovery/in_memory_discovery_manager.py
async def store_mcp_session(
    self, user_id: str, session_id: str, server_name: str, mcp_session_id: str
) -> None:
    """
    Store MCP session ID for a server.

    Args:
        user_id: User ID
        session_id: Teal agent session ID
        server_name: Name of the MCP server
        mcp_session_id: MCP session ID from server
    """
    async with self._lock:
        key = self._make_key(user_id, session_id)
        state = self._storage.get(key)

        # Auto-create state if doesn't exist
        if not state:
            logger.warning(
                f"MCP state not found for user={user_id}, session={session_id}. "
                f"Auto-creating to store session for {server_name}."
            )
            state = McpState(
                user_id=user_id,
                session_id=session_id,
                discovered_servers={},
                discovery_completed=False,
                created_at=datetime.now(UTC),
            )
            self._storage[key] = state

        # Ensure server entry exists and preserve plugin_data if present
        existing_entry = state.discovered_servers.get(server_name, {})
        plugin_data = existing_entry.get("plugin_data")
        state.discovered_servers[server_name] = {
            "plugin_data": plugin_data,
            **(
                {"session": existing_entry.get("session")}
                if existing_entry.get("session")
                else {}
            ),
        }

        # Store session data
        session_bucket = state.discovered_servers[server_name].get("session", {})
        now_iso = datetime.now(UTC).isoformat()
        session_bucket.update(
            {
                "mcp_session_id": mcp_session_id,
                "created_at": session_bucket.get("created_at", now_iso),
                "last_used_at": now_iso,
            }
        )
        state.discovered_servers[server_name]["session"] = session_bucket

        logger.debug(
            f"Stored MCP session {mcp_session_id} for server={server_name}, "
            f"user={user_id}, session={session_id}"
        )
sk_agents.mcp_discovery.in_memory_discovery_manager.InMemoryStateManager.get_mcp_session async
get_mcp_session(
    user_id: str, session_id: str, server_name: str
) -> str | None

Get MCP session ID for a server.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Teal agent session ID

required
server_name str

Name of the MCP server

required

Returns:

Type Description
str | None

MCP session ID if exists, None otherwise

Source code in src/sk_agents/mcp_discovery/in_memory_discovery_manager.py
async def get_mcp_session(self, user_id: str, session_id: str, server_name: str) -> str | None:
    """
    Get MCP session ID for a server.

    Args:
        user_id: User ID
        session_id: Teal agent session ID
        server_name: Name of the MCP server

    Returns:
        MCP session ID if exists, None otherwise
    """
    async with self._lock:
        key = self._make_key(user_id, session_id)
        state = self._storage.get(key)

        if not state:
            return None

        server_data = state.discovered_servers.get(server_name)
        if not server_data:
            return None

        session_bucket = server_data.get("session")
        if not session_bucket:
            return None
        return session_bucket.get("mcp_session_id")
sk_agents.mcp_discovery.in_memory_discovery_manager.InMemoryStateManager.update_session_last_used async
update_session_last_used(
    user_id: str, session_id: str, server_name: str
) -> None

Update last_used timestamp for an MCP session.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Teal agent session ID

required
server_name str

Name of the MCP server

required

Raises:

Type Description
DiscoveryUpdateError

If state or server doesn't exist

Source code in src/sk_agents/mcp_discovery/in_memory_discovery_manager.py
async def update_session_last_used(
    self, user_id: str, session_id: str, server_name: str
) -> None:
    """
    Update last_used timestamp for an MCP session.

    Args:
        user_id: User ID
        session_id: Teal agent session ID
        server_name: Name of the MCP server

    Raises:
        DiscoveryUpdateError: If state or server doesn't exist
    """
    async with self._lock:
        key = self._make_key(user_id, session_id)
        state = self._storage.get(key)

        if not state:
            raise DiscoveryUpdateError(
                f"MCP state not found for user={user_id}, session={session_id}"
            )

        if server_name not in state.discovered_servers:
            raise DiscoveryUpdateError(
                f"Server {server_name} not found in state for "
                f"user={user_id}, session={session_id}"
            )

        session_bucket = state.discovered_servers[server_name].get("session")
        if not session_bucket:
            session_bucket = {}
        session_bucket["last_used_at"] = datetime.now(UTC).isoformat()
        state.discovered_servers[server_name]["session"] = session_bucket

        logger.debug(
            f"Updated last_used for server={server_name}, user={user_id}, session={session_id}"
        )
sk_agents.mcp_discovery.in_memory_discovery_manager.InMemoryStateManager.clear_mcp_session async
clear_mcp_session(
    user_id: str,
    session_id: str,
    server_name: str,
    expected_session_id: str | None = None,
) -> None

Remove stored MCP session info for a server if present.

Source code in src/sk_agents/mcp_discovery/in_memory_discovery_manager.py
async def clear_mcp_session(
    self,
    user_id: str,
    session_id: str,
    server_name: str,
    expected_session_id: str | None = None,
) -> None:
    """Remove stored MCP session info for a server if present."""
    async with self._lock:
        key = self._make_key(user_id, session_id)
        state = self._storage.get(key)
        if not state:
            return
        entry = state.discovered_servers.get(server_name)
        if not entry:
            return
        if "session" in entry:
            if expected_session_id:
                current = entry.get("session", {}).get("mcp_session_id")
                if current and current != expected_session_id:
                    # Another session already replaced it; do not clear
                    return
            entry.pop("session", None)
            state.discovered_servers[server_name] = entry
            logger.debug(
                f"Cleared MCP session for server={server_name}, "
                f"user={user_id}, session={session_id}"
            )
sk_agents.mcp_discovery.mcp_discovery_manager

MCP State Manager - Abstract Interface

Provides abstract base class for managing MCP tool discovery and session state. Follows the same pattern as TaskPersistenceManager and SecureAuthStorageManager.

sk_agents.mcp_discovery.mcp_discovery_manager.DiscoveryError

Bases: Exception

Base exception for MCP state manager errors.

Source code in src/sk_agents/mcp_discovery/mcp_discovery_manager.py
class DiscoveryError(Exception):
    """Base exception for MCP state manager errors."""

    pass
sk_agents.mcp_discovery.mcp_discovery_manager.DiscoveryCreateError

Bases: DiscoveryError

Raised when state creation fails.

Source code in src/sk_agents/mcp_discovery/mcp_discovery_manager.py
class DiscoveryCreateError(DiscoveryError):
    """Raised when state creation fails."""

    pass
sk_agents.mcp_discovery.mcp_discovery_manager.DiscoveryUpdateError

Bases: DiscoveryError

Raised when state update fails.

Source code in src/sk_agents/mcp_discovery/mcp_discovery_manager.py
class DiscoveryUpdateError(DiscoveryError):
    """Raised when state update fails."""

    pass
sk_agents.mcp_discovery.mcp_discovery_manager.McpState

MCP state for a specific user session.

Stores the results of MCP server discovery and session management including: - Which servers have been discovered - Serialized plugin data for each server - MCP session IDs for stateful servers - Completion status

Scoped to (user_id, session_id) for session-level isolation.

Structure of discovered_servers: { "server_name": { "tools": [...], # Plugin metadata "mcp_session_id": "session-abc123", # Optional, for stateful servers "last_used_at": "2025-01-15T10:30:00Z", # Optional, session activity timestamp "created_at": "2025-01-15T10:00:00Z" # Optional, session creation timestamp } }

Source code in src/sk_agents/mcp_discovery/mcp_discovery_manager.py
class McpState:
    """
    MCP state for a specific user session.

    Stores the results of MCP server discovery and session management including:
    - Which servers have been discovered
    - Serialized plugin data for each server
    - MCP session IDs for stateful servers
    - Completion status

    Scoped to (user_id, session_id) for session-level isolation.

    Structure of discovered_servers:
    {
        "server_name": {
            "tools": [...],  # Plugin metadata
            "mcp_session_id": "session-abc123",  # Optional, for stateful servers
            "last_used_at": "2025-01-15T10:30:00Z",  # Optional, session activity timestamp
            "created_at": "2025-01-15T10:00:00Z"  # Optional, session creation timestamp
        }
    }
    """

    def __init__(
        self,
        user_id: str,
        session_id: str,
        discovered_servers: dict[str, dict],
        discovery_completed: bool,
        created_at: datetime | None = None,
        failed_servers: dict[str, str] | None = None,
    ):
        """
        Initialize MCP state.

        Args:
            user_id: User ID for authentication and scoping
            session_id: Session ID for conversation grouping
            discovered_servers: Mapping of server_name to plugin data and session info
            discovery_completed: Whether discovery has finished successfully
            created_at: Timestamp of state creation (defaults to now)
            failed_servers: Dictionary of failed servers and their error messages
        """
        self.user_id = user_id
        self.session_id = session_id
        self.discovered_servers = discovered_servers
        self.discovery_completed = discovery_completed
        self.created_at = created_at or datetime.now(UTC)
        self.failed_servers = failed_servers or {}
sk_agents.mcp_discovery.mcp_discovery_manager.McpState.__init__
__init__(
    user_id: str,
    session_id: str,
    discovered_servers: dict[str, dict],
    discovery_completed: bool,
    created_at: datetime | None = None,
    failed_servers: dict[str, str] | None = None,
)

Initialize MCP state.

Parameters:

Name Type Description Default
user_id str

User ID for authentication and scoping

required
session_id str

Session ID for conversation grouping

required
discovered_servers dict[str, dict]

Mapping of server_name to plugin data and session info

required
discovery_completed bool

Whether discovery has finished successfully

required
created_at datetime | None

Timestamp of state creation (defaults to now)

None
failed_servers dict[str, str] | None

Dictionary of failed servers and their error messages

None
Source code in src/sk_agents/mcp_discovery/mcp_discovery_manager.py
def __init__(
    self,
    user_id: str,
    session_id: str,
    discovered_servers: dict[str, dict],
    discovery_completed: bool,
    created_at: datetime | None = None,
    failed_servers: dict[str, str] | None = None,
):
    """
    Initialize MCP state.

    Args:
        user_id: User ID for authentication and scoping
        session_id: Session ID for conversation grouping
        discovered_servers: Mapping of server_name to plugin data and session info
        discovery_completed: Whether discovery has finished successfully
        created_at: Timestamp of state creation (defaults to now)
        failed_servers: Dictionary of failed servers and their error messages
    """
    self.user_id = user_id
    self.session_id = session_id
    self.discovered_servers = discovered_servers
    self.discovery_completed = discovery_completed
    self.created_at = created_at or datetime.now(UTC)
    self.failed_servers = failed_servers or {}
sk_agents.mcp_discovery.mcp_discovery_manager.McpStateManager

Bases: ABC

Abstract interface for MCP state management (discovery + sessions).

Implementations must provide storage for MCP state scoped to (user_id, session_id) combinations. This enables: - Session-level tool isolation - Shared discovery across tasks in the same session - MCP session persistence for stateful servers - External state storage (Redis, in-memory, etc.)

Pattern matches: - TaskPersistenceManager (for task state) - SecureAuthStorageManager (for OAuth tokens)

Source code in src/sk_agents/mcp_discovery/mcp_discovery_manager.py
class McpStateManager(ABC):
    """
    Abstract interface for MCP state management (discovery + sessions).

    Implementations must provide storage for MCP state scoped to
    (user_id, session_id) combinations. This enables:
    - Session-level tool isolation
    - Shared discovery across tasks in the same session
    - MCP session persistence for stateful servers
    - External state storage (Redis, in-memory, etc.)

    Pattern matches:
    - TaskPersistenceManager (for task state)
    - SecureAuthStorageManager (for OAuth tokens)
    """

    @abstractmethod
    async def create_discovery(self, state: McpState) -> None:
        """
        Create initial state for (user_id, session_id).

        Args:
            state: MCP state to create

        Raises:
            DiscoveryCreateError: If state already exists for this (user_id, session_id)
        """
        pass

    @abstractmethod
    async def load_discovery(self, user_id: str, session_id: str) -> McpState | None:
        """
        Load MCP state for (user_id, session_id).

        Args:
            user_id: User ID
            session_id: Session ID

        Returns:
            MCP state if exists, None otherwise
        """
        pass

    @abstractmethod
    async def update_discovery(self, state: McpState) -> None:
        """
        Update existing MCP state.

        Args:
            state: Updated MCP state

        Raises:
            DiscoveryUpdateError: If state does not exist
        """
        pass

    @abstractmethod
    async def delete_discovery(self, user_id: str, session_id: str) -> None:
        """
        Delete MCP state for (user_id, session_id).

        Args:
            user_id: User ID
            session_id: Session ID
        """
        pass

    @abstractmethod
    async def mark_completed(self, user_id: str, session_id: str) -> None:
        """
        Mark discovery as completed for (user_id, session_id).

        If the state does not exist, it will be created automatically
        with an empty discovered_servers dict and discovery_completed=True.
        A warning will be logged when auto-creating.

        This operation is idempotent - calling it multiple times has the same
        effect as calling it once.

        Args:
            user_id: User ID
            session_id: Session ID
        """
        pass

    @abstractmethod
    async def is_completed(self, user_id: str, session_id: str) -> bool:
        """
        Check if discovery is completed for (user_id, session_id).

        Args:
            user_id: User ID
            session_id: Session ID

        Returns:
            True if discovery completed, False otherwise
        """
        pass

    @abstractmethod
    async def store_mcp_session(
        self, user_id: str, session_id: str, server_name: str, mcp_session_id: str
    ) -> None:
        """
        Store MCP session ID for a server.

        If state doesn't exist, it will be created. If server doesn't exist
        in discovered_servers, it will be added.

        Args:
            user_id: User ID
            session_id: Teal agent session ID
            server_name: Name of the MCP server
            mcp_session_id: MCP session ID from server

        Raises:
            DiscoveryUpdateError: If state update fails
        """
        pass

    @abstractmethod
    async def get_mcp_session(self, user_id: str, session_id: str, server_name: str) -> str | None:
        """
        Get MCP session ID for a server.

        Args:
            user_id: User ID
            session_id: Teal agent session ID
            server_name: Name of the MCP server

        Returns:
            MCP session ID if exists, None otherwise
        """
        pass

    @abstractmethod
    async def update_session_last_used(
        self, user_id: str, session_id: str, server_name: str
    ) -> None:
        """
        Update last_used timestamp for an MCP session.

        Args:
            user_id: User ID
            session_id: Teal agent session ID
            server_name: Name of the MCP server

        Raises:
            DiscoveryUpdateError: If state or server doesn't exist
        """
        pass

    @abstractmethod
    async def clear_mcp_session(
        self,
        user_id: str,
        session_id: str,
        server_name: str,
        expected_session_id: str | None = None,
    ) -> None:
        """
        Clear the stored MCP session for a given server (if present).

        Args:
            user_id: User ID
            session_id: Teal agent session ID
            server_name: Name of the MCP server
            expected_session_id: Optional session id to match before clearing
        """
        pass
sk_agents.mcp_discovery.mcp_discovery_manager.McpStateManager.create_discovery abstractmethod async
create_discovery(state: McpState) -> None

Create initial state for (user_id, session_id).

Parameters:

Name Type Description Default
state McpState

MCP state to create

required

Raises:

Type Description
DiscoveryCreateError

If state already exists for this (user_id, session_id)

Source code in src/sk_agents/mcp_discovery/mcp_discovery_manager.py
@abstractmethod
async def create_discovery(self, state: McpState) -> None:
    """
    Create initial state for (user_id, session_id).

    Args:
        state: MCP state to create

    Raises:
        DiscoveryCreateError: If state already exists for this (user_id, session_id)
    """
    pass
sk_agents.mcp_discovery.mcp_discovery_manager.McpStateManager.load_discovery abstractmethod async
load_discovery(
    user_id: str, session_id: str
) -> McpState | None

Load MCP state for (user_id, session_id).

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Session ID

required

Returns:

Type Description
McpState | None

MCP state if exists, None otherwise

Source code in src/sk_agents/mcp_discovery/mcp_discovery_manager.py
@abstractmethod
async def load_discovery(self, user_id: str, session_id: str) -> McpState | None:
    """
    Load MCP state for (user_id, session_id).

    Args:
        user_id: User ID
        session_id: Session ID

    Returns:
        MCP state if exists, None otherwise
    """
    pass
sk_agents.mcp_discovery.mcp_discovery_manager.McpStateManager.update_discovery abstractmethod async
update_discovery(state: McpState) -> None

Update existing MCP state.

Parameters:

Name Type Description Default
state McpState

Updated MCP state

required

Raises:

Type Description
DiscoveryUpdateError

If state does not exist

Source code in src/sk_agents/mcp_discovery/mcp_discovery_manager.py
@abstractmethod
async def update_discovery(self, state: McpState) -> None:
    """
    Update existing MCP state.

    Args:
        state: Updated MCP state

    Raises:
        DiscoveryUpdateError: If state does not exist
    """
    pass
sk_agents.mcp_discovery.mcp_discovery_manager.McpStateManager.delete_discovery abstractmethod async
delete_discovery(user_id: str, session_id: str) -> None

Delete MCP state for (user_id, session_id).

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Session ID

required
Source code in src/sk_agents/mcp_discovery/mcp_discovery_manager.py
@abstractmethod
async def delete_discovery(self, user_id: str, session_id: str) -> None:
    """
    Delete MCP state for (user_id, session_id).

    Args:
        user_id: User ID
        session_id: Session ID
    """
    pass
sk_agents.mcp_discovery.mcp_discovery_manager.McpStateManager.mark_completed abstractmethod async
mark_completed(user_id: str, session_id: str) -> None

Mark discovery as completed for (user_id, session_id).

If the state does not exist, it will be created automatically with an empty discovered_servers dict and discovery_completed=True. A warning will be logged when auto-creating.

This operation is idempotent - calling it multiple times has the same effect as calling it once.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Session ID

required
Source code in src/sk_agents/mcp_discovery/mcp_discovery_manager.py
@abstractmethod
async def mark_completed(self, user_id: str, session_id: str) -> None:
    """
    Mark discovery as completed for (user_id, session_id).

    If the state does not exist, it will be created automatically
    with an empty discovered_servers dict and discovery_completed=True.
    A warning will be logged when auto-creating.

    This operation is idempotent - calling it multiple times has the same
    effect as calling it once.

    Args:
        user_id: User ID
        session_id: Session ID
    """
    pass
sk_agents.mcp_discovery.mcp_discovery_manager.McpStateManager.is_completed abstractmethod async
is_completed(user_id: str, session_id: str) -> bool

Check if discovery is completed for (user_id, session_id).

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Session ID

required

Returns:

Type Description
bool

True if discovery completed, False otherwise

Source code in src/sk_agents/mcp_discovery/mcp_discovery_manager.py
@abstractmethod
async def is_completed(self, user_id: str, session_id: str) -> bool:
    """
    Check if discovery is completed for (user_id, session_id).

    Args:
        user_id: User ID
        session_id: Session ID

    Returns:
        True if discovery completed, False otherwise
    """
    pass
sk_agents.mcp_discovery.mcp_discovery_manager.McpStateManager.store_mcp_session abstractmethod async
store_mcp_session(
    user_id: str,
    session_id: str,
    server_name: str,
    mcp_session_id: str,
) -> None

Store MCP session ID for a server.

If state doesn't exist, it will be created. If server doesn't exist in discovered_servers, it will be added.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Teal agent session ID

required
server_name str

Name of the MCP server

required
mcp_session_id str

MCP session ID from server

required

Raises:

Type Description
DiscoveryUpdateError

If state update fails

Source code in src/sk_agents/mcp_discovery/mcp_discovery_manager.py
@abstractmethod
async def store_mcp_session(
    self, user_id: str, session_id: str, server_name: str, mcp_session_id: str
) -> None:
    """
    Store MCP session ID for a server.

    If state doesn't exist, it will be created. If server doesn't exist
    in discovered_servers, it will be added.

    Args:
        user_id: User ID
        session_id: Teal agent session ID
        server_name: Name of the MCP server
        mcp_session_id: MCP session ID from server

    Raises:
        DiscoveryUpdateError: If state update fails
    """
    pass
sk_agents.mcp_discovery.mcp_discovery_manager.McpStateManager.get_mcp_session abstractmethod async
get_mcp_session(
    user_id: str, session_id: str, server_name: str
) -> str | None

Get MCP session ID for a server.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Teal agent session ID

required
server_name str

Name of the MCP server

required

Returns:

Type Description
str | None

MCP session ID if exists, None otherwise

Source code in src/sk_agents/mcp_discovery/mcp_discovery_manager.py
@abstractmethod
async def get_mcp_session(self, user_id: str, session_id: str, server_name: str) -> str | None:
    """
    Get MCP session ID for a server.

    Args:
        user_id: User ID
        session_id: Teal agent session ID
        server_name: Name of the MCP server

    Returns:
        MCP session ID if exists, None otherwise
    """
    pass
sk_agents.mcp_discovery.mcp_discovery_manager.McpStateManager.update_session_last_used abstractmethod async
update_session_last_used(
    user_id: str, session_id: str, server_name: str
) -> None

Update last_used timestamp for an MCP session.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Teal agent session ID

required
server_name str

Name of the MCP server

required

Raises:

Type Description
DiscoveryUpdateError

If state or server doesn't exist

Source code in src/sk_agents/mcp_discovery/mcp_discovery_manager.py
@abstractmethod
async def update_session_last_used(
    self, user_id: str, session_id: str, server_name: str
) -> None:
    """
    Update last_used timestamp for an MCP session.

    Args:
        user_id: User ID
        session_id: Teal agent session ID
        server_name: Name of the MCP server

    Raises:
        DiscoveryUpdateError: If state or server doesn't exist
    """
    pass
sk_agents.mcp_discovery.mcp_discovery_manager.McpStateManager.clear_mcp_session abstractmethod async
clear_mcp_session(
    user_id: str,
    session_id: str,
    server_name: str,
    expected_session_id: str | None = None,
) -> None

Clear the stored MCP session for a given server (if present).

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Teal agent session ID

required
server_name str

Name of the MCP server

required
expected_session_id str | None

Optional session id to match before clearing

None
Source code in src/sk_agents/mcp_discovery/mcp_discovery_manager.py
@abstractmethod
async def clear_mcp_session(
    self,
    user_id: str,
    session_id: str,
    server_name: str,
    expected_session_id: str | None = None,
) -> None:
    """
    Clear the stored MCP session for a given server (if present).

    Args:
        user_id: User ID
        session_id: Teal agent session ID
        server_name: Name of the MCP server
        expected_session_id: Optional session id to match before clearing
    """
    pass
sk_agents.mcp_discovery.redis_discovery_manager

Redis MCP State Manager

Provides Redis-backed implementation for production deployments. Follows the same pattern as Redis persistence and auth storage.

sk_agents.mcp_discovery.redis_discovery_manager.RedisStateManager

Bases: McpStateManager

Redis-backed implementation of MCP state manager.

Stores MCP state in Redis for: - Production deployments - Multi-instance horizontal scaling - Persistence across server restarts - Shared state across distributed systems

Uses the same Redis configuration as other components (TA_REDIS_*).

Source code in src/sk_agents/mcp_discovery/redis_discovery_manager.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
class RedisStateManager(McpStateManager):
    """
    Redis-backed implementation of MCP state manager.

    Stores MCP state in Redis for:
    - Production deployments
    - Multi-instance horizontal scaling
    - Persistence across server restarts
    - Shared state across distributed systems

    Uses the same Redis configuration as other components (TA_REDIS_*).
    """

    def __init__(self, app_config: AppConfig, redis_client: Redis | None = None):
        """
        Initialize Redis state manager.

        Args:
            app_config: Application configuration for Redis connection
            redis_client: Optional pre-configured Redis client (for testing)
        """
        self.app_config = app_config
        self.redis = redis_client or self._create_redis_client()
        self.key_prefix = "mcp_state"

        # TTL support: Default to 24 hours (86400 seconds)
        from sk_agents.configs import TA_REDIS_TTL

        ttl_str = self.app_config.get(TA_REDIS_TTL.env_name)
        if ttl_str:
            self.ttl = int(ttl_str)
        else:
            # Default to 24 hours for discovery state
            self.ttl = 86400

        logger.debug(f"Redis state manager initialized with TTL={self.ttl}s")

    async def close(self) -> None:
        """Close Redis connection and cleanup resources."""
        if self.redis:
            await self.redis.close()
            logger.debug("Redis state manager connection closed")

    async def __aenter__(self):
        """Async context manager entry."""
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        """Async context manager exit."""
        await self.close()

    def _create_redis_client(self) -> Redis:
        """
        Create Redis client from app configuration.

        Reuses existing TA_REDIS_* environment variables for consistency
        with other persistence components.

        Returns:
            Configured Redis client

        Raises:
            ValueError: If required Redis config is missing
        """
        from sk_agents.configs import (
            TA_REDIS_DB,
            TA_REDIS_HOST,
            TA_REDIS_PORT,
            TA_REDIS_PWD,
            TA_REDIS_SSL,
        )

        host = self.app_config.get(TA_REDIS_HOST.env_name)
        port_str = self.app_config.get(TA_REDIS_PORT.env_name)
        db_str = self.app_config.get(TA_REDIS_DB.env_name, default="0")
        ssl_str = self.app_config.get(TA_REDIS_SSL.env_name, default="false")
        pwd = self.app_config.get(TA_REDIS_PWD.env_name, default=None)

        if not host:
            raise ValueError("TA_REDIS_HOST must be configured for Redis discovery manager")
        if not port_str:
            raise ValueError("TA_REDIS_PORT must be configured for Redis discovery manager")

        port = int(port_str)
        db = int(db_str)
        ssl = strtobool(ssl_str)

        logger.info(
            f"Creating Redis discovery client: host={host}, port={port}, db={db}, ssl={ssl}"
        )

        return Redis(host=host, port=port, db=db, ssl=ssl, password=pwd)

    def _make_key(self, user_id: str, session_id: str) -> str:
        """
        Create Redis key for storage.

        Format: mcp_state:{user_id}:{session_id}

        Args:
            user_id: User ID
            session_id: Session ID

        Returns:
            Redis key string
        """
        return f"{self.key_prefix}:{user_id}:{session_id}"

    async def create_discovery(self, state: McpState) -> None:
        """
        Create initial MCP state in Redis.

        Args:
            state: MCP state to create

        Raises:
            DiscoveryCreateError: If state already exists
        """
        key = self._make_key(state.user_id, state.session_id)
        exists = await self.redis.exists(key)
        if exists:
            raise DiscoveryCreateError(
                f"MCP state already exists for user={state.user_id}, session={state.session_id}"
            )

        data = self._serialize(state)
        # Set with TTL
        await self.redis.set(key, data, ex=self.ttl)
        logger.debug(
            f"Created Redis MCP state: user={state.user_id}, session={state.session_id}, "
            f"TTL={self.ttl}s"
        )

    async def load_discovery(self, user_id: str, session_id: str) -> McpState | None:
        """
        Load MCP state from Redis.

        Args:
            user_id: User ID
            session_id: Session ID

        Returns:
            MCP state if exists, None otherwise
        """
        key = self._make_key(user_id, session_id)
        data = await self.redis.get(key)
        if not data:
            return None
        return self._deserialize(data, user_id, session_id)

    async def update_discovery(self, state: McpState) -> None:
        """
        Update existing MCP state in Redis.

        Args:
            state: Updated MCP state

        Raises:
            DiscoveryUpdateError: If state does not exist
        """
        key = self._make_key(state.user_id, state.session_id)
        # Check existence before updating
        exists = await self.redis.exists(key)
        if not exists:
            raise DiscoveryUpdateError(
                f"MCP state not found for user={state.user_id}, session={state.session_id}"
            )

        data = self._serialize(state)
        # Update with TTL to extend expiration
        await self.redis.set(key, data, ex=self.ttl)
        logger.debug(f"Updated Redis MCP state: user={state.user_id}, session={state.session_id}")

    async def delete_discovery(self, user_id: str, session_id: str) -> None:
        """
        Delete discovery state from Redis.

        Args:
            user_id: User ID
            session_id: Session ID
        """
        key = self._make_key(user_id, session_id)
        await self.redis.delete(key)
        logger.debug(f"Deleted Redis discovery state: user={user_id}, session={session_id}")

    async def mark_completed(self, user_id: str, session_id: str) -> None:
        """
        Mark discovery as completed in Redis using atomic operation.

        If state doesn't exist, auto-creates it with discovery_completed=True
        and empty discovered_servers dict. A warning is logged when auto-creating.

        Uses Lua script for atomic read-modify-write to prevent race conditions
        in multi-worker deployments.

        Args:
            user_id: User ID
            session_id: Session ID
        """
        key = self._make_key(user_id, session_id)

        # Lua script for atomic mark_completed operation
        lua_script = """
        local key = KEYS[1]
        local ttl = tonumber(ARGV[1])
        local data = redis.call('GET', key)

        if data then
            -- State exists, update discovery_completed field
            local obj = cjson.decode(data)
            obj.discovery_completed = true
            local updated_data = cjson.encode(obj)
            redis.call('SET', key, updated_data, 'EX', ttl)
            return 1
        else
            -- State doesn't exist, return 0 to signal auto-create
            return 0
        end
        """

        result = await self.redis.eval(lua_script, 1, key, self.ttl)

        if result == 1:
            logger.debug(f"Marked discovery completed: user={user_id}, session={session_id}")
        else:
            # Auto-create state if it doesn't exist
            logger.warning(
                f"MCP state not found for user={user_id}, session={session_id}. "
                f"Auto-creating with discovery_completed=True."
            )
            state = McpState(
                user_id=user_id,
                session_id=session_id,
                discovered_servers={},
                discovery_completed=True,
                created_at=datetime.now(UTC),
            )
            data = self._serialize(state)
            await self.redis.set(key, data, ex=self.ttl)
            logger.debug(f"Auto-created discovery state: user={user_id}, session={session_id}")

    async def is_completed(self, user_id: str, session_id: str) -> bool:
        """
        Check if discovery is completed in Redis.

        Args:
            user_id: User ID
            session_id: Session ID

        Returns:
            True if discovery completed, False otherwise
        """
        state = await self.load_discovery(user_id, session_id)
        return state.discovery_completed if state else False

    async def store_mcp_session(
        self, user_id: str, session_id: str, server_name: str, mcp_session_id: str
    ) -> None:
        """
        Store MCP session ID for a server using atomic Lua script.

        Args:
            user_id: User ID
            session_id: Teal agent session ID
            server_name: Name of the MCP server
            mcp_session_id: MCP session ID from server
        """
        key = self._make_key(user_id, session_id)

        # Lua script for atomic store operation
        lua_script = """
        local key = KEYS[1]
        local ttl = tonumber(ARGV[1])
        local server_name = ARGV[2]
        local mcp_session_id = ARGV[3]
        local timestamp = ARGV[4]

        local data = redis.call('GET', key)
        local obj

        if data then
            -- State exists, update it
            obj = cjson.decode(data)
        else
            -- State doesn't exist, create minimal state
            obj = {
                user_id = ARGV[5],
                session_id = ARGV[6],
                discovered_servers = {},
                discovery_completed = false,
                created_at = timestamp
            }
        end

        -- Ensure server entry exists
        if not obj.discovered_servers[server_name] then
            obj.discovered_servers[server_name] = {}
        end

        -- Store session data
        if not obj.discovered_servers[server_name].session then
            obj.discovered_servers[server_name].session = {}
        end

        obj.discovered_servers[server_name].session.mcp_session_id = mcp_session_id
        local sess = obj.discovered_servers[server_name].session
        sess.created_at = sess.created_at or timestamp
        sess.last_used_at = timestamp

        local updated_data = cjson.encode(obj)
        redis.call('SET', key, updated_data, 'EX', ttl)
        return 1
        """

        timestamp = datetime.now(UTC).isoformat()
        await self.redis.eval(
            lua_script,
            1,
            key,
            self.ttl,
            server_name,
            mcp_session_id,
            timestamp,
            user_id,
            session_id,
        )

        logger.debug(
            f"Stored MCP session {mcp_session_id} for server={server_name}, "
            f"user={user_id}, session={session_id}"
        )

    async def get_mcp_session(self, user_id: str, session_id: str, server_name: str) -> str | None:
        """
        Get MCP session ID for a server.

        Args:
            user_id: User ID
            session_id: Teal agent session ID
            server_name: Name of the MCP server

        Returns:
            MCP session ID if exists, None otherwise
        """
        state = await self.load_discovery(user_id, session_id)

        if not state:
            return None

        server_data = state.discovered_servers.get(server_name)
        if not server_data:
            return None

        session_bucket = server_data.get("session")
        if not session_bucket:
            return None

        return session_bucket.get("mcp_session_id")

    async def update_session_last_used(
        self, user_id: str, session_id: str, server_name: str
    ) -> None:
        """
        Update last_used timestamp using atomic Lua script.

        Args:
            user_id: User ID
            session_id: Teal agent session ID
            server_name: Name of the MCP server

        Raises:
            DiscoveryUpdateError: If state or server doesn't exist
        """
        key = self._make_key(user_id, session_id)

        # Lua script for atomic update
        lua_script = """
        local key = KEYS[1]
        local ttl = tonumber(ARGV[1])
        local server_name = ARGV[2]
        local timestamp = ARGV[3]

        local data = redis.call('GET', key)
        if not data then
            return 0  -- State not found
        end

        local obj = cjson.decode(data)

        if not obj.discovered_servers[server_name] then
            return -1  -- Server not found
        end

        if not obj.discovered_servers[server_name].session then
            obj.discovered_servers[server_name].session = {}
        end
        obj.discovered_servers[server_name].session.last_used_at = timestamp

        local updated_data = cjson.encode(obj)
        redis.call('SET', key, updated_data, 'EX', ttl)
        return 1
        """

        timestamp = datetime.now(UTC).isoformat()
        result = await self.redis.eval(lua_script, 1, key, self.ttl, server_name, timestamp)

        if result == 0:
            raise DiscoveryUpdateError(
                f"MCP state not found for user={user_id}, session={session_id}"
            )
        elif result == -1:
            raise DiscoveryUpdateError(
                f"Server {server_name} not found in state for user={user_id}, session={session_id}"
            )

        logger.debug(
            f"Updated last_used for server={server_name}, user={user_id}, session={session_id}"
        )

    async def clear_mcp_session(
        self,
        user_id: str,
        session_id: str,
        server_name: str,
        expected_session_id: str | None = None,
    ) -> None:
        """Remove stored MCP session info for a server if present."""
        key = self._make_key(user_id, session_id)

        lua_script = """
        local key = KEYS[1]
        local server_name = ARGV[1]
        local ttl = tonumber(ARGV[2])
        local expected_session_id = ARGV[3]

        local data = redis.call('GET', key)
        if not data then
            return 0 -- state missing
        end

        local obj = cjson.decode(data)
        if not obj.discovered_servers[server_name] then
            return -1 -- server missing
        end

        -- Only clear if expected matches or no expectation provided
        if obj.discovered_servers[server_name].session then
            local current = obj.discovered_servers[server_name].session.mcp_session_id
            if expected_session_id ~= nil and expected_session_id ~= '' then
                if current ~= expected_session_id then
                    return -2  -- session changed, skip clear
                end
            end
        end

        obj.discovered_servers[server_name].session = nil

        local updated_data = cjson.encode(obj)
        redis.call('SET', key, updated_data, 'EX', ttl)
        return 1
        """

        expected_arg = expected_session_id or ""
        result = await self.redis.eval(lua_script, 1, key, server_name, self.ttl, expected_arg)
        if result == 0:
            logger.debug(
                f"clear_mcp_session: state missing for user={user_id}, session={session_id}"
            )
        elif result == -1:
            logger.debug(
                f"clear_mcp_session: server missing for user={user_id}, "
                f"session={session_id}, server={server_name}"
            )
        elif result == -2:
            logger.debug(
                f"clear_mcp_session: session changed for user={user_id}, "
                f"session={session_id}, server={server_name}"
            )
        else:
            logger.debug(
                f"Cleared MCP session for server={server_name}, "
                f"user={user_id}, session={session_id}"
            )

    def _serialize(self, state: McpState) -> str:
        """
        Serialize MCP state to JSON.

        Args:
            state: MCP state to serialize

        Returns:
            JSON string representation
        """
        return json.dumps(
            {
                "user_id": state.user_id,
                "session_id": state.session_id,
                "discovered_servers": state.discovered_servers,
                "discovery_completed": state.discovery_completed,
                "created_at": state.created_at.isoformat(),
                "failed_servers": state.failed_servers,
            }
        )

    def _deserialize(self, data: str | bytes, user_id: str, session_id: str) -> McpState:
        """
        Deserialize JSON to MCP state object.

        Args:
            data: JSON string or bytes from Redis
            user_id: User ID (for validation)
            session_id: Session ID (for validation)

        Returns:
            McpState object

        Raises:
            ValueError: If deserialized user_id/session_id don't match parameters
        """
        # Handle bytes from Redis
        if isinstance(data, bytes):
            data = data.decode("utf-8")

        obj = json.loads(data)

        # Validate that serialized data matches the key parameters
        if obj["user_id"] != user_id:
            raise ValueError(
                f"Deserialized user_id '{obj['user_id']}' does not match "
                f"expected user_id '{user_id}'"
            )
        if obj["session_id"] != session_id:
            raise ValueError(
                f"Deserialized session_id '{obj['session_id']}' does not match "
                f"expected session_id '{session_id}'"
            )

        return McpState(
            user_id=user_id,
            session_id=session_id,
            discovered_servers=obj["discovered_servers"],
            discovery_completed=obj["discovery_completed"],
            created_at=datetime.fromisoformat(obj["created_at"]),
            failed_servers=obj.get("failed_servers", {}),
        )
sk_agents.mcp_discovery.redis_discovery_manager.RedisStateManager.__init__
__init__(
    app_config: AppConfig, redis_client: Redis | None = None
)

Initialize Redis state manager.

Parameters:

Name Type Description Default
app_config AppConfig

Application configuration for Redis connection

required
redis_client Redis | None

Optional pre-configured Redis client (for testing)

None
Source code in src/sk_agents/mcp_discovery/redis_discovery_manager.py
def __init__(self, app_config: AppConfig, redis_client: Redis | None = None):
    """
    Initialize Redis state manager.

    Args:
        app_config: Application configuration for Redis connection
        redis_client: Optional pre-configured Redis client (for testing)
    """
    self.app_config = app_config
    self.redis = redis_client or self._create_redis_client()
    self.key_prefix = "mcp_state"

    # TTL support: Default to 24 hours (86400 seconds)
    from sk_agents.configs import TA_REDIS_TTL

    ttl_str = self.app_config.get(TA_REDIS_TTL.env_name)
    if ttl_str:
        self.ttl = int(ttl_str)
    else:
        # Default to 24 hours for discovery state
        self.ttl = 86400

    logger.debug(f"Redis state manager initialized with TTL={self.ttl}s")
sk_agents.mcp_discovery.redis_discovery_manager.RedisStateManager.close async
close() -> None

Close Redis connection and cleanup resources.

Source code in src/sk_agents/mcp_discovery/redis_discovery_manager.py
async def close(self) -> None:
    """Close Redis connection and cleanup resources."""
    if self.redis:
        await self.redis.close()
        logger.debug("Redis state manager connection closed")
sk_agents.mcp_discovery.redis_discovery_manager.RedisStateManager.__aenter__ async
__aenter__()

Async context manager entry.

Source code in src/sk_agents/mcp_discovery/redis_discovery_manager.py
async def __aenter__(self):
    """Async context manager entry."""
    return self
sk_agents.mcp_discovery.redis_discovery_manager.RedisStateManager.__aexit__ async
__aexit__(exc_type, exc_val, exc_tb)

Async context manager exit.

Source code in src/sk_agents/mcp_discovery/redis_discovery_manager.py
async def __aexit__(self, exc_type, exc_val, exc_tb):
    """Async context manager exit."""
    await self.close()
sk_agents.mcp_discovery.redis_discovery_manager.RedisStateManager.create_discovery async
create_discovery(state: McpState) -> None

Create initial MCP state in Redis.

Parameters:

Name Type Description Default
state McpState

MCP state to create

required

Raises:

Type Description
DiscoveryCreateError

If state already exists

Source code in src/sk_agents/mcp_discovery/redis_discovery_manager.py
async def create_discovery(self, state: McpState) -> None:
    """
    Create initial MCP state in Redis.

    Args:
        state: MCP state to create

    Raises:
        DiscoveryCreateError: If state already exists
    """
    key = self._make_key(state.user_id, state.session_id)
    exists = await self.redis.exists(key)
    if exists:
        raise DiscoveryCreateError(
            f"MCP state already exists for user={state.user_id}, session={state.session_id}"
        )

    data = self._serialize(state)
    # Set with TTL
    await self.redis.set(key, data, ex=self.ttl)
    logger.debug(
        f"Created Redis MCP state: user={state.user_id}, session={state.session_id}, "
        f"TTL={self.ttl}s"
    )
sk_agents.mcp_discovery.redis_discovery_manager.RedisStateManager.load_discovery async
load_discovery(
    user_id: str, session_id: str
) -> McpState | None

Load MCP state from Redis.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Session ID

required

Returns:

Type Description
McpState | None

MCP state if exists, None otherwise

Source code in src/sk_agents/mcp_discovery/redis_discovery_manager.py
async def load_discovery(self, user_id: str, session_id: str) -> McpState | None:
    """
    Load MCP state from Redis.

    Args:
        user_id: User ID
        session_id: Session ID

    Returns:
        MCP state if exists, None otherwise
    """
    key = self._make_key(user_id, session_id)
    data = await self.redis.get(key)
    if not data:
        return None
    return self._deserialize(data, user_id, session_id)
sk_agents.mcp_discovery.redis_discovery_manager.RedisStateManager.update_discovery async
update_discovery(state: McpState) -> None

Update existing MCP state in Redis.

Parameters:

Name Type Description Default
state McpState

Updated MCP state

required

Raises:

Type Description
DiscoveryUpdateError

If state does not exist

Source code in src/sk_agents/mcp_discovery/redis_discovery_manager.py
async def update_discovery(self, state: McpState) -> None:
    """
    Update existing MCP state in Redis.

    Args:
        state: Updated MCP state

    Raises:
        DiscoveryUpdateError: If state does not exist
    """
    key = self._make_key(state.user_id, state.session_id)
    # Check existence before updating
    exists = await self.redis.exists(key)
    if not exists:
        raise DiscoveryUpdateError(
            f"MCP state not found for user={state.user_id}, session={state.session_id}"
        )

    data = self._serialize(state)
    # Update with TTL to extend expiration
    await self.redis.set(key, data, ex=self.ttl)
    logger.debug(f"Updated Redis MCP state: user={state.user_id}, session={state.session_id}")
sk_agents.mcp_discovery.redis_discovery_manager.RedisStateManager.delete_discovery async
delete_discovery(user_id: str, session_id: str) -> None

Delete discovery state from Redis.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Session ID

required
Source code in src/sk_agents/mcp_discovery/redis_discovery_manager.py
async def delete_discovery(self, user_id: str, session_id: str) -> None:
    """
    Delete discovery state from Redis.

    Args:
        user_id: User ID
        session_id: Session ID
    """
    key = self._make_key(user_id, session_id)
    await self.redis.delete(key)
    logger.debug(f"Deleted Redis discovery state: user={user_id}, session={session_id}")
sk_agents.mcp_discovery.redis_discovery_manager.RedisStateManager.mark_completed async
mark_completed(user_id: str, session_id: str) -> None

Mark discovery as completed in Redis using atomic operation.

If state doesn't exist, auto-creates it with discovery_completed=True and empty discovered_servers dict. A warning is logged when auto-creating.

Uses Lua script for atomic read-modify-write to prevent race conditions in multi-worker deployments.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Session ID

required
Source code in src/sk_agents/mcp_discovery/redis_discovery_manager.py
async def mark_completed(self, user_id: str, session_id: str) -> None:
    """
    Mark discovery as completed in Redis using atomic operation.

    If state doesn't exist, auto-creates it with discovery_completed=True
    and empty discovered_servers dict. A warning is logged when auto-creating.

    Uses Lua script for atomic read-modify-write to prevent race conditions
    in multi-worker deployments.

    Args:
        user_id: User ID
        session_id: Session ID
    """
    key = self._make_key(user_id, session_id)

    # Lua script for atomic mark_completed operation
    lua_script = """
    local key = KEYS[1]
    local ttl = tonumber(ARGV[1])
    local data = redis.call('GET', key)

    if data then
        -- State exists, update discovery_completed field
        local obj = cjson.decode(data)
        obj.discovery_completed = true
        local updated_data = cjson.encode(obj)
        redis.call('SET', key, updated_data, 'EX', ttl)
        return 1
    else
        -- State doesn't exist, return 0 to signal auto-create
        return 0
    end
    """

    result = await self.redis.eval(lua_script, 1, key, self.ttl)

    if result == 1:
        logger.debug(f"Marked discovery completed: user={user_id}, session={session_id}")
    else:
        # Auto-create state if it doesn't exist
        logger.warning(
            f"MCP state not found for user={user_id}, session={session_id}. "
            f"Auto-creating with discovery_completed=True."
        )
        state = McpState(
            user_id=user_id,
            session_id=session_id,
            discovered_servers={},
            discovery_completed=True,
            created_at=datetime.now(UTC),
        )
        data = self._serialize(state)
        await self.redis.set(key, data, ex=self.ttl)
        logger.debug(f"Auto-created discovery state: user={user_id}, session={session_id}")
sk_agents.mcp_discovery.redis_discovery_manager.RedisStateManager.is_completed async
is_completed(user_id: str, session_id: str) -> bool

Check if discovery is completed in Redis.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Session ID

required

Returns:

Type Description
bool

True if discovery completed, False otherwise

Source code in src/sk_agents/mcp_discovery/redis_discovery_manager.py
async def is_completed(self, user_id: str, session_id: str) -> bool:
    """
    Check if discovery is completed in Redis.

    Args:
        user_id: User ID
        session_id: Session ID

    Returns:
        True if discovery completed, False otherwise
    """
    state = await self.load_discovery(user_id, session_id)
    return state.discovery_completed if state else False
sk_agents.mcp_discovery.redis_discovery_manager.RedisStateManager.store_mcp_session async
store_mcp_session(
    user_id: str,
    session_id: str,
    server_name: str,
    mcp_session_id: str,
) -> None

Store MCP session ID for a server using atomic Lua script.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Teal agent session ID

required
server_name str

Name of the MCP server

required
mcp_session_id str

MCP session ID from server

required
Source code in src/sk_agents/mcp_discovery/redis_discovery_manager.py
async def store_mcp_session(
    self, user_id: str, session_id: str, server_name: str, mcp_session_id: str
) -> None:
    """
    Store MCP session ID for a server using atomic Lua script.

    Args:
        user_id: User ID
        session_id: Teal agent session ID
        server_name: Name of the MCP server
        mcp_session_id: MCP session ID from server
    """
    key = self._make_key(user_id, session_id)

    # Lua script for atomic store operation
    lua_script = """
    local key = KEYS[1]
    local ttl = tonumber(ARGV[1])
    local server_name = ARGV[2]
    local mcp_session_id = ARGV[3]
    local timestamp = ARGV[4]

    local data = redis.call('GET', key)
    local obj

    if data then
        -- State exists, update it
        obj = cjson.decode(data)
    else
        -- State doesn't exist, create minimal state
        obj = {
            user_id = ARGV[5],
            session_id = ARGV[6],
            discovered_servers = {},
            discovery_completed = false,
            created_at = timestamp
        }
    end

    -- Ensure server entry exists
    if not obj.discovered_servers[server_name] then
        obj.discovered_servers[server_name] = {}
    end

    -- Store session data
    if not obj.discovered_servers[server_name].session then
        obj.discovered_servers[server_name].session = {}
    end

    obj.discovered_servers[server_name].session.mcp_session_id = mcp_session_id
    local sess = obj.discovered_servers[server_name].session
    sess.created_at = sess.created_at or timestamp
    sess.last_used_at = timestamp

    local updated_data = cjson.encode(obj)
    redis.call('SET', key, updated_data, 'EX', ttl)
    return 1
    """

    timestamp = datetime.now(UTC).isoformat()
    await self.redis.eval(
        lua_script,
        1,
        key,
        self.ttl,
        server_name,
        mcp_session_id,
        timestamp,
        user_id,
        session_id,
    )

    logger.debug(
        f"Stored MCP session {mcp_session_id} for server={server_name}, "
        f"user={user_id}, session={session_id}"
    )
sk_agents.mcp_discovery.redis_discovery_manager.RedisStateManager.get_mcp_session async
get_mcp_session(
    user_id: str, session_id: str, server_name: str
) -> str | None

Get MCP session ID for a server.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Teal agent session ID

required
server_name str

Name of the MCP server

required

Returns:

Type Description
str | None

MCP session ID if exists, None otherwise

Source code in src/sk_agents/mcp_discovery/redis_discovery_manager.py
async def get_mcp_session(self, user_id: str, session_id: str, server_name: str) -> str | None:
    """
    Get MCP session ID for a server.

    Args:
        user_id: User ID
        session_id: Teal agent session ID
        server_name: Name of the MCP server

    Returns:
        MCP session ID if exists, None otherwise
    """
    state = await self.load_discovery(user_id, session_id)

    if not state:
        return None

    server_data = state.discovered_servers.get(server_name)
    if not server_data:
        return None

    session_bucket = server_data.get("session")
    if not session_bucket:
        return None

    return session_bucket.get("mcp_session_id")
sk_agents.mcp_discovery.redis_discovery_manager.RedisStateManager.update_session_last_used async
update_session_last_used(
    user_id: str, session_id: str, server_name: str
) -> None

Update last_used timestamp using atomic Lua script.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Teal agent session ID

required
server_name str

Name of the MCP server

required

Raises:

Type Description
DiscoveryUpdateError

If state or server doesn't exist

Source code in src/sk_agents/mcp_discovery/redis_discovery_manager.py
async def update_session_last_used(
    self, user_id: str, session_id: str, server_name: str
) -> None:
    """
    Update last_used timestamp using atomic Lua script.

    Args:
        user_id: User ID
        session_id: Teal agent session ID
        server_name: Name of the MCP server

    Raises:
        DiscoveryUpdateError: If state or server doesn't exist
    """
    key = self._make_key(user_id, session_id)

    # Lua script for atomic update
    lua_script = """
    local key = KEYS[1]
    local ttl = tonumber(ARGV[1])
    local server_name = ARGV[2]
    local timestamp = ARGV[3]

    local data = redis.call('GET', key)
    if not data then
        return 0  -- State not found
    end

    local obj = cjson.decode(data)

    if not obj.discovered_servers[server_name] then
        return -1  -- Server not found
    end

    if not obj.discovered_servers[server_name].session then
        obj.discovered_servers[server_name].session = {}
    end
    obj.discovered_servers[server_name].session.last_used_at = timestamp

    local updated_data = cjson.encode(obj)
    redis.call('SET', key, updated_data, 'EX', ttl)
    return 1
    """

    timestamp = datetime.now(UTC).isoformat()
    result = await self.redis.eval(lua_script, 1, key, self.ttl, server_name, timestamp)

    if result == 0:
        raise DiscoveryUpdateError(
            f"MCP state not found for user={user_id}, session={session_id}"
        )
    elif result == -1:
        raise DiscoveryUpdateError(
            f"Server {server_name} not found in state for user={user_id}, session={session_id}"
        )

    logger.debug(
        f"Updated last_used for server={server_name}, user={user_id}, session={session_id}"
    )
sk_agents.mcp_discovery.redis_discovery_manager.RedisStateManager.clear_mcp_session async
clear_mcp_session(
    user_id: str,
    session_id: str,
    server_name: str,
    expected_session_id: str | None = None,
) -> None

Remove stored MCP session info for a server if present.

Source code in src/sk_agents/mcp_discovery/redis_discovery_manager.py
async def clear_mcp_session(
    self,
    user_id: str,
    session_id: str,
    server_name: str,
    expected_session_id: str | None = None,
) -> None:
    """Remove stored MCP session info for a server if present."""
    key = self._make_key(user_id, session_id)

    lua_script = """
    local key = KEYS[1]
    local server_name = ARGV[1]
    local ttl = tonumber(ARGV[2])
    local expected_session_id = ARGV[3]

    local data = redis.call('GET', key)
    if not data then
        return 0 -- state missing
    end

    local obj = cjson.decode(data)
    if not obj.discovered_servers[server_name] then
        return -1 -- server missing
    end

    -- Only clear if expected matches or no expectation provided
    if obj.discovered_servers[server_name].session then
        local current = obj.discovered_servers[server_name].session.mcp_session_id
        if expected_session_id ~= nil and expected_session_id ~= '' then
            if current ~= expected_session_id then
                return -2  -- session changed, skip clear
            end
        end
    end

    obj.discovered_servers[server_name].session = nil

    local updated_data = cjson.encode(obj)
    redis.call('SET', key, updated_data, 'EX', ttl)
    return 1
    """

    expected_arg = expected_session_id or ""
    result = await self.redis.eval(lua_script, 1, key, server_name, self.ttl, expected_arg)
    if result == 0:
        logger.debug(
            f"clear_mcp_session: state missing for user={user_id}, session={session_id}"
        )
    elif result == -1:
        logger.debug(
            f"clear_mcp_session: server missing for user={user_id}, "
            f"session={session_id}, server={server_name}"
        )
    elif result == -2:
        logger.debug(
            f"clear_mcp_session: session changed for user={user_id}, "
            f"session={session_id}, server={server_name}"
        )
    else:
        logger.debug(
            f"Cleared MCP session for server={server_name}, "
            f"user={user_id}, session={session_id}"
        )
sk_agents.mcp_plugin_registry

MCP Plugin Registry - Discovers and stores MCP tools at session start.

This registry discovers MCP tools and stores them in external state. At request time, tools are loaded from state and used to instantiate McpPlugin directly in kernel_builder.

sk_agents.mcp_plugin_registry.McpPluginRegistry

Registry for MCP tools with per-session isolation.

At session start, this registry: 1. Connects to MCP servers temporarily 2. Discovers available tools 3. Registers tools in catalog for governance/HITL 4. Serializes tool data to external storage (via McpStateManager)

At request time: - Tools are loaded from storage via get_tools_for_session() - kernel_builder instantiates McpPlugin directly with these tools

This ensures proper multi-tenant isolation and horizontal scalability. Tool state is stored externally (Redis/InMemory) instead of class variables.

Source code in src/sk_agents/mcp_plugin_registry.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
class McpPluginRegistry:
    """
    Registry for MCP tools with per-session isolation.

    At session start, this registry:
    1. Connects to MCP servers temporarily
    2. Discovers available tools
    3. Registers tools in catalog for governance/HITL
    4. Serializes tool data to external storage (via McpStateManager)

    At request time:
    - Tools are loaded from storage via get_tools_for_session()
    - kernel_builder instantiates McpPlugin directly with these tools

    This ensures proper multi-tenant isolation and horizontal scalability.
    Tool state is stored externally (Redis/InMemory) instead of class variables.
    """

    @staticmethod
    def _apply_governance_overrides(
        base_governance: Governance, tool_name: str, overrides: dict[str, GovernanceOverride] | None
    ) -> Governance:
        """Apply manual governance overrides from config."""
        if not overrides or tool_name not in overrides:
            return base_governance

        override = overrides[tool_name]

        return Governance(
            requires_hitl=override.requires_hitl
            if override.requires_hitl is not None
            else base_governance.requires_hitl,
            cost=override.cost if override.cost is not None else base_governance.cost,
            data_sensitivity=override.data_sensitivity
            if override.data_sensitivity is not None
            else base_governance.data_sensitivity,
        )

    @staticmethod
    def _create_auth_if_needed(server_config: McpServerConfig) -> Oauth2PluginAuth | None:
        """Create auth config if server requires OAuth2."""
        if server_config.auth_server and server_config.scopes:
            return Oauth2PluginAuth(
                auth_server=server_config.auth_server, scopes=server_config.scopes
            )
        return None

    @classmethod
    async def discover_and_materialize(
        cls,
        mcp_servers: list[McpServerConfig],
        user_id: str,
        session_id: str,
        discovery_manager,  # McpStateManager
        app_config,
    ) -> None:
        """
        Discover MCP tools and store in external state.

        This is called once per session when first invoked.
        Creates temporary connections to discover tools, then closes them.

        Args:
            mcp_servers: List of MCP server configurations
            user_id: User ID for authentication
            session_id: Session ID for scoping
            discovery_manager: Manager for storing discovery state

        Raises:
            AuthRequiredError: If any server requires authentication that is missing
        """
        from sk_agents.mcp_client import AuthRequiredError

        logger.info(f"Starting MCP discovery for session {session_id} ({len(mcp_servers)} servers)")

        # Load existing state
        state = await discovery_manager.load_discovery(user_id, session_id)
        if not state:
            raise ValueError(f"Discovery state not initialized for session: {session_id}")

        auth_errors = []  # Collect auth errors to surface to user

        for server_config in mcp_servers:
            try:
                # Discover this server
                plugin_data, discovered_session_id = await cls._discover_server(
                    server_config, user_id, session_id, discovery_manager, app_config
                )

                # Preserve any existing session bucket
                existing_entry = state.discovered_servers.get(server_config.name, {})
                session_bucket = existing_entry.get("session") or {}

                # Always persist freshly discovered plugin data
                state.discovered_servers[server_config.name] = {
                    "plugin_data": plugin_data,
                    **({"session": session_bucket} if session_bucket else {}),
                }
                await discovery_manager.update_discovery(state)

                # If discovery yielded a session id, persist via state manager API
                if discovered_session_id:
                    try:
                        await discovery_manager.store_mcp_session(
                            user_id,
                            session_id,
                            server_config.name,
                            discovered_session_id,
                        )
                        await discovery_manager.update_session_last_used(
                            user_id, session_id, server_config.name
                        )
                    except Exception as err:
                        logger.warning(
                            f"Failed to persist MCP session for {server_config.name}: {err}"
                        )

            except AuthRequiredError as e:
                # Auth error - collect and surface to user
                logger.warning(
                    f"Auth required for MCP server {server_config.name} (session: {session_id})"
                )
                auth_errors.append(e)
            except Exception as e:
                # Other errors - log and continue with remaining servers
                # Extract underlying exception from TaskGroup if needed
                import traceback

                error_details = "".join(traceback.format_exception(type(e), e, e.__traceback__))

                # If it's a TaskGroup exception, try to extract the underlying exception
                underlying_error = str(e)
                if hasattr(e, "__cause__") and e.__cause__:
                    underlying_error = f"{e} (caused by: {e.__cause__})"
                elif hasattr(e, "exceptions"):
                    # ExceptionGroup-style
                    underlying_error = f"{e} (sub-exceptions: {e.exceptions})"

                logger.error(
                    f"Failed to discover MCP server {server_config.name} "
                    f"for session {session_id}:\n"
                    f"Error: {underlying_error}\n"
                    f"Full traceback:\n{error_details}"
                )

                # Capture failure in state
                state.failed_servers[server_config.name] = underlying_error
                try:
                    await discovery_manager.update_discovery(state)
                except Exception as update_err:
                    logger.error(
                        f"Failed to persist discovery error for {server_config.name}: {update_err}"
                    )

                continue

        # If any servers require auth, raise the first one to trigger auth challenge
        if auth_errors:
            logger.info(
                f"MCP discovery requires auth for {len(auth_errors)} server(s) "
                f"(session: {session_id}): {[e.server_name for e in auth_errors]}"
            )
            raise auth_errors[0]  # Raise first auth error to trigger challenge

        logger.info(
            f"MCP discovery complete for session {session_id}. "
            f"Discovered {len(state.discovered_servers)} servers"
        )

    @classmethod
    async def _discover_server(
        cls,
        server_config: McpServerConfig,
        user_id: str,
        session_id: str,
        discovery_manager,
        app_config,
    ) -> tuple[dict, str | None]:
        """
        Discover tools from a single MCP server.

        Returns:
            Tuple: (Serialized plugin data, optional mcp_session_id)
        """
        logger.info(f"Discovering tools from MCP server: {server_config.name}")

        # Pre-flight auth validation using unified resolver (handles refresh/audience)
        try:
            await resolve_server_auth_headers(
                server_config,
                user_id=user_id,
                app_config=app_config,
            )
            logger.info(f"Auth verified for {server_config.name}, proceeding with discovery")
        except AuthRequiredError:
            raise
        except Exception as e:
            logger.error(f"Auth resolution failed for {server_config.name}: {e}")
            raise

        # Temporary connection for discovery
        async with AsyncExitStack() as stack:
            stored_session_id = None
            if discovery_manager:
                try:
                    stored_session_id = await discovery_manager.get_mcp_session(
                        user_id, session_id, server_config.name
                    )
                except Exception:
                    logger.debug("Unable to fetch stored MCP session id for discovery")

            # Create temp connection (reuse session id if available)
            session, get_session_id = await create_mcp_session_with_retry(
                server_config,
                stack,
                user_id,
                mcp_session_id=stored_session_id,
                on_stale_session=(
                    lambda sid: discovery_manager.clear_mcp_session(
                        user_id, session_id, server_config.name, expected_session_id=sid
                    )
                    if discovery_manager
                    else None
                ),
            )

            # List available tools
            tools_result = await session.list_tools()
            logger.info(f"Found {len(tools_result.tools)} tools on {server_config.name}")

            # Create stateless McpTool objects
            mcp_tools = []
            for tool_info in tools_result.tools:
                # Create stateless tool
                mcp_tool = McpTool(
                    tool_name=tool_info.name,
                    description=tool_info.description,
                    input_schema=tool_info.inputSchema,
                    output_schema=getattr(tool_info, "outputSchema", None),
                    server_config=server_config,
                    server_name=server_config.name,
                )
                mcp_tools.append(mcp_tool)

                # Register in catalog for governance/HITL
                cls._register_tool_in_catalog(tool_info, server_config)

            # Serialize plugin data for storage
            plugin_data = cls._serialize_plugin_data(mcp_tools, server_config.name)

            session_identifier = get_session_id() if get_session_id else None

            logger.info(f"Discovered {len(mcp_tools)} tools from {server_config.name}")
            # Connection auto-closes when exiting context

            return plugin_data, session_identifier

    @classmethod
    def _register_tool_in_catalog(cls, tool_info: Any, server_config: McpServerConfig) -> None:
        """Register tool in catalog for governance and HITL."""
        try:
            catalog = PluginCatalogFactory().get_catalog()
            if not catalog:
                logger.warning("Plugin catalog not available, skipping catalog registration")
                return

            # Create consistent tool_id format: mcp_{server_name}_{tool_name}
            tool_id = f"mcp_{server_config.name}_{tool_info.name}"

            # Map MCP annotations to governance.
            # Newer MCP SDKs return a ToolAnnotations object without dict-like access.
            annotations_obj = getattr(tool_info, "annotations", None)
            if annotations_obj is None:
                annotations = {}
            elif hasattr(annotations_obj, "model_dump"):
                annotations = annotations_obj.model_dump() or {}
            elif isinstance(annotations_obj, dict):
                annotations = annotations_obj
            else:
                # Best-effort fallback for unknown types
                annotations = {}

            base_governance = map_mcp_annotations_to_governance(annotations)
            governance_with_trust = apply_trust_level_governance(
                base_governance, server_config.trust_level, tool_info.description or ""
            )

            # Apply manual overrides from config
            governance = cls._apply_governance_overrides(
                governance_with_trust, tool_info.name, server_config.tool_governance_overrides
            )

            # Create auth config if needed
            auth = cls._create_auth_if_needed(server_config)

            # Create PluginTool for catalog
            plugin_tool = PluginTool(
                tool_id=tool_id,
                name=tool_info.name,
                description=tool_info.description,
                governance=governance,
                auth=auth,
            )

            # Register in catalog
            plugin_id = f"mcp_{server_config.name}"
            catalog.register_dynamic_tool(plugin_tool, plugin_id=plugin_id)

            logger.debug(
                f"Registered tool in catalog: {tool_id} (requires_hitl={governance.requires_hitl})"
            )

        except Exception as e:
            logger.error(f"Failed to register tool {tool_info.name} in catalog: {e}")
            # Don't fail the whole discovery if catalog registration fails

    @classmethod
    def _serialize_plugin_data(cls, tools: list[McpTool], server_name: str) -> dict:
        """
        Serialize plugin tools to storable format.

        Args:
            tools: List of McpTool objects
            server_name: Name of the MCP server

        Returns:
            Dict: Serialized plugin data
        """

        def _sanitize_server_config(server_config):
            """Drop secrets before persisting discovery state."""
            cfg = server_config.model_dump()

            # Remove confidential OAuth client secret
            cfg.pop("oauth_client_secret", None)

            # Strip Authorization headers to avoid token leakage
            headers = cfg.get("headers") or {}
            cfg["headers"] = {k: v for k, v in headers.items() if k.lower() != "authorization"}

            # Drop env entries that look sensitive (best‑effort)
            env = cfg.get("env")
            if isinstance(env, dict):
                cfg["env"] = {
                    k: v
                    for k, v in env.items()
                    if not any(s in k.lower() for s in ["secret", "token", "key", "password"])
                }

            return cfg

        tools_data = []
        for tool in tools:
            tools_data.append(
                {
                    "tool_name": tool.tool_name,
                    "description": tool.description,
                    "input_schema": tool.input_schema,
                    "output_schema": tool.output_schema,
                    "server_name": tool.server_name,
                    "server_config": _sanitize_server_config(tool.server_config),
                }
            )
        return {"server_name": server_name, "tools": tools_data}

    @classmethod
    def _deserialize_tools(cls, plugin_data: dict) -> list[McpTool]:
        """
        Deserialize plugin data to McpTool list.

        Args:
            plugin_data: Serialized plugin data from storage

        Returns:
            List of McpTool objects
        """
        from sk_agents.tealagents.v1alpha1.config import McpServerConfig

        tools = []
        for tool_data in plugin_data["tools"]:
            server_config = McpServerConfig(**tool_data["server_config"])
            tool = McpTool(
                tool_name=tool_data["tool_name"],
                description=tool_data["description"],
                input_schema=tool_data["input_schema"],
                output_schema=tool_data["output_schema"],
                server_config=server_config,
                server_name=tool_data["server_name"],
            )
            tools.append(tool)

        return tools

    @classmethod
    async def get_tools_for_session(
        cls,
        user_id: str,
        session_id: str,
        discovery_manager,  # McpStateManager
    ) -> dict[str, list[McpTool]]:
        """
        Load MCP tools from external storage for this session.

        Args:
            user_id: User ID
            session_id: Session ID
            discovery_manager: Manager for loading discovery state

        Returns:
            Dictionary mapping server_name to list of McpTool objects
        """
        # Load state from external storage
        state = await discovery_manager.load_discovery(user_id, session_id)
        if not state or not state.discovery_completed:
            return {}

        # Deserialize tools for each server
        server_tools = {}
        for server_name, entry in state.discovered_servers.items():
            plugin_blob = entry.get("plugin_data") if isinstance(entry, dict) else None
            plugin_data = plugin_blob if plugin_blob else entry  # fallback to legacy shape
            tools = cls._deserialize_tools(plugin_data)
            server_tools[server_name] = tools

        logger.debug(f"Loaded tools for {len(server_tools)} MCP servers for session {session_id}")
        return server_tools
sk_agents.mcp_plugin_registry.McpPluginRegistry.discover_and_materialize async classmethod
discover_and_materialize(
    mcp_servers: list[McpServerConfig],
    user_id: str,
    session_id: str,
    discovery_manager,
    app_config,
) -> None

Discover MCP tools and store in external state.

This is called once per session when first invoked. Creates temporary connections to discover tools, then closes them.

Parameters:

Name Type Description Default
mcp_servers list[McpServerConfig]

List of MCP server configurations

required
user_id str

User ID for authentication

required
session_id str

Session ID for scoping

required
discovery_manager

Manager for storing discovery state

required

Raises:

Type Description
AuthRequiredError

If any server requires authentication that is missing

Source code in src/sk_agents/mcp_plugin_registry.py
@classmethod
async def discover_and_materialize(
    cls,
    mcp_servers: list[McpServerConfig],
    user_id: str,
    session_id: str,
    discovery_manager,  # McpStateManager
    app_config,
) -> None:
    """
    Discover MCP tools and store in external state.

    This is called once per session when first invoked.
    Creates temporary connections to discover tools, then closes them.

    Args:
        mcp_servers: List of MCP server configurations
        user_id: User ID for authentication
        session_id: Session ID for scoping
        discovery_manager: Manager for storing discovery state

    Raises:
        AuthRequiredError: If any server requires authentication that is missing
    """
    from sk_agents.mcp_client import AuthRequiredError

    logger.info(f"Starting MCP discovery for session {session_id} ({len(mcp_servers)} servers)")

    # Load existing state
    state = await discovery_manager.load_discovery(user_id, session_id)
    if not state:
        raise ValueError(f"Discovery state not initialized for session: {session_id}")

    auth_errors = []  # Collect auth errors to surface to user

    for server_config in mcp_servers:
        try:
            # Discover this server
            plugin_data, discovered_session_id = await cls._discover_server(
                server_config, user_id, session_id, discovery_manager, app_config
            )

            # Preserve any existing session bucket
            existing_entry = state.discovered_servers.get(server_config.name, {})
            session_bucket = existing_entry.get("session") or {}

            # Always persist freshly discovered plugin data
            state.discovered_servers[server_config.name] = {
                "plugin_data": plugin_data,
                **({"session": session_bucket} if session_bucket else {}),
            }
            await discovery_manager.update_discovery(state)

            # If discovery yielded a session id, persist via state manager API
            if discovered_session_id:
                try:
                    await discovery_manager.store_mcp_session(
                        user_id,
                        session_id,
                        server_config.name,
                        discovered_session_id,
                    )
                    await discovery_manager.update_session_last_used(
                        user_id, session_id, server_config.name
                    )
                except Exception as err:
                    logger.warning(
                        f"Failed to persist MCP session for {server_config.name}: {err}"
                    )

        except AuthRequiredError as e:
            # Auth error - collect and surface to user
            logger.warning(
                f"Auth required for MCP server {server_config.name} (session: {session_id})"
            )
            auth_errors.append(e)
        except Exception as e:
            # Other errors - log and continue with remaining servers
            # Extract underlying exception from TaskGroup if needed
            import traceback

            error_details = "".join(traceback.format_exception(type(e), e, e.__traceback__))

            # If it's a TaskGroup exception, try to extract the underlying exception
            underlying_error = str(e)
            if hasattr(e, "__cause__") and e.__cause__:
                underlying_error = f"{e} (caused by: {e.__cause__})"
            elif hasattr(e, "exceptions"):
                # ExceptionGroup-style
                underlying_error = f"{e} (sub-exceptions: {e.exceptions})"

            logger.error(
                f"Failed to discover MCP server {server_config.name} "
                f"for session {session_id}:\n"
                f"Error: {underlying_error}\n"
                f"Full traceback:\n{error_details}"
            )

            # Capture failure in state
            state.failed_servers[server_config.name] = underlying_error
            try:
                await discovery_manager.update_discovery(state)
            except Exception as update_err:
                logger.error(
                    f"Failed to persist discovery error for {server_config.name}: {update_err}"
                )

            continue

    # If any servers require auth, raise the first one to trigger auth challenge
    if auth_errors:
        logger.info(
            f"MCP discovery requires auth for {len(auth_errors)} server(s) "
            f"(session: {session_id}): {[e.server_name for e in auth_errors]}"
        )
        raise auth_errors[0]  # Raise first auth error to trigger challenge

    logger.info(
        f"MCP discovery complete for session {session_id}. "
        f"Discovered {len(state.discovered_servers)} servers"
    )
sk_agents.mcp_plugin_registry.McpPluginRegistry.get_tools_for_session async classmethod
get_tools_for_session(
    user_id: str, session_id: str, discovery_manager
) -> dict[str, list[McpTool]]

Load MCP tools from external storage for this session.

Parameters:

Name Type Description Default
user_id str

User ID

required
session_id str

Session ID

required
discovery_manager

Manager for loading discovery state

required

Returns:

Type Description
dict[str, list[McpTool]]

Dictionary mapping server_name to list of McpTool objects

Source code in src/sk_agents/mcp_plugin_registry.py
@classmethod
async def get_tools_for_session(
    cls,
    user_id: str,
    session_id: str,
    discovery_manager,  # McpStateManager
) -> dict[str, list[McpTool]]:
    """
    Load MCP tools from external storage for this session.

    Args:
        user_id: User ID
        session_id: Session ID
        discovery_manager: Manager for loading discovery state

    Returns:
        Dictionary mapping server_name to list of McpTool objects
    """
    # Load state from external storage
    state = await discovery_manager.load_discovery(user_id, session_id)
    if not state or not state.discovery_completed:
        return {}

    # Deserialize tools for each server
    server_tools = {}
    for server_name, entry in state.discovered_servers.items():
        plugin_blob = entry.get("plugin_data") if isinstance(entry, dict) else None
        plugin_data = plugin_blob if plugin_blob else entry  # fallback to legacy shape
        tools = cls._deserialize_tools(plugin_data)
        server_tools[server_name] = tools

    logger.debug(f"Loaded tools for {len(server_tools)} MCP servers for session {session_id}")
    return server_tools
sk_agents.persistence
sk_agents.persistence.custom
sk_agents.persistence.custom.example_redis_persistence

Complete Redis Task Persistence Implementation

This example demonstrates a full-featured, production-ready Redis-based task persistence implementation. It serves as a complete alternative to the default in-memory storage.

To use this implementation, set the following environment variables:

TA_PERSISTENCE_MODULE=src/sk_agents/persistence/custom/example_redis_persistence.py TA_PERSISTENCE_CLASS=RedisTaskPersistenceManager

Required Redis configuration environment variables: - TA_REDIS_HOST (default: localhost) - TA_REDIS_PORT (default: 6379) - TA_REDIS_DB (default: 0) - TA_REDIS_TTL (default: 3600 seconds) - TA_REDIS_PWD (optional) - TA_REDIS_SSL (default: false)

sk_agents.persistence.custom.example_redis_persistence.RedisTaskPersistenceManager

Bases: TaskPersistenceManager

Source code in src/sk_agents/persistence/custom/example_redis_persistence.py
class RedisTaskPersistenceManager(TaskPersistenceManager):
    def __init__(self, app_config: AppConfig = None):
        """
        Initialize the Redis-based task persistence manager.

        Args:
            app_config: Application configuration object. If None, creates a new one.
        """
        if app_config is None:
            app_config = AppConfig()

        self.app_config = app_config
        self._lock = threading.Lock()

        # Get Redis configuration
        redis_host = self.app_config.get(TA_REDIS_HOST.env_name) or "localhost"
        redis_port = int(self.app_config.get(TA_REDIS_PORT.env_name) or 6379)
        redis_db = int(self.app_config.get(TA_REDIS_DB.env_name) or 0)
        redis_password = self.app_config.get(TA_REDIS_PWD.env_name)
        redis_ssl = self.app_config.get(TA_REDIS_SSL.env_name) == "false"
        self.ttl = int(self.app_config.get(TA_REDIS_TTL.env_name) or 3600)  # Default 1 hour

        # Initialize Redis client
        self.redis_client = redis.Redis(
            host=redis_host,
            port=redis_port,
            db=redis_db,
            password=redis_password,
            ssl=redis_ssl,
            decode_responses=True,  # Automatically decode responses to strings
            socket_connect_timeout=5,
            socket_timeout=5,
            retry_on_timeout=True,
        )

        # Test connection
        try:
            self.redis_client.ping()
        except redis.ConnectionError as e:
            raise ConnectionError(f"Failed to connect to Redis: {e}") from e

    def _get_task_key(self, task_id: str) -> str:
        """Generate a Redis key for the given task_id."""
        return f"task_persistence:task:{task_id}"

    def _get_request_index_key(self, request_id: str) -> str:
        """Generate a Redis key for request_id index."""
        return f"task_persistence:request_index:{request_id}"

    def _serialize_task(self, task: AgentTask) -> str:
        """Serialize AgentTask to JSON string."""
        return task.model_dump_json()

    def _deserialize_task(self, task_str: str) -> AgentTask:
        """Deserialize JSON string to AgentTask."""
        task_dict = json.loads(task_str)
        return AgentTask.model_validate(task_dict)

    async def create(self, task: AgentTask) -> None:
        """Create a new task in Redis."""
        try:
            task_key = self._get_task_key(task.task_id)

            # Check if task already exists
            if self.redis_client.exists(task_key):
                raise PersistenceCreateError(
                    message=f"Task with ID '{task.task_id}' already exists."
                )

            # Serialize and store the task
            serialized_task = self._serialize_task(task)
            self.redis_client.setex(task_key, self.ttl, serialized_task)

            # Update request_id indexes
            for item in task.items:
                request_index_key = self._get_request_index_key(item.request_id)
                self.redis_client.sadd(request_index_key, task.task_id)
                self.redis_client.expire(request_index_key, self.ttl)

        except redis.RedisError as e:
            raise PersistenceCreateError(
                message=f"Failed to create task '{task.task_id}' in Redis: {e}"
            ) from e
        except Exception as e:
            raise PersistenceCreateError(
                message=f"Unexpected error creating task '{task.task_id}': {e}"
            ) from e

    async def load(self, task_id: str) -> AgentTask | None:
        """Load a task from Redis by task_id."""
        with self._lock:
            try:
                task_key = self._get_task_key(task_id)
                task_str = self.redis_client.get(task_key)

                if task_str is None:
                    return None

                return self._deserialize_task(task_str)

            except redis.RedisError as e:
                raise PersistenceLoadError(
                    message=f"Failed to load task '{task_id}' from Redis: {e}"
                ) from e
            except (json.JSONDecodeError, ValueError) as e:
                # If we can't deserialize the task, it's corrupted, so delete it
                try:
                    task_key = self._get_task_key(task_id)
                    self.redis_client.delete(task_key)
                except redis.RedisError:
                    pass  # Ignore deletion errors
                raise PersistenceLoadError(
                    message=f"Corrupted task data found for task_id {task_id}: {e}"
                ) from e

    async def update(self, task: AgentTask) -> None:
        """Update an existing task in Redis."""
        try:
            task_key = self._get_task_key(task.task_id)

            # Check if task exists
            old_task_str = self.redis_client.get(task_key)
            if old_task_str is None:
                raise PersistenceUpdateError(
                    f"Task with ID '{task.task_id}' does not exist for update."
                )

            # Deserialize old task to clean up old request_id indexes
            old_task = self._deserialize_task(old_task_str)

            # Remove old request_id associations
            for item in old_task.items:
                request_index_key = self._get_request_index_key(item.request_id)
                self.redis_client.srem(request_index_key, task.task_id)

            # Update the task
            serialized_task = self._serialize_task(task)
            self.redis_client.setex(task_key, self.ttl, serialized_task)

            # Add new request_id associations
            for item in task.items:
                request_index_key = self._get_request_index_key(item.request_id)
                self.redis_client.sadd(request_index_key, task.task_id)
                self.redis_client.expire(request_index_key, self.ttl)

        except redis.RedisError as e:
            raise PersistenceUpdateError(
                message=f"Failed to update task '{task.task_id}' in Redis: {e}"
            ) from e
        except Exception as e:
            raise PersistenceUpdateError(
                message=f"Unexpected error updating task '{task.task_id}': {e}"
            ) from e

    async def delete(self, task_id: str) -> None:
        """Delete a task from Redis."""
        try:
            task_key = self._get_task_key(task_id)

            # Get the task first to clean up request_id indexes
            task_str = self.redis_client.get(task_key)
            if task_str is None:
                raise PersistenceDeleteError(
                    message=f"Task with ID '{task_id}' does not exist for deletion."
                )

            task = self._deserialize_task(task_str)

            # Remove from request_id indexes
            for item in task.items:
                request_index_key = self._get_request_index_key(item.request_id)
                self.redis_client.srem(request_index_key, task_id)

            # Delete the task
            self.redis_client.delete(task_key)

        except redis.RedisError as e:
            raise PersistenceDeleteError(
                message=f"Failed to delete task '{task_id}' from Redis: {e}"
            ) from e
        except Exception as e:
            raise PersistenceDeleteError(
                message=f"Unexpected error deleting task '{task_id}': {e}"
            ) from e

    async def load_by_request_id(self, request_id: str) -> AgentTask | None:
        """Load a task by request_id."""
        try:
            request_index_key = self._get_request_index_key(request_id)
            task_ids = self.redis_client.smembers(request_index_key)

            if not task_ids:
                return None

            # If multiple tasks have the same request_id, return the first one
            task_id = next(iter(task_ids))
            return await self.load(task_id)

        except redis.RedisError as e:
            raise PersistenceLoadError(
                message=f"Failed to load task by request_id '{request_id}' from Redis: {e}"
            ) from e
        except Exception as e:
            raise PersistenceLoadError(
                message=f"Unexpected error loading task by request_id '{request_id}': {e}"
            ) from e

    def health_check(self) -> bool:
        """Check if Redis connection is healthy."""
        try:
            self.redis_client.ping()
            return True
        except redis.RedisError:
            return False

    def clear_all_tasks(self) -> int:
        """
        Clear all task data (useful for testing).

        Returns:
            Number of keys deleted.
        """
        try:
            # Get all task keys
            task_keys = self.redis_client.keys("task_persistence:task:*")
            request_index_keys = self.redis_client.keys("task_persistence:request_index:*")

            all_keys = task_keys + request_index_keys

            if not all_keys:
                return 0

            return self.redis_client.delete(*all_keys)

        except redis.RedisError as e:
            raise RuntimeError(f"Failed to clear all tasks from Redis: {e}") from e
sk_agents.persistence.custom.example_redis_persistence.RedisTaskPersistenceManager.__init__
__init__(app_config: AppConfig = None)

Initialize the Redis-based task persistence manager.

Parameters:

Name Type Description Default
app_config AppConfig

Application configuration object. If None, creates a new one.

None
Source code in src/sk_agents/persistence/custom/example_redis_persistence.py
def __init__(self, app_config: AppConfig = None):
    """
    Initialize the Redis-based task persistence manager.

    Args:
        app_config: Application configuration object. If None, creates a new one.
    """
    if app_config is None:
        app_config = AppConfig()

    self.app_config = app_config
    self._lock = threading.Lock()

    # Get Redis configuration
    redis_host = self.app_config.get(TA_REDIS_HOST.env_name) or "localhost"
    redis_port = int(self.app_config.get(TA_REDIS_PORT.env_name) or 6379)
    redis_db = int(self.app_config.get(TA_REDIS_DB.env_name) or 0)
    redis_password = self.app_config.get(TA_REDIS_PWD.env_name)
    redis_ssl = self.app_config.get(TA_REDIS_SSL.env_name) == "false"
    self.ttl = int(self.app_config.get(TA_REDIS_TTL.env_name) or 3600)  # Default 1 hour

    # Initialize Redis client
    self.redis_client = redis.Redis(
        host=redis_host,
        port=redis_port,
        db=redis_db,
        password=redis_password,
        ssl=redis_ssl,
        decode_responses=True,  # Automatically decode responses to strings
        socket_connect_timeout=5,
        socket_timeout=5,
        retry_on_timeout=True,
    )

    # Test connection
    try:
        self.redis_client.ping()
    except redis.ConnectionError as e:
        raise ConnectionError(f"Failed to connect to Redis: {e}") from e
sk_agents.persistence.custom.example_redis_persistence.RedisTaskPersistenceManager.create async
create(task: AgentTask) -> None

Create a new task in Redis.

Source code in src/sk_agents/persistence/custom/example_redis_persistence.py
async def create(self, task: AgentTask) -> None:
    """Create a new task in Redis."""
    try:
        task_key = self._get_task_key(task.task_id)

        # Check if task already exists
        if self.redis_client.exists(task_key):
            raise PersistenceCreateError(
                message=f"Task with ID '{task.task_id}' already exists."
            )

        # Serialize and store the task
        serialized_task = self._serialize_task(task)
        self.redis_client.setex(task_key, self.ttl, serialized_task)

        # Update request_id indexes
        for item in task.items:
            request_index_key = self._get_request_index_key(item.request_id)
            self.redis_client.sadd(request_index_key, task.task_id)
            self.redis_client.expire(request_index_key, self.ttl)

    except redis.RedisError as e:
        raise PersistenceCreateError(
            message=f"Failed to create task '{task.task_id}' in Redis: {e}"
        ) from e
    except Exception as e:
        raise PersistenceCreateError(
            message=f"Unexpected error creating task '{task.task_id}': {e}"
        ) from e
sk_agents.persistence.custom.example_redis_persistence.RedisTaskPersistenceManager.load async
load(task_id: str) -> AgentTask | None

Load a task from Redis by task_id.

Source code in src/sk_agents/persistence/custom/example_redis_persistence.py
async def load(self, task_id: str) -> AgentTask | None:
    """Load a task from Redis by task_id."""
    with self._lock:
        try:
            task_key = self._get_task_key(task_id)
            task_str = self.redis_client.get(task_key)

            if task_str is None:
                return None

            return self._deserialize_task(task_str)

        except redis.RedisError as e:
            raise PersistenceLoadError(
                message=f"Failed to load task '{task_id}' from Redis: {e}"
            ) from e
        except (json.JSONDecodeError, ValueError) as e:
            # If we can't deserialize the task, it's corrupted, so delete it
            try:
                task_key = self._get_task_key(task_id)
                self.redis_client.delete(task_key)
            except redis.RedisError:
                pass  # Ignore deletion errors
            raise PersistenceLoadError(
                message=f"Corrupted task data found for task_id {task_id}: {e}"
            ) from e
sk_agents.persistence.custom.example_redis_persistence.RedisTaskPersistenceManager.update async
update(task: AgentTask) -> None

Update an existing task in Redis.

Source code in src/sk_agents/persistence/custom/example_redis_persistence.py
async def update(self, task: AgentTask) -> None:
    """Update an existing task in Redis."""
    try:
        task_key = self._get_task_key(task.task_id)

        # Check if task exists
        old_task_str = self.redis_client.get(task_key)
        if old_task_str is None:
            raise PersistenceUpdateError(
                f"Task with ID '{task.task_id}' does not exist for update."
            )

        # Deserialize old task to clean up old request_id indexes
        old_task = self._deserialize_task(old_task_str)

        # Remove old request_id associations
        for item in old_task.items:
            request_index_key = self._get_request_index_key(item.request_id)
            self.redis_client.srem(request_index_key, task.task_id)

        # Update the task
        serialized_task = self._serialize_task(task)
        self.redis_client.setex(task_key, self.ttl, serialized_task)

        # Add new request_id associations
        for item in task.items:
            request_index_key = self._get_request_index_key(item.request_id)
            self.redis_client.sadd(request_index_key, task.task_id)
            self.redis_client.expire(request_index_key, self.ttl)

    except redis.RedisError as e:
        raise PersistenceUpdateError(
            message=f"Failed to update task '{task.task_id}' in Redis: {e}"
        ) from e
    except Exception as e:
        raise PersistenceUpdateError(
            message=f"Unexpected error updating task '{task.task_id}': {e}"
        ) from e
sk_agents.persistence.custom.example_redis_persistence.RedisTaskPersistenceManager.delete async
delete(task_id: str) -> None

Delete a task from Redis.

Source code in src/sk_agents/persistence/custom/example_redis_persistence.py
async def delete(self, task_id: str) -> None:
    """Delete a task from Redis."""
    try:
        task_key = self._get_task_key(task_id)

        # Get the task first to clean up request_id indexes
        task_str = self.redis_client.get(task_key)
        if task_str is None:
            raise PersistenceDeleteError(
                message=f"Task with ID '{task_id}' does not exist for deletion."
            )

        task = self._deserialize_task(task_str)

        # Remove from request_id indexes
        for item in task.items:
            request_index_key = self._get_request_index_key(item.request_id)
            self.redis_client.srem(request_index_key, task_id)

        # Delete the task
        self.redis_client.delete(task_key)

    except redis.RedisError as e:
        raise PersistenceDeleteError(
            message=f"Failed to delete task '{task_id}' from Redis: {e}"
        ) from e
    except Exception as e:
        raise PersistenceDeleteError(
            message=f"Unexpected error deleting task '{task_id}': {e}"
        ) from e
sk_agents.persistence.custom.example_redis_persistence.RedisTaskPersistenceManager.load_by_request_id async
load_by_request_id(request_id: str) -> AgentTask | None

Load a task by request_id.

Source code in src/sk_agents/persistence/custom/example_redis_persistence.py
async def load_by_request_id(self, request_id: str) -> AgentTask | None:
    """Load a task by request_id."""
    try:
        request_index_key = self._get_request_index_key(request_id)
        task_ids = self.redis_client.smembers(request_index_key)

        if not task_ids:
            return None

        # If multiple tasks have the same request_id, return the first one
        task_id = next(iter(task_ids))
        return await self.load(task_id)

    except redis.RedisError as e:
        raise PersistenceLoadError(
            message=f"Failed to load task by request_id '{request_id}' from Redis: {e}"
        ) from e
    except Exception as e:
        raise PersistenceLoadError(
            message=f"Unexpected error loading task by request_id '{request_id}': {e}"
        ) from e
sk_agents.persistence.custom.example_redis_persistence.RedisTaskPersistenceManager.health_check
health_check() -> bool

Check if Redis connection is healthy.

Source code in src/sk_agents/persistence/custom/example_redis_persistence.py
def health_check(self) -> bool:
    """Check if Redis connection is healthy."""
    try:
        self.redis_client.ping()
        return True
    except redis.RedisError:
        return False
sk_agents.persistence.custom.example_redis_persistence.RedisTaskPersistenceManager.clear_all_tasks
clear_all_tasks() -> int

Clear all task data (useful for testing).

Returns:

Type Description
int

Number of keys deleted.

Source code in src/sk_agents/persistence/custom/example_redis_persistence.py
def clear_all_tasks(self) -> int:
    """
    Clear all task data (useful for testing).

    Returns:
        Number of keys deleted.
    """
    try:
        # Get all task keys
        task_keys = self.redis_client.keys("task_persistence:task:*")
        request_index_keys = self.redis_client.keys("task_persistence:request_index:*")

        all_keys = task_keys + request_index_keys

        if not all_keys:
            return 0

        return self.redis_client.delete(*all_keys)

    except redis.RedisError as e:
        raise RuntimeError(f"Failed to clear all tasks from Redis: {e}") from e
sk_agents.persistence.persistence_factory
sk_agents.persistence.persistence_factory.PersistenceFactory
Source code in src/sk_agents/persistence/persistence_factory.py
class PersistenceFactory(metaclass=Singleton):
    def __init__(self, app_config: AppConfig):
        self.app_config = app_config

        # Try to load custom module, fallback to default if not configured
        module_name, class_name = self._get_custom_persistence_config()
        if module_name and class_name:
            try:
                self.module = ModuleLoader.load_module(module_name)
            except Exception as e:
                raise ImportError(f"Failed to load module '{module_name}': {e}") from e

            self.class_name = class_name
            self._validate_custom_class()
        else:
            self.module = None
            self.class_name = None

    def get_persistence_manager(self) -> TaskPersistenceManager:
        if self.module and self.class_name:
            # Use custom implementation
            custom_class = getattr(self.module, self.class_name)
            try:
                return custom_class(app_config=self.app_config)
            except TypeError:
                # Fallback if app_config not accepted
                return custom_class()
        else:
            # Use default implementation
            return InMemoryPersistenceManager()

    def _get_custom_persistence_config(self) -> tuple[str | None, str | None]:
        """Get custom persistence configuration, returning None values if using defaults."""
        module_name = self.app_config.get(TA_PERSISTENCE_MODULE.env_name)
        class_name = self.app_config.get(TA_PERSISTENCE_CLASS.env_name)

        # Check if we're using the default values (which means no custom config)
        if (
            module_name == TA_PERSISTENCE_MODULE.default_value
            and class_name == TA_PERSISTENCE_CLASS.default_value
        ):
            return None, None

        return module_name, class_name

    def _validate_custom_class(self):
        """Validate that the custom class is a proper TaskPersistenceManager subclass."""
        if not hasattr(self.module, self.class_name):
            module_name = getattr(self.module, "__name__", "unknown module")
            raise ValueError(
                f"Custom Task Persistence Manager class: {self.class_name} "
                f"Not found in module: {module_name}"
            )

        custom_class = getattr(self.module, self.class_name)
        if not issubclass(custom_class, TaskPersistenceManager):
            raise TypeError(
                f"Class '{self.class_name}' is not a subclass of TaskPersistenceManager."
            )
sk_agents.plugin_catalog
sk_agents.plugin_catalog.local_plugin_catalog
sk_agents.plugin_catalog.local_plugin_catalog.FileBasedPluginCatalog

Bases: PluginCatalog

File-based implementation that loads plugins from JSON files.

Source code in src/sk_agents/plugin_catalog/local_plugin_catalog.py
class FileBasedPluginCatalog(PluginCatalog):
    """File-based implementation that loads plugins from JSON files."""

    def __init__(self, app_config: AppConfig):
        self.app_config = app_config
        self.catalog_path = Path(self.app_config.get(TA_PLUGIN_CATALOG_FILE.env_name))
        self._plugins: dict[str, Plugin] = {}
        self._tools: dict[str, PluginTool] = {}
        self._load_plugins()

    def get_plugin(self, plugin_id: str) -> Plugin | None:
        """Get a plugin by its ID."""
        return self._plugins.get(plugin_id)

    def get_tool(self, tool_id: str) -> PluginTool | None:
        """Get a tool by its ID."""
        return self._tools.get(tool_id)

    def register_dynamic_plugin(self, plugin: Plugin) -> None:
        """Register a plugin discovered at runtime (e.g., from MCP servers)."""
        self._plugins[plugin.plugin_id] = plugin

        # Index all tools from this plugin for quick lookup
        for tool in plugin.tools:
            self._tools[tool.tool_id] = tool

    def register_dynamic_tool(self, tool: PluginTool, plugin_id: str = None) -> None:
        """Register a tool discovered at runtime."""
        # Add tool to tools index
        self._tools[tool.tool_id] = tool

        # If plugin_id is provided, ensure the plugin exists or create it
        if plugin_id:
            if plugin_id not in self._plugins:
                # Create a minimal plugin for this tool
                from sk_agents.plugin_catalog.models import McpPluginType

                plugin = Plugin(
                    plugin_id=plugin_id,
                    name=f"Dynamic Plugin: {plugin_id}",
                    description="Dynamically created plugin for runtime tools",
                    version="1.0.0",
                    owner="dynamic-registration",
                    plugin_type=McpPluginType(),
                    tools=[tool],
                )
                self._plugins[plugin_id] = plugin
            else:
                # Add tool to existing plugin
                existing_plugin = self._plugins[plugin_id]
                if tool not in existing_plugin.tools:
                    existing_plugin.tools.append(tool)

    def unregister_dynamic_plugin(self, plugin_id: str) -> bool:
        """Unregister a dynamically registered plugin."""
        if plugin_id in self._plugins:
            plugin = self._plugins[plugin_id]

            # Remove all tools from this plugin
            for tool in plugin.tools:
                if tool.tool_id in self._tools:
                    del self._tools[tool.tool_id]

            # Remove the plugin
            del self._plugins[plugin_id]
            return True
        return False

    def _load_plugins(self) -> None:
        """Load plugins from a single JSON file."""
        if not self.catalog_path.exists():
            return

        try:
            with open(self.catalog_path) as local_plugin_json:
                catalog_data = json.load(local_plugin_json)

            # Validate and convert to PluginCatalogDefinition
            try:
                catalog_definition = PluginCatalogDefinition.model_validate(catalog_data)
            except Exception as validation_error:
                raise PluginCatalogDefinitionException(
                    message="Plugin catalog definition validation failed"
                ) from validation_error
            # Process the validated plugins
            for plugin_data in catalog_definition.plugins:
                plugin = plugin_data
                self._plugins[plugin.plugin_id] = plugin

                # Index tools for quick lookup
                for tool in plugin.tools:
                    self._tools[tool.tool_id] = tool

        except PluginCatalogDefinitionException:
            # Re-raise our custom exception
            raise
        except Exception as e:
            raise PluginFileReadException(
                message="""
                Catalog encountered an error
                when attempting to read file
                """
            ) from e
sk_agents.plugin_catalog.local_plugin_catalog.FileBasedPluginCatalog.get_plugin
get_plugin(plugin_id: str) -> Plugin | None

Get a plugin by its ID.

Source code in src/sk_agents/plugin_catalog/local_plugin_catalog.py
def get_plugin(self, plugin_id: str) -> Plugin | None:
    """Get a plugin by its ID."""
    return self._plugins.get(plugin_id)
sk_agents.plugin_catalog.local_plugin_catalog.FileBasedPluginCatalog.get_tool
get_tool(tool_id: str) -> PluginTool | None

Get a tool by its ID.

Source code in src/sk_agents/plugin_catalog/local_plugin_catalog.py
def get_tool(self, tool_id: str) -> PluginTool | None:
    """Get a tool by its ID."""
    return self._tools.get(tool_id)
sk_agents.plugin_catalog.local_plugin_catalog.FileBasedPluginCatalog.register_dynamic_plugin
register_dynamic_plugin(plugin: Plugin) -> None

Register a plugin discovered at runtime (e.g., from MCP servers).

Source code in src/sk_agents/plugin_catalog/local_plugin_catalog.py
def register_dynamic_plugin(self, plugin: Plugin) -> None:
    """Register a plugin discovered at runtime (e.g., from MCP servers)."""
    self._plugins[plugin.plugin_id] = plugin

    # Index all tools from this plugin for quick lookup
    for tool in plugin.tools:
        self._tools[tool.tool_id] = tool
sk_agents.plugin_catalog.local_plugin_catalog.FileBasedPluginCatalog.register_dynamic_tool
register_dynamic_tool(
    tool: PluginTool, plugin_id: str = None
) -> None

Register a tool discovered at runtime.

Source code in src/sk_agents/plugin_catalog/local_plugin_catalog.py
def register_dynamic_tool(self, tool: PluginTool, plugin_id: str = None) -> None:
    """Register a tool discovered at runtime."""
    # Add tool to tools index
    self._tools[tool.tool_id] = tool

    # If plugin_id is provided, ensure the plugin exists or create it
    if plugin_id:
        if plugin_id not in self._plugins:
            # Create a minimal plugin for this tool
            from sk_agents.plugin_catalog.models import McpPluginType

            plugin = Plugin(
                plugin_id=plugin_id,
                name=f"Dynamic Plugin: {plugin_id}",
                description="Dynamically created plugin for runtime tools",
                version="1.0.0",
                owner="dynamic-registration",
                plugin_type=McpPluginType(),
                tools=[tool],
            )
            self._plugins[plugin_id] = plugin
        else:
            # Add tool to existing plugin
            existing_plugin = self._plugins[plugin_id]
            if tool not in existing_plugin.tools:
                existing_plugin.tools.append(tool)
sk_agents.plugin_catalog.local_plugin_catalog.FileBasedPluginCatalog.unregister_dynamic_plugin
unregister_dynamic_plugin(plugin_id: str) -> bool

Unregister a dynamically registered plugin.

Source code in src/sk_agents/plugin_catalog/local_plugin_catalog.py
def unregister_dynamic_plugin(self, plugin_id: str) -> bool:
    """Unregister a dynamically registered plugin."""
    if plugin_id in self._plugins:
        plugin = self._plugins[plugin_id]

        # Remove all tools from this plugin
        for tool in plugin.tools:
            if tool.tool_id in self._tools:
                del self._tools[tool.tool_id]

        # Remove the plugin
        del self._plugins[plugin_id]
        return True
    return False
sk_agents.plugin_catalog.models
sk_agents.plugin_catalog.models.GovernanceOverride

Bases: BaseModel

Optional governance overrides for MCP tools.

Only specified fields will override auto-inferred values.

Source code in src/sk_agents/plugin_catalog/models.py
class GovernanceOverride(BaseModel):
    """Optional governance overrides for MCP tools.

    Only specified fields will override auto-inferred values.
    """

    requires_hitl: bool | None = None
    cost: Literal["low", "medium", "high"] | None = None
    data_sensitivity: Literal["public", "proprietary", "confidential", "sensitive"] | None = None
sk_agents.plugin_catalog.plugin_catalog
sk_agents.plugin_catalog.plugin_catalog.PluginCatalog

Bases: ABC

Source code in src/sk_agents/plugin_catalog/plugin_catalog.py
class PluginCatalog(ABC):
    @abstractmethod
    def get_plugin(self, plugin_id: str) -> Plugin | None: ...

    @abstractmethod
    def get_tool(self, tool_id: str) -> PluginTool | None: ...

    # Dynamic registration methods for MCP and other runtime-discovered tools
    def register_dynamic_plugin(self, plugin: Plugin) -> None:
        """Register a plugin discovered at runtime (e.g., from MCP servers)."""
        # Default no-op implementation; subclasses may override
        _ = plugin

    def register_dynamic_tool(self, tool: PluginTool, plugin_id: str | None = None) -> None:
        """Register a tool discovered at runtime."""
        # Default no-op implementation; subclasses may override
        _ = tool, plugin_id

    def unregister_dynamic_plugin(self, plugin_id: str) -> bool:
        """Unregister a dynamically registered plugin."""
        return False  # Default implementation does nothing
sk_agents.plugin_catalog.plugin_catalog.PluginCatalog.register_dynamic_plugin
register_dynamic_plugin(plugin: Plugin) -> None

Register a plugin discovered at runtime (e.g., from MCP servers).

Source code in src/sk_agents/plugin_catalog/plugin_catalog.py
def register_dynamic_plugin(self, plugin: Plugin) -> None:
    """Register a plugin discovered at runtime (e.g., from MCP servers)."""
    # Default no-op implementation; subclasses may override
    _ = plugin
sk_agents.plugin_catalog.plugin_catalog.PluginCatalog.register_dynamic_tool
register_dynamic_tool(
    tool: PluginTool, plugin_id: str | None = None
) -> None

Register a tool discovered at runtime.

Source code in src/sk_agents/plugin_catalog/plugin_catalog.py
def register_dynamic_tool(self, tool: PluginTool, plugin_id: str | None = None) -> None:
    """Register a tool discovered at runtime."""
    # Default no-op implementation; subclasses may override
    _ = tool, plugin_id
sk_agents.plugin_catalog.plugin_catalog.PluginCatalog.unregister_dynamic_plugin
unregister_dynamic_plugin(plugin_id: str) -> bool

Unregister a dynamically registered plugin.

Source code in src/sk_agents/plugin_catalog/plugin_catalog.py
def unregister_dynamic_plugin(self, plugin_id: str) -> bool:
    """Unregister a dynamically registered plugin."""
    return False  # Default implementation does nothing
sk_agents.plugin_catalog.plugin_catalog_factory
sk_agents.plugin_catalog.plugin_catalog_factory.PluginCatalogFactory

Singleton factory for creating PluginCatalog instances based on environment variables.

Source code in src/sk_agents/plugin_catalog/plugin_catalog_factory.py
class PluginCatalogFactory(metaclass=Singleton):
    """
    Singleton factory for creating PluginCatalog
    instances based on environment variables.
    """

    def __init__(self):
        super().__init__()
        AppConfig.add_configs(configs)
        app_config = AppConfig()
        self.app_config = app_config
        self._catalog_instance: PluginCatalog | None = None

    def get_catalog(self) -> PluginCatalog:
        """
        Get the plugin catalog instance,
        creating it if it doesn't exist.
        """
        if self._catalog_instance is None:
            self._catalog_instance = self._create_catalog()
        return self._catalog_instance

    def _create_catalog(self) -> PluginCatalog:
        """
        Create a new plugin catalog instance
        based on environment variables.
        """
        module_name = self.app_config.get(TA_PLUGIN_CATALOG_MODULE.env_name)
        class_name = self.app_config.get(TA_PLUGIN_CATALOG_CLASS.env_name)

        if not module_name or not class_name:
            raise ValueError(
                "Both TA_PLUGIN_CATALOG_MODULE and TA_PLUGIN_CATALOG_CLASS "
                "environment variables must be set"
            )

        try:
            # Dynamically import the module
            module = ModuleLoader.load_module(module_name)

            # Get the class from the module
            catalog_class: type[PluginCatalog] = getattr(module, class_name)

            # Verify it's a subclass of PluginCatalog
            if not issubclass(catalog_class, PluginCatalog):
                raise TypeError(
                    f"Class {class_name} in module {module_name} must inherit from PluginCatalog"
                )

            # Instantiate and return the catalog
            return catalog_class(self.app_config)

        except ImportError as e:
            raise ImportError(f"Failed to import module '{module_name}': {e}") from e
        except AttributeError as e:
            raise AttributeError(
                f"Class '{class_name}' not found in module '{module_name}': {e}"
            ) from e
sk_agents.plugin_catalog.plugin_catalog_factory.PluginCatalogFactory.get_catalog
get_catalog() -> PluginCatalog

Get the plugin catalog instance, creating it if it doesn't exist.

Source code in src/sk_agents/plugin_catalog/plugin_catalog_factory.py
def get_catalog(self) -> PluginCatalog:
    """
    Get the plugin catalog instance,
    creating it if it doesn't exist.
    """
    if self._catalog_instance is None:
        self._catalog_instance = self._create_catalog()
    return self._catalog_instance
sk_agents.routes
sk_agents.routes.Routes
Source code in src/sk_agents/routes.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
class Routes:
    @staticmethod
    def get_url(name: str, version: str, app_config: AppConfig) -> str:
        base_url = app_config.get(TA_AGENT_BASE_URL.env_name)
        if not base_url:
            logger.exception("Base URL is not provided in the app config.")
            raise ValueError("Base URL is not provided in the app config.")
        return f"{base_url}/{name}/{version}/a2a"

    @staticmethod
    def get_provider(app_config: AppConfig) -> AgentProvider:
        return AgentProvider(
            organization=app_config.get(TA_PROVIDER_ORG.env_name),
            url=app_config.get(TA_PROVIDER_URL.env_name),
        )

    @staticmethod
    def get_agent_card(config: BaseConfig, app_config: AppConfig) -> AgentCard:
        if config.metadata is None:
            logger.exception("Agent card metadata is not provided in the config.")
            raise ValueError("Agent card metadata is not provided in the config.")

        metadata = config.metadata
        skills = [
            AgentSkill(
                id=skill.id,
                name=skill.name,
                description=skill.description,
                tags=skill.tags,
                examples=skill.examples,
                inputModes=skill.input_modes,
                outputModes=skill.output_modes,
            )
            for skill in metadata.skills
        ]
        return AgentCard(
            name=config.name,
            version=str(config.version),
            description=metadata.description,
            url=Routes.get_url(config.name, config.version, app_config),
            provider=Routes.get_provider(app_config),
            documentationUrl=config.metadata.documentation_url,
            capabilities=AgentCapabilities(
                streaming=True, pushNotifications=False, stateTransitionHistory=True
            ),
            defaultInputModes=["text"],
            defaultOutputModes=["text"],
            skills=skills,
        )

    @staticmethod
    def _create_chat_completions_builder(app_config: AppConfig):
        return ChatCompletionBuilder(app_config)

    @staticmethod
    def _create_remote_plugin_loader(app_config: AppConfig):
        remote_plugin_catalog = RemotePluginCatalog(app_config)
        return RemotePluginLoader(remote_plugin_catalog)

    @staticmethod
    def _create_kernel_builder(app_config: AppConfig, authorization: str):
        chat_completions = Routes._create_chat_completions_builder(app_config)
        remote_plugin_loader = Routes._create_remote_plugin_loader(app_config)
        kernel_builder = KernelBuilder(
            chat_completions, remote_plugin_loader, app_config, authorization
        )
        return kernel_builder

    @staticmethod
    def _create_agent_builder(app_config: AppConfig, authorization: str):
        kernel_builder = Routes._create_kernel_builder(app_config, authorization)
        agent_builder = AgentBuilder(kernel_builder, authorization)
        return agent_builder

    @staticmethod
    def get_request_handler(
        config: BaseConfig,
        app_config: AppConfig,
        chat_completion_builder: ChatCompletionBuilder,
        state_manager: StateManager,
        task_store: TaskStore,
    ) -> DefaultRequestHandler:
        return DefaultRequestHandler(
            agent_executor=A2AAgentExecutor(
                config, app_config, chat_completion_builder, state_manager
            ),
            task_store=task_store,
        )

    @staticmethod
    def get_task_handler(
        config: BaseConfig,
        app_config: AppConfig,
        authorization: str,
        state_manager: TaskPersistenceManager,
        mcp_discovery_manager=None,  # McpStateManager - Optional
    ) -> TealAgentsV1Alpha1Handler:
        agent_builder = Routes._create_agent_builder(app_config, authorization)
        return TealAgentsV1Alpha1Handler(
            config, app_config, agent_builder, state_manager, mcp_discovery_manager
        )

    @staticmethod
    def get_a2a_routes(
        name: str,
        version: str,
        description: str,
        config: BaseConfig,
        app_config: AppConfig,
        chat_completion_builder: ChatCompletionBuilder,
        task_store: TaskStore,
        state_manager: StateManager,
    ) -> APIRouter:
        """
        DEPRECATION NOTICE: A2A (Agent-to-Agent) routes are being deprecated
        as part of the framework migration evaluation. This method is maintained for
        backward compatibility only. New development should avoid using A2A functionality.
        """
        a2a_app = A2AStarletteApplication(
            agent_card=Routes.get_agent_card(config, app_config),
            http_handler=Routes.get_request_handler(
                config, app_config, chat_completion_builder, state_manager, task_store
            ),
        )
        a2a_router = APIRouter()

        @a2a_router.post("")
        @docstring_parameter(description)
        async def handle_a2a(request: Request):
            """
            {0}

            Agent-to-Agent Invocation
            """
            return await a2a_app._handle_requests(request)

        @a2a_router.get("/.well-known/agent.json")
        @docstring_parameter(f"{name}:{version} - {description}")
        async def handle_get_agent_card(request: Request):
            """
            Retrieve agent card for {0}
            """
            return await a2a_app._handle_get_agent_card(request)

        return a2a_router

    @staticmethod
    def get_rest_routes(
        name: str,
        version: str,
        description: str,
        root_handler_name: str,
        config: BaseConfig,
        app_config: AppConfig,
        input_class: type,
        output_class: type,
    ) -> APIRouter:
        router = APIRouter()

        @router.post("")
        @docstring_parameter(description)
        async def invoke(inputs: input_class, request: Request) -> InvokeResponse[output_class]:  # type: ignore
            """
            {0}
            """
            st = get_telemetry()
            context = extract(request.headers)

            authorization = request.headers.get("authorization", None)
            with (
                st.tracer.start_as_current_span(
                    f"{name}-{version}-invoke",
                    context=context,
                )
                if st.telemetry_enabled()
                else nullcontext()
            ):
                match root_handler_name:
                    case "skagents":
                        handler: BaseHandler = skagents_handle(config, app_config, authorization)
                    case _:
                        raise ValueError(f"Unknown apiVersion: {config.apiVersion}")

                inv_inputs = inputs.__dict__
                output = await handler.invoke(inputs=inv_inputs)
                return output

        @router.post("/sse")
        @docstring_parameter(description)
        async def invoke_sse(inputs: input_class, request: Request) -> StreamingResponse:
            """
            {0}
            Initiate SSE call
            """
            st = get_telemetry()
            context = extract(request.headers)
            authorization = request.headers.get("authorization", None)
            inv_inputs = inputs.__dict__

            async def event_generator():
                with (
                    st.tracer.start_as_current_span(
                        f"{config.service_name}-{str(config.version)}-invoke_sse",
                        context=context,
                    )
                    if st.telemetry_enabled()
                    else nullcontext()
                ):
                    match root_handler_name:
                        case "skagents":
                            handler: BaseHandler = skagents_handle(
                                config, app_config, authorization
                            )
                            # noinspection PyTypeChecker
                            async for content in handler.invoke_stream(inputs=inv_inputs):
                                yield get_sse_event_for_response(content)
                        case _:
                            logger.exception(
                                "Unknown apiVersion: %s", config.apiVersion, exc_info=True
                            )
                            raise ValueError(f"Unknown apiVersion: {config.apiVersion}")

            return StreamingResponse(event_generator(), media_type="text/event-stream")

        return router

    @staticmethod
    def get_websocket_routes(
        name: str,
        version: str,
        root_handler_name: str,
        config: BaseConfig,
        app_config: AppConfig,
        input_class: type,
    ) -> APIRouter:
        router = APIRouter()

        @router.websocket("/stream")
        async def invoke_stream(websocket: WebSocket) -> None:
            await websocket.accept()
            st = get_telemetry()
            context = extract(websocket.headers)

            authorization = websocket.headers.get("authorization", None)
            try:
                data = await websocket.receive_json()
                with (
                    st.tracer.start_as_current_span(
                        f"{name}-{str(version)}-invoke_stream",
                        context=context,
                    )
                    if st.telemetry_enabled()
                    else nullcontext()
                ):
                    inputs = input_class(**data)
                    inv_inputs = inputs.__dict__
                    match root_handler_name:
                        case "skagents":
                            handler: BaseHandler = skagents_handle(
                                config, app_config, authorization
                            )
                            async for content in handler.invoke_stream(inputs=inv_inputs):
                                if isinstance(content, PartialResponse):
                                    await websocket.send_text(content.output_partial)
                            await websocket.close()
                        case _:
                            logger.exception(
                                "Unknown apiVersion: %s", config.apiVersion, exc_info=True
                            )
                            raise ValueError(f"Unknown apiVersion %s: {config.apiVersion}")
            except WebSocketDisconnect:
                logger.exception("websocket disconnected")
                print("websocket disconnected")

        return router

    @staticmethod
    def get_stateful_routes(
        name: str,
        version: str,
        description: str,
        config: BaseConfig,
        app_config: AppConfig,
        state_manager: TaskPersistenceManager,
        authorizer: RequestAuthorizer,
        auth_storage_manager: SecureAuthStorageManager,
        mcp_discovery_manager=None,  # McpStateManager - Optional
        input_class: type[UserMessage] = UserMessage,
    ) -> APIRouter:
        """
        Get the stateful API routes for the given configuration.
        """
        router = APIRouter()

        async def get_user_id(authorization: str = Header(None)):
            user_id = await authorizer.authorize_request(authorization)
            if not user_id:
                raise HTTPException(
                    status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required"
                )
            return user_id

        @router.post(
            "",
            response_model=StateResponse,
            summary="Send a message to the agent",
            response_description="Agent response with state identifiers",
            tags=["Agent"],
        )
        async def chat(message: input_class, user_id: str = Depends(get_user_id)) -> StateResponse:
            # Handle new task creation or task retrieval
            teal_handler = Routes.get_task_handler(
                config, app_config, user_id, state_manager, mcp_discovery_manager
            )
            response_content = await teal_handler.invoke(user_id, message)
            # Return response with state identifiers
            status = TaskStatus.COMPLETED.value
            if type(response_content) is HitlResponse:
                status = TaskStatus.PAUSED.value
            return StateResponse(
                session_id=response_content.session_id,
                task_id=response_content.task_id,
                request_id=response_content.request_id,
                status=status,
                content=response_content,  # Replace with actual response
            )

        return router

    @staticmethod
    def get_resume_routes(
        config: BaseConfig,
        app_config: AppConfig,
        state_manager: TaskPersistenceManager,
        mcp_discovery_manager=None,
    ) -> APIRouter:
        router = APIRouter()

        @router.post("/tealagents/v1alpha1/resume/{request_id}")
        async def resume(request_id: str, request: Request, body: ResumeRequest):
            authorization = request.headers.get("authorization", None)
            teal_handler = Routes.get_task_handler(
                config, app_config, authorization, state_manager, mcp_discovery_manager
            )
            try:
                return await teal_handler.resume_task(authorization, request_id, body, stream=False)
            except Exception as e:
                logger.exception(f"Error in resume: {e}")
                raise HTTPException(status_code=500, detail="Internal Server Error") from e

        @router.post("/tealagents/v1alpha1/resume/{request_id}/sse")
        async def resume_sse(request_id: str, request: Request, body: ResumeRequest):
            authorization = request.headers.get("authorization", None)
            teal_handler = Routes.get_task_handler(
                config, app_config, authorization, state_manager, mcp_discovery_manager
            )

            async def event_generator():
                try:
                    async for content in teal_handler.resume_task(
                        authorization, request_id, body, stream=True
                    ):
                        yield get_sse_event_for_response(content)
                except Exception as e:
                    logger.exception(f"Error in resume_sse: {e}")
                    raise HTTPException(status_code=500, detail="Internal Server Error") from e

            return StreamingResponse(event_generator(), media_type="text/event-stream")

        return router

    @staticmethod
    def get_oauth_callback_routes(
        config: BaseConfig,
        app_config: AppConfig,
    ) -> APIRouter:
        """
        Get OAuth 2.1 callback routes for MCP server authentication.

        This route handles the OAuth redirect callback after user authorization.
        """
        router = APIRouter()

        @router.get("/oauth/callback")
        async def oauth_callback(
            code: str,
            state: str,
        ):
            """
            Handle OAuth 2.1 callback from authorization server.

            Validates state, exchanges code for tokens, and stores in AuthStorage.

            Args:
                code: Authorization code from auth server
                state: CSRF state parameter

            Returns:
                Success response with server name and token metadata
            """
            from sk_agents.auth.oauth_client import OAuthClient
            from sk_agents.auth.oauth_state_manager import OAuthStateManager

            try:
                # Initialize OAuth components
                oauth_client = OAuthClient()
                state_manager = OAuthStateManager()

                # Retrieve flow state using state parameter only
                # This extracts user_id without requiring it upfront
                try:
                    flow_state = state_manager.retrieve_flow_state_by_state_only(state)
                except ValueError as e:
                    logger.warning(f"Invalid OAuth state in callback: {e}")
                    raise HTTPException(
                        status_code=status.HTTP_400_BAD_REQUEST,
                        detail="Invalid or expired state parameter",
                    ) from e

                user_id = flow_state.user_id
                server_name = flow_state.server_name

                # Look up server config from agent configuration
                mcp_servers = (
                    getattr(config.spec.agent, "mcp_servers", None)
                    if hasattr(config, "spec")
                    else None
                )
                if not mcp_servers:
                    raise HTTPException(
                        status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                        detail="No MCP servers configured",
                    )

                server_config = None
                for server in mcp_servers:
                    if server.name == server_name:
                        server_config = server
                        break

                if not server_config:
                    raise HTTPException(
                        status_code=status.HTTP_404_NOT_FOUND,
                        detail=f"MCP server '{server_name}' not found in configuration",
                    )

                # Handle callback (validate state, exchange code, store tokens)
                oauth_data = await oauth_client.handle_callback(
                    code=code, state=state, user_id=user_id, server_config=server_config
                )

                logger.info(f"OAuth callback successful for user={user_id}, server={server_name}")

                # Return success response
                return {
                    "status": "success",
                    "message": f"Successfully authenticated to {server_name}",
                    "server_name": server_name,
                    "scopes": oauth_data.scopes,
                    "expires_at": oauth_data.expires_at.isoformat(),
                }

            except HTTPException:
                raise
            except Exception as e:
                logger.exception(f"Error in OAuth callback: {e}")
                raise HTTPException(
                    status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                    detail=f"OAuth callback failed: {str(e)}",
                ) from e

        return router
sk_agents.routes.Routes.get_a2a_routes staticmethod
get_a2a_routes(
    name: str,
    version: str,
    description: str,
    config: BaseConfig,
    app_config: AppConfig,
    chat_completion_builder: ChatCompletionBuilder,
    task_store: TaskStore,
    state_manager: StateManager,
) -> APIRouter

DEPRECATION NOTICE: A2A (Agent-to-Agent) routes are being deprecated as part of the framework migration evaluation. This method is maintained for backward compatibility only. New development should avoid using A2A functionality.

Source code in src/sk_agents/routes.py
@staticmethod
def get_a2a_routes(
    name: str,
    version: str,
    description: str,
    config: BaseConfig,
    app_config: AppConfig,
    chat_completion_builder: ChatCompletionBuilder,
    task_store: TaskStore,
    state_manager: StateManager,
) -> APIRouter:
    """
    DEPRECATION NOTICE: A2A (Agent-to-Agent) routes are being deprecated
    as part of the framework migration evaluation. This method is maintained for
    backward compatibility only. New development should avoid using A2A functionality.
    """
    a2a_app = A2AStarletteApplication(
        agent_card=Routes.get_agent_card(config, app_config),
        http_handler=Routes.get_request_handler(
            config, app_config, chat_completion_builder, state_manager, task_store
        ),
    )
    a2a_router = APIRouter()

    @a2a_router.post("")
    @docstring_parameter(description)
    async def handle_a2a(request: Request):
        """
        {0}

        Agent-to-Agent Invocation
        """
        return await a2a_app._handle_requests(request)

    @a2a_router.get("/.well-known/agent.json")
    @docstring_parameter(f"{name}:{version} - {description}")
    async def handle_get_agent_card(request: Request):
        """
        Retrieve agent card for {0}
        """
        return await a2a_app._handle_get_agent_card(request)

    return a2a_router
sk_agents.routes.Routes.get_stateful_routes staticmethod
get_stateful_routes(
    name: str,
    version: str,
    description: str,
    config: BaseConfig,
    app_config: AppConfig,
    state_manager: TaskPersistenceManager,
    authorizer: RequestAuthorizer,
    auth_storage_manager: SecureAuthStorageManager,
    mcp_discovery_manager=None,
    input_class: type[UserMessage] = UserMessage,
) -> APIRouter

Get the stateful API routes for the given configuration.

Source code in src/sk_agents/routes.py
@staticmethod
def get_stateful_routes(
    name: str,
    version: str,
    description: str,
    config: BaseConfig,
    app_config: AppConfig,
    state_manager: TaskPersistenceManager,
    authorizer: RequestAuthorizer,
    auth_storage_manager: SecureAuthStorageManager,
    mcp_discovery_manager=None,  # McpStateManager - Optional
    input_class: type[UserMessage] = UserMessage,
) -> APIRouter:
    """
    Get the stateful API routes for the given configuration.
    """
    router = APIRouter()

    async def get_user_id(authorization: str = Header(None)):
        user_id = await authorizer.authorize_request(authorization)
        if not user_id:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required"
            )
        return user_id

    @router.post(
        "",
        response_model=StateResponse,
        summary="Send a message to the agent",
        response_description="Agent response with state identifiers",
        tags=["Agent"],
    )
    async def chat(message: input_class, user_id: str = Depends(get_user_id)) -> StateResponse:
        # Handle new task creation or task retrieval
        teal_handler = Routes.get_task_handler(
            config, app_config, user_id, state_manager, mcp_discovery_manager
        )
        response_content = await teal_handler.invoke(user_id, message)
        # Return response with state identifiers
        status = TaskStatus.COMPLETED.value
        if type(response_content) is HitlResponse:
            status = TaskStatus.PAUSED.value
        return StateResponse(
            session_id=response_content.session_id,
            task_id=response_content.task_id,
            request_id=response_content.request_id,
            status=status,
            content=response_content,  # Replace with actual response
        )

    return router
sk_agents.routes.Routes.get_oauth_callback_routes staticmethod
get_oauth_callback_routes(
    config: BaseConfig, app_config: AppConfig
) -> APIRouter

Get OAuth 2.1 callback routes for MCP server authentication.

This route handles the OAuth redirect callback after user authorization.

Source code in src/sk_agents/routes.py
@staticmethod
def get_oauth_callback_routes(
    config: BaseConfig,
    app_config: AppConfig,
) -> APIRouter:
    """
    Get OAuth 2.1 callback routes for MCP server authentication.

    This route handles the OAuth redirect callback after user authorization.
    """
    router = APIRouter()

    @router.get("/oauth/callback")
    async def oauth_callback(
        code: str,
        state: str,
    ):
        """
        Handle OAuth 2.1 callback from authorization server.

        Validates state, exchanges code for tokens, and stores in AuthStorage.

        Args:
            code: Authorization code from auth server
            state: CSRF state parameter

        Returns:
            Success response with server name and token metadata
        """
        from sk_agents.auth.oauth_client import OAuthClient
        from sk_agents.auth.oauth_state_manager import OAuthStateManager

        try:
            # Initialize OAuth components
            oauth_client = OAuthClient()
            state_manager = OAuthStateManager()

            # Retrieve flow state using state parameter only
            # This extracts user_id without requiring it upfront
            try:
                flow_state = state_manager.retrieve_flow_state_by_state_only(state)
            except ValueError as e:
                logger.warning(f"Invalid OAuth state in callback: {e}")
                raise HTTPException(
                    status_code=status.HTTP_400_BAD_REQUEST,
                    detail="Invalid or expired state parameter",
                ) from e

            user_id = flow_state.user_id
            server_name = flow_state.server_name

            # Look up server config from agent configuration
            mcp_servers = (
                getattr(config.spec.agent, "mcp_servers", None)
                if hasattr(config, "spec")
                else None
            )
            if not mcp_servers:
                raise HTTPException(
                    status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                    detail="No MCP servers configured",
                )

            server_config = None
            for server in mcp_servers:
                if server.name == server_name:
                    server_config = server
                    break

            if not server_config:
                raise HTTPException(
                    status_code=status.HTTP_404_NOT_FOUND,
                    detail=f"MCP server '{server_name}' not found in configuration",
                )

            # Handle callback (validate state, exchange code, store tokens)
            oauth_data = await oauth_client.handle_callback(
                code=code, state=state, user_id=user_id, server_config=server_config
            )

            logger.info(f"OAuth callback successful for user={user_id}, server={server_name}")

            # Return success response
            return {
                "status": "success",
                "message": f"Successfully authenticated to {server_name}",
                "server_name": server_name,
                "scopes": oauth_data.scopes,
                "expires_at": oauth_data.expires_at.isoformat(),
            }

        except HTTPException:
            raise
        except Exception as e:
            logger.exception(f"Error in OAuth callback: {e}")
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                detail=f"OAuth callback failed: {str(e)}",
            ) from e

    return router
sk_agents.ska_types
sk_agents.ska_types.BaseMultiModalInputWithUserContext

Bases: KernelBaseModel

The history of a chat interaction between an automated assistant and a human with multimodal input (text and images), along with context about the user.

Source code in src/sk_agents/ska_types.py
class BaseMultiModalInputWithUserContext(KernelBaseModel):
    """The history of a chat interaction between an automated assistant and a
    human with multimodal input (text and images), along with context about the user."""

    session_id: str | None = None
    chat_history: list[HistoryMultiModalMessage] | None = None
    user_context: dict[str, str] | None = None
sk_agents.ska_types.HistoryMessage

Bases: BaseModel

A single interaction in a chat history.
'role' - Either 'user' (requestor) or 'assistant' (responder) indicating who sent the message.
'content' - The content of the message

Source code in src/sk_agents/ska_types.py
class HistoryMessage(BaseModel):
    """A single interaction in a chat history.<br/>
    'role' - Either 'user' (requestor) or 'assistant' (responder) indicating
    who sent the message.<br/>
    'content' - The content of the message"""

    role: Literal["user", "assistant"]
    content: str
sk_agents.ska_types.BaseInput

Bases: KernelBaseModel

The history of a chat interaction between an automated assistant and a human.

Source code in src/sk_agents/ska_types.py
class BaseInput(KernelBaseModel):
    """The history of a chat interaction between an automated assistant and a
    human."""

    chat_history: list[HistoryMessage] | None = None
sk_agents.ska_types.BaseInputWithUserContext

Bases: KernelBaseModel

The history of a chat interaction between an automated assistant and a human, along with context about the user.

Source code in src/sk_agents/ska_types.py
class BaseInputWithUserContext(KernelBaseModel):
    """The history of a chat interaction between an automated assistant and a
    human, along with context about the user."""

    chat_history: list[HistoryMessage] | None = None
    user_context: dict[str, str] | None = None
sk_agents.state
sk_agents.state.redis_state_manager

Redis implementation of the StateManager interface. This implementation uses Redis as the persistent store for task state management.

sk_agents.state.redis_state_manager.RedisStateManager

Bases: StateManager

Redis implementation of the StateManager interface.

This class provides Redis-based persistence for task state management.

Source code in src/sk_agents/state/redis_state_manager.py
class RedisStateManager(StateManager):
    """Redis implementation of the StateManager interface.

    This class provides Redis-based persistence for task state management.
    """

    def __init__(
        self,
        redis_client: Redis,
        ttl: int | None = None,
        key_prefix: str = "task_state:",
    ):
        """Initialize the RedisStateManager with a Redis client.

        Args:
            redis_client: An instance of Redis client
            key_prefix: Prefix used for Redis keys (default: "task_state:")
        """
        self._redis = redis_client
        self._key_prefix = key_prefix
        self._ttl = ttl

    def _get_message_key(self, task_id: str) -> str:
        """Generate a Redis key for a task's messages.

        Args:
            task_id: The ID of the task

        Returns:
            A Redis key string for the task's messages
        """
        return f"{self._key_prefix}{task_id}:messages"

    def _get_canceled_key(self, task_id: str) -> str:
        """Generate a Redis key for a task's canceled status.

        Args:
            task_id: The ID of the task

        Returns:
            A Redis key string for the task's canceled status
        """
        return f"{self._key_prefix}{task_id}:canceled"

    async def update_task_messages(
        self, task_id: str, new_message: HistoryMultiModalMessage
    ) -> list[HistoryMultiModalMessage]:
        """Updates the messages for a specific task.

        Appends a new message to the task's message history and returns
        the complete list of messages.

        Args:
            task_id: The ID of the task
            new_message: The new message to add to the task's history

        Returns:
            The complete list of messages for the task
        """
        # Get the Redis key for this task's messages
        message_key = self._get_message_key(task_id)

        # Serialize the new message to JSON with mode='json' to ensure enums are properly serialized
        message_json = json.dumps(new_message.model_dump(mode="json"))

        # Add the new message to the list in Redis
        await self._redis.rpush(message_key, message_json)
        if self._ttl:
            await self._redis.expire(message_key, int(self._ttl))

        # Retrieve all messages for the task
        message_jsons = await self._redis.lrange(message_key, 0, -1)

        # Deserialize each message from JSON
        messages = [
            HistoryMultiModalMessage.model_validate(json.loads(msg)) for msg in message_jsons
        ]

        return messages

    async def set_canceled(self, task_id: str) -> None:
        """Marks a task as canceled.

        Args:
            task_id: The ID of the task to mark as canceled
        """
        # Set the canceled flag for the task
        await self._redis.set(self._get_canceled_key(task_id), "1", ex=self._ttl)

    async def is_canceled(self, task_id: str) -> bool:
        """Checks if a task is marked as canceled.

        Args:
            task_id: The ID of the task to check

        Returns:
            True if the task is canceled, False otherwise
        """
        # Check if the canceled flag is set
        canceled = await self._redis.get(self._get_canceled_key(task_id))
        return canceled == "1"
sk_agents.state.redis_state_manager.RedisStateManager.__init__
__init__(
    redis_client: Redis,
    ttl: int | None = None,
    key_prefix: str = "task_state:",
)

Initialize the RedisStateManager with a Redis client.

Parameters:

Name Type Description Default
redis_client Redis

An instance of Redis client

required
key_prefix str

Prefix used for Redis keys (default: "task_state:")

'task_state:'
Source code in src/sk_agents/state/redis_state_manager.py
def __init__(
    self,
    redis_client: Redis,
    ttl: int | None = None,
    key_prefix: str = "task_state:",
):
    """Initialize the RedisStateManager with a Redis client.

    Args:
        redis_client: An instance of Redis client
        key_prefix: Prefix used for Redis keys (default: "task_state:")
    """
    self._redis = redis_client
    self._key_prefix = key_prefix
    self._ttl = ttl
sk_agents.state.redis_state_manager.RedisStateManager.update_task_messages async
update_task_messages(
    task_id: str, new_message: HistoryMultiModalMessage
) -> list[HistoryMultiModalMessage]

Updates the messages for a specific task.

Appends a new message to the task's message history and returns the complete list of messages.

Parameters:

Name Type Description Default
task_id str

The ID of the task

required
new_message HistoryMultiModalMessage

The new message to add to the task's history

required

Returns:

Type Description
list[HistoryMultiModalMessage]

The complete list of messages for the task

Source code in src/sk_agents/state/redis_state_manager.py
async def update_task_messages(
    self, task_id: str, new_message: HistoryMultiModalMessage
) -> list[HistoryMultiModalMessage]:
    """Updates the messages for a specific task.

    Appends a new message to the task's message history and returns
    the complete list of messages.

    Args:
        task_id: The ID of the task
        new_message: The new message to add to the task's history

    Returns:
        The complete list of messages for the task
    """
    # Get the Redis key for this task's messages
    message_key = self._get_message_key(task_id)

    # Serialize the new message to JSON with mode='json' to ensure enums are properly serialized
    message_json = json.dumps(new_message.model_dump(mode="json"))

    # Add the new message to the list in Redis
    await self._redis.rpush(message_key, message_json)
    if self._ttl:
        await self._redis.expire(message_key, int(self._ttl))

    # Retrieve all messages for the task
    message_jsons = await self._redis.lrange(message_key, 0, -1)

    # Deserialize each message from JSON
    messages = [
        HistoryMultiModalMessage.model_validate(json.loads(msg)) for msg in message_jsons
    ]

    return messages
sk_agents.state.redis_state_manager.RedisStateManager.set_canceled async
set_canceled(task_id: str) -> None

Marks a task as canceled.

Parameters:

Name Type Description Default
task_id str

The ID of the task to mark as canceled

required
Source code in src/sk_agents/state/redis_state_manager.py
async def set_canceled(self, task_id: str) -> None:
    """Marks a task as canceled.

    Args:
        task_id: The ID of the task to mark as canceled
    """
    # Set the canceled flag for the task
    await self._redis.set(self._get_canceled_key(task_id), "1", ex=self._ttl)
sk_agents.state.redis_state_manager.RedisStateManager.is_canceled async
is_canceled(task_id: str) -> bool

Checks if a task is marked as canceled.

Parameters:

Name Type Description Default
task_id str

The ID of the task to check

required

Returns:

Type Description
bool

True if the task is canceled, False otherwise

Source code in src/sk_agents/state/redis_state_manager.py
async def is_canceled(self, task_id: str) -> bool:
    """Checks if a task is marked as canceled.

    Args:
        task_id: The ID of the task to check

    Returns:
        True if the task is canceled, False otherwise
    """
    # Check if the canceled flag is set
    canceled = await self._redis.get(self._get_canceled_key(task_id))
    return canceled == "1"
sk_agents.stateful
sk_agents.stateful.UserMessage

Bases: BaseModel

New input model for the tealagents/v1alpha1 API version. Unlike BaseMultiModalInput, chat history is maintained server-side.

Source code in src/sk_agents/stateful.py
class UserMessage(BaseModel):
    """
    New input model for the tealagents/v1alpha1 API version.
    Unlike BaseMultiModalInput, chat history is maintained server-side.
    """

    session_id: UUID4 | None = None
    task_id: UUID4 | None = None
    items: list[MultiModalItem]

    @field_validator("session_id", "task_id", mode="before")
    @classmethod
    def validate_uuid(cls, v):
        if v is not None and not isinstance(v, uuid.UUID):
            try:
                return uuid.UUID(v)
            except (ValueError, AttributeError) as err:
                raise ValueError(f"Invalid UUID format: {v}") from err
        return v
sk_agents.stateful.TaskState

Bases: BaseModel

Model for the state associated with a Task ID

Source code in src/sk_agents/stateful.py
class TaskState(BaseModel):
    """Model for the state associated with a Task ID"""

    task_id: UUID4
    session_id: UUID4
    user_id: str  # User identity for authorization
    messages: list[dict[str, Any]]  # Chat history and execution trace
    status: TaskStatus = TaskStatus.RUNNING
    created_at: datetime = Field(default_factory=datetime.utcnow)
    updated_at: datetime = Field(default_factory=datetime.utcnow)
    metadata: dict[str, Any] = Field(default_factory=dict)
sk_agents.stateful.RequestState

Bases: BaseModel

Model for the state associated with a Request ID

Source code in src/sk_agents/stateful.py
class RequestState(BaseModel):
    """Model for the state associated with a Request ID"""

    request_id: UUID4
    task_id: UUID4
    status: TaskStatus = TaskStatus.RUNNING
    created_at: datetime = Field(default_factory=datetime.utcnow)
    updated_at: datetime = Field(default_factory=datetime.utcnow)
    metadata: dict[str, Any] = Field(default_factory=dict)
sk_agents.stateful.StateResponse

Bases: BaseModel

Response model including state identifiers

Source code in src/sk_agents/stateful.py
class StateResponse(BaseModel):
    """Response model including state identifiers"""

    session_id: UUID4
    task_id: UUID4
    request_id: UUID4
    status: TaskStatus
    content: InvokeResponse | RejectedToolResponse | HitlResponse | TealAgentsResponse
sk_agents.stateful.StateManager

Bases: ABC

Abstract base class for state management

Source code in src/sk_agents/stateful.py
class StateManager(ABC):
    """Abstract base class for state management"""

    @abstractmethod
    async def create_task(self, session_id: UUID4 | None, user_id: str) -> tuple[UUID4, UUID4]:
        """Create a new task and return session_id and task_id"""

    @abstractmethod
    async def get_task(self, task_id: UUID4) -> TaskState:
        """Get a task by ID"""

    @abstractmethod
    async def update_task(self, task_state: TaskState) -> None:
        """Update a task state"""

    @abstractmethod
    async def create_request(self, task_id: UUID4) -> UUID4:
        """Create a new request and return request_id"""

    @abstractmethod
    async def get_request(self, request_id: UUID4) -> RequestState:
        """Get a request by ID"""

    @abstractmethod
    async def update_request(self, request_state: RequestState) -> None:
        """Update a request state"""
sk_agents.stateful.StateManager.create_task abstractmethod async
create_task(
    session_id: UUID4 | None, user_id: str
) -> tuple[UUID4, UUID4]

Create a new task and return session_id and task_id

Source code in src/sk_agents/stateful.py
@abstractmethod
async def create_task(self, session_id: UUID4 | None, user_id: str) -> tuple[UUID4, UUID4]:
    """Create a new task and return session_id and task_id"""
sk_agents.stateful.StateManager.get_task abstractmethod async
get_task(task_id: UUID4) -> TaskState

Get a task by ID

Source code in src/sk_agents/stateful.py
@abstractmethod
async def get_task(self, task_id: UUID4) -> TaskState:
    """Get a task by ID"""
sk_agents.stateful.StateManager.update_task abstractmethod async
update_task(task_state: TaskState) -> None

Update a task state

Source code in src/sk_agents/stateful.py
@abstractmethod
async def update_task(self, task_state: TaskState) -> None:
    """Update a task state"""
sk_agents.stateful.StateManager.create_request abstractmethod async
create_request(task_id: UUID4) -> UUID4

Create a new request and return request_id

Source code in src/sk_agents/stateful.py
@abstractmethod
async def create_request(self, task_id: UUID4) -> UUID4:
    """Create a new request and return request_id"""
sk_agents.stateful.StateManager.get_request abstractmethod async
get_request(request_id: UUID4) -> RequestState

Get a request by ID

Source code in src/sk_agents/stateful.py
@abstractmethod
async def get_request(self, request_id: UUID4) -> RequestState:
    """Get a request by ID"""
sk_agents.stateful.StateManager.update_request abstractmethod async
update_request(request_state: RequestState) -> None

Update a request state

Source code in src/sk_agents/stateful.py
@abstractmethod
async def update_request(self, request_state: RequestState) -> None:
    """Update a request state"""
sk_agents.stateful.InMemoryStateManager

Bases: StateManager

In-memory implementation of state manager

Source code in src/sk_agents/stateful.py
class InMemoryStateManager(StateManager):
    """In-memory implementation of state manager"""

    def __init__(self):
        self.tasks: dict[UUID4, TaskState] = {}
        self.requests: dict[UUID4, RequestState] = {}

    async def create_task(self, session_id: UUID4 | None, user_id: str) -> tuple[UUID4, UUID4]:
        session_id = session_id or uuid.uuid4()
        task_id = uuid.uuid4()
        self.tasks[task_id] = TaskState(
            task_id=task_id, session_id=session_id, user_id=user_id, messages=[]
        )
        return session_id, task_id

    async def get_task(self, task_id: UUID4) -> TaskState:
        if task_id not in self.tasks:
            raise ValueError(f"Task not found: {task_id}")
        return self.tasks[task_id]

    async def update_task(self, task_state: TaskState) -> None:
        task_state.updated_at = datetime.utcnow()
        self.tasks[task_state.task_id] = task_state

    async def create_request(self, task_id: UUID4) -> UUID4:
        request_id = uuid.uuid4()
        self.requests[request_id] = RequestState(request_id=request_id, task_id=task_id)
        return request_id

    async def get_request(self, request_id: UUID4) -> RequestState:
        if request_id not in self.requests:
            raise ValueError(f"Request not found: {request_id}")
        return self.requests[request_id]

    async def update_request(self, request_state: RequestState) -> None:
        request_state.updated_at = datetime.utcnow()
        self.requests[request_state.request_id] = request_state
sk_agents.stateful.RedisStateManager

Bases: StateManager

Redis implementation of state manager

Source code in src/sk_agents/stateful.py
class RedisStateManager(StateManager):
    """Redis implementation of state manager"""

    def __init__(self, redis_client: Redis, ttl: int | None = None):
        self.redis = redis_client
        self.ttl = ttl  # Time-to-live in seconds

    async def create_task(self, session_id: UUID4 | None, user_id: str) -> tuple[UUID4, UUID4]:
        session_id = session_id or uuid.uuid4()
        task_id = uuid.uuid4()
        task_state = TaskState(task_id=task_id, session_id=session_id, user_id=user_id, messages=[])
        await self._set_task(task_state)
        return session_id, task_id

    async def get_task(self, task_id: UUID4) -> TaskState:
        key = f"task:{task_id}"
        data = await self.redis.get(key)
        if not data:
            raise ValueError(f"Task not found: {task_id}")
        return TaskState.parse_raw(data)

    async def update_task(self, task_state: TaskState) -> None:
        task_state.updated_at = datetime.utcnow()
        await self._set_task(task_state)

    async def _set_task(self, task_state: TaskState) -> None:
        key = f"task:{task_state.task_id}"
        await self.redis.set(key, task_state.json(), ex=self.ttl)

    async def create_request(self, task_id: UUID4) -> UUID4:
        request_id = uuid.uuid4()
        request_state = RequestState(request_id=request_id, task_id=task_id)
        await self._set_request(request_state)
        return request_id

    async def get_request(self, request_id: UUID4) -> RequestState:
        key = f"request:{request_id}"
        data = await self.redis.get(key)
        if not data:
            raise ValueError(f"Request not found: {request_id}")
        return RequestState.parse_raw(data)

    async def update_request(self, request_state: RequestState) -> None:
        request_state.updated_at = datetime.utcnow()
        await self._set_request(request_state)

    async def _set_request(self, request_state: RequestState) -> None:
        key = f"request:{request_state.request_id}"
        await self.redis.set(key, request_state.json(), ex=self.ttl)
sk_agents.stateful.AuthenticationManager

Bases: ABC

Abstract base class for authentication management

Source code in src/sk_agents/stateful.py
class AuthenticationManager(ABC):
    """Abstract base class for authentication management"""

    @abstractmethod
    async def authorize_request(self, token: str) -> str:
        """Authenticate a token and return the user ID"""
        pass

    @abstractmethod
    async def validate_task_access(self, task_id: UUID4, user_id: str) -> bool:
        """Validate if a user has access to a task"""
        pass
sk_agents.stateful.AuthenticationManager.authorize_request abstractmethod async
authorize_request(token: str) -> str

Authenticate a token and return the user ID

Source code in src/sk_agents/stateful.py
@abstractmethod
async def authorize_request(self, token: str) -> str:
    """Authenticate a token and return the user ID"""
    pass
sk_agents.stateful.AuthenticationManager.validate_task_access abstractmethod async
validate_task_access(task_id: UUID4, user_id: str) -> bool

Validate if a user has access to a task

Source code in src/sk_agents/stateful.py
@abstractmethod
async def validate_task_access(self, task_id: UUID4, user_id: str) -> bool:
    """Validate if a user has access to a task"""
    pass
sk_agents.stateful.MockAuthenticationManager

Bases: AuthenticationManager

Mock implementation of authentication manager for development

Source code in src/sk_agents/stateful.py
class MockAuthenticationManager(AuthenticationManager):
    """Mock implementation of authentication manager for development"""

    async def authorize_request(self, token: str) -> str:
        # In mock implementation, just return the token as the user ID
        # In real implementation, this would validate the token with Entra ID
        return token or "anonymous-user"

    async def validate_task_access(self, task_id: UUID4, user_id: str) -> bool:
        # In mock implementation, always return True
        # In real implementation, this would check if the user owns the task
        return True
sk_agents.tealagents
sk_agents.tealagents.kernel_builder
sk_agents.tealagents.kernel_builder.KernelBuilder
Source code in src/sk_agents/tealagents/kernel_builder.py
class KernelBuilder:
    def __init__(
        self,
        chat_completion_builder: ChatCompletionBuilder,
        remote_plugin_loader: RemotePluginLoader,
        app_config: AppConfig,
        authorization: str | None = None,
    ):
        self.chat_completion_builder: ChatCompletionBuilder = chat_completion_builder
        self.remote_plugin_loader = remote_plugin_loader
        self.app_config: AppConfig = app_config
        self.authorization = authorization
        self.logger = logging.getLogger(__name__)

        # Initialize auth storage and authorizer for token cache functionality
        self.auth_storage_manager: SecureAuthStorageManager = AuthStorageFactory(
            app_config
        ).get_auth_storage_manager()
        self.authorizer: RequestAuthorizer = AuthorizerFactory(app_config).get_authorizer()

    async def build_kernel(
        self,
        model_name: str,
        service_id: str,
        plugins: list[str],
        remote_plugins: list[str],
        mcp_servers: list[McpServerConfig] | None = None,
        authorization: str | None = None,
        extra_data_collector: ExtraDataCollector | None = None,
        user_id: str | None = None,
    ) -> Kernel:
        try:
            kernel = self._create_base_kernel(model_name, service_id)
            kernel = self._parse_plugins(plugins, kernel, authorization, extra_data_collector)
            kernel = self._load_remote_plugins(remote_plugins, kernel)

            # MCP plugins will be loaded separately in async context by handler
            # Remove sync MCP loading to avoid event loop conflicts

            return kernel
        except Exception as e:
            self.logger.exception(f"Could build kernel with service ID {service_id}. - {e}")
            raise

    def get_model_type_for_name(self, model_name: str) -> ModelType:
        try:
            return self.chat_completion_builder.get_model_type_for_name(model_name)
        except Exception as e:
            self.logger.exception(f"Could not get model type for {model_name}. - {e}")
            raise

    def model_supports_structured_output(self, model_name: str) -> bool:
        return self.chat_completion_builder.model_supports_structured_output(model_name)

    def _create_base_kernel(self, model_name: str, service_id: str) -> Kernel:
        try:
            chat_completion = self.chat_completion_builder.get_chat_completion_for_model(
                service_id=service_id,
                model_name=model_name,
            )

            kernel = Kernel()
            kernel.add_service(chat_completion)

            return kernel
        except Exception as e:
            self.logger.exception(f"Could not create base kernelwith service id {service_id}.-{e}")
            raise

    def _load_remote_plugins(self, remote_plugins: list[str], kernel: Kernel) -> Kernel:
        if remote_plugins is None or len(remote_plugins) < 1:
            return kernel
        try:
            self.remote_plugin_loader.load_remote_plugins(kernel, remote_plugins)
            return kernel
        except Exception as e:
            self.logger.exception(f"Could not load remote plugings. -{e}")
            raise

    async def load_mcp_plugins(
        self,
        kernel: Kernel,
        user_id: str,
        session_id: str,
        mcp_discovery_manager,
        connection_manager,
    ) -> Kernel:
        """
        Load MCP plugins by instantiating McpPlugin directly with tools from storage.

        This loads tools discovered at session start and creates McpPlugin instances
        for each MCP server. Only tools that the user has authenticated to access
        will be loaded, ensuring proper multi-tenant isolation at the session level.

        Args:
            kernel: The kernel to add plugins to
            user_id: User ID to get plugins for (required)
            session_id: Session ID for plugin isolation (required)
            mcp_discovery_manager: Discovery manager for loading tool state (required)
            connection_manager: Request-scoped connection manager for connection reuse (required)

        Returns:
            The kernel with session's MCP plugins loaded

        Note: MCP tools must be discovered first via McpPluginRegistry.discover_and_materialize()
        before calling this method.
        """
        if not user_id:
            raise ValueError("user_id is required when loading MCP plugins")
        if not session_id:
            raise ValueError("session_id is required when loading MCP plugins")
        if not mcp_discovery_manager:
            raise ValueError("mcp_discovery_manager is required when loading MCP plugins")
        if not connection_manager:
            raise ValueError("connection_manager is required when loading MCP plugins")

        try:
            from sk_agents.mcp_client import McpPlugin
            from sk_agents.mcp_plugin_registry import McpPluginRegistry

            # Get tools for THIS session (session-level isolation)
            server_tools = await McpPluginRegistry.get_tools_for_session(
                user_id, session_id, mcp_discovery_manager
            )

            if not server_tools:
                self.logger.debug(f"No MCP tools found for user {user_id}, session {session_id}")
                return kernel

            # Instantiate McpPlugin directly for each server
            for server_name, tools in server_tools.items():
                plugin_instance = McpPlugin(
                    tools=tools,
                    server_name=server_name,
                    user_id=user_id,
                    connection_manager=connection_manager,
                    authorization=self.authorization,
                    extra_data_collector=None,
                )

                # Register with kernel
                # Sanitize server name: SK requires plugin names to match ^[0-9A-Za-z_]+
                sanitized_server_name = server_name.replace("-", "_").replace(".", "_")
                kernel.add_plugin(plugin_instance, f"mcp_{sanitized_server_name}")
                self.logger.info(
                    f"Loaded MCP plugin for {server_name} as mcp_{sanitized_server_name} "
                    f"(user: {user_id}, session: {session_id})"
                )

            self.logger.info(
                f"Loaded {len(server_tools)} MCP plugins for user {user_id}, session {session_id}"
            )
            return kernel

        except Exception as e:
            self.logger.exception(
                f"Could not load MCP plugins for user {user_id}, session {session_id}. - {e}"
            )
            raise

    def _parse_plugins(
        self,
        plugin_names: list[str],
        kernel: Kernel,
        authorization: str | None = None,
        extra_data_collector: ExtraDataCollector | None = None,
    ) -> Kernel:
        if plugin_names is None or len(plugin_names) < 1:
            return kernel

        plugin_loader = get_plugin_loader()
        plugins = plugin_loader.get_plugins(plugin_names)

        for plugin_name, plugin_class in plugins.items():
            # For non-MCP plugins, use original authorization directly
            # (MCP plugins handle auth differently via user_id)
            plugin_authorization = authorization

            # Create and add the plugin to the kernel
            kernel.add_plugin(plugin_class(plugin_authorization, extra_data_collector), plugin_name)

        return kernel

    async def _get_plugin_authorization(
        self, plugin_name: str, original_authorization: str | None = None
    ) -> str | None:
        """
        Get plugin-specific authorization, checking token cache for stored OAuth2 tokens.

        Args:
            plugin_name: Name of the plugin requesting authorization
            original_authorization: Original authorization header from the request

        Returns:
            Authorization string to use for the plugin (either cached token or original)
        """
        if not original_authorization:
            return None

        try:
            # Extract user ID from the authorization header
            user_id = await self.authorizer.authorize_request(original_authorization)
            if not user_id:
                self.logger.warning(
                    f"Could not extract user ID from authorization for plugin {plugin_name}"
                )
                return original_authorization

            # Try to retrieve cached OAuth2 tokens for this user and plugin
            cached_auth_data = self.auth_storage_manager.retrieve(user_id, plugin_name)

            if cached_auth_data and hasattr(cached_auth_data, "access_token"):
                self.logger.info(f"Using cached token for plugin {plugin_name}, user {user_id}")
                # Return the cached access token in Bearer format
                return f"Bearer {cached_auth_data.access_token}"
            else:
                self.logger.debug(
                    f"No cached tokens found for plugin {plugin_name}, user {user_id} - "
                    f"returning None"
                )
                return None

        except Exception as e:
            self.logger.warning(
                f"Error retrieving cached tokens for plugin {plugin_name}: {e} - returning None"
            )
            return None
sk_agents.tealagents.kernel_builder.KernelBuilder.load_mcp_plugins async
load_mcp_plugins(
    kernel: Kernel,
    user_id: str,
    session_id: str,
    mcp_discovery_manager,
    connection_manager,
) -> Kernel

Load MCP plugins by instantiating McpPlugin directly with tools from storage.

This loads tools discovered at session start and creates McpPlugin instances for each MCP server. Only tools that the user has authenticated to access will be loaded, ensuring proper multi-tenant isolation at the session level.

Parameters:

Name Type Description Default
kernel Kernel

The kernel to add plugins to

required
user_id str

User ID to get plugins for (required)

required
session_id str

Session ID for plugin isolation (required)

required
mcp_discovery_manager

Discovery manager for loading tool state (required)

required
connection_manager

Request-scoped connection manager for connection reuse (required)

required

Returns:

Type Description
Kernel

The kernel with session's MCP plugins loaded

Note: MCP tools must be discovered first via McpPluginRegistry.discover_and_materialize() before calling this method.

Source code in src/sk_agents/tealagents/kernel_builder.py
async def load_mcp_plugins(
    self,
    kernel: Kernel,
    user_id: str,
    session_id: str,
    mcp_discovery_manager,
    connection_manager,
) -> Kernel:
    """
    Load MCP plugins by instantiating McpPlugin directly with tools from storage.

    This loads tools discovered at session start and creates McpPlugin instances
    for each MCP server. Only tools that the user has authenticated to access
    will be loaded, ensuring proper multi-tenant isolation at the session level.

    Args:
        kernel: The kernel to add plugins to
        user_id: User ID to get plugins for (required)
        session_id: Session ID for plugin isolation (required)
        mcp_discovery_manager: Discovery manager for loading tool state (required)
        connection_manager: Request-scoped connection manager for connection reuse (required)

    Returns:
        The kernel with session's MCP plugins loaded

    Note: MCP tools must be discovered first via McpPluginRegistry.discover_and_materialize()
    before calling this method.
    """
    if not user_id:
        raise ValueError("user_id is required when loading MCP plugins")
    if not session_id:
        raise ValueError("session_id is required when loading MCP plugins")
    if not mcp_discovery_manager:
        raise ValueError("mcp_discovery_manager is required when loading MCP plugins")
    if not connection_manager:
        raise ValueError("connection_manager is required when loading MCP plugins")

    try:
        from sk_agents.mcp_client import McpPlugin
        from sk_agents.mcp_plugin_registry import McpPluginRegistry

        # Get tools for THIS session (session-level isolation)
        server_tools = await McpPluginRegistry.get_tools_for_session(
            user_id, session_id, mcp_discovery_manager
        )

        if not server_tools:
            self.logger.debug(f"No MCP tools found for user {user_id}, session {session_id}")
            return kernel

        # Instantiate McpPlugin directly for each server
        for server_name, tools in server_tools.items():
            plugin_instance = McpPlugin(
                tools=tools,
                server_name=server_name,
                user_id=user_id,
                connection_manager=connection_manager,
                authorization=self.authorization,
                extra_data_collector=None,
            )

            # Register with kernel
            # Sanitize server name: SK requires plugin names to match ^[0-9A-Za-z_]+
            sanitized_server_name = server_name.replace("-", "_").replace(".", "_")
            kernel.add_plugin(plugin_instance, f"mcp_{sanitized_server_name}")
            self.logger.info(
                f"Loaded MCP plugin for {server_name} as mcp_{sanitized_server_name} "
                f"(user: {user_id}, session: {session_id})"
            )

        self.logger.info(
            f"Loaded {len(server_tools)} MCP plugins for user {user_id}, session {session_id}"
        )
        return kernel

    except Exception as e:
        self.logger.exception(
            f"Could not load MCP plugins for user {user_id}, session {session_id}. - {e}"
        )
        raise
sk_agents.tealagents.models
sk_agents.tealagents.models.AuthChallengeResponse

Bases: BaseModel

Response when MCP server authentication is required before agent construction.

Source code in src/sk_agents/tealagents/models.py
class AuthChallengeResponse(BaseModel):
    """Response when MCP server authentication is required before agent construction."""

    task_id: str
    session_id: str
    request_id: str
    message: str = "Authentication required for MCP servers."
    auth_challenges: list[dict]  # List of auth challenge details per server
    resume_url: str  # URL to resume agent flow after auth completion
sk_agents.tealagents.models.TaskStatus

Bases: Enum

Enum representing the status of a task

Source code in src/sk_agents/tealagents/models.py
class TaskStatus(Enum):
    """Enum representing the status of a task"""

    RUNNING = "Running"
    PAUSED = "Paused"
    COMPLETED = "Completed"
    FAILED = "Failed"
sk_agents.tealagents.v1alpha1
sk_agents.tealagents.v1alpha1.agent
sk_agents.tealagents.v1alpha1.agent.handler
sk_agents.tealagents.v1alpha1.agent.handler.TealAgentsV1Alpha1Handler

Bases: BaseHandler

Source code in src/sk_agents/tealagents/v1alpha1/agent/handler.py
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
class TealAgentsV1Alpha1Handler(BaseHandler):
    def __init__(
        self,
        config: BaseConfig,
        app_config: AppConfig,
        agent_builder: AgentBuilder,
        state_manager: TaskPersistenceManager,
        discovery_manager=None,  # McpStateManager - Optional, only needed for MCP
    ):
        self.version = config.version
        self.name = config.name
        self.app_config = app_config
        if hasattr(config, "spec"):
            self.config = Config(config=config)
        else:
            raise ValueError("Invalid config")
        self.agent_builder = agent_builder
        self.state = state_manager
        self.authorizer = DummyAuthorizer()
        self.discovery_manager = discovery_manager  # Store discovery manager (optional)

        # Track which sessions have seen MCP auth status messages (to show only once per session)
        self._mcp_status_shown_per_session: set[str] = set()

    async def _create_mcp_connection_manager(self, user_id: str, session_id: str):
        """
        Create a request-scoped MCP connection manager if MCP servers are configured.

        The connection manager provides:
        - Lazy connection establishment (connect on first tool call per server)
        - Connection reuse within the request (all tools share connections)
        - Automatic cleanup at request end
        - Session ID persistence for cross-request continuity

        Args:
            user_id: User ID for authentication
            session_id: Session ID for session-level scoping

        Returns:
            McpConnectionManager if MCP servers configured, None otherwise
        """
        mcp_servers = self.config.get_agent().mcp_servers
        if not mcp_servers or not self.discovery_manager:
            return None

        try:
            from sk_agents.mcp_client import McpConnectionManager

            # Build server configs dict keyed by server name
            server_configs = {server.name: server for server in mcp_servers}

            return McpConnectionManager(
                server_configs=server_configs,
                user_id=user_id,
                session_id=session_id,
                state_manager=self.discovery_manager,
                app_config=self.app_config,
            )
        except Exception as e:
            logger.warning(
                f"Failed to create MCP connection manager: {e}. "
                "Falling back to per-tool connections."
            )
            return None

    async def _ensure_session_discovery(
        self, user_id: str, session_id: str, task_id: str, request_id: str
    ) -> AuthChallengeResponse | None:
        """
        Ensure MCP tool discovery has been performed for this session.

        Discovery happens once per (user_id, session_id) when first detected.
        All tasks in the session share the discovered tools.

        Args:
            user_id: User ID for authentication
            session_id: Session ID for session-level scoping
            task_id: Task ID for auth challenge response
            request_id: Request ID for auth challenge response

        Returns:
            AuthChallengeResponse if authentication is required, None if discovery complete
        """
        # Early return if no discovery manager (no MCP servers configured)
        if not self.discovery_manager:
            return None

        # Check if discovery already completed for this session
        is_completed = await self.discovery_manager.is_completed(user_id, session_id)
        if is_completed:
            logger.debug(f"MCP discovery already completed for session: {session_id}")
            return None

        # Load or create discovery state
        discovery_state = await self.discovery_manager.load_discovery(user_id, session_id)
        if not discovery_state:
            from sk_agents.mcp_discovery.mcp_discovery_manager import McpState

            discovery_state = McpState(
                user_id=user_id,
                session_id=session_id,
                discovered_servers={},
                discovery_completed=False,
            )
            await self.discovery_manager.create_discovery(discovery_state)
            logger.info(f"Created discovery state for session: {session_id}")

        # Check if MCP servers configured
        mcp_servers = self.config.get_agent().mcp_servers
        if not mcp_servers or len(mcp_servers) == 0:
            await self.discovery_manager.mark_completed(user_id, session_id)
            return None

        try:
            from sk_agents.mcp_client import AuthRequiredError
            from sk_agents.mcp_plugin_registry import McpPluginRegistry

            logger.info(
                f"Starting MCP discovery for session {session_id} ({len(mcp_servers)} servers)"
            )

            await McpPluginRegistry.discover_and_materialize(
                mcp_servers, user_id, session_id, self.discovery_manager, self.app_config
            )

            await self.discovery_manager.mark_completed(user_id, session_id)
            logger.info(f"MCP discovery completed for session {session_id}")
            return None

        except AuthRequiredError as e:
            # Auth required - return challenge
            logger.info(
                f"MCP discovery requires authentication for '{e.server_name}' "
                f"(session: {session_id})"
            )

            try:
                # Find server config
                server_config = next((s for s in mcp_servers if s.name == e.server_name), None)
                if not server_config:
                    raise ValueError(f"Server config not found for '{e.server_name}'")

                # Initiate OAuth 2.1 authorization flow with PKCE
                from sk_agents.auth.oauth_client import OAuthClient

                oauth_client = OAuthClient()

                # Generate authorization URL with PKCE
                auth_url = await oauth_client.initiate_authorization_flow(
                    server_config=server_config, user_id=user_id
                )

                logger.info(f"Generated OAuth authorization URL for {e.server_name}")

                return AuthChallengeResponse(
                    task_id=task_id,
                    session_id=session_id,
                    request_id=request_id,
                    message=f"Authentication required for MCP server '{e.server_name}'.",
                    auth_challenges=[
                        {
                            "server_name": e.server_name,
                            "auth_server": e.auth_server,
                            "scopes": e.scopes,
                            "auth_url": auth_url,
                        }
                    ],
                    resume_url="/tealagents/v1alpha1/invoke",
                )

            except Exception as oauth_error:
                logger.error(f"Failed to initiate OAuth flow: {oauth_error}")
                return AuthChallengeResponse(
                    task_id=task_id,
                    session_id=session_id,
                    request_id=request_id,
                    message=f"Authentication required for MCP server '{e.server_name}'.",
                    auth_challenges=[
                        {
                            "server_name": e.server_name,
                            "auth_server": e.auth_server,
                            "scopes": e.scopes,
                            "auth_url": f"{e.auth_server}/authorize?error=oauth_client_failed",
                        }
                    ],
                    resume_url="/tealagents/v1alpha1/invoke",
                )

        except Exception as e:
            logger.error(f"MCP discovery failed for session {session_id}: {e}")
            raise

    @staticmethod
    async def _invoke_function(
        kernel: Kernel, fc_content: FunctionCallContent
    ) -> FunctionResultContent:
        """Helper to execute a single tool function call."""
        function = kernel.get_function(
            fc_content.plugin_name,
            fc_content.function_name,
        )
        kernel_argument = fc_content.to_kernel_arguments()
        function_result = await function.invoke(kernel, kernel_argument)
        return FunctionResultContent.from_function_call_content_and_result(
            fc_content, function_result
        )

    @staticmethod
    def _augment_with_user_context(inputs: UserMessage, chat_history: ChatHistory) -> None:
        if inputs.user_context:
            content = "The following user context was provided:\n"
            for key, value in inputs.user_context.items():
                content += f"  {key}: {value}\n"
            chat_history.add_message(
                ChatMessageContent(role=AuthorRole.USER, items=[TextContent(text=content)])
            )

    @staticmethod
    def _configure_agent_task(
        session_id: str,
        user_id: str,
        task_id: str,
        role: Literal["user", "assistant"],
        request_id: str,
        inputs: UserMessage,
        status: Literal["Running", "Paused", "Completed", "Failed", "Canceled"],
    ) -> AgentTask:
        agent_items = []
        for item in inputs.items:
            task_item = AgentTaskItem(
                task_id=task_id, role=role, item=item, request_id=request_id, updated=datetime.now()
            )
            agent_items.append(task_item)

        agent_task = AgentTask(
            task_id=task_id,
            session_id=session_id,
            user_id=user_id,
            items=agent_items,
            created_at=datetime.now(),
            last_updated=datetime.now(),
            status=status,
        )
        return agent_task

    async def authenticate_user(self, token: str) -> str:
        try:
            user_id = await self.authorizer.authorize_request(auth_header=token)
            return user_id
        except Exception as e:
            raise AuthenticationException(
                message=(f"Unable to authenticate user, exception message: {e}")
            ) from e

    async def authenticate_mcp_servers(
        self, user_id: str, session_id: str, task_id: str, request_id: str
    ) -> AuthChallengeResponse | None:
        """
        Authenticate MCP servers before agent construction.

        Returns AuthChallengeResponse if authentication is needed,
        None if all servers are authenticated.
        """
        mcp_servers = self.config.get_agent().mcp_servers
        if not mcp_servers:
            return None

        try:
            from sk_agents.auth_storage.auth_storage_factory import AuthStorageFactory
            from sk_agents.mcp_client import build_auth_storage_key

            auth_storage_factory = AuthStorageFactory(self.app_config)
            auth_storage = auth_storage_factory.get_auth_storage_manager()

            missing_auth_servers = []

            for server_config in mcp_servers:
                if server_config.auth_server and server_config.scopes:
                    # Check if we have valid auth for this server
                    composite_key = build_auth_storage_key(
                        server_config.auth_server, server_config.scopes
                    )
                    auth_data = auth_storage.retrieve(user_id, composite_key)

                    if not auth_data:
                        # Missing authentication for this server
                        scope_param = "%20".join(server_config.scopes)
                        auth_challenge = {
                            "server_name": server_config.name,
                            "auth_server": server_config.auth_server,
                            "scopes": server_config.scopes,
                            "auth_url": (
                                f"{server_config.auth_server}/authorize?"
                                f"client_id=teal_agents&scope={scope_param}&response_type=code"
                            ),
                        }
                        missing_auth_servers.append(auth_challenge)

            if missing_auth_servers:
                num_servers = len(missing_auth_servers)
                return AuthChallengeResponse(
                    task_id=task_id,
                    session_id=session_id,
                    request_id=request_id,
                    message=f"Authentication required for {num_servers} MCP server(s).",
                    auth_challenges=missing_auth_servers,
                    resume_url=f"/tealagents/v1alpha1/resume/{request_id}",
                )

            return None

        except Exception as e:
            logger.warning(f"Error during MCP server authentication check: {e}")
            # Continue without MCP auth if there are issues with auth storage
            return None

    @staticmethod
    def handle_state_id(inputs: UserMessage) -> tuple[str, str, str]:
        if inputs.session_id:
            session_id = inputs.session_id
        else:
            session_id = str(uuid.uuid4())

        if inputs.task_id:
            task_id = inputs.task_id
        else:
            task_id = str(uuid.uuid4())

        request_id = str(uuid.uuid4())

        return session_id, task_id, request_id

    async def _manage_incoming_task(
        self, task_id: str, session_id: str, user_id: str, request_id: str, inputs: UserMessage
    ) -> AgentTask | None:
        try:
            agent_task = await self.state.load(task_id)
            if not agent_task:
                agent_task = TealAgentsV1Alpha1Handler._configure_agent_task(
                    session_id=session_id,
                    user_id=user_id,
                    task_id=task_id,
                    role="user",
                    request_id=request_id,
                    inputs=inputs,
                    status="Running",
                )
                await self.state.create(agent_task)
                return agent_task
        except (PersistenceLoadError, PersistenceCreateError) as e:
            raise AgentInvokeException(
                f"Failed to load or create task {task_id}: {e.message}"
            ) from e
        except Exception as e:
            raise AgentInvokeException(
                f"Unexpected error occurred while managing incoming task {task_id}: {str(e)}"
            ) from e

    async def _manage_agent_response_task(
        self, agent_task: AgentTask, agent_response: TealAgentsResponse
    ) -> None:
        new_item = AgentTaskItem(
            task_id=agent_response.task_id,
            role="assistant",
            item=MultiModalItem(content_type=ContentType.TEXT, content=agent_response.output),
            request_id=agent_response.request_id,
            updated=datetime.now(),
        )
        agent_task.items.append(new_item)
        agent_task.last_updated = datetime.now()
        await self.state.update(agent_task)

    @staticmethod
    def _validate_user_id(user_id: str, task_id: str, agent_task: AgentTask) -> None:
        try:
            assert user_id == agent_task.user_id
        except AssertionError as e:
            raise AgentInvokeException(
                message=(f"Invalid user ID {user_id}and task ID {task_id} provided. {e}")
            ) from e

    @staticmethod
    def _build_chat_history(agent_task: AgentTask, chat_history: ChatHistory) -> ChatHistory:
        chat_message_items: list[TextContent | ImageContent] = []
        for task_item in agent_task.items:
            chat_message_items.append(item_to_content(task_item.item))
            message_content = ChatMessageContent(role=task_item.role, items=chat_message_items)
            chat_history.add_message(message_content)
        return chat_history

    @staticmethod
    def _rejected_task_item(task_id: str, request_id: str) -> AgentTaskItem:
        return AgentTaskItem(
            task_id=task_id,
            role="user",
            item=MultiModalItem(content_type=ContentType.TEXT, content="tool execution rejected"),
            request_id=request_id,
            updated=datetime.now(),
        )

    @staticmethod
    def _approved_task_item(task_id: str, request_id: str) -> AgentTaskItem:
        return AgentTaskItem(
            task_id=task_id,
            role="user",
            item=MultiModalItem(content_type=ContentType.TEXT, content="tool execution approved"),
            request_id=request_id,
            updated=datetime.now(),
        )

    async def _manage_hitl_exception(
        self,
        agent_task: AgentTask,
        session_id: str,
        task_id: str,
        request_id: str,
        function_calls: list,
        chat_history: ChatHistory,
    ):
        agent_task.status = "Paused"
        assistant_item = AgentTaskItem(
            task_id=task_id,
            role="assistant",
            item=MultiModalItem(
                content_type=ContentType.TEXT, content="HITL intervention required."
            ),
            request_id=request_id,
            updated=datetime.now(),
            pending_tool_calls=[fc.model_dump() for fc in function_calls],
            chat_history=chat_history,
        )
        agent_task.items.append(assistant_item)
        agent_task.last_updated = datetime.now()
        await self.state.update(agent_task)

        base_url = "/tealagents/v1alpha1/resume"
        approval_url = f"{base_url}/{request_id}?action=approve"
        rejection_url = f"{base_url}/{request_id}?action=reject"

        hitl_response = HitlResponse(
            session_id=session_id,
            task_id=task_id,
            request_id=request_id,
            tool_calls=[fc.model_dump() for fc in function_calls],
            approval_url=approval_url,
            rejection_url=rejection_url,
        )
        return hitl_response

    @staticmethod
    async def _manage_function_calls(
        function_calls: list[FunctionCallContent], chat_history: ChatHistory, kernel: Kernel
    ) -> None:
        intervention_calls = []
        non_intervention_calls = []

        # Separate function calls into intervention and non-intervention
        for fc in function_calls:
            if hitl_manager.check_for_intervention(fc):
                intervention_calls.append(fc)
            else:
                non_intervention_calls.append(fc)

        # Process non-intervention function calls first
        if non_intervention_calls:
            results = await asyncio.gather(
                *[
                    TealAgentsV1Alpha1Handler._invoke_function(kernel, fc)
                    for fc in non_intervention_calls
                ]
            )

            # Add results to history
            for result in results:
                chat_history.add_message(result.to_chat_message_content())

        # Handle intervention function calls
        if intervention_calls:
            logger.info(f"Intervention required for{len(intervention_calls)} function calls.")
            raise hitl_manager.HitlInterventionRequired(intervention_calls)

    async def prepare_agent_response(
        self,
        agent_task: AgentTask,
        request_id: str,
        response: ChatMessageContent | list[str],
        token_usage: TokenUsage,
        extra_data_collector: ExtraDataCollector,
    ):
        if isinstance(response, list):
            agent_output = "".join(response)
        else:
            agent_output = response.content

        total_tokens = token_usage.total_tokens
        session_id = agent_task.session_id
        task_id = agent_task.task_id
        request_id = request_id

        agent_response = TealAgentsResponse(
            session_id=session_id,
            task_id=task_id,
            request_id=request_id,
            output=agent_output,
            source=f"{self.name}:{self.version}",
            token_usage=token_usage,
            extra_data=extra_data_collector.get_extra_data(),
        )
        await self._manage_agent_response_task(agent_task, agent_response)
        logger.info(
            f"{self.name}:{self.version}"
            f"successful invocation with {total_tokens} tokens. "
            f"Session ID: {session_id}, Task ID: {task_id},"
            f"Request ID {request_id}"
        )
        return agent_response

    async def resume_task(
        self, auth_token: str, request_id: str, action_status: ResumeRequest, stream: bool
    ) -> (
        TealAgentsResponse
        | RejectedToolResponse
        | HitlResponse
        | AsyncIterable[TealAgentsResponse | TealAgentsPartialResponse | HitlResponse]
    ):
        user_id = await self.authenticate_user(token=auth_token)
        agent_task = await self.state.load_by_request_id(request_id)
        if agent_task is None:
            raise AgentInvokeException(f"No agent task found for request ID: {request_id}")

        # Validate task has items
        if not agent_task.items:
            raise AgentInvokeException(
                f"Cannot resume task {request_id}: task has no items. "
                f"Task may be corrupted or improperly initialized."
            )

        session_id = agent_task.session_id
        task_id = agent_task.task_id

        # Retrieve chat history from last item with validation
        last_item = agent_task.items[-1]
        if last_item.chat_history is None:
            raise AgentInvokeException(
                f"Cannot resume task {request_id}: chat history not preserved in paused state. "
                f"This indicates a persistence layer issue during HITL pause."
            )
        chat_history = last_item.chat_history

        TealAgentsV1Alpha1Handler._validate_user_id(user_id, task_id, agent_task)

        # Validate task is in correct state for resumption
        if agent_task.status != "Paused":
            raise AgentInvokeException(
                f"Cannot resume task {task_id}: task is in '{agent_task.status}' state, "
                f"expected 'Paused'. Task may have already been processed or cancelled."
            )

        if action_status.action != "approve":
            agent_task.status = "Canceled"
            agent_task.items.append(
                TealAgentsV1Alpha1Handler._rejected_task_item(
                    task_id=task_id, request_id=request_id
                )
            )
            agent_task.last_updated = datetime.now()
            await self.state.update(agent_task)

            return RejectedToolResponse(
                task_id=task_id, session_id=agent_task.session_id, request_id=request_id
            )
        # Record Approval state
        agent_task.status = "Running"
        agent_task.items.append(
            TealAgentsV1Alpha1Handler._approved_task_item(
                task_id=agent_task.task_id, request_id=request_id
            )
        )
        agent_task.last_updated = datetime.now()
        await self.state.update(agent_task)

        # Retrieve the pending_tool_calls from the last AgentTaskItem before approval/rejection item
        # Validate sufficient items exist
        if len(agent_task.items) < 2:
            raise AgentInvokeException(
                f"Invalid task state for request ID {request_id}: "
                f"expected at least 2 task items for HITL resume, found {len(agent_task.items)}"
            )

        pending_tools_item = agent_task.items[-2]
        if not pending_tools_item.pending_tool_calls:
            raise AgentInvokeException(
                f"Pending tool calls not found for request ID: {request_id}. "
                f"Task item at index -2 has no pending tool calls."
            )

        _pending_tools = list(pending_tools_item.pending_tool_calls)
        pending_tools = [FunctionCallContent(**function_call) for function_call in _pending_tools]

        # Create request-scoped connection manager for MCP connection reuse
        connection_manager = await self._create_mcp_connection_manager(user_id, session_id)

        async def _execute_resume(conn_mgr=None):
            # Execute the tool calls using asyncio.gather(),
            # just as the agent would have.
            extra_data_collector = ExtraDataCollector()
            agent = await self.agent_builder.build_agent(
                self.config.get_agent(), extra_data_collector, user_id=user_id
            )

            # Load MCP plugins after agent construction (per-session isolation)
            # connection_manager is required for MCP plugin loading
            if self.config.get_agent().mcp_servers and self.discovery_manager and conn_mgr:
                await self.agent_builder.kernel_builder.load_mcp_plugins(
                    agent.agent.kernel, user_id, session_id, self.discovery_manager, conn_mgr
                )

            kernel = agent.agent.kernel

            # Create ToolContent objects from the results
            results = await asyncio.gather(
                *[TealAgentsV1Alpha1Handler._invoke_function(kernel, fc) for fc in pending_tools]
            )
            # Add results to chat history
            for result in results:
                chat_history.add_message(result.to_chat_message_content())

            if stream:
                final_response_stream = self.recursion_invoke_stream(
                    chat_history, session_id, task_id, request_id, connection_manager=conn_mgr
                )
                return final_response_stream
            else:
                final_response_invoke = await self.recursion_invoke(
                    inputs=chat_history,
                    session_id=session_id,
                    request_id=request_id,
                    task_id=task_id,
                    connection_manager=conn_mgr,
                )
                return final_response_invoke

        if connection_manager:
            async with connection_manager:
                return await _execute_resume(connection_manager)
        else:
            return await _execute_resume()

    async def invoke(
        self, auth_token: str, inputs: UserMessage
    ) -> TealAgentsResponse | HitlResponse | AuthChallengeResponse:
        # Initial setup
        logger.info("Beginning processing invoke")

        user_id = await self.authenticate_user(token=auth_token)

        # Generate state IDs first (needed for auth challenges)
        state_ids = TealAgentsV1Alpha1Handler.handle_state_id(inputs)
        session_id, task_id, request_id = state_ids
        inputs.session_id = session_id
        inputs.task_id = task_id

        # Ensure MCP discovery has been performed for this session
        # May return AuthChallengeResponse if auth required during discovery
        discovery_auth_challenge = await self._ensure_session_discovery(
            user_id, session_id, task_id, request_id
        )
        if discovery_auth_challenge:
            logger.info("Returning auth challenge from MCP discovery")
            return discovery_auth_challenge

        agent_task = await self._manage_incoming_task(
            task_id, session_id, user_id, request_id, inputs
        )
        if agent_task is None:
            raise AgentInvokeException("Agent task not created")
        # Check user_id match request and state
        TealAgentsV1Alpha1Handler._validate_user_id(user_id, task_id, agent_task)

        # Check MCP server authentication before agent construction
        auth_challenge = await self.authenticate_mcp_servers(
            user_id, session_id, task_id, request_id
        )
        if auth_challenge:
            logger.info(
                f"MCP authentication required for {len(auth_challenge.auth_challenges)} server(s)"
            )
            return auth_challenge

        chat_history = ChatHistory()
        TealAgentsV1Alpha1Handler._augment_with_user_context(
            inputs=inputs, chat_history=chat_history
        )
        TealAgentsV1Alpha1Handler._build_chat_history(agent_task, chat_history)
        logger.info("Building the final response")

        # Create request-scoped connection manager for MCP connection reuse
        connection_manager = await self._create_mcp_connection_manager(user_id, session_id)
        if connection_manager:
            async with connection_manager:
                final_response_invoke = await self.recursion_invoke(
                    inputs=chat_history,
                    session_id=session_id,
                    request_id=request_id,
                    task_id=task_id,
                    connection_manager=connection_manager,
                )
        else:
            final_response_invoke = await self.recursion_invoke(
                inputs=chat_history, session_id=session_id, request_id=request_id, task_id=task_id
            )
        logger.info("Final response complete")

        return final_response_invoke

    async def invoke_stream(
        self, auth_token: str, inputs: UserMessage
    ) -> AsyncIterable[
        TealAgentsResponse | TealAgentsPartialResponse | HitlResponse | AuthChallengeResponse
    ]:
        # Initial setup
        logger.info("Beginning processing invoke")
        user_id = await self.authenticate_user(token=auth_token)

        # Generate state IDs first (needed for auth challenges)
        state_ids = TealAgentsV1Alpha1Handler.handle_state_id(inputs)
        session_id, task_id, request_id = state_ids

        # Ensure MCP discovery has been performed for this session
        # May return AuthChallengeResponse if auth required during discovery
        discovery_auth_challenge = await self._ensure_session_discovery(
            user_id, session_id, task_id, request_id
        )
        if discovery_auth_challenge:
            logger.info("Returning auth challenge from MCP discovery")
            yield discovery_auth_challenge
            return

        # Notify user that MCP is ready (only once per session, after discovery)
        mcp_servers = self.config.get_agent().mcp_servers
        show_status = session_id not in self._mcp_status_shown_per_session

        if show_status and mcp_servers and len(mcp_servers) > 0:
            # Load state to check for failures
            failed_servers = {}
            if self.discovery_manager:
                try:
                    state = await self.discovery_manager.load_discovery(user_id, session_id)
                    if state:
                        failed_servers = state.failed_servers
                except Exception:
                    logger.debug("Failed to load discovery state for status message")

            all_server_names = [server.name for server in mcp_servers]
            successful_servers = [s for s in all_server_names if s not in failed_servers]

            messages = []
            if successful_servers:
                messages.append(f"✅ MCP connected: {', '.join(successful_servers)}")

            if failed_servers:
                failed_list = []
                for name, error in failed_servers.items():
                    # Truncate error if too long
                    short_error = (error[:50] + "...") if len(error) > 50 else error
                    failed_list.append(f"{name} ({short_error})")
                messages.append(f"⚠️ MCP connection failed: {', '.join(failed_list)}")

            status_msg = "\n".join(messages) + "\n\n"

            yield TealAgentsPartialResponse(
                task_id=task_id,
                session_id=session_id,
                request_id=request_id,
                output_partial=status_msg,
            )
            # Mark this session as having seen the status message
            self._mcp_status_shown_per_session.add(session_id)

        agent_task = await self._manage_incoming_task(
            task_id, session_id, user_id, request_id, inputs
        )
        if agent_task is None:
            raise AgentInvokeException("Agent task not created")
        # Check user_id match request and state
        TealAgentsV1Alpha1Handler._validate_user_id(user_id, task_id, agent_task)

        # Check MCP server authentication before agent construction
        auth_challenge = await self.authenticate_mcp_servers(
            user_id, session_id, task_id, request_id
        )
        if auth_challenge:
            logger.info(
                f"MCP authentication required for {len(auth_challenge.auth_challenges)} server(s)"
            )
            yield auth_challenge
            return

        chat_history = ChatHistory()
        TealAgentsV1Alpha1Handler._augment_with_user_context(
            inputs=inputs, chat_history=chat_history
        )
        logger.info("Building the final response")
        TealAgentsV1Alpha1Handler._build_chat_history(agent_task, chat_history)

        # Create request-scoped connection manager for MCP connection reuse
        connection_manager = await self._create_mcp_connection_manager(user_id, session_id)
        if connection_manager:
            async with connection_manager:
                async for response_chunk in self.recursion_invoke_stream(
                    chat_history,
                    session_id,
                    task_id,
                    request_id,
                    connection_manager=connection_manager,
                ):
                    yield response_chunk
        else:
            async for response_chunk in self.recursion_invoke_stream(
                chat_history, session_id, task_id, request_id
            ):
                yield response_chunk

        logger.info("Final response complete")

    async def recursion_invoke(
        self,
        inputs: ChatHistory,
        session_id: str,
        task_id: str,
        request_id: str,
        connection_manager=None,
    ) -> TealAgentsResponse | HitlResponse:
        # Initial setup

        chat_history = inputs
        agent_task = await self.state.load_by_request_id(request_id)
        if not agent_task:
            raise PersistenceLoadError(f"Agent task with ID {task_id} not found in state.")

        user_id = agent_task.user_id
        extra_data_collector = ExtraDataCollector()
        agent = await self.agent_builder.build_agent(
            self.config.get_agent(), extra_data_collector, user_id=user_id
        )

        # Load MCP plugins after agent construction (per-session isolation)
        # connection_manager is required for MCP plugin loading
        if self.config.get_agent().mcp_servers and self.discovery_manager and connection_manager:
            await self.agent_builder.kernel_builder.load_mcp_plugins(
                agent.agent.kernel, user_id, session_id, self.discovery_manager, connection_manager
            )

        # Prepare metadata
        completion_tokens: int = 0
        prompt_tokens: int = 0
        total_tokens: int = 0

        try:
            # Manual tool calling implementation (existing logic)
            kernel = agent.agent.kernel
            arguments = agent.agent.arguments
            chat_completion_service, settings = kernel.select_ai_service(
                arguments=arguments, type=ChatCompletionClientBase
            )

            assert isinstance(chat_completion_service, ChatCompletionClientBase)

            # Initial call to the LLM
            response_list = []
            responses = await chat_completion_service.get_chat_message_contents(
                chat_history=chat_history,
                settings=settings,
                kernel=kernel,
                arguments=arguments,
            )
            for response_chunk in responses:
                # response_list.extend(response_chunk)
                chat_history.add_message(response_chunk)
                response_list.append(response_chunk)

            function_calls = []
            final_response = None

            # Separate content and tool calls
            for response in response_list:
                # Update token usage
                call_usage = get_token_usage_for_response(agent.get_model_type(), response)
                completion_tokens += call_usage.completion_tokens
                prompt_tokens += call_usage.prompt_tokens
                total_tokens += call_usage.total_tokens

                # A response may have multiple items, e.g., multiple tool calls
                fc_in_response = [
                    item for item in response.items if isinstance(item, FunctionCallContent)
                ]

                if fc_in_response:
                    # chat_history.add_message(response)
                    # Add assistant's message to history
                    function_calls.extend(fc_in_response)
                else:
                    # If no function calls, it's a direct answer
                    final_response = response
            token_usage = TokenUsage(
                completion_tokens=completion_tokens,
                prompt_tokens=prompt_tokens,
                total_tokens=total_tokens,
            )
            # If tool calls were returned, execute them
            if function_calls:
                await self._manage_function_calls(function_calls, chat_history, kernel)

                # Make a recursive call to get the final response from the LLM
                recursive_response = await self.recursion_invoke(
                    inputs=chat_history,
                    session_id=session_id,
                    task_id=task_id,
                    request_id=request_id,
                    connection_manager=connection_manager,
                )
                return recursive_response

            # No tool calls, return the direct response
            if final_response is None:
                error_msg = (
                    f"No response received from LLM for Session ID {session_id}, "
                    f"Task ID {task_id}, Request ID {request_id}. "
                    f"Function calls processed: {len(function_calls)}"
                )
                logger.error(error_msg)
                raise AgentInvokeException(error_msg)
        except hitl_manager.HitlInterventionRequired as hitl_exc:
            return await self._manage_hitl_exception(
                agent_task, session_id, task_id, request_id, hitl_exc.function_calls, chat_history
            )

        except Exception as e:
            logger.exception(
                f"Error invoking {self.name}:{self.version}"
                f"for Session ID {session_id}, Task ID {task_id},"
                f"Request ID {request_id}, Error message: {str(e)}",
                exc_info=True,
            )
            raise AgentInvokeException(
                f"Error invoking {self.name}:{self.version}"
                f"for Session ID {session_id}, Task ID {task_id},"
                f" Request ID {request_id}, Error message: {str(e)}"
            ) from e

        # Persist and return response
        return await self.prepare_agent_response(
            agent_task, request_id, final_response, token_usage, extra_data_collector
        )

    async def recursion_invoke_stream(
        self,
        inputs: ChatHistory,
        session_id: str,
        task_id: str,
        request_id: str,
        connection_manager=None,
    ) -> AsyncIterable[TealAgentsResponse | TealAgentsPartialResponse | HitlResponse]:
        chat_history = inputs
        agent_task = await self.state.load_by_request_id(request_id)
        if not agent_task:
            raise PersistenceLoadError(f"Agent task with ID {task_id} not found in state.")

        user_id = agent_task.user_id
        extra_data_collector = ExtraDataCollector()
        agent = await self.agent_builder.build_agent(
            self.config.get_agent(), extra_data_collector, user_id=user_id
        )

        # Load MCP plugins after agent construction (per-session isolation)
        # connection_manager is required for MCP plugin loading
        if self.config.get_agent().mcp_servers and self.discovery_manager and connection_manager:
            await self.agent_builder.kernel_builder.load_mcp_plugins(
                agent.agent.kernel, user_id, session_id, self.discovery_manager, connection_manager
            )

        # Prepare metadata
        final_response = []
        completion_tokens: int = 0
        prompt_tokens: int = 0
        total_tokens: int = 0

        try:
            kernel = agent.agent.kernel
            arguments = agent.agent.arguments
            kernel_configs = kernel.select_ai_service(
                arguments=arguments, type=ChatCompletionClientBase
            )
            chat_completion_service, settings = kernel_configs
            assert isinstance(chat_completion_service, ChatCompletionClientBase)

            all_responses = []
            # Stream the initial response from the LLM
            response_list = []
            responses = await chat_completion_service.get_chat_message_contents(
                chat_history=chat_history,
                settings=settings,
                kernel=kernel,
                arguments=arguments,
            )
            for response_chunk in responses:
                chat_history.add_message(response_chunk)
                response_list.append(response_chunk)

            for response in response_list:
                all_responses.append(response)
                # Calculate usage metrics
                call_usage = get_token_usage_for_response(agent.get_model_type(), response)
                completion_tokens += call_usage.completion_tokens
                prompt_tokens += call_usage.prompt_tokens
                total_tokens += call_usage.total_tokens

                if response.content:
                    try:
                        # Attempt to parse as ExtraDataPartial
                        extra_data_partial: ExtraDataPartial = ExtraDataPartial.new_from_json(
                            response.content
                        )
                        extra_data_collector.add_extra_data_items(extra_data_partial.extra_data)
                    except Exception:
                        if len(response.content) > 0:
                            # Handle and return partial response
                            final_response.append(response.content)
                            yield TealAgentsPartialResponse(
                                session_id=session_id,
                                task_id=task_id,
                                request_id=request_id,
                                output_partial=response.content,
                                source=f"{self.name}:{self.version}",
                            )

            token_usage = TokenUsage(
                completion_tokens=completion_tokens,
                prompt_tokens=prompt_tokens,
                total_tokens=total_tokens,
            )
            # Aggregate the full response to check for tool calls
            if not all_responses:
                return

            full_completion: StreamingChatMessageContent = reduce(lambda x, y: x + y, all_responses)
            function_calls = [
                item for item in full_completion.items if isinstance(item, FunctionCallContent)
            ]

            # If tool calls are present, execute them
            if function_calls:
                await self._manage_function_calls(function_calls, chat_history, kernel)
                # Make a recursive call to get the final streamed response
                async for final_response_chunk in self.recursion_invoke_stream(
                    chat_history,
                    session_id,
                    task_id,
                    request_id,
                    connection_manager=connection_manager,
                ):
                    yield final_response_chunk
                return
        except hitl_manager.HitlInterventionRequired as hitl_exc:
            yield await self._manage_hitl_exception(
                agent_task, session_id, task_id, request_id, hitl_exc.function_calls, chat_history
            )
            return

        except Exception as e:
            logger.exception(
                f"Error invoking stream for {self.name}:{self.version} "
                f"for Session ID {session_id}, Task ID {task_id},"
                f" Request ID {request_id}, Error message: {str(e)}",
                exc_info=True,
            )
            raise AgentInvokeException(
                f"Error invoking stream for {self.name}:{self.version}"
                f"for Session ID {session_id}, Task ID {task_id},"
                f"Request ID {request_id}, Error message: {str(e)}"
            ) from e

        # # Persist and return response
        yield await self.prepare_agent_response(
            agent_task, request_id, final_response, token_usage, extra_data_collector
        )
sk_agents.tealagents.v1alpha1.agent.handler.TealAgentsV1Alpha1Handler.authenticate_mcp_servers async
authenticate_mcp_servers(
    user_id: str,
    session_id: str,
    task_id: str,
    request_id: str,
) -> AuthChallengeResponse | None

Authenticate MCP servers before agent construction.

Returns AuthChallengeResponse if authentication is needed, None if all servers are authenticated.

Source code in src/sk_agents/tealagents/v1alpha1/agent/handler.py
async def authenticate_mcp_servers(
    self, user_id: str, session_id: str, task_id: str, request_id: str
) -> AuthChallengeResponse | None:
    """
    Authenticate MCP servers before agent construction.

    Returns AuthChallengeResponse if authentication is needed,
    None if all servers are authenticated.
    """
    mcp_servers = self.config.get_agent().mcp_servers
    if not mcp_servers:
        return None

    try:
        from sk_agents.auth_storage.auth_storage_factory import AuthStorageFactory
        from sk_agents.mcp_client import build_auth_storage_key

        auth_storage_factory = AuthStorageFactory(self.app_config)
        auth_storage = auth_storage_factory.get_auth_storage_manager()

        missing_auth_servers = []

        for server_config in mcp_servers:
            if server_config.auth_server and server_config.scopes:
                # Check if we have valid auth for this server
                composite_key = build_auth_storage_key(
                    server_config.auth_server, server_config.scopes
                )
                auth_data = auth_storage.retrieve(user_id, composite_key)

                if not auth_data:
                    # Missing authentication for this server
                    scope_param = "%20".join(server_config.scopes)
                    auth_challenge = {
                        "server_name": server_config.name,
                        "auth_server": server_config.auth_server,
                        "scopes": server_config.scopes,
                        "auth_url": (
                            f"{server_config.auth_server}/authorize?"
                            f"client_id=teal_agents&scope={scope_param}&response_type=code"
                        ),
                    }
                    missing_auth_servers.append(auth_challenge)

        if missing_auth_servers:
            num_servers = len(missing_auth_servers)
            return AuthChallengeResponse(
                task_id=task_id,
                session_id=session_id,
                request_id=request_id,
                message=f"Authentication required for {num_servers} MCP server(s).",
                auth_challenges=missing_auth_servers,
                resume_url=f"/tealagents/v1alpha1/resume/{request_id}",
            )

        return None

    except Exception as e:
        logger.warning(f"Error during MCP server authentication check: {e}")
        # Continue without MCP auth if there are issues with auth storage
        return None
sk_agents.tealagents.v1alpha1.config
sk_agents.tealagents.v1alpha1.config.McpServerConfig

Bases: BaseModel

Configuration for an MCP server connection supporting multiple transports.

Source code in src/sk_agents/tealagents/v1alpha1/config.py
class McpServerConfig(BaseModel):
    """Configuration for an MCP server connection supporting multiple transports."""

    model_config = ConfigDict(extra="allow")

    name: str
    # Supported transports: stdio for local servers, http for remote servers
    transport: Literal["stdio", "http"] = "stdio"

    # Stdio transport fields
    command: str | None = None
    args: list[str] = []
    env: dict[str, str] | None = None

    # HTTP transport fields
    url: str | None = None
    headers: dict[str, str] | None = None  # Non-sensitive headers only
    timeout: float | None = None  # Will be set automatically if not provided
    sse_read_timeout: float | None = None  # Will be set automatically if not provided
    verify_ssl: bool = True  # Allow opt-out for dev; defaults to verified SSL

    # Server-level authentication for tool catalog integration
    auth_server: str | None = None  # OAuth2 authorization server URL
    scopes: list[str] = []  # Required OAuth2 scopes for this server's tools

    # Tool-specific governance overrides (optional)
    tool_governance_overrides: dict[str, GovernanceOverride] | None = None

    # Server trust level for additional governance controls
    trust_level: Literal["trusted", "sandboxed", "untrusted"] = "untrusted"

    # Request-level timeout for individual MCP operations (seconds)
    request_timeout: float | None = 30.0

    # Optional per-server user header injection (opt-in)
    user_id_header: str | None = None  # e.g., "Arcade-User-Id"
    user_id_source: Literal["auth", "env"] | None = None  # where to read the value
    user_id_env_var: str | None = None  # env var name when source == "env"

    # OAuth 2.1 Configuration (MCP Compliance)
    oauth_client_id: str | None = None  # Pre-registered OAuth client ID
    oauth_client_secret: str | None = None  # Client secret (confidential clients only)
    canonical_uri: str | None = None  # Explicit canonical URI override
    enable_dynamic_registration: bool = True  # Try RFC7591 dynamic registration

    # MCP Protocol Version (for conditional OAuth parameter inclusion)
    protocol_version: str | None = None  # MCP protocol version (e.g., "2025-06-18")

    # Server Metadata Discovery Configuration (RFC 8414/9728)
    enable_metadata_discovery: bool = True  # Enable RFC 8414/9728 discovery
    metadata_cache_ttl: int = 3600  # Metadata cache TTL in seconds (default: 1 hour)

    @property
    def effective_canonical_uri(self) -> str:
        """
        Get canonical MCP server URI for resource parameter binding.

        Per MCP spec, canonical URI must be:
        - Absolute HTTPS URI
        - Lowercase scheme and host
        - Optional port and path

        Returns:
            str: Canonical URI (either explicit or computed from url)

        Raises:
            ValueError: If cannot determine canonical URI
        """
        from sk_agents.mcp_client import normalize_canonical_uri

        # Use explicit canonical_uri if provided
        if self.canonical_uri:
            return normalize_canonical_uri(self.canonical_uri)

        # Compute from url for HTTP transport
        if self.transport == "http" and self.url:
            return normalize_canonical_uri(self.url)

        # Stdio transport doesn't need canonical URI (no OAuth)
        if self.transport == "stdio":
            raise ValueError(
                f"Canonical URI not applicable for stdio transport (server: {self.name})"
            )

        raise ValueError(
            f"Cannot determine canonical URI for server '{self.name}'. "
            f"Provide 'canonical_uri' or ensure 'url' is set for HTTP transport."
        )

    @property
    def oauth_redirect_uri(self) -> str:
        """Get platform OAuth redirect URI from config."""
        from ska_utils import AppConfig

        from sk_agents.configs import TA_OAUTH_REDIRECT_URI

        app_config = AppConfig()
        return app_config.get(TA_OAUTH_REDIRECT_URI.env_name)

    @model_validator(mode="after")
    def validate_transport_fields(self):
        """Validate that required fields are provided for the selected transport."""
        if self.transport == "stdio":
            if not self.command:
                raise ValueError("'command' is required for stdio transport")
            # Basic security validation
            if any(char in (self.command or "") for char in [";", "&", "|", "`", "$"]):
                raise ValueError("Command contains potentially unsafe characters")
        elif self.transport == "http":
            if not self.url:
                raise ValueError("'url' is required for http transport")
            # Validate URL format
            if not (self.url.startswith("http://") or self.url.startswith("https://")):
                raise ValueError("HTTP transport URL must start with 'http://' or 'https://'")

            # Set smart defaults for timeouts if not provided
            if self.timeout is None:
                self.timeout = 30.0  # Default timeout
            if self.sse_read_timeout is None:
                self.sse_read_timeout = 300.0  # Default SSE read timeout

            # Warn if no authentication configured for HTTP server
            has_oauth = self.auth_server and self.scopes
            has_auth_header = self.headers and any(
                k.lower() == "authorization" for k in self.headers.keys()
            )

            if not has_oauth and not has_auth_header:
                import warnings

                warnings.warn(
                    f"MCP server '{self.name}' is configured without authentication. "
                    f"This should only be used for:\n"
                    f"  - Public/read-only MCP servers\n"
                    f"  - Development/testing environments\n"
                    f"  - Internal networks with network-level security\n"
                    f"For production use with sensitive data, configure OAuth "
                    f"(auth_server + scopes) or provide Authorization header.",
                    UserWarning,
                    stacklevel=2,
                )

            # OAuth validation - only if using OAuth
            # If one OAuth field is provided, both must be provided
            if self.auth_server or self.scopes:
                if not (self.auth_server and self.scopes):
                    raise ValueError(
                        "Both auth_server and scopes are required when using OAuth authentication. "
                        "Provide both or neither for simple header-based authentication."
                    )

        # OAuth-specific validation (only when OAuth is configured)
        if self.auth_server and self.scopes:
            # Validate auth_server URL format
            if not self.auth_server.startswith(("http://", "https://")):
                raise ValueError("auth_server must be a valid HTTP/HTTPS URL")

            # HTTPS enforcement (per OAuth 2.1 and MCP spec)
            from ska_utils import AppConfig

            from sk_agents.configs import TA_MCP_OAUTH_STRICT_HTTPS_VALIDATION
            from sk_agents.mcp_client import validate_https_url

            app_config = AppConfig()
            strict_https = (
                app_config.get(TA_MCP_OAUTH_STRICT_HTTPS_VALIDATION.env_name).lower() == "true"
            )

            if strict_https:
                # Validate auth_server uses HTTPS (or localhost)
                if not validate_https_url(self.auth_server, allow_localhost=True):
                    raise ValueError(
                        f"auth_server must use HTTPS (or http://localhost for dev): "
                        f"{self.auth_server}. "
                        f"Disable with TA_MCP_OAUTH_STRICT_HTTPS_VALIDATION=false"
                    )

                # Validate redirect_uri uses HTTPS (or localhost)
                redirect_uri = self.oauth_redirect_uri
                if redirect_uri and not validate_https_url(redirect_uri, allow_localhost=True):
                    raise ValueError(
                        f"OAuth redirect_uri must use HTTPS (or http://localhost for dev): "
                        f"{redirect_uri}. "
                        f"Disable with TA_MCP_OAUTH_STRICT_HTTPS_VALIDATION=false"
                    )

        return self
sk_agents.tealagents.v1alpha1.config.McpServerConfig.effective_canonical_uri property
effective_canonical_uri: str

Get canonical MCP server URI for resource parameter binding.

Per MCP spec, canonical URI must be: - Absolute HTTPS URI - Lowercase scheme and host - Optional port and path

Returns:

Name Type Description
str str

Canonical URI (either explicit or computed from url)

Raises:

Type Description
ValueError

If cannot determine canonical URI

sk_agents.tealagents.v1alpha1.config.McpServerConfig.oauth_redirect_uri property
oauth_redirect_uri: str

Get platform OAuth redirect URI from config.

sk_agents.tealagents.v1alpha1.config.McpServerConfig.validate_transport_fields
validate_transport_fields()

Validate that required fields are provided for the selected transport.

Source code in src/sk_agents/tealagents/v1alpha1/config.py
@model_validator(mode="after")
def validate_transport_fields(self):
    """Validate that required fields are provided for the selected transport."""
    if self.transport == "stdio":
        if not self.command:
            raise ValueError("'command' is required for stdio transport")
        # Basic security validation
        if any(char in (self.command or "") for char in [";", "&", "|", "`", "$"]):
            raise ValueError("Command contains potentially unsafe characters")
    elif self.transport == "http":
        if not self.url:
            raise ValueError("'url' is required for http transport")
        # Validate URL format
        if not (self.url.startswith("http://") or self.url.startswith("https://")):
            raise ValueError("HTTP transport URL must start with 'http://' or 'https://'")

        # Set smart defaults for timeouts if not provided
        if self.timeout is None:
            self.timeout = 30.0  # Default timeout
        if self.sse_read_timeout is None:
            self.sse_read_timeout = 300.0  # Default SSE read timeout

        # Warn if no authentication configured for HTTP server
        has_oauth = self.auth_server and self.scopes
        has_auth_header = self.headers and any(
            k.lower() == "authorization" for k in self.headers.keys()
        )

        if not has_oauth and not has_auth_header:
            import warnings

            warnings.warn(
                f"MCP server '{self.name}' is configured without authentication. "
                f"This should only be used for:\n"
                f"  - Public/read-only MCP servers\n"
                f"  - Development/testing environments\n"
                f"  - Internal networks with network-level security\n"
                f"For production use with sensitive data, configure OAuth "
                f"(auth_server + scopes) or provide Authorization header.",
                UserWarning,
                stacklevel=2,
            )

        # OAuth validation - only if using OAuth
        # If one OAuth field is provided, both must be provided
        if self.auth_server or self.scopes:
            if not (self.auth_server and self.scopes):
                raise ValueError(
                    "Both auth_server and scopes are required when using OAuth authentication. "
                    "Provide both or neither for simple header-based authentication."
                )

    # OAuth-specific validation (only when OAuth is configured)
    if self.auth_server and self.scopes:
        # Validate auth_server URL format
        if not self.auth_server.startswith(("http://", "https://")):
            raise ValueError("auth_server must be a valid HTTP/HTTPS URL")

        # HTTPS enforcement (per OAuth 2.1 and MCP spec)
        from ska_utils import AppConfig

        from sk_agents.configs import TA_MCP_OAUTH_STRICT_HTTPS_VALIDATION
        from sk_agents.mcp_client import validate_https_url

        app_config = AppConfig()
        strict_https = (
            app_config.get(TA_MCP_OAUTH_STRICT_HTTPS_VALIDATION.env_name).lower() == "true"
        )

        if strict_https:
            # Validate auth_server uses HTTPS (or localhost)
            if not validate_https_url(self.auth_server, allow_localhost=True):
                raise ValueError(
                    f"auth_server must use HTTPS (or http://localhost for dev): "
                    f"{self.auth_server}. "
                    f"Disable with TA_MCP_OAUTH_STRICT_HTTPS_VALIDATION=false"
                )

            # Validate redirect_uri uses HTTPS (or localhost)
            redirect_uri = self.oauth_redirect_uri
            if redirect_uri and not validate_https_url(redirect_uri, allow_localhost=True):
                raise ValueError(
                    f"OAuth redirect_uri must use HTTPS (or http://localhost for dev): "
                    f"{redirect_uri}. "
                    f"Disable with TA_MCP_OAUTH_STRICT_HTTPS_VALIDATION=false"
                )

    return self
sk_agents.utility_routes
sk_agents.utility_routes.HealthStatus

Bases: BaseModel

Health check response model.

Source code in src/sk_agents/utility_routes.py
class HealthStatus(BaseModel):
    """Health check response model."""

    status: str
    timestamp: str
    version: str | None = None
    uptime: float | None = None
    dependencies: dict[str, Any] | None = None
sk_agents.utility_routes.ReadinessStatus

Bases: BaseModel

Readiness check response model.

Source code in src/sk_agents/utility_routes.py
class ReadinessStatus(BaseModel):
    """Readiness check response model."""

    ready: bool
    timestamp: str
    checks: dict[str, Any]
sk_agents.utility_routes.LivenessStatus

Bases: BaseModel

Liveness check response model.

Source code in src/sk_agents/utility_routes.py
class LivenessStatus(BaseModel):
    """Liveness check response model."""

    alive: bool
    timestamp: str
sk_agents.utility_routes.UtilityRoutes

Utility routes for health checks and system monitoring.

Source code in src/sk_agents/utility_routes.py
class UtilityRoutes:
    """Utility routes for health checks and system monitoring."""

    def __init__(self, start_time: datetime | None = None):
        self.start_time = start_time or datetime.now()

    def get_health_routes(
        self,
        config: BaseConfig,
        app_config: AppConfig,
    ) -> APIRouter:
        """
        Get health check routes for the application.

        Args:
            config: Base configuration
            app_config: Application configuration

        Returns:
            APIRouter: Router with health check endpoints
        """
        router = APIRouter()

        @router.get(
            "/health",
            response_model=HealthStatus,
            summary="Health check endpoint",
            description="Returns the health status of the application",
            tags=["Health"],
        )
        async def health_check(request: Request) -> HealthStatus:
            """
            Basic health check endpoint that returns the application status.
            """
            try:
                current_time = datetime.now()
                uptime = (current_time - self.start_time).total_seconds()

                return HealthStatus(
                    status="healthy",
                    timestamp=current_time.isoformat(),
                    version=str(config.version) if config.version else None,
                    uptime=uptime,
                )
            except Exception as e:
                logger.exception(f"Health check failed: {e}")
                raise HTTPException(
                    status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Service unhealthy"
                ) from e

        @router.get(
            "/health/live",
            response_model=LivenessStatus,
            summary="Liveness probe",
            description="Kubernetes liveness probe endpoint",
            tags=["Health"],
        )
        async def liveness_check(request: Request) -> LivenessStatus:
            """
            Liveness probe for Kubernetes deployments.
            This endpoint should return 200 if the application is running.
            """
            try:
                return LivenessStatus(alive=True, timestamp=datetime.now().isoformat())
            except Exception as e:
                logger.exception(f"Liveness check failed: {e}")
                raise HTTPException(
                    status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Service not alive"
                ) from e

        return router
sk_agents.utility_routes.UtilityRoutes.get_health_routes
get_health_routes(
    config: BaseConfig, app_config: AppConfig
) -> APIRouter

Get health check routes for the application.

Parameters:

Name Type Description Default
config BaseConfig

Base configuration

required
app_config AppConfig

Application configuration

required

Returns:

Name Type Description
APIRouter APIRouter

Router with health check endpoints

Source code in src/sk_agents/utility_routes.py
def get_health_routes(
    self,
    config: BaseConfig,
    app_config: AppConfig,
) -> APIRouter:
    """
    Get health check routes for the application.

    Args:
        config: Base configuration
        app_config: Application configuration

    Returns:
        APIRouter: Router with health check endpoints
    """
    router = APIRouter()

    @router.get(
        "/health",
        response_model=HealthStatus,
        summary="Health check endpoint",
        description="Returns the health status of the application",
        tags=["Health"],
    )
    async def health_check(request: Request) -> HealthStatus:
        """
        Basic health check endpoint that returns the application status.
        """
        try:
            current_time = datetime.now()
            uptime = (current_time - self.start_time).total_seconds()

            return HealthStatus(
                status="healthy",
                timestamp=current_time.isoformat(),
                version=str(config.version) if config.version else None,
                uptime=uptime,
            )
        except Exception as e:
            logger.exception(f"Health check failed: {e}")
            raise HTTPException(
                status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Service unhealthy"
            ) from e

    @router.get(
        "/health/live",
        response_model=LivenessStatus,
        summary="Liveness probe",
        description="Kubernetes liveness probe endpoint",
        tags=["Health"],
    )
    async def liveness_check(request: Request) -> LivenessStatus:
        """
        Liveness probe for Kubernetes deployments.
        This endpoint should return 200 if the application is running.
        """
        try:
            return LivenessStatus(alive=True, timestamp=datetime.now().isoformat())
        except Exception as e:
            logger.exception(f"Liveness check failed: {e}")
            raise HTTPException(
                status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Service not alive"
            ) from e

    return router