|
50 | 50 |
|
51 | 51 | # Standard |
52 | 52 | import asyncio |
| 53 | +from asyncio import Task |
53 | 54 | from datetime import datetime, timezone |
54 | 55 | import json |
55 | 56 | import logging |
@@ -184,7 +185,7 @@ def __init__( |
184 | 185 | # Set up backend-specific components |
185 | 186 | if self._backend == "memory": |
186 | 187 | # Nothing special needed for memory backend |
187 | | - self._session_message = None |
| 188 | + self._session_message: dict[str, Any] | None = None |
188 | 189 |
|
189 | 190 | elif self._backend == "none": |
190 | 191 | # No session tracking - this is just a dummy registry |
@@ -296,7 +297,7 @@ def __init__( |
296 | 297 | self._sessions: Dict[str, Any] = {} # Local transport cache |
297 | 298 | self._client_capabilities: Dict[str, Dict[str, Any]] = {} # Client capabilities by session_id |
298 | 299 | self._lock = asyncio.Lock() |
299 | | - self._cleanup_task = None |
| 300 | + self._cleanup_task: Task | None = None |
300 | 301 |
|
301 | 302 | async def initialize(self) -> None: |
302 | 303 | """Initialize the registry with async setup. |
@@ -702,7 +703,7 @@ async def broadcast(self, session_id: str, message: Dict[str, Any]) -> None: |
702 | 703 | else: |
703 | 704 | msg_json = json.dumps(str(message)) |
704 | 705 |
|
705 | | - self._session_message: Dict[str, Any] = {"session_id": session_id, "message": msg_json} |
| 706 | + self._session_message: Dict[str, Any] | None = {"session_id": session_id, "message": msg_json} |
706 | 707 |
|
707 | 708 | elif self._backend == "redis": |
708 | 709 | try: |
@@ -840,7 +841,7 @@ async def respond( |
840 | 841 | elif self._backend == "memory": |
841 | 842 | # if self._session_message: |
842 | 843 | transport = self.get_session_sync(session_id) |
843 | | - if transport: |
| 844 | + if transport and self._session_message: |
844 | 845 | message = json.loads(str(self._session_message.get("message"))) |
845 | 846 | await self.generate_response(message=message, transport=transport, server_id=server_id, user=user, base_url=base_url) |
846 | 847 |
|
@@ -868,7 +869,7 @@ async def respond( |
868 | 869 |
|
869 | 870 | elif self._backend == "database": |
870 | 871 |
|
871 | | - def _db_read_session(session_id: str) -> SessionRecord: |
| 872 | + def _db_read_session(session_id: str) -> SessionRecord | None: |
872 | 873 | """Check if session still exists in the database. |
873 | 874 |
|
874 | 875 | Queries the SessionRecord table to verify that the session |
@@ -903,7 +904,7 @@ def _db_read_session(session_id: str) -> SessionRecord: |
903 | 904 | finally: |
904 | 905 | db_session.close() |
905 | 906 |
|
906 | | - def _db_read(session_id: str) -> SessionMessageRecord: |
| 907 | + def _db_read(session_id: str) -> SessionMessageRecord | None: |
907 | 908 | """Read pending message for a session from the database. |
908 | 909 |
|
909 | 910 | Retrieves the first (oldest) unprocessed message for the given |
@@ -1348,23 +1349,23 @@ async def generate_response(self, message: Dict[str, Any], transport: SSETranspo |
1348 | 1349 | result = {} |
1349 | 1350 |
|
1350 | 1351 | if "method" in message and "id" in message: |
| 1352 | + method = message["method"] |
| 1353 | + params = message.get("params", {}) |
| 1354 | + params["server_id"] = server_id |
| 1355 | + req_id = message["id"] |
| 1356 | + |
| 1357 | + rpc_input = { |
| 1358 | + "jsonrpc": "2.0", |
| 1359 | + "method": method, |
| 1360 | + "params": params, |
| 1361 | + "id": req_id, |
| 1362 | + } |
| 1363 | + # Get the token from the current authentication context |
| 1364 | + # The user object doesn't contain the token directly, we need to reconstruct it |
| 1365 | + # Since we don't have access to the original headers here, we need a different approach |
| 1366 | + # We'll extract the token from the session or create a new admin token |
| 1367 | + token = None |
1351 | 1368 | try: |
1352 | | - method = message["method"] |
1353 | | - params = message.get("params", {}) |
1354 | | - params["server_id"] = server_id |
1355 | | - req_id = message["id"] |
1356 | | - |
1357 | | - rpc_input = { |
1358 | | - "jsonrpc": "2.0", |
1359 | | - "method": method, |
1360 | | - "params": params, |
1361 | | - "id": req_id, |
1362 | | - } |
1363 | | - # Get the token from the current authentication context |
1364 | | - # The user object doesn't contain the token directly, we need to reconstruct it |
1365 | | - # Since we don't have access to the original headers here, we need a different approach |
1366 | | - # We'll extract the token from the session or create a new admin token |
1367 | | - token = None |
1368 | 1369 | if hasattr(user, "get") and "auth_token" in user: |
1369 | 1370 | token = user["auth_token"] |
1370 | 1371 | else: |
|
0 commit comments