From a611ac13d257645a5637d2a18d525118ab24b15b Mon Sep 17 00:00:00 2001 From: Luv Bansal Date: Fri, 5 Dec 2025 18:21:48 +0530 Subject: [PATCH 01/13] Agentic Class --- clarifai/runners/models/agentic_class.py | 1324 ++++++++++++++++++++++ 1 file changed, 1324 insertions(+) create mode 100644 clarifai/runners/models/agentic_class.py diff --git a/clarifai/runners/models/agentic_class.py b/clarifai/runners/models/agentic_class.py new file mode 100644 index 00000000..f1f520f3 --- /dev/null +++ b/clarifai/runners/models/agentic_class.py @@ -0,0 +1,1324 @@ +"""Base class for creating OpenAI-compatible API server with MCP (Model Context Protocol) support.""" + +import asyncio +import json +import os +from typing import Any, Dict, Iterator, List + +from clarifai_grpc.grpc.api.status import status_code_pb2 +from pydantic_core import from_json, to_json + +from clarifai.runners.models.model_class import ModelClass +from clarifai.runners.models.openai_class import OpenAIModelClass +from clarifai.utils.logging import logger + + +class AgenticModelClass(OpenAIModelClass): + """Base class for wrapping OpenAI-compatible servers with MCP (Model Context Protocol) support. + + This class extends OpenAIModelClass to enable agentic behavior by integrating LLMs with MCP servers. + It handles tool discovery, execution, and iterative tool calling for both chat completions and + responses endpoints, supporting both streaming and non-streaming modes. + + To use this class, create a subclass and set the following class attributes: + - client: The OpenAI-compatible client instance + - model: The name of the model to use with the client + + Example: + class MyAgenticModel(AgenticModelClass): + client = OpenAI(api_key="your-key") + model = "gpt-4" + """ + + async def _connect_to_servers( + self, mcp_servers: List[str], max_retries: int = 2, retry_delay: float = 1.0 + ) -> Dict[str, Any]: + """Connect to all configured Clarifai MCP servers. + + Args: + mcp_servers: List of MCP server URLs to connect to + max_retries: Maximum number of retry attempts per server (default: 2) + retry_delay: Delay in seconds between retry attempts (default: 1.0) + + Returns: + Dictionary mapping server URLs to client info and tools + """ + try: + from fastmcp import Client + from fastmcp.client.transports import StreamableHttpTransport + except ImportError: + raise ImportError( + "fastmcp package is required to use MCP functionality. " + "Install it with: pip install fastmcp" + ) + + mcp_clients = {} + + for mcp_url in mcp_servers: + last_error = None + connected = False + + for attempt in range(max_retries): + try: + # Create transport for this server + transport = StreamableHttpTransport( + url=mcp_url, + headers={"Authorization": "Bearer " + os.environ["CLARIFAI_PAT"]}, + ) + + # Create and connect client + client = Client(transport) + await client.__aenter__() + + # Store client with server info + mcp_clients[mcp_url] = {"client": client, "tools": []} + + # List available tools + tools_result = await client.list_tools() + mcp_clients[mcp_url]["tools"] = tools_result + + logger.info(f"✓ Connected to {mcp_url} with {len(tools_result)} tools") + connected = True + break # Success, exit retry loop + + except Exception as e: + last_error = e + if attempt < max_retries - 1: + logger.warning( + f"⚠ Failed to connect to {mcp_url} (attempt {attempt + 1}/{max_retries}): {e}. " + f"Retrying in {retry_delay}s..." + ) + await asyncio.sleep(retry_delay) + else: + logger.error( + f"❌ Failed to connect to {mcp_url} after {max_retries} attempts: {e}" + ) + + if not connected: + # Log final failure if all retries exhausted + logger.error( + f"❌ Could not connect to {mcp_url} after {max_retries} attempts. " + f"Last error: {last_error}" + ) + # Continue with other servers even if one fails + + return mcp_clients + + async def _get_mcp_tools_and_clients( + self, mcp_servers: List[str] + ) -> tuple[List[dict], dict, dict]: + """Get available tools and clients from all connected MCP servers. + + Args: + mcp_servers: List of MCP server URLs + + Returns: + A tuple of (tools in OpenAI format, mcp_clients dictionary, tool_to_server mapping). + """ + mcp_clients = await self._connect_to_servers(mcp_servers) + + all_tools = [] + tool_to_server = {} # Map tool name to server URL + + for mcp_url, server_info in mcp_clients.items(): + tools = server_info["tools"] + for tool in tools: + tool_name = tool.name + all_tools.append( + { + "type": "function", + "function": { + "name": tool_name, + "description": f"{tool.description}", + "parameters": tool.inputSchema, + }, + } + ) + # Map tool name to its server URL + tool_to_server[tool_name] = mcp_url + + logger.info(f"Access to the {len(all_tools)} tools") + return all_tools, mcp_clients, tool_to_server + + async def _cleanup(self, mcp_clients: dict): + """Clean up MCP client resources. + + Args: + mcp_clients: Dictionary of MCP clients to clean up + """ + logger.info("Cleaning up MCP connections...") + for mcp_url, server_info in mcp_clients.items(): + try: + client = server_info["client"] + # Try to close the client properly + if hasattr(client, 'close') and callable(getattr(client, 'close', None)): + if asyncio.iscoroutinefunction(client.close): + await client.close() + else: + client.close() + else: + await client.__aexit__(None, None, None) + logger.info(f"✓ Disconnected from {mcp_url}") + except Exception as e: + # Log other errors but don't fail cleanup + logger.warning(f"⚠ Error disconnecting from {mcp_url}: {e} (continuing cleanup)") + + def _init_token_accumulation(self): + """Initialize token accumulation for a new request.""" + if not hasattr(self._thread_local, 'accumulated_tokens'): + self._thread_local.accumulated_tokens = {'prompt_tokens': 0, 'completion_tokens': 0} + + def _accumulate_usage(self, resp): + """Accumulate token usage from response object without calling set_output_context. + + This method extracts tokens from the response and adds them to the accumulated total. + It should be called for each API response in a multi-call request flow. + + Args: + resp: Response object with usage information + """ + # Extract usage from response (same logic as base _set_usage) + has_usage = getattr(resp, "usage", None) + has_response_usage = getattr(resp, "response", None) and getattr( + resp.response, "usage", None + ) + + if has_response_usage or has_usage: + prompt_tokens = 0 + completion_tokens = 0 + if has_usage: + prompt_tokens = getattr(resp.usage, "prompt_tokens", 0) or getattr( + resp.usage, "input_tokens", 0 + ) + completion_tokens = getattr(resp.usage, "completion_tokens", 0) or getattr( + resp.usage, "output_tokens", 0 + ) + else: + prompt_tokens = getattr(resp.response.usage, "input_tokens", 0) + completion_tokens = getattr(resp.response.usage, "output_tokens", 0) + + if prompt_tokens is None: + prompt_tokens = 0 + if completion_tokens is None: + completion_tokens = 0 + + # Only accumulate if we have valid tokens + if prompt_tokens > 0 or completion_tokens > 0: + self._init_token_accumulation() + self._thread_local.accumulated_tokens['prompt_tokens'] += prompt_tokens + self._thread_local.accumulated_tokens['completion_tokens'] += completion_tokens + + def _finalize_token_usage(self): + """Finalize token accumulation and set the total in output context. + + This should be called once at the end of a request that may have multiple API calls. + """ + if hasattr(self._thread_local, 'accumulated_tokens'): + prompt_tokens = self._thread_local.accumulated_tokens['prompt_tokens'] + completion_tokens = self._thread_local.accumulated_tokens['completion_tokens'] + + if prompt_tokens > 0 or completion_tokens > 0: + self.set_output_context( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + # Clean up + del self._thread_local.accumulated_tokens + + def _set_usage(self, resp): + """Override _set_usage to accumulate tokens across multiple API calls. + + In agentic flows, multiple OpenAI API calls are made (initial + recursive calls after tool execution). + This method accumulates tokens from all calls and only sets the final total once. + + Args: + resp: Response object with usage information + """ + # Accumulate tokens instead of immediately setting them + self._accumulate_usage(resp) + + def _handle_chat_completions( + self, + request_data: Dict[str, Any], + mcp_servers: List[str] = None, + mcp_clients: dict = None, + tools: List[dict] = None, + ): + """Handle chat completion requests with optional MCP tool support.""" + if mcp_servers and tools: + request_data = request_data.copy() + request_data["tools"] = tools + request_data["tool_choice"] = request_data.get("tool_choice", "auto") + + # Use base class implementation + return super()._handle_chat_completions(request_data) + + def _convert_tools_to_response_api_format(self, tools: List[dict]) -> List[dict]: + """Convert tools from chat completion format to response API format. + + Chat completion format: {"type": "function", "function": {"name": ..., "description": ..., "parameters": ...}} + Response API format: {"type": "function", "name": ..., "description": ..., "parameters": ...} + + Args: + tools: List of tools in chat completion format + + Returns: + List of tools in response API format + """ + response_api_tools = [] + for tool in tools: + if isinstance(tool, dict): + tool_type = tool.get("type", "function") + # Check if it's in chat completion format (has nested "function") + if "function" in tool: + func = tool["function"] + response_api_tools.append( + { + "type": tool_type, + "name": func.get("name"), + "description": func.get("description", ""), + "parameters": func.get("parameters", {}), + } + ) + # Already in response API format + elif "name" in tool: + response_api_tools.append(tool) + return response_api_tools + + def _handle_responses( + self, + request_data: Dict[str, Any], + mcp_servers: List[str] = None, + mcp_clients: dict = None, + tools: List[dict] = None, + ): + """Handle response API requests with optional MCP tool support.""" + # If we have MCP tools, convert them to response API format and add them to the request + if mcp_servers and tools: + request_data = request_data.copy() # Don't modify original + # Convert tools from chat completion format to response API format + response_api_tools = self._convert_tools_to_response_api_format(tools) + request_data["tools"] = response_api_tools + request_data["tool_choice"] = request_data.get("tool_choice", "auto") + + # Use base class implementation + return super()._handle_responses(request_data) + + def _route_request( + self, + endpoint: str, + request_data: Dict[str, Any], + mcp_servers: List[str] = None, + mcp_clients: dict = None, + tools: List[dict] = None, + ): + """Route the request to appropriate handler based on endpoint, with optional MCP support.""" + # For chat completions, pass MCP parameters + if endpoint == self.ENDPOINT_CHAT_COMPLETIONS: + return self._handle_chat_completions(request_data, mcp_servers, mcp_clients, tools) + + # For responses endpoint, pass MCP parameters + if endpoint == self.ENDPOINT_RESPONSES: + return self._handle_responses(request_data, mcp_servers, mcp_clients, tools) + + # For other endpoints, use base class implementation + return super()._route_request(endpoint, request_data) + + async def _execute_tool_calls( + self, + tool_calls: List[Any], + mcp_clients: dict, + messages: List[dict], + tool_to_server: dict = None, + ): + """Execute tool calls from chat completion and add results to messages. Handles both OpenAI tool_call objects and dict format.""" + for tool_call in tool_calls: + # Handle both OpenAI tool_call objects and dict format + if hasattr(tool_call, 'function'): + tool_name = tool_call.function.name + tool_args = json.loads(tool_call.function.arguments) + tool_id = tool_call.id + else: + tool_name = tool_call['function']['name'] + tool_args = json.loads(tool_call['function']['arguments']) + tool_id = tool_call['id'] + + result = None + error_msg = None + + # If we have tool-to-server mapping, try the correct server first + if tool_to_server and tool_name in tool_to_server: + server_url = tool_to_server[tool_name] + if server_url in mcp_clients: + try: + logger.info(f"Calling tool {tool_name} with arguments {tool_args}") + result = await mcp_clients[server_url]["client"].call_tool( + tool_name, arguments=tool_args + ) + except Exception as e: + error_msg = str(e) + logger.error(f"❌ Error calling tool {tool_name}: {e}") + + # If not found or failed, try all servers as fallback + if result is None: + for server_url, server_info in mcp_clients.items(): + # Skip if we already tried this server + if ( + tool_to_server + and tool_name in tool_to_server + and tool_to_server[tool_name] == server_url + ): + continue + try: + logger.info(f"Calling tool {tool_name} with arguments {tool_args}") + result = await server_info["client"].call_tool( + tool_name, arguments=tool_args + ) + break + except Exception as e: + error_msg = str(e) + logger.error(f"❌ Error calling tool {tool_name}: {e}") + continue + + if result: + content = ( + result.content[0].text if hasattr(result, 'content') else str(result[0].text) + ) + else: + content = f"Error: Failed to execute tool {tool_name}. {error_msg if error_msg else 'Tool not found on any server.'}" + + messages.append( + { + "role": "tool", + "tool_call_id": tool_id, + "content": content, + } + ) + + async def _execute_response_api_tool_calls( + self, + tool_calls: List[Dict[str, Any]], + mcp_clients: dict, + input_items: List[Any], + tool_to_server: dict = None, + ): + """Execute tool calls from response API and add results to input items. + + Args: + tool_calls: List of tool call dicts from response API output + mcp_clients: Dictionary of MCP clients + input_items: List of input items (can be modified in place) + tool_to_server: Mapping of tool names to server URLs + """ + for tool_call in tool_calls: + tool_name = tool_call.get("name") + tool_args_str = tool_call.get("arguments", "{}") + tool_id = tool_call.get("id") + call_id = tool_call.get("call_id") + + # Parse arguments + try: + tool_args = ( + json.loads(tool_args_str) if isinstance(tool_args_str, str) else tool_args_str + ) + except json.JSONDecodeError: + tool_args = {} + + result = None + error_msg = None + + # If we have tool-to-server mapping, try the correct server first + if tool_to_server and tool_name in tool_to_server: + server_url = tool_to_server[tool_name] + if server_url in mcp_clients: + try: + logger.info(f"Calling tool {tool_name} with arguments {tool_args}") + result = await mcp_clients[server_url]["client"].call_tool( + tool_name, arguments=tool_args + ) + except Exception as e: + error_msg = str(e) + logger.error(f"❌ Error calling tool {tool_name}: {e}") + + # If not found or failed, try all servers as fallback + if result is None: + for server_url, server_info in mcp_clients.items(): + # Skip if we already tried this server + if ( + tool_to_server + and tool_name in tool_to_server + and tool_to_server[tool_name] == server_url + ): + continue + try: + logger.info(f"Calling tool {tool_name} with arguments {tool_args}") + result = await server_info["client"].call_tool( + tool_name, arguments=tool_args + ) + break + except Exception as e: + error_msg = str(e) + logger.error(f"❌ Error calling tool {tool_name} : {e}") + continue + + # Get tool output + if result: + content = ( + result.content[0].text if hasattr(result, 'content') else str(result[0].text) + ) + else: + content = f"Error: Failed to execute tool {tool_name}. {error_msg if error_msg else 'Tool not found on any server.'}" + + # Use call_id if available, otherwise use id (call_id is required for function_call_output) + output_call_id = call_id if call_id else tool_id + if not output_call_id: + # If neither is available, skip this tool call + logger.warning( + f"⚠ Warning: No call_id or id found for tool {tool_name}, skipping output" + ) + continue + + input_items.append( + { + "type": "function_call_output", + "call_id": output_call_id, + "output": content, + } + ) + + def _extract_tool_calls_from_response_output( + self, response_output: List[Any] + ) -> List[Dict[str, Any]]: + """Extract tool calls from response API output array. + + Args: + response_output: List of output items from response API + + Returns: + List of tool call dictionaries that need to be executed + """ + tool_calls = [] + for item in response_output: + # Convert item to dict if it's a Pydantic model + if not isinstance(item, dict): + if hasattr(item, 'model_dump'): + item = item.model_dump() + elif hasattr(item, 'dict'): + item = item.dict() + elif hasattr(item, '__dict__'): + item = item.__dict__ + else: + continue + + # Check if item is a function_tool_call that needs execution + item_type = item.get("type") + if item_type in ["function_tool_call", "function_call", "function", "tool_call"]: + # Only execute if status indicates it needs execution (not already completed) + status = item.get("status", "") + output = item.get("output") + # Execute if status is pending/in_progress/empty or if output is missing + if status in ["pending", "in_progress", ""] or output is None: + tool_calls.append(item) + return tool_calls + + def _convert_output_items_to_input_items( + self, response_output: List[Any] + ) -> List[Dict[str, Any]]: + """Convert response API output items to input items format. + + This includes messages, reasoning, and completed tool calls (with outputs). + Excludes tool calls that are pending or in progress. + + Args: + response_output: List of output items from response API + + Returns: + List of input items in the format expected by response API + """ + input_items = [] + for item in response_output: + # Convert item to dict if it's a Pydantic model + if not isinstance(item, dict): + if hasattr(item, 'model_dump'): + item = item.model_dump() + elif hasattr(item, 'dict'): + item = item.dict() + elif hasattr(item, '__dict__'): + item = item.__dict__ + else: + continue + + item_type = item.get("type") + + # Include messages and reasoning as-is + if item_type in ["message", "reasoning"]: + input_items.append(item) + # Include completed tool calls (with output) as function_tool_call items + elif item_type in ["function_tool_call", "function_call", "function", "tool_call"]: + status = item.get("status", "") + output = item.get("output") + # Only include if it's completed (has output) + if output is not None or status in ["completed", "done"]: + input_items.append(item) + + return input_items + + def _accumulate_tool_call_delta(self, tool_call_delta, tool_calls_accumulated: dict): + """Accumulate tool call data from a streaming delta.""" + index = tool_call_delta.index + if index not in tool_calls_accumulated: + tool_calls_accumulated[index] = { + "id": tool_call_delta.id, + "type": "function", + "function": {"name": "", "arguments": ""}, + } + if tool_call_delta.id: + tool_calls_accumulated[index]["id"] = tool_call_delta.id + if tool_call_delta.function: + if tool_call_delta.function.name: + tool_calls_accumulated[index]["function"]["name"] = tool_call_delta.function.name + if tool_call_delta.function.arguments: + tool_calls_accumulated[index]["function"]["arguments"] += ( + tool_call_delta.function.arguments + ) + + def _convert_accumulated_tool_calls(self, tool_calls_accumulated: dict) -> List[dict]: + """Convert accumulated tool calls dictionary to list format in chat completion format.""" + tool_calls_list = [] + for idx in sorted(tool_calls_accumulated.keys()): + tc = tool_calls_accumulated[idx] + tool_calls_list.append( + { + "id": tc["id"], + "type": tc["type"], + "function": { + "name": tc["function"]["name"], + "arguments": tc["function"]["arguments"], + }, + } + ) + return tool_calls_list + + def _accumulate_response_tool_call_delta( + self, delta_item: Dict[str, Any], tool_calls_accumulated: Dict[str, Dict[str, Any]] + ): + """Accumulate tool call data from a streaming delta in response API format. + + Args: + delta_item: A delta item from response API streaming (type="function_tool_call") + tool_calls_accumulated: Dictionary mapping call_id to accumulated tool call data + """ + # Get call_id or generate one if not present + call_id = delta_item.get("call_id") or delta_item.get("id") + if not call_id: + # Use a temporary ID based on output_index if available + output_index = delta_item.get("output_index", 0) + call_id = f"temp_{output_index}" + + if call_id not in tool_calls_accumulated: + tool_calls_accumulated[call_id] = { + "id": call_id, + "type": "function_tool_call", + "name": "", + "arguments": "", + "status": "in_progress", + } + + # Accumulate name (may come incrementally) + if "name" in delta_item and delta_item["name"]: + tool_calls_accumulated[call_id]["name"] = delta_item["name"] + + # Accumulate arguments (may come incrementally as string) + if "arguments" in delta_item and delta_item["arguments"]: + tool_calls_accumulated[call_id]["arguments"] += delta_item["arguments"] + + # Update status if present + if "status" in delta_item: + tool_calls_accumulated[call_id]["status"] = delta_item["status"] + + def _create_completion_request( + self, + messages: List[dict], + tools: List[dict], + max_tokens: int, + temperature: float, + top_p: float, + stream: bool = False, + ): + """Create a completion request with common parameters.""" + kwargs = { + "model": self.model, + "messages": messages, + "max_completion_tokens": max_tokens, + "temperature": temperature, + "top_p": top_p, + } + if tools: + kwargs["tools"] = tools + kwargs["tool_choice"] = "auto" + if stream: + kwargs["stream"] = True + kwargs["stream_options"] = {"include_usage": True} + return self.client.chat.completions.create(**kwargs) + + def _bridge_async_generator(self, async_gen_func): + """Bridge an async generator to a sync generator.""" + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + gen = async_gen_func() + while True: + try: + yield loop.run_until_complete(gen.__anext__()) + except StopAsyncIteration: + break + finally: + loop.close() + + async def _stream_with_mcp_tools_json( + self, + openai_messages: List[dict], + tools: List[dict], + mcp_clients: dict, + max_tokens: int, + temperature: float, + top_p: float, + tool_to_server: dict = None, + ): + """Async generator to handle MCP tool calls with streaming support for chat completions, yielding JSON chunks.""" + tool_calls_accumulated = {} + streaming_response = "" + + stream = self._create_completion_request( + openai_messages, tools, max_tokens, temperature, top_p, stream=True + ) + + for chunk in stream: + self._set_usage(chunk) + yield chunk.model_dump_json() + + if chunk.choices: + delta = chunk.choices[0].delta + if delta.tool_calls: + for tool_call_delta in delta.tool_calls: + self._accumulate_tool_call_delta(tool_call_delta, tool_calls_accumulated) + if delta.content: + streaming_response += delta.content + + # Execute tool calls if any were accumulated + if tool_calls_accumulated: + tool_calls_list = self._convert_accumulated_tool_calls(tool_calls_accumulated) + openai_messages.append( + { + "role": "assistant", + "content": streaming_response if streaming_response else None, + "tool_calls": tool_calls_list, + } + ) + await self._execute_tool_calls( + tool_calls_list, mcp_clients, openai_messages, tool_to_server + ) + + # Continue streaming with tool results (recursive call - don't finalize here) + async for chunk_json in self._stream_with_mcp_tools_json( + openai_messages, tools, mcp_clients, max_tokens, temperature, top_p, tool_to_server + ): + yield chunk_json + # Note: Finalization happens at the top level in openai_stream_transport + + async def _stream_responses_with_mcp_tools_json( + self, + request_data: Dict[str, Any], + tools: List[dict], + mcp_clients: dict, + tool_to_server: dict = None, + ): + """Async generator to handle MCP tool calls with streaming support for response API, yielding JSON chunks.""" + # Get input items + input_data = request_data.get("input", "") + if isinstance(input_data, str): + input_items = [ + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": input_data}], + } + ] + else: + input_items = input_data if isinstance(input_data, list) else [] + + # Create request with tools (convert to response API format) + response_args = {**request_data, "model": self.model} + if tools: + # Convert tools from chat completion format to response API format + response_api_tools = self._convert_tools_to_response_api_format(tools) + response_args["tools"] = response_api_tools + response_args["tool_choice"] = response_args.get("tool_choice", "auto") + + # Stream the response and accumulate output + stream = self.client.responses.create(**response_args) + accumulated_output = [] + tool_calls_accumulated = {} # Track tool calls incrementally + original_to_filtered_index_map = {} # Map original output indices to filtered indices (for messages only) + + for chunk in stream: + self._set_usage(chunk) + + # Handle different event types from response API streaming + chunk_type = getattr(chunk, 'type', None) or chunk.__class__.__name__ + + # Check if this event contains a non-message item that should be filtered + should_yield = True + item_to_check = None + + # Check response.output_item.added events + if ( + chunk_type == 'response.output_item.added' + or chunk_type == 'ResponseOutputItemAddedEvent' + ) and hasattr(chunk, 'item'): + item_to_check = chunk.item + # Build index mapping for messages as we see them + if hasattr(chunk, 'output_index'): + item_dict = ( + item_to_check + if isinstance(item_to_check, dict) + else ( + item_to_check.model_dump() + if hasattr(item_to_check, 'model_dump') + else item_to_check.dict() + if hasattr(item_to_check, 'dict') + else {} + ) + ) + item_type = item_dict.get("type") + if item_type == "message": + # This is a message, map original index to filtered index + original_index = chunk.output_index + # The filtered index is just the count of messages we've seen so far + original_to_filtered_index_map[original_index] = len( + original_to_filtered_index_map + ) + # Check response.output_item.done events + elif ( + chunk_type == 'response.output_item.done' + or chunk_type == 'ResponseOutputItemDoneEvent' + ) and hasattr(chunk, 'item'): + item_to_check = chunk.item + # Check events with output_index (like response.output_item.delta) + elif hasattr(chunk, 'output_index'): + original_index = chunk.output_index + # Only yield if this index maps to a message + if original_index not in original_to_filtered_index_map: + should_yield = False + + # If we have an item to check, verify it's a message type + if item_to_check: + item_dict = ( + item_to_check + if isinstance(item_to_check, dict) + else ( + item_to_check.model_dump() + if hasattr(item_to_check, 'model_dump') + else item_to_check.dict() + if hasattr(item_to_check, 'dict') + else {} + ) + ) + item_type = item_dict.get("type") + # Only yield if it's a message, otherwise skip (but still process internally) + if item_type != "message": + should_yield = False + + # For response.completed events, filter output to only include messages before yielding + if ( + chunk_type == 'response.completed' or chunk_type == 'ResponseCompletedEvent' + ) and hasattr(chunk, 'response'): + response = chunk.response + if hasattr(response, 'output') and response.output: + # Filter output to only include message items and build index mapping + filtered_output = [] + filtered_index = 0 + original_to_filtered_index_map.clear() # Reset mapping for this response + + for original_index, item in enumerate(response.output): + item_dict = ( + item + if isinstance(item, dict) + else ( + item.model_dump() + if hasattr(item, 'model_dump') + else item.dict() + if hasattr(item, 'dict') + else {} + ) + ) + item_type = item_dict.get("type") + # Only include message items in the filtered output (as dicts for JSON serialization) + if item_type == "message": + filtered_output.append(item_dict) + original_to_filtered_index_map[original_index] = filtered_index + filtered_index += 1 + # Still accumulate tool calls for internal processing + elif item_type in [ + "function_tool_call", + "function_call", + "function", + "tool_call", + ]: + item_id = item_dict.get("id") + if item_id: + existing_ids = [ + i.get("id") + if isinstance(i, dict) + else (getattr(i, "id", None) if hasattr(i, "id") else None) + for i in accumulated_output + ] + if item_id not in existing_ids: + accumulated_output.append(item_dict) + else: + accumulated_output.append(item_dict) + else: + # For other types, still accumulate but don't include in filtered output + item_id = item_dict.get("id") + if item_id: + existing_ids = [ + i.get("id") + if isinstance(i, dict) + else (getattr(i, "id", None) if hasattr(i, "id") else None) + for i in accumulated_output + ] + if item_id not in existing_ids: + accumulated_output.append(item_dict) + else: + accumulated_output.append(item_dict) + + # Create a modified response with filtered output + response_dict = ( + response.model_dump() + if hasattr(response, 'model_dump') + else response.dict() + if hasattr(response, 'dict') + else {} + ) + response_dict["output"] = filtered_output + + # Create modified chunk with filtered response + modified_chunk_dict = { + "type": "response.completed", + "sequence_number": getattr(chunk, 'sequence_number', None), + "response": response_dict, + } + yield json.dumps(modified_chunk_dict) + else: + # No output to filter, yield as-is + yield chunk.model_dump_json() + elif should_yield: + # For events with output_index, remap to filtered index if it's a message index + if hasattr(chunk, 'output_index'): + original_index = chunk.output_index + if original_index in original_to_filtered_index_map: + # Remap the output_index to the filtered index + chunk_dict = ( + chunk.model_dump() + if hasattr(chunk, 'model_dump') + else ( + chunk.dict() + if hasattr(chunk, 'dict') + else json.loads(chunk.model_dump_json()) + if hasattr(chunk, 'model_dump_json') + else {} + ) + ) + chunk_dict["output_index"] = original_to_filtered_index_map[original_index] + yield json.dumps(chunk_dict) + # else: already filtered out by should_yield = False above + else: + # For all other chunk types, yield as-is (if not filtered out) + yield chunk.model_dump_json() + + # Handle ResponseOutputItemAddedEvent - initial tool call item + if ( + chunk_type == 'response.output_item.added' + or chunk_type == 'ResponseOutputItemAddedEvent' + ) and hasattr(chunk, 'item'): + item = chunk.item + item_dict = ( + item + if isinstance(item, dict) + else ( + item.model_dump() + if hasattr(item, 'model_dump') + else item.dict() + if hasattr(item, 'dict') + else {} + ) + ) + item_type = item_dict.get("type") + + # If it's a tool call, start accumulating it + if item_type in ["function_tool_call", "function_call", "function", "tool_call"]: + item_id = item_dict.get("id") or item_dict.get("call_id") + call_id = item_dict.get("call_id") + if item_id: + tool_calls_accumulated[item_id] = { + "id": item_id, + "call_id": call_id, # Preserve call_id for function_call_output + "type": item_type, + "name": item_dict.get("name", ""), + "arguments": item_dict.get("arguments", ""), + "status": item_dict.get("status", "in_progress"), + } + + # Handle ResponseFunctionCallArgumentsDeltaEvent - incremental argument updates + elif ( + chunk_type == 'response.function_call_arguments.delta' + or chunk_type == 'ResponseFunctionCallArgumentsDeltaEvent' + ): + item_id = getattr(chunk, 'item_id', None) + delta = getattr(chunk, 'delta', '') + + if item_id and item_id in tool_calls_accumulated: + # Accumulate the delta arguments + tool_calls_accumulated[item_id]["arguments"] += delta + + # Handle ResponseFunctionCallArgumentsDoneEvent - arguments complete + elif ( + chunk_type == 'response.function_call_arguments.done' + or chunk_type == 'ResponseFunctionCallArgumentsDoneEvent' + ): + item_id = getattr(chunk, 'item_id', None) + arguments = getattr(chunk, 'arguments', '') + + if item_id and item_id in tool_calls_accumulated: + # Set final arguments + tool_calls_accumulated[item_id]["arguments"] = arguments + + # Handle ResponseOutputItemDoneEvent - tool call item completed + elif ( + chunk_type == 'response.output_item.done' + or chunk_type == 'ResponseOutputItemDoneEvent' + ) and hasattr(chunk, 'item'): + item = chunk.item + item_dict = ( + item + if isinstance(item, dict) + else ( + item.model_dump() + if hasattr(item, 'model_dump') + else item.dict() + if hasattr(item, 'dict') + else {} + ) + ) + item_type = item_dict.get("type") + + # If it's a completed tool call, add to accumulated output + if item_type in ["function_tool_call", "function_call", "function", "tool_call"]: + item_id = item_dict.get("id") + if item_id and item_id in tool_calls_accumulated: + # Update with final status and preserve call_id if present + tool_calls_accumulated[item_id]["status"] = item_dict.get( + "status", "completed" + ) + if "call_id" in item_dict: + tool_calls_accumulated[item_id]["call_id"] = item_dict.get("call_id") + # Add to accumulated output + accumulated_output.append(tool_calls_accumulated[item_id]) + else: + # Not in accumulated, add directly + accumulated_output.append(item_dict) + else: + # Non-tool-call item + accumulated_output.append(item_dict) + + # Handle standard response objects with output (fallback) + elif hasattr(chunk, 'output') and chunk.output: + for item in chunk.output: + item_dict = ( + item + if isinstance(item, dict) + else ( + item.model_dump() + if hasattr(item, 'model_dump') + else item.dict() + if hasattr(item, 'dict') + else {} + ) + ) + item_id = item_dict.get("id") + if item_id: + existing_ids = [ + i.get("id") + if isinstance(i, dict) + else (getattr(i, "id", None) if hasattr(i, "id") else None) + for i in accumulated_output + ] + if item_id not in existing_ids: + accumulated_output.append(item_dict) + else: + accumulated_output.append(item_dict) + + # After streaming completes, add any remaining accumulated tool calls + for call_id, call_data in tool_calls_accumulated.items(): + # Only add if it has a name and is not already in accumulated_output + if call_data.get("name"): + existing_ids = [ + i.get("id") + if isinstance(i, dict) + else (getattr(i, "id", None) if hasattr(i, "id") else None) + for i in accumulated_output + ] + if call_id not in existing_ids: + accumulated_output.append(call_data) + + # Check for tool calls in accumulated output + tool_calls = self._extract_tool_calls_from_response_output(accumulated_output) + # Execute tool calls if any + if tool_calls: + # Convert model's output (messages, reasoning, completed tool calls) to input items + model_output_items = self._convert_output_items_to_input_items(accumulated_output) + input_items.extend(model_output_items) + + # Execute tool calls and add results to input + await self._execute_response_api_tool_calls( + tool_calls, mcp_clients, input_items, tool_to_server + ) + + # Update request with new input including model output and tool results + request_data["input"] = input_items + + # Continue streaming with tool results (recursive call - don't finalize here) + async for chunk_json in self._stream_responses_with_mcp_tools_json( + request_data, tools, mcp_clients, tool_to_server + ): + yield chunk_json + # Note: Finalization happens at the top level in openai_stream_transport + + @ModelClass.method + def openai_transport(self, msg: str) -> str: + """Process an OpenAI-compatible request and send it to the appropriate OpenAI endpoint. + + Args: + msg: JSON string containing the request parameters including 'openai_endpoint' + + Returns: + JSON string containing the response or error + """ + try: + request_data = from_json(msg) + request_data = self._update_old_fields(request_data) + mcp_servers = request_data.pop("mcp_servers", None) + endpoint = request_data.pop("openai_endpoint", self.DEFAULT_ENDPOINT) + tools = request_data.get("tools") + + if mcp_servers and len(mcp_servers) > 0 and tools is None: + + async def run_with_mcp(): + logger.info(f"Getting tools and clients for MCP servers: {mcp_servers}") + ( + tools_local, + mcp_clients_local, + tool_to_server_local, + ) = await self._get_mcp_tools_and_clients(mcp_servers) + try: + if endpoint == self.ENDPOINT_CHAT_COMPLETIONS: + response = self._route_request( + endpoint, request_data, mcp_servers, mcp_clients_local, tools_local + ) + + # Handle tool calls iteratively for chat completions + while response.choices and response.choices[0].message.tool_calls: + messages = request_data.get("messages", []) + messages.append(response.choices[0].message) + await self._execute_tool_calls( + response.choices[0].message.tool_calls, + mcp_clients_local, + messages, + tool_to_server_local, + ) + request_data["messages"] = messages + response = self._route_request( + endpoint, + request_data, + mcp_servers, + mcp_clients_local, + tools_local, + ) + + return response + elif endpoint == self.ENDPOINT_RESPONSES: + response = self._route_request( + endpoint, request_data, mcp_servers, mcp_clients_local, tools_local + ) + + # Handle tool calls iteratively for response API + # Get input items (can be string or list) + input_data = request_data.get("input", "") + if isinstance(input_data, str): + input_items = [ + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": input_data}], + } + ] + else: + input_items = input_data if isinstance(input_data, list) else [] + + # Extract tool calls from response output + response_output = ( + response.output if hasattr(response, 'output') else [] + ) + tool_calls = self._extract_tool_calls_from_response_output( + response_output + ) + + while tool_calls: + # Convert model's output (messages, reasoning, completed tool calls) to input items + model_output_items = self._convert_output_items_to_input_items( + response_output + ) + input_items.extend(model_output_items) + + # Execute tool calls and add results to input + await self._execute_response_api_tool_calls( + tool_calls, + mcp_clients_local, + input_items, + tool_to_server_local, + ) + # Update request with new input including model output and tool results + request_data["input"] = input_items + + # Make new request with tool results + response = self._route_request( + endpoint, + request_data, + mcp_servers, + mcp_clients_local, + tools_local, + ) + + # Check for more tool calls + response_output = ( + response.output if hasattr(response, 'output') else [] + ) + tool_calls = self._extract_tool_calls_from_response_output( + response_output + ) + + return response + else: + return self._route_request(endpoint, request_data) + finally: + await self._cleanup(mcp_clients_local) + + response = asyncio.run(run_with_mcp()) + else: + response = self._route_request(endpoint, request_data) + + # Finalize token usage accumulation (sum of all API calls) + self._finalize_token_usage() + return response.model_dump_json() + except Exception as e: + logger.exception(e) + return to_json( + { + "code": status_code_pb2.MODEL_PREDICTION_FAILED, + "description": "Model prediction failed", + "details": str(e), + } + ) + + @ModelClass.method + def openai_stream_transport(self, msg: str) -> Iterator[str]: + """Process an OpenAI-compatible request and return a streaming response iterator. + + This method is used when stream=True and returns an iterator of strings directly, + without converting to a list or JSON serializing. Supports chat completions and responses endpoints. + + Args: + msg: The request as a JSON string. + + Returns: + Iterator[str]: An iterator yielding text chunks from the streaming response. + """ + try: + request_data = from_json(msg) + request_data = self._update_old_fields(request_data) + mcp_servers = request_data.pop("mcp_servers", None) + endpoint = request_data.pop("openai_endpoint", self.DEFAULT_ENDPOINT) + + if endpoint not in [self.ENDPOINT_CHAT_COMPLETIONS, self.ENDPOINT_RESPONSES]: + raise ValueError("Streaming is only supported for chat completions and responses.") + + if mcp_servers and len(mcp_servers) > 0 and request_data.get("tools") is None: + + async def run_with_mcp(): + logger.info(f"Getting tools and clients for MCP servers: {mcp_servers}") + ( + tools_local, + mcp_clients_local, + tool_to_server_local, + ) = await self._get_mcp_tools_and_clients(mcp_servers) + try: + if endpoint == self.ENDPOINT_CHAT_COMPLETIONS: + messages = request_data.get("messages", []) + async for chunk_json in self._stream_with_mcp_tools_json( + messages, + tools_local, + mcp_clients_local, + request_data.get("max_completion_tokens", 4096), + request_data.get("temperature", 1.0), + request_data.get("top_p", 1.0), + tool_to_server_local, + ): + yield chunk_json + # Finalize token usage accumulation after streaming completes + self._finalize_token_usage() + elif endpoint == self.ENDPOINT_RESPONSES: + async for chunk_json in self._stream_responses_with_mcp_tools_json( + request_data, tools_local, mcp_clients_local, tool_to_server_local + ): + yield chunk_json + # Finalize token usage accumulation after streaming completes + self._finalize_token_usage() + else: + # Fallback for other endpoints + response_args = {**request_data, "model": self.model} + for chunk in self.client.responses.create(**response_args): + self._set_usage(chunk) + yield chunk.model_dump_json() + # Finalize token usage accumulation after streaming completes + self._finalize_token_usage() + finally: + await self._cleanup(mcp_clients_local) + + yield from self._bridge_async_generator(run_with_mcp) + return + + # Non-MCP path or responses endpoint + if endpoint == self.ENDPOINT_RESPONSES: + response_args = {**request_data, "model": self.model} + for chunk in self.client.responses.create(**response_args): + self._set_usage(chunk) + yield chunk.model_dump_json() + # Finalize token usage accumulation after streaming completes + self._finalize_token_usage() + else: + completion_args = self._create_completion_args(request_data) + for chunk in self.client.chat.completions.create(**completion_args): + self._set_usage(chunk) + yield chunk.model_dump_json() + # Finalize token usage accumulation after streaming completes + self._finalize_token_usage() + + except Exception as e: + logger.exception(e) + yield to_json( + { + "code": status_code_pb2.MODEL_PREDICTION_FAILED, + "description": "Model prediction failed", + "details": str(e), + } + ) From 9277d247f0501eb086d867253c597f2052afd1e7 Mon Sep 17 00:00:00 2001 From: Luv Bansal Date: Tue, 9 Dec 2025 16:45:38 +0530 Subject: [PATCH 02/13] reuse connection --- clarifai/runners/models/agentic_class.py | 1220 +++++++++++++--------- 1 file changed, 710 insertions(+), 510 deletions(-) diff --git a/clarifai/runners/models/agentic_class.py b/clarifai/runners/models/agentic_class.py index f1f520f3..d479cbc2 100644 --- a/clarifai/runners/models/agentic_class.py +++ b/clarifai/runners/models/agentic_class.py @@ -3,7 +3,9 @@ import asyncio import json import os -from typing import Any, Dict, Iterator, List +import threading +import time +from typing import Any, Dict, Iterator, List, Optional from clarifai_grpc.grpc.api.status import status_code_pb2 from pydantic_core import from_json, to_json @@ -13,35 +15,104 @@ from clarifai.utils.logging import logger -class AgenticModelClass(OpenAIModelClass): - """Base class for wrapping OpenAI-compatible servers with MCP (Model Context Protocol) support. +class MCPConnectionPool: + """Thread-safe connection pool for MCP servers with persistent connections. - This class extends OpenAIModelClass to enable agentic behavior by integrating LLMs with MCP servers. - It handles tool discovery, execution, and iterative tool calling for both chat completions and - responses endpoints, supporting both streaming and non-streaming modes. + This class manages MCP client connections across multiple requests, + maintaining persistent connections and handling reconnection when needed. + """ - To use this class, create a subclass and set the following class attributes: - - client: The OpenAI-compatible client instance - - model: The name of the model to use with the client + _instance: Optional['MCPConnectionPool'] = None + _lock = threading.Lock() + + def __new__(cls): + """Singleton pattern to ensure one connection pool per process.""" + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + if self._initialized: + return + + self._connections: Dict[ + str, Dict[str, Any] + ] = {} # url -> {client, tools, loop, last_used, lock} + self._connection_locks: Dict[str, threading.Lock] = {} # url -> lock for that connection + self._global_lock = threading.Lock() + self._loop: Optional[asyncio.AbstractEventLoop] = None + self._loop_thread: Optional[threading.Thread] = None + self._loop_lock = threading.Lock() + self._max_idle_time = 300 # 5 minutes idle timeout + self._cleanup_interval = 60 # Check for stale connections every minute + self._last_cleanup = time.time() + self._initialized = True + + def _get_or_create_event_loop(self) -> asyncio.AbstractEventLoop: + """Get or create a persistent event loop running in a background thread. + + This ensures MCP connections persist across request boundaries even when + the request's event loop is closed. + """ + with self._loop_lock: + # Check if we have a running loop + if self._loop is not None and self._loop_thread is not None: + if self._loop_thread.is_alive() and not self._loop.is_closed(): + return self._loop - Example: - class MyAgenticModel(AgenticModelClass): - client = OpenAI(api_key="your-key") - model = "gpt-4" - """ + # Create a new event loop in a background thread + loop_ready = threading.Event() - async def _connect_to_servers( - self, mcp_servers: List[str], max_retries: int = 2, retry_delay: float = 1.0 - ) -> Dict[str, Any]: - """Connect to all configured Clarifai MCP servers. + def run_loop(): + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + loop_ready.set() + self._loop.run_forever() + + self._loop_thread = threading.Thread(target=run_loop, daemon=True) + self._loop_thread.start() + loop_ready.wait(timeout=5.0) # Wait for loop to be ready + + if self._loop is None: + raise RuntimeError("Failed to create event loop for MCP connections") + + return self._loop + + def _run_coroutine(self, coro) -> Any: + """Run a coroutine in the persistent event loop. Args: - mcp_servers: List of MCP server URLs to connect to - max_retries: Maximum number of retry attempts per server (default: 2) - retry_delay: Delay in seconds between retry attempts (default: 1.0) + coro: Coroutine to run Returns: - Dictionary mapping server URLs to client info and tools + Result of the coroutine + """ + loop = self._get_or_create_event_loop() + future = asyncio.run_coroutine_threadsafe(coro, loop) + return future.result(timeout=30.0) # 30 second timeout for operations + + def _get_connection_lock(self, url: str) -> threading.Lock: + """Get or create a lock for a specific URL.""" + with self._global_lock: + if url not in self._connection_locks: + self._connection_locks[url] = threading.Lock() + return self._connection_locks[url] + + async def _connect_single_server( + self, url: str, max_retries: int = 2, retry_delay: float = 1.0 + ) -> Optional[Dict[str, Any]]: + """Connect to a single MCP server with retries. + + Args: + url: MCP server URL + max_retries: Maximum retry attempts + retry_delay: Delay between retries in seconds + + Returns: + Dictionary with client and tools, or None if connection failed """ try: from fastmcp import Client @@ -52,76 +123,177 @@ async def _connect_to_servers( "Install it with: pip install fastmcp" ) - mcp_clients = {} + last_error = None + + for attempt in range(max_retries): + try: + transport = StreamableHttpTransport( + url=url, + headers={"Authorization": "Bearer " + os.environ.get("CLARIFAI_PAT", "")}, + ) + + client = Client(transport) + await client.__aenter__() - for mcp_url in mcp_servers: - last_error = None - connected = False + # List available tools + tools_result = await client.list_tools() - for attempt in range(max_retries): - try: - # Create transport for this server - transport = StreamableHttpTransport( - url=mcp_url, - headers={"Authorization": "Bearer " + os.environ["CLARIFAI_PAT"]}, + logger.info(f"✓ Connected to {url} with {len(tools_result)} tools") + + return { + "client": client, + "tools": tools_result, + "last_used": time.time(), + "connected_at": time.time(), + } + + except Exception as e: + last_error = e + if attempt < max_retries - 1: + logger.warning( + f"⚠ Failed to connect to {url} (attempt {attempt + 1}/{max_retries}): {e}. " + f"Retrying in {retry_delay}s..." + ) + await asyncio.sleep(retry_delay) + else: + logger.error( + f"❌ Failed to connect to {url} after {max_retries} attempts: {e}" ) - # Create and connect client - client = Client(transport) - await client.__aenter__() + return None - # Store client with server info - mcp_clients[mcp_url] = {"client": client, "tools": []} + async def _verify_connection(self, url: str, connection_info: Dict[str, Any]) -> bool: + """Verify that a connection is still valid. - # List available tools - tools_result = await client.list_tools() - mcp_clients[mcp_url]["tools"] = tools_result + Args: + url: MCP server URL + connection_info: Connection info dictionary - logger.info(f"✓ Connected to {mcp_url} with {len(tools_result)} tools") - connected = True - break # Success, exit retry loop + Returns: + True if connection is valid, False otherwise + """ + try: + client = connection_info["client"] + # Try to list tools as a health check + await asyncio.wait_for(client.list_tools(), timeout=5.0) + return True + except Exception as e: + logger.warning(f"⚠ Connection to {url} is no longer valid: {e}") + return False - except Exception as e: - last_error = e - if attempt < max_retries - 1: - logger.warning( - f"⚠ Failed to connect to {mcp_url} (attempt {attempt + 1}/{max_retries}): {e}. " - f"Retrying in {retry_delay}s..." - ) - await asyncio.sleep(retry_delay) - else: - logger.error( - f"❌ Failed to connect to {mcp_url} after {max_retries} attempts: {e}" + async def _disconnect_single_server(self, url: str, connection_info: Dict[str, Any]): + """Disconnect from a single MCP server. + + Args: + url: MCP server URL + connection_info: Connection info dictionary + """ + try: + client = connection_info["client"] + if hasattr(client, 'close') and callable(getattr(client, 'close', None)): + if asyncio.iscoroutinefunction(client.close): + await client.close() + else: + client.close() + else: + await client.__aexit__(None, None, None) + logger.info(f"✓ Disconnected from {url}") + except Exception as e: + logger.warning(f"⚠ Error disconnecting from {url}: {e}") + + def get_connections( + self, mcp_servers: List[str], max_retries: int = 2, retry_delay: float = 1.0 + ) -> Dict[str, Any]: + """Get connections for the specified MCP servers. + + This method reuses existing connections when possible and creates + new ones as needed. Thread-safe. + + Args: + mcp_servers: List of MCP server URLs + max_retries: Maximum retry attempts for new connections + retry_delay: Delay between retries + + Returns: + Dictionary mapping server URLs to client info and tools + """ + # Periodic cleanup of stale connections + self._maybe_cleanup_stale_connections() + + result = {} + urls_to_connect = [] + + # First pass: get existing valid connections + for url in mcp_servers: + lock = self._get_connection_lock(url) + with lock: + if url in self._connections: + connection_info = self._connections[url] + # Check if connection is still valid + try: + is_valid = self._run_coroutine( + self._verify_connection(url, connection_info) ) + if is_valid: + connection_info["last_used"] = time.time() + result[url] = connection_info + logger.debug(f"Reusing existing connection to {url}") + continue + else: + # Connection is stale, remove it + del self._connections[url] + except Exception as e: + logger.warning(f"⚠ Error verifying connection to {url}: {e}") + # Remove potentially stale connection + if url in self._connections: + del self._connections[url] - if not connected: - # Log final failure if all retries exhausted - logger.error( - f"❌ Could not connect to {mcp_url} after {max_retries} attempts. " - f"Last error: {last_error}" - ) - # Continue with other servers even if one fails + urls_to_connect.append(url) + + # Second pass: connect to servers that need new connections + if urls_to_connect: - return mcp_clients + async def connect_servers(): + tasks = [] + for url in urls_to_connect: + tasks.append(self._connect_single_server(url, max_retries, retry_delay)) + return await asyncio.gather(*tasks, return_exceptions=True) - async def _get_mcp_tools_and_clients( + try: + results = self._run_coroutine(connect_servers()) + + for url, connection_result in zip(urls_to_connect, results): + if isinstance(connection_result, Exception): + logger.error(f"❌ Failed to connect to {url}: {connection_result}") + continue + if connection_result is not None: + lock = self._get_connection_lock(url) + with lock: + self._connections[url] = connection_result + result[url] = connection_result + except Exception as e: + logger.error(f"❌ Error connecting to MCP servers: {e}") + + return result + + def get_tools_and_mapping( self, mcp_servers: List[str] - ) -> tuple[List[dict], dict, dict]: - """Get available tools and clients from all connected MCP servers. + ) -> tuple[List[dict], Dict[str, Any], Dict[str, str]]: + """Get tools and server mapping for the specified MCP servers. Args: mcp_servers: List of MCP server URLs Returns: - A tuple of (tools in OpenAI format, mcp_clients dictionary, tool_to_server mapping). + Tuple of (tools in OpenAI format, mcp_clients dictionary, tool_to_server mapping) """ - mcp_clients = await self._connect_to_servers(mcp_servers) + mcp_clients = self.get_connections(mcp_servers) all_tools = [] - tool_to_server = {} # Map tool name to server URL + tool_to_server = {} for mcp_url, server_info in mcp_clients.items(): - tools = server_info["tools"] + tools = server_info.get("tools", []) for tool in tools: tool_name = tool.name all_tools.append( @@ -134,34 +306,223 @@ async def _get_mcp_tools_and_clients( }, } ) - # Map tool name to its server URL tool_to_server[tool_name] = mcp_url - logger.info(f"Access to the {len(all_tools)} tools") + logger.info(f"Access to {len(all_tools)} tools from {len(mcp_clients)} servers") return all_tools, mcp_clients, tool_to_server - async def _cleanup(self, mcp_clients: dict): - """Clean up MCP client resources. + def _maybe_cleanup_stale_connections(self): + """Clean up connections that have been idle for too long.""" + current_time = time.time() + + # Only run cleanup periodically + if current_time - self._last_cleanup < self._cleanup_interval: + return + + self._last_cleanup = current_time + urls_to_remove = [] + + with self._global_lock: + for url, connection_info in self._connections.items(): + last_used = connection_info.get("last_used", 0) + if current_time - last_used > self._max_idle_time: + urls_to_remove.append(url) + + for url in urls_to_remove: + self.disconnect(url) + + def disconnect(self, url: str): + """Disconnect from a specific MCP server. Args: - mcp_clients: Dictionary of MCP clients to clean up + url: MCP server URL to disconnect from """ - logger.info("Cleaning up MCP connections...") - for mcp_url, server_info in mcp_clients.items(): + lock = self._get_connection_lock(url) + with lock: + if url in self._connections: + connection_info = self._connections.pop(url) + try: + self._run_coroutine(self._disconnect_single_server(url, connection_info)) + except Exception as e: + logger.warning(f"⚠ Error during disconnect from {url}: {e}") + + def disconnect_all(self): + """Disconnect from all MCP servers.""" + with self._global_lock: + urls = list(self._connections.keys()) + + for url in urls: + self.disconnect(url) + + # Stop the event loop + with self._loop_lock: + if self._loop is not None and not self._loop.is_closed(): + self._loop.call_soon_threadsafe(self._loop.stop) + if self._loop_thread is not None: + self._loop_thread.join(timeout=5.0) + self._loop = None + self._loop_thread = None + + def call_tool( + self, + tool_name: str, + arguments: Dict[str, Any], + mcp_clients: Dict[str, Any], + tool_to_server: Dict[str, str], + ) -> Any: + """Call a tool on the appropriate MCP server. + + Args: + tool_name: Name of the tool to call + arguments: Arguments to pass to the tool + mcp_clients: Dictionary of MCP clients + tool_to_server: Mapping of tool names to server URLs + + Returns: + Tool call result + """ + + async def _call_tool(): + result = None + error_msg = None + + # Try the mapped server first + if tool_to_server and tool_name in tool_to_server: + server_url = tool_to_server[tool_name] + if server_url in mcp_clients: + try: + logger.info(f"Calling tool {tool_name} with arguments {arguments}") + result = await mcp_clients[server_url]["client"].call_tool( + tool_name, arguments=arguments + ) + return result + except Exception as e: + error_msg = str(e) + logger.error(f"❌ Error calling tool {tool_name}: {e}") + + # Fallback: try all servers + for server_url, server_info in mcp_clients.items(): + if tool_to_server and tool_name in tool_to_server: + if tool_to_server[tool_name] == server_url: + continue # Already tried this one + try: + logger.info(f"Calling tool {tool_name} with arguments {arguments}") + result = await server_info["client"].call_tool(tool_name, arguments=arguments) + return result + except Exception as e: + error_msg = str(e) + logger.error(f"❌ Error calling tool {tool_name}: {e}") + continue + + raise Exception( + f"Failed to execute tool {tool_name}. " + f"{error_msg if error_msg else 'Tool not found on any server.'}" + ) + + return self._run_coroutine(_call_tool()) + + async def call_tool_async( + self, + tool_name: str, + arguments: Dict[str, Any], + mcp_clients: Dict[str, Any], + tool_to_server: Dict[str, str], + ) -> Any: + """Async version of call_tool for use within async contexts. + + Args: + tool_name: Name of the tool to call + arguments: Arguments to pass to the tool + mcp_clients: Dictionary of MCP clients + tool_to_server: Mapping of tool names to server URLs + + Returns: + Tool call result + """ + result = None + error_msg = None + + # Try the mapped server first + if tool_to_server and tool_name in tool_to_server: + server_url = tool_to_server[tool_name] + if server_url in mcp_clients: + try: + logger.info(f"Calling tool {tool_name} with arguments {arguments}") + result = await mcp_clients[server_url]["client"].call_tool( + tool_name, arguments=arguments + ) + return result + except Exception as e: + error_msg = str(e) + logger.error(f"❌ Error calling tool {tool_name}: {e}") + + # Fallback: try all servers + for server_url, server_info in mcp_clients.items(): + if tool_to_server and tool_name in tool_to_server: + if tool_to_server[tool_name] == server_url: + continue try: - client = server_info["client"] - # Try to close the client properly - if hasattr(client, 'close') and callable(getattr(client, 'close', None)): - if asyncio.iscoroutinefunction(client.close): - await client.close() - else: - client.close() - else: - await client.__aexit__(None, None, None) - logger.info(f"✓ Disconnected from {mcp_url}") + logger.info(f"Calling tool {tool_name} with arguments {arguments}") + result = await server_info["client"].call_tool(tool_name, arguments=arguments) + return result except Exception as e: - # Log other errors but don't fail cleanup - logger.warning(f"⚠ Error disconnecting from {mcp_url}: {e} (continuing cleanup)") + error_msg = str(e) + logger.error(f"❌ Error calling tool {tool_name}: {e}") + continue + + raise Exception( + f"Failed to execute tool {tool_name}. " + f"{error_msg if error_msg else 'Tool not found on any server.'}" + ) + + +class AgenticModelClass(OpenAIModelClass): + """Base class for wrapping OpenAI-compatible servers with MCP (Model Context Protocol) support. + + This class extends OpenAIModelClass to enable agentic behavior by integrating LLMs with MCP servers. + It handles tool discovery, execution, and iterative tool calling for both chat completions and + responses endpoints, supporting both streaming and non-streaming modes. + + MCP connections are maintained persistently across requests using a connection pool, which + significantly improves performance by avoiding reconnection overhead. + + To use this class, create a subclass and set the following class attributes: + - client: The OpenAI-compatible client instance + - model: The name of the model to use with the client + + Example: + class MyAgenticModel(AgenticModelClass): + client = OpenAI(api_key="your-key") + model = "gpt-4" + """ + + # Singleton connection pool shared across all instances + _mcp_pool: Optional[MCPConnectionPool] = None + _pool_lock = threading.Lock() + + @classmethod + def _get_mcp_pool(cls) -> MCPConnectionPool: + """Get or create the MCP connection pool singleton.""" + if cls._mcp_pool is None: + with cls._pool_lock: + if cls._mcp_pool is None: + cls._mcp_pool = MCPConnectionPool() + return cls._mcp_pool + + def _get_mcp_tools_and_clients(self, mcp_servers: List[str]) -> tuple[List[dict], dict, dict]: + """Get available tools and clients from all connected MCP servers. + + This method uses the connection pool to reuse existing connections + when possible, significantly improving performance. + + Args: + mcp_servers: List of MCP server URLs + + Returns: + A tuple of (tools in OpenAI format, mcp_clients dictionary, tool_to_server mapping). + """ + pool = self._get_mcp_pool() + return pool.get_tools_and_mapping(mcp_servers) def _init_token_accumulation(self): """Initialize token accumulation for a new request.""" @@ -177,7 +538,6 @@ def _accumulate_usage(self, resp): Args: resp: Response object with usage information """ - # Extract usage from response (same logic as base _set_usage) has_usage = getattr(resp, "usage", None) has_response_usage = getattr(resp, "response", None) and getattr( resp.response, "usage", None @@ -202,17 +562,13 @@ def _accumulate_usage(self, resp): if completion_tokens is None: completion_tokens = 0 - # Only accumulate if we have valid tokens if prompt_tokens > 0 or completion_tokens > 0: self._init_token_accumulation() self._thread_local.accumulated_tokens['prompt_tokens'] += prompt_tokens self._thread_local.accumulated_tokens['completion_tokens'] += completion_tokens def _finalize_token_usage(self): - """Finalize token accumulation and set the total in output context. - - This should be called once at the end of a request that may have multiple API calls. - """ + """Finalize token accumulation and set the total in output context.""" if hasattr(self._thread_local, 'accumulated_tokens'): prompt_tokens = self._thread_local.accumulated_tokens['prompt_tokens'] completion_tokens = self._thread_local.accumulated_tokens['completion_tokens'] @@ -222,19 +578,10 @@ def _finalize_token_usage(self): prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, ) - # Clean up del self._thread_local.accumulated_tokens def _set_usage(self, resp): - """Override _set_usage to accumulate tokens across multiple API calls. - - In agentic flows, multiple OpenAI API calls are made (initial + recursive calls after tool execution). - This method accumulates tokens from all calls and only sets the final total once. - - Args: - resp: Response object with usage information - """ - # Accumulate tokens instead of immediately setting them + """Override _set_usage to accumulate tokens across multiple API calls.""" self._accumulate_usage(resp) def _handle_chat_completions( @@ -250,26 +597,14 @@ def _handle_chat_completions( request_data["tools"] = tools request_data["tool_choice"] = request_data.get("tool_choice", "auto") - # Use base class implementation return super()._handle_chat_completions(request_data) def _convert_tools_to_response_api_format(self, tools: List[dict]) -> List[dict]: - """Convert tools from chat completion format to response API format. - - Chat completion format: {"type": "function", "function": {"name": ..., "description": ..., "parameters": ...}} - Response API format: {"type": "function", "name": ..., "description": ..., "parameters": ...} - - Args: - tools: List of tools in chat completion format - - Returns: - List of tools in response API format - """ + """Convert tools from chat completion format to response API format.""" response_api_tools = [] for tool in tools: if isinstance(tool, dict): tool_type = tool.get("type", "function") - # Check if it's in chat completion format (has nested "function") if "function" in tool: func = tool["function"] response_api_tools.append( @@ -280,7 +615,6 @@ def _convert_tools_to_response_api_format(self, tools: List[dict]) -> List[dict] "parameters": func.get("parameters", {}), } ) - # Already in response API format elif "name" in tool: response_api_tools.append(tool) return response_api_tools @@ -293,15 +627,12 @@ def _handle_responses( tools: List[dict] = None, ): """Handle response API requests with optional MCP tool support.""" - # If we have MCP tools, convert them to response API format and add them to the request if mcp_servers and tools: - request_data = request_data.copy() # Don't modify original - # Convert tools from chat completion format to response API format + request_data = request_data.copy() response_api_tools = self._convert_tools_to_response_api_format(tools) request_data["tools"] = response_api_tools request_data["tool_choice"] = request_data.get("tool_choice", "auto") - # Use base class implementation return super()._handle_responses(request_data) def _route_request( @@ -312,28 +643,29 @@ def _route_request( mcp_clients: dict = None, tools: List[dict] = None, ): - """Route the request to appropriate handler based on endpoint, with optional MCP support.""" - # For chat completions, pass MCP parameters + """Route the request to appropriate handler based on endpoint.""" if endpoint == self.ENDPOINT_CHAT_COMPLETIONS: return self._handle_chat_completions(request_data, mcp_servers, mcp_clients, tools) - # For responses endpoint, pass MCP parameters if endpoint == self.ENDPOINT_RESPONSES: return self._handle_responses(request_data, mcp_servers, mcp_clients, tools) - # For other endpoints, use base class implementation return super()._route_request(endpoint, request_data) - async def _execute_tool_calls( + def _execute_tool_calls( self, tool_calls: List[Any], mcp_clients: dict, messages: List[dict], tool_to_server: dict = None, ): - """Execute tool calls from chat completion and add results to messages. Handles both OpenAI tool_call objects and dict format.""" + """Execute tool calls from chat completion and add results to messages. + + Uses the connection pool for tool execution. + """ + pool = self._get_mcp_pool() + for tool_call in tool_calls: - # Handle both OpenAI tool_call objects and dict format if hasattr(tool_call, 'function'): tool_name = tool_call.function.name tool_args = json.loads(tool_call.function.arguments) @@ -343,49 +675,51 @@ async def _execute_tool_calls( tool_args = json.loads(tool_call['function']['arguments']) tool_id = tool_call['id'] - result = None - error_msg = None + try: + result = pool.call_tool(tool_name, tool_args, mcp_clients, tool_to_server) + content = ( + result.content[0].text if hasattr(result, 'content') else str(result[0].text) + ) + except Exception as e: + content = f"Error: {str(e)}" - # If we have tool-to-server mapping, try the correct server first - if tool_to_server and tool_name in tool_to_server: - server_url = tool_to_server[tool_name] - if server_url in mcp_clients: - try: - logger.info(f"Calling tool {tool_name} with arguments {tool_args}") - result = await mcp_clients[server_url]["client"].call_tool( - tool_name, arguments=tool_args - ) - except Exception as e: - error_msg = str(e) - logger.error(f"❌ Error calling tool {tool_name}: {e}") + messages.append( + { + "role": "tool", + "tool_call_id": tool_id, + "content": content, + } + ) - # If not found or failed, try all servers as fallback - if result is None: - for server_url, server_info in mcp_clients.items(): - # Skip if we already tried this server - if ( - tool_to_server - and tool_name in tool_to_server - and tool_to_server[tool_name] == server_url - ): - continue - try: - logger.info(f"Calling tool {tool_name} with arguments {tool_args}") - result = await server_info["client"].call_tool( - tool_name, arguments=tool_args - ) - break - except Exception as e: - error_msg = str(e) - logger.error(f"❌ Error calling tool {tool_name}: {e}") - continue + async def _execute_tool_calls_async( + self, + tool_calls: List[Any], + mcp_clients: dict, + messages: List[dict], + tool_to_server: dict = None, + ): + """Async version of _execute_tool_calls for streaming contexts.""" + pool = self._get_mcp_pool() + + for tool_call in tool_calls: + if hasattr(tool_call, 'function'): + tool_name = tool_call.function.name + tool_args = json.loads(tool_call.function.arguments) + tool_id = tool_call.id + else: + tool_name = tool_call['function']['name'] + tool_args = json.loads(tool_call['function']['arguments']) + tool_id = tool_call['id'] - if result: + try: + result = await pool.call_tool_async( + tool_name, tool_args, mcp_clients, tool_to_server + ) content = ( result.content[0].text if hasattr(result, 'content') else str(result[0].text) ) - else: - content = f"Error: Failed to execute tool {tool_name}. {error_msg if error_msg else 'Tool not found on any server.'}" + except Exception as e: + content = f"Error: {str(e)}" messages.append( { @@ -395,28 +729,22 @@ async def _execute_tool_calls( } ) - async def _execute_response_api_tool_calls( + def _execute_response_api_tool_calls( self, tool_calls: List[Dict[str, Any]], mcp_clients: dict, input_items: List[Any], tool_to_server: dict = None, ): - """Execute tool calls from response API and add results to input items. + """Execute tool calls from response API and add results to input items.""" + pool = self._get_mcp_pool() - Args: - tool_calls: List of tool call dicts from response API output - mcp_clients: Dictionary of MCP clients - input_items: List of input items (can be modified in place) - tool_to_server: Mapping of tool names to server URLs - """ for tool_call in tool_calls: tool_name = tool_call.get("name") tool_args_str = tool_call.get("arguments", "{}") tool_id = tool_call.get("id") call_id = tool_call.get("call_id") - # Parse arguments try: tool_args = ( json.loads(tool_args_str) if isinstance(tool_args_str, str) else tool_args_str @@ -424,58 +752,63 @@ async def _execute_response_api_tool_calls( except json.JSONDecodeError: tool_args = {} - result = None - error_msg = None + try: + result = pool.call_tool(tool_name, tool_args, mcp_clients, tool_to_server) + content = ( + result.content[0].text if hasattr(result, 'content') else str(result[0].text) + ) + except Exception as e: + content = f"Error: {str(e)}" - # If we have tool-to-server mapping, try the correct server first - if tool_to_server and tool_name in tool_to_server: - server_url = tool_to_server[tool_name] - if server_url in mcp_clients: - try: - logger.info(f"Calling tool {tool_name} with arguments {tool_args}") - result = await mcp_clients[server_url]["client"].call_tool( - tool_name, arguments=tool_args - ) - except Exception as e: - error_msg = str(e) - logger.error(f"❌ Error calling tool {tool_name}: {e}") + output_call_id = call_id if call_id else tool_id + if not output_call_id: + logger.warning(f"⚠ No call_id or id found for tool {tool_name}, skipping output") + continue - # If not found or failed, try all servers as fallback - if result is None: - for server_url, server_info in mcp_clients.items(): - # Skip if we already tried this server - if ( - tool_to_server - and tool_name in tool_to_server - and tool_to_server[tool_name] == server_url - ): - continue - try: - logger.info(f"Calling tool {tool_name} with arguments {tool_args}") - result = await server_info["client"].call_tool( - tool_name, arguments=tool_args - ) - break - except Exception as e: - error_msg = str(e) - logger.error(f"❌ Error calling tool {tool_name} : {e}") - continue + input_items.append( + { + "type": "function_call_output", + "call_id": output_call_id, + "output": content, + } + ) - # Get tool output - if result: + async def _execute_response_api_tool_calls_async( + self, + tool_calls: List[Dict[str, Any]], + mcp_clients: dict, + input_items: List[Any], + tool_to_server: dict = None, + ): + """Async version of _execute_response_api_tool_calls for streaming contexts.""" + pool = self._get_mcp_pool() + + for tool_call in tool_calls: + tool_name = tool_call.get("name") + tool_args_str = tool_call.get("arguments", "{}") + tool_id = tool_call.get("id") + call_id = tool_call.get("call_id") + + try: + tool_args = ( + json.loads(tool_args_str) if isinstance(tool_args_str, str) else tool_args_str + ) + except json.JSONDecodeError: + tool_args = {} + + try: + result = await pool.call_tool_async( + tool_name, tool_args, mcp_clients, tool_to_server + ) content = ( result.content[0].text if hasattr(result, 'content') else str(result[0].text) ) - else: - content = f"Error: Failed to execute tool {tool_name}. {error_msg if error_msg else 'Tool not found on any server.'}" + except Exception as e: + content = f"Error: {str(e)}" - # Use call_id if available, otherwise use id (call_id is required for function_call_output) output_call_id = call_id if call_id else tool_id if not output_call_id: - # If neither is available, skip this tool call - logger.warning( - f"⚠ Warning: No call_id or id found for tool {tool_name}, skipping output" - ) + logger.warning(f"⚠ No call_id or id found for tool {tool_name}, skipping output") continue input_items.append( @@ -489,17 +822,9 @@ async def _execute_response_api_tool_calls( def _extract_tool_calls_from_response_output( self, response_output: List[Any] ) -> List[Dict[str, Any]]: - """Extract tool calls from response API output array. - - Args: - response_output: List of output items from response API - - Returns: - List of tool call dictionaries that need to be executed - """ + """Extract tool calls from response API output array.""" tool_calls = [] for item in response_output: - # Convert item to dict if it's a Pydantic model if not isinstance(item, dict): if hasattr(item, 'model_dump'): item = item.model_dump() @@ -510,13 +835,10 @@ def _extract_tool_calls_from_response_output( else: continue - # Check if item is a function_tool_call that needs execution item_type = item.get("type") if item_type in ["function_tool_call", "function_call", "function", "tool_call"]: - # Only execute if status indicates it needs execution (not already completed) status = item.get("status", "") output = item.get("output") - # Execute if status is pending/in_progress/empty or if output is missing if status in ["pending", "in_progress", ""] or output is None: tool_calls.append(item) return tool_calls @@ -524,20 +846,9 @@ def _extract_tool_calls_from_response_output( def _convert_output_items_to_input_items( self, response_output: List[Any] ) -> List[Dict[str, Any]]: - """Convert response API output items to input items format. - - This includes messages, reasoning, and completed tool calls (with outputs). - Excludes tool calls that are pending or in progress. - - Args: - response_output: List of output items from response API - - Returns: - List of input items in the format expected by response API - """ + """Convert response API output items to input items format.""" input_items = [] for item in response_output: - # Convert item to dict if it's a Pydantic model if not isinstance(item, dict): if hasattr(item, 'model_dump'): item = item.model_dump() @@ -550,14 +861,11 @@ def _convert_output_items_to_input_items( item_type = item.get("type") - # Include messages and reasoning as-is if item_type in ["message", "reasoning"]: input_items.append(item) - # Include completed tool calls (with output) as function_tool_call items elif item_type in ["function_tool_call", "function_call", "function", "tool_call"]: status = item.get("status", "") output = item.get("output") - # Only include if it's completed (has output) if output is not None or status in ["completed", "done"]: input_items.append(item) @@ -583,7 +891,7 @@ def _accumulate_tool_call_delta(self, tool_call_delta, tool_calls_accumulated: d ) def _convert_accumulated_tool_calls(self, tool_calls_accumulated: dict) -> List[dict]: - """Convert accumulated tool calls dictionary to list format in chat completion format.""" + """Convert accumulated tool calls dictionary to list format.""" tool_calls_list = [] for idx in sorted(tool_calls_accumulated.keys()): tc = tool_calls_accumulated[idx] @@ -599,43 +907,6 @@ def _convert_accumulated_tool_calls(self, tool_calls_accumulated: dict) -> List[ ) return tool_calls_list - def _accumulate_response_tool_call_delta( - self, delta_item: Dict[str, Any], tool_calls_accumulated: Dict[str, Dict[str, Any]] - ): - """Accumulate tool call data from a streaming delta in response API format. - - Args: - delta_item: A delta item from response API streaming (type="function_tool_call") - tool_calls_accumulated: Dictionary mapping call_id to accumulated tool call data - """ - # Get call_id or generate one if not present - call_id = delta_item.get("call_id") or delta_item.get("id") - if not call_id: - # Use a temporary ID based on output_index if available - output_index = delta_item.get("output_index", 0) - call_id = f"temp_{output_index}" - - if call_id not in tool_calls_accumulated: - tool_calls_accumulated[call_id] = { - "id": call_id, - "type": "function_tool_call", - "name": "", - "arguments": "", - "status": "in_progress", - } - - # Accumulate name (may come incrementally) - if "name" in delta_item and delta_item["name"]: - tool_calls_accumulated[call_id]["name"] = delta_item["name"] - - # Accumulate arguments (may come incrementally as string) - if "arguments" in delta_item and delta_item["arguments"]: - tool_calls_accumulated[call_id]["arguments"] += delta_item["arguments"] - - # Update status if present - if "status" in delta_item: - tool_calls_accumulated[call_id]["status"] = delta_item["status"] - def _create_completion_request( self, messages: List[dict], @@ -662,18 +933,52 @@ def _create_completion_request( return self.client.chat.completions.create(**kwargs) def _bridge_async_generator(self, async_gen_func): - """Bridge an async generator to a sync generator.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + """Bridge an async generator to a sync generator using the pool's event loop.""" + pool = self._get_mcp_pool() + loop = pool._get_or_create_event_loop() + + # Create a queue for communication between async generator and sync iteration + queue = asyncio.Queue() + done_event = threading.Event() + exception_holder = [None] + + async def producer(): + try: + async for item in async_gen_func(): + await queue.put(item) + except Exception as e: + exception_holder[0] = e + finally: + await queue.put(None) # Sentinel to signal completion + done_event.set() + + # Start the producer in the background loop + future = asyncio.run_coroutine_threadsafe(producer(), loop) + try: - gen = async_gen_func() while True: + # Get from queue with timeout to allow checking for exceptions + get_future = asyncio.run_coroutine_threadsafe(queue.get(), loop) try: - yield loop.run_until_complete(gen.__anext__()) - except StopAsyncIteration: + item = get_future.result(timeout=30.0) + except Exception as e: + if exception_holder[0]: + raise exception_holder[0] + raise + + if item is None: # Sentinel break + yield item + + # Check if there was an exception in the producer + if exception_holder[0]: + raise exception_holder[0] finally: - loop.close() + # Ensure the future is done + try: + future.result(timeout=1.0) + except: + pass async def _stream_with_mcp_tools_json( self, @@ -685,7 +990,7 @@ async def _stream_with_mcp_tools_json( top_p: float, tool_to_server: dict = None, ): - """Async generator to handle MCP tool calls with streaming support for chat completions, yielding JSON chunks.""" + """Async generator to handle MCP tool calls with streaming support for chat completions.""" tool_calls_accumulated = {} streaming_response = "" @@ -705,7 +1010,6 @@ async def _stream_with_mcp_tools_json( if delta.content: streaming_response += delta.content - # Execute tool calls if any were accumulated if tool_calls_accumulated: tool_calls_list = self._convert_accumulated_tool_calls(tool_calls_accumulated) openai_messages.append( @@ -715,16 +1019,14 @@ async def _stream_with_mcp_tools_json( "tool_calls": tool_calls_list, } ) - await self._execute_tool_calls( + await self._execute_tool_calls_async( tool_calls_list, mcp_clients, openai_messages, tool_to_server ) - # Continue streaming with tool results (recursive call - don't finalize here) async for chunk_json in self._stream_with_mcp_tools_json( openai_messages, tools, mcp_clients, max_tokens, temperature, top_p, tool_to_server ): yield chunk_json - # Note: Finalization happens at the top level in openai_stream_transport async def _stream_responses_with_mcp_tools_json( self, @@ -733,8 +1035,7 @@ async def _stream_responses_with_mcp_tools_json( mcp_clients: dict, tool_to_server: dict = None, ): - """Async generator to handle MCP tool calls with streaming support for response API, yielding JSON chunks.""" - # Get input items + """Async generator to handle MCP tool calls with streaming support for response API.""" input_data = request_data.get("input", "") if isinstance(input_data, str): input_items = [ @@ -747,37 +1048,30 @@ async def _stream_responses_with_mcp_tools_json( else: input_items = input_data if isinstance(input_data, list) else [] - # Create request with tools (convert to response API format) response_args = {**request_data, "model": self.model} if tools: - # Convert tools from chat completion format to response API format response_api_tools = self._convert_tools_to_response_api_format(tools) response_args["tools"] = response_api_tools response_args["tool_choice"] = response_args.get("tool_choice", "auto") - # Stream the response and accumulate output stream = self.client.responses.create(**response_args) accumulated_output = [] - tool_calls_accumulated = {} # Track tool calls incrementally - original_to_filtered_index_map = {} # Map original output indices to filtered indices (for messages only) + tool_calls_accumulated = {} + original_to_filtered_index_map = {} for chunk in stream: self._set_usage(chunk) - # Handle different event types from response API streaming chunk_type = getattr(chunk, 'type', None) or chunk.__class__.__name__ - # Check if this event contains a non-message item that should be filtered should_yield = True item_to_check = None - # Check response.output_item.added events if ( chunk_type == 'response.output_item.added' or chunk_type == 'ResponseOutputItemAddedEvent' ) and hasattr(chunk, 'item'): item_to_check = chunk.item - # Build index mapping for messages as we see them if hasattr(chunk, 'output_index'): item_dict = ( item_to_check @@ -792,26 +1086,20 @@ async def _stream_responses_with_mcp_tools_json( ) item_type = item_dict.get("type") if item_type == "message": - # This is a message, map original index to filtered index original_index = chunk.output_index - # The filtered index is just the count of messages we've seen so far original_to_filtered_index_map[original_index] = len( original_to_filtered_index_map ) - # Check response.output_item.done events elif ( chunk_type == 'response.output_item.done' or chunk_type == 'ResponseOutputItemDoneEvent' ) and hasattr(chunk, 'item'): item_to_check = chunk.item - # Check events with output_index (like response.output_item.delta) elif hasattr(chunk, 'output_index'): original_index = chunk.output_index - # Only yield if this index maps to a message if original_index not in original_to_filtered_index_map: should_yield = False - # If we have an item to check, verify it's a message type if item_to_check: item_dict = ( item_to_check @@ -825,20 +1113,17 @@ async def _stream_responses_with_mcp_tools_json( ) ) item_type = item_dict.get("type") - # Only yield if it's a message, otherwise skip (but still process internally) if item_type != "message": should_yield = False - # For response.completed events, filter output to only include messages before yielding if ( chunk_type == 'response.completed' or chunk_type == 'ResponseCompletedEvent' ) and hasattr(chunk, 'response'): response = chunk.response if hasattr(response, 'output') and response.output: - # Filter output to only include message items and build index mapping filtered_output = [] filtered_index = 0 - original_to_filtered_index_map.clear() # Reset mapping for this response + original_to_filtered_index_map.clear() for original_index, item in enumerate(response.output): item_dict = ( @@ -853,12 +1138,10 @@ async def _stream_responses_with_mcp_tools_json( ) ) item_type = item_dict.get("type") - # Only include message items in the filtered output (as dicts for JSON serialization) if item_type == "message": filtered_output.append(item_dict) original_to_filtered_index_map[original_index] = filtered_index filtered_index += 1 - # Still accumulate tool calls for internal processing elif item_type in [ "function_tool_call", "function_call", @@ -878,7 +1161,6 @@ async def _stream_responses_with_mcp_tools_json( else: accumulated_output.append(item_dict) else: - # For other types, still accumulate but don't include in filtered output item_id = item_dict.get("id") if item_id: existing_ids = [ @@ -892,7 +1174,6 @@ async def _stream_responses_with_mcp_tools_json( else: accumulated_output.append(item_dict) - # Create a modified response with filtered output response_dict = ( response.model_dump() if hasattr(response, 'model_dump') @@ -902,7 +1183,6 @@ async def _stream_responses_with_mcp_tools_json( ) response_dict["output"] = filtered_output - # Create modified chunk with filtered response modified_chunk_dict = { "type": "response.completed", "sequence_number": getattr(chunk, 'sequence_number', None), @@ -910,14 +1190,11 @@ async def _stream_responses_with_mcp_tools_json( } yield json.dumps(modified_chunk_dict) else: - # No output to filter, yield as-is yield chunk.model_dump_json() elif should_yield: - # For events with output_index, remap to filtered index if it's a message index if hasattr(chunk, 'output_index'): original_index = chunk.output_index if original_index in original_to_filtered_index_map: - # Remap the output_index to the filtered index chunk_dict = ( chunk.model_dump() if hasattr(chunk, 'model_dump') @@ -931,12 +1208,9 @@ async def _stream_responses_with_mcp_tools_json( ) chunk_dict["output_index"] = original_to_filtered_index_map[original_index] yield json.dumps(chunk_dict) - # else: already filtered out by should_yield = False above else: - # For all other chunk types, yield as-is (if not filtered out) yield chunk.model_dump_json() - # Handle ResponseOutputItemAddedEvent - initial tool call item if ( chunk_type == 'response.output_item.added' or chunk_type == 'ResponseOutputItemAddedEvent' @@ -955,21 +1229,19 @@ async def _stream_responses_with_mcp_tools_json( ) item_type = item_dict.get("type") - # If it's a tool call, start accumulating it if item_type in ["function_tool_call", "function_call", "function", "tool_call"]: item_id = item_dict.get("id") or item_dict.get("call_id") call_id = item_dict.get("call_id") if item_id: tool_calls_accumulated[item_id] = { "id": item_id, - "call_id": call_id, # Preserve call_id for function_call_output + "call_id": call_id, "type": item_type, "name": item_dict.get("name", ""), "arguments": item_dict.get("arguments", ""), "status": item_dict.get("status", "in_progress"), } - # Handle ResponseFunctionCallArgumentsDeltaEvent - incremental argument updates elif ( chunk_type == 'response.function_call_arguments.delta' or chunk_type == 'ResponseFunctionCallArgumentsDeltaEvent' @@ -978,10 +1250,8 @@ async def _stream_responses_with_mcp_tools_json( delta = getattr(chunk, 'delta', '') if item_id and item_id in tool_calls_accumulated: - # Accumulate the delta arguments tool_calls_accumulated[item_id]["arguments"] += delta - # Handle ResponseFunctionCallArgumentsDoneEvent - arguments complete elif ( chunk_type == 'response.function_call_arguments.done' or chunk_type == 'ResponseFunctionCallArgumentsDoneEvent' @@ -990,10 +1260,8 @@ async def _stream_responses_with_mcp_tools_json( arguments = getattr(chunk, 'arguments', '') if item_id and item_id in tool_calls_accumulated: - # Set final arguments tool_calls_accumulated[item_id]["arguments"] = arguments - # Handle ResponseOutputItemDoneEvent - tool call item completed elif ( chunk_type == 'response.output_item.done' or chunk_type == 'ResponseOutputItemDoneEvent' @@ -1012,26 +1280,20 @@ async def _stream_responses_with_mcp_tools_json( ) item_type = item_dict.get("type") - # If it's a completed tool call, add to accumulated output if item_type in ["function_tool_call", "function_call", "function", "tool_call"]: item_id = item_dict.get("id") if item_id and item_id in tool_calls_accumulated: - # Update with final status and preserve call_id if present tool_calls_accumulated[item_id]["status"] = item_dict.get( "status", "completed" ) if "call_id" in item_dict: tool_calls_accumulated[item_id]["call_id"] = item_dict.get("call_id") - # Add to accumulated output accumulated_output.append(tool_calls_accumulated[item_id]) else: - # Not in accumulated, add directly accumulated_output.append(item_dict) else: - # Non-tool-call item accumulated_output.append(item_dict) - # Handle standard response objects with output (fallback) elif hasattr(chunk, 'output') and chunk.output: for item in chunk.output: item_dict = ( @@ -1058,9 +1320,7 @@ async def _stream_responses_with_mcp_tools_json( else: accumulated_output.append(item_dict) - # After streaming completes, add any remaining accumulated tool calls for call_id, call_data in tool_calls_accumulated.items(): - # Only add if it has a name and is not already in accumulated_output if call_data.get("name"): existing_ids = [ i.get("id") @@ -1071,39 +1331,25 @@ async def _stream_responses_with_mcp_tools_json( if call_id not in existing_ids: accumulated_output.append(call_data) - # Check for tool calls in accumulated output tool_calls = self._extract_tool_calls_from_response_output(accumulated_output) - # Execute tool calls if any if tool_calls: - # Convert model's output (messages, reasoning, completed tool calls) to input items model_output_items = self._convert_output_items_to_input_items(accumulated_output) input_items.extend(model_output_items) - # Execute tool calls and add results to input - await self._execute_response_api_tool_calls( + await self._execute_response_api_tool_calls_async( tool_calls, mcp_clients, input_items, tool_to_server ) - # Update request with new input including model output and tool results request_data["input"] = input_items - # Continue streaming with tool results (recursive call - don't finalize here) async for chunk_json in self._stream_responses_with_mcp_tools_json( request_data, tools, mcp_clients, tool_to_server ): yield chunk_json - # Note: Finalization happens at the top level in openai_stream_transport @ModelClass.method def openai_transport(self, msg: str) -> str: - """Process an OpenAI-compatible request and send it to the appropriate OpenAI endpoint. - - Args: - msg: JSON string containing the request parameters including 'openai_endpoint' - - Returns: - JSON string containing the response or error - """ + """Process an OpenAI-compatible request and send it to the appropriate OpenAI endpoint.""" try: request_data = from_json(msg) request_data = self._update_old_fields(request_data) @@ -1112,112 +1358,86 @@ def openai_transport(self, msg: str) -> str: tools = request_data.get("tools") if mcp_servers and len(mcp_servers) > 0 and tools is None: + logger.info(f"Getting tools and clients for MCP servers: {mcp_servers}") + tools_local, mcp_clients_local, tool_to_server_local = ( + self._get_mcp_tools_and_clients(mcp_servers) + ) - async def run_with_mcp(): - logger.info(f"Getting tools and clients for MCP servers: {mcp_servers}") - ( - tools_local, - mcp_clients_local, - tool_to_server_local, - ) = await self._get_mcp_tools_and_clients(mcp_servers) - try: - if endpoint == self.ENDPOINT_CHAT_COMPLETIONS: - response = self._route_request( - endpoint, request_data, mcp_servers, mcp_clients_local, tools_local - ) + # Note: No cleanup needed - connections are maintained in the pool - # Handle tool calls iteratively for chat completions - while response.choices and response.choices[0].message.tool_calls: - messages = request_data.get("messages", []) - messages.append(response.choices[0].message) - await self._execute_tool_calls( - response.choices[0].message.tool_calls, - mcp_clients_local, - messages, - tool_to_server_local, - ) - request_data["messages"] = messages - response = self._route_request( - endpoint, - request_data, - mcp_servers, - mcp_clients_local, - tools_local, - ) - - return response - elif endpoint == self.ENDPOINT_RESPONSES: - response = self._route_request( - endpoint, request_data, mcp_servers, mcp_clients_local, tools_local - ) + if endpoint == self.ENDPOINT_CHAT_COMPLETIONS: + response = self._route_request( + endpoint, request_data, mcp_servers, mcp_clients_local, tools_local + ) - # Handle tool calls iteratively for response API - # Get input items (can be string or list) - input_data = request_data.get("input", "") - if isinstance(input_data, str): - input_items = [ - { - "type": "message", - "role": "user", - "content": [{"type": "input_text", "text": input_data}], - } - ] - else: - input_items = input_data if isinstance(input_data, list) else [] + while response.choices and response.choices[0].message.tool_calls: + messages = request_data.get("messages", []) + messages.append(response.choices[0].message) + self._execute_tool_calls( + response.choices[0].message.tool_calls, + mcp_clients_local, + messages, + tool_to_server_local, + ) + request_data["messages"] = messages + response = self._route_request( + endpoint, + request_data, + mcp_servers, + mcp_clients_local, + tools_local, + ) - # Extract tool calls from response output - response_output = ( - response.output if hasattr(response, 'output') else [] - ) - tool_calls = self._extract_tool_calls_from_response_output( - response_output - ) + elif endpoint == self.ENDPOINT_RESPONSES: + response = self._route_request( + endpoint, request_data, mcp_servers, mcp_clients_local, tools_local + ) - while tool_calls: - # Convert model's output (messages, reasoning, completed tool calls) to input items - model_output_items = self._convert_output_items_to_input_items( - response_output - ) - input_items.extend(model_output_items) - - # Execute tool calls and add results to input - await self._execute_response_api_tool_calls( - tool_calls, - mcp_clients_local, - input_items, - tool_to_server_local, - ) - # Update request with new input including model output and tool results - request_data["input"] = input_items - - # Make new request with tool results - response = self._route_request( - endpoint, - request_data, - mcp_servers, - mcp_clients_local, - tools_local, - ) - - # Check for more tool calls - response_output = ( - response.output if hasattr(response, 'output') else [] - ) - tool_calls = self._extract_tool_calls_from_response_output( - response_output - ) - - return response - else: - return self._route_request(endpoint, request_data) - finally: - await self._cleanup(mcp_clients_local) + input_data = request_data.get("input", "") + if isinstance(input_data, str): + input_items = [ + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": input_data}], + } + ] + else: + input_items = input_data if isinstance(input_data, list) else [] + + response_output = response.output if hasattr(response, 'output') else [] + tool_calls = self._extract_tool_calls_from_response_output(response_output) + + while tool_calls: + model_output_items = self._convert_output_items_to_input_items( + response_output + ) + input_items.extend(model_output_items) - response = asyncio.run(run_with_mcp()) + self._execute_response_api_tool_calls( + tool_calls, + mcp_clients_local, + input_items, + tool_to_server_local, + ) + request_data["input"] = input_items + + response = self._route_request( + endpoint, + request_data, + mcp_servers, + mcp_clients_local, + tools_local, + ) + + response_output = response.output if hasattr(response, 'output') else [] + tool_calls = self._extract_tool_calls_from_response_output(response_output) + + else: + response = self._route_request(endpoint, request_data) else: response = self._route_request(endpoint, request_data) - # Finalize token usage accumulation (sum of all API calls) self._finalize_token_usage() return response.model_dump_json() except Exception as e: @@ -1232,17 +1452,7 @@ async def run_with_mcp(): @ModelClass.method def openai_stream_transport(self, msg: str) -> Iterator[str]: - """Process an OpenAI-compatible request and return a streaming response iterator. - - This method is used when stream=True and returns an iterator of strings directly, - without converting to a list or JSON serializing. Supports chat completions and responses endpoints. - - Args: - msg: The request as a JSON string. - - Returns: - Iterator[str]: An iterator yielding text chunks from the streaming response. - """ + """Process an OpenAI-compatible request and return a streaming response iterator.""" try: request_data = from_json(msg) request_data = self._update_old_fields(request_data) @@ -1253,64 +1463,54 @@ def openai_stream_transport(self, msg: str) -> Iterator[str]: raise ValueError("Streaming is only supported for chat completions and responses.") if mcp_servers and len(mcp_servers) > 0 and request_data.get("tools") is None: + logger.info(f"Getting tools and clients for MCP servers: {mcp_servers}") + tools_local, mcp_clients_local, tool_to_server_local = ( + self._get_mcp_tools_and_clients(mcp_servers) + ) - async def run_with_mcp(): - logger.info(f"Getting tools and clients for MCP servers: {mcp_servers}") - ( - tools_local, - mcp_clients_local, - tool_to_server_local, - ) = await self._get_mcp_tools_and_clients(mcp_servers) - try: - if endpoint == self.ENDPOINT_CHAT_COMPLETIONS: - messages = request_data.get("messages", []) - async for chunk_json in self._stream_with_mcp_tools_json( - messages, - tools_local, - mcp_clients_local, - request_data.get("max_completion_tokens", 4096), - request_data.get("temperature", 1.0), - request_data.get("top_p", 1.0), - tool_to_server_local, - ): - yield chunk_json - # Finalize token usage accumulation after streaming completes - self._finalize_token_usage() - elif endpoint == self.ENDPOINT_RESPONSES: - async for chunk_json in self._stream_responses_with_mcp_tools_json( - request_data, tools_local, mcp_clients_local, tool_to_server_local - ): - yield chunk_json - # Finalize token usage accumulation after streaming completes - self._finalize_token_usage() - else: - # Fallback for other endpoints - response_args = {**request_data, "model": self.model} - for chunk in self.client.responses.create(**response_args): - self._set_usage(chunk) - yield chunk.model_dump_json() - # Finalize token usage accumulation after streaming completes - self._finalize_token_usage() - finally: - await self._cleanup(mcp_clients_local) - - yield from self._bridge_async_generator(run_with_mcp) + # Note: No cleanup needed - connections are maintained in the pool + + async def stream_generator(): + if endpoint == self.ENDPOINT_CHAT_COMPLETIONS: + messages = request_data.get("messages", []) + async for chunk_json in self._stream_with_mcp_tools_json( + messages, + tools_local, + mcp_clients_local, + request_data.get("max_completion_tokens", 4096), + request_data.get("temperature", 1.0), + request_data.get("top_p", 1.0), + tool_to_server_local, + ): + yield chunk_json + self._finalize_token_usage() + elif endpoint == self.ENDPOINT_RESPONSES: + async for chunk_json in self._stream_responses_with_mcp_tools_json( + request_data, tools_local, mcp_clients_local, tool_to_server_local + ): + yield chunk_json + self._finalize_token_usage() + else: + response_args = {**request_data, "model": self.model} + for chunk in self.client.responses.create(**response_args): + self._set_usage(chunk) + yield chunk.model_dump_json() + self._finalize_token_usage() + + yield from self._bridge_async_generator(stream_generator) return - # Non-MCP path or responses endpoint if endpoint == self.ENDPOINT_RESPONSES: response_args = {**request_data, "model": self.model} for chunk in self.client.responses.create(**response_args): self._set_usage(chunk) yield chunk.model_dump_json() - # Finalize token usage accumulation after streaming completes self._finalize_token_usage() else: completion_args = self._create_completion_args(request_data) for chunk in self.client.chat.completions.create(**completion_args): self._set_usage(chunk) yield chunk.model_dump_json() - # Finalize token usage accumulation after streaming completes self._finalize_token_usage() except Exception as e: From c0cf5bf8f4792f7ce854a0b4f13e2c1c38278d25 Mon Sep 17 00:00:00 2001 From: Luv Bansal Date: Tue, 9 Dec 2025 17:35:03 +0530 Subject: [PATCH 03/13] reuse connection and make it optimise --- clarifai/runners/models/agentic_class.py | 1355 +++++++++++----------- 1 file changed, 667 insertions(+), 688 deletions(-) diff --git a/clarifai/runners/models/agentic_class.py b/clarifai/runners/models/agentic_class.py index d479cbc2..7bfa67d5 100644 --- a/clarifai/runners/models/agentic_class.py +++ b/clarifai/runners/models/agentic_class.py @@ -5,7 +5,9 @@ import os import threading import time -from typing import Any, Dict, Iterator, List, Optional +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field +from typing import Any, Dict, Iterator, List, Optional, Set, Tuple from clarifai_grpc.grpc.api.status import status_code_pb2 from pydantic_core import from_json, to_json @@ -15,16 +17,39 @@ from clarifai.utils.logging import logger -class MCPConnectionPool: - """Thread-safe connection pool for MCP servers with persistent connections. +@dataclass +class MCPConnection: + """Represents a single MCP server connection with metadata.""" - This class manages MCP client connections across multiple requests, - maintaining persistent connections and handling reconnection when needed. - """ + client: Any + tools: List[Any] + tool_names: Set[str] # For O(1) tool lookup + last_used: float + connected_at: float + url: str + lock: threading.RLock = field(default_factory=threading.RLock) + use_count: int = 0 + + def mark_used(self): + """Mark connection as recently used.""" + self.last_used = time.time() + self.use_count += 1 + + +class MCPConnectionPool: + """Thread-safe connection pool for MCP servers with persistent connections.""" _instance: Optional['MCPConnectionPool'] = None _lock = threading.Lock() + # Pool configuration + DEFAULT_MAX_IDLE_TIME = 600 # 10 minutes, if mcp server idle time > DEFAULT_MAX_IDLE_TIME it disconnect from server + DEFAULT_CLEANUP_INTERVAL = 120 # 2 minute, + DEFAULT_VERIFY_THRESHOLD = 60 # Only verify if idle > 60 seconds + DEFAULT_CONNECT_TIMEOUT = 30 # Connection timeout + DEFAULT_TOOL_CALL_TIMEOUT = 60 # Tool call timeout + MAX_PARALLEL_CONNECTIONS = 10 # Max concurrent connection attempts + def __new__(cls): """Singleton pattern to ensure one connection pool per process.""" if cls._instance is None: @@ -34,85 +59,87 @@ def __new__(cls): cls._instance._initialized = False return cls._instance - def __init__(self): + def __init__( + self, + max_idle_time: float = DEFAULT_MAX_IDLE_TIME, + cleanup_interval: float = DEFAULT_CLEANUP_INTERVAL, + verify_threshold: float = DEFAULT_VERIFY_THRESHOLD, + ): if self._initialized: return - self._connections: Dict[ - str, Dict[str, Any] - ] = {} # url -> {client, tools, loop, last_used, lock} - self._connection_locks: Dict[str, threading.Lock] = {} # url -> lock for that connection - self._global_lock = threading.Lock() + self._connections: Dict[str, MCPConnection] = {} + self._global_lock = threading.RLock() self._loop: Optional[asyncio.AbstractEventLoop] = None self._loop_thread: Optional[threading.Thread] = None self._loop_lock = threading.Lock() - self._max_idle_time = 300 # 5 minutes idle timeout - self._cleanup_interval = 60 # Check for stale connections every minute + self._loop_ready = threading.Event() + + # Configuration + self._max_idle_time = max_idle_time + self._cleanup_interval = cleanup_interval + self._verify_threshold = verify_threshold self._last_cleanup = time.time() - self._initialized = True - def _get_or_create_event_loop(self) -> asyncio.AbstractEventLoop: - """Get or create a persistent event loop running in a background thread. + # Tool name to URL cache for O(1) lookup + self._tool_to_url_cache: Dict[str, str] = {} + self._cache_lock = threading.RLock() - This ensures MCP connections persist across request boundaries even when - the request's event loop is closed. - """ + # Thread pool for parallel operations + self._executor = ThreadPoolExecutor( + max_workers=self.MAX_PARALLEL_CONNECTIONS, thread_name_prefix="mcp_pool_" + ) + + # Start the background event loop immediately + self._start_event_loop() + + self._initialized = True + + def _start_event_loop(self): + """Start the persistent event loop in a background thread.""" with self._loop_lock: - # Check if we have a running loop if self._loop is not None and self._loop_thread is not None: if self._loop_thread.is_alive() and not self._loop.is_closed(): - return self._loop + return - # Create a new event loop in a background thread - loop_ready = threading.Event() + self._loop_ready.clear() def run_loop(): self._loop = asyncio.new_event_loop() asyncio.set_event_loop(self._loop) - loop_ready.set() + self._loop_ready.set() self._loop.run_forever() - self._loop_thread = threading.Thread(target=run_loop, daemon=True) + self._loop_thread = threading.Thread( + target=run_loop, daemon=True, name="mcp_event_loop" + ) self._loop_thread.start() - loop_ready.wait(timeout=5.0) # Wait for loop to be ready - - if self._loop is None: - raise RuntimeError("Failed to create event loop for MCP connections") - return self._loop + if not self._loop_ready.wait(timeout=10.0): + raise RuntimeError("Failed to start MCP event loop") - def _run_coroutine(self, coro) -> Any: - """Run a coroutine in the persistent event loop. + def _get_event_loop(self) -> asyncio.AbstractEventLoop: + """Get the persistent event loop, starting it if necessary.""" + if self._loop is None or self._loop.is_closed(): + self._start_event_loop() + return self._loop - Args: - coro: Coroutine to run - - Returns: - Result of the coroutine - """ - loop = self._get_or_create_event_loop() + def _run_coroutine(self, coro, timeout: float = DEFAULT_CONNECT_TIMEOUT) -> Any: + """Run a coroutine in the persistent event loop with timeout.""" + loop = self._get_event_loop() future = asyncio.run_coroutine_threadsafe(coro, loop) - return future.result(timeout=30.0) # 30 second timeout for operations - - def _get_connection_lock(self, url: str) -> threading.Lock: - """Get or create a lock for a specific URL.""" - with self._global_lock: - if url not in self._connection_locks: - self._connection_locks[url] = threading.Lock() - return self._connection_locks[url] + return future.result(timeout=timeout) - async def _connect_single_server( - self, url: str, max_retries: int = 2, retry_delay: float = 1.0 - ) -> Optional[Dict[str, Any]]: - """Connect to a single MCP server with retries. + def _run_coroutine_nowait(self, coro) -> asyncio.Future: + """Schedule a coroutine without waiting for result.""" + loop = self._get_event_loop() + return asyncio.run_coroutine_threadsafe(coro, loop) - Args: - url: MCP server URL - max_retries: Maximum retry attempts - retry_delay: Delay between retries in seconds + async def _create_client(self, url: str) -> Tuple[Any, List[Any]]: + """Create and connect a single MCP client. Returns: - Dictionary with client and tools, or None if connection failed + Tuple of (client, tools) """ try: from fastmcp import Client @@ -123,154 +150,204 @@ async def _connect_single_server( "Install it with: pip install fastmcp" ) + transport = StreamableHttpTransport( + url=url, + headers={"Authorization": "Bearer " + os.environ.get("CLARIFAI_PAT", "")}, + ) + + client = Client(transport) + await client.__aenter__() + tools = await client.list_tools() + + return client, tools + + async def _connect_single_server( + self, url: str, max_retries: int = 2, retry_delay: float = 1.0 + ) -> Optional[MCPConnection]: + """Connect to a single MCP server with retries.""" last_error = None for attempt in range(max_retries): try: - transport = StreamableHttpTransport( - url=url, - headers={"Authorization": "Bearer " + os.environ.get("CLARIFAI_PAT", "")}, + client, tools = await asyncio.wait_for( + self._create_client(url), timeout=self.DEFAULT_CONNECT_TIMEOUT ) - client = Client(transport) - await client.__aenter__() + # Build tool name set for O(1) lookup + tool_names = {tool.name for tool in tools} - # List available tools - tools_result = await client.list_tools() + connection = MCPConnection( + client=client, + tools=tools, + tool_names=tool_names, + last_used=time.time(), + connected_at=time.time(), + url=url, + ) - logger.info(f"✓ Connected to {url} with {len(tools_result)} tools") - - return { - "client": client, - "tools": tools_result, - "last_used": time.time(), - "connected_at": time.time(), - } + logger.info(f"✓ Connected to {url} with {len(tools)} tools") + return connection + except asyncio.TimeoutError: + last_error = TimeoutError( + f"Connection timeout after {self.DEFAULT_CONNECT_TIMEOUT}s" + ) + logger.warning( + f"⚠ Timeout connecting to {url} (attempt {attempt + 1}/{max_retries})" + ) except Exception as e: last_error = e if attempt < max_retries - 1: logger.warning( - f"⚠ Failed to connect to {url} (attempt {attempt + 1}/{max_retries}): {e}. " - f"Retrying in {retry_delay}s..." + f"⚠ Failed to connect to {url} (attempt {attempt + 1}/{max_retries}): {e}" ) await asyncio.sleep(retry_delay) - else: - logger.error( - f"❌ Failed to connect to {url} after {max_retries} attempts: {e}" - ) + logger.error(f"❌ Failed to connect to {url} after {max_retries} attempts: {last_error}") return None - async def _verify_connection(self, url: str, connection_info: Dict[str, Any]) -> bool: - """Verify that a connection is still valid. + async def _connect_servers_parallel( + self, urls: List[str], max_retries: int = 2, retry_delay: float = 1.0 + ) -> Dict[str, MCPConnection]: + """Connect to multiple servers in parallel.""" + if not urls: + return {} - Args: - url: MCP server URL - connection_info: Connection info dictionary + tasks = [self._connect_single_server(url, max_retries, retry_delay) for url in urls] - Returns: - True if connection is valid, False otherwise - """ + results = await asyncio.gather(*tasks, return_exceptions=True) + + connections = {} + for url, result in zip(urls, results): + if isinstance(result, Exception): + logger.error(f"❌ Error connecting to {url}: {result}") + elif result is not None: + connections[url] = result + + return connections + + async def _verify_connection_async(self, connection: MCPConnection) -> bool: + """Verify that a connection is still valid.""" try: - client = connection_info["client"] - # Try to list tools as a health check - await asyncio.wait_for(client.list_tools(), timeout=5.0) + await asyncio.wait_for(connection.client.list_tools(), timeout=5.0) return True except Exception as e: - logger.warning(f"⚠ Connection to {url} is no longer valid: {e}") + logger.warning(f"⚠ Connection to {connection.url} is no longer valid: {e}") return False - async def _disconnect_single_server(self, url: str, connection_info: Dict[str, Any]): - """Disconnect from a single MCP server. - - Args: - url: MCP server URL - connection_info: Connection info dictionary - """ + async def _disconnect_async(self, connection: MCPConnection): + """Disconnect from a single MCP server.""" try: - client = connection_info["client"] + client = connection.client if hasattr(client, 'close') and callable(getattr(client, 'close', None)): if asyncio.iscoroutinefunction(client.close): - await client.close() + await asyncio.wait_for(client.close(), timeout=5.0) else: client.close() else: - await client.__aexit__(None, None, None) - logger.info(f"✓ Disconnected from {url}") + await asyncio.wait_for(client.__aexit__(None, None, None), timeout=5.0) + logger.info(f"✓ Disconnected from {connection.url}") except Exception as e: - logger.warning(f"⚠ Error disconnecting from {url}: {e}") + logger.warning(f"⚠ Error disconnecting from {connection.url}: {e}") + + def _should_verify_connection(self, connection: MCPConnection) -> bool: + """Check if connection needs verification based on idle time.""" + idle_time = time.time() - connection.last_used + return idle_time > self._verify_threshold + + def _update_tool_cache(self, connections: Dict[str, MCPConnection]): + """Update the tool-to-URL cache from connections.""" + with self._cache_lock: + for url, conn in connections.items(): + for tool_name in conn.tool_names: + self._tool_to_url_cache[tool_name] = url + + def _invalidate_tool_cache(self, url: str): + """Remove tools from cache when a connection is removed.""" + with self._cache_lock: + self._tool_to_url_cache = { + name: cached_url + for name, cached_url in self._tool_to_url_cache.items() + if cached_url != url + } def get_connections( self, mcp_servers: List[str], max_retries: int = 2, retry_delay: float = 1.0 - ) -> Dict[str, Any]: + ) -> Dict[str, MCPConnection]: """Get connections for the specified MCP servers. - This method reuses existing connections when possible and creates - new ones as needed. Thread-safe. - - Args: - mcp_servers: List of MCP server URLs - max_retries: Maximum retry attempts for new connections - retry_delay: Delay between retries - - Returns: - Dictionary mapping server URLs to client info and tools + Uses lazy verification - only verifies connections that have been + idle longer than the verification threshold. """ - # Periodic cleanup of stale connections self._maybe_cleanup_stale_connections() result = {} urls_to_connect = [] + urls_to_verify = [] - # First pass: get existing valid connections - for url in mcp_servers: - lock = self._get_connection_lock(url) - with lock: + # First pass: categorize URLs + with self._global_lock: + for url in mcp_servers: if url in self._connections: - connection_info = self._connections[url] - # Check if connection is still valid - try: - is_valid = self._run_coroutine( - self._verify_connection(url, connection_info) - ) - if is_valid: - connection_info["last_used"] = time.time() - result[url] = connection_info - logger.debug(f"Reusing existing connection to {url}") - continue - else: - # Connection is stale, remove it - del self._connections[url] - except Exception as e: - logger.warning(f"⚠ Error verifying connection to {url}: {e}") - # Remove potentially stale connection - if url in self._connections: - del self._connections[url] - - urls_to_connect.append(url) + connection = self._connections[url] + if self._should_verify_connection(connection): + urls_to_verify.append(url) + else: + # Recently used, assume still valid + connection.mark_used() + result[url] = connection + else: + urls_to_connect.append(url) - # Second pass: connect to servers that need new connections - if urls_to_connect: + # Verify stale connections in parallel + if urls_to_verify: - async def connect_servers(): + async def verify_all(): tasks = [] - for url in urls_to_connect: - tasks.append(self._connect_single_server(url, max_retries, retry_delay)) + for url in urls_to_verify: + conn = self._connections.get(url) + if conn: + tasks.append(self._verify_connection_async(conn)) + else: + tasks.append(asyncio.coroutine(lambda: False)()) return await asyncio.gather(*tasks, return_exceptions=True) try: - results = self._run_coroutine(connect_servers()) + verify_results = self._run_coroutine(verify_all(), timeout=15.0) + + with self._global_lock: + for url, is_valid in zip(urls_to_verify, verify_results): + if isinstance(is_valid, Exception) or not is_valid: + # Connection is stale, need to reconnect + if url in self._connections: + self._invalidate_tool_cache(url) + del self._connections[url] + urls_to_connect.append(url) + else: + # Connection is valid + connection = self._connections[url] + connection.mark_used() + result[url] = connection + except Exception as e: + logger.error(f"❌ Error verifying connections: {e}") + # On verification failure, try to reconnect all + urls_to_connect.extend(urls_to_verify) + + # Connect to new servers in parallel + if urls_to_connect: + try: + new_connections = self._run_coroutine( + self._connect_servers_parallel(urls_to_connect, max_retries, retry_delay), + timeout=self.DEFAULT_CONNECT_TIMEOUT * max_retries + 10, + ) + + with self._global_lock: + self._connections.update(new_connections) + result.update(new_connections) + + # Update tool cache + self._update_tool_cache(new_connections) - for url, connection_result in zip(urls_to_connect, results): - if isinstance(connection_result, Exception): - logger.error(f"❌ Failed to connect to {url}: {connection_result}") - continue - if connection_result is not None: - lock = self._get_connection_lock(url) - with lock: - self._connections[url] = connection_result - result[url] = connection_result except Exception as e: logger.error(f"❌ Error connecting to MCP servers: {e}") @@ -278,81 +355,194 @@ async def connect_servers(): def get_tools_and_mapping( self, mcp_servers: List[str] - ) -> tuple[List[dict], Dict[str, Any], Dict[str, str]]: + ) -> Tuple[List[dict], Dict[str, MCPConnection], Dict[str, str]]: """Get tools and server mapping for the specified MCP servers. - Args: - mcp_servers: List of MCP server URLs - Returns: - Tuple of (tools in OpenAI format, mcp_clients dictionary, tool_to_server mapping) + Tuple of (tools in OpenAI format, connections dict, tool_to_server mapping) """ - mcp_clients = self.get_connections(mcp_servers) + connections = self.get_connections(mcp_servers) all_tools = [] tool_to_server = {} + seen_tools = set() # Avoid duplicate tools - for mcp_url, server_info in mcp_clients.items(): - tools = server_info.get("tools", []) - for tool in tools: - tool_name = tool.name - all_tools.append( - { - "type": "function", - "function": { - "name": tool_name, - "description": f"{tool.description}", - "parameters": tool.inputSchema, - }, - } - ) - tool_to_server[tool_name] = mcp_url + for url, conn in connections.items(): + for tool in conn.tools: + if tool.name not in seen_tools: + seen_tools.add(tool.name) + all_tools.append( + { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description or "", + "parameters": tool.inputSchema, + }, + } + ) + tool_to_server[tool.name] = url + + logger.info(f"Access to {len(all_tools)} tools from {len(connections)} servers") + return all_tools, connections, tool_to_server + + def get_cached_tool_url(self, tool_name: str) -> Optional[str]: + """Get URL for a tool from cache. O(1) lookup.""" + with self._cache_lock: + return self._tool_to_url_cache.get(tool_name) + + async def call_tool_async( + self, + tool_name: str, + arguments: Dict[str, Any], + connections: Dict[str, MCPConnection], + tool_to_server: Dict[str, str], + ) -> Any: + """Call a tool on the appropriate MCP server.""" + # Try cached/mapped server first + server_url = tool_to_server.get(tool_name) or self.get_cached_tool_url(tool_name) + + if server_url and server_url in connections: + conn = connections[server_url] + with conn.lock: + try: + logger.info(f"Calling tool {tool_name} on {server_url}") + result = await asyncio.wait_for( + conn.client.call_tool(tool_name, arguments=arguments), + timeout=self.DEFAULT_TOOL_CALL_TIMEOUT, + ) + conn.mark_used() + return result + except asyncio.TimeoutError: + logger.error(f"❌ Timeout calling tool {tool_name}") + raise + except Exception as e: + logger.error(f"❌ Error calling tool {tool_name}: {e}") + # Fall through to try other servers + + # Fallback: find server with this tool + for url, conn in connections.items(): + if url == server_url: + continue # Already tried + if tool_name in conn.tool_names: + with conn.lock: + try: + logger.info(f"Calling tool {tool_name} on {url} (fallback)") + result = await asyncio.wait_for( + conn.client.call_tool(tool_name, arguments=arguments), + timeout=self.DEFAULT_TOOL_CALL_TIMEOUT, + ) + conn.mark_used() + # Update cache for future lookups + with self._cache_lock: + self._tool_to_url_cache[tool_name] = url + return result + except Exception as e: + logger.error(f"❌ Error calling tool {tool_name} on {url}: {e}") + continue + + raise Exception(f"Tool {tool_name} not found on any connected server") - logger.info(f"Access to {len(all_tools)} tools from {len(mcp_clients)} servers") - return all_tools, mcp_clients, tool_to_server + def call_tool( + self, + tool_name: str, + arguments: Dict[str, Any], + connections: Dict[str, MCPConnection], + tool_to_server: Dict[str, str], + ) -> Any: + """Synchronous wrapper for call_tool_async.""" + return self._run_coroutine( + self.call_tool_async(tool_name, arguments, connections, tool_to_server), + timeout=self.DEFAULT_TOOL_CALL_TIMEOUT + 5, + ) + + async def call_tools_parallel( + self, + tool_calls: List[Tuple[str, Dict[str, Any]]], + connections: Dict[str, MCPConnection], + tool_to_server: Dict[str, str], + ) -> List[Tuple[str, Any, Optional[Exception]]]: + """Execute multiple tool calls in parallel. + + Args: + tool_calls: List of (tool_name, arguments) tuples + connections: Connection dictionary + tool_to_server: Tool to server mapping + + Returns: + List of (tool_name, result, exception) tuples + """ + + async def call_single(tool_name: str, args: Dict[str, Any]): + try: + result = await self.call_tool_async(tool_name, args, connections, tool_to_server) + return (tool_name, result, None) + except Exception as e: + return (tool_name, None, e) + + tasks = [call_single(name, args) for name, args in tool_calls] + return await asyncio.gather(*tasks) def _maybe_cleanup_stale_connections(self): """Clean up connections that have been idle for too long.""" current_time = time.time() - # Only run cleanup periodically if current_time - self._last_cleanup < self._cleanup_interval: return self._last_cleanup = current_time - urls_to_remove = [] with self._global_lock: - for url, connection_info in self._connections.items(): - last_used = connection_info.get("last_used", 0) - if current_time - last_used > self._max_idle_time: - urls_to_remove.append(url) + urls_to_remove = [ + url + for url, conn in self._connections.items() + if current_time - conn.last_used > self._max_idle_time + ] for url in urls_to_remove: - self.disconnect(url) + self._disconnect_url(url) - def disconnect(self, url: str): - """Disconnect from a specific MCP server. + def _disconnect_url(self, url: str): + """Disconnect from a specific URL.""" + with self._global_lock: + connection = self._connections.pop(url, None) - Args: - url: MCP server URL to disconnect from - """ - lock = self._get_connection_lock(url) - with lock: - if url in self._connections: - connection_info = self._connections.pop(url) - try: - self._run_coroutine(self._disconnect_single_server(url, connection_info)) - except Exception as e: - logger.warning(f"⚠ Error during disconnect from {url}: {e}") + if connection: + self._invalidate_tool_cache(url) + try: + self._run_coroutine(self._disconnect_async(connection), timeout=10.0) + except Exception as e: + logger.warning(f"⚠ Error during disconnect from {url}: {e}") + + def disconnect(self, url: str): + """Public method to disconnect from a specific MCP server.""" + self._disconnect_url(url) def disconnect_all(self): - """Disconnect from all MCP servers.""" + """Disconnect from all MCP servers and cleanup.""" with self._global_lock: urls = list(self._connections.keys()) - for url in urls: - self.disconnect(url) + # Disconnect all in parallel + async def disconnect_all_async(): + tasks = [] + for url in urls: + conn = self._connections.get(url) + if conn: + tasks.append(self._disconnect_async(conn)) + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + + try: + self._run_coroutine(disconnect_all_async(), timeout=30.0) + except Exception as e: + logger.warning(f"⚠ Error during bulk disconnect: {e}") + + with self._global_lock: + self._connections.clear() + + with self._cache_lock: + self._tool_to_url_cache.clear() # Stop the event loop with self._loop_lock: @@ -363,140 +553,50 @@ def disconnect_all(self): self._loop = None self._loop_thread = None - def call_tool( - self, - tool_name: str, - arguments: Dict[str, Any], - mcp_clients: Dict[str, Any], - tool_to_server: Dict[str, str], - ) -> Any: - """Call a tool on the appropriate MCP server. + # Shutdown executor + self._executor.shutdown(wait=False) - Args: - tool_name: Name of the tool to call - arguments: Arguments to pass to the tool - mcp_clients: Dictionary of MCP clients - tool_to_server: Mapping of tool names to server URLs + def warm_up(self, mcp_servers: List[str]): + """Pre-establish connections to servers. - Returns: - Tool call result + Call this during initialization to avoid connection latency on first request. """ + logger.info(f"Warming up connections to {len(mcp_servers)} MCP servers") + self.get_connections(mcp_servers) - async def _call_tool(): - result = None - error_msg = None - - # Try the mapped server first - if tool_to_server and tool_name in tool_to_server: - server_url = tool_to_server[tool_name] - if server_url in mcp_clients: - try: - logger.info(f"Calling tool {tool_name} with arguments {arguments}") - result = await mcp_clients[server_url]["client"].call_tool( - tool_name, arguments=arguments - ) - return result - except Exception as e: - error_msg = str(e) - logger.error(f"❌ Error calling tool {tool_name}: {e}") - - # Fallback: try all servers - for server_url, server_info in mcp_clients.items(): - if tool_to_server and tool_name in tool_to_server: - if tool_to_server[tool_name] == server_url: - continue # Already tried this one - try: - logger.info(f"Calling tool {tool_name} with arguments {arguments}") - result = await server_info["client"].call_tool(tool_name, arguments=arguments) - return result - except Exception as e: - error_msg = str(e) - logger.error(f"❌ Error calling tool {tool_name}: {e}") - continue - - raise Exception( - f"Failed to execute tool {tool_name}. " - f"{error_msg if error_msg else 'Tool not found on any server.'}" - ) - - return self._run_coroutine(_call_tool()) - - async def call_tool_async( - self, - tool_name: str, - arguments: Dict[str, Any], - mcp_clients: Dict[str, Any], - tool_to_server: Dict[str, str], - ) -> Any: - """Async version of call_tool for use within async contexts. - - Args: - tool_name: Name of the tool to call - arguments: Arguments to pass to the tool - mcp_clients: Dictionary of MCP clients - tool_to_server: Mapping of tool names to server URLs - - Returns: - Tool call result - """ - result = None - error_msg = None - - # Try the mapped server first - if tool_to_server and tool_name in tool_to_server: - server_url = tool_to_server[tool_name] - if server_url in mcp_clients: - try: - logger.info(f"Calling tool {tool_name} with arguments {arguments}") - result = await mcp_clients[server_url]["client"].call_tool( - tool_name, arguments=arguments - ) - return result - except Exception as e: - error_msg = str(e) - logger.error(f"❌ Error calling tool {tool_name}: {e}") - - # Fallback: try all servers - for server_url, server_info in mcp_clients.items(): - if tool_to_server and tool_name in tool_to_server: - if tool_to_server[tool_name] == server_url: - continue - try: - logger.info(f"Calling tool {tool_name} with arguments {arguments}") - result = await server_info["client"].call_tool(tool_name, arguments=arguments) - return result - except Exception as e: - error_msg = str(e) - logger.error(f"❌ Error calling tool {tool_name}: {e}") - continue + def get_stats(self) -> Dict[str, Any]: + """Get pool statistics for monitoring.""" + with self._global_lock: + connections_info = [] + for url, conn in self._connections.items(): + connections_info.append( + { + "url": url, + "tools_count": len(conn.tools), + "use_count": conn.use_count, + "idle_seconds": time.time() - conn.last_used, + "connected_seconds": time.time() - conn.connected_at, + } + ) - raise Exception( - f"Failed to execute tool {tool_name}. " - f"{error_msg if error_msg else 'Tool not found on any server.'}" - ) + return { + "total_connections": len(self._connections), + "cached_tools": len(self._tool_to_url_cache), + "connections": connections_info, + } class AgenticModelClass(OpenAIModelClass): - """Base class for wrapping OpenAI-compatible servers with MCP (Model Context Protocol) support. - - This class extends OpenAIModelClass to enable agentic behavior by integrating LLMs with MCP servers. - It handles tool discovery, execution, and iterative tool calling for both chat completions and - responses endpoints, supporting both streaming and non-streaming modes. - - MCP connections are maintained persistently across requests using a connection pool, which - significantly improves performance by avoiding reconnection overhead. - - To use this class, create a subclass and set the following class attributes: - - client: The OpenAI-compatible client instance - - model: The name of the model to use with the client - - Example: - class MyAgenticModel(AgenticModelClass): - client = OpenAI(api_key="your-key") - model = "gpt-4" + """Base class for wrapping OpenAI-compatible servers with MCP support. + + Optimizations over base implementation: + - Persistent connection pool across requests + - Parallel tool execution + - Tool name caching for O(1) lookup + - Lazy connection verification + - Efficient streaming with queue-based async bridging """ - # Singleton connection pool shared across all instances _mcp_pool: Optional[MCPConnectionPool] = None _pool_lock = threading.Lock() @@ -509,18 +609,16 @@ def _get_mcp_pool(cls) -> MCPConnectionPool: cls._mcp_pool = MCPConnectionPool() return cls._mcp_pool - def _get_mcp_tools_and_clients(self, mcp_servers: List[str]) -> tuple[List[dict], dict, dict]: - """Get available tools and clients from all connected MCP servers. - - This method uses the connection pool to reuse existing connections - when possible, significantly improving performance. - - Args: - mcp_servers: List of MCP server URLs + @classmethod + def warm_up_mcp(cls, mcp_servers: List[str]): + """Pre-establish MCP connections during model initialization.""" + pool = cls._get_mcp_pool() + pool.warm_up(mcp_servers) - Returns: - A tuple of (tools in OpenAI format, mcp_clients dictionary, tool_to_server mapping). - """ + def _get_mcp_tools_and_clients( + self, mcp_servers: List[str] + ) -> Tuple[List[dict], Dict[str, MCPConnection], Dict[str, str]]: + """Get available tools and clients from MCP servers.""" pool = self._get_mcp_pool() return pool.get_tools_and_mapping(mcp_servers) @@ -530,14 +628,7 @@ def _init_token_accumulation(self): self._thread_local.accumulated_tokens = {'prompt_tokens': 0, 'completion_tokens': 0} def _accumulate_usage(self, resp): - """Accumulate token usage from response object without calling set_output_context. - - This method extracts tokens from the response and adds them to the accumulated total. - It should be called for each API response in a multi-call request flow. - - Args: - resp: Response object with usage information - """ + """Accumulate token usage from response object.""" has_usage = getattr(resp, "usage", None) has_response_usage = getattr(resp, "response", None) and getattr( resp.response, "usage", None @@ -557,10 +648,8 @@ def _accumulate_usage(self, resp): prompt_tokens = getattr(resp.response.usage, "input_tokens", 0) completion_tokens = getattr(resp.response.usage, "output_tokens", 0) - if prompt_tokens is None: - prompt_tokens = 0 - if completion_tokens is None: - completion_tokens = 0 + prompt_tokens = prompt_tokens or 0 + completion_tokens = completion_tokens or 0 if prompt_tokens > 0 or completion_tokens > 0: self._init_token_accumulation() @@ -596,7 +685,6 @@ def _handle_chat_completions( request_data = request_data.copy() request_data["tools"] = tools request_data["tool_choice"] = request_data.get("tool_choice", "auto") - return super()._handle_chat_completions(request_data) def _convert_tools_to_response_api_format(self, tools: List[dict]) -> List[dict]: @@ -632,7 +720,6 @@ def _handle_responses( response_api_tools = self._convert_tools_to_response_api_format(tools) request_data["tools"] = response_api_tools request_data["tool_choice"] = request_data.get("tool_choice", "auto") - return super()._handle_responses(request_data) def _route_request( @@ -646,25 +733,22 @@ def _route_request( """Route the request to appropriate handler based on endpoint.""" if endpoint == self.ENDPOINT_CHAT_COMPLETIONS: return self._handle_chat_completions(request_data, mcp_servers, mcp_clients, tools) - if endpoint == self.ENDPOINT_RESPONSES: return self._handle_responses(request_data, mcp_servers, mcp_clients, tools) - return super()._route_request(endpoint, request_data) def _execute_tool_calls( self, tool_calls: List[Any], - mcp_clients: dict, + connections: Dict[str, MCPConnection], messages: List[dict], - tool_to_server: dict = None, + tool_to_server: Dict[str, str], ): - """Execute tool calls from chat completion and add results to messages. - - Uses the connection pool for tool execution. - """ + """Execute tool calls from chat completion and add results to messages.""" pool = self._get_mcp_pool() + # Prepare tool calls for potential parallel execution + parsed_calls = [] for tool_call in tool_calls: if hasattr(tool_call, 'function'): tool_name = tool_call.function.name @@ -674,9 +758,12 @@ def _execute_tool_calls( tool_name = tool_call['function']['name'] tool_args = json.loads(tool_call['function']['arguments']) tool_id = tool_call['id'] + parsed_calls.append((tool_id, tool_name, tool_args)) + # Execute tools (could be parallelized for independent tools) + for tool_id, tool_name, tool_args in parsed_calls: try: - result = pool.call_tool(tool_name, tool_args, mcp_clients, tool_to_server) + result = pool.call_tool(tool_name, tool_args, connections, tool_to_server) content = ( result.content[0].text if hasattr(result, 'content') else str(result[0].text) ) @@ -694,13 +781,15 @@ def _execute_tool_calls( async def _execute_tool_calls_async( self, tool_calls: List[Any], - mcp_clients: dict, + connections: Dict[str, MCPConnection], messages: List[dict], - tool_to_server: dict = None, + tool_to_server: Dict[str, str], ): - """Async version of _execute_tool_calls for streaming contexts.""" + """Async version with parallel tool execution support.""" pool = self._get_mcp_pool() + # Parse all tool calls + parsed_calls = [] for tool_call in tool_calls: if hasattr(tool_call, 'function'): tool_name = tool_call.function.name @@ -710,16 +799,20 @@ async def _execute_tool_calls_async( tool_name = tool_call['function']['name'] tool_args = json.loads(tool_call['function']['arguments']) tool_id = tool_call['id'] + parsed_calls.append((tool_id, tool_name, tool_args)) - try: - result = await pool.call_tool_async( - tool_name, tool_args, mcp_clients, tool_to_server - ) + # Execute all tools in parallel + tool_inputs = [(name, args) for _, name, args in parsed_calls] + results = await pool.call_tools_parallel(tool_inputs, connections, tool_to_server) + + # Map results back to tool IDs + for (tool_id, tool_name, _), (_, result, error) in zip(parsed_calls, results): + if error: + content = f"Error: {str(error)}" + else: content = ( result.content[0].text if hasattr(result, 'content') else str(result[0].text) ) - except Exception as e: - content = f"Error: {str(e)}" messages.append( { @@ -732,9 +825,9 @@ async def _execute_tool_calls_async( def _execute_response_api_tool_calls( self, tool_calls: List[Dict[str, Any]], - mcp_clients: dict, + connections: Dict[str, MCPConnection], input_items: List[Any], - tool_to_server: dict = None, + tool_to_server: Dict[str, str], ): """Execute tool calls from response API and add results to input items.""" pool = self._get_mcp_pool() @@ -753,16 +846,16 @@ def _execute_response_api_tool_calls( tool_args = {} try: - result = pool.call_tool(tool_name, tool_args, mcp_clients, tool_to_server) + result = pool.call_tool(tool_name, tool_args, connections, tool_to_server) content = ( result.content[0].text if hasattr(result, 'content') else str(result[0].text) ) except Exception as e: content = f"Error: {str(e)}" - output_call_id = call_id if call_id else tool_id + output_call_id = call_id or tool_id if not output_call_id: - logger.warning(f"⚠ No call_id or id found for tool {tool_name}, skipping output") + logger.warning(f"⚠ No call_id or id found for tool {tool_name}, skipping") continue input_items.append( @@ -776,13 +869,15 @@ def _execute_response_api_tool_calls( async def _execute_response_api_tool_calls_async( self, tool_calls: List[Dict[str, Any]], - mcp_clients: dict, + connections: Dict[str, MCPConnection], input_items: List[Any], - tool_to_server: dict = None, + tool_to_server: Dict[str, str], ): - """Async version of _execute_response_api_tool_calls for streaming contexts.""" + """Async version with parallel tool execution support.""" pool = self._get_mcp_pool() + # Parse all tool calls + parsed_calls = [] for tool_call in tool_calls: tool_name = tool_call.get("name") tool_args_str = tool_call.get("arguments", "{}") @@ -796,19 +891,24 @@ async def _execute_response_api_tool_calls_async( except json.JSONDecodeError: tool_args = {} - try: - result = await pool.call_tool_async( - tool_name, tool_args, mcp_clients, tool_to_server - ) + parsed_calls.append((tool_name, tool_args, tool_id, call_id)) + + # Execute all tools in parallel + tool_inputs = [(name, args) for name, args, _, _ in parsed_calls] + results = await pool.call_tools_parallel(tool_inputs, connections, tool_to_server) + + # Map results back + for (tool_name, _, tool_id, call_id), (_, result, error) in zip(parsed_calls, results): + if error: + content = f"Error: {str(error)}" + else: content = ( result.content[0].text if hasattr(result, 'content') else str(result[0].text) ) - except Exception as e: - content = f"Error: {str(e)}" - output_call_id = call_id if call_id else tool_id + output_call_id = call_id or tool_id if not output_call_id: - logger.warning(f"⚠ No call_id or id found for tool {tool_name}, skipping output") + logger.warning(f"⚠ No call_id or id found for tool {tool_name}, skipping") continue input_items.append( @@ -860,7 +960,6 @@ def _convert_output_items_to_input_items( continue item_type = item.get("type") - if item_type in ["message", "reasoning"]: input_items.append(item) elif item_type in ["function_tool_call", "function_call", "function", "tool_call"]: @@ -868,7 +967,6 @@ def _convert_output_items_to_input_items( output = item.get("output") if output is not None or status in ["completed", "done"]: input_items.append(item) - return input_items def _accumulate_tool_call_delta(self, tool_call_delta, tool_calls_accumulated: dict): @@ -892,20 +990,19 @@ def _accumulate_tool_call_delta(self, tool_call_delta, tool_calls_accumulated: d def _convert_accumulated_tool_calls(self, tool_calls_accumulated: dict) -> List[dict]: """Convert accumulated tool calls dictionary to list format.""" - tool_calls_list = [] - for idx in sorted(tool_calls_accumulated.keys()): - tc = tool_calls_accumulated[idx] - tool_calls_list.append( - { - "id": tc["id"], - "type": tc["type"], - "function": { - "name": tc["function"]["name"], - "arguments": tc["function"]["arguments"], - }, - } + return [ + { + "id": tc["id"], + "type": tc["type"], + "function": { + "name": tc["function"]["name"], + "arguments": tc["function"]["arguments"], + }, + } + for tc in ( + tool_calls_accumulated[idx] for idx in sorted(tool_calls_accumulated.keys()) ) - return tool_calls_list + ] def _create_completion_request( self, @@ -933,14 +1030,14 @@ def _create_completion_request( return self.client.chat.completions.create(**kwargs) def _bridge_async_generator(self, async_gen_func): - """Bridge an async generator to a sync generator using the pool's event loop.""" + """Bridge an async generator to a sync generator using efficient queue-based approach.""" pool = self._get_mcp_pool() - loop = pool._get_or_create_event_loop() + loop = pool._get_event_loop() - # Create a queue for communication between async generator and sync iteration - queue = asyncio.Queue() - done_event = threading.Event() - exception_holder = [None] + # Use a bounded queue to apply backpressure + queue: asyncio.Queue = asyncio.Queue(maxsize=100) + done = threading.Event() + exception_holder: List[Optional[Exception]] = [None] async def producer(): try: @@ -949,48 +1046,43 @@ async def producer(): except Exception as e: exception_holder[0] = e finally: - await queue.put(None) # Sentinel to signal completion - done_event.set() + await queue.put(None) # Sentinel + done.set() - # Start the producer in the background loop - future = asyncio.run_coroutine_threadsafe(producer(), loop) + # Start producer + asyncio.run_coroutine_threadsafe(producer(), loop) try: while True: - # Get from queue with timeout to allow checking for exceptions - get_future = asyncio.run_coroutine_threadsafe(queue.get(), loop) + # Get from queue with timeout try: - item = get_future.result(timeout=30.0) - except Exception as e: + get_future = asyncio.run_coroutine_threadsafe(queue.get(), loop) + item = get_future.result(timeout=60.0) + except Exception: if exception_holder[0]: raise exception_holder[0] raise - if item is None: # Sentinel + if item is None: break yield item - # Check if there was an exception in the producer if exception_holder[0]: raise exception_holder[0] finally: - # Ensure the future is done - try: - future.result(timeout=1.0) - except: - pass + done.wait(timeout=1.0) async def _stream_with_mcp_tools_json( self, openai_messages: List[dict], tools: List[dict], - mcp_clients: dict, + connections: Dict[str, MCPConnection], max_tokens: int, temperature: float, top_p: float, - tool_to_server: dict = None, + tool_to_server: Dict[str, str], ): - """Async generator to handle MCP tool calls with streaming support for chat completions.""" + """Async generator for streaming chat completions with MCP tools.""" tool_calls_accumulated = {} streaming_response = "" @@ -1015,16 +1107,18 @@ async def _stream_with_mcp_tools_json( openai_messages.append( { "role": "assistant", - "content": streaming_response if streaming_response else None, + "content": streaming_response or None, "tool_calls": tool_calls_list, } ) + + # Execute tools in parallel await self._execute_tool_calls_async( - tool_calls_list, mcp_clients, openai_messages, tool_to_server + tool_calls_list, connections, openai_messages, tool_to_server ) async for chunk_json in self._stream_with_mcp_tools_json( - openai_messages, tools, mcp_clients, max_tokens, temperature, top_p, tool_to_server + openai_messages, tools, connections, max_tokens, temperature, top_p, tool_to_server ): yield chunk_json @@ -1032,10 +1126,10 @@ async def _stream_responses_with_mcp_tools_json( self, request_data: Dict[str, Any], tools: List[dict], - mcp_clients: dict, - tool_to_server: dict = None, + connections: Dict[str, MCPConnection], + tool_to_server: Dict[str, str], ): - """Async generator to handle MCP tool calls with streaming support for response API.""" + """Async generator for streaming response API with MCP tools.""" input_data = request_data.get("input", "") if isinstance(input_data, str): input_items = [ @@ -1061,223 +1155,114 @@ async def _stream_responses_with_mcp_tools_json( for chunk in stream: self._set_usage(chunk) - chunk_type = getattr(chunk, 'type', None) or chunk.__class__.__name__ should_yield = True item_to_check = None - if ( - chunk_type == 'response.output_item.added' - or chunk_type == 'ResponseOutputItemAddedEvent' + # Process chunk based on type (condensed for brevity - same logic as before) + if chunk_type in ( + 'response.output_item.added', + 'ResponseOutputItemAddedEvent', ) and hasattr(chunk, 'item'): item_to_check = chunk.item if hasattr(chunk, 'output_index'): - item_dict = ( - item_to_check - if isinstance(item_to_check, dict) - else ( - item_to_check.model_dump() - if hasattr(item_to_check, 'model_dump') - else item_to_check.dict() - if hasattr(item_to_check, 'dict') - else {} - ) - ) - item_type = item_dict.get("type") - if item_type == "message": - original_index = chunk.output_index - original_to_filtered_index_map[original_index] = len( + item_dict = self._to_dict(item_to_check) + if item_dict.get("type") == "message": + original_to_filtered_index_map[chunk.output_index] = len( original_to_filtered_index_map ) - elif ( - chunk_type == 'response.output_item.done' - or chunk_type == 'ResponseOutputItemDoneEvent' + + # Track tool calls + if item_dict.get("type") in [ + "function_tool_call", + "function_call", + "function", + "tool_call", + ]: + item_id = item_dict.get("id") or item_dict.get("call_id") + if item_id: + tool_calls_accumulated[item_id] = { + "id": item_id, + "call_id": item_dict.get("call_id"), + "type": item_dict.get("type"), + "name": item_dict.get("name", ""), + "arguments": item_dict.get("arguments", ""), + "status": item_dict.get("status", "in_progress"), + } + + elif chunk_type in ( + 'response.output_item.done', + 'ResponseOutputItemDoneEvent', ) and hasattr(chunk, 'item'): item_to_check = chunk.item + elif hasattr(chunk, 'output_index'): - original_index = chunk.output_index - if original_index not in original_to_filtered_index_map: + if chunk.output_index not in original_to_filtered_index_map: should_yield = False + # Filter non-message items if item_to_check: - item_dict = ( - item_to_check - if isinstance(item_to_check, dict) - else ( - item_to_check.model_dump() - if hasattr(item_to_check, 'model_dump') - else item_to_check.dict() - if hasattr(item_to_check, 'dict') - else {} - ) - ) - item_type = item_dict.get("type") - if item_type != "message": + item_dict = self._to_dict(item_to_check) + if item_dict.get("type") != "message": should_yield = False - if ( - chunk_type == 'response.completed' or chunk_type == 'ResponseCompletedEvent' - ) and hasattr(chunk, 'response'): + # Handle response.completed + if chunk_type in ('response.completed', 'ResponseCompletedEvent') and hasattr( + chunk, 'response' + ): response = chunk.response if hasattr(response, 'output') and response.output: - filtered_output = [] - filtered_index = 0 - original_to_filtered_index_map.clear() - - for original_index, item in enumerate(response.output): - item_dict = ( - item - if isinstance(item, dict) - else ( - item.model_dump() - if hasattr(item, 'model_dump') - else item.dict() - if hasattr(item, 'dict') - else {} - ) - ) - item_type = item_dict.get("type") - if item_type == "message": - filtered_output.append(item_dict) - original_to_filtered_index_map[original_index] = filtered_index - filtered_index += 1 - elif item_type in [ - "function_tool_call", - "function_call", - "function", - "tool_call", - ]: - item_id = item_dict.get("id") - if item_id: - existing_ids = [ - i.get("id") - if isinstance(i, dict) - else (getattr(i, "id", None) if hasattr(i, "id") else None) - for i in accumulated_output - ] - if item_id not in existing_ids: - accumulated_output.append(item_dict) - else: - accumulated_output.append(item_dict) - else: - item_id = item_dict.get("id") - if item_id: - existing_ids = [ - i.get("id") - if isinstance(i, dict) - else (getattr(i, "id", None) if hasattr(i, "id") else None) - for i in accumulated_output - ] - if item_id not in existing_ids: - accumulated_output.append(item_dict) - else: - accumulated_output.append(item_dict) - - response_dict = ( - response.model_dump() - if hasattr(response, 'model_dump') - else response.dict() - if hasattr(response, 'dict') - else {} + filtered_output, accumulated_output = self._process_response_output( + response.output, accumulated_output, tool_calls_accumulated ) + + response_dict = self._to_dict(response) response_dict["output"] = filtered_output - modified_chunk_dict = { - "type": "response.completed", - "sequence_number": getattr(chunk, 'sequence_number', None), - "response": response_dict, - } - yield json.dumps(modified_chunk_dict) + yield json.dumps( + { + "type": "response.completed", + "sequence_number": getattr(chunk, 'sequence_number', None), + "response": response_dict, + } + ) else: yield chunk.model_dump_json() elif should_yield: - if hasattr(chunk, 'output_index'): - original_index = chunk.output_index - if original_index in original_to_filtered_index_map: - chunk_dict = ( - chunk.model_dump() - if hasattr(chunk, 'model_dump') - else ( - chunk.dict() - if hasattr(chunk, 'dict') - else json.loads(chunk.model_dump_json()) - if hasattr(chunk, 'model_dump_json') - else {} - ) - ) - chunk_dict["output_index"] = original_to_filtered_index_map[original_index] - yield json.dumps(chunk_dict) - else: + if ( + hasattr(chunk, 'output_index') + and chunk.output_index in original_to_filtered_index_map + ): + chunk_dict = self._to_dict(chunk) + chunk_dict["output_index"] = original_to_filtered_index_map[chunk.output_index] + yield json.dumps(chunk_dict) + elif not hasattr(chunk, 'output_index'): yield chunk.model_dump_json() - if ( - chunk_type == 'response.output_item.added' - or chunk_type == 'ResponseOutputItemAddedEvent' - ) and hasattr(chunk, 'item'): - item = chunk.item - item_dict = ( - item - if isinstance(item, dict) - else ( - item.model_dump() - if hasattr(item, 'model_dump') - else item.dict() - if hasattr(item, 'dict') - else {} - ) - ) - item_type = item_dict.get("type") - - if item_type in ["function_tool_call", "function_call", "function", "tool_call"]: - item_id = item_dict.get("id") or item_dict.get("call_id") - call_id = item_dict.get("call_id") - if item_id: - tool_calls_accumulated[item_id] = { - "id": item_id, - "call_id": call_id, - "type": item_type, - "name": item_dict.get("name", ""), - "arguments": item_dict.get("arguments", ""), - "status": item_dict.get("status", "in_progress"), - } - - elif ( - chunk_type == 'response.function_call_arguments.delta' - or chunk_type == 'ResponseFunctionCallArgumentsDeltaEvent' + # Handle argument deltas + if chunk_type in ( + 'response.function_call_arguments.delta', + 'ResponseFunctionCallArgumentsDeltaEvent', ): item_id = getattr(chunk, 'item_id', None) - delta = getattr(chunk, 'delta', '') - if item_id and item_id in tool_calls_accumulated: - tool_calls_accumulated[item_id]["arguments"] += delta + tool_calls_accumulated[item_id]["arguments"] += getattr(chunk, 'delta', '') - elif ( - chunk_type == 'response.function_call_arguments.done' - or chunk_type == 'ResponseFunctionCallArgumentsDoneEvent' + elif chunk_type in ( + 'response.function_call_arguments.done', + 'ResponseFunctionCallArgumentsDoneEvent', ): item_id = getattr(chunk, 'item_id', None) - arguments = getattr(chunk, 'arguments', '') - if item_id and item_id in tool_calls_accumulated: - tool_calls_accumulated[item_id]["arguments"] = arguments + tool_calls_accumulated[item_id]["arguments"] = getattr(chunk, 'arguments', '') - elif ( - chunk_type == 'response.output_item.done' - or chunk_type == 'ResponseOutputItemDoneEvent' + # Handle output item done + elif chunk_type in ( + 'response.output_item.done', + 'ResponseOutputItemDoneEvent', ) and hasattr(chunk, 'item'): - item = chunk.item - item_dict = ( - item - if isinstance(item, dict) - else ( - item.model_dump() - if hasattr(item, 'model_dump') - else item.dict() - if hasattr(item, 'dict') - else {} - ) - ) + item_dict = self._to_dict(chunk.item) item_type = item_dict.get("type") if item_type in ["function_tool_call", "function_call", "function", "tool_call"]: @@ -1294,62 +1279,80 @@ async def _stream_responses_with_mcp_tools_json( else: accumulated_output.append(item_dict) - elif hasattr(chunk, 'output') and chunk.output: - for item in chunk.output: - item_dict = ( - item - if isinstance(item, dict) - else ( - item.model_dump() - if hasattr(item, 'model_dump') - else item.dict() - if hasattr(item, 'dict') - else {} - ) - ) - item_id = item_dict.get("id") - if item_id: - existing_ids = [ - i.get("id") - if isinstance(i, dict) - else (getattr(i, "id", None) if hasattr(i, "id") else None) - for i in accumulated_output - ] - if item_id not in existing_ids: - accumulated_output.append(item_dict) - else: - accumulated_output.append(item_dict) - + # Add remaining accumulated tool calls for call_id, call_data in tool_calls_accumulated.items(): if call_data.get("name"): - existing_ids = [ - i.get("id") - if isinstance(i, dict) - else (getattr(i, "id", None) if hasattr(i, "id") else None) - for i in accumulated_output - ] + existing_ids = {self._get_id(i) for i in accumulated_output} if call_id not in existing_ids: accumulated_output.append(call_data) + # Execute tool calls if any tool_calls = self._extract_tool_calls_from_response_output(accumulated_output) if tool_calls: model_output_items = self._convert_output_items_to_input_items(accumulated_output) input_items.extend(model_output_items) + # Execute tools in parallel await self._execute_response_api_tool_calls_async( - tool_calls, mcp_clients, input_items, tool_to_server + tool_calls, connections, input_items, tool_to_server ) request_data["input"] = input_items async for chunk_json in self._stream_responses_with_mcp_tools_json( - request_data, tools, mcp_clients, tool_to_server + request_data, tools, connections, tool_to_server ): yield chunk_json + def _to_dict(self, obj: Any) -> dict: + """Convert object to dictionary.""" + if isinstance(obj, dict): + return obj + if hasattr(obj, 'model_dump'): + return obj.model_dump() + if hasattr(obj, 'dict'): + return obj.dict() + if hasattr(obj, '__dict__'): + return obj.__dict__ + return {} + + def _get_id(self, item: Any) -> Optional[str]: + """Get ID from an item.""" + if isinstance(item, dict): + return item.get("id") + return getattr(item, "id", None) + + def _process_response_output( + self, + output: List[Any], + accumulated_output: List[Dict], + tool_calls_accumulated: Dict[str, Dict], + ) -> Tuple[List[Dict], List[Dict]]: + """Process response output, filtering messages and accumulating tool calls.""" + filtered_output = [] + + for item in output: + item_dict = self._to_dict(item) + item_type = item_dict.get("type") + + if item_type == "message": + filtered_output.append(item_dict) + elif item_type in ["function_tool_call", "function_call", "function", "tool_call"]: + item_id = item_dict.get("id") + existing_ids = {self._get_id(i) for i in accumulated_output} + if not item_id or item_id not in existing_ids: + accumulated_output.append(item_dict) + else: + item_id = item_dict.get("id") + existing_ids = {self._get_id(i) for i in accumulated_output} + if not item_id or item_id not in existing_ids: + accumulated_output.append(item_dict) + + return filtered_output, accumulated_output + @ModelClass.method def openai_transport(self, msg: str) -> str: - """Process an OpenAI-compatible request and send it to the appropriate OpenAI endpoint.""" + """Process an OpenAI-compatible request.""" try: request_data = from_json(msg) request_data = self._update_old_fields(request_data) @@ -1358,16 +1361,14 @@ def openai_transport(self, msg: str) -> str: tools = request_data.get("tools") if mcp_servers and len(mcp_servers) > 0 and tools is None: - logger.info(f"Getting tools and clients for MCP servers: {mcp_servers}") - tools_local, mcp_clients_local, tool_to_server_local = ( - self._get_mcp_tools_and_clients(mcp_servers) + logger.info(f"Getting tools for MCP servers: {mcp_servers}") + tools_local, connections, tool_to_server = self._get_mcp_tools_and_clients( + mcp_servers ) - # Note: No cleanup needed - connections are maintained in the pool - if endpoint == self.ENDPOINT_CHAT_COMPLETIONS: response = self._route_request( - endpoint, request_data, mcp_servers, mcp_clients_local, tools_local + endpoint, request_data, mcp_servers, connections, tools_local ) while response.choices and response.choices[0].message.tool_calls: @@ -1375,22 +1376,18 @@ def openai_transport(self, msg: str) -> str: messages.append(response.choices[0].message) self._execute_tool_calls( response.choices[0].message.tool_calls, - mcp_clients_local, + connections, messages, - tool_to_server_local, + tool_to_server, ) request_data["messages"] = messages response = self._route_request( - endpoint, - request_data, - mcp_servers, - mcp_clients_local, - tools_local, + endpoint, request_data, mcp_servers, connections, tools_local ) elif endpoint == self.ENDPOINT_RESPONSES: response = self._route_request( - endpoint, request_data, mcp_servers, mcp_clients_local, tools_local + endpoint, request_data, mcp_servers, connections, tools_local ) input_data = request_data.get("input", "") @@ -1413,26 +1410,15 @@ def openai_transport(self, msg: str) -> str: response_output ) input_items.extend(model_output_items) - self._execute_response_api_tool_calls( - tool_calls, - mcp_clients_local, - input_items, - tool_to_server_local, + tool_calls, connections, input_items, tool_to_server ) request_data["input"] = input_items - response = self._route_request( - endpoint, - request_data, - mcp_servers, - mcp_clients_local, - tools_local, + endpoint, request_data, mcp_servers, connections, tools_local ) - response_output = response.output if hasattr(response, 'output') else [] tool_calls = self._extract_tool_calls_from_response_output(response_output) - else: response = self._route_request(endpoint, request_data) else: @@ -1452,7 +1438,7 @@ def openai_transport(self, msg: str) -> str: @ModelClass.method def openai_stream_transport(self, msg: str) -> Iterator[str]: - """Process an OpenAI-compatible request and return a streaming response iterator.""" + """Process an OpenAI-compatible request with streaming.""" try: request_data = from_json(msg) request_data = self._update_old_fields(request_data) @@ -1460,58 +1446,51 @@ def openai_stream_transport(self, msg: str) -> Iterator[str]: endpoint = request_data.pop("openai_endpoint", self.DEFAULT_ENDPOINT) if endpoint not in [self.ENDPOINT_CHAT_COMPLETIONS, self.ENDPOINT_RESPONSES]: - raise ValueError("Streaming is only supported for chat completions and responses.") + raise ValueError("Streaming only supported for chat completions and responses.") if mcp_servers and len(mcp_servers) > 0 and request_data.get("tools") is None: - logger.info(f"Getting tools and clients for MCP servers: {mcp_servers}") - tools_local, mcp_clients_local, tool_to_server_local = ( - self._get_mcp_tools_and_clients(mcp_servers) + logger.info(f"Getting tools for MCP servers: {mcp_servers}") + tools_local, connections, tool_to_server = self._get_mcp_tools_and_clients( + mcp_servers ) - # Note: No cleanup needed - connections are maintained in the pool - async def stream_generator(): if endpoint == self.ENDPOINT_CHAT_COMPLETIONS: messages = request_data.get("messages", []) async for chunk_json in self._stream_with_mcp_tools_json( messages, tools_local, - mcp_clients_local, + connections, request_data.get("max_completion_tokens", 4096), request_data.get("temperature", 1.0), request_data.get("top_p", 1.0), - tool_to_server_local, + tool_to_server, ): yield chunk_json self._finalize_token_usage() elif endpoint == self.ENDPOINT_RESPONSES: async for chunk_json in self._stream_responses_with_mcp_tools_json( - request_data, tools_local, mcp_clients_local, tool_to_server_local + request_data, tools_local, connections, tool_to_server ): yield chunk_json self._finalize_token_usage() - else: - response_args = {**request_data, "model": self.model} - for chunk in self.client.responses.create(**response_args): - self._set_usage(chunk) - yield chunk.model_dump_json() - self._finalize_token_usage() yield from self._bridge_async_generator(stream_generator) return + # Non-MCP path if endpoint == self.ENDPOINT_RESPONSES: response_args = {**request_data, "model": self.model} for chunk in self.client.responses.create(**response_args): self._set_usage(chunk) yield chunk.model_dump_json() - self._finalize_token_usage() else: completion_args = self._create_completion_args(request_data) for chunk in self.client.chat.completions.create(**completion_args): self._set_usage(chunk) yield chunk.model_dump_json() - self._finalize_token_usage() + + self._finalize_token_usage() except Exception as e: logger.exception(e) From b174b0ec6c624d3106fad6379c65870ea2140ef1 Mon Sep 17 00:00:00 2001 From: Luv Bansal Date: Tue, 9 Dec 2025 18:28:00 +0530 Subject: [PATCH 04/13] simplify code --- clarifai/runners/models/agentic_class.py | 1735 ++++++++-------------- 1 file changed, 629 insertions(+), 1106 deletions(-) diff --git a/clarifai/runners/models/agentic_class.py b/clarifai/runners/models/agentic_class.py index 7bfa67d5..a6d9b06d 100644 --- a/clarifai/runners/models/agentic_class.py +++ b/clarifai/runners/models/agentic_class.py @@ -5,7 +5,6 @@ import os import threading import time -from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from typing import Any, Dict, Iterator, List, Optional, Set, Tuple @@ -19,136 +18,85 @@ @dataclass class MCPConnection: - """Represents a single MCP server connection with metadata.""" + """Single MCP server connection with metadata.""" client: Any tools: List[Any] - tool_names: Set[str] # For O(1) tool lookup - last_used: float - connected_at: float + tool_names: Set[str] url: str - lock: threading.RLock = field(default_factory=threading.RLock) - use_count: int = 0 + last_used: float = field(default_factory=time.time) - def mark_used(self): - """Mark connection as recently used.""" + def touch(self): + """Mark as recently used.""" self.last_used = time.time() - self.use_count += 1 class MCPConnectionPool: - """Thread-safe connection pool for MCP servers with persistent connections.""" + """Thread-safe connection pool for MCP servers.""" _instance: Optional['MCPConnectionPool'] = None - _lock = threading.Lock() + _instance_lock = threading.Lock() - # Pool configuration - DEFAULT_MAX_IDLE_TIME = 600 # 10 minutes, if mcp server idle time > DEFAULT_MAX_IDLE_TIME it disconnect from server - DEFAULT_CLEANUP_INTERVAL = 120 # 2 minute, - DEFAULT_VERIFY_THRESHOLD = 60 # Only verify if idle > 60 seconds - DEFAULT_CONNECT_TIMEOUT = 30 # Connection timeout - DEFAULT_TOOL_CALL_TIMEOUT = 60 # Tool call timeout - MAX_PARALLEL_CONNECTIONS = 10 # Max concurrent connection attempts + # Timeouts + CONNECT_TIMEOUT = 30.0 + TOOL_CALL_TIMEOUT = 60.0 + VERIFY_IDLE_THRESHOLD = 60.0 # Verify if idle > 60s + MAX_IDLE_TIME = 600.0 # Remove if idle > 10min def __new__(cls): - """Singleton pattern to ensure one connection pool per process.""" if cls._instance is None: - with cls._lock: + with cls._instance_lock: if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance - def __init__( - self, - max_idle_time: float = DEFAULT_MAX_IDLE_TIME, - cleanup_interval: float = DEFAULT_CLEANUP_INTERVAL, - verify_threshold: float = DEFAULT_VERIFY_THRESHOLD, - ): + def __init__(self): if self._initialized: return self._connections: Dict[str, MCPConnection] = {} - self._global_lock = threading.RLock() + self._lock = threading.RLock() + + # Tool caches for O(1) lookup + self._tool_to_url: Dict[str, str] = {} + self._all_tools: Dict[str, dict] = {} # tool_name -> OpenAI format + + # Background event loop self._loop: Optional[asyncio.AbstractEventLoop] = None self._loop_thread: Optional[threading.Thread] = None - self._loop_lock = threading.Lock() - self._loop_ready = threading.Event() - - # Configuration - self._max_idle_time = max_idle_time - self._cleanup_interval = cleanup_interval - self._verify_threshold = verify_threshold - self._last_cleanup = time.time() - - # Tool name to URL cache for O(1) lookup - self._tool_to_url_cache: Dict[str, str] = {} - self._cache_lock = threading.RLock() - - # Thread pool for parallel operations - self._executor = ThreadPoolExecutor( - max_workers=self.MAX_PARALLEL_CONNECTIONS, thread_name_prefix="mcp_pool_" - ) - - # Start the background event loop immediately self._start_event_loop() self._initialized = True def _start_event_loop(self): - """Start the persistent event loop in a background thread.""" - with self._loop_lock: - if self._loop is not None and self._loop_thread is not None: - if self._loop_thread.is_alive() and not self._loop.is_closed(): - return - - self._loop_ready.clear() - - def run_loop(): - self._loop = asyncio.new_event_loop() - asyncio.set_event_loop(self._loop) - self._loop_ready.set() - self._loop.run_forever() - - self._loop_thread = threading.Thread( - target=run_loop, daemon=True, name="mcp_event_loop" - ) - self._loop_thread.start() + """Start background event loop for async operations.""" + ready = threading.Event() + + def run(): + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + ready.set() + self._loop.run_forever() - if not self._loop_ready.wait(timeout=10.0): - raise RuntimeError("Failed to start MCP event loop") + self._loop_thread = threading.Thread(target=run, daemon=True, name="mcp_pool") + self._loop_thread.start() + ready.wait(timeout=5.0) - def _get_event_loop(self) -> asyncio.AbstractEventLoop: - """Get the persistent event loop, starting it if necessary.""" + def _run_async(self, coro, timeout: float = CONNECT_TIMEOUT) -> Any: + """Run coroutine in background loop.""" if self._loop is None or self._loop.is_closed(): self._start_event_loop() - return self._loop - - def _run_coroutine(self, coro, timeout: float = DEFAULT_CONNECT_TIMEOUT) -> Any: - """Run a coroutine in the persistent event loop with timeout.""" - loop = self._get_event_loop() - future = asyncio.run_coroutine_threadsafe(coro, loop) + future = asyncio.run_coroutine_threadsafe(coro, self._loop) return future.result(timeout=timeout) - def _run_coroutine_nowait(self, coro) -> asyncio.Future: - """Schedule a coroutine without waiting for result.""" - loop = self._get_event_loop() - return asyncio.run_coroutine_threadsafe(coro, loop) - - async def _create_client(self, url: str) -> Tuple[Any, List[Any]]: - """Create and connect a single MCP client. - - Returns: - Tuple of (client, tools) - """ + async def _create_connection(self, url: str) -> MCPConnection: + """Create new MCP connection.""" try: from fastmcp import Client from fastmcp.client.transports import StreamableHttpTransport except ImportError: - raise ImportError( - "fastmcp package is required to use MCP functionality. " - "Install it with: pip install fastmcp" - ) + raise ImportError("fastmcp required: pip install fastmcp") transport = StreamableHttpTransport( url=url, @@ -156,223 +104,145 @@ async def _create_client(self, url: str) -> Tuple[Any, List[Any]]: ) client = Client(transport) - await client.__aenter__() - tools = await client.list_tools() - - return client, tools - - async def _connect_single_server( - self, url: str, max_retries: int = 2, retry_delay: float = 1.0 - ) -> Optional[MCPConnection]: - """Connect to a single MCP server with retries.""" - last_error = None + await asyncio.wait_for(client.__aenter__(), timeout=self.CONNECT_TIMEOUT) + tools = await asyncio.wait_for(client.list_tools(), timeout=10.0) - for attempt in range(max_retries): - try: - client, tools = await asyncio.wait_for( - self._create_client(url), timeout=self.DEFAULT_CONNECT_TIMEOUT - ) - - # Build tool name set for O(1) lookup - tool_names = {tool.name for tool in tools} - - connection = MCPConnection( - client=client, - tools=tools, - tool_names=tool_names, - last_used=time.time(), - connected_at=time.time(), - url=url, - ) - - logger.info(f"✓ Connected to {url} with {len(tools)} tools") - return connection - - except asyncio.TimeoutError: - last_error = TimeoutError( - f"Connection timeout after {self.DEFAULT_CONNECT_TIMEOUT}s" - ) - logger.warning( - f"⚠ Timeout connecting to {url} (attempt {attempt + 1}/{max_retries})" - ) - except Exception as e: - last_error = e - if attempt < max_retries - 1: - logger.warning( - f"⚠ Failed to connect to {url} (attempt {attempt + 1}/{max_retries}): {e}" - ) - await asyncio.sleep(retry_delay) - - logger.error(f"❌ Failed to connect to {url} after {max_retries} attempts: {last_error}") - return None - - async def _connect_servers_parallel( - self, urls: List[str], max_retries: int = 2, retry_delay: float = 1.0 - ) -> Dict[str, MCPConnection]: - """Connect to multiple servers in parallel.""" - if not urls: - return {} - - tasks = [self._connect_single_server(url, max_retries, retry_delay) for url in urls] - - results = await asyncio.gather(*tasks, return_exceptions=True) - - connections = {} - for url, result in zip(urls, results): - if isinstance(result, Exception): - logger.error(f"❌ Error connecting to {url}: {result}") - elif result is not None: - connections[url] = result - - return connections + return MCPConnection( + client=client, + tools=tools, + tool_names={t.name for t in tools}, + url=url, + ) - async def _verify_connection_async(self, connection: MCPConnection) -> bool: - """Verify that a connection is still valid.""" + async def _verify_connection(self, conn: MCPConnection) -> bool: + """Check if connection is still valid.""" try: - await asyncio.wait_for(connection.client.list_tools(), timeout=5.0) + await asyncio.wait_for(conn.client.list_tools(), timeout=5.0) return True - except Exception as e: - logger.warning(f"⚠ Connection to {connection.url} is no longer valid: {e}") + except Exception: return False - async def _disconnect_async(self, connection: MCPConnection): - """Disconnect from a single MCP server.""" + async def _close_connection(self, conn: MCPConnection): + """Close a connection gracefully.""" try: - client = connection.client - if hasattr(client, 'close') and callable(getattr(client, 'close', None)): - if asyncio.iscoroutinefunction(client.close): - await asyncio.wait_for(client.close(), timeout=5.0) - else: - client.close() + if hasattr(conn.client, 'close'): + await asyncio.wait_for(conn.client.close(), timeout=5.0) else: - await asyncio.wait_for(client.__aexit__(None, None, None), timeout=5.0) - logger.info(f"✓ Disconnected from {connection.url}") + await asyncio.wait_for(conn.client.__aexit__(None, None, None), timeout=5.0) except Exception as e: - logger.warning(f"⚠ Error disconnecting from {connection.url}: {e}") + logger.warning(f"Error closing connection to {conn.url}: {e}") - def _should_verify_connection(self, connection: MCPConnection) -> bool: + def _needs_verification(self, conn: MCPConnection) -> bool: """Check if connection needs verification based on idle time.""" - idle_time = time.time() - connection.last_used - return idle_time > self._verify_threshold - - def _update_tool_cache(self, connections: Dict[str, MCPConnection]): - """Update the tool-to-URL cache from connections.""" - with self._cache_lock: - for url, conn in connections.items(): - for tool_name in conn.tool_names: - self._tool_to_url_cache[tool_name] = url - - def _invalidate_tool_cache(self, url: str): - """Remove tools from cache when a connection is removed.""" - with self._cache_lock: - self._tool_to_url_cache = { - name: cached_url - for name, cached_url in self._tool_to_url_cache.items() - if cached_url != url + return time.time() - conn.last_used > self.VERIFY_IDLE_THRESHOLD + + def _update_tool_cache(self, conn: MCPConnection): + """Update tool caches from connection.""" + for tool in conn.tools: + self._tool_to_url[tool.name] = conn.url + self._all_tools[tool.name] = { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description or "", + "parameters": tool.inputSchema, + }, } - def get_connections( - self, mcp_servers: List[str], max_retries: int = 2, retry_delay: float = 1.0 - ) -> Dict[str, MCPConnection]: - """Get connections for the specified MCP servers. + def get_connections(self, urls: List[str]) -> Dict[str, MCPConnection]: + """Get connections for URLs, creating/verifying as needed. - Uses lazy verification - only verifies connections that have been - idle longer than the verification threshold. + This is the main entry point. It: + 1. Returns cached connections if recently used + 2. Verifies cached connections if idle + 3. Creates new connections if needed + All operations happen in parallel where possible. """ - self._maybe_cleanup_stale_connections() - result = {} - urls_to_connect = [] - urls_to_verify = [] + to_verify = [] + to_create = [] - # First pass: categorize URLs - with self._global_lock: - for url in mcp_servers: + with self._lock: + for url in urls: if url in self._connections: - connection = self._connections[url] - if self._should_verify_connection(connection): - urls_to_verify.append(url) + conn = self._connections[url] + if self._needs_verification(conn): + to_verify.append(url) else: - # Recently used, assume still valid - connection.mark_used() - result[url] = connection + conn.touch() + result[url] = conn else: - urls_to_connect.append(url) + to_create.append(url) # Verify stale connections in parallel - if urls_to_verify: + if to_verify: async def verify_all(): - tasks = [] - for url in urls_to_verify: - conn = self._connections.get(url) - if conn: - tasks.append(self._verify_connection_async(conn)) - else: - tasks.append(asyncio.coroutine(lambda: False)()) - return await asyncio.gather(*tasks, return_exceptions=True) + tasks = { + url: self._verify_connection(self._connections[url]) + for url in to_verify + if url in self._connections + } + results = await asyncio.gather(*tasks.values(), return_exceptions=True) + return dict(zip(tasks.keys(), results)) try: - verify_results = self._run_coroutine(verify_all(), timeout=15.0) - - with self._global_lock: - for url, is_valid in zip(urls_to_verify, verify_results): - if isinstance(is_valid, Exception) or not is_valid: - # Connection is stale, need to reconnect - if url in self._connections: - self._invalidate_tool_cache(url) - del self._connections[url] - urls_to_connect.append(url) + verify_results = self._run_async(verify_all(), timeout=15.0) + with self._lock: + for url, is_valid in verify_results.items(): + if is_valid is True: + conn = self._connections[url] + conn.touch() + result[url] = conn else: - # Connection is valid - connection = self._connections[url] - connection.mark_used() - result[url] = connection + # Remove invalid connection, will recreate + self._connections.pop(url, None) + to_create.append(url) except Exception as e: - logger.error(f"❌ Error verifying connections: {e}") - # On verification failure, try to reconnect all - urls_to_connect.extend(urls_to_verify) + logger.error(f"Verification error: {e}") + to_create.extend(to_verify) - # Connect to new servers in parallel - if urls_to_connect: - try: - new_connections = self._run_coroutine( - self._connect_servers_parallel(urls_to_connect, max_retries, retry_delay), - timeout=self.DEFAULT_CONNECT_TIMEOUT * max_retries + 10, - ) - - with self._global_lock: - self._connections.update(new_connections) - result.update(new_connections) + # Create new connections in parallel + if to_create: - # Update tool cache - self._update_tool_cache(new_connections) + async def create_all(): + tasks = {url: self._create_connection(url) for url in to_create} + results = await asyncio.gather(*tasks.values(), return_exceptions=True) + return dict(zip(tasks.keys(), results)) + try: + create_results = self._run_async(create_all(), timeout=self.CONNECT_TIMEOUT + 5) + with self._lock: + for url, conn_or_error in create_results.items(): + if isinstance(conn_or_error, Exception): + logger.error(f"Failed to connect to {url}: {conn_or_error}") + else: + self._connections[url] = conn_or_error + self._update_tool_cache(conn_or_error) + result[url] = conn_or_error + logger.info(f"✓ Connected to {url} ({len(conn_or_error.tools)} tools)") except Exception as e: - logger.error(f"❌ Error connecting to MCP servers: {e}") + logger.error(f"Connection error: {e}") return result def get_tools_and_mapping( - self, mcp_servers: List[str] + self, urls: List[str] ) -> Tuple[List[dict], Dict[str, MCPConnection], Dict[str, str]]: - """Get tools and server mapping for the specified MCP servers. + """Get tools, connections, and tool-to-server mapping.""" + connections = self.get_connections(urls) - Returns: - Tuple of (tools in OpenAI format, connections dict, tool_to_server mapping) - """ - connections = self.get_connections(mcp_servers) - - all_tools = [] + tools = [] tool_to_server = {} - seen_tools = set() # Avoid duplicate tools + seen = set() for url, conn in connections.items(): for tool in conn.tools: - if tool.name not in seen_tools: - seen_tools.add(tool.name) - all_tools.append( - { + if tool.name not in seen: + seen.add(tool.name) + tools.append( + self._all_tools.get(tool.name) + or { "type": "function", "function": { "name": tool.name, @@ -383,13 +253,8 @@ def get_tools_and_mapping( ) tool_to_server[tool.name] = url - logger.info(f"Access to {len(all_tools)} tools from {len(connections)} servers") - return all_tools, connections, tool_to_server - - def get_cached_tool_url(self, tool_name: str) -> Optional[str]: - """Get URL for a tool from cache. O(1) lookup.""" - with self._cache_lock: - return self._tool_to_url_cache.get(tool_name) + logger.info(f"Loaded {len(tools)} tools from {len(connections)} servers") + return tools, connections, tool_to_server async def call_tool_async( self, @@ -398,50 +263,17 @@ async def call_tool_async( connections: Dict[str, MCPConnection], tool_to_server: Dict[str, str], ) -> Any: - """Call a tool on the appropriate MCP server.""" - # Try cached/mapped server first - server_url = tool_to_server.get(tool_name) or self.get_cached_tool_url(tool_name) - - if server_url and server_url in connections: - conn = connections[server_url] - with conn.lock: - try: - logger.info(f"Calling tool {tool_name} on {server_url}") - result = await asyncio.wait_for( - conn.client.call_tool(tool_name, arguments=arguments), - timeout=self.DEFAULT_TOOL_CALL_TIMEOUT, - ) - conn.mark_used() - return result - except asyncio.TimeoutError: - logger.error(f"❌ Timeout calling tool {tool_name}") - raise - except Exception as e: - logger.error(f"❌ Error calling tool {tool_name}: {e}") - # Fall through to try other servers - - # Fallback: find server with this tool - for url, conn in connections.items(): - if url == server_url: - continue # Already tried - if tool_name in conn.tool_names: - with conn.lock: - try: - logger.info(f"Calling tool {tool_name} on {url} (fallback)") - result = await asyncio.wait_for( - conn.client.call_tool(tool_name, arguments=arguments), - timeout=self.DEFAULT_TOOL_CALL_TIMEOUT, - ) - conn.mark_used() - # Update cache for future lookups - with self._cache_lock: - self._tool_to_url_cache[tool_name] = url - return result - except Exception as e: - logger.error(f"❌ Error calling tool {tool_name} on {url}: {e}") - continue - - raise Exception(f"Tool {tool_name} not found on any connected server") + """Call a tool asynchronously.""" + url = tool_to_server.get(tool_name) or self._tool_to_url.get(tool_name) + if not url or url not in connections: + raise ValueError(f"Tool '{tool_name}' not found") + + conn = connections[url] + result = await asyncio.wait_for( + conn.client.call_tool(tool_name, arguments=arguments), timeout=self.TOOL_CALL_TIMEOUT + ) + conn.touch() + return result def call_tool( self, @@ -450,982 +282,677 @@ def call_tool( connections: Dict[str, MCPConnection], tool_to_server: Dict[str, str], ) -> Any: - """Synchronous wrapper for call_tool_async.""" - return self._run_coroutine( + """Call a tool synchronously.""" + logger.info(f"Calling tool {tool_name}") + return self._run_async( self.call_tool_async(tool_name, arguments, connections, tool_to_server), - timeout=self.DEFAULT_TOOL_CALL_TIMEOUT + 5, + timeout=self.TOOL_CALL_TIMEOUT + 5, ) - async def call_tools_parallel( + async def call_tools_batch_async( self, - tool_calls: List[Tuple[str, Dict[str, Any]]], + calls: List[Tuple[str, str, Dict[str, Any]]], # [(id, name, args), ...] connections: Dict[str, MCPConnection], tool_to_server: Dict[str, str], - ) -> List[Tuple[str, Any, Optional[Exception]]]: - """Execute multiple tool calls in parallel. - - Args: - tool_calls: List of (tool_name, arguments) tuples - connections: Connection dictionary - tool_to_server: Tool to server mapping - - Returns: - List of (tool_name, result, exception) tuples - """ + ) -> List[Tuple[str, Optional[Any], Optional[str]]]: + """Call multiple tools in parallel. Returns [(id, result, error), ...]""" - async def call_single(tool_name: str, args: Dict[str, Any]): + async def call_one(call_id: str, name: str, args: Dict): try: - result = await self.call_tool_async(tool_name, args, connections, tool_to_server) - return (tool_name, result, None) + result = await self.call_tool_async(name, args, connections, tool_to_server) + return (call_id, result, None) except Exception as e: - return (tool_name, None, e) + logger.error(f"Error calling tool {name}: {e}") + return (call_id, None, str(e)) - tasks = [call_single(name, args) for name, args in tool_calls] + tasks = [call_one(cid, name, args) for cid, name, args in calls] return await asyncio.gather(*tasks) - def _maybe_cleanup_stale_connections(self): - """Clean up connections that have been idle for too long.""" - current_time = time.time() - - if current_time - self._last_cleanup < self._cleanup_interval: - return - - self._last_cleanup = current_time + def call_tools_batch( + self, + calls: List[Tuple[str, str, Dict[str, Any]]], + connections: Dict[str, MCPConnection], + tool_to_server: Dict[str, str], + ) -> List[Tuple[str, Optional[Any], Optional[str]]]: + """Call multiple tools in parallel (sync wrapper).""" + return self._run_async( + self.call_tools_batch_async(calls, connections, tool_to_server), + timeout=self.TOOL_CALL_TIMEOUT + 10, + ) - with self._global_lock: - urls_to_remove = [ + def cleanup_idle(self): + """Remove connections idle for too long.""" + now = time.time() + with self._lock: + to_remove = [ url for url, conn in self._connections.items() - if current_time - conn.last_used > self._max_idle_time + if now - conn.last_used > self.MAX_IDLE_TIME ] - - for url in urls_to_remove: - self._disconnect_url(url) - - def _disconnect_url(self, url: str): - """Disconnect from a specific URL.""" - with self._global_lock: - connection = self._connections.pop(url, None) - - if connection: - self._invalidate_tool_cache(url) - try: - self._run_coroutine(self._disconnect_async(connection), timeout=10.0) - except Exception as e: - logger.warning(f"⚠ Error during disconnect from {url}: {e}") - - def disconnect(self, url: str): - """Public method to disconnect from a specific MCP server.""" - self._disconnect_url(url) - - def disconnect_all(self): - """Disconnect from all MCP servers and cleanup.""" - with self._global_lock: - urls = list(self._connections.keys()) - - # Disconnect all in parallel - async def disconnect_all_async(): - tasks = [] - for url in urls: - conn = self._connections.get(url) - if conn: - tasks.append(self._disconnect_async(conn)) - if tasks: - await asyncio.gather(*tasks, return_exceptions=True) - - try: - self._run_coroutine(disconnect_all_async(), timeout=30.0) - except Exception as e: - logger.warning(f"⚠ Error during bulk disconnect: {e}") - - with self._global_lock: + for url in to_remove: + conn = self._connections.pop(url) + self._run_async(self._close_connection(conn), timeout=10.0) + logger.info(f"Removed idle connection to {url}") + + def shutdown(self): + """Close all connections and stop event loop.""" + with self._lock: + for conn in self._connections.values(): + try: + self._run_async(self._close_connection(conn), timeout=5.0) + except Exception: + pass self._connections.clear() + self._tool_to_url.clear() + self._all_tools.clear() - with self._cache_lock: - self._tool_to_url_cache.clear() + if self._loop and not self._loop.is_closed(): + self._loop.call_soon_threadsafe(self._loop.stop) - # Stop the event loop - with self._loop_lock: - if self._loop is not None and not self._loop.is_closed(): - self._loop.call_soon_threadsafe(self._loop.stop) - if self._loop_thread is not None: - self._loop_thread.join(timeout=5.0) - self._loop = None - self._loop_thread = None - # Shutdown executor - self._executor.shutdown(wait=False) - - def warm_up(self, mcp_servers: List[str]): - """Pre-establish connections to servers. +class AgenticModelClass(OpenAIModelClass): + """Base class for wrapping OpenAI-compatible servers with MCP (Model Context Protocol) support. - Call this during initialization to avoid connection latency on first request. - """ - logger.info(f"Warming up connections to {len(mcp_servers)} MCP servers") - self.get_connections(mcp_servers) - - def get_stats(self) -> Dict[str, Any]: - """Get pool statistics for monitoring.""" - with self._global_lock: - connections_info = [] - for url, conn in self._connections.items(): - connections_info.append( - { - "url": url, - "tools_count": len(conn.tools), - "use_count": conn.use_count, - "idle_seconds": time.time() - conn.last_used, - "connected_seconds": time.time() - conn.connected_at, - } - ) + This class extends OpenAIModelClass to enable agentic behavior by integrating LLMs with MCP servers. + It handles tool discovery, execution, and iterative tool calling for both chat completions and + responses endpoints, supporting both streaming and non-streaming modes. - return { - "total_connections": len(self._connections), - "cached_tools": len(self._tool_to_url_cache), - "connections": connections_info, - } + MCP connections are maintained persistently across requests using a connection pool, which + significantly improves performance by avoiding reconnection overhead. + To use this class, create a subclass and set the following class attributes: + - client: The OpenAI-compatible client instance + - model: The name of the model to use with the client -class AgenticModelClass(OpenAIModelClass): - """Base class for wrapping OpenAI-compatible servers with MCP support. - - Optimizations over base implementation: - - Persistent connection pool across requests - - Parallel tool execution - - Tool name caching for O(1) lookup - - Lazy connection verification - - Efficient streaming with queue-based async bridging + Example: + class MyAgenticModel(AgenticModelClass): + client = OpenAI(api_key="your-key") + model = "gpt-4" """ - _mcp_pool: Optional[MCPConnectionPool] = None + _pool: Optional[MCPConnectionPool] = None _pool_lock = threading.Lock() @classmethod - def _get_mcp_pool(cls) -> MCPConnectionPool: - """Get or create the MCP connection pool singleton.""" - if cls._mcp_pool is None: + def get_pool(cls) -> MCPConnectionPool: + """Get shared connection pool.""" + if cls._pool is None: with cls._pool_lock: - if cls._mcp_pool is None: - cls._mcp_pool = MCPConnectionPool() - return cls._mcp_pool + if cls._pool is None: + cls._pool = MCPConnectionPool() + return cls._pool @classmethod - def warm_up_mcp(cls, mcp_servers: List[str]): - """Pre-establish MCP connections during model initialization.""" - pool = cls._get_mcp_pool() - pool.warm_up(mcp_servers) + def warm_up(cls, mcp_servers: List[str]): + """Pre-connect to MCP servers.""" + cls.get_pool().get_connections(mcp_servers) - def _get_mcp_tools_and_clients( - self, mcp_servers: List[str] - ) -> Tuple[List[dict], Dict[str, MCPConnection], Dict[str, str]]: - """Get available tools and clients from MCP servers.""" - pool = self._get_mcp_pool() - return pool.get_tools_and_mapping(mcp_servers) - - def _init_token_accumulation(self): - """Initialize token accumulation for a new request.""" - if not hasattr(self._thread_local, 'accumulated_tokens'): - self._thread_local.accumulated_tokens = {'prompt_tokens': 0, 'completion_tokens': 0} - - def _accumulate_usage(self, resp): - """Accumulate token usage from response object.""" - has_usage = getattr(resp, "usage", None) - has_response_usage = getattr(resp, "response", None) and getattr( - resp.response, "usage", None - ) + # === Token Tracking === - if has_response_usage or has_usage: - prompt_tokens = 0 - completion_tokens = 0 - if has_usage: - prompt_tokens = getattr(resp.usage, "prompt_tokens", 0) or getattr( - resp.usage, "input_tokens", 0 - ) - completion_tokens = getattr(resp.usage, "completion_tokens", 0) or getattr( - resp.usage, "output_tokens", 0 - ) - else: - prompt_tokens = getattr(resp.response.usage, "input_tokens", 0) - completion_tokens = getattr(resp.response.usage, "output_tokens", 0) - - prompt_tokens = prompt_tokens or 0 - completion_tokens = completion_tokens or 0 - - if prompt_tokens > 0 or completion_tokens > 0: - self._init_token_accumulation() - self._thread_local.accumulated_tokens['prompt_tokens'] += prompt_tokens - self._thread_local.accumulated_tokens['completion_tokens'] += completion_tokens + def _init_tokens(self): + if not hasattr(self._thread_local, 'tokens'): + self._thread_local.tokens = {'prompt': 0, 'completion': 0} - def _finalize_token_usage(self): - """Finalize token accumulation and set the total in output context.""" - if hasattr(self._thread_local, 'accumulated_tokens'): - prompt_tokens = self._thread_local.accumulated_tokens['prompt_tokens'] - completion_tokens = self._thread_local.accumulated_tokens['completion_tokens'] + def _add_tokens(self, resp): + """Accumulate tokens from response.""" + usage = getattr(resp, 'usage', None) or ( + getattr(resp.response, 'usage', None) if hasattr(resp, 'response') else None + ) + if usage: + self._init_tokens() + self._thread_local.tokens['prompt'] += ( + getattr(usage, 'prompt_tokens', 0) or getattr(usage, 'input_tokens', 0) or 0 + ) + self._thread_local.tokens['completion'] += ( + getattr(usage, 'completion_tokens', 0) or getattr(usage, 'output_tokens', 0) or 0 + ) - if prompt_tokens > 0 or completion_tokens > 0: + def _finalize_tokens(self): + """Send accumulated tokens to output context.""" + if hasattr(self._thread_local, 'tokens'): + t = self._thread_local.tokens + if t['prompt'] > 0 or t['completion'] > 0: self.set_output_context( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, + prompt_tokens=t['prompt'], completion_tokens=t['completion'] ) - del self._thread_local.accumulated_tokens + del self._thread_local.tokens def _set_usage(self, resp): - """Override _set_usage to accumulate tokens across multiple API calls.""" - self._accumulate_usage(resp) + self._add_tokens(resp) - def _handle_chat_completions( - self, - request_data: Dict[str, Any], - mcp_servers: List[str] = None, - mcp_clients: dict = None, - tools: List[dict] = None, - ): - """Handle chat completion requests with optional MCP tool support.""" - if mcp_servers and tools: - request_data = request_data.copy() - request_data["tools"] = tools - request_data["tool_choice"] = request_data.get("tool_choice", "auto") - return super()._handle_chat_completions(request_data) + # === Tool Format Conversion === - def _convert_tools_to_response_api_format(self, tools: List[dict]) -> List[dict]: - """Convert tools from chat completion format to response API format.""" - response_api_tools = [] - for tool in tools: - if isinstance(tool, dict): - tool_type = tool.get("type", "function") - if "function" in tool: - func = tool["function"] - response_api_tools.append( - { - "type": tool_type, - "name": func.get("name"), - "description": func.get("description", ""), - "parameters": func.get("parameters", {}), - } - ) - elif "name" in tool: - response_api_tools.append(tool) - return response_api_tools + def _to_response_api_tools(self, tools: List[dict]) -> List[dict]: + """Convert chat completion tools to response API format.""" + result = [] + for t in tools: + if "function" in t: + f = t["function"] + result.append( + { + "type": "function", + "name": f.get("name"), + "description": f.get("description", ""), + "parameters": f.get("parameters", {}), + } + ) + elif "name" in t: + result.append(t) + return result - def _handle_responses( - self, - request_data: Dict[str, Any], - mcp_servers: List[str] = None, - mcp_clients: dict = None, - tools: List[dict] = None, - ): - """Handle response API requests with optional MCP tool support.""" - if mcp_servers and tools: - request_data = request_data.copy() - response_api_tools = self._convert_tools_to_response_api_format(tools) - request_data["tools"] = response_api_tools - request_data["tool_choice"] = request_data.get("tool_choice", "auto") - return super()._handle_responses(request_data) + def _to_dict(self, obj) -> dict: + """Convert object to dict.""" + if isinstance(obj, dict): + return obj + for attr in ('model_dump', 'dict'): + if hasattr(obj, attr): + return getattr(obj, attr)() + return getattr(obj, '__dict__', {}) + + # === Tool Call Parsing === + + def _parse_chat_tool_calls(self, tool_calls) -> List[Tuple[str, str, Dict]]: + """Parse chat completion tool calls into (id, name, args) tuples.""" + result = [] + for tc in tool_calls: + if hasattr(tc, 'function'): + result.append((tc.id, tc.function.name, json.loads(tc.function.arguments))) + else: + result.append( + (tc['id'], tc['function']['name'], json.loads(tc['function']['arguments'])) + ) + return result - def _route_request( - self, - endpoint: str, - request_data: Dict[str, Any], - mcp_servers: List[str] = None, - mcp_clients: dict = None, - tools: List[dict] = None, - ): - """Route the request to appropriate handler based on endpoint.""" - if endpoint == self.ENDPOINT_CHAT_COMPLETIONS: - return self._handle_chat_completions(request_data, mcp_servers, mcp_clients, tools) - if endpoint == self.ENDPOINT_RESPONSES: - return self._handle_responses(request_data, mcp_servers, mcp_clients, tools) - return super()._route_request(endpoint, request_data) + def _parse_response_tool_calls(self, items: List[dict]) -> List[Tuple[str, str, Dict]]: + """Parse response API tool calls into (call_id, name, args) tuples.""" + result = [] + for item in items: + d = self._to_dict(item) + if d.get('type') in ('function_tool_call', 'function_call', 'function', 'tool_call'): + status = d.get('status', '') + if status in ('pending', 'in_progress', '') or d.get('output') is None: + call_id = d.get('call_id') or d.get('id') + name = d.get('name') + args_str = d.get('arguments', '{}') + if call_id and name: + try: + args = json.loads(args_str) if isinstance(args_str, str) else args_str + except json.JSONDecodeError: + args = {} + result.append((call_id, name, args)) + return result + + # === Tool Execution === - def _execute_tool_calls( + def _execute_chat_tools( self, - tool_calls: List[Any], + tool_calls, connections: Dict[str, MCPConnection], messages: List[dict], tool_to_server: Dict[str, str], ): - """Execute tool calls from chat completion and add results to messages.""" - pool = self._get_mcp_pool() - - # Prepare tool calls for potential parallel execution - parsed_calls = [] - for tool_call in tool_calls: - if hasattr(tool_call, 'function'): - tool_name = tool_call.function.name - tool_args = json.loads(tool_call.function.arguments) - tool_id = tool_call.id - else: - tool_name = tool_call['function']['name'] - tool_args = json.loads(tool_call['function']['arguments']) - tool_id = tool_call['id'] - parsed_calls.append((tool_id, tool_name, tool_args)) + """Execute chat completion tool calls and append results to messages.""" + pool = self.get_pool() + parsed = self._parse_chat_tool_calls(tool_calls) + results = pool.call_tools_batch(parsed, connections, tool_to_server) - # Execute tools (could be parallelized for independent tools) - for tool_id, tool_name, tool_args in parsed_calls: - try: - result = pool.call_tool(tool_name, tool_args, connections, tool_to_server) + for call_id, result, error in results: + if error: + content = f"Error: {error}" + else: content = ( result.content[0].text if hasattr(result, 'content') else str(result[0].text) ) - except Exception as e: - content = f"Error: {str(e)}" - - messages.append( - { - "role": "tool", - "tool_call_id": tool_id, - "content": content, - } - ) + messages.append({"role": "tool", "tool_call_id": call_id, "content": content}) - async def _execute_tool_calls_async( + async def _execute_chat_tools_async( self, - tool_calls: List[Any], + tool_calls, connections: Dict[str, MCPConnection], messages: List[dict], tool_to_server: Dict[str, str], ): - """Async version with parallel tool execution support.""" - pool = self._get_mcp_pool() - - # Parse all tool calls - parsed_calls = [] - for tool_call in tool_calls: - if hasattr(tool_call, 'function'): - tool_name = tool_call.function.name - tool_args = json.loads(tool_call.function.arguments) - tool_id = tool_call.id - else: - tool_name = tool_call['function']['name'] - tool_args = json.loads(tool_call['function']['arguments']) - tool_id = tool_call['id'] - parsed_calls.append((tool_id, tool_name, tool_args)) + """Async version of chat tool execution.""" + pool = self.get_pool() + parsed = self._parse_chat_tool_calls(tool_calls) + results = await pool.call_tools_batch_async(parsed, connections, tool_to_server) - # Execute all tools in parallel - tool_inputs = [(name, args) for _, name, args in parsed_calls] - results = await pool.call_tools_parallel(tool_inputs, connections, tool_to_server) - - # Map results back to tool IDs - for (tool_id, tool_name, _), (_, result, error) in zip(parsed_calls, results): + for call_id, result, error in results: if error: - content = f"Error: {str(error)}" + content = f"Error: {error}" else: content = ( result.content[0].text if hasattr(result, 'content') else str(result[0].text) ) + messages.append({"role": "tool", "tool_call_id": call_id, "content": content}) - messages.append( - { - "role": "tool", - "tool_call_id": tool_id, - "content": content, - } - ) - - def _execute_response_api_tool_calls( + def _execute_response_tools( self, - tool_calls: List[Dict[str, Any]], + tool_calls: List[Tuple[str, str, Dict]], connections: Dict[str, MCPConnection], - input_items: List[Any], + input_items: List, tool_to_server: Dict[str, str], ): - """Execute tool calls from response API and add results to input items.""" - pool = self._get_mcp_pool() - - for tool_call in tool_calls: - tool_name = tool_call.get("name") - tool_args_str = tool_call.get("arguments", "{}") - tool_id = tool_call.get("id") - call_id = tool_call.get("call_id") + """Execute response API tool calls and append results to input_items.""" + pool = self.get_pool() + results = pool.call_tools_batch(tool_calls, connections, tool_to_server) - try: - tool_args = ( - json.loads(tool_args_str) if isinstance(tool_args_str, str) else tool_args_str - ) - except json.JSONDecodeError: - tool_args = {} - - try: - result = pool.call_tool(tool_name, tool_args, connections, tool_to_server) - content = ( + for call_id, result, error in results: + if error: + output = f"Error: {error}" + else: + output = ( result.content[0].text if hasattr(result, 'content') else str(result[0].text) ) - except Exception as e: - content = f"Error: {str(e)}" - - output_call_id = call_id or tool_id - if not output_call_id: - logger.warning(f"⚠ No call_id or id found for tool {tool_name}, skipping") - continue - input_items.append( - { - "type": "function_call_output", - "call_id": output_call_id, - "output": content, - } + {"type": "function_call_output", "call_id": call_id, "output": output} ) - async def _execute_response_api_tool_calls_async( + async def _execute_response_tools_async( self, - tool_calls: List[Dict[str, Any]], + tool_calls: List[Tuple[str, str, Dict]], connections: Dict[str, MCPConnection], - input_items: List[Any], + input_items: List, tool_to_server: Dict[str, str], ): - """Async version with parallel tool execution support.""" - pool = self._get_mcp_pool() - - # Parse all tool calls - parsed_calls = [] - for tool_call in tool_calls: - tool_name = tool_call.get("name") - tool_args_str = tool_call.get("arguments", "{}") - tool_id = tool_call.get("id") - call_id = tool_call.get("call_id") - - try: - tool_args = ( - json.loads(tool_args_str) if isinstance(tool_args_str, str) else tool_args_str - ) - except json.JSONDecodeError: - tool_args = {} - - parsed_calls.append((tool_name, tool_args, tool_id, call_id)) + """Async version of response API tool execution.""" + pool = self.get_pool() + results = await pool.call_tools_batch_async(tool_calls, connections, tool_to_server) - # Execute all tools in parallel - tool_inputs = [(name, args) for name, args, _, _ in parsed_calls] - results = await pool.call_tools_parallel(tool_inputs, connections, tool_to_server) - - # Map results back - for (tool_name, _, tool_id, call_id), (_, result, error) in zip(parsed_calls, results): + for call_id, result, error in results: if error: - content = f"Error: {str(error)}" + output = f"Error: {error}" else: - content = ( + output = ( result.content[0].text if hasattr(result, 'content') else str(result[0].text) ) - - output_call_id = call_id or tool_id - if not output_call_id: - logger.warning(f"⚠ No call_id or id found for tool {tool_name}, skipping") - continue - input_items.append( - { - "type": "function_call_output", - "call_id": output_call_id, - "output": content, - } + {"type": "function_call_output", "call_id": call_id, "output": output} ) - def _extract_tool_calls_from_response_output( - self, response_output: List[Any] - ) -> List[Dict[str, Any]]: - """Extract tool calls from response API output array.""" - tool_calls = [] - for item in response_output: - if not isinstance(item, dict): - if hasattr(item, 'model_dump'): - item = item.model_dump() - elif hasattr(item, 'dict'): - item = item.dict() - elif hasattr(item, '__dict__'): - item = item.__dict__ - else: - continue + # === Response Output Processing === - item_type = item.get("type") - if item_type in ["function_tool_call", "function_call", "function", "tool_call"]: - status = item.get("status", "") - output = item.get("output") - if status in ["pending", "in_progress", ""] or output is None: - tool_calls.append(item) - return tool_calls - - def _convert_output_items_to_input_items( - self, response_output: List[Any] - ) -> List[Dict[str, Any]]: - """Convert response API output items to input items format.""" - input_items = [] - for item in response_output: - if not isinstance(item, dict): - if hasattr(item, 'model_dump'): - item = item.model_dump() - elif hasattr(item, 'dict'): - item = item.dict() - elif hasattr(item, '__dict__'): - item = item.__dict__ - else: - continue + def _convert_output_to_input(self, output: List) -> List[dict]: + """Convert response API output items to input items.""" + result = [] + for item in output: + d = self._to_dict(item) + t = d.get('type') + if t in ('message', 'reasoning'): + result.append(d) + elif t in ('function_tool_call', 'function_call', 'function', 'tool_call'): + if d.get('output') is not None or d.get('status') in ('completed', 'done'): + result.append(d) + return result + + # === Request Handlers === + + def _handle_chat_completions( + self, request_data: Dict, mcp_servers=None, connections=None, tools=None + ): + if mcp_servers and tools: + request_data = { + **request_data, + "tools": tools, + "tool_choice": request_data.get("tool_choice", "auto"), + } + return super()._handle_chat_completions(request_data) + + def _handle_responses( + self, request_data: Dict, mcp_servers=None, connections=None, tools=None + ): + if mcp_servers and tools: + request_data = { + **request_data, + "tools": self._to_response_api_tools(tools), + "tool_choice": request_data.get("tool_choice", "auto"), + } + return super()._handle_responses(request_data) + + def _route_request( + self, endpoint: str, request_data: Dict, mcp_servers=None, connections=None, tools=None + ): + if endpoint == self.ENDPOINT_CHAT_COMPLETIONS: + return self._handle_chat_completions(request_data, mcp_servers, connections, tools) + if endpoint == self.ENDPOINT_RESPONSES: + return self._handle_responses(request_data, mcp_servers, connections, tools) + return super()._route_request(endpoint, request_data) - item_type = item.get("type") - if item_type in ["message", "reasoning"]: - input_items.append(item) - elif item_type in ["function_tool_call", "function_call", "function", "tool_call"]: - status = item.get("status", "") - output = item.get("output") - if output is not None or status in ["completed", "done"]: - input_items.append(item) - return input_items - - def _accumulate_tool_call_delta(self, tool_call_delta, tool_calls_accumulated: dict): - """Accumulate tool call data from a streaming delta.""" - index = tool_call_delta.index - if index not in tool_calls_accumulated: - tool_calls_accumulated[index] = { - "id": tool_call_delta.id, + # === Streaming Helpers === + + def _accumulate_tool_delta(self, delta, accumulated: dict): + """Accumulate streaming tool call deltas.""" + idx = delta.index + if idx not in accumulated: + accumulated[idx] = { + "id": delta.id, "type": "function", "function": {"name": "", "arguments": ""}, } - if tool_call_delta.id: - tool_calls_accumulated[index]["id"] = tool_call_delta.id - if tool_call_delta.function: - if tool_call_delta.function.name: - tool_calls_accumulated[index]["function"]["name"] = tool_call_delta.function.name - if tool_call_delta.function.arguments: - tool_calls_accumulated[index]["function"]["arguments"] += ( - tool_call_delta.function.arguments - ) - - def _convert_accumulated_tool_calls(self, tool_calls_accumulated: dict) -> List[dict]: - """Convert accumulated tool calls dictionary to list format.""" + if delta.id: + accumulated[idx]["id"] = delta.id + if delta.function: + if delta.function.name: + accumulated[idx]["function"]["name"] = delta.function.name + if delta.function.arguments: + accumulated[idx]["function"]["arguments"] += delta.function.arguments + + def _finalize_tool_calls(self, accumulated: dict) -> List[dict]: + """Convert accumulated tool calls to list.""" return [ - { - "id": tc["id"], - "type": tc["type"], - "function": { - "name": tc["function"]["name"], - "arguments": tc["function"]["arguments"], - }, - } - for tc in ( - tool_calls_accumulated[idx] for idx in sorted(tool_calls_accumulated.keys()) - ) + {"id": v["id"], "type": "function", "function": v["function"]} + for v in (accumulated[k] for k in sorted(accumulated)) ] - def _create_completion_request( - self, - messages: List[dict], - tools: List[dict], - max_tokens: int, - temperature: float, - top_p: float, - stream: bool = False, - ): - """Create a completion request with common parameters.""" + def _create_stream_request(self, messages, tools, max_tokens, temperature, top_p): + """Create streaming chat completion request.""" kwargs = { "model": self.model, "messages": messages, "max_completion_tokens": max_tokens, "temperature": temperature, "top_p": top_p, + "stream": True, + "stream_options": {"include_usage": True}, } if tools: kwargs["tools"] = tools kwargs["tool_choice"] = "auto" - if stream: - kwargs["stream"] = True - kwargs["stream_options"] = {"include_usage": True} return self.client.chat.completions.create(**kwargs) - def _bridge_async_generator(self, async_gen_func): - """Bridge an async generator to a sync generator using efficient queue-based approach.""" - pool = self._get_mcp_pool() - loop = pool._get_event_loop() - - # Use a bounded queue to apply backpressure - queue: asyncio.Queue = asyncio.Queue(maxsize=100) - done = threading.Event() - exception_holder: List[Optional[Exception]] = [None] + def _bridge_async_gen(self, async_gen_fn): + """Bridge async generator to sync generator.""" + pool = self.get_pool() + loop = pool._loop + queue = asyncio.Queue() + error_holder = [None] async def producer(): try: - async for item in async_gen_func(): + async for item in async_gen_fn(): await queue.put(item) except Exception as e: - exception_holder[0] = e + error_holder[0] = e finally: - await queue.put(None) # Sentinel - done.set() + await queue.put(None) - # Start producer asyncio.run_coroutine_threadsafe(producer(), loop) - try: - while True: - # Get from queue with timeout - try: - get_future = asyncio.run_coroutine_threadsafe(queue.get(), loop) - item = get_future.result(timeout=60.0) - except Exception: - if exception_holder[0]: - raise exception_holder[0] - raise - - if item is None: - break - yield item + while True: + future = asyncio.run_coroutine_threadsafe(queue.get(), loop) + item = future.result(timeout=120.0) + if item is None: + if error_holder[0]: + raise error_holder[0] + break + yield item - if exception_holder[0]: - raise exception_holder[0] - finally: - done.wait(timeout=1.0) + # === Streaming with MCP === - async def _stream_with_mcp_tools_json( - self, - openai_messages: List[dict], - tools: List[dict], - connections: Dict[str, MCPConnection], - max_tokens: int, - temperature: float, - top_p: float, - tool_to_server: Dict[str, str], + async def _stream_chat_with_tools( + self, messages, tools, connections, tool_to_server, max_tokens, temperature, top_p ): - """Async generator for streaming chat completions with MCP tools.""" - tool_calls_accumulated = {} - streaming_response = "" - - stream = self._create_completion_request( - openai_messages, tools, max_tokens, temperature, top_p, stream=True - ) + """Stream chat completions with MCP tool support.""" + accumulated_tools = {} + assistant_content = "" - for chunk in stream: + for chunk in self._create_stream_request(messages, tools, max_tokens, temperature, top_p): self._set_usage(chunk) yield chunk.model_dump_json() if chunk.choices: delta = chunk.choices[0].delta if delta.tool_calls: - for tool_call_delta in delta.tool_calls: - self._accumulate_tool_call_delta(tool_call_delta, tool_calls_accumulated) + for tc in delta.tool_calls: + self._accumulate_tool_delta(tc, accumulated_tools) if delta.content: - streaming_response += delta.content + assistant_content += delta.content - if tool_calls_accumulated: - tool_calls_list = self._convert_accumulated_tool_calls(tool_calls_accumulated) - openai_messages.append( + if accumulated_tools: + tool_calls = self._finalize_tool_calls(accumulated_tools) + messages.append( { "role": "assistant", - "content": streaming_response or None, - "tool_calls": tool_calls_list, + "content": assistant_content or None, + "tool_calls": tool_calls, } ) + await self._execute_chat_tools_async(tool_calls, connections, messages, tool_to_server) - # Execute tools in parallel - await self._execute_tool_calls_async( - tool_calls_list, connections, openai_messages, tool_to_server - ) - - async for chunk_json in self._stream_with_mcp_tools_json( - openai_messages, tools, connections, max_tokens, temperature, top_p, tool_to_server + async for chunk in self._stream_chat_with_tools( + messages, tools, connections, tool_to_server, max_tokens, temperature, top_p ): - yield chunk_json + yield chunk - async def _stream_responses_with_mcp_tools_json( - self, - request_data: Dict[str, Any], - tools: List[dict], - connections: Dict[str, MCPConnection], - tool_to_server: Dict[str, str], - ): - """Async generator for streaming response API with MCP tools.""" + async def _stream_responses_with_tools(self, request_data, tools, connections, tool_to_server): + """Stream response API with MCP tool support.""" input_data = request_data.get("input", "") - if isinstance(input_data, str): - input_items = [ + input_items = ( + [ { "type": "message", "role": "user", "content": [{"type": "input_text", "text": input_data}], } ] - else: - input_items = input_data if isinstance(input_data, list) else [] + if isinstance(input_data, str) + else (input_data if isinstance(input_data, list) else []) + ) response_args = {**request_data, "model": self.model} if tools: - response_api_tools = self._convert_tools_to_response_api_format(tools) - response_args["tools"] = response_api_tools + response_args["tools"] = self._to_response_api_tools(tools) response_args["tool_choice"] = response_args.get("tool_choice", "auto") - stream = self.client.responses.create(**response_args) accumulated_output = [] - tool_calls_accumulated = {} - original_to_filtered_index_map = {} + tool_calls_by_id = {} + msg_index_map = {} - for chunk in stream: + for chunk in self.client.responses.create(**response_args): self._set_usage(chunk) - chunk_type = getattr(chunk, 'type', None) or chunk.__class__.__name__ - - should_yield = True - item_to_check = None + chunk_type = getattr(chunk, 'type', '') or chunk.__class__.__name__ - # Process chunk based on type (condensed for brevity - same logic as before) + # Track message indices for filtering if chunk_type in ( 'response.output_item.added', 'ResponseOutputItemAddedEvent', ) and hasattr(chunk, 'item'): - item_to_check = chunk.item - if hasattr(chunk, 'output_index'): - item_dict = self._to_dict(item_to_check) - if item_dict.get("type") == "message": - original_to_filtered_index_map[chunk.output_index] = len( - original_to_filtered_index_map - ) + item_dict = self._to_dict(chunk.item) + if item_dict.get('type') == 'message' and hasattr(chunk, 'output_index'): + msg_index_map[chunk.output_index] = len(msg_index_map) + elif item_dict.get('type') in ( + 'function_tool_call', + 'function_call', + 'function', + 'tool_call', + ): + item_id = item_dict.get('id') or item_dict.get('call_id') + if item_id: + tool_calls_by_id[item_id] = { + 'id': item_id, + 'call_id': item_dict.get('call_id'), + 'type': item_dict.get('type'), + 'name': item_dict.get('name', ''), + 'arguments': item_dict.get('arguments', ''), + 'status': 'in_progress', + } - # Track tool calls - if item_dict.get("type") in [ - "function_tool_call", - "function_call", - "function", - "tool_call", - ]: - item_id = item_dict.get("id") or item_dict.get("call_id") - if item_id: - tool_calls_accumulated[item_id] = { - "id": item_id, - "call_id": item_dict.get("call_id"), - "type": item_dict.get("type"), - "name": item_dict.get("name", ""), - "arguments": item_dict.get("arguments", ""), - "status": item_dict.get("status", "in_progress"), - } + # Accumulate arguments + elif chunk_type in ( + 'response.function_call_arguments.delta', + 'ResponseFunctionCallArgumentsDeltaEvent', + ): + item_id = getattr(chunk, 'item_id', None) + if item_id and item_id in tool_calls_by_id: + tool_calls_by_id[item_id]['arguments'] += getattr(chunk, 'delta', '') + + elif chunk_type in ( + 'response.function_call_arguments.done', + 'ResponseFunctionCallArgumentsDoneEvent', + ): + item_id = getattr(chunk, 'item_id', None) + if item_id and item_id in tool_calls_by_id: + tool_calls_by_id[item_id]['arguments'] = getattr(chunk, 'arguments', '') + # Mark tool call complete elif chunk_type in ( 'response.output_item.done', 'ResponseOutputItemDoneEvent', ) and hasattr(chunk, 'item'): - item_to_check = chunk.item - - elif hasattr(chunk, 'output_index'): - if chunk.output_index not in original_to_filtered_index_map: - should_yield = False - - # Filter non-message items - if item_to_check: - item_dict = self._to_dict(item_to_check) - if item_dict.get("type") != "message": - should_yield = False + item_dict = self._to_dict(chunk.item) + item_type = item_dict.get('type') + if item_type in ('function_tool_call', 'function_call', 'function', 'tool_call'): + item_id = item_dict.get('id') + if item_id and item_id in tool_calls_by_id: + tool_calls_by_id[item_id]['status'] = 'completed' + if 'call_id' in item_dict: + tool_calls_by_id[item_id]['call_id'] = item_dict['call_id'] + accumulated_output.append(tool_calls_by_id[item_id]) + else: + accumulated_output.append(item_dict) - # Handle response.completed - if chunk_type in ('response.completed', 'ResponseCompletedEvent') and hasattr( + # Handle completed response - filter to messages only + elif chunk_type in ('response.completed', 'ResponseCompletedEvent') and hasattr( chunk, 'response' ): - response = chunk.response - if hasattr(response, 'output') and response.output: - filtered_output, accumulated_output = self._process_response_output( - response.output, accumulated_output, tool_calls_accumulated - ) - - response_dict = self._to_dict(response) - response_dict["output"] = filtered_output - + resp = chunk.response + if hasattr(resp, 'output') and resp.output: + filtered = [ + self._to_dict(i) + for i in resp.output + if self._to_dict(i).get('type') == 'message' + ] + resp_dict = self._to_dict(resp) + resp_dict['output'] = filtered yield json.dumps( { - "type": "response.completed", - "sequence_number": getattr(chunk, 'sequence_number', None), - "response": response_dict, + 'type': 'response.completed', + 'sequence_number': getattr(chunk, 'sequence_number', None), + 'response': resp_dict, } ) + continue + + # Yield message-related chunks with remapped indices + should_yield = True + if hasattr(chunk, 'output_index'): + if chunk.output_index not in msg_index_map: + should_yield = False else: - yield chunk.model_dump_json() - elif should_yield: - if ( - hasattr(chunk, 'output_index') - and chunk.output_index in original_to_filtered_index_map - ): chunk_dict = self._to_dict(chunk) - chunk_dict["output_index"] = original_to_filtered_index_map[chunk.output_index] + chunk_dict['output_index'] = msg_index_map[chunk.output_index] yield json.dumps(chunk_dict) - elif not hasattr(chunk, 'output_index'): - yield chunk.model_dump_json() + continue - # Handle argument deltas - if chunk_type in ( + if should_yield and chunk_type not in ( 'response.function_call_arguments.delta', 'ResponseFunctionCallArgumentsDeltaEvent', - ): - item_id = getattr(chunk, 'item_id', None) - if item_id and item_id in tool_calls_accumulated: - tool_calls_accumulated[item_id]["arguments"] += getattr(chunk, 'delta', '') - - elif chunk_type in ( 'response.function_call_arguments.done', 'ResponseFunctionCallArgumentsDoneEvent', ): - item_id = getattr(chunk, 'item_id', None) - if item_id and item_id in tool_calls_accumulated: - tool_calls_accumulated[item_id]["arguments"] = getattr(chunk, 'arguments', '') - - # Handle output item done - elif chunk_type in ( - 'response.output_item.done', - 'ResponseOutputItemDoneEvent', - ) and hasattr(chunk, 'item'): - item_dict = self._to_dict(chunk.item) - item_type = item_dict.get("type") - - if item_type in ["function_tool_call", "function_call", "function", "tool_call"]: - item_id = item_dict.get("id") - if item_id and item_id in tool_calls_accumulated: - tool_calls_accumulated[item_id]["status"] = item_dict.get( - "status", "completed" - ) - if "call_id" in item_dict: - tool_calls_accumulated[item_id]["call_id"] = item_dict.get("call_id") - accumulated_output.append(tool_calls_accumulated[item_id]) - else: - accumulated_output.append(item_dict) + item = getattr(chunk, 'item', None) + if item: + if self._to_dict(item).get('type') not in ( + 'function_tool_call', + 'function_call', + 'function', + 'tool_call', + ): + yield chunk.model_dump_json() else: - accumulated_output.append(item_dict) + yield chunk.model_dump_json() - # Add remaining accumulated tool calls - for call_id, call_data in tool_calls_accumulated.items(): - if call_data.get("name"): - existing_ids = {self._get_id(i) for i in accumulated_output} - if call_id not in existing_ids: - accumulated_output.append(call_data) + # Add any remaining tool calls + for tc in tool_calls_by_id.values(): + if tc.get('name') and tc['id'] not in { + self._to_dict(o).get('id') for o in accumulated_output + }: + accumulated_output.append(tc) # Execute tool calls if any - tool_calls = self._extract_tool_calls_from_response_output(accumulated_output) + tool_calls = self._parse_response_tool_calls(accumulated_output) if tool_calls: - model_output_items = self._convert_output_items_to_input_items(accumulated_output) - input_items.extend(model_output_items) - - # Execute tools in parallel - await self._execute_response_api_tool_calls_async( + input_items.extend(self._convert_output_to_input(accumulated_output)) + await self._execute_response_tools_async( tool_calls, connections, input_items, tool_to_server ) + request_data['input'] = input_items - request_data["input"] = input_items - - async for chunk_json in self._stream_responses_with_mcp_tools_json( + async for chunk in self._stream_responses_with_tools( request_data, tools, connections, tool_to_server ): - yield chunk_json + yield chunk - def _to_dict(self, obj: Any) -> dict: - """Convert object to dictionary.""" - if isinstance(obj, dict): - return obj - if hasattr(obj, 'model_dump'): - return obj.model_dump() - if hasattr(obj, 'dict'): - return obj.dict() - if hasattr(obj, '__dict__'): - return obj.__dict__ - return {} - - def _get_id(self, item: Any) -> Optional[str]: - """Get ID from an item.""" - if isinstance(item, dict): - return item.get("id") - return getattr(item, "id", None) - - def _process_response_output( - self, - output: List[Any], - accumulated_output: List[Dict], - tool_calls_accumulated: Dict[str, Dict], - ) -> Tuple[List[Dict], List[Dict]]: - """Process response output, filtering messages and accumulating tool calls.""" - filtered_output = [] - - for item in output: - item_dict = self._to_dict(item) - item_type = item_dict.get("type") - - if item_type == "message": - filtered_output.append(item_dict) - elif item_type in ["function_tool_call", "function_call", "function", "tool_call"]: - item_id = item_dict.get("id") - existing_ids = {self._get_id(i) for i in accumulated_output} - if not item_id or item_id not in existing_ids: - accumulated_output.append(item_dict) - else: - item_id = item_dict.get("id") - existing_ids = {self._get_id(i) for i in accumulated_output} - if not item_id or item_id not in existing_ids: - accumulated_output.append(item_dict) - - return filtered_output, accumulated_output + # === Main OpenAI Methods === @ModelClass.method def openai_transport(self, msg: str) -> str: - """Process an OpenAI-compatible request.""" + """Handle non-streaming OpenAI requests.""" try: - request_data = from_json(msg) - request_data = self._update_old_fields(request_data) - mcp_servers = request_data.pop("mcp_servers", None) - endpoint = request_data.pop("openai_endpoint", self.DEFAULT_ENDPOINT) - tools = request_data.get("tools") - - if mcp_servers and len(mcp_servers) > 0 and tools is None: - logger.info(f"Getting tools for MCP servers: {mcp_servers}") - tools_local, connections, tool_to_server = self._get_mcp_tools_and_clients( - mcp_servers - ) + data = from_json(msg) + data = self._update_old_fields(data) + mcp_servers = data.pop("mcp_servers", None) + endpoint = data.pop("openai_endpoint", self.DEFAULT_ENDPOINT) + + if mcp_servers and data.get("tools") is None: + pool = self.get_pool() + tools, connections, tool_to_server = pool.get_tools_and_mapping(mcp_servers) if endpoint == self.ENDPOINT_CHAT_COMPLETIONS: - response = self._route_request( - endpoint, request_data, mcp_servers, connections, tools_local - ) + response = self._route_request(endpoint, data, mcp_servers, connections, tools) while response.choices and response.choices[0].message.tool_calls: - messages = request_data.get("messages", []) + messages = data.get("messages", []) messages.append(response.choices[0].message) - self._execute_tool_calls( + self._execute_chat_tools( response.choices[0].message.tool_calls, connections, messages, tool_to_server, ) - request_data["messages"] = messages + data["messages"] = messages response = self._route_request( - endpoint, request_data, mcp_servers, connections, tools_local + endpoint, data, mcp_servers, connections, tools ) elif endpoint == self.ENDPOINT_RESPONSES: - response = self._route_request( - endpoint, request_data, mcp_servers, connections, tools_local - ) + response = self._route_request(endpoint, data, mcp_servers, connections, tools) - input_data = request_data.get("input", "") - if isinstance(input_data, str): - input_items = [ + input_data = data.get("input", "") + input_items = ( + [ { "type": "message", "role": "user", "content": [{"type": "input_text", "text": input_data}], } ] - else: - input_items = input_data if isinstance(input_data, list) else [] + if isinstance(input_data, str) + else (input_data if isinstance(input_data, list) else []) + ) - response_output = response.output if hasattr(response, 'output') else [] - tool_calls = self._extract_tool_calls_from_response_output(response_output) + output = response.output if hasattr(response, 'output') else [] + tool_calls = self._parse_response_tool_calls(output) while tool_calls: - model_output_items = self._convert_output_items_to_input_items( - response_output - ) - input_items.extend(model_output_items) - self._execute_response_api_tool_calls( + input_items.extend(self._convert_output_to_input(output)) + self._execute_response_tools( tool_calls, connections, input_items, tool_to_server ) - request_data["input"] = input_items + data["input"] = input_items response = self._route_request( - endpoint, request_data, mcp_servers, connections, tools_local + endpoint, data, mcp_servers, connections, tools ) - response_output = response.output if hasattr(response, 'output') else [] - tool_calls = self._extract_tool_calls_from_response_output(response_output) + output = response.output if hasattr(response, 'output') else [] + tool_calls = self._parse_response_tool_calls(output) else: - response = self._route_request(endpoint, request_data) + response = self._route_request(endpoint, data) else: - response = self._route_request(endpoint, request_data) + response = self._route_request(endpoint, data) - self._finalize_token_usage() + self._finalize_tokens() return response.model_dump_json() + except Exception as e: logger.exception(e) return to_json( @@ -1438,59 +965,55 @@ def openai_transport(self, msg: str) -> str: @ModelClass.method def openai_stream_transport(self, msg: str) -> Iterator[str]: - """Process an OpenAI-compatible request with streaming.""" + """Handle streaming OpenAI requests.""" try: - request_data = from_json(msg) - request_data = self._update_old_fields(request_data) - mcp_servers = request_data.pop("mcp_servers", None) - endpoint = request_data.pop("openai_endpoint", self.DEFAULT_ENDPOINT) - - if endpoint not in [self.ENDPOINT_CHAT_COMPLETIONS, self.ENDPOINT_RESPONSES]: - raise ValueError("Streaming only supported for chat completions and responses.") - - if mcp_servers and len(mcp_servers) > 0 and request_data.get("tools") is None: - logger.info(f"Getting tools for MCP servers: {mcp_servers}") - tools_local, connections, tool_to_server = self._get_mcp_tools_and_clients( - mcp_servers - ) + data = from_json(msg) + data = self._update_old_fields(data) + mcp_servers = data.pop("mcp_servers", None) + endpoint = data.pop("openai_endpoint", self.DEFAULT_ENDPOINT) - async def stream_generator(): - if endpoint == self.ENDPOINT_CHAT_COMPLETIONS: - messages = request_data.get("messages", []) - async for chunk_json in self._stream_with_mcp_tools_json( - messages, - tools_local, + if endpoint not in (self.ENDPOINT_CHAT_COMPLETIONS, self.ENDPOINT_RESPONSES): + raise ValueError("Streaming only for chat completions and responses") + + if mcp_servers and data.get("tools") is None: + pool = self.get_pool() + tools, connections, tool_to_server = pool.get_tools_and_mapping(mcp_servers) + + if endpoint == self.ENDPOINT_CHAT_COMPLETIONS: + yield from self._bridge_async_gen( + lambda: self._stream_chat_with_tools( + data.get("messages", []), + tools, connections, - request_data.get("max_completion_tokens", 4096), - request_data.get("temperature", 1.0), - request_data.get("top_p", 1.0), tool_to_server, - ): - yield chunk_json - self._finalize_token_usage() - elif endpoint == self.ENDPOINT_RESPONSES: - async for chunk_json in self._stream_responses_with_mcp_tools_json( - request_data, tools_local, connections, tool_to_server - ): - yield chunk_json - self._finalize_token_usage() - - yield from self._bridge_async_generator(stream_generator) + data.get("max_completion_tokens", 4096), + data.get("temperature", 1.0), + data.get("top_p", 1.0), + ) + ) + else: + yield from self._bridge_async_gen( + lambda: self._stream_responses_with_tools( + data, tools, connections, tool_to_server + ) + ) + + self._finalize_tokens() return - # Non-MCP path + # Non-MCP streaming if endpoint == self.ENDPOINT_RESPONSES: - response_args = {**request_data, "model": self.model} - for chunk in self.client.responses.create(**response_args): + for chunk in self.client.responses.create(**{**data, "model": self.model}): self._set_usage(chunk) yield chunk.model_dump_json() else: - completion_args = self._create_completion_args(request_data) - for chunk in self.client.chat.completions.create(**completion_args): + for chunk in self.client.chat.completions.create( + **self._create_completion_args(data) + ): self._set_usage(chunk) yield chunk.model_dump_json() - self._finalize_token_usage() + self._finalize_tokens() except Exception as e: logger.exception(e) From a21787b7619e6c53916a2c6209c921b165d04521 Mon Sep 17 00:00:00 2001 From: Luv Bansal Date: Tue, 9 Dec 2025 18:35:38 +0530 Subject: [PATCH 05/13] simplify code --- clarifai/runners/models/agentic_class.py | 150 +++++++++++++---------- 1 file changed, 83 insertions(+), 67 deletions(-) diff --git a/clarifai/runners/models/agentic_class.py b/clarifai/runners/models/agentic_class.py index a6d9b06d..19f14509 100644 --- a/clarifai/runners/models/agentic_class.py +++ b/clarifai/runners/models/agentic_class.py @@ -18,7 +18,7 @@ @dataclass class MCPConnection: - """Single MCP server connection with metadata.""" + """Single MCP server connection.""" client: Any tools: List[Any] @@ -27,21 +27,21 @@ class MCPConnection: last_used: float = field(default_factory=time.time) def touch(self): - """Mark as recently used.""" self.last_used = time.time() class MCPConnectionPool: - """Thread-safe connection pool for MCP servers.""" + """Thread-safe connection pool with passive idle cleanup.""" _instance: Optional['MCPConnectionPool'] = None _instance_lock = threading.Lock() - # Timeouts + # Timeouts and thresholds CONNECT_TIMEOUT = 30.0 TOOL_CALL_TIMEOUT = 60.0 - VERIFY_IDLE_THRESHOLD = 60.0 # Verify if idle > 60s - MAX_IDLE_TIME = 600.0 # Remove if idle > 10min + VERIFY_IDLE_THRESHOLD = 60.0 # Verify connections idle > 60s + MAX_IDLE_TIME = 600.0 # Remove connections idle > 10min + CLEANUP_INTERVAL = 120.0 # Check cleanup at most every 2min def __new__(cls): if cls._instance is None: @@ -58,9 +58,12 @@ def __init__(self): self._connections: Dict[str, MCPConnection] = {} self._lock = threading.RLock() - # Tool caches for O(1) lookup + # Tool caches self._tool_to_url: Dict[str, str] = {} - self._all_tools: Dict[str, dict] = {} # tool_name -> OpenAI format + self._all_tools: Dict[str, dict] = {} + + # Cleanup tracking + self._last_cleanup = 0.0 # Background event loop self._loop: Optional[asyncio.AbstractEventLoop] = None @@ -70,7 +73,7 @@ def __init__(self): self._initialized = True def _start_event_loop(self): - """Start background event loop for async operations.""" + """Start background event loop.""" ready = threading.Event() def run(): @@ -83,13 +86,71 @@ def run(): self._loop_thread.start() ready.wait(timeout=5.0) - def _run_async(self, coro, timeout: float = CONNECT_TIMEOUT) -> Any: + def _run_async(self, coro, timeout: float = 30.0) -> Any: """Run coroutine in background loop.""" if self._loop is None or self._loop.is_closed(): self._start_event_loop() future = asyncio.run_coroutine_threadsafe(coro, self._loop) return future.result(timeout=timeout) + # ==================== Cleanup Logic ==================== + + def _maybe_cleanup_idle(self): + """Passive cleanup - removes connections idle too long. + + Called at the start of get_connections() to clean up + without needing a background thread. + """ + now = time.time() + + # Rate limit cleanup checks + if now - self._last_cleanup < self.CLEANUP_INTERVAL: + return + + self._last_cleanup = now + + # Find idle connections + with self._lock: + to_remove = [ + url + for url, conn in self._connections.items() + if now - conn.last_used > self.MAX_IDLE_TIME + ] + + # Remove them (outside lock to avoid deadlock during async close) + for url in to_remove: + self._disconnect(url) + + def _disconnect(self, url: str): + """Disconnect and remove a connection.""" + with self._lock: + conn = self._connections.pop(url, None) + + # Invalidate tool cache entries for this URL + if conn: + for tool_name in conn.tool_names: + self._tool_to_url.pop(tool_name, None) + self._all_tools.pop(tool_name, None) + + if conn: + try: + self._run_async(self._close_connection(conn), timeout=10.0) + logger.info(f"Disconnected idle connection from {url}") + except Exception as e: + logger.warning(f"Error disconnecting from {url}: {e}") + + async def _close_connection(self, conn: MCPConnection): + """Close a connection gracefully.""" + try: + if hasattr(conn.client, 'close'): + await asyncio.wait_for(conn.client.close(), timeout=5.0) + else: + await asyncio.wait_for(conn.client.__aexit__(None, None, None), timeout=5.0) + except Exception as e: + logger.warning(f"Error closing connection to {conn.url}: {e}") + + # ==================== Connection Management ==================== + async def _create_connection(self, url: str) -> MCPConnection: """Create new MCP connection.""" try: @@ -122,22 +183,12 @@ async def _verify_connection(self, conn: MCPConnection) -> bool: except Exception: return False - async def _close_connection(self, conn: MCPConnection): - """Close a connection gracefully.""" - try: - if hasattr(conn.client, 'close'): - await asyncio.wait_for(conn.client.close(), timeout=5.0) - else: - await asyncio.wait_for(conn.client.__aexit__(None, None, None), timeout=5.0) - except Exception as e: - logger.warning(f"Error closing connection to {conn.url}: {e}") - def _needs_verification(self, conn: MCPConnection) -> bool: - """Check if connection needs verification based on idle time.""" + """Check if connection should be verified.""" return time.time() - conn.last_used > self.VERIFY_IDLE_THRESHOLD def _update_tool_cache(self, conn: MCPConnection): - """Update tool caches from connection.""" + """Cache tool info from connection.""" for tool in conn.tools: self._tool_to_url[tool.name] = conn.url self._all_tools[tool.name] = { @@ -150,18 +201,15 @@ def _update_tool_cache(self, conn: MCPConnection): } def get_connections(self, urls: List[str]) -> Dict[str, MCPConnection]: - """Get connections for URLs, creating/verifying as needed. + """Get connections for URLs, with passive cleanup.""" + # Passive cleanup of idle connections + self._maybe_cleanup_idle() - This is the main entry point. It: - 1. Returns cached connections if recently used - 2. Verifies cached connections if idle - 3. Creates new connections if needed - All operations happen in parallel where possible. - """ result = {} to_verify = [] to_create = [] + # Categorize URLs with self._lock: for url in urls: if url in self._connections: @@ -195,7 +243,7 @@ async def verify_all(): conn.touch() result[url] = conn else: - # Remove invalid connection, will recreate + # Invalid - remove and recreate self._connections.pop(url, None) to_create.append(url) except Exception as e: @@ -229,7 +277,7 @@ async def create_all(): def get_tools_and_mapping( self, urls: List[str] ) -> Tuple[List[dict], Dict[str, MCPConnection], Dict[str, str]]: - """Get tools, connections, and tool-to-server mapping.""" + """Get tools, connections, and mapping.""" connections = self.get_connections(urls) tools = [] @@ -256,6 +304,8 @@ def get_tools_and_mapping( logger.info(f"Loaded {len(tools)} tools from {len(connections)} servers") return tools, connections, tool_to_server + # ==================== Tool Execution ==================== + async def call_tool_async( self, tool_name: str, @@ -264,6 +314,7 @@ async def call_tool_async( tool_to_server: Dict[str, str], ) -> Any: """Call a tool asynchronously.""" + logger.info(f"Calling tool {tool_name}") url = tool_to_server.get(tool_name) or self._tool_to_url.get(tool_name) if not url or url not in connections: raise ValueError(f"Tool '{tool_name}' not found") @@ -302,7 +353,6 @@ async def call_one(call_id: str, name: str, args: Dict): result = await self.call_tool_async(name, args, connections, tool_to_server) return (call_id, result, None) except Exception as e: - logger.error(f"Error calling tool {name}: {e}") return (call_id, None, str(e)) tasks = [call_one(cid, name, args) for cid, name, args in calls] @@ -314,41 +364,12 @@ def call_tools_batch( connections: Dict[str, MCPConnection], tool_to_server: Dict[str, str], ) -> List[Tuple[str, Optional[Any], Optional[str]]]: - """Call multiple tools in parallel (sync wrapper).""" + """Call multiple tools in parallel (sync).""" return self._run_async( self.call_tools_batch_async(calls, connections, tool_to_server), timeout=self.TOOL_CALL_TIMEOUT + 10, ) - def cleanup_idle(self): - """Remove connections idle for too long.""" - now = time.time() - with self._lock: - to_remove = [ - url - for url, conn in self._connections.items() - if now - conn.last_used > self.MAX_IDLE_TIME - ] - for url in to_remove: - conn = self._connections.pop(url) - self._run_async(self._close_connection(conn), timeout=10.0) - logger.info(f"Removed idle connection to {url}") - - def shutdown(self): - """Close all connections and stop event loop.""" - with self._lock: - for conn in self._connections.values(): - try: - self._run_async(self._close_connection(conn), timeout=5.0) - except Exception: - pass - self._connections.clear() - self._tool_to_url.clear() - self._all_tools.clear() - - if self._loop and not self._loop.is_closed(): - self._loop.call_soon_threadsafe(self._loop.stop) - class AgenticModelClass(OpenAIModelClass): """Base class for wrapping OpenAI-compatible servers with MCP (Model Context Protocol) support. @@ -382,11 +403,6 @@ def get_pool(cls) -> MCPConnectionPool: cls._pool = MCPConnectionPool() return cls._pool - @classmethod - def warm_up(cls, mcp_servers: List[str]): - """Pre-connect to MCP servers.""" - cls.get_pool().get_connections(mcp_servers) - # === Token Tracking === def _init_tokens(self): From d806288fbc08ad40f0b510ce2061aeba3a259aa8 Mon Sep 17 00:00:00 2001 From: Luv Bansal Date: Mon, 15 Dec 2025 13:14:09 +0530 Subject: [PATCH 06/13] increase idle time and verify duration --- clarifai/runners/models/agentic_class.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/clarifai/runners/models/agentic_class.py b/clarifai/runners/models/agentic_class.py index 19f14509..981364c2 100644 --- a/clarifai/runners/models/agentic_class.py +++ b/clarifai/runners/models/agentic_class.py @@ -39,9 +39,9 @@ class MCPConnectionPool: # Timeouts and thresholds CONNECT_TIMEOUT = 30.0 TOOL_CALL_TIMEOUT = 60.0 - VERIFY_IDLE_THRESHOLD = 60.0 # Verify connections idle > 60s - MAX_IDLE_TIME = 600.0 # Remove connections idle > 10min - CLEANUP_INTERVAL = 120.0 # Check cleanup at most every 2min + VERIFY_IDLE_THRESHOLD = 60 * 2 # Verify connections idle > 2min + MAX_IDLE_TIME = 60 * 15 # Remove connections idle > 15min + CLEANUP_INTERVAL = 60 * 2 # Check cleanup at most every 2min def __new__(cls): if cls._instance is None: From 91368128a13f2ca493e6f1e38d1713e698fe3427 Mon Sep 17 00:00:00 2001 From: Luv Bansal Date: Mon, 15 Dec 2025 14:07:28 +0530 Subject: [PATCH 07/13] Fix co-pilot comments --- clarifai/runners/models/agentic_class.py | 119 ++++++++++++++++++----- 1 file changed, 94 insertions(+), 25 deletions(-) diff --git a/clarifai/runners/models/agentic_class.py b/clarifai/runners/models/agentic_class.py index 981364c2..eda4bb8f 100644 --- a/clarifai/runners/models/agentic_class.py +++ b/clarifai/runners/models/agentic_class.py @@ -36,12 +36,17 @@ class MCPConnectionPool: _instance: Optional['MCPConnectionPool'] = None _instance_lock = threading.Lock() - # Timeouts and thresholds - CONNECT_TIMEOUT = 30.0 - TOOL_CALL_TIMEOUT = 60.0 - VERIFY_IDLE_THRESHOLD = 60 * 2 # Verify connections idle > 2min - MAX_IDLE_TIME = 60 * 15 # Remove connections idle > 15min - CLEANUP_INTERVAL = 60 * 2 # Check cleanup at most every 2min + # Timeouts and thresholds (configurable via environment variables) + # Default: 30s. Time to wait for a connection to be established. Increase if MCP servers are slow to respond. + CONNECT_TIMEOUT = float(os.environ.get("CLARIFAI_MCP_CONNECT_TIMEOUT", 30.0)) + # Default: 60s. Maximum time to wait for a tool call to complete. Increase for long-running tools. + TOOL_CALL_TIMEOUT = float(os.environ.get("CLARIFAI_MCP_TOOL_CALL_TIMEOUT", 60.0)) + # Default: 2min. Connections idle for more than this are verified before reuse. + VERIFY_IDLE_THRESHOLD = float(os.environ.get("CLARIFAI_MCP_VERIFY_IDLE_THRESHOLD", 60 * 2)) + # Default: (15min). Connections idle for more than this are removed from the pool. + MAX_IDLE_TIME = float(os.environ.get("CLARIFAI_MCP_MAX_IDLE_TIME", 15 * 60)) + # Default: 2min. Cleanup runs at most this often to remove idle connections. + CLEANUP_INTERVAL = float(os.environ.get("CLARIFAI_MCP_CLEANUP_INTERVAL", 2 * 60)) def __new__(cls): if cls._instance is None: @@ -84,7 +89,8 @@ def run(): self._loop_thread = threading.Thread(target=run, daemon=True, name="mcp_pool") self._loop_thread.start() - ready.wait(timeout=5.0) + if not ready.wait(timeout=5.0): + raise RuntimeError("Background event loop failed to start within 5 seconds") def _run_async(self, coro, timeout: float = 30.0) -> Any: """Run coroutine in background loop.""" @@ -515,10 +521,16 @@ def _execute_chat_tools( for call_id, result, error in results: if error: content = f"Error: {error}" + elif ( + hasattr(result, 'content') + and len(result.content) > 0 + and result.content[0].get('text') + ): + content = result.content[0].text + elif len(result) > 0 and result[0].get('text'): + content = result[0].text else: - content = ( - result.content[0].text if hasattr(result, 'content') else str(result[0].text) - ) + content = None messages.append({"role": "tool", "tool_call_id": call_id, "content": content}) async def _execute_chat_tools_async( @@ -536,10 +548,16 @@ async def _execute_chat_tools_async( for call_id, result, error in results: if error: content = f"Error: {error}" + elif ( + hasattr(result, 'content') + and len(result.content) > 0 + and result.content[0].get('text') + ): + content = result.content[0].text + elif len(result) > 0 and result[0].get('text'): + content = result[0].text else: - content = ( - result.content[0].text if hasattr(result, 'content') else str(result[0].text) - ) + content = None messages.append({"role": "tool", "tool_call_id": call_id, "content": content}) def _execute_response_tools( @@ -556,10 +574,16 @@ def _execute_response_tools( for call_id, result, error in results: if error: output = f"Error: {error}" + elif ( + hasattr(result, 'content') + and len(result.content) > 0 + and result.content[0].get('text') + ): + output = result.content[0].text + elif len(result) > 0 and result[0].get('text'): + output = result[0].text else: - output = ( - result.content[0].text if hasattr(result, 'content') else str(result[0].text) - ) + output = None input_items.append( {"type": "function_call_output", "call_id": call_id, "output": output} ) @@ -578,10 +602,16 @@ async def _execute_response_tools_async( for call_id, result, error in results: if error: output = f"Error: {error}" + elif ( + hasattr(result, 'content') + and len(result.content) > 0 + and result.content[0].get('text') + ): + output = result.content[0].text + elif len(result) > 0 and result[0].get('text'): + output = result[0].text else: - output = ( - result.content[0].text if hasattr(result, 'content') else str(result[0].text) - ) + output = None input_items.append( {"type": "function_call_output", "call_id": call_id, "output": output} ) @@ -676,7 +706,7 @@ def _create_stream_request(self, messages, tools, max_tokens, temperature, top_p kwargs["tool_choice"] = "auto" return self.client.chat.completions.create(**kwargs) - def _bridge_async_gen(self, async_gen_fn): + def _async_to_sync_generator(self, async_gen_fn): """Bridge async generator to sync generator.""" pool = self.get_pool() loop = pool._loop @@ -708,7 +738,23 @@ async def producer(): async def _stream_chat_with_tools( self, messages, tools, connections, tool_to_server, max_tokens, temperature, top_p ): - """Stream chat completions with MCP tool support.""" + """ + Stream chat completions with MCP tool support, recursively handling tool calls. + This method streams chat completion chunks, accumulating any tool calls generated by the model. + If tool calls are present after streaming, it executes those tools and recursively continues + streaming with the updated messages (including tool call results). The recursion terminates + when no further tool calls are generated in the streamed response. + Args: + messages: The list of chat messages so far. + tools: The list of available tools. + connections: MCP tool connections. + tool_to_server: Mapping of tool names to server URLs. + max_tokens: Maximum number of tokens to generate. + temperature: Sampling temperature. + top_p: Nucleus sampling parameter. + Yields: + JSON-serialized chat completion chunks. + """ accumulated_tools = {} assistant_content = "" @@ -741,7 +787,26 @@ async def _stream_chat_with_tools( yield chunk async def _stream_responses_with_tools(self, request_data, tools, connections, tool_to_server): - """Stream response API with MCP tool support.""" + """ + Streams responses for the API with MCP (Model Context Protocol) tool support. + This method processes the incoming request data, which may include user messages or input items, + and streams back responses that may involve multiple event types, such as user messages, assistant + responses, and tool call events. It supports both single string input and a list of message objects. + Event Handling Flow: + - Parses the input data into a standardized list of message items. + - Prepares response arguments, including tool definitions and tool choice if tools are provided. + - Accumulates output chunks as they are generated, yielding each chunk as a JSON-encoded string. + - Handles tool call events by invoking the appropriate tools via MCP connections, and streams + the results back as part of the response. + - May recursively invoke itself to handle follow-up tool calls or multi-turn interactions. + Args: + request_data (dict): The incoming request payload, including input messages. + tools (list): List of tool definitions available for invocation. + connections (dict): MCP connections for tool execution. + tool_to_server (dict): Mapping from tool names to server endpoints. + Yields: + str: JSON-encoded response chunks, streamed as they become available. + """ input_data = request_data.get("input", "") input_items = ( [ @@ -917,7 +982,11 @@ def openai_transport(self, msg: str) -> str: if endpoint == self.ENDPOINT_CHAT_COMPLETIONS: response = self._route_request(endpoint, data, mcp_servers, connections, tools) - while response.choices and response.choices[0].message.tool_calls: + while ( + response.choices + and len(response.choices) > 0 + and response.choices[0].get('message', {}).get('tool_calls') + ): messages = data.get("messages", []) messages.append(response.choices[0].message) self._execute_chat_tools( @@ -996,7 +1065,7 @@ def openai_stream_transport(self, msg: str) -> Iterator[str]: tools, connections, tool_to_server = pool.get_tools_and_mapping(mcp_servers) if endpoint == self.ENDPOINT_CHAT_COMPLETIONS: - yield from self._bridge_async_gen( + yield from self._async_to_sync_generator( lambda: self._stream_chat_with_tools( data.get("messages", []), tools, @@ -1008,7 +1077,7 @@ def openai_stream_transport(self, msg: str) -> Iterator[str]: ) ) else: - yield from self._bridge_async_gen( + yield from self._async_to_sync_generator( lambda: self._stream_responses_with_tools( data, tools, connections, tool_to_server ) From d977ec5e1f68bb4314b2cffa77776f2acacef314 Mon Sep 17 00:00:00 2001 From: Luv Bansal Date: Mon, 15 Dec 2025 14:11:57 +0530 Subject: [PATCH 08/13] co-pilot comments --- clarifai/runners/models/agentic_class.py | 26 +++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/clarifai/runners/models/agentic_class.py b/clarifai/runners/models/agentic_class.py index eda4bb8f..0ef02ad3 100644 --- a/clarifai/runners/models/agentic_class.py +++ b/clarifai/runners/models/agentic_class.py @@ -31,7 +31,31 @@ def touch(self): class MCPConnectionPool: - """Thread-safe connection pool with passive idle cleanup.""" + """ + Singleton, thread-safe connection pool for managing MCP server connections. + Lifecycle: + - The pool is implemented as a singleton. The first instantiation creates the instance; + subsequent instantiations return the same object. + - Initialization sets up internal data structures, a background asyncio event loop, + and a dedicated thread for running asynchronous tasks. + - The event loop is started in a background daemon thread and is used to run async + operations (such as connecting and disconnecting). + Thread Safety: + - All access to shared state (connections, tool caches) is protected by a reentrant lock (`self._lock`). + - The singleton instance is protected by a class-level lock (`_instance_lock`) to ensure only one instance is created. + - The background event loop is started and accessed in a thread-safe manner. + Cleanup Behavior: + - Idle connections are cleaned up passively: whenever `get_connections()` is called, the pool checks for + connections that have been idle longer than `MAX_IDLE_TIME` and disconnects them. + - Cleanup is rate-limited by `CLEANUP_INTERVAL` to avoid excessive checks. + - Disconnection is performed asynchronously in the background event loop. + - Tool caches are invalidated when a connection is removed. + - There is no explicit shutdown; background threads and event loops are daemonized and will exit with the process. + Usage Notes: + - Users do not need to manage the pool directly; it is managed automatically as a singleton. + - Connections are created, reused, and cleaned up transparently. + - The pool is safe for concurrent use from multiple threads. + """ _instance: Optional['MCPConnectionPool'] = None _instance_lock = threading.Lock() From 2bd46970d323d4275a751338e6f2227590711b6e Mon Sep 17 00:00:00 2001 From: Luv Bansal Date: Mon, 15 Dec 2025 14:54:24 +0530 Subject: [PATCH 09/13] co-pilot comments --- clarifai/runners/models/agentic_class.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/clarifai/runners/models/agentic_class.py b/clarifai/runners/models/agentic_class.py index 0ef02ad3..301e9351 100644 --- a/clarifai/runners/models/agentic_class.py +++ b/clarifai/runners/models/agentic_class.py @@ -502,11 +502,23 @@ def _parse_chat_tool_calls(self, tool_calls) -> List[Tuple[str, str, Dict]]: result = [] for tc in tool_calls: if hasattr(tc, 'function'): - result.append((tc.id, tc.function.name, json.loads(tc.function.arguments))) + try: + args = json.loads(tc.function.arguments) + except json.JSONDecodeError: + logger.warning( + f"Malformed tool arguments for tool '{getattr(tc.function, 'name', None)}': {tc.function.arguments!r}" + ) + args = {} + result.append((tc.id, tc.function.name, args)) else: - result.append( - (tc['id'], tc['function']['name'], json.loads(tc['function']['arguments'])) - ) + try: + args = json.loads(tc['function']['arguments']) + except json.JSONDecodeError: + logger.warning( + f"Malformed tool arguments for tool '{tc['function'].get('name', None)}': {tc['function']['arguments']!r}" + ) + args = {} + result.append((tc['id'], tc['function']['name'], args)) return result def _parse_response_tool_calls(self, items: List[dict]) -> List[Tuple[str, str, Dict]]: From ff94a7d7a10a5f309dd0304da0db1a8ff2508af7 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Mon, 15 Dec 2025 16:11:10 +0530 Subject: [PATCH 10/13] Add comprehensive test coverage for MCPConnectionPool connection lifecycle (#875) * Initial plan * Add comprehensive test coverage for MCPConnectionPool Co-authored-by: luv-bansal <70321430+luv-bansal@users.noreply.github.com> * Improve test robustness based on code review feedback Co-authored-by: luv-bansal <70321430+luv-bansal@users.noreply.github.com> * Optimize test execution time and improve mock specifications Co-authored-by: luv-bansal <70321430+luv-bansal@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: luv-bansal <70321430+luv-bansal@users.noreply.github.com> --- tests/runners/test_mcp_connection_pool.py | 532 ++++++++++++++++++++++ 1 file changed, 532 insertions(+) create mode 100644 tests/runners/test_mcp_connection_pool.py diff --git a/tests/runners/test_mcp_connection_pool.py b/tests/runners/test_mcp_connection_pool.py new file mode 100644 index 00000000..be682532 --- /dev/null +++ b/tests/runners/test_mcp_connection_pool.py @@ -0,0 +1,532 @@ +"""Test cases for MCPConnectionPool singleton and connection lifecycle management.""" + +import asyncio +import time +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from clarifai.runners.models.agentic_class import MCPConnection, MCPConnectionPool + +# Test constants for timing margins +IDLE_MARGIN_SECONDS = 10 # Margin to add beyond thresholds for robust idle tests + + +class TestMCPConnectionPool: + """Tests for MCPConnectionPool singleton and connection management.""" + + @pytest.fixture(autouse=True) + def reset_pool(self): + """Reset the singleton instance before each test.""" + # Clear singleton instance + MCPConnectionPool._instance = None + yield + # Clean up after test + MCPConnectionPool._instance = None + + def test_singleton_pattern(self): + """Test that MCPConnectionPool is a singleton.""" + pool1 = MCPConnectionPool() + pool2 = MCPConnectionPool() + + assert pool1 is pool2 + assert id(pool1) == id(pool2) + + def test_singleton_initialization_once(self): + """Test that singleton is initialized only once.""" + pool1 = MCPConnectionPool() + initial_connections = pool1._connections + + pool2 = MCPConnectionPool() + + # Should have the same connections dictionary + assert pool2._connections is initial_connections + + def test_event_loop_initialization(self): + """Test that background event loop is started on init.""" + pool = MCPConnectionPool() + + assert pool._loop is not None + assert pool._loop_thread is not None + assert pool._loop_thread.is_alive() + assert not pool._loop.is_closed() + + def test_connection_cleanup_idle_timeout(self): + """Test that connections idle > MAX_IDLE_TIME are removed.""" + pool = MCPConnectionPool() + + # Create a mock connection that's been idle for too long + mock_client = MagicMock() + mock_client.close = AsyncMock() + old_conn = MCPConnection( + client=mock_client, + tools=[], + tool_names=set(), + url="http://old-server", + last_used=time.time() - pool.MAX_IDLE_TIME - IDLE_MARGIN_SECONDS, + ) + + with pool._lock: + pool._connections["http://old-server"] = old_conn + + # Force cleanup to run immediately + pool._last_cleanup = 0 + pool._maybe_cleanup_idle() + + # Connection should be removed + with pool._lock: + assert "http://old-server" not in pool._connections + + def test_cleanup_interval_rate_limiting(self): + """Test that cleanup checks are rate limited.""" + pool = MCPConnectionPool() + + # Create an idle connection + mock_client = MagicMock() + old_conn = MCPConnection( + client=mock_client, + tools=[], + tool_names=set(), + url="http://server", + last_used=time.time() - pool.MAX_IDLE_TIME - IDLE_MARGIN_SECONDS, + ) + + with pool._lock: + pool._connections["http://server"] = old_conn + + # Set last cleanup to recent time + pool._last_cleanup = time.time() + + # Try to cleanup - should be skipped due to rate limiting + pool._maybe_cleanup_idle() + + # Connection should still exist (cleanup was skipped) + with pool._lock: + assert "http://server" in pool._connections + + def test_connection_verification_valid(self): + """Test connection verification for valid connections.""" + pool = MCPConnectionPool() + + # Create mock connection + mock_client = MagicMock() + mock_client.list_tools = AsyncMock(return_value=[]) + + conn = MCPConnection( + client=mock_client, tools=[], tool_names=set(), url="http://valid-server" + ) + + # Run verification + is_valid = pool._run_async(pool._verify_connection(conn)) + + assert is_valid is True + mock_client.list_tools.assert_called_once() + + def test_connection_verification_invalid(self): + """Test connection verification for invalid connections.""" + pool = MCPConnectionPool() + + # Create mock connection that fails verification + mock_client = MagicMock() + mock_client.list_tools = AsyncMock(side_effect=Exception("Connection lost")) + + conn = MCPConnection( + client=mock_client, tools=[], tool_names=set(), url="http://invalid-server" + ) + + # Run verification + is_valid = pool._run_async(pool._verify_connection(conn)) + + assert is_valid is False + + def test_needs_verification(self): + """Test _needs_verification logic.""" + pool = MCPConnectionPool() + + # Fresh connection should not need verification + fresh_conn = MCPConnection( + client=MagicMock(), tools=[], tool_names=set(), url="http://fresh" + ) + assert pool._needs_verification(fresh_conn) is False + + # Old connection should need verification + old_conn = MCPConnection( + client=MagicMock(), + tools=[], + tool_names=set(), + url="http://old", + last_used=time.time() - pool.VERIFY_IDLE_THRESHOLD - IDLE_MARGIN_SECONDS, + ) + assert pool._needs_verification(old_conn) is True + + def test_parallel_connection_creation(self): + """Test that connections are created in parallel.""" + pool = MCPConnectionPool() + + # Create mock connections and add them directly + mock_client1 = MagicMock() + mock_client1.list_tools = AsyncMock(return_value=[]) + conn1 = MCPConnection( + client=mock_client1, tools=[], tool_names=set(), url="http://server1" + ) + + mock_client2 = MagicMock() + mock_client2.list_tools = AsyncMock(return_value=[]) + conn2 = MCPConnection( + client=mock_client2, tools=[], tool_names=set(), url="http://server2" + ) + + mock_client3 = MagicMock() + mock_client3.list_tools = AsyncMock(return_value=[]) + conn3 = MCPConnection( + client=mock_client3, tools=[], tool_names=set(), url="http://server3" + ) + + with pool._lock: + pool._connections["http://server1"] = conn1 + pool._connections["http://server2"] = conn2 + pool._connections["http://server3"] = conn3 + + # Get connections - should reuse existing ones + urls = ["http://server1", "http://server2", "http://server3"] + connections = pool.get_connections(urls) + + # All connections should be returned + assert len(connections) == 3 + for url in urls: + assert url in connections + + def test_connection_creation_error_handling(self): + """Test error handling when connection creation fails.""" + pool = MCPConnectionPool() + + # Try to create connection for a URL that's not already in pool + # This will fail because fastmcp is not actually installed + urls = ["http://bad-server"] + connections = pool.get_connections(urls) + + # Should handle error gracefully and return empty dict + assert len(connections) == 0 + + def test_parallel_connection_creation_partial_failure(self): + """Test parallel creation with some failures.""" + pool = MCPConnectionPool() + + # Add two valid connections + mock_client1 = MagicMock() + mock_client1.list_tools = AsyncMock(return_value=[]) + conn1 = MCPConnection( + client=mock_client1, tools=[], tool_names=set(), url="http://server1" + ) + + mock_client3 = MagicMock() + mock_client3.list_tools = AsyncMock(return_value=[]) + conn3 = MCPConnection( + client=mock_client3, tools=[], tool_names=set(), url="http://server3" + ) + + with pool._lock: + pool._connections["http://server1"] = conn1 + pool._connections["http://server3"] = conn3 + + # Try to get connections including one that doesn't exist (will fail to create) + urls = ["http://server1", "http://server2", "http://server3"] + connections = pool.get_connections(urls) + + # Should have 2 successful connections (1st and 3rd) + assert len(connections) == 2 + assert "http://server1" in connections + assert "http://server3" in connections + assert "http://server2" not in connections + + def test_connection_touch_mechanism(self): + """Test that connections are touched when accessed.""" + pool = MCPConnectionPool() + + # Create connection with old timestamp + mock_client = MagicMock() + old_time = time.time() - 100 + conn = MCPConnection( + client=mock_client, tools=[], tool_names=set(), url="http://server", last_used=old_time + ) + + with pool._lock: + pool._connections["http://server"] = conn + + # Access connection (should touch it) + connections = pool.get_connections(["http://server"]) + + # Last used should be updated + assert connections["http://server"].last_used > old_time + + def test_tool_cache_update(self): + """Test that tool cache is updated when connections are created.""" + pool = MCPConnectionPool() + + # Create mock tools + mock_tool1 = MagicMock() + mock_tool1.name = "tool1" + mock_tool1.description = "Test tool 1" + mock_tool1.inputSchema = {"type": "object"} + + mock_tool2 = MagicMock() + mock_tool2.name = "tool2" + mock_tool2.description = "Test tool 2" + mock_tool2.inputSchema = {"type": "object"} + + conn = MCPConnection( + client=MagicMock(), + tools=[mock_tool1, mock_tool2], + tool_names={"tool1", "tool2"}, + url="http://server", + ) + + # Update cache + pool._update_tool_cache(conn) + + # Verify cache contents + assert "tool1" in pool._tool_to_url + assert "tool2" in pool._tool_to_url + assert pool._tool_to_url["tool1"] == "http://server" + assert pool._tool_to_url["tool2"] == "http://server" + assert "tool1" in pool._all_tools + assert "tool2" in pool._all_tools + + def test_tool_cache_invalidation_on_disconnect(self): + """Test that tool cache is invalidated when connection is removed.""" + pool = MCPConnectionPool() + + # Create connection with tools + mock_client = MagicMock() + mock_client.close = AsyncMock() + + mock_tool = MagicMock() + mock_tool.name = "test_tool" + + conn = MCPConnection( + client=mock_client, tools=[mock_tool], tool_names={"test_tool"}, url="http://server" + ) + + with pool._lock: + pool._connections["http://server"] = conn + pool._tool_to_url["test_tool"] = "http://server" + pool._all_tools["test_tool"] = {"type": "function"} + + # Disconnect + pool._disconnect("http://server") + + # Tool cache should be cleared + assert "test_tool" not in pool._tool_to_url + assert "test_tool" not in pool._all_tools + + def test_connection_reuse(self): + """Test that existing connections are reused.""" + pool = MCPConnectionPool() + + # Create initial connection + mock_client = MagicMock() + conn = MCPConnection(client=mock_client, tools=[], tool_names=set(), url="http://server") + + with pool._lock: + pool._connections["http://server"] = conn + + # Get connection again + connections = pool.get_connections(["http://server"]) + + # Should reuse same connection + assert connections["http://server"] is conn + + def test_stale_connection_recreation(self): + """Test that stale connections are verified and removed if invalid.""" + pool = MCPConnectionPool() + + # Create stale connection that will fail verification + mock_old_client = MagicMock() + mock_old_client.list_tools = AsyncMock(side_effect=Exception("Connection lost")) + + old_conn = MCPConnection( + client=mock_old_client, + tools=[], + tool_names=set(), + url="http://server", + last_used=time.time() - pool.VERIFY_IDLE_THRESHOLD - IDLE_MARGIN_SECONDS, + ) + + with pool._lock: + pool._connections["http://server"] = old_conn + + # Get connection (should verify and fail, then try to recreate but fail due to missing fastmcp) + connections = pool.get_connections(["http://server"]) + + # Should not have connection since verification failed and recreation failed + # (fastmcp is not installed so recreation will fail) + assert "http://server" not in connections + + # Original connection should have been removed from pool + with pool._lock: + assert "http://server" not in pool._connections + + def test_close_connection_with_close_method(self): + """Test closing connection with close() method.""" + pool = MCPConnectionPool() + + # Create connection with close method + mock_client = MagicMock() + mock_client.close = AsyncMock() + + conn = MCPConnection(client=mock_client, tools=[], tool_names=set(), url="http://server") + + # Close connection + pool._run_async(pool._close_connection(conn)) + + # close() should be called + mock_client.close.assert_called_once() + + def test_close_connection_with_aexit(self): + """Test closing connection with __aexit__ method.""" + pool = MCPConnectionPool() + + # Create connection without close method but with __aexit__ + # Use spec to explicitly define available methods (common client methods) + mock_client = MagicMock(spec=['__aexit__', '__aenter__', 'list_tools', 'call_tool']) + mock_client.__aexit__ = AsyncMock() + + conn = MCPConnection(client=mock_client, tools=[], tool_names=set(), url="http://server") + + # Close connection + pool._run_async(pool._close_connection(conn)) + + # __aexit__ should be called + mock_client.__aexit__.assert_called_once_with(None, None, None) + + def test_get_tools_and_mapping(self): + """Test getting tools and connection mapping.""" + pool = MCPConnectionPool() + + # Create mock tools + mock_tool1 = MagicMock() + mock_tool1.name = "tool1" + mock_tool1.description = "Test tool 1" + mock_tool1.inputSchema = {"type": "object"} + + mock_tool2 = MagicMock() + mock_tool2.name = "tool2" + mock_tool2.description = "Test tool 2" + mock_tool2.inputSchema = {"type": "object"} + + # Create connections + conn1 = MCPConnection( + client=MagicMock(), tools=[mock_tool1], tool_names={"tool1"}, url="http://server1" + ) + + conn2 = MCPConnection( + client=MagicMock(), tools=[mock_tool2], tool_names={"tool2"}, url="http://server2" + ) + + with pool._lock: + pool._connections["http://server1"] = conn1 + pool._connections["http://server2"] = conn2 + + # Get tools and mapping + tools, connections, tool_to_server = pool.get_tools_and_mapping( + ["http://server1", "http://server2"] + ) + + # Verify results + assert len(tools) == 2 + assert len(connections) == 2 + assert "tool1" in tool_to_server + assert "tool2" in tool_to_server + assert tool_to_server["tool1"] == "http://server1" + assert tool_to_server["tool2"] == "http://server2" + + def test_call_tool_sync(self): + """Test synchronous tool calling.""" + pool = MCPConnectionPool() + + # Create mock connection + mock_client = MagicMock() + mock_result = MagicMock() + mock_result.content = [MagicMock(text="Result")] + mock_client.call_tool = AsyncMock(return_value=mock_result) + + conn = MCPConnection( + client=mock_client, tools=[], tool_names={"test_tool"}, url="http://server" + ) + + connections = {"http://server": conn} + tool_to_server = {"test_tool": "http://server"} + + # Call tool + result = pool.call_tool("test_tool", {"arg": "value"}, connections, tool_to_server) + + # Verify call + assert result is mock_result + mock_client.call_tool.assert_called_once_with("test_tool", arguments={"arg": "value"}) + + def test_call_tools_batch(self): + """Test batch tool calling.""" + pool = MCPConnectionPool() + + # Create mock connection + mock_client = MagicMock() + mock_result1 = MagicMock() + mock_result1.content = [MagicMock(text="Result1")] + mock_result2 = MagicMock() + mock_result2.content = [MagicMock(text="Result2")] + + async def mock_call_tool(name, arguments): + if name == "tool1": + return mock_result1 + return mock_result2 + + mock_client.call_tool = mock_call_tool + + conn = MCPConnection( + client=mock_client, tools=[], tool_names={"tool1", "tool2"}, url="http://server" + ) + + connections = {"http://server": conn} + tool_to_server = {"tool1": "http://server", "tool2": "http://server"} + + # Call tools in batch + calls = [("id1", "tool1", {"arg": "value1"}), ("id2", "tool2", {"arg": "value2"})] + results = pool.call_tools_batch(calls, connections, tool_to_server) + + # Verify results + assert len(results) == 2 + assert results[0][0] == "id1" + assert results[1][0] == "id2" + + def test_tool_call_timeout(self): + """Test that tool calls timeout appropriately.""" + pool = MCPConnectionPool() + + # Temporarily reduce timeout for faster test execution + original_timeout = pool.TOOL_CALL_TIMEOUT + pool.TOOL_CALL_TIMEOUT = 2.0 # Use 2 second timeout for test + + try: + # Create mock connection with slow tool + mock_client = MagicMock() + + async def slow_tool(*args, **kwargs): + # Sleep just beyond the reduced timeout + await asyncio.sleep(3.0) + return MagicMock() + + mock_client.call_tool = slow_tool + + conn = MCPConnection( + client=mock_client, tools=[], tool_names={"slow_tool"}, url="http://server" + ) + + connections = {"http://server": conn} + tool_to_server = {"slow_tool": "http://server"} + + # Call should timeout + with pytest.raises(asyncio.TimeoutError): + pool.call_tool("slow_tool", {}, connections, tool_to_server) + finally: + # Restore original timeout + pool.TOOL_CALL_TIMEOUT = original_timeout From 917060510f29b1dc9b4acca8ae670a3870de22ea Mon Sep 17 00:00:00 2001 From: Luv Bansal Date: Mon, 15 Dec 2025 16:37:44 +0530 Subject: [PATCH 11/13] Fix co-pilot suggestions --- clarifai/runners/models/agentic_class.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/clarifai/runners/models/agentic_class.py b/clarifai/runners/models/agentic_class.py index 301e9351..6577b0ae 100644 --- a/clarifai/runners/models/agentic_class.py +++ b/clarifai/runners/models/agentic_class.py @@ -118,8 +118,13 @@ def run(): def _run_async(self, coro, timeout: float = 30.0) -> Any: """Run coroutine in background loop.""" + # Double-checked locking pattern to prevent race condition + # when multiple threads try to restart a closed loop if self._loop is None or self._loop.is_closed(): - self._start_event_loop() + with self._lock: + # Check again after acquiring lock (another thread may have started it) + if self._loop is None or self._loop.is_closed(): + self._start_event_loop() future = asyncio.run_coroutine_threadsafe(coro, self._loop) return future.result(timeout=timeout) From 9b5fc4387a589810f0d745bd199a8bc160549fb0 Mon Sep 17 00:00:00 2001 From: Luv Bansal Date: Mon, 15 Dec 2025 17:45:57 +0530 Subject: [PATCH 12/13] Agentic Class tests --- tests/runners/test_agentic_class.py | 1210 +++++++++++++++++++++++++++ 1 file changed, 1210 insertions(+) create mode 100644 tests/runners/test_agentic_class.py diff --git a/tests/runners/test_agentic_class.py b/tests/runners/test_agentic_class.py new file mode 100644 index 00000000..6ba7a554 --- /dev/null +++ b/tests/runners/test_agentic_class.py @@ -0,0 +1,1210 @@ +"""Comprehensive tests for AgenticModelClass and MCPConnectionPool integration.""" + +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from pydantic_core import to_json + +from clarifai.runners.models.agentic_class import ( + AgenticModelClass, + MCPConnection, + MCPConnectionPool, +) +from clarifai.runners.models.dummy_openai_model import MockOpenAIClient + + +class DummyAgenticModel(AgenticModelClass): + """Dummy agentic model for testing.""" + + client = MockOpenAIClient() + model = "dummy-agentic-model" + + +class TestAgenticModelClass: + """Tests for AgenticModelClass functionality.""" + + @pytest.fixture(autouse=True) + def reset_pool(self): + """Reset the singleton instance before each test.""" + AgenticModelClass._pool = None + MCPConnectionPool._instance = None + yield + AgenticModelClass._pool = None + MCPConnectionPool._instance = None + + @pytest.fixture + def model(self): + """Create a test model instance.""" + return DummyAgenticModel() + + @pytest.fixture + def mock_pool(self): + """Create a mock connection pool.""" + pool = MagicMock(spec=MCPConnectionPool) + pool.get_tools_and_mapping = MagicMock(return_value=([], {}, {})) + pool.call_tools_batch = MagicMock(return_value=[]) + pool.call_tools_batch_async = AsyncMock(return_value=[]) + pool._loop = asyncio.new_event_loop() + return pool + + # === Token Tracking Tests === + + def test_init_tokens(self, model): + """Test token initialization.""" + model._init_tokens() + assert hasattr(model._thread_local, 'tokens') + assert model._thread_local.tokens == {'prompt': 0, 'completion': 0} + + def test_add_tokens_from_usage(self, model): + """Test adding tokens from response with usage attribute.""" + mock_response = MagicMock() + mock_usage = MagicMock() + mock_usage.prompt_tokens = 10 + mock_usage.completion_tokens = 20 + mock_response.usage = mock_usage + + model._add_tokens(mock_response) + assert model._thread_local.tokens['prompt'] == 10 + assert model._thread_local.tokens['completion'] == 20 + + def test_add_tokens_from_response_usage(self, model): + """Test adding tokens from response.response.usage.""" + mock_response = MagicMock() + mock_response.usage = None + mock_inner_response = MagicMock() + + # Use a simple object instead of MagicMock to avoid mock attribute issues + class MockUsage: + input_tokens = 15 + output_tokens = 25 + + mock_usage = MockUsage() + mock_inner_response.usage = mock_usage + mock_response.response = mock_inner_response + + model._add_tokens(mock_response) + assert model._thread_local.tokens['prompt'] == 15 + assert model._thread_local.tokens['completion'] == 25 + + def test_add_tokens_accumulates(self, model): + """Test that tokens accumulate across multiple calls.""" + mock_response1 = MagicMock() + mock_usage1 = MagicMock() + mock_usage1.prompt_tokens = 10 + mock_usage1.completion_tokens = 20 + mock_response1.usage = mock_usage1 + + mock_response2 = MagicMock() + mock_usage2 = MagicMock() + mock_usage2.prompt_tokens = 5 + mock_usage2.completion_tokens = 10 + mock_response2.usage = mock_usage2 + + model._add_tokens(mock_response1) + model._add_tokens(mock_response2) + + assert model._thread_local.tokens['prompt'] == 15 + assert model._thread_local.tokens['completion'] == 30 + + def test_finalize_tokens(self, model): + """Test finalizing tokens to output context.""" + model._init_tokens() + model._thread_local.tokens['prompt'] = 10 + model._thread_local.tokens['completion'] = 20 + + with patch.object(model, 'set_output_context') as mock_set: + model._finalize_tokens() + mock_set.assert_called_once_with(prompt_tokens=10, completion_tokens=20) + assert not hasattr(model._thread_local, 'tokens') + + def test_finalize_tokens_no_tokens(self, model): + """Test finalizing when no tokens were tracked.""" + with patch.object(model, 'set_output_context') as mock_set: + model._finalize_tokens() + mock_set.assert_not_called() + + # === Tool Format Conversion Tests === + + def test_to_response_api_tools_with_function(self, model): + """Test converting chat completion tools to response API format.""" + tools = [ + { + "type": "function", + "function": { + "name": "test_tool", + "description": "A test tool", + "parameters": {"type": "object", "properties": {}}, + }, + } + ] + result = model._to_response_api_tools(tools) + assert len(result) == 1 + assert result[0]["type"] == "function" + assert result[0]["name"] == "test_tool" + assert result[0]["description"] == "A test tool" + + def test_to_response_api_tools_with_name(self, model): + """Test converting tools that already have name field.""" + tools = [ + { + "type": "function", + "name": "test_tool", + "description": "A test tool", + "parameters": {}, + } + ] + result = model._to_response_api_tools(tools) + assert len(result) == 1 + assert result[0]["name"] == "test_tool" + + def test_to_dict_from_dict(self, model): + """Test _to_dict with dict input.""" + obj = {"key": "value"} + result = model._to_dict(obj) + assert result == obj + + def test_to_dict_with_model_dump(self, model): + """Test _to_dict with object that has model_dump method.""" + obj = MagicMock() + obj.model_dump.return_value = {"key": "value"} + result = model._to_dict(obj) + assert result == {"key": "value"} + + def test_to_dict_with_dict_method(self, model): + """Test _to_dict with object that has dict method.""" + + # Use a simple class to avoid MagicMock issues + class TestObj: + def dict(self): + return {"key": "value"} + + obj = TestObj() + result = model._to_dict(obj) + assert result == {"key": "value"} + + def test_to_dict_with_dict_attribute(self, model): + """Test _to_dict with object that has __dict__ attribute.""" + + class TestObj: + def __init__(self): + self.key = "value" + + obj = TestObj() + result = model._to_dict(obj) + assert result == {"key": "value"} + + # === Tool Call Parsing Tests === + + def test_parse_chat_tool_calls_with_function_attribute(self, model): + """Test parsing chat completion tool calls with function attribute.""" + mock_tool_call = MagicMock() + mock_tool_call.id = "call_123" + mock_tool_call.function.name = "test_tool" + mock_tool_call.function.arguments = '{"arg": "value"}' + + tool_calls = [mock_tool_call] + result = model._parse_chat_tool_calls(tool_calls) + + assert len(result) == 1 + assert result[0] == ("call_123", "test_tool", {"arg": "value"}) + + def test_parse_chat_tool_calls_with_dict(self, model): + """Test parsing chat completion tool calls from dict.""" + tool_calls = [ + {"id": "call_123", "function": {"name": "test_tool", "arguments": '{"arg": "value"}'}} + ] + result = model._parse_chat_tool_calls(tool_calls) + + assert len(result) == 1 + assert result[0] == ("call_123", "test_tool", {"arg": "value"}) + + def test_parse_chat_tool_calls_invalid_json(self, model): + """Test parsing tool calls with invalid JSON arguments.""" + mock_tool_call = MagicMock() + mock_tool_call.id = "call_123" + mock_tool_call.function.name = "test_tool" + mock_tool_call.function.arguments = "invalid json" + + tool_calls = [mock_tool_call] + result = model._parse_chat_tool_calls(tool_calls) + + assert len(result) == 1 + assert result[0] == ("call_123", "test_tool", {}) + + def test_parse_response_tool_calls(self, model): + """Test parsing response API tool calls.""" + items = [ + { + "type": "function_tool_call", + "call_id": "call_123", + "name": "test_tool", + "arguments": '{"arg": "value"}', + "status": "pending", + } + ] + result = model._parse_response_tool_calls(items) + + assert len(result) == 1 + assert result[0] == ("call_123", "test_tool", {"arg": "value"}) + + def test_parse_response_tool_calls_with_id(self, model): + """Test parsing response API tool calls using id instead of call_id.""" + items = [ + { + "type": "function_call", + "id": "call_123", + "name": "test_tool", + "arguments": '{"arg": "value"}', + "output": None, + } + ] + result = model._parse_response_tool_calls(items) + + assert len(result) == 1 + assert result[0] == ("call_123", "test_tool", {"arg": "value"}) + + def test_parse_response_tool_calls_skips_completed(self, model): + """Test that completed tool calls are skipped.""" + items = [ + { + "type": "function_tool_call", + "call_id": "call_123", + "name": "test_tool", + "arguments": '{"arg": "value"}', + "status": "completed", + "output": "result", + } + ] + result = model._parse_response_tool_calls(items) + assert len(result) == 0 + + def test_parse_response_tool_calls_with_dict_arguments(self, model): + """Test parsing response tool calls with dict arguments instead of string.""" + items = [ + { + "type": "function_tool_call", + "call_id": "call_123", + "name": "test_tool", + "arguments": {"arg": "value"}, # Already a dict + "status": "pending", + } + ] + result = model._parse_response_tool_calls(items) + + assert len(result) == 1 + assert result[0] == ("call_123", "test_tool", {"arg": "value"}) + + def test_parse_response_tool_calls_empty_string_arguments(self, model): + """Test parsing response tool calls with empty string arguments.""" + items = [ + { + "type": "function_tool_call", + "call_id": "call_123", + "name": "test_tool", + "arguments": "", + "status": "pending", + } + ] + result = model._parse_response_tool_calls(items) + + assert len(result) == 1 + assert result[0] == ("call_123", "test_tool", {}) + + # === Tool Execution Tests === + + def test_execute_chat_tools(self, model): + """Test executing chat completion tool calls.""" + mock_pool = MagicMock() + mock_result = MagicMock() + mock_result.content = [MagicMock(text="Tool result")] + mock_pool.call_tools_batch.return_value = [("call_123", mock_result, None)] + + with patch.object(model, 'get_pool', return_value=mock_pool): + tool_calls = [MagicMock()] + tool_calls[0].id = "call_123" + tool_calls[0].function.name = "test_tool" + tool_calls[0].function.arguments = '{}' + + connections = {} + messages = [] + tool_to_server = {} + + model._execute_chat_tools(tool_calls, connections, messages, tool_to_server) + + assert len(messages) == 1 + assert messages[0]["role"] == "tool" + assert messages[0]["tool_call_id"] == "call_123" + assert messages[0]["content"] == "Tool result" + + def test_execute_chat_tools_with_error(self, model): + """Test executing chat tools with error.""" + mock_pool = MagicMock() + mock_pool.call_tools_batch.return_value = [("call_123", None, "Tool error")] + + with patch.object(model, 'get_pool', return_value=mock_pool): + tool_calls = [MagicMock()] + tool_calls[0].id = "call_123" + tool_calls[0].function.name = "test_tool" + tool_calls[0].function.arguments = '{}' + + connections = {} + messages = [] + tool_to_server = {} + + model._execute_chat_tools(tool_calls, connections, messages, tool_to_server) + + assert len(messages) == 1 + assert messages[0]["content"] == "Error: Tool error" + + def test_execute_chat_tools_with_list_result(self, model): + """Test executing chat tools with list result format.""" + mock_pool = MagicMock() + + # Use a dict-like object that supports both .get() and .text access + class TextDict: + def __init__(self, text): + self._text = text + + def get(self, key, default=None): + if key == 'text': + return self._text + return default + + @property + def text(self): + return self._text + + mock_result = [TextDict("List result")] + mock_pool.call_tools_batch.return_value = [("call_123", mock_result, None)] + + with patch.object(model, 'get_pool', return_value=mock_pool): + tool_calls = [MagicMock()] + tool_calls[0].id = "call_123" + tool_calls[0].function.name = "test_tool" + tool_calls[0].function.arguments = '{}' + + connections = {} + messages = [] + tool_to_server = {} + + model._execute_chat_tools(tool_calls, connections, messages, tool_to_server) + + assert len(messages) == 1 + assert messages[0]["content"] == "List result" + + def test_execute_chat_tools_with_none_content(self, model): + """Test executing chat tools when result has no content.""" + mock_pool = MagicMock() + mock_result = MagicMock() + mock_result.content = [] + mock_pool.call_tools_batch.return_value = [("call_123", mock_result, None)] + + with patch.object(model, 'get_pool', return_value=mock_pool): + tool_calls = [MagicMock()] + tool_calls[0].id = "call_123" + tool_calls[0].function.name = "test_tool" + tool_calls[0].function.arguments = '{}' + + connections = {} + messages = [] + tool_to_server = {} + + model._execute_chat_tools(tool_calls, connections, messages, tool_to_server) + + assert len(messages) == 1 + assert messages[0]["content"] is None + + @pytest.mark.asyncio + async def test_execute_chat_tools_async(self, model): + """Test async execution of chat tools.""" + mock_pool = MagicMock() + mock_result = MagicMock() + mock_result.content = [MagicMock(text="Async result")] + mock_pool.call_tools_batch_async = AsyncMock( + return_value=[("call_123", mock_result, None)] + ) + + with patch.object(model, 'get_pool', return_value=mock_pool): + tool_calls = [MagicMock()] + tool_calls[0].id = "call_123" + tool_calls[0].function.name = "test_tool" + tool_calls[0].function.arguments = '{}' + + connections = {} + messages = [] + tool_to_server = {} + + await model._execute_chat_tools_async( + tool_calls, connections, messages, tool_to_server + ) + + assert len(messages) == 1 + assert messages[0]["content"] == "Async result" + + def test_execute_response_tools(self, model): + """Test executing response API tool calls.""" + mock_pool = MagicMock() + mock_result = MagicMock() + mock_result.content = [MagicMock(text="Response result")] + mock_pool.call_tools_batch.return_value = [("call_123", mock_result, None)] + + with patch.object(model, 'get_pool', return_value=mock_pool): + tool_calls = [("call_123", "test_tool", {})] + connections = {} + input_items = [] + tool_to_server = {} + + model._execute_response_tools(tool_calls, connections, input_items, tool_to_server) + + assert len(input_items) == 1 + assert input_items[0]["type"] == "function_call_output" + assert input_items[0]["call_id"] == "call_123" + assert input_items[0]["output"] == "Response result" + + @pytest.mark.asyncio + async def test_execute_response_tools_async(self, model): + """Test async execution of response tools.""" + mock_pool = MagicMock() + mock_result = MagicMock() + mock_result.content = [MagicMock(text="Async response result")] + mock_pool.call_tools_batch_async = AsyncMock( + return_value=[("call_123", mock_result, None)] + ) + + with patch.object(model, 'get_pool', return_value=mock_pool): + tool_calls = [("call_123", "test_tool", {})] + connections = {} + input_items = [] + tool_to_server = {} + + await model._execute_response_tools_async( + tool_calls, connections, input_items, tool_to_server + ) + + assert len(input_items) == 1 + assert input_items[0]["output"] == "Async response result" + + # === Response Output Processing Tests === + + def test_convert_output_to_input(self, model): + """Test converting response API output to input items.""" + output = [ + {"type": "message", "role": "assistant", "content": "Hello"}, + {"type": "function_tool_call", "name": "tool1", "output": "result"}, + {"type": "reasoning", "content": "Thinking..."}, + ] + result = model._convert_output_to_input(output) + + assert len(result) == 3 + assert result[0]["type"] == "message" + assert result[1]["type"] == "function_tool_call" + assert result[2]["type"] == "reasoning" + + def test_convert_output_to_input_filters_pending(self, model): + """Test that pending tool calls are filtered out.""" + output = [ + {"type": "function_tool_call", "name": "tool1", "status": "pending", "output": None} + ] + result = model._convert_output_to_input(output) + assert len(result) == 0 + + def test_convert_output_to_input_empty_list(self, model): + """Test converting empty output list.""" + result = model._convert_output_to_input([]) + assert result == [] + + # === Request Handler Tests === + + def test_handle_chat_completions_with_tools(self, model): + """Test handling chat completions with MCP tools.""" + tools = [{"type": "function", "function": {"name": "test_tool"}}] + request_data = {"messages": [{"role": "user", "content": "Hello"}]} + + with patch.object( + AgenticModelClass.__bases__[0], '_handle_chat_completions' + ) as mock_super_handle: + mock_response = MagicMock() + mock_super_handle.return_value = mock_response + + result = model._handle_chat_completions( + request_data, mcp_servers=["http://server"], connections={}, tools=tools + ) + + # The method creates a new dict, so check what was passed to super + call_args = mock_super_handle.call_args[0][0] + assert "tools" in call_args + assert call_args["tool_choice"] == "auto" + + def test_handle_chat_completions_without_tools(self, model): + """Test handling chat completions without MCP tools.""" + request_data = {"messages": [{"role": "user", "content": "Hello"}]} + + with patch.object(model, '_handle_chat_completions') as mock_handle: + mock_response = MagicMock() + mock_handle.return_value = mock_response + result = model._handle_chat_completions(request_data, None, None, None) + + # Should not modify request_data when no tools + assert "tools" not in request_data + + def test_handle_responses_with_tools(self, model): + """Test handling responses with MCP tools.""" + tools = [{"type": "function", "function": {"name": "test_tool"}}] + request_data = {"input": "Hello"} + + with patch.object( + AgenticModelClass.__bases__[0], '_handle_responses' + ) as mock_super_handle: + mock_response = MagicMock() + mock_super_handle.return_value = mock_response + + result = model._handle_responses( + request_data, mcp_servers=["http://server"], connections={}, tools=tools + ) + + # The method creates a new dict, so check what was passed to super + call_args = mock_super_handle.call_args[0][0] + assert "tools" in call_args + assert call_args["tool_choice"] == "auto" + + def test_handle_responses_without_tools(self, model): + """Test handling responses without MCP tools.""" + request_data = {"input": "Hello"} + + with patch.object(model, '_handle_responses') as mock_handle: + mock_response = MagicMock() + mock_handle.return_value = mock_response + result = model._handle_responses(request_data, None, None, None) + + # Should not modify request_data when no tools + assert "tools" not in request_data + + def test_route_request_chat_completions(self, model): + """Test routing to chat completions endpoint.""" + request_data = {"messages": [{"role": "user", "content": "Hello"}]} + with patch.object(model, '_handle_chat_completions') as mock_handle: + mock_handle.return_value = MagicMock() + model._route_request(model.ENDPOINT_CHAT_COMPLETIONS, request_data, None, None, None) + mock_handle.assert_called_once() + + def test_route_request_responses(self, model): + """Test routing to responses endpoint.""" + request_data = {"input": "Hello"} + with patch.object(model, '_handle_responses') as mock_handle: + mock_handle.return_value = MagicMock() + model._route_request(model.ENDPOINT_RESPONSES, request_data, None, None, None) + mock_handle.assert_called_once() + + # === Streaming Helper Tests === + + def test_accumulate_tool_delta(self, model): + """Test accumulating streaming tool call deltas.""" + accumulated = {} + delta = MagicMock() + delta.index = 0 + delta.id = "call_123" + delta.function.name = "test_tool" + delta.function.arguments = '{"arg": "value"}' + + model._accumulate_tool_delta(delta, accumulated) + + assert 0 in accumulated + assert accumulated[0]["id"] == "call_123" + assert accumulated[0]["function"]["name"] == "test_tool" + assert accumulated[0]["function"]["arguments"] == '{"arg": "value"}' + + def test_accumulate_tool_delta_incremental(self, model): + """Test accumulating incremental tool call arguments.""" + accumulated = { + 0: { + "id": "call_123", + "type": "function", + "function": {"name": "test_tool", "arguments": ""}, + } + } + delta = MagicMock() + delta.index = 0 + delta.id = None + delta.function.name = None + delta.function.arguments = '{"arg": "value"}' + + model._accumulate_tool_delta(delta, accumulated) + + assert accumulated[0]["function"]["arguments"] == '{"arg": "value"}' + + def test_finalize_tool_calls(self, model): + """Test finalizing accumulated tool calls.""" + accumulated = { + 0: { + "id": "call_1", + "type": "function", + "function": {"name": "tool1", "arguments": "{}"}, + }, + 1: { + "id": "call_2", + "type": "function", + "function": {"name": "tool2", "arguments": "{}"}, + }, + } + result = model._finalize_tool_calls(accumulated) + + assert len(result) == 2 + assert result[0]["id"] == "call_1" + assert result[1]["id"] == "call_2" + + def test_create_stream_request(self, model): + """Test creating streaming chat completion request.""" + messages = [{"role": "user", "content": "Hello"}] + tools = [{"type": "function", "function": {"name": "test_tool"}}] + + with patch.object(model.client.chat.completions, 'create') as mock_create: + mock_create.return_value = iter([]) + result = model._create_stream_request(messages, tools, 100, 0.7, 0.9) + + mock_create.assert_called_once() + call_kwargs = mock_create.call_args[1] + assert call_kwargs["stream"] is True + assert call_kwargs["tools"] == tools + assert call_kwargs["tool_choice"] == "auto" + + def test_async_to_sync_generator(self, model): + """Test bridging async generator to sync generator.""" + + async def async_gen(): + yield "chunk1" + yield "chunk2" + yield "chunk3" + + pool = model.get_pool() + result = list(model._async_to_sync_generator(async_gen)) + + assert len(result) == 3 + assert result == ["chunk1", "chunk2", "chunk3"] + + def test_async_to_sync_generator_with_error(self, model): + """Test async to sync generator with error.""" + + async def async_gen(): + yield "chunk1" + raise ValueError("Test error") + + with pytest.raises(ValueError, match="Test error"): + list(model._async_to_sync_generator(async_gen)) + + # === Streaming with MCP Tests === + + @pytest.mark.asyncio + async def test_stream_chat_with_tools_no_tool_calls(self, model): + """Test streaming chat with tools when no tool calls are generated.""" + messages = [{"role": "user", "content": "Hello"}] + tools = [] + connections = {} + tool_to_server = {} + + mock_chunk = MagicMock() + mock_chunk.model_dump_json.return_value = '{"id": "chunk1"}' + mock_chunk.choices = [MagicMock()] + mock_chunk.choices[0].delta = MagicMock() + mock_chunk.choices[0].delta.tool_calls = None + mock_chunk.choices[0].delta.content = "Hello" + + with patch.object(model, '_create_stream_request', return_value=[mock_chunk]): + chunks = [] + async for chunk in model._stream_chat_with_tools( + messages, tools, connections, tool_to_server, 100, 0.7, 0.9 + ): + chunks.append(chunk) + + assert len(chunks) == 1 + + @pytest.mark.asyncio + async def test_stream_chat_with_tools_with_tool_calls(self, model): + """Test streaming chat with tools when tool calls are generated.""" + messages = [{"role": "user", "content": "Hello"}] + tools = [{"type": "function", "function": {"name": "test_tool"}}] + connections = {} + tool_to_server = {} + + # Create mock tool call delta + mock_tool_call_delta = MagicMock() + mock_tool_call_delta.index = 0 + mock_tool_call_delta.id = "call_123" + mock_tool_call_delta.function.name = "test_tool" + mock_tool_call_delta.function.arguments = '{}' + + # First chunk with tool calls + mock_chunk_with_tools = MagicMock() + mock_chunk_with_tools.model_dump_json.return_value = '{"id": "chunk1"}' + mock_chunk_with_tools.choices = [MagicMock()] + mock_chunk_with_tools.choices[0].delta = MagicMock() + mock_chunk_with_tools.choices[0].delta.tool_calls = [mock_tool_call_delta] + mock_chunk_with_tools.choices[0].delta.content = None + + # Second chunk without tool calls (to prevent recursion) + mock_chunk_no_tools = MagicMock() + mock_chunk_no_tools.model_dump_json.return_value = '{"id": "chunk2"}' + mock_chunk_no_tools.choices = [MagicMock()] + mock_chunk_no_tools.choices[0].delta = MagicMock() + mock_chunk_no_tools.choices[0].delta.tool_calls = None + mock_chunk_no_tools.choices[0].delta.content = "Response" + + mock_pool = MagicMock() + mock_result = MagicMock() + mock_result.content = [MagicMock(text="Tool result")] + mock_pool.call_tools_batch_async = AsyncMock( + return_value=[("call_123", mock_result, None)] + ) + + # Create a generator that yields chunks + def chunk_generator(): + yield mock_chunk_with_tools + + # For recursive call, return chunk without tool calls + def chunk_generator_no_tools(): + yield mock_chunk_no_tools + + with ( + patch.object( + model, + '_create_stream_request', + side_effect=[chunk_generator(), chunk_generator_no_tools()], + ), + patch.object(model, 'get_pool', return_value=mock_pool), + patch.object( + model, + '_finalize_tool_calls', + return_value=[ + { + "id": "call_123", + "type": "function", + "function": {"name": "test_tool", "arguments": "{}"}, + } + ], + ), + ): + chunks = [] + async for chunk in model._stream_chat_with_tools( + messages, tools, connections, tool_to_server, 100, 0.7, 0.9 + ): + chunks.append(chunk) + + # Should have initial chunk plus recursive call chunks + assert len(chunks) >= 1 + + @pytest.mark.asyncio + async def test_stream_responses_with_tools(self, model): + """Test streaming responses with MCP tools.""" + request_data = {"input": "Hello"} + tools = [{"type": "function", "function": {"name": "test_tool"}}] + connections = {} + tool_to_server = {} + + # Create a simple event class that will be yielded + # The code checks chunk_type and yields if it doesn't match certain patterns + # We need to avoid output_index being detected by hasattr + class MockEvent: + def __init__(self): + self.type = "response.created" + self.response = None + + def model_dump_json(self): + return '{"type": "response.created"}' + + # Set the class name for chunk_type detection + MockEvent.__name__ = "ResponseCreatedEvent" + mock_event = MockEvent() + + with patch.object(model.client.responses, 'create', return_value=[mock_event]): + chunks = [] + async for chunk in model._stream_responses_with_tools( + request_data, tools, connections, tool_to_server + ): + chunks.append(chunk) + + # Should yield at least the event as JSON + assert len(chunks) >= 1 + + # === Main OpenAI Methods Tests === + + def test_openai_transport_without_mcp(self, model): + """Test openai_transport without MCP servers.""" + request = { + "messages": [{"role": "user", "content": "Hello"}], + "openai_endpoint": model.ENDPOINT_CHAT_COMPLETIONS, + } + + with patch.object(model, '_route_request') as mock_route: + mock_response = MagicMock() + mock_response.model_dump_json.return_value = '{"id": "test"}' + mock_route.return_value = mock_response + + result = model.openai_transport(to_json(request)) + assert json.loads(result)["id"] == "test" + + def test_openai_transport_with_mcp_chat_completions(self, model): + """Test openai_transport with MCP servers for chat completions.""" + request = { + "messages": [{"role": "user", "content": "Hello"}], + "mcp_servers": ["http://server"], + "openai_endpoint": model.ENDPOINT_CHAT_COMPLETIONS, + } + + mock_tool = MagicMock() + mock_tool.name = "test_tool" + mock_tool.description = "Test tool" + mock_tool.inputSchema = {} + + mock_conn = MCPConnection( + client=MagicMock(), tools=[mock_tool], tool_names={"test_tool"}, url="http://server" + ) + + mock_pool = MagicMock() + mock_pool.get_tools_and_mapping.return_value = ( + [{"type": "function", "function": {"name": "test_tool"}}], + {"http://server": mock_conn}, + {"test_tool": "http://server"}, + ) + + mock_response = MagicMock() + mock_response.model_dump_json.return_value = '{"id": "test"}' + mock_response.choices = [MagicMock()] + mock_response.choices[0].get.return_value = {} + mock_response.choices[0].message = MagicMock() + mock_response.choices[0].message.get.return_value = None + + with ( + patch.object(model, 'get_pool', return_value=mock_pool), + patch.object(model, '_route_request', return_value=mock_response), + ): + result = model.openai_transport(to_json(request)) + assert json.loads(result)["id"] == "test" + + def test_openai_transport_with_mcp_responses(self, model): + """Test openai_transport with MCP servers for responses API.""" + request = { + "input": "Hello", + "mcp_servers": ["http://server"], + "openai_endpoint": model.ENDPOINT_RESPONSES, + } + + mock_tool = MagicMock() + mock_tool.name = "test_tool" + mock_tool.description = "Test tool" + mock_tool.inputSchema = {} + + mock_conn = MCPConnection( + client=MagicMock(), tools=[mock_tool], tool_names={"test_tool"}, url="http://server" + ) + + mock_pool = MagicMock() + mock_pool.get_tools_and_mapping.return_value = ( + [{"type": "function", "function": {"name": "test_tool"}}], + {"http://server": mock_conn}, + {"test_tool": "http://server"}, + ) + + mock_response = MagicMock() + mock_response.model_dump_json.return_value = '{"id": "test"}' + mock_response.output = [] + + with ( + patch.object(model, 'get_pool', return_value=mock_pool), + patch.object(model, '_route_request', return_value=mock_response), + ): + result = model.openai_transport(to_json(request)) + assert json.loads(result)["id"] == "test" + + def test_openai_transport_with_existing_tools(self, model): + """Test openai_transport when tools are already provided (should not use MCP).""" + request = { + "messages": [{"role": "user", "content": "Hello"}], + "mcp_servers": ["http://server"], + "tools": [{"type": "function", "function": {"name": "existing_tool"}}], + "openai_endpoint": model.ENDPOINT_CHAT_COMPLETIONS, + } + + with patch.object(model, '_route_request') as mock_route: + mock_response = MagicMock() + mock_response.model_dump_json.return_value = '{"id": "test"}' + mock_route.return_value = mock_response + + result = model.openai_transport(to_json(request)) + # Should not call get_pool when tools are already provided + assert json.loads(result)["id"] == "test" + + def test_openai_transport_error_handling(self, model): + """Test error handling in openai_transport.""" + request = { + "messages": [{"role": "user", "content": "Hello"}], + "openai_endpoint": model.ENDPOINT_CHAT_COMPLETIONS, + } + + with patch.object(model, '_route_request', side_effect=Exception("Test error")): + result = model.openai_transport(to_json(request)) + error = json.loads(result) + # Check that it's an error code (could be 22000 or 21313 depending on status_code_pb2) + assert "code" in error + assert error["code"] > 0 + assert "Test error" in error["details"] + + def test_openai_stream_transport_without_mcp(self, model): + """Test openai_stream_transport without MCP servers.""" + request = { + "messages": [{"role": "user", "content": "Hello"}], + "stream": True, + "openai_endpoint": model.ENDPOINT_CHAT_COMPLETIONS, + } + + mock_chunk = MagicMock() + mock_chunk.model_dump_json.return_value = '{"id": "chunk1"}' + # Ensure usage is None to avoid token finalization issues + mock_chunk.usage = None + + with patch.object(model.client.chat.completions, 'create', return_value=[mock_chunk]): + chunks = list(model.openai_stream_transport(to_json(request))) + # Filter out error chunks if any + chunks = [c for c in chunks if not (isinstance(c, bytes) and b'"code"' in c)] + assert len(chunks) == 1 + + def test_openai_stream_transport_with_mcp(self, model): + """Test openai_stream_transport with MCP servers.""" + request = { + "messages": [{"role": "user", "content": "Hello"}], + "mcp_servers": ["http://server"], + "stream": True, + "openai_endpoint": model.ENDPOINT_CHAT_COMPLETIONS, + } + + mock_tool = MagicMock() + mock_tool.name = "test_tool" + mock_tool.description = "Test tool" + mock_tool.inputSchema = {} + + mock_conn = MCPConnection( + client=MagicMock(), tools=[mock_tool], tool_names={"test_tool"}, url="http://server" + ) + + mock_pool = MagicMock() + mock_pool.get_tools_and_mapping.return_value = ( + [{"type": "function", "function": {"name": "test_tool"}}], + {"http://server": mock_conn}, + {"test_tool": "http://server"}, + ) + mock_pool._loop = asyncio.new_event_loop() + + mock_chunk = MagicMock() + mock_chunk.model_dump_json.return_value = '{"id": "chunk1"}' + mock_chunk.choices = [MagicMock()] + mock_chunk.choices[0].delta = MagicMock() + mock_chunk.choices[0].delta.tool_calls = None + mock_chunk.choices[0].delta.content = "Hello" + + with ( + patch.object(model, 'get_pool', return_value=mock_pool), + patch.object(model, '_async_to_sync_generator') as mock_gen, + ): + mock_gen.return_value = iter(['{"id": "chunk1"}']) + chunks = list(model.openai_stream_transport(to_json(request))) + assert len(chunks) == 1 + + def test_openai_stream_transport_error_handling(self, model): + """Test error handling in openai_stream_transport.""" + request = { + "messages": [{"role": "user", "content": "Hello"}], + "stream": True, + "openai_endpoint": model.ENDPOINT_CHAT_COMPLETIONS, + } + + # Patch the client method to raise an error + def raise_error(**kwargs): + raise Exception("Test error") + + with patch.object(model.client.chat.completions, 'create', side_effect=raise_error): + chunks = list(model.openai_stream_transport(to_json(request))) + # Should have at least one error chunk + assert len(chunks) >= 1 + # The error should be in JSON format + error_str = ( + chunks[-1] + if isinstance(chunks[-1], str) + else chunks[-1].decode('utf-8') + if isinstance(chunks[-1], bytes) + else str(chunks[-1]) + ) + error = json.loads(error_str) + assert "code" in error + assert "Test error" in error.get("details", "") + + # === Pool Management Tests === + + def test_get_pool_singleton(self, model): + """Test that get_pool returns a singleton.""" + pool1 = model.get_pool() + pool2 = model.get_pool() + assert pool1 is pool2 + + def test_get_pool_thread_safe(self, model): + """Test that get_pool is thread-safe.""" + import threading + + pools = [] + + def get_pool(): + pools.append(model.get_pool()) + + threads = [threading.Thread(target=get_pool) for _ in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + # All should be the same instance + assert all(p is pools[0] for p in pools) + + # === Integration Tests === + + def test_full_flow_chat_completions_with_tools(self, model): + """Test full flow: chat completions with tool calls.""" + request = { + "messages": [{"role": "user", "content": "Use test_tool"}], + "mcp_servers": ["http://server"], + "openai_endpoint": model.ENDPOINT_CHAT_COMPLETIONS, + } + + mock_tool = MagicMock() + mock_tool.name = "test_tool" + mock_tool.description = "Test tool" + mock_tool.inputSchema = {} + + mock_conn = MCPConnection( + client=MagicMock(), tools=[mock_tool], tool_names={"test_tool"}, url="http://server" + ) + + # First response has tool calls + mock_tool_call = MagicMock() + mock_tool_call.id = "call_123" + mock_tool_call.function.name = "test_tool" + mock_tool_call.function.arguments = '{}' + + mock_message = MagicMock() + mock_message.tool_calls = [mock_tool_call] + + # Make message behave like a dict for .get() calls + def message_get(key, default=None): + if key == "tool_calls": + return [mock_tool_call] + return default + + mock_message.get = message_get + + mock_response1 = MagicMock() + mock_response1.model_dump_json.return_value = '{"id": "test1"}' + mock_response1.choices = [MagicMock()] + mock_response1.choices[0].message = mock_message + + # Make choices[0] behave like a dict for .get() calls + def mock_get(key, default=None): + if key == "message": + return mock_message + return default + + mock_response1.choices[0].get = mock_get + + # Second response (after tool execution) has no tool calls + mock_message2 = MagicMock() + mock_message2.tool_calls = None + + def message_get2(key, default=None): + if key == "tool_calls": + return None + return default + + mock_message2.get = message_get2 + + mock_response2 = MagicMock() + mock_response2.model_dump_json.return_value = '{"id": "test2"}' + mock_response2.choices = [MagicMock()] + mock_response2.choices[0].message = mock_message2 + + def mock_get2(key, default=None): + if key == "message": + return mock_message2 + return default + + mock_response2.choices[0].get = mock_get2 + + mock_pool = MagicMock() + mock_pool.get_tools_and_mapping.return_value = ( + [{"type": "function", "function": {"name": "test_tool"}}], + {"http://server": mock_conn}, + {"test_tool": "http://server"}, + ) + mock_result = MagicMock() + mock_result.content = [MagicMock(text="Tool executed")] + mock_pool.call_tools_batch.return_value = [("call_123", mock_result, None)] + + with ( + patch.object(model, 'get_pool', return_value=mock_pool), + patch.object(model, '_route_request', side_effect=[mock_response1, mock_response2]), + ): + result = model.openai_transport(to_json(request)) + assert json.loads(result)["id"] == "test2" + + def test_full_flow_responses_with_tools(self, model): + """Test full flow: responses API with tool calls.""" + request = { + "input": "Use test_tool", + "mcp_servers": ["http://server"], + "openai_endpoint": model.ENDPOINT_RESPONSES, + } + + mock_tool = MagicMock() + mock_tool.name = "test_tool" + mock_tool.description = "Test tool" + mock_tool.inputSchema = {} + + mock_conn = MCPConnection( + client=MagicMock(), tools=[mock_tool], tool_names={"test_tool"}, url="http://server" + ) + + # First response has tool calls + mock_response1 = MagicMock() + mock_response1.model_dump_json.return_value = '{"id": "test1"}' + mock_response1.output = [ + { + "type": "function_tool_call", + "call_id": "call_123", + "name": "test_tool", + "arguments": '{}', + "output": None, + } + ] + + # Second response (after tool execution) has no tool calls + mock_response2 = MagicMock() + mock_response2.model_dump_json.return_value = '{"id": "test2"}' + mock_response2.output = [] + + mock_pool = MagicMock() + mock_pool.get_tools_and_mapping.return_value = ( + [{"type": "function", "function": {"name": "test_tool"}}], + {"http://server": mock_conn}, + {"test_tool": "http://server"}, + ) + mock_result = MagicMock() + mock_result.content = [MagicMock(text="Tool executed")] + mock_pool.call_tools_batch.return_value = [("call_123", mock_result, None)] + + with ( + patch.object(model, 'get_pool', return_value=mock_pool), + patch.object(model, '_route_request', side_effect=[mock_response1, mock_response2]), + ): + result = model.openai_transport(to_json(request)) + assert json.loads(result)["id"] == "test2" + + def test_to_response_api_tools_empty_list(self, model): + """Test converting empty tools list.""" + result = model._to_response_api_tools([]) + assert result == [] From 9c8e0dc09d9e39222be2c519159d118be6dd0b97 Mon Sep 17 00:00:00 2001 From: Luv Bansal Date: Tue, 16 Dec 2025 15:54:15 +0530 Subject: [PATCH 13/13] fix agentic class --- clarifai/runners/models/agentic_class.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/clarifai/runners/models/agentic_class.py b/clarifai/runners/models/agentic_class.py index 6577b0ae..e3aa2b91 100644 --- a/clarifai/runners/models/agentic_class.py +++ b/clarifai/runners/models/agentic_class.py @@ -565,10 +565,10 @@ def _execute_chat_tools( elif ( hasattr(result, 'content') and len(result.content) > 0 - and result.content[0].get('text') + and hasattr(result.content[0], 'text') ): content = result.content[0].text - elif len(result) > 0 and result[0].get('text'): + elif len(result) > 0 and hasattr(result[0], 'text') and result[0].text: content = result[0].text else: content = None @@ -592,10 +592,11 @@ async def _execute_chat_tools_async( elif ( hasattr(result, 'content') and len(result.content) > 0 - and result.content[0].get('text') + and hasattr(result.content[0], 'text') + and result.content[0].text ): content = result.content[0].text - elif len(result) > 0 and result[0].get('text'): + elif len(result) > 0 and hasattr(result[0], 'text') and result[0].text: content = result[0].text else: content = None @@ -618,10 +619,11 @@ def _execute_response_tools( elif ( hasattr(result, 'content') and len(result.content) > 0 - and result.content[0].get('text') + and hasattr(result.content[0], 'text') + and result.content[0].text ): output = result.content[0].text - elif len(result) > 0 and result[0].get('text'): + elif len(result) > 0 and hasattr(result[0], 'text') and result[0].text: output = result[0].text else: output = None @@ -646,10 +648,11 @@ async def _execute_response_tools_async( elif ( hasattr(result, 'content') and len(result.content) > 0 - and result.content[0].get('text') + and hasattr(result.content[0], 'text') + and result.content[0].text ): output = result.content[0].text - elif len(result) > 0 and result[0].get('text'): + elif len(result) > 0 and hasattr(result[0], 'text') and result[0].text: output = result[0].text else: output = None