Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 23 additions & 22 deletions mcpgateway/cache/session_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@

# Standard
import asyncio
from asyncio import Task
from datetime import datetime, timezone
import json
import logging
Expand Down Expand Up @@ -184,7 +185,7 @@ def __init__(
# Set up backend-specific components
if self._backend == "memory":
# Nothing special needed for memory backend
self._session_message = None
self._session_message: dict[str, Any] | None = None

elif self._backend == "none":
# No session tracking - this is just a dummy registry
Expand Down Expand Up @@ -296,7 +297,7 @@ def __init__(
self._sessions: Dict[str, Any] = {} # Local transport cache
self._client_capabilities: Dict[str, Dict[str, Any]] = {} # Client capabilities by session_id
self._lock = asyncio.Lock()
self._cleanup_task = None
self._cleanup_task: Task | None = None

async def initialize(self) -> None:
"""Initialize the registry with async setup.
Expand Down Expand Up @@ -702,7 +703,7 @@ async def broadcast(self, session_id: str, message: Dict[str, Any]) -> None:
else:
msg_json = json.dumps(str(message))

self._session_message: Dict[str, Any] = {"session_id": session_id, "message": msg_json}
self._session_message: Dict[str, Any] | None = {"session_id": session_id, "message": msg_json}
Copy link

Copilot AI Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type annotation on assignment is incorrect. The variable _session_message is being assigned a dict value {\"session_id\": session_id, \"message\": msg_json}, but the annotation indicates it could be None. This should be just self._session_message = {\"session_id\": session_id, \"message\": msg_json} without the type annotation, as the type was already declared at initialization (line 188).

Suggested change
self._session_message: Dict[str, Any] | None = {"session_id": session_id, "message": msg_json}
self._session_message = {"session_id": session_id, "message": msg_json}

Copilot uses AI. Check for mistakes.

elif self._backend == "redis":
try:
Expand Down Expand Up @@ -840,7 +841,7 @@ async def respond(
elif self._backend == "memory":
# if self._session_message:
transport = self.get_session_sync(session_id)
if transport:
if transport and self._session_message:
message = json.loads(str(self._session_message.get("message")))
await self.generate_response(message=message, transport=transport, server_id=server_id, user=user, base_url=base_url)

Expand Down Expand Up @@ -868,7 +869,7 @@ async def respond(

elif self._backend == "database":

def _db_read_session(session_id: str) -> SessionRecord:
def _db_read_session(session_id: str) -> SessionRecord | None:
"""Check if session still exists in the database.

Queries the SessionRecord table to verify that the session
Expand Down Expand Up @@ -903,7 +904,7 @@ def _db_read_session(session_id: str) -> SessionRecord:
finally:
db_session.close()

def _db_read(session_id: str) -> SessionMessageRecord:
def _db_read(session_id: str) -> SessionMessageRecord | None:
"""Read pending message for a session from the database.

Retrieves the first (oldest) unprocessed message for the given
Expand Down Expand Up @@ -1348,23 +1349,23 @@ async def generate_response(self, message: Dict[str, Any], transport: SSETranspo
result = {}

if "method" in message and "id" in message:
method = message["method"]
params = message.get("params", {})
params["server_id"] = server_id
req_id = message["id"]

rpc_input = {
"jsonrpc": "2.0",
"method": method,
"params": params,
"id": req_id,
}
# Get the token from the current authentication context
# The user object doesn't contain the token directly, we need to reconstruct it
# Since we don't have access to the original headers here, we need a different approach
# We'll extract the token from the session or create a new admin token
token = None
try:
method = message["method"]
params = message.get("params", {})
params["server_id"] = server_id
req_id = message["id"]

rpc_input = {
"jsonrpc": "2.0",
"method": method,
"params": params,
"id": req_id,
}
# Get the token from the current authentication context
# The user object doesn't contain the token directly, we need to reconstruct it
# Since we don't have access to the original headers here, we need a different approach
# We'll extract the token from the session or create a new admin token
token = None
if hasattr(user, "get") and "auth_token" in user:
token = user["auth_token"]
else:
Expand Down
21 changes: 15 additions & 6 deletions mcpgateway/services/logging_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ def emit(self, record: logging.LogRecord) -> None:
# No running loop, can't store
return

# Schedule the coroutine
asyncio.run_coroutine_threadsafe(
# Schedule the coroutine and store the future (fire-and-forget)
future = asyncio.run_coroutine_threadsafe(
self.storage.add_log(
level=log_level,
message=message,
Expand All @@ -167,6 +167,8 @@ def emit(self, record: logging.LogRecord) -> None:
),
self.loop,
)
# Add a done callback to catch any exceptions without blocking
future.add_done_callback(lambda f: f.exception() if not f.cancelled() else None)
Copy link

Copilot AI Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The done callback silently ignores exceptions. While this may be intentional for the fire-and-forget pattern, the lambda should at least log the exception when it occurs for debugging purposes, similar to how exceptions are handled elsewhere in this module.

Copilot uses AI. Check for mistakes.
except Exception:
# Silently fail to avoid logging recursion
pass # nosec B110 - Intentional to prevent logging recursion
Expand Down Expand Up @@ -204,6 +206,7 @@ def __init__(self) -> None:
self._subscribers: List[asyncio.Queue[_LogMessage]] = []
self._loggers: Dict[str, logging.Logger] = {}
self._storage: LogStorageService | None = None # Will be initialized if admin UI is enabled
self._storage_handler: Optional[StorageHandler] = None # Track the storage handler for cleanup

async def initialize(self) -> None:
"""Initialize logging service.
Expand Down Expand Up @@ -249,10 +252,10 @@ async def initialize(self) -> None:
self._storage = LogStorageService()

# Add storage handler to capture all logs
storage_handler = StorageHandler(self._storage)
storage_handler.setFormatter(text_formatter)
storage_handler.setLevel(getattr(logging, settings.log_level.upper()))
root_logger.addHandler(storage_handler)
self._storage_handler = StorageHandler(self._storage)
self._storage_handler.setFormatter(text_formatter)
self._storage_handler.setLevel(getattr(logging, settings.log_level.upper()))
root_logger.addHandler(self._storage_handler)

logging.info(f"Log storage initialized with {settings.log_buffer_size_mb}MB buffer")

Expand All @@ -271,6 +274,12 @@ async def shutdown(self) -> None:
>>> asyncio.run(service.shutdown())

"""
# Remove storage handler from root logger if it was added
if self._storage_handler:
root_logger = logging.getLogger()
root_logger.removeHandler(self._storage_handler)
self._storage_handler = None

# Clear subscribers
self._subscribers.clear()
logging.info("Logging service shutdown")
Expand Down
2 changes: 1 addition & 1 deletion mcpgateway/transports/streamablehttp_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ async def get_prompt(prompt_id: str, arguments: dict[str, str] | None = None) ->
if not result or not result.messages:
logger.warning(f"No content returned by prompt: {prompt_id}")
return []
message_dicts = [message.dict() for message in result.messages]
message_dicts = [message.model_dump() for message in result.messages]
return types.GetPromptResult(messages=message_dicts, description=result.description)
except Exception as e:
logger.exception(f"Error getting prompt '{prompt_id}': {e}")
Expand Down
13 changes: 9 additions & 4 deletions plugins/content_moderation/content_moderation.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,12 @@ async def _moderate_with_patterns(self, text: str) -> ModerationResult:
break

return ModerationResult(
flagged=flagged, categories=categories, action=action, provider=ModerationProvider.IBM_WATSON, confidence=max_score, details={"method": "pattern_matching"} # Default fallback
flagged=flagged,
categories=categories,
action=action,
provider=ModerationProvider.IBM_WATSON,
confidence=max_score,
details={"method": "pattern_matching"}, # Default fallback
)

async def _extract_text_content(self, payload: Any) -> List[str]:
Expand Down Expand Up @@ -555,7 +560,7 @@ async def prompt_pre_fetch(self, payload: PromptPrehookPayload, _context: Plugin

if self._cfg.audit_decisions:
logger.info(
f"Content moderation - Prompt: {payload.prompt_id}, Result: {result.flagged}, " f"Action: {result.action}, Provider: {result.provider}, " f"Confidence: {result.confidence:.2f}"
f"Content moderation - Prompt: {payload.prompt_id}, Result: {result.flagged}, Action: {result.action}, Provider: {result.provider}, Confidence: {result.confidence:.2f}"
)

if result.action == ModerationAction.BLOCK:
Expand All @@ -572,7 +577,7 @@ async def prompt_pre_fetch(self, payload: PromptPrehookPayload, _context: Plugin
"flagged_text_preview": text[:100] + "..." if len(text) > 100 else text,
},
),
metadata={"moderation_result": result.dict(), "provider": result.provider.value},
metadata={"moderation_result": result.model_dump(), "provider": result.provider.value},
)
elif result.modified_content:
# Modify the payload with redacted/transformed content
Expand All @@ -598,7 +603,7 @@ async def tool_pre_invoke(self, payload: ToolPreInvokePayload, _context: PluginC
result = await self._moderate_content(text)

if self._cfg.audit_decisions:
logger.info(f"Content moderation - Tool: {payload.name}, Result: {result.flagged}, " f"Action: {result.action}, Provider: {result.provider}")
logger.info(f"Content moderation - Tool: {payload.name}, Result: {result.flagged}, Action: {result.action}, Provider: {result.provider}")

if result.action == ModerationAction.BLOCK:
return ToolPreInvokeResult(
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ dev = [
"pytest-env>=1.1.5",
"pytest-examples>=0.0.18",
"pytest-httpx>=0.35.0",
"pytest-integration-mark>=0.2.0",
"pytest-md-report>=0.7.0",
"pytest-rerunfailures>=16.0.1",
"pytest-timeout>=2.4.0",
Expand Down
Loading
Loading