diff --git a/clarifai/runners/models/agentic_class.py b/clarifai/runners/models/agentic_class.py new file mode 100644 index 00000000..e3aa2b91 --- /dev/null +++ b/clarifai/runners/models/agentic_class.py @@ -0,0 +1,1155 @@ +"""Base class for creating OpenAI-compatible API server with MCP (Model Context Protocol) support.""" + +import asyncio +import json +import os +import threading +import time +from dataclasses import dataclass, field +from typing import Any, Dict, Iterator, List, Optional, Set, Tuple + +from clarifai_grpc.grpc.api.status import status_code_pb2 +from pydantic_core import from_json, to_json + +from clarifai.runners.models.model_class import ModelClass +from clarifai.runners.models.openai_class import OpenAIModelClass +from clarifai.utils.logging import logger + + +@dataclass +class MCPConnection: + """Single MCP server connection.""" + + client: Any + tools: List[Any] + tool_names: Set[str] + url: str + last_used: float = field(default_factory=time.time) + + def touch(self): + self.last_used = time.time() + + +class MCPConnectionPool: + """ + Singleton, thread-safe connection pool for managing MCP server connections. + Lifecycle: + - The pool is implemented as a singleton. The first instantiation creates the instance; + subsequent instantiations return the same object. + - Initialization sets up internal data structures, a background asyncio event loop, + and a dedicated thread for running asynchronous tasks. + - The event loop is started in a background daemon thread and is used to run async + operations (such as connecting and disconnecting). + Thread Safety: + - All access to shared state (connections, tool caches) is protected by a reentrant lock (`self._lock`). + - The singleton instance is protected by a class-level lock (`_instance_lock`) to ensure only one instance is created. + - The background event loop is started and accessed in a thread-safe manner. + Cleanup Behavior: + - Idle connections are cleaned up passively: whenever `get_connections()` is called, the pool checks for + connections that have been idle longer than `MAX_IDLE_TIME` and disconnects them. + - Cleanup is rate-limited by `CLEANUP_INTERVAL` to avoid excessive checks. + - Disconnection is performed asynchronously in the background event loop. + - Tool caches are invalidated when a connection is removed. + - There is no explicit shutdown; background threads and event loops are daemonized and will exit with the process. + Usage Notes: + - Users do not need to manage the pool directly; it is managed automatically as a singleton. + - Connections are created, reused, and cleaned up transparently. + - The pool is safe for concurrent use from multiple threads. + """ + + _instance: Optional['MCPConnectionPool'] = None + _instance_lock = threading.Lock() + + # Timeouts and thresholds (configurable via environment variables) + # Default: 30s. Time to wait for a connection to be established. Increase if MCP servers are slow to respond. + CONNECT_TIMEOUT = float(os.environ.get("CLARIFAI_MCP_CONNECT_TIMEOUT", 30.0)) + # Default: 60s. Maximum time to wait for a tool call to complete. Increase for long-running tools. + TOOL_CALL_TIMEOUT = float(os.environ.get("CLARIFAI_MCP_TOOL_CALL_TIMEOUT", 60.0)) + # Default: 2min. Connections idle for more than this are verified before reuse. + VERIFY_IDLE_THRESHOLD = float(os.environ.get("CLARIFAI_MCP_VERIFY_IDLE_THRESHOLD", 60 * 2)) + # Default: (15min). Connections idle for more than this are removed from the pool. + MAX_IDLE_TIME = float(os.environ.get("CLARIFAI_MCP_MAX_IDLE_TIME", 15 * 60)) + # Default: 2min. Cleanup runs at most this often to remove idle connections. + CLEANUP_INTERVAL = float(os.environ.get("CLARIFAI_MCP_CLEANUP_INTERVAL", 2 * 60)) + + def __new__(cls): + if cls._instance is None: + with cls._instance_lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + if self._initialized: + return + + self._connections: Dict[str, MCPConnection] = {} + self._lock = threading.RLock() + + # Tool caches + self._tool_to_url: Dict[str, str] = {} + self._all_tools: Dict[str, dict] = {} + + # Cleanup tracking + self._last_cleanup = 0.0 + + # Background event loop + self._loop: Optional[asyncio.AbstractEventLoop] = None + self._loop_thread: Optional[threading.Thread] = None + self._start_event_loop() + + self._initialized = True + + def _start_event_loop(self): + """Start background event loop.""" + ready = threading.Event() + + def run(): + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + ready.set() + self._loop.run_forever() + + self._loop_thread = threading.Thread(target=run, daemon=True, name="mcp_pool") + self._loop_thread.start() + if not ready.wait(timeout=5.0): + raise RuntimeError("Background event loop failed to start within 5 seconds") + + def _run_async(self, coro, timeout: float = 30.0) -> Any: + """Run coroutine in background loop.""" + # Double-checked locking pattern to prevent race condition + # when multiple threads try to restart a closed loop + if self._loop is None or self._loop.is_closed(): + with self._lock: + # Check again after acquiring lock (another thread may have started it) + if self._loop is None or self._loop.is_closed(): + self._start_event_loop() + future = asyncio.run_coroutine_threadsafe(coro, self._loop) + return future.result(timeout=timeout) + + # ==================== Cleanup Logic ==================== + + def _maybe_cleanup_idle(self): + """Passive cleanup - removes connections idle too long. + + Called at the start of get_connections() to clean up + without needing a background thread. + """ + now = time.time() + + # Rate limit cleanup checks + if now - self._last_cleanup < self.CLEANUP_INTERVAL: + return + + self._last_cleanup = now + + # Find idle connections + with self._lock: + to_remove = [ + url + for url, conn in self._connections.items() + if now - conn.last_used > self.MAX_IDLE_TIME + ] + + # Remove them (outside lock to avoid deadlock during async close) + for url in to_remove: + self._disconnect(url) + + def _disconnect(self, url: str): + """Disconnect and remove a connection.""" + with self._lock: + conn = self._connections.pop(url, None) + + # Invalidate tool cache entries for this URL + if conn: + for tool_name in conn.tool_names: + self._tool_to_url.pop(tool_name, None) + self._all_tools.pop(tool_name, None) + + if conn: + try: + self._run_async(self._close_connection(conn), timeout=10.0) + logger.info(f"Disconnected idle connection from {url}") + except Exception as e: + logger.warning(f"Error disconnecting from {url}: {e}") + + async def _close_connection(self, conn: MCPConnection): + """Close a connection gracefully.""" + try: + if hasattr(conn.client, 'close'): + await asyncio.wait_for(conn.client.close(), timeout=5.0) + else: + await asyncio.wait_for(conn.client.__aexit__(None, None, None), timeout=5.0) + except Exception as e: + logger.warning(f"Error closing connection to {conn.url}: {e}") + + # ==================== Connection Management ==================== + + async def _create_connection(self, url: str) -> MCPConnection: + """Create new MCP connection.""" + try: + from fastmcp import Client + from fastmcp.client.transports import StreamableHttpTransport + except ImportError: + raise ImportError("fastmcp required: pip install fastmcp") + + transport = StreamableHttpTransport( + url=url, + headers={"Authorization": "Bearer " + os.environ.get("CLARIFAI_PAT", "")}, + ) + + client = Client(transport) + await asyncio.wait_for(client.__aenter__(), timeout=self.CONNECT_TIMEOUT) + tools = await asyncio.wait_for(client.list_tools(), timeout=10.0) + + return MCPConnection( + client=client, + tools=tools, + tool_names={t.name for t in tools}, + url=url, + ) + + async def _verify_connection(self, conn: MCPConnection) -> bool: + """Check if connection is still valid.""" + try: + await asyncio.wait_for(conn.client.list_tools(), timeout=5.0) + return True + except Exception: + return False + + def _needs_verification(self, conn: MCPConnection) -> bool: + """Check if connection should be verified.""" + return time.time() - conn.last_used > self.VERIFY_IDLE_THRESHOLD + + def _update_tool_cache(self, conn: MCPConnection): + """Cache tool info from connection.""" + for tool in conn.tools: + self._tool_to_url[tool.name] = conn.url + self._all_tools[tool.name] = { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description or "", + "parameters": tool.inputSchema, + }, + } + + def get_connections(self, urls: List[str]) -> Dict[str, MCPConnection]: + """Get connections for URLs, with passive cleanup.""" + # Passive cleanup of idle connections + self._maybe_cleanup_idle() + + result = {} + to_verify = [] + to_create = [] + + # Categorize URLs + with self._lock: + for url in urls: + if url in self._connections: + conn = self._connections[url] + if self._needs_verification(conn): + to_verify.append(url) + else: + conn.touch() + result[url] = conn + else: + to_create.append(url) + + # Verify stale connections in parallel + if to_verify: + + async def verify_all(): + tasks = { + url: self._verify_connection(self._connections[url]) + for url in to_verify + if url in self._connections + } + results = await asyncio.gather(*tasks.values(), return_exceptions=True) + return dict(zip(tasks.keys(), results)) + + try: + verify_results = self._run_async(verify_all(), timeout=15.0) + with self._lock: + for url, is_valid in verify_results.items(): + if is_valid is True: + conn = self._connections[url] + conn.touch() + result[url] = conn + else: + # Invalid - remove and recreate + self._connections.pop(url, None) + to_create.append(url) + except Exception as e: + logger.error(f"Verification error: {e}") + to_create.extend(to_verify) + + # Create new connections in parallel + if to_create: + + async def create_all(): + tasks = {url: self._create_connection(url) for url in to_create} + results = await asyncio.gather(*tasks.values(), return_exceptions=True) + return dict(zip(tasks.keys(), results)) + + try: + create_results = self._run_async(create_all(), timeout=self.CONNECT_TIMEOUT + 5) + with self._lock: + for url, conn_or_error in create_results.items(): + if isinstance(conn_or_error, Exception): + logger.error(f"Failed to connect to {url}: {conn_or_error}") + else: + self._connections[url] = conn_or_error + self._update_tool_cache(conn_or_error) + result[url] = conn_or_error + logger.info(f"✓ Connected to {url} ({len(conn_or_error.tools)} tools)") + except Exception as e: + logger.error(f"Connection error: {e}") + + return result + + def get_tools_and_mapping( + self, urls: List[str] + ) -> Tuple[List[dict], Dict[str, MCPConnection], Dict[str, str]]: + """Get tools, connections, and mapping.""" + connections = self.get_connections(urls) + + tools = [] + tool_to_server = {} + seen = set() + + for url, conn in connections.items(): + for tool in conn.tools: + if tool.name not in seen: + seen.add(tool.name) + tools.append( + self._all_tools.get(tool.name) + or { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description or "", + "parameters": tool.inputSchema, + }, + } + ) + tool_to_server[tool.name] = url + + logger.info(f"Loaded {len(tools)} tools from {len(connections)} servers") + return tools, connections, tool_to_server + + # ==================== Tool Execution ==================== + + async def call_tool_async( + self, + tool_name: str, + arguments: Dict[str, Any], + connections: Dict[str, MCPConnection], + tool_to_server: Dict[str, str], + ) -> Any: + """Call a tool asynchronously.""" + logger.info(f"Calling tool {tool_name}") + url = tool_to_server.get(tool_name) or self._tool_to_url.get(tool_name) + if not url or url not in connections: + raise ValueError(f"Tool '{tool_name}' not found") + + conn = connections[url] + result = await asyncio.wait_for( + conn.client.call_tool(tool_name, arguments=arguments), timeout=self.TOOL_CALL_TIMEOUT + ) + conn.touch() + return result + + def call_tool( + self, + tool_name: str, + arguments: Dict[str, Any], + connections: Dict[str, MCPConnection], + tool_to_server: Dict[str, str], + ) -> Any: + """Call a tool synchronously.""" + logger.info(f"Calling tool {tool_name}") + return self._run_async( + self.call_tool_async(tool_name, arguments, connections, tool_to_server), + timeout=self.TOOL_CALL_TIMEOUT + 5, + ) + + async def call_tools_batch_async( + self, + calls: List[Tuple[str, str, Dict[str, Any]]], # [(id, name, args), ...] + connections: Dict[str, MCPConnection], + tool_to_server: Dict[str, str], + ) -> List[Tuple[str, Optional[Any], Optional[str]]]: + """Call multiple tools in parallel. Returns [(id, result, error), ...]""" + + async def call_one(call_id: str, name: str, args: Dict): + try: + result = await self.call_tool_async(name, args, connections, tool_to_server) + return (call_id, result, None) + except Exception as e: + return (call_id, None, str(e)) + + tasks = [call_one(cid, name, args) for cid, name, args in calls] + return await asyncio.gather(*tasks) + + def call_tools_batch( + self, + calls: List[Tuple[str, str, Dict[str, Any]]], + connections: Dict[str, MCPConnection], + tool_to_server: Dict[str, str], + ) -> List[Tuple[str, Optional[Any], Optional[str]]]: + """Call multiple tools in parallel (sync).""" + return self._run_async( + self.call_tools_batch_async(calls, connections, tool_to_server), + timeout=self.TOOL_CALL_TIMEOUT + 10, + ) + + +class AgenticModelClass(OpenAIModelClass): + """Base class for wrapping OpenAI-compatible servers with MCP (Model Context Protocol) support. + + This class extends OpenAIModelClass to enable agentic behavior by integrating LLMs with MCP servers. + It handles tool discovery, execution, and iterative tool calling for both chat completions and + responses endpoints, supporting both streaming and non-streaming modes. + + MCP connections are maintained persistently across requests using a connection pool, which + significantly improves performance by avoiding reconnection overhead. + + To use this class, create a subclass and set the following class attributes: + - client: The OpenAI-compatible client instance + - model: The name of the model to use with the client + + Example: + class MyAgenticModel(AgenticModelClass): + client = OpenAI(api_key="your-key") + model = "gpt-4" + """ + + _pool: Optional[MCPConnectionPool] = None + _pool_lock = threading.Lock() + + @classmethod + def get_pool(cls) -> MCPConnectionPool: + """Get shared connection pool.""" + if cls._pool is None: + with cls._pool_lock: + if cls._pool is None: + cls._pool = MCPConnectionPool() + return cls._pool + + # === Token Tracking === + + def _init_tokens(self): + if not hasattr(self._thread_local, 'tokens'): + self._thread_local.tokens = {'prompt': 0, 'completion': 0} + + def _add_tokens(self, resp): + """Accumulate tokens from response.""" + usage = getattr(resp, 'usage', None) or ( + getattr(resp.response, 'usage', None) if hasattr(resp, 'response') else None + ) + if usage: + self._init_tokens() + self._thread_local.tokens['prompt'] += ( + getattr(usage, 'prompt_tokens', 0) or getattr(usage, 'input_tokens', 0) or 0 + ) + self._thread_local.tokens['completion'] += ( + getattr(usage, 'completion_tokens', 0) or getattr(usage, 'output_tokens', 0) or 0 + ) + + def _finalize_tokens(self): + """Send accumulated tokens to output context.""" + if hasattr(self._thread_local, 'tokens'): + t = self._thread_local.tokens + if t['prompt'] > 0 or t['completion'] > 0: + self.set_output_context( + prompt_tokens=t['prompt'], completion_tokens=t['completion'] + ) + del self._thread_local.tokens + + def _set_usage(self, resp): + self._add_tokens(resp) + + # === Tool Format Conversion === + + def _to_response_api_tools(self, tools: List[dict]) -> List[dict]: + """Convert chat completion tools to response API format.""" + result = [] + for t in tools: + if "function" in t: + f = t["function"] + result.append( + { + "type": "function", + "name": f.get("name"), + "description": f.get("description", ""), + "parameters": f.get("parameters", {}), + } + ) + elif "name" in t: + result.append(t) + return result + + def _to_dict(self, obj) -> dict: + """Convert object to dict.""" + if isinstance(obj, dict): + return obj + for attr in ('model_dump', 'dict'): + if hasattr(obj, attr): + return getattr(obj, attr)() + return getattr(obj, '__dict__', {}) + + # === Tool Call Parsing === + + def _parse_chat_tool_calls(self, tool_calls) -> List[Tuple[str, str, Dict]]: + """Parse chat completion tool calls into (id, name, args) tuples.""" + result = [] + for tc in tool_calls: + if hasattr(tc, 'function'): + try: + args = json.loads(tc.function.arguments) + except json.JSONDecodeError: + logger.warning( + f"Malformed tool arguments for tool '{getattr(tc.function, 'name', None)}': {tc.function.arguments!r}" + ) + args = {} + result.append((tc.id, tc.function.name, args)) + else: + try: + args = json.loads(tc['function']['arguments']) + except json.JSONDecodeError: + logger.warning( + f"Malformed tool arguments for tool '{tc['function'].get('name', None)}': {tc['function']['arguments']!r}" + ) + args = {} + result.append((tc['id'], tc['function']['name'], args)) + return result + + def _parse_response_tool_calls(self, items: List[dict]) -> List[Tuple[str, str, Dict]]: + """Parse response API tool calls into (call_id, name, args) tuples.""" + result = [] + for item in items: + d = self._to_dict(item) + if d.get('type') in ('function_tool_call', 'function_call', 'function', 'tool_call'): + status = d.get('status', '') + if status in ('pending', 'in_progress', '') or d.get('output') is None: + call_id = d.get('call_id') or d.get('id') + name = d.get('name') + args_str = d.get('arguments', '{}') + if call_id and name: + try: + args = json.loads(args_str) if isinstance(args_str, str) else args_str + except json.JSONDecodeError: + args = {} + result.append((call_id, name, args)) + return result + + # === Tool Execution === + + def _execute_chat_tools( + self, + tool_calls, + connections: Dict[str, MCPConnection], + messages: List[dict], + tool_to_server: Dict[str, str], + ): + """Execute chat completion tool calls and append results to messages.""" + pool = self.get_pool() + parsed = self._parse_chat_tool_calls(tool_calls) + results = pool.call_tools_batch(parsed, connections, tool_to_server) + + for call_id, result, error in results: + if error: + content = f"Error: {error}" + elif ( + hasattr(result, 'content') + and len(result.content) > 0 + and hasattr(result.content[0], 'text') + ): + content = result.content[0].text + elif len(result) > 0 and hasattr(result[0], 'text') and result[0].text: + content = result[0].text + else: + content = None + messages.append({"role": "tool", "tool_call_id": call_id, "content": content}) + + async def _execute_chat_tools_async( + self, + tool_calls, + connections: Dict[str, MCPConnection], + messages: List[dict], + tool_to_server: Dict[str, str], + ): + """Async version of chat tool execution.""" + pool = self.get_pool() + parsed = self._parse_chat_tool_calls(tool_calls) + results = await pool.call_tools_batch_async(parsed, connections, tool_to_server) + + for call_id, result, error in results: + if error: + content = f"Error: {error}" + elif ( + hasattr(result, 'content') + and len(result.content) > 0 + and hasattr(result.content[0], 'text') + and result.content[0].text + ): + content = result.content[0].text + elif len(result) > 0 and hasattr(result[0], 'text') and result[0].text: + content = result[0].text + else: + content = None + messages.append({"role": "tool", "tool_call_id": call_id, "content": content}) + + def _execute_response_tools( + self, + tool_calls: List[Tuple[str, str, Dict]], + connections: Dict[str, MCPConnection], + input_items: List, + tool_to_server: Dict[str, str], + ): + """Execute response API tool calls and append results to input_items.""" + pool = self.get_pool() + results = pool.call_tools_batch(tool_calls, connections, tool_to_server) + + for call_id, result, error in results: + if error: + output = f"Error: {error}" + elif ( + hasattr(result, 'content') + and len(result.content) > 0 + and hasattr(result.content[0], 'text') + and result.content[0].text + ): + output = result.content[0].text + elif len(result) > 0 and hasattr(result[0], 'text') and result[0].text: + output = result[0].text + else: + output = None + input_items.append( + {"type": "function_call_output", "call_id": call_id, "output": output} + ) + + async def _execute_response_tools_async( + self, + tool_calls: List[Tuple[str, str, Dict]], + connections: Dict[str, MCPConnection], + input_items: List, + tool_to_server: Dict[str, str], + ): + """Async version of response API tool execution.""" + pool = self.get_pool() + results = await pool.call_tools_batch_async(tool_calls, connections, tool_to_server) + + for call_id, result, error in results: + if error: + output = f"Error: {error}" + elif ( + hasattr(result, 'content') + and len(result.content) > 0 + and hasattr(result.content[0], 'text') + and result.content[0].text + ): + output = result.content[0].text + elif len(result) > 0 and hasattr(result[0], 'text') and result[0].text: + output = result[0].text + else: + output = None + input_items.append( + {"type": "function_call_output", "call_id": call_id, "output": output} + ) + + # === Response Output Processing === + + def _convert_output_to_input(self, output: List) -> List[dict]: + """Convert response API output items to input items.""" + result = [] + for item in output: + d = self._to_dict(item) + t = d.get('type') + if t in ('message', 'reasoning'): + result.append(d) + elif t in ('function_tool_call', 'function_call', 'function', 'tool_call'): + if d.get('output') is not None or d.get('status') in ('completed', 'done'): + result.append(d) + return result + + # === Request Handlers === + + def _handle_chat_completions( + self, request_data: Dict, mcp_servers=None, connections=None, tools=None + ): + if mcp_servers and tools: + request_data = { + **request_data, + "tools": tools, + "tool_choice": request_data.get("tool_choice", "auto"), + } + return super()._handle_chat_completions(request_data) + + def _handle_responses( + self, request_data: Dict, mcp_servers=None, connections=None, tools=None + ): + if mcp_servers and tools: + request_data = { + **request_data, + "tools": self._to_response_api_tools(tools), + "tool_choice": request_data.get("tool_choice", "auto"), + } + return super()._handle_responses(request_data) + + def _route_request( + self, endpoint: str, request_data: Dict, mcp_servers=None, connections=None, tools=None + ): + if endpoint == self.ENDPOINT_CHAT_COMPLETIONS: + return self._handle_chat_completions(request_data, mcp_servers, connections, tools) + if endpoint == self.ENDPOINT_RESPONSES: + return self._handle_responses(request_data, mcp_servers, connections, tools) + return super()._route_request(endpoint, request_data) + + # === Streaming Helpers === + + def _accumulate_tool_delta(self, delta, accumulated: dict): + """Accumulate streaming tool call deltas.""" + idx = delta.index + if idx not in accumulated: + accumulated[idx] = { + "id": delta.id, + "type": "function", + "function": {"name": "", "arguments": ""}, + } + if delta.id: + accumulated[idx]["id"] = delta.id + if delta.function: + if delta.function.name: + accumulated[idx]["function"]["name"] = delta.function.name + if delta.function.arguments: + accumulated[idx]["function"]["arguments"] += delta.function.arguments + + def _finalize_tool_calls(self, accumulated: dict) -> List[dict]: + """Convert accumulated tool calls to list.""" + return [ + {"id": v["id"], "type": "function", "function": v["function"]} + for v in (accumulated[k] for k in sorted(accumulated)) + ] + + def _create_stream_request(self, messages, tools, max_tokens, temperature, top_p): + """Create streaming chat completion request.""" + kwargs = { + "model": self.model, + "messages": messages, + "max_completion_tokens": max_tokens, + "temperature": temperature, + "top_p": top_p, + "stream": True, + "stream_options": {"include_usage": True}, + } + if tools: + kwargs["tools"] = tools + kwargs["tool_choice"] = "auto" + return self.client.chat.completions.create(**kwargs) + + def _async_to_sync_generator(self, async_gen_fn): + """Bridge async generator to sync generator.""" + pool = self.get_pool() + loop = pool._loop + queue = asyncio.Queue() + error_holder = [None] + + async def producer(): + try: + async for item in async_gen_fn(): + await queue.put(item) + except Exception as e: + error_holder[0] = e + finally: + await queue.put(None) + + asyncio.run_coroutine_threadsafe(producer(), loop) + + while True: + future = asyncio.run_coroutine_threadsafe(queue.get(), loop) + item = future.result(timeout=120.0) + if item is None: + if error_holder[0]: + raise error_holder[0] + break + yield item + + # === Streaming with MCP === + + async def _stream_chat_with_tools( + self, messages, tools, connections, tool_to_server, max_tokens, temperature, top_p + ): + """ + Stream chat completions with MCP tool support, recursively handling tool calls. + This method streams chat completion chunks, accumulating any tool calls generated by the model. + If tool calls are present after streaming, it executes those tools and recursively continues + streaming with the updated messages (including tool call results). The recursion terminates + when no further tool calls are generated in the streamed response. + Args: + messages: The list of chat messages so far. + tools: The list of available tools. + connections: MCP tool connections. + tool_to_server: Mapping of tool names to server URLs. + max_tokens: Maximum number of tokens to generate. + temperature: Sampling temperature. + top_p: Nucleus sampling parameter. + Yields: + JSON-serialized chat completion chunks. + """ + accumulated_tools = {} + assistant_content = "" + + for chunk in self._create_stream_request(messages, tools, max_tokens, temperature, top_p): + self._set_usage(chunk) + yield chunk.model_dump_json() + + if chunk.choices: + delta = chunk.choices[0].delta + if delta.tool_calls: + for tc in delta.tool_calls: + self._accumulate_tool_delta(tc, accumulated_tools) + if delta.content: + assistant_content += delta.content + + if accumulated_tools: + tool_calls = self._finalize_tool_calls(accumulated_tools) + messages.append( + { + "role": "assistant", + "content": assistant_content or None, + "tool_calls": tool_calls, + } + ) + await self._execute_chat_tools_async(tool_calls, connections, messages, tool_to_server) + + async for chunk in self._stream_chat_with_tools( + messages, tools, connections, tool_to_server, max_tokens, temperature, top_p + ): + yield chunk + + async def _stream_responses_with_tools(self, request_data, tools, connections, tool_to_server): + """ + Streams responses for the API with MCP (Model Context Protocol) tool support. + This method processes the incoming request data, which may include user messages or input items, + and streams back responses that may involve multiple event types, such as user messages, assistant + responses, and tool call events. It supports both single string input and a list of message objects. + Event Handling Flow: + - Parses the input data into a standardized list of message items. + - Prepares response arguments, including tool definitions and tool choice if tools are provided. + - Accumulates output chunks as they are generated, yielding each chunk as a JSON-encoded string. + - Handles tool call events by invoking the appropriate tools via MCP connections, and streams + the results back as part of the response. + - May recursively invoke itself to handle follow-up tool calls or multi-turn interactions. + Args: + request_data (dict): The incoming request payload, including input messages. + tools (list): List of tool definitions available for invocation. + connections (dict): MCP connections for tool execution. + tool_to_server (dict): Mapping from tool names to server endpoints. + Yields: + str: JSON-encoded response chunks, streamed as they become available. + """ + input_data = request_data.get("input", "") + input_items = ( + [ + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": input_data}], + } + ] + if isinstance(input_data, str) + else (input_data if isinstance(input_data, list) else []) + ) + + response_args = {**request_data, "model": self.model} + if tools: + response_args["tools"] = self._to_response_api_tools(tools) + response_args["tool_choice"] = response_args.get("tool_choice", "auto") + + accumulated_output = [] + tool_calls_by_id = {} + msg_index_map = {} + + for chunk in self.client.responses.create(**response_args): + self._set_usage(chunk) + chunk_type = getattr(chunk, 'type', '') or chunk.__class__.__name__ + + # Track message indices for filtering + if chunk_type in ( + 'response.output_item.added', + 'ResponseOutputItemAddedEvent', + ) and hasattr(chunk, 'item'): + item_dict = self._to_dict(chunk.item) + if item_dict.get('type') == 'message' and hasattr(chunk, 'output_index'): + msg_index_map[chunk.output_index] = len(msg_index_map) + elif item_dict.get('type') in ( + 'function_tool_call', + 'function_call', + 'function', + 'tool_call', + ): + item_id = item_dict.get('id') or item_dict.get('call_id') + if item_id: + tool_calls_by_id[item_id] = { + 'id': item_id, + 'call_id': item_dict.get('call_id'), + 'type': item_dict.get('type'), + 'name': item_dict.get('name', ''), + 'arguments': item_dict.get('arguments', ''), + 'status': 'in_progress', + } + + # Accumulate arguments + elif chunk_type in ( + 'response.function_call_arguments.delta', + 'ResponseFunctionCallArgumentsDeltaEvent', + ): + item_id = getattr(chunk, 'item_id', None) + if item_id and item_id in tool_calls_by_id: + tool_calls_by_id[item_id]['arguments'] += getattr(chunk, 'delta', '') + + elif chunk_type in ( + 'response.function_call_arguments.done', + 'ResponseFunctionCallArgumentsDoneEvent', + ): + item_id = getattr(chunk, 'item_id', None) + if item_id and item_id in tool_calls_by_id: + tool_calls_by_id[item_id]['arguments'] = getattr(chunk, 'arguments', '') + + # Mark tool call complete + elif chunk_type in ( + 'response.output_item.done', + 'ResponseOutputItemDoneEvent', + ) and hasattr(chunk, 'item'): + item_dict = self._to_dict(chunk.item) + item_type = item_dict.get('type') + if item_type in ('function_tool_call', 'function_call', 'function', 'tool_call'): + item_id = item_dict.get('id') + if item_id and item_id in tool_calls_by_id: + tool_calls_by_id[item_id]['status'] = 'completed' + if 'call_id' in item_dict: + tool_calls_by_id[item_id]['call_id'] = item_dict['call_id'] + accumulated_output.append(tool_calls_by_id[item_id]) + else: + accumulated_output.append(item_dict) + + # Handle completed response - filter to messages only + elif chunk_type in ('response.completed', 'ResponseCompletedEvent') and hasattr( + chunk, 'response' + ): + resp = chunk.response + if hasattr(resp, 'output') and resp.output: + filtered = [ + self._to_dict(i) + for i in resp.output + if self._to_dict(i).get('type') == 'message' + ] + resp_dict = self._to_dict(resp) + resp_dict['output'] = filtered + yield json.dumps( + { + 'type': 'response.completed', + 'sequence_number': getattr(chunk, 'sequence_number', None), + 'response': resp_dict, + } + ) + continue + + # Yield message-related chunks with remapped indices + should_yield = True + if hasattr(chunk, 'output_index'): + if chunk.output_index not in msg_index_map: + should_yield = False + else: + chunk_dict = self._to_dict(chunk) + chunk_dict['output_index'] = msg_index_map[chunk.output_index] + yield json.dumps(chunk_dict) + continue + + if should_yield and chunk_type not in ( + 'response.function_call_arguments.delta', + 'ResponseFunctionCallArgumentsDeltaEvent', + 'response.function_call_arguments.done', + 'ResponseFunctionCallArgumentsDoneEvent', + ): + item = getattr(chunk, 'item', None) + if item: + if self._to_dict(item).get('type') not in ( + 'function_tool_call', + 'function_call', + 'function', + 'tool_call', + ): + yield chunk.model_dump_json() + else: + yield chunk.model_dump_json() + + # Add any remaining tool calls + for tc in tool_calls_by_id.values(): + if tc.get('name') and tc['id'] not in { + self._to_dict(o).get('id') for o in accumulated_output + }: + accumulated_output.append(tc) + + # Execute tool calls if any + tool_calls = self._parse_response_tool_calls(accumulated_output) + if tool_calls: + input_items.extend(self._convert_output_to_input(accumulated_output)) + await self._execute_response_tools_async( + tool_calls, connections, input_items, tool_to_server + ) + request_data['input'] = input_items + + async for chunk in self._stream_responses_with_tools( + request_data, tools, connections, tool_to_server + ): + yield chunk + + # === Main OpenAI Methods === + + @ModelClass.method + def openai_transport(self, msg: str) -> str: + """Handle non-streaming OpenAI requests.""" + try: + data = from_json(msg) + data = self._update_old_fields(data) + mcp_servers = data.pop("mcp_servers", None) + endpoint = data.pop("openai_endpoint", self.DEFAULT_ENDPOINT) + + if mcp_servers and data.get("tools") is None: + pool = self.get_pool() + tools, connections, tool_to_server = pool.get_tools_and_mapping(mcp_servers) + + if endpoint == self.ENDPOINT_CHAT_COMPLETIONS: + response = self._route_request(endpoint, data, mcp_servers, connections, tools) + + while ( + response.choices + and len(response.choices) > 0 + and response.choices[0].get('message', {}).get('tool_calls') + ): + messages = data.get("messages", []) + messages.append(response.choices[0].message) + self._execute_chat_tools( + response.choices[0].message.tool_calls, + connections, + messages, + tool_to_server, + ) + data["messages"] = messages + response = self._route_request( + endpoint, data, mcp_servers, connections, tools + ) + + elif endpoint == self.ENDPOINT_RESPONSES: + response = self._route_request(endpoint, data, mcp_servers, connections, tools) + + input_data = data.get("input", "") + input_items = ( + [ + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": input_data}], + } + ] + if isinstance(input_data, str) + else (input_data if isinstance(input_data, list) else []) + ) + + output = response.output if hasattr(response, 'output') else [] + tool_calls = self._parse_response_tool_calls(output) + + while tool_calls: + input_items.extend(self._convert_output_to_input(output)) + self._execute_response_tools( + tool_calls, connections, input_items, tool_to_server + ) + data["input"] = input_items + response = self._route_request( + endpoint, data, mcp_servers, connections, tools + ) + output = response.output if hasattr(response, 'output') else [] + tool_calls = self._parse_response_tool_calls(output) + else: + response = self._route_request(endpoint, data) + else: + response = self._route_request(endpoint, data) + + self._finalize_tokens() + return response.model_dump_json() + + except Exception as e: + logger.exception(e) + return to_json( + { + "code": status_code_pb2.MODEL_PREDICTION_FAILED, + "description": "Model prediction failed", + "details": str(e), + } + ) + + @ModelClass.method + def openai_stream_transport(self, msg: str) -> Iterator[str]: + """Handle streaming OpenAI requests.""" + try: + data = from_json(msg) + data = self._update_old_fields(data) + mcp_servers = data.pop("mcp_servers", None) + endpoint = data.pop("openai_endpoint", self.DEFAULT_ENDPOINT) + + if endpoint not in (self.ENDPOINT_CHAT_COMPLETIONS, self.ENDPOINT_RESPONSES): + raise ValueError("Streaming only for chat completions and responses") + + if mcp_servers and data.get("tools") is None: + pool = self.get_pool() + tools, connections, tool_to_server = pool.get_tools_and_mapping(mcp_servers) + + if endpoint == self.ENDPOINT_CHAT_COMPLETIONS: + yield from self._async_to_sync_generator( + lambda: self._stream_chat_with_tools( + data.get("messages", []), + tools, + connections, + tool_to_server, + data.get("max_completion_tokens", 4096), + data.get("temperature", 1.0), + data.get("top_p", 1.0), + ) + ) + else: + yield from self._async_to_sync_generator( + lambda: self._stream_responses_with_tools( + data, tools, connections, tool_to_server + ) + ) + + self._finalize_tokens() + return + + # Non-MCP streaming + if endpoint == self.ENDPOINT_RESPONSES: + for chunk in self.client.responses.create(**{**data, "model": self.model}): + self._set_usage(chunk) + yield chunk.model_dump_json() + else: + for chunk in self.client.chat.completions.create( + **self._create_completion_args(data) + ): + self._set_usage(chunk) + yield chunk.model_dump_json() + + self._finalize_tokens() + + except Exception as e: + logger.exception(e) + yield to_json( + { + "code": status_code_pb2.MODEL_PREDICTION_FAILED, + "description": "Model prediction failed", + "details": str(e), + } + ) diff --git a/tests/runners/test_agentic_class.py b/tests/runners/test_agentic_class.py new file mode 100644 index 00000000..6ba7a554 --- /dev/null +++ b/tests/runners/test_agentic_class.py @@ -0,0 +1,1210 @@ +"""Comprehensive tests for AgenticModelClass and MCPConnectionPool integration.""" + +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from pydantic_core import to_json + +from clarifai.runners.models.agentic_class import ( + AgenticModelClass, + MCPConnection, + MCPConnectionPool, +) +from clarifai.runners.models.dummy_openai_model import MockOpenAIClient + + +class DummyAgenticModel(AgenticModelClass): + """Dummy agentic model for testing.""" + + client = MockOpenAIClient() + model = "dummy-agentic-model" + + +class TestAgenticModelClass: + """Tests for AgenticModelClass functionality.""" + + @pytest.fixture(autouse=True) + def reset_pool(self): + """Reset the singleton instance before each test.""" + AgenticModelClass._pool = None + MCPConnectionPool._instance = None + yield + AgenticModelClass._pool = None + MCPConnectionPool._instance = None + + @pytest.fixture + def model(self): + """Create a test model instance.""" + return DummyAgenticModel() + + @pytest.fixture + def mock_pool(self): + """Create a mock connection pool.""" + pool = MagicMock(spec=MCPConnectionPool) + pool.get_tools_and_mapping = MagicMock(return_value=([], {}, {})) + pool.call_tools_batch = MagicMock(return_value=[]) + pool.call_tools_batch_async = AsyncMock(return_value=[]) + pool._loop = asyncio.new_event_loop() + return pool + + # === Token Tracking Tests === + + def test_init_tokens(self, model): + """Test token initialization.""" + model._init_tokens() + assert hasattr(model._thread_local, 'tokens') + assert model._thread_local.tokens == {'prompt': 0, 'completion': 0} + + def test_add_tokens_from_usage(self, model): + """Test adding tokens from response with usage attribute.""" + mock_response = MagicMock() + mock_usage = MagicMock() + mock_usage.prompt_tokens = 10 + mock_usage.completion_tokens = 20 + mock_response.usage = mock_usage + + model._add_tokens(mock_response) + assert model._thread_local.tokens['prompt'] == 10 + assert model._thread_local.tokens['completion'] == 20 + + def test_add_tokens_from_response_usage(self, model): + """Test adding tokens from response.response.usage.""" + mock_response = MagicMock() + mock_response.usage = None + mock_inner_response = MagicMock() + + # Use a simple object instead of MagicMock to avoid mock attribute issues + class MockUsage: + input_tokens = 15 + output_tokens = 25 + + mock_usage = MockUsage() + mock_inner_response.usage = mock_usage + mock_response.response = mock_inner_response + + model._add_tokens(mock_response) + assert model._thread_local.tokens['prompt'] == 15 + assert model._thread_local.tokens['completion'] == 25 + + def test_add_tokens_accumulates(self, model): + """Test that tokens accumulate across multiple calls.""" + mock_response1 = MagicMock() + mock_usage1 = MagicMock() + mock_usage1.prompt_tokens = 10 + mock_usage1.completion_tokens = 20 + mock_response1.usage = mock_usage1 + + mock_response2 = MagicMock() + mock_usage2 = MagicMock() + mock_usage2.prompt_tokens = 5 + mock_usage2.completion_tokens = 10 + mock_response2.usage = mock_usage2 + + model._add_tokens(mock_response1) + model._add_tokens(mock_response2) + + assert model._thread_local.tokens['prompt'] == 15 + assert model._thread_local.tokens['completion'] == 30 + + def test_finalize_tokens(self, model): + """Test finalizing tokens to output context.""" + model._init_tokens() + model._thread_local.tokens['prompt'] = 10 + model._thread_local.tokens['completion'] = 20 + + with patch.object(model, 'set_output_context') as mock_set: + model._finalize_tokens() + mock_set.assert_called_once_with(prompt_tokens=10, completion_tokens=20) + assert not hasattr(model._thread_local, 'tokens') + + def test_finalize_tokens_no_tokens(self, model): + """Test finalizing when no tokens were tracked.""" + with patch.object(model, 'set_output_context') as mock_set: + model._finalize_tokens() + mock_set.assert_not_called() + + # === Tool Format Conversion Tests === + + def test_to_response_api_tools_with_function(self, model): + """Test converting chat completion tools to response API format.""" + tools = [ + { + "type": "function", + "function": { + "name": "test_tool", + "description": "A test tool", + "parameters": {"type": "object", "properties": {}}, + }, + } + ] + result = model._to_response_api_tools(tools) + assert len(result) == 1 + assert result[0]["type"] == "function" + assert result[0]["name"] == "test_tool" + assert result[0]["description"] == "A test tool" + + def test_to_response_api_tools_with_name(self, model): + """Test converting tools that already have name field.""" + tools = [ + { + "type": "function", + "name": "test_tool", + "description": "A test tool", + "parameters": {}, + } + ] + result = model._to_response_api_tools(tools) + assert len(result) == 1 + assert result[0]["name"] == "test_tool" + + def test_to_dict_from_dict(self, model): + """Test _to_dict with dict input.""" + obj = {"key": "value"} + result = model._to_dict(obj) + assert result == obj + + def test_to_dict_with_model_dump(self, model): + """Test _to_dict with object that has model_dump method.""" + obj = MagicMock() + obj.model_dump.return_value = {"key": "value"} + result = model._to_dict(obj) + assert result == {"key": "value"} + + def test_to_dict_with_dict_method(self, model): + """Test _to_dict with object that has dict method.""" + + # Use a simple class to avoid MagicMock issues + class TestObj: + def dict(self): + return {"key": "value"} + + obj = TestObj() + result = model._to_dict(obj) + assert result == {"key": "value"} + + def test_to_dict_with_dict_attribute(self, model): + """Test _to_dict with object that has __dict__ attribute.""" + + class TestObj: + def __init__(self): + self.key = "value" + + obj = TestObj() + result = model._to_dict(obj) + assert result == {"key": "value"} + + # === Tool Call Parsing Tests === + + def test_parse_chat_tool_calls_with_function_attribute(self, model): + """Test parsing chat completion tool calls with function attribute.""" + mock_tool_call = MagicMock() + mock_tool_call.id = "call_123" + mock_tool_call.function.name = "test_tool" + mock_tool_call.function.arguments = '{"arg": "value"}' + + tool_calls = [mock_tool_call] + result = model._parse_chat_tool_calls(tool_calls) + + assert len(result) == 1 + assert result[0] == ("call_123", "test_tool", {"arg": "value"}) + + def test_parse_chat_tool_calls_with_dict(self, model): + """Test parsing chat completion tool calls from dict.""" + tool_calls = [ + {"id": "call_123", "function": {"name": "test_tool", "arguments": '{"arg": "value"}'}} + ] + result = model._parse_chat_tool_calls(tool_calls) + + assert len(result) == 1 + assert result[0] == ("call_123", "test_tool", {"arg": "value"}) + + def test_parse_chat_tool_calls_invalid_json(self, model): + """Test parsing tool calls with invalid JSON arguments.""" + mock_tool_call = MagicMock() + mock_tool_call.id = "call_123" + mock_tool_call.function.name = "test_tool" + mock_tool_call.function.arguments = "invalid json" + + tool_calls = [mock_tool_call] + result = model._parse_chat_tool_calls(tool_calls) + + assert len(result) == 1 + assert result[0] == ("call_123", "test_tool", {}) + + def test_parse_response_tool_calls(self, model): + """Test parsing response API tool calls.""" + items = [ + { + "type": "function_tool_call", + "call_id": "call_123", + "name": "test_tool", + "arguments": '{"arg": "value"}', + "status": "pending", + } + ] + result = model._parse_response_tool_calls(items) + + assert len(result) == 1 + assert result[0] == ("call_123", "test_tool", {"arg": "value"}) + + def test_parse_response_tool_calls_with_id(self, model): + """Test parsing response API tool calls using id instead of call_id.""" + items = [ + { + "type": "function_call", + "id": "call_123", + "name": "test_tool", + "arguments": '{"arg": "value"}', + "output": None, + } + ] + result = model._parse_response_tool_calls(items) + + assert len(result) == 1 + assert result[0] == ("call_123", "test_tool", {"arg": "value"}) + + def test_parse_response_tool_calls_skips_completed(self, model): + """Test that completed tool calls are skipped.""" + items = [ + { + "type": "function_tool_call", + "call_id": "call_123", + "name": "test_tool", + "arguments": '{"arg": "value"}', + "status": "completed", + "output": "result", + } + ] + result = model._parse_response_tool_calls(items) + assert len(result) == 0 + + def test_parse_response_tool_calls_with_dict_arguments(self, model): + """Test parsing response tool calls with dict arguments instead of string.""" + items = [ + { + "type": "function_tool_call", + "call_id": "call_123", + "name": "test_tool", + "arguments": {"arg": "value"}, # Already a dict + "status": "pending", + } + ] + result = model._parse_response_tool_calls(items) + + assert len(result) == 1 + assert result[0] == ("call_123", "test_tool", {"arg": "value"}) + + def test_parse_response_tool_calls_empty_string_arguments(self, model): + """Test parsing response tool calls with empty string arguments.""" + items = [ + { + "type": "function_tool_call", + "call_id": "call_123", + "name": "test_tool", + "arguments": "", + "status": "pending", + } + ] + result = model._parse_response_tool_calls(items) + + assert len(result) == 1 + assert result[0] == ("call_123", "test_tool", {}) + + # === Tool Execution Tests === + + def test_execute_chat_tools(self, model): + """Test executing chat completion tool calls.""" + mock_pool = MagicMock() + mock_result = MagicMock() + mock_result.content = [MagicMock(text="Tool result")] + mock_pool.call_tools_batch.return_value = [("call_123", mock_result, None)] + + with patch.object(model, 'get_pool', return_value=mock_pool): + tool_calls = [MagicMock()] + tool_calls[0].id = "call_123" + tool_calls[0].function.name = "test_tool" + tool_calls[0].function.arguments = '{}' + + connections = {} + messages = [] + tool_to_server = {} + + model._execute_chat_tools(tool_calls, connections, messages, tool_to_server) + + assert len(messages) == 1 + assert messages[0]["role"] == "tool" + assert messages[0]["tool_call_id"] == "call_123" + assert messages[0]["content"] == "Tool result" + + def test_execute_chat_tools_with_error(self, model): + """Test executing chat tools with error.""" + mock_pool = MagicMock() + mock_pool.call_tools_batch.return_value = [("call_123", None, "Tool error")] + + with patch.object(model, 'get_pool', return_value=mock_pool): + tool_calls = [MagicMock()] + tool_calls[0].id = "call_123" + tool_calls[0].function.name = "test_tool" + tool_calls[0].function.arguments = '{}' + + connections = {} + messages = [] + tool_to_server = {} + + model._execute_chat_tools(tool_calls, connections, messages, tool_to_server) + + assert len(messages) == 1 + assert messages[0]["content"] == "Error: Tool error" + + def test_execute_chat_tools_with_list_result(self, model): + """Test executing chat tools with list result format.""" + mock_pool = MagicMock() + + # Use a dict-like object that supports both .get() and .text access + class TextDict: + def __init__(self, text): + self._text = text + + def get(self, key, default=None): + if key == 'text': + return self._text + return default + + @property + def text(self): + return self._text + + mock_result = [TextDict("List result")] + mock_pool.call_tools_batch.return_value = [("call_123", mock_result, None)] + + with patch.object(model, 'get_pool', return_value=mock_pool): + tool_calls = [MagicMock()] + tool_calls[0].id = "call_123" + tool_calls[0].function.name = "test_tool" + tool_calls[0].function.arguments = '{}' + + connections = {} + messages = [] + tool_to_server = {} + + model._execute_chat_tools(tool_calls, connections, messages, tool_to_server) + + assert len(messages) == 1 + assert messages[0]["content"] == "List result" + + def test_execute_chat_tools_with_none_content(self, model): + """Test executing chat tools when result has no content.""" + mock_pool = MagicMock() + mock_result = MagicMock() + mock_result.content = [] + mock_pool.call_tools_batch.return_value = [("call_123", mock_result, None)] + + with patch.object(model, 'get_pool', return_value=mock_pool): + tool_calls = [MagicMock()] + tool_calls[0].id = "call_123" + tool_calls[0].function.name = "test_tool" + tool_calls[0].function.arguments = '{}' + + connections = {} + messages = [] + tool_to_server = {} + + model._execute_chat_tools(tool_calls, connections, messages, tool_to_server) + + assert len(messages) == 1 + assert messages[0]["content"] is None + + @pytest.mark.asyncio + async def test_execute_chat_tools_async(self, model): + """Test async execution of chat tools.""" + mock_pool = MagicMock() + mock_result = MagicMock() + mock_result.content = [MagicMock(text="Async result")] + mock_pool.call_tools_batch_async = AsyncMock( + return_value=[("call_123", mock_result, None)] + ) + + with patch.object(model, 'get_pool', return_value=mock_pool): + tool_calls = [MagicMock()] + tool_calls[0].id = "call_123" + tool_calls[0].function.name = "test_tool" + tool_calls[0].function.arguments = '{}' + + connections = {} + messages = [] + tool_to_server = {} + + await model._execute_chat_tools_async( + tool_calls, connections, messages, tool_to_server + ) + + assert len(messages) == 1 + assert messages[0]["content"] == "Async result" + + def test_execute_response_tools(self, model): + """Test executing response API tool calls.""" + mock_pool = MagicMock() + mock_result = MagicMock() + mock_result.content = [MagicMock(text="Response result")] + mock_pool.call_tools_batch.return_value = [("call_123", mock_result, None)] + + with patch.object(model, 'get_pool', return_value=mock_pool): + tool_calls = [("call_123", "test_tool", {})] + connections = {} + input_items = [] + tool_to_server = {} + + model._execute_response_tools(tool_calls, connections, input_items, tool_to_server) + + assert len(input_items) == 1 + assert input_items[0]["type"] == "function_call_output" + assert input_items[0]["call_id"] == "call_123" + assert input_items[0]["output"] == "Response result" + + @pytest.mark.asyncio + async def test_execute_response_tools_async(self, model): + """Test async execution of response tools.""" + mock_pool = MagicMock() + mock_result = MagicMock() + mock_result.content = [MagicMock(text="Async response result")] + mock_pool.call_tools_batch_async = AsyncMock( + return_value=[("call_123", mock_result, None)] + ) + + with patch.object(model, 'get_pool', return_value=mock_pool): + tool_calls = [("call_123", "test_tool", {})] + connections = {} + input_items = [] + tool_to_server = {} + + await model._execute_response_tools_async( + tool_calls, connections, input_items, tool_to_server + ) + + assert len(input_items) == 1 + assert input_items[0]["output"] == "Async response result" + + # === Response Output Processing Tests === + + def test_convert_output_to_input(self, model): + """Test converting response API output to input items.""" + output = [ + {"type": "message", "role": "assistant", "content": "Hello"}, + {"type": "function_tool_call", "name": "tool1", "output": "result"}, + {"type": "reasoning", "content": "Thinking..."}, + ] + result = model._convert_output_to_input(output) + + assert len(result) == 3 + assert result[0]["type"] == "message" + assert result[1]["type"] == "function_tool_call" + assert result[2]["type"] == "reasoning" + + def test_convert_output_to_input_filters_pending(self, model): + """Test that pending tool calls are filtered out.""" + output = [ + {"type": "function_tool_call", "name": "tool1", "status": "pending", "output": None} + ] + result = model._convert_output_to_input(output) + assert len(result) == 0 + + def test_convert_output_to_input_empty_list(self, model): + """Test converting empty output list.""" + result = model._convert_output_to_input([]) + assert result == [] + + # === Request Handler Tests === + + def test_handle_chat_completions_with_tools(self, model): + """Test handling chat completions with MCP tools.""" + tools = [{"type": "function", "function": {"name": "test_tool"}}] + request_data = {"messages": [{"role": "user", "content": "Hello"}]} + + with patch.object( + AgenticModelClass.__bases__[0], '_handle_chat_completions' + ) as mock_super_handle: + mock_response = MagicMock() + mock_super_handle.return_value = mock_response + + result = model._handle_chat_completions( + request_data, mcp_servers=["http://server"], connections={}, tools=tools + ) + + # The method creates a new dict, so check what was passed to super + call_args = mock_super_handle.call_args[0][0] + assert "tools" in call_args + assert call_args["tool_choice"] == "auto" + + def test_handle_chat_completions_without_tools(self, model): + """Test handling chat completions without MCP tools.""" + request_data = {"messages": [{"role": "user", "content": "Hello"}]} + + with patch.object(model, '_handle_chat_completions') as mock_handle: + mock_response = MagicMock() + mock_handle.return_value = mock_response + result = model._handle_chat_completions(request_data, None, None, None) + + # Should not modify request_data when no tools + assert "tools" not in request_data + + def test_handle_responses_with_tools(self, model): + """Test handling responses with MCP tools.""" + tools = [{"type": "function", "function": {"name": "test_tool"}}] + request_data = {"input": "Hello"} + + with patch.object( + AgenticModelClass.__bases__[0], '_handle_responses' + ) as mock_super_handle: + mock_response = MagicMock() + mock_super_handle.return_value = mock_response + + result = model._handle_responses( + request_data, mcp_servers=["http://server"], connections={}, tools=tools + ) + + # The method creates a new dict, so check what was passed to super + call_args = mock_super_handle.call_args[0][0] + assert "tools" in call_args + assert call_args["tool_choice"] == "auto" + + def test_handle_responses_without_tools(self, model): + """Test handling responses without MCP tools.""" + request_data = {"input": "Hello"} + + with patch.object(model, '_handle_responses') as mock_handle: + mock_response = MagicMock() + mock_handle.return_value = mock_response + result = model._handle_responses(request_data, None, None, None) + + # Should not modify request_data when no tools + assert "tools" not in request_data + + def test_route_request_chat_completions(self, model): + """Test routing to chat completions endpoint.""" + request_data = {"messages": [{"role": "user", "content": "Hello"}]} + with patch.object(model, '_handle_chat_completions') as mock_handle: + mock_handle.return_value = MagicMock() + model._route_request(model.ENDPOINT_CHAT_COMPLETIONS, request_data, None, None, None) + mock_handle.assert_called_once() + + def test_route_request_responses(self, model): + """Test routing to responses endpoint.""" + request_data = {"input": "Hello"} + with patch.object(model, '_handle_responses') as mock_handle: + mock_handle.return_value = MagicMock() + model._route_request(model.ENDPOINT_RESPONSES, request_data, None, None, None) + mock_handle.assert_called_once() + + # === Streaming Helper Tests === + + def test_accumulate_tool_delta(self, model): + """Test accumulating streaming tool call deltas.""" + accumulated = {} + delta = MagicMock() + delta.index = 0 + delta.id = "call_123" + delta.function.name = "test_tool" + delta.function.arguments = '{"arg": "value"}' + + model._accumulate_tool_delta(delta, accumulated) + + assert 0 in accumulated + assert accumulated[0]["id"] == "call_123" + assert accumulated[0]["function"]["name"] == "test_tool" + assert accumulated[0]["function"]["arguments"] == '{"arg": "value"}' + + def test_accumulate_tool_delta_incremental(self, model): + """Test accumulating incremental tool call arguments.""" + accumulated = { + 0: { + "id": "call_123", + "type": "function", + "function": {"name": "test_tool", "arguments": ""}, + } + } + delta = MagicMock() + delta.index = 0 + delta.id = None + delta.function.name = None + delta.function.arguments = '{"arg": "value"}' + + model._accumulate_tool_delta(delta, accumulated) + + assert accumulated[0]["function"]["arguments"] == '{"arg": "value"}' + + def test_finalize_tool_calls(self, model): + """Test finalizing accumulated tool calls.""" + accumulated = { + 0: { + "id": "call_1", + "type": "function", + "function": {"name": "tool1", "arguments": "{}"}, + }, + 1: { + "id": "call_2", + "type": "function", + "function": {"name": "tool2", "arguments": "{}"}, + }, + } + result = model._finalize_tool_calls(accumulated) + + assert len(result) == 2 + assert result[0]["id"] == "call_1" + assert result[1]["id"] == "call_2" + + def test_create_stream_request(self, model): + """Test creating streaming chat completion request.""" + messages = [{"role": "user", "content": "Hello"}] + tools = [{"type": "function", "function": {"name": "test_tool"}}] + + with patch.object(model.client.chat.completions, 'create') as mock_create: + mock_create.return_value = iter([]) + result = model._create_stream_request(messages, tools, 100, 0.7, 0.9) + + mock_create.assert_called_once() + call_kwargs = mock_create.call_args[1] + assert call_kwargs["stream"] is True + assert call_kwargs["tools"] == tools + assert call_kwargs["tool_choice"] == "auto" + + def test_async_to_sync_generator(self, model): + """Test bridging async generator to sync generator.""" + + async def async_gen(): + yield "chunk1" + yield "chunk2" + yield "chunk3" + + pool = model.get_pool() + result = list(model._async_to_sync_generator(async_gen)) + + assert len(result) == 3 + assert result == ["chunk1", "chunk2", "chunk3"] + + def test_async_to_sync_generator_with_error(self, model): + """Test async to sync generator with error.""" + + async def async_gen(): + yield "chunk1" + raise ValueError("Test error") + + with pytest.raises(ValueError, match="Test error"): + list(model._async_to_sync_generator(async_gen)) + + # === Streaming with MCP Tests === + + @pytest.mark.asyncio + async def test_stream_chat_with_tools_no_tool_calls(self, model): + """Test streaming chat with tools when no tool calls are generated.""" + messages = [{"role": "user", "content": "Hello"}] + tools = [] + connections = {} + tool_to_server = {} + + mock_chunk = MagicMock() + mock_chunk.model_dump_json.return_value = '{"id": "chunk1"}' + mock_chunk.choices = [MagicMock()] + mock_chunk.choices[0].delta = MagicMock() + mock_chunk.choices[0].delta.tool_calls = None + mock_chunk.choices[0].delta.content = "Hello" + + with patch.object(model, '_create_stream_request', return_value=[mock_chunk]): + chunks = [] + async for chunk in model._stream_chat_with_tools( + messages, tools, connections, tool_to_server, 100, 0.7, 0.9 + ): + chunks.append(chunk) + + assert len(chunks) == 1 + + @pytest.mark.asyncio + async def test_stream_chat_with_tools_with_tool_calls(self, model): + """Test streaming chat with tools when tool calls are generated.""" + messages = [{"role": "user", "content": "Hello"}] + tools = [{"type": "function", "function": {"name": "test_tool"}}] + connections = {} + tool_to_server = {} + + # Create mock tool call delta + mock_tool_call_delta = MagicMock() + mock_tool_call_delta.index = 0 + mock_tool_call_delta.id = "call_123" + mock_tool_call_delta.function.name = "test_tool" + mock_tool_call_delta.function.arguments = '{}' + + # First chunk with tool calls + mock_chunk_with_tools = MagicMock() + mock_chunk_with_tools.model_dump_json.return_value = '{"id": "chunk1"}' + mock_chunk_with_tools.choices = [MagicMock()] + mock_chunk_with_tools.choices[0].delta = MagicMock() + mock_chunk_with_tools.choices[0].delta.tool_calls = [mock_tool_call_delta] + mock_chunk_with_tools.choices[0].delta.content = None + + # Second chunk without tool calls (to prevent recursion) + mock_chunk_no_tools = MagicMock() + mock_chunk_no_tools.model_dump_json.return_value = '{"id": "chunk2"}' + mock_chunk_no_tools.choices = [MagicMock()] + mock_chunk_no_tools.choices[0].delta = MagicMock() + mock_chunk_no_tools.choices[0].delta.tool_calls = None + mock_chunk_no_tools.choices[0].delta.content = "Response" + + mock_pool = MagicMock() + mock_result = MagicMock() + mock_result.content = [MagicMock(text="Tool result")] + mock_pool.call_tools_batch_async = AsyncMock( + return_value=[("call_123", mock_result, None)] + ) + + # Create a generator that yields chunks + def chunk_generator(): + yield mock_chunk_with_tools + + # For recursive call, return chunk without tool calls + def chunk_generator_no_tools(): + yield mock_chunk_no_tools + + with ( + patch.object( + model, + '_create_stream_request', + side_effect=[chunk_generator(), chunk_generator_no_tools()], + ), + patch.object(model, 'get_pool', return_value=mock_pool), + patch.object( + model, + '_finalize_tool_calls', + return_value=[ + { + "id": "call_123", + "type": "function", + "function": {"name": "test_tool", "arguments": "{}"}, + } + ], + ), + ): + chunks = [] + async for chunk in model._stream_chat_with_tools( + messages, tools, connections, tool_to_server, 100, 0.7, 0.9 + ): + chunks.append(chunk) + + # Should have initial chunk plus recursive call chunks + assert len(chunks) >= 1 + + @pytest.mark.asyncio + async def test_stream_responses_with_tools(self, model): + """Test streaming responses with MCP tools.""" + request_data = {"input": "Hello"} + tools = [{"type": "function", "function": {"name": "test_tool"}}] + connections = {} + tool_to_server = {} + + # Create a simple event class that will be yielded + # The code checks chunk_type and yields if it doesn't match certain patterns + # We need to avoid output_index being detected by hasattr + class MockEvent: + def __init__(self): + self.type = "response.created" + self.response = None + + def model_dump_json(self): + return '{"type": "response.created"}' + + # Set the class name for chunk_type detection + MockEvent.__name__ = "ResponseCreatedEvent" + mock_event = MockEvent() + + with patch.object(model.client.responses, 'create', return_value=[mock_event]): + chunks = [] + async for chunk in model._stream_responses_with_tools( + request_data, tools, connections, tool_to_server + ): + chunks.append(chunk) + + # Should yield at least the event as JSON + assert len(chunks) >= 1 + + # === Main OpenAI Methods Tests === + + def test_openai_transport_without_mcp(self, model): + """Test openai_transport without MCP servers.""" + request = { + "messages": [{"role": "user", "content": "Hello"}], + "openai_endpoint": model.ENDPOINT_CHAT_COMPLETIONS, + } + + with patch.object(model, '_route_request') as mock_route: + mock_response = MagicMock() + mock_response.model_dump_json.return_value = '{"id": "test"}' + mock_route.return_value = mock_response + + result = model.openai_transport(to_json(request)) + assert json.loads(result)["id"] == "test" + + def test_openai_transport_with_mcp_chat_completions(self, model): + """Test openai_transport with MCP servers for chat completions.""" + request = { + "messages": [{"role": "user", "content": "Hello"}], + "mcp_servers": ["http://server"], + "openai_endpoint": model.ENDPOINT_CHAT_COMPLETIONS, + } + + mock_tool = MagicMock() + mock_tool.name = "test_tool" + mock_tool.description = "Test tool" + mock_tool.inputSchema = {} + + mock_conn = MCPConnection( + client=MagicMock(), tools=[mock_tool], tool_names={"test_tool"}, url="http://server" + ) + + mock_pool = MagicMock() + mock_pool.get_tools_and_mapping.return_value = ( + [{"type": "function", "function": {"name": "test_tool"}}], + {"http://server": mock_conn}, + {"test_tool": "http://server"}, + ) + + mock_response = MagicMock() + mock_response.model_dump_json.return_value = '{"id": "test"}' + mock_response.choices = [MagicMock()] + mock_response.choices[0].get.return_value = {} + mock_response.choices[0].message = MagicMock() + mock_response.choices[0].message.get.return_value = None + + with ( + patch.object(model, 'get_pool', return_value=mock_pool), + patch.object(model, '_route_request', return_value=mock_response), + ): + result = model.openai_transport(to_json(request)) + assert json.loads(result)["id"] == "test" + + def test_openai_transport_with_mcp_responses(self, model): + """Test openai_transport with MCP servers for responses API.""" + request = { + "input": "Hello", + "mcp_servers": ["http://server"], + "openai_endpoint": model.ENDPOINT_RESPONSES, + } + + mock_tool = MagicMock() + mock_tool.name = "test_tool" + mock_tool.description = "Test tool" + mock_tool.inputSchema = {} + + mock_conn = MCPConnection( + client=MagicMock(), tools=[mock_tool], tool_names={"test_tool"}, url="http://server" + ) + + mock_pool = MagicMock() + mock_pool.get_tools_and_mapping.return_value = ( + [{"type": "function", "function": {"name": "test_tool"}}], + {"http://server": mock_conn}, + {"test_tool": "http://server"}, + ) + + mock_response = MagicMock() + mock_response.model_dump_json.return_value = '{"id": "test"}' + mock_response.output = [] + + with ( + patch.object(model, 'get_pool', return_value=mock_pool), + patch.object(model, '_route_request', return_value=mock_response), + ): + result = model.openai_transport(to_json(request)) + assert json.loads(result)["id"] == "test" + + def test_openai_transport_with_existing_tools(self, model): + """Test openai_transport when tools are already provided (should not use MCP).""" + request = { + "messages": [{"role": "user", "content": "Hello"}], + "mcp_servers": ["http://server"], + "tools": [{"type": "function", "function": {"name": "existing_tool"}}], + "openai_endpoint": model.ENDPOINT_CHAT_COMPLETIONS, + } + + with patch.object(model, '_route_request') as mock_route: + mock_response = MagicMock() + mock_response.model_dump_json.return_value = '{"id": "test"}' + mock_route.return_value = mock_response + + result = model.openai_transport(to_json(request)) + # Should not call get_pool when tools are already provided + assert json.loads(result)["id"] == "test" + + def test_openai_transport_error_handling(self, model): + """Test error handling in openai_transport.""" + request = { + "messages": [{"role": "user", "content": "Hello"}], + "openai_endpoint": model.ENDPOINT_CHAT_COMPLETIONS, + } + + with patch.object(model, '_route_request', side_effect=Exception("Test error")): + result = model.openai_transport(to_json(request)) + error = json.loads(result) + # Check that it's an error code (could be 22000 or 21313 depending on status_code_pb2) + assert "code" in error + assert error["code"] > 0 + assert "Test error" in error["details"] + + def test_openai_stream_transport_without_mcp(self, model): + """Test openai_stream_transport without MCP servers.""" + request = { + "messages": [{"role": "user", "content": "Hello"}], + "stream": True, + "openai_endpoint": model.ENDPOINT_CHAT_COMPLETIONS, + } + + mock_chunk = MagicMock() + mock_chunk.model_dump_json.return_value = '{"id": "chunk1"}' + # Ensure usage is None to avoid token finalization issues + mock_chunk.usage = None + + with patch.object(model.client.chat.completions, 'create', return_value=[mock_chunk]): + chunks = list(model.openai_stream_transport(to_json(request))) + # Filter out error chunks if any + chunks = [c for c in chunks if not (isinstance(c, bytes) and b'"code"' in c)] + assert len(chunks) == 1 + + def test_openai_stream_transport_with_mcp(self, model): + """Test openai_stream_transport with MCP servers.""" + request = { + "messages": [{"role": "user", "content": "Hello"}], + "mcp_servers": ["http://server"], + "stream": True, + "openai_endpoint": model.ENDPOINT_CHAT_COMPLETIONS, + } + + mock_tool = MagicMock() + mock_tool.name = "test_tool" + mock_tool.description = "Test tool" + mock_tool.inputSchema = {} + + mock_conn = MCPConnection( + client=MagicMock(), tools=[mock_tool], tool_names={"test_tool"}, url="http://server" + ) + + mock_pool = MagicMock() + mock_pool.get_tools_and_mapping.return_value = ( + [{"type": "function", "function": {"name": "test_tool"}}], + {"http://server": mock_conn}, + {"test_tool": "http://server"}, + ) + mock_pool._loop = asyncio.new_event_loop() + + mock_chunk = MagicMock() + mock_chunk.model_dump_json.return_value = '{"id": "chunk1"}' + mock_chunk.choices = [MagicMock()] + mock_chunk.choices[0].delta = MagicMock() + mock_chunk.choices[0].delta.tool_calls = None + mock_chunk.choices[0].delta.content = "Hello" + + with ( + patch.object(model, 'get_pool', return_value=mock_pool), + patch.object(model, '_async_to_sync_generator') as mock_gen, + ): + mock_gen.return_value = iter(['{"id": "chunk1"}']) + chunks = list(model.openai_stream_transport(to_json(request))) + assert len(chunks) == 1 + + def test_openai_stream_transport_error_handling(self, model): + """Test error handling in openai_stream_transport.""" + request = { + "messages": [{"role": "user", "content": "Hello"}], + "stream": True, + "openai_endpoint": model.ENDPOINT_CHAT_COMPLETIONS, + } + + # Patch the client method to raise an error + def raise_error(**kwargs): + raise Exception("Test error") + + with patch.object(model.client.chat.completions, 'create', side_effect=raise_error): + chunks = list(model.openai_stream_transport(to_json(request))) + # Should have at least one error chunk + assert len(chunks) >= 1 + # The error should be in JSON format + error_str = ( + chunks[-1] + if isinstance(chunks[-1], str) + else chunks[-1].decode('utf-8') + if isinstance(chunks[-1], bytes) + else str(chunks[-1]) + ) + error = json.loads(error_str) + assert "code" in error + assert "Test error" in error.get("details", "") + + # === Pool Management Tests === + + def test_get_pool_singleton(self, model): + """Test that get_pool returns a singleton.""" + pool1 = model.get_pool() + pool2 = model.get_pool() + assert pool1 is pool2 + + def test_get_pool_thread_safe(self, model): + """Test that get_pool is thread-safe.""" + import threading + + pools = [] + + def get_pool(): + pools.append(model.get_pool()) + + threads = [threading.Thread(target=get_pool) for _ in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + # All should be the same instance + assert all(p is pools[0] for p in pools) + + # === Integration Tests === + + def test_full_flow_chat_completions_with_tools(self, model): + """Test full flow: chat completions with tool calls.""" + request = { + "messages": [{"role": "user", "content": "Use test_tool"}], + "mcp_servers": ["http://server"], + "openai_endpoint": model.ENDPOINT_CHAT_COMPLETIONS, + } + + mock_tool = MagicMock() + mock_tool.name = "test_tool" + mock_tool.description = "Test tool" + mock_tool.inputSchema = {} + + mock_conn = MCPConnection( + client=MagicMock(), tools=[mock_tool], tool_names={"test_tool"}, url="http://server" + ) + + # First response has tool calls + mock_tool_call = MagicMock() + mock_tool_call.id = "call_123" + mock_tool_call.function.name = "test_tool" + mock_tool_call.function.arguments = '{}' + + mock_message = MagicMock() + mock_message.tool_calls = [mock_tool_call] + + # Make message behave like a dict for .get() calls + def message_get(key, default=None): + if key == "tool_calls": + return [mock_tool_call] + return default + + mock_message.get = message_get + + mock_response1 = MagicMock() + mock_response1.model_dump_json.return_value = '{"id": "test1"}' + mock_response1.choices = [MagicMock()] + mock_response1.choices[0].message = mock_message + + # Make choices[0] behave like a dict for .get() calls + def mock_get(key, default=None): + if key == "message": + return mock_message + return default + + mock_response1.choices[0].get = mock_get + + # Second response (after tool execution) has no tool calls + mock_message2 = MagicMock() + mock_message2.tool_calls = None + + def message_get2(key, default=None): + if key == "tool_calls": + return None + return default + + mock_message2.get = message_get2 + + mock_response2 = MagicMock() + mock_response2.model_dump_json.return_value = '{"id": "test2"}' + mock_response2.choices = [MagicMock()] + mock_response2.choices[0].message = mock_message2 + + def mock_get2(key, default=None): + if key == "message": + return mock_message2 + return default + + mock_response2.choices[0].get = mock_get2 + + mock_pool = MagicMock() + mock_pool.get_tools_and_mapping.return_value = ( + [{"type": "function", "function": {"name": "test_tool"}}], + {"http://server": mock_conn}, + {"test_tool": "http://server"}, + ) + mock_result = MagicMock() + mock_result.content = [MagicMock(text="Tool executed")] + mock_pool.call_tools_batch.return_value = [("call_123", mock_result, None)] + + with ( + patch.object(model, 'get_pool', return_value=mock_pool), + patch.object(model, '_route_request', side_effect=[mock_response1, mock_response2]), + ): + result = model.openai_transport(to_json(request)) + assert json.loads(result)["id"] == "test2" + + def test_full_flow_responses_with_tools(self, model): + """Test full flow: responses API with tool calls.""" + request = { + "input": "Use test_tool", + "mcp_servers": ["http://server"], + "openai_endpoint": model.ENDPOINT_RESPONSES, + } + + mock_tool = MagicMock() + mock_tool.name = "test_tool" + mock_tool.description = "Test tool" + mock_tool.inputSchema = {} + + mock_conn = MCPConnection( + client=MagicMock(), tools=[mock_tool], tool_names={"test_tool"}, url="http://server" + ) + + # First response has tool calls + mock_response1 = MagicMock() + mock_response1.model_dump_json.return_value = '{"id": "test1"}' + mock_response1.output = [ + { + "type": "function_tool_call", + "call_id": "call_123", + "name": "test_tool", + "arguments": '{}', + "output": None, + } + ] + + # Second response (after tool execution) has no tool calls + mock_response2 = MagicMock() + mock_response2.model_dump_json.return_value = '{"id": "test2"}' + mock_response2.output = [] + + mock_pool = MagicMock() + mock_pool.get_tools_and_mapping.return_value = ( + [{"type": "function", "function": {"name": "test_tool"}}], + {"http://server": mock_conn}, + {"test_tool": "http://server"}, + ) + mock_result = MagicMock() + mock_result.content = [MagicMock(text="Tool executed")] + mock_pool.call_tools_batch.return_value = [("call_123", mock_result, None)] + + with ( + patch.object(model, 'get_pool', return_value=mock_pool), + patch.object(model, '_route_request', side_effect=[mock_response1, mock_response2]), + ): + result = model.openai_transport(to_json(request)) + assert json.loads(result)["id"] == "test2" + + def test_to_response_api_tools_empty_list(self, model): + """Test converting empty tools list.""" + result = model._to_response_api_tools([]) + assert result == [] diff --git a/tests/runners/test_mcp_connection_pool.py b/tests/runners/test_mcp_connection_pool.py new file mode 100644 index 00000000..be682532 --- /dev/null +++ b/tests/runners/test_mcp_connection_pool.py @@ -0,0 +1,532 @@ +"""Test cases for MCPConnectionPool singleton and connection lifecycle management.""" + +import asyncio +import time +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from clarifai.runners.models.agentic_class import MCPConnection, MCPConnectionPool + +# Test constants for timing margins +IDLE_MARGIN_SECONDS = 10 # Margin to add beyond thresholds for robust idle tests + + +class TestMCPConnectionPool: + """Tests for MCPConnectionPool singleton and connection management.""" + + @pytest.fixture(autouse=True) + def reset_pool(self): + """Reset the singleton instance before each test.""" + # Clear singleton instance + MCPConnectionPool._instance = None + yield + # Clean up after test + MCPConnectionPool._instance = None + + def test_singleton_pattern(self): + """Test that MCPConnectionPool is a singleton.""" + pool1 = MCPConnectionPool() + pool2 = MCPConnectionPool() + + assert pool1 is pool2 + assert id(pool1) == id(pool2) + + def test_singleton_initialization_once(self): + """Test that singleton is initialized only once.""" + pool1 = MCPConnectionPool() + initial_connections = pool1._connections + + pool2 = MCPConnectionPool() + + # Should have the same connections dictionary + assert pool2._connections is initial_connections + + def test_event_loop_initialization(self): + """Test that background event loop is started on init.""" + pool = MCPConnectionPool() + + assert pool._loop is not None + assert pool._loop_thread is not None + assert pool._loop_thread.is_alive() + assert not pool._loop.is_closed() + + def test_connection_cleanup_idle_timeout(self): + """Test that connections idle > MAX_IDLE_TIME are removed.""" + pool = MCPConnectionPool() + + # Create a mock connection that's been idle for too long + mock_client = MagicMock() + mock_client.close = AsyncMock() + old_conn = MCPConnection( + client=mock_client, + tools=[], + tool_names=set(), + url="http://old-server", + last_used=time.time() - pool.MAX_IDLE_TIME - IDLE_MARGIN_SECONDS, + ) + + with pool._lock: + pool._connections["http://old-server"] = old_conn + + # Force cleanup to run immediately + pool._last_cleanup = 0 + pool._maybe_cleanup_idle() + + # Connection should be removed + with pool._lock: + assert "http://old-server" not in pool._connections + + def test_cleanup_interval_rate_limiting(self): + """Test that cleanup checks are rate limited.""" + pool = MCPConnectionPool() + + # Create an idle connection + mock_client = MagicMock() + old_conn = MCPConnection( + client=mock_client, + tools=[], + tool_names=set(), + url="http://server", + last_used=time.time() - pool.MAX_IDLE_TIME - IDLE_MARGIN_SECONDS, + ) + + with pool._lock: + pool._connections["http://server"] = old_conn + + # Set last cleanup to recent time + pool._last_cleanup = time.time() + + # Try to cleanup - should be skipped due to rate limiting + pool._maybe_cleanup_idle() + + # Connection should still exist (cleanup was skipped) + with pool._lock: + assert "http://server" in pool._connections + + def test_connection_verification_valid(self): + """Test connection verification for valid connections.""" + pool = MCPConnectionPool() + + # Create mock connection + mock_client = MagicMock() + mock_client.list_tools = AsyncMock(return_value=[]) + + conn = MCPConnection( + client=mock_client, tools=[], tool_names=set(), url="http://valid-server" + ) + + # Run verification + is_valid = pool._run_async(pool._verify_connection(conn)) + + assert is_valid is True + mock_client.list_tools.assert_called_once() + + def test_connection_verification_invalid(self): + """Test connection verification for invalid connections.""" + pool = MCPConnectionPool() + + # Create mock connection that fails verification + mock_client = MagicMock() + mock_client.list_tools = AsyncMock(side_effect=Exception("Connection lost")) + + conn = MCPConnection( + client=mock_client, tools=[], tool_names=set(), url="http://invalid-server" + ) + + # Run verification + is_valid = pool._run_async(pool._verify_connection(conn)) + + assert is_valid is False + + def test_needs_verification(self): + """Test _needs_verification logic.""" + pool = MCPConnectionPool() + + # Fresh connection should not need verification + fresh_conn = MCPConnection( + client=MagicMock(), tools=[], tool_names=set(), url="http://fresh" + ) + assert pool._needs_verification(fresh_conn) is False + + # Old connection should need verification + old_conn = MCPConnection( + client=MagicMock(), + tools=[], + tool_names=set(), + url="http://old", + last_used=time.time() - pool.VERIFY_IDLE_THRESHOLD - IDLE_MARGIN_SECONDS, + ) + assert pool._needs_verification(old_conn) is True + + def test_parallel_connection_creation(self): + """Test that connections are created in parallel.""" + pool = MCPConnectionPool() + + # Create mock connections and add them directly + mock_client1 = MagicMock() + mock_client1.list_tools = AsyncMock(return_value=[]) + conn1 = MCPConnection( + client=mock_client1, tools=[], tool_names=set(), url="http://server1" + ) + + mock_client2 = MagicMock() + mock_client2.list_tools = AsyncMock(return_value=[]) + conn2 = MCPConnection( + client=mock_client2, tools=[], tool_names=set(), url="http://server2" + ) + + mock_client3 = MagicMock() + mock_client3.list_tools = AsyncMock(return_value=[]) + conn3 = MCPConnection( + client=mock_client3, tools=[], tool_names=set(), url="http://server3" + ) + + with pool._lock: + pool._connections["http://server1"] = conn1 + pool._connections["http://server2"] = conn2 + pool._connections["http://server3"] = conn3 + + # Get connections - should reuse existing ones + urls = ["http://server1", "http://server2", "http://server3"] + connections = pool.get_connections(urls) + + # All connections should be returned + assert len(connections) == 3 + for url in urls: + assert url in connections + + def test_connection_creation_error_handling(self): + """Test error handling when connection creation fails.""" + pool = MCPConnectionPool() + + # Try to create connection for a URL that's not already in pool + # This will fail because fastmcp is not actually installed + urls = ["http://bad-server"] + connections = pool.get_connections(urls) + + # Should handle error gracefully and return empty dict + assert len(connections) == 0 + + def test_parallel_connection_creation_partial_failure(self): + """Test parallel creation with some failures.""" + pool = MCPConnectionPool() + + # Add two valid connections + mock_client1 = MagicMock() + mock_client1.list_tools = AsyncMock(return_value=[]) + conn1 = MCPConnection( + client=mock_client1, tools=[], tool_names=set(), url="http://server1" + ) + + mock_client3 = MagicMock() + mock_client3.list_tools = AsyncMock(return_value=[]) + conn3 = MCPConnection( + client=mock_client3, tools=[], tool_names=set(), url="http://server3" + ) + + with pool._lock: + pool._connections["http://server1"] = conn1 + pool._connections["http://server3"] = conn3 + + # Try to get connections including one that doesn't exist (will fail to create) + urls = ["http://server1", "http://server2", "http://server3"] + connections = pool.get_connections(urls) + + # Should have 2 successful connections (1st and 3rd) + assert len(connections) == 2 + assert "http://server1" in connections + assert "http://server3" in connections + assert "http://server2" not in connections + + def test_connection_touch_mechanism(self): + """Test that connections are touched when accessed.""" + pool = MCPConnectionPool() + + # Create connection with old timestamp + mock_client = MagicMock() + old_time = time.time() - 100 + conn = MCPConnection( + client=mock_client, tools=[], tool_names=set(), url="http://server", last_used=old_time + ) + + with pool._lock: + pool._connections["http://server"] = conn + + # Access connection (should touch it) + connections = pool.get_connections(["http://server"]) + + # Last used should be updated + assert connections["http://server"].last_used > old_time + + def test_tool_cache_update(self): + """Test that tool cache is updated when connections are created.""" + pool = MCPConnectionPool() + + # Create mock tools + mock_tool1 = MagicMock() + mock_tool1.name = "tool1" + mock_tool1.description = "Test tool 1" + mock_tool1.inputSchema = {"type": "object"} + + mock_tool2 = MagicMock() + mock_tool2.name = "tool2" + mock_tool2.description = "Test tool 2" + mock_tool2.inputSchema = {"type": "object"} + + conn = MCPConnection( + client=MagicMock(), + tools=[mock_tool1, mock_tool2], + tool_names={"tool1", "tool2"}, + url="http://server", + ) + + # Update cache + pool._update_tool_cache(conn) + + # Verify cache contents + assert "tool1" in pool._tool_to_url + assert "tool2" in pool._tool_to_url + assert pool._tool_to_url["tool1"] == "http://server" + assert pool._tool_to_url["tool2"] == "http://server" + assert "tool1" in pool._all_tools + assert "tool2" in pool._all_tools + + def test_tool_cache_invalidation_on_disconnect(self): + """Test that tool cache is invalidated when connection is removed.""" + pool = MCPConnectionPool() + + # Create connection with tools + mock_client = MagicMock() + mock_client.close = AsyncMock() + + mock_tool = MagicMock() + mock_tool.name = "test_tool" + + conn = MCPConnection( + client=mock_client, tools=[mock_tool], tool_names={"test_tool"}, url="http://server" + ) + + with pool._lock: + pool._connections["http://server"] = conn + pool._tool_to_url["test_tool"] = "http://server" + pool._all_tools["test_tool"] = {"type": "function"} + + # Disconnect + pool._disconnect("http://server") + + # Tool cache should be cleared + assert "test_tool" not in pool._tool_to_url + assert "test_tool" not in pool._all_tools + + def test_connection_reuse(self): + """Test that existing connections are reused.""" + pool = MCPConnectionPool() + + # Create initial connection + mock_client = MagicMock() + conn = MCPConnection(client=mock_client, tools=[], tool_names=set(), url="http://server") + + with pool._lock: + pool._connections["http://server"] = conn + + # Get connection again + connections = pool.get_connections(["http://server"]) + + # Should reuse same connection + assert connections["http://server"] is conn + + def test_stale_connection_recreation(self): + """Test that stale connections are verified and removed if invalid.""" + pool = MCPConnectionPool() + + # Create stale connection that will fail verification + mock_old_client = MagicMock() + mock_old_client.list_tools = AsyncMock(side_effect=Exception("Connection lost")) + + old_conn = MCPConnection( + client=mock_old_client, + tools=[], + tool_names=set(), + url="http://server", + last_used=time.time() - pool.VERIFY_IDLE_THRESHOLD - IDLE_MARGIN_SECONDS, + ) + + with pool._lock: + pool._connections["http://server"] = old_conn + + # Get connection (should verify and fail, then try to recreate but fail due to missing fastmcp) + connections = pool.get_connections(["http://server"]) + + # Should not have connection since verification failed and recreation failed + # (fastmcp is not installed so recreation will fail) + assert "http://server" not in connections + + # Original connection should have been removed from pool + with pool._lock: + assert "http://server" not in pool._connections + + def test_close_connection_with_close_method(self): + """Test closing connection with close() method.""" + pool = MCPConnectionPool() + + # Create connection with close method + mock_client = MagicMock() + mock_client.close = AsyncMock() + + conn = MCPConnection(client=mock_client, tools=[], tool_names=set(), url="http://server") + + # Close connection + pool._run_async(pool._close_connection(conn)) + + # close() should be called + mock_client.close.assert_called_once() + + def test_close_connection_with_aexit(self): + """Test closing connection with __aexit__ method.""" + pool = MCPConnectionPool() + + # Create connection without close method but with __aexit__ + # Use spec to explicitly define available methods (common client methods) + mock_client = MagicMock(spec=['__aexit__', '__aenter__', 'list_tools', 'call_tool']) + mock_client.__aexit__ = AsyncMock() + + conn = MCPConnection(client=mock_client, tools=[], tool_names=set(), url="http://server") + + # Close connection + pool._run_async(pool._close_connection(conn)) + + # __aexit__ should be called + mock_client.__aexit__.assert_called_once_with(None, None, None) + + def test_get_tools_and_mapping(self): + """Test getting tools and connection mapping.""" + pool = MCPConnectionPool() + + # Create mock tools + mock_tool1 = MagicMock() + mock_tool1.name = "tool1" + mock_tool1.description = "Test tool 1" + mock_tool1.inputSchema = {"type": "object"} + + mock_tool2 = MagicMock() + mock_tool2.name = "tool2" + mock_tool2.description = "Test tool 2" + mock_tool2.inputSchema = {"type": "object"} + + # Create connections + conn1 = MCPConnection( + client=MagicMock(), tools=[mock_tool1], tool_names={"tool1"}, url="http://server1" + ) + + conn2 = MCPConnection( + client=MagicMock(), tools=[mock_tool2], tool_names={"tool2"}, url="http://server2" + ) + + with pool._lock: + pool._connections["http://server1"] = conn1 + pool._connections["http://server2"] = conn2 + + # Get tools and mapping + tools, connections, tool_to_server = pool.get_tools_and_mapping( + ["http://server1", "http://server2"] + ) + + # Verify results + assert len(tools) == 2 + assert len(connections) == 2 + assert "tool1" in tool_to_server + assert "tool2" in tool_to_server + assert tool_to_server["tool1"] == "http://server1" + assert tool_to_server["tool2"] == "http://server2" + + def test_call_tool_sync(self): + """Test synchronous tool calling.""" + pool = MCPConnectionPool() + + # Create mock connection + mock_client = MagicMock() + mock_result = MagicMock() + mock_result.content = [MagicMock(text="Result")] + mock_client.call_tool = AsyncMock(return_value=mock_result) + + conn = MCPConnection( + client=mock_client, tools=[], tool_names={"test_tool"}, url="http://server" + ) + + connections = {"http://server": conn} + tool_to_server = {"test_tool": "http://server"} + + # Call tool + result = pool.call_tool("test_tool", {"arg": "value"}, connections, tool_to_server) + + # Verify call + assert result is mock_result + mock_client.call_tool.assert_called_once_with("test_tool", arguments={"arg": "value"}) + + def test_call_tools_batch(self): + """Test batch tool calling.""" + pool = MCPConnectionPool() + + # Create mock connection + mock_client = MagicMock() + mock_result1 = MagicMock() + mock_result1.content = [MagicMock(text="Result1")] + mock_result2 = MagicMock() + mock_result2.content = [MagicMock(text="Result2")] + + async def mock_call_tool(name, arguments): + if name == "tool1": + return mock_result1 + return mock_result2 + + mock_client.call_tool = mock_call_tool + + conn = MCPConnection( + client=mock_client, tools=[], tool_names={"tool1", "tool2"}, url="http://server" + ) + + connections = {"http://server": conn} + tool_to_server = {"tool1": "http://server", "tool2": "http://server"} + + # Call tools in batch + calls = [("id1", "tool1", {"arg": "value1"}), ("id2", "tool2", {"arg": "value2"})] + results = pool.call_tools_batch(calls, connections, tool_to_server) + + # Verify results + assert len(results) == 2 + assert results[0][0] == "id1" + assert results[1][0] == "id2" + + def test_tool_call_timeout(self): + """Test that tool calls timeout appropriately.""" + pool = MCPConnectionPool() + + # Temporarily reduce timeout for faster test execution + original_timeout = pool.TOOL_CALL_TIMEOUT + pool.TOOL_CALL_TIMEOUT = 2.0 # Use 2 second timeout for test + + try: + # Create mock connection with slow tool + mock_client = MagicMock() + + async def slow_tool(*args, **kwargs): + # Sleep just beyond the reduced timeout + await asyncio.sleep(3.0) + return MagicMock() + + mock_client.call_tool = slow_tool + + conn = MCPConnection( + client=mock_client, tools=[], tool_names={"slow_tool"}, url="http://server" + ) + + connections = {"http://server": conn} + tool_to_server = {"slow_tool": "http://server"} + + # Call should timeout + with pytest.raises(asyncio.TimeoutError): + pool.call_tool("slow_tool", {}, connections, tool_to_server) + finally: + # Restore original timeout + pool.TOOL_CALL_TIMEOUT = original_timeout