diff --git a/clarifai/runners/dockerfile_template/Dockerfile.node.template b/clarifai/runners/dockerfile_template/Dockerfile.node.template new file mode 100644 index 00000000..34c525ab --- /dev/null +++ b/clarifai/runners/dockerfile_template/Dockerfile.node.template @@ -0,0 +1,62 @@ +# syntax=docker/dockerfile:1.13-labs + +FROM --platform=$TARGETPLATFORM python:${PYTHON_VERSION}-slim +# Install Node.js & npm (npx is included with npm) +# This is required to run the Node.js MCP server via npx +RUN apt-get update && apt-get install -y curl xz-utils \ + && curl -fsSL https://nodejs.org/dist/v${NODE_VERSION}/node-v${NODE_VERSION}-linux-x64.tar.xz -o node.tar.xz \ + && tar -xJf node.tar.xz -C /usr/local --strip-components=1 \ + && rm node.tar.xz \ + && node -v && npm -v && npx -v + +COPY --link requirements.txt /home/nonroot/requirements.txt + +# Update clarifai package so we always have latest protocol to the API. Everything should land in /venv +RUN ["pip", "install", "--no-cache-dir", "-r", "/home/nonroot/requirements.txt"] +RUN ["pip", "show", "--no-cache-dir", "clarifai"] + +# Set the NUMBA cache dir to /tmp +# Set the TORCHINDUCTOR cache dir to /tmp +# The CLARIFAI* will be set by the templaing system. +ENV NUMBA_CACHE_DIR=/tmp/numba_cache \ + TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_cache \ + HOME=/tmp \ + DEBIAN_FRONTEND=noninteractive + +##### +# Copy the files needed to download +##### +# This creates the directory that HF downloader will populate and with nonroot:nonroot permissions up. +# COPY --chown=nonroot:nonroot downloader/unused.yaml /home/nonroot/main/1/checkpoints/.cache/unused.yaml + +##### +# Download checkpoints if config.yaml has checkpoints.when = "build" +COPY --link=true config.yaml /home/nonroot/main/ +# RUN ["python", "-m", "clarifai.cli", "model", "download-checkpoints", "/home/nonroot/main", "--out_path", "/home/nonroot/main/1/checkpoints", "--stage", "build"] + +##### +# Copy in the actual files like config.yaml, requirements.txt, and most importantly 1/model.py +# for the actual model. +# If checkpoints aren't downloaded since a checkpoints: block is not provided, then they will +# be in the build context and copied here as well. +COPY --link=true 1 /home/nonroot/main/1 + +# At this point we only need these for validation in the SDK. +COPY --link=true requirements.txt config.yaml /home/nonroot/main/ + +# Add the model directory to the python path. +ENV PYTHONPATH=${PYTHONPATH}:/home/nonroot/main \ + CLARIFAI_PAT=${CLARIFAI_PAT} \ + CLARIFAI_USER_ID=${CLARIFAI_USER_ID} \ + CLARIFAI_RUNNER_ID=${CLARIFAI_RUNNER_ID} \ + CLARIFAI_NODEPOOL_ID=${CLARIFAI_NODEPOOL_ID} \ + CLARIFAI_COMPUTE_CLUSTER_ID=${CLARIFAI_COMPUTE_CLUSTER_ID} \ + CLARIFAI_API_BASE=${CLARIFAI_API_BASE:-https://api.clarifai.com} + +WORKDIR /home/nonroot/main + +# Finally run the clarifai entrypoint to start the runner loop and local runner server. +# Note(zeiler): we may want to make this a clarifai CLI call. +ENTRYPOINT ["python", "-m", "clarifai.runners.server"] +CMD ["--model_path", "/home/nonroot/main"] +############################# diff --git a/clarifai/runners/models/mcp_class.py b/clarifai/runners/models/mcp_class.py index 4a88d8a4..4f72bd5c 100644 --- a/clarifai/runners/models/mcp_class.py +++ b/clarifai/runners/models/mcp_class.py @@ -1,23 +1,97 @@ -"""Base class for creating Model Context Protocol (MCP) servers.""" +"""MCP model base class with a *single* long‑lived FastMCP client. + +The implementation creates a **background thread + event‑loop** that owns +the FastMCP client. All incoming MCP calls are forwarded to that loop via +``asyncio.run_coroutine_threadsafe`` and the result is returned synchronously. +""" import asyncio import json -from typing import TYPE_CHECKING, Any +import logging +import threading +from typing import TYPE_CHECKING, Any, Optional from clarifai.runners.models.model_class import ModelClass -if TYPE_CHECKING: - from fastmcp import FastMCP +if TYPE_CHECKING: # pragma: no cover + from fastmcp import Client, FastMCP + from mcp.client.session import ClientSession + +logger = logging.getLogger(__name__) class MCPModelClass(ModelClass): - """Base class for wrapping FastMCP servers as a model running in Clarfai. This handles - all the transport between the API and the MCP server here. Simply subclass this and implement - the get_server() method to return the FastMCP server instance. The server is then used to - handle all the requests and responses. """ + Base class for wrapping a FastMCP server as a Clarifai model. + + Sub‑classes must implement :meth:`get_server` and return a ready‑to‑use + ``FastMCP`` instance. + """ + + def __init__(self): + super().__init__() + self._fastmcp_server: Optional["FastMCP"] = None + + # FastMCP client that talks to the server (created inside the background loop) + self._client: Optional["Client"] = None + self._client_session: Optional["ClientSession"] = None + + # Background thread / loop handling + self._loop: Optional[asyncio.AbstractEventLoop] = None + self._thread: Optional[threading.Thread] = None + self._initialized = threading.Event() + self._init_error: Optional[Exception] = None def load_model(self): + """Called by Clarifai to initialize the model. Starts the background loop.""" + self._start_background_loop() + + # Wait for initialization to complete (with timeout) + if not self._initialized.wait(timeout=60): + raise RuntimeError("Background MCP initialization timed out") + + if self._init_error is not None: + raise self._init_error + + def get_server(self) -> "FastMCP": + """Required method for each subclass to implement to return the FastMCP server to use.""" + raise NotImplementedError("Subclasses must implement get_server() method") + + def _start_background_loop(self) -> None: + """Spin up a daemon thread that runs its own asyncio event‑loop.""" + + def runner(): + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + + try: + # Run the initialisation coroutine + init_coro = self._background_initialise() + self._loop.run_until_complete(init_coro) + except Exception as e: + self._init_error = e + self._initialized.set() + return + + self._initialized.set() + + # Keep the loop alive until we are asked to stop. + self._loop.run_forever() + + # Clean‑up when loop stops + self._loop.run_until_complete(self._background_shutdown()) + self._loop.close() + logger.debug("Background MCP thread stopped") + + self._thread = threading.Thread(target=runner, name="MCP-background-loop", daemon=True) + self._thread.start() + logger.debug("Background MCP thread started") + + async def _background_initialise(self) -> None: + """ + Create the FastMCP server and client. + All objects are bound to the *same* event‑loop (the background loop). + """ try: from fastmcp import Client except ImportError: @@ -25,107 +99,161 @@ def load_model(self): "fastmcp package is required to use MCP functionality. " "Install it with: pip install fastmcp" ) - # in memory transport provided in fastmcp v2 so we can easily use the client functions. - self.client = Client(self.get_server()) - def get_server(self) -> 'FastMCP': - """Required method for each subclass to implement to return the FastMCP server to use.""" - raise NotImplementedError("Subclasses must implement get_server() method") + # Create FastMCP server (this triggers lifespan which may do tool discovery) + self._fastmcp_server = self.get_server() - @ModelClass.method - def mcp_transport(self, msg: str) -> str: - """The single model method to get the jsonrpc message and send it to the FastMCP server then - return it's response. + # Create FastMCP client + self._client = Client(self._fastmcp_server) + await self._client.__aenter__() + self._client_session = self._client.session + + logger.debug("Background MCP initialisation complete") + + async def _background_shutdown(self) -> None: + """Clean up resources when shutting down.""" + # Close FastMCP client + if self._client is not None: + try: + await self._client.__aexit__(None, None, None) + except Exception: + logger.exception("Error while closing FastMCP client") + self._client = None + self._client_session = None + + logger.debug("Background MCP shutdown complete") + def _run_in_background(self, coro) -> Any: + """ + Schedule *coro* on the background loop and block until it finishes. + """ + if self._loop is None or not self._loop.is_running(): + raise RuntimeError("Background event loop not running") + future = asyncio.run_coroutine_threadsafe(coro, self._loop) + return future.result(timeout=300) # 5 minute timeout + + async def _bg_send_request(self, payload: dict) -> Any: + """ + Runs inside the background loop. Forwards the request to FastMCP client. """ from mcp import types from mcp.shared.exceptions import McpError - async def send_notification(client_message: types.ClientNotification) -> None: - async with self.client: - # Strip the jsonrpc field since send_notification will also pass it in for some reason. - client_message = types.ClientNotification.model_validate( - client_message.model_dump( - by_alias=True, mode="json", exclude_none=True, exclude={"jsonrpc"} - ) - ) - try: - return await self.client.session.send_notification(client_message) - except McpError as e: - return types.JSONRPCError(jsonrpc="2.0", error=e.error) - - async def send_request(client_message: types.ClientRequest, id: str) -> Any: - async with self.client: - # Strip the jsonrpc and id fields as send_request sets them again too. - client_message = types.ClientRequest.model_validate( - client_message.model_dump( - by_alias=True, mode="json", exclude_none=True, exclude={"jsonrpc", "id"} - ) - ) + msg_id = payload.get("id", "") + + raw_msg = types.ClientRequest.model_validate(payload) + clean_dict = raw_msg.model_dump( + by_alias=True, + mode="json", + exclude_none=True, + exclude={"jsonrpc", "id"}, + ) + client_message = types.ClientRequest.model_validate(clean_dict) + + # Determine the expected result type + result_type = self._get_result_type(client_message) + + if result_type is None: + # Special case: InitializeRequest + if isinstance(client_message.root, types.InitializeRequest): + return await self._client_session.initialize() + raise NotImplementedError( + f"Method {getattr(client_message, 'method', 'unknown')} not implemented" + ) + + try: + return await self._client_session.send_request(client_message, result_type) + except McpError as e: + return types.JSONRPCError(jsonrpc="2.0", id=msg_id, error=e.error) + + def _get_result_type(self, client_message): + """Map request type to result type.""" + from mcp import types + + type_map = { + types.PingRequest: types.EmptyResult, + types.InitializeRequest: None, # Special handling + types.SetLevelRequest: types.EmptyResult, + types.ListResourcesRequest: types.ListResourcesResult, + types.ListResourceTemplatesRequest: types.ListResourceTemplatesResult, + types.ReadResourceRequest: types.ReadResourceResult, + types.SubscribeRequest: types.EmptyResult, + types.UnsubscribeRequest: types.EmptyResult, + types.ListPromptsRequest: types.ListPromptsResult, + types.GetPromptRequest: types.GetPromptResult, + types.CompleteRequest: types.CompleteResult, + types.ListToolsRequest: types.ListToolsResult, + types.CallToolRequest: types.CallToolResult, + } + + for req_type, res_type in type_map.items(): + if isinstance(client_message.root, req_type): + return res_type + return None + + async def _bg_send_notification(self, payload: dict) -> None: + """ + Runs inside the background loop. Forwards notification to FastMCP client. + """ + from mcp import types + from mcp.shared.exceptions import McpError + + raw_msg = types.ClientNotification.model_validate(payload) + clean_dict = raw_msg.model_dump( + by_alias=True, + mode="json", + exclude_none=True, + exclude={"jsonrpc"}, + ) + client_message = types.ClientNotification.model_validate(clean_dict) + + try: + await self._client_session.send_notification(client_message) + except McpError: + logger.exception("Error while sending notification to FastMCP") + + @ModelClass.method + def mcp_transport(self, msg: str) -> str: + """ + Synchronous entry point used by Clarifai. + """ + from mcp import types + + payload = json.loads(msg) + + if not payload.get("method", "").startswith("notifications/"): + # Normal request – we need a response. + result = self._run_in_background(self._bg_send_request(payload)) - result_type = None - if isinstance(client_message.root, types.PingRequest): - result_type = types.EmptyResult - elif isinstance(client_message.root, types.InitializeRequest): - return await self.client.session.initialize() - elif isinstance(client_message.root, types.SetLevelRequest): - result_type = types.EmptyResult - elif isinstance(client_message.root, types.ListResourcesRequest): - result_type = types.ListResourcesResult - elif isinstance(client_message.root, types.ListResourceTemplatesRequest): - result_type = types.ListResourceTemplatesResult - elif isinstance(client_message.root, types.ReadResourceRequest): - result_type = types.ReadResourceResult - elif isinstance(client_message.root, types.SubscribeRequest): - result_type = types.EmptyResult - elif isinstance(client_message.root, types.UnsubscribeRequest): - result_type = types.EmptyResult - elif isinstance(client_message.root, types.ListPromptsRequest): - result_type = types.ListPromptsResult - elif isinstance(client_message.root, types.GetPromptRequest): - result_type = types.GetPromptResult - elif isinstance(client_message.root, types.CompleteRequest): - result_type = types.CompleteResult - elif isinstance(client_message.root, types.ListToolsRequest): - result_type = types.ListToolsResult - elif isinstance(client_message.root, types.CallToolRequest): - result_type = types.CallToolResult - else: - # this is a special case where we need to return the list of tools. - raise NotImplementedError(f"Method {client_message.method} not implemented") - # Call the mcp server using send_request() or send_notification() depending on the method. - try: - return await self.client.session.send_request(client_message, result_type) - except McpError as e: - return types.JSONRPCError(jsonrpc="2.0", id=id, error=e.error) - - # The message coming here is the generic request. We look at it's .method - # to determine which client function to call and to further subparse the params. - # Note(zeiler): unfortunately the pydantic types in mcp/types.py are not consistent. - # The JSONRPCRequest are supposed to have an id but the InitializeRequest - # does not have it. - d = json.loads(msg) - id = d.get('id', "") - - # If we have an id it's a JSONRPCRequest - if not d.get('method', '').startswith("notifications/"): - client_message = types.ClientRequest.model_validate(d) - # Note(zeiler): this response is the "result" field of the JSONRPCResponse. - # the API will fill in the "id" and "jsonrpc" fields. - response = asyncio.run(send_request(client_message, id=id)) - if response is None: - response = types.JSONRPCError( + if result is None: + result = types.JSONRPCError( jsonrpc="2.0", - id=id, + id=payload.get("id", ""), error=types.ErrorData( - code=types.INTERNAL_ERROR, message="Got empty response from MCP server." + code=types.INTERNAL_ERROR, + message="Empty response from MCP server.", ), ) - # return as a serialized json string - res = response.model_dump_json(by_alias=True, exclude_none=True) - return res - else: # JSONRPCRequest - client_message = types.ClientNotification.model_validate(d) - # send_notification returns None always so nothing to return. - asyncio.run(send_notification(client_message)) + return result.model_dump_json(by_alias=True, exclude_none=True) + else: + # Notification – fire‑and‑forget + self._run_in_background(self._bg_send_notification(payload)) return "{}" + + def shutdown(self) -> None: + """Stop the background thread and close everything.""" + if self._loop is None: + return + + self._loop.call_soon_threadsafe(self._loop.stop) + if self._thread is not None: + self._thread.join(timeout=10) + self._loop = None + self._thread = None + logger.info("MCP bridge shut down") + + def __del__(self): + try: + self.shutdown() + except Exception: + pass diff --git a/clarifai/runners/models/model_builder.py b/clarifai/runners/models/model_builder.py index 25cb55d9..7a563b28 100644 --- a/clarifai/runners/models/model_builder.py +++ b/clarifai/runners/models/model_builder.py @@ -1083,19 +1083,34 @@ def _generate_dockerfile_content(self): Generate the Dockerfile content based on the model configuration. This is a helper method that returns the content without writing to file. """ - dockerfile_template = os.path.join( - os.path.dirname(os.path.dirname(__file__)), - 'dockerfile_template', - 'Dockerfile.template', - ) + # Get the Python version from the config file + build_info = self.config.get('build_info', {}) + + # Check if node_version is specified - if so, use the Node.js Dockerfile template + node_version = build_info.get('node_version', '') or '' + use_node_template = bool(node_version and str(node_version).strip()) + + if use_node_template: + dockerfile_template_path = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + 'dockerfile_template', + 'Dockerfile.node.template', + ) + logger.info( + f"Setup: Node version {node_version} specified in config.yaml, using Node.js Dockerfile template" + ) + else: + dockerfile_template_path = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + 'dockerfile_template', + 'Dockerfile.template', + ) - with open(dockerfile_template, 'r') as template_file: + with open(dockerfile_template_path, 'r') as template_file: dockerfile_template = template_file.read() dockerfile_template = Template(dockerfile_template) - # Get the Python version from the config file - build_info = self.config.get('build_info', {}) if 'python_version' in build_info: python_version = build_info['python_version'] if python_version not in AVAILABLE_PYTHON_IMAGES: @@ -1122,6 +1137,43 @@ def _generate_dockerfile_content(self): # Parse the requirements.txt file to determine the base image dependencies = self._parse_requirements() + # If using Node.js template, use simpler substitution + if use_node_template: + if 'clarifai' not in dependencies: + raise Exception( + f"clarifai not found in requirements.txt, please add clarifai to the requirements.txt file with a fixed version. Current version is clarifai=={CLARIFAI_LATEST_VERSION}" + ) + clarifai_version = dependencies['clarifai'] + if not clarifai_version: + logger.warn( + f"clarifai version not found in requirements.txt, using the latest version {CLARIFAI_LATEST_VERSION}" + ) + clarifai_version = CLARIFAI_LATEST_VERSION + lines = [] + with open(os.path.join(self.folder, 'requirements.txt'), 'r') as file: + for line in file: + # if the line without whitespace is "clarifai" + dependency, version = self._match_req_line(line) + if dependency and dependency == "clarifai": + lines.append( + line.replace("clarifai", f"clarifai=={CLARIFAI_LATEST_VERSION}") + ) + else: + lines.append(line) + with open(os.path.join(self.folder, 'requirements.txt'), 'w') as file: + file.writelines(lines) + logger.warn( + f"Updated requirements.txt to have clarifai=={CLARIFAI_LATEST_VERSION}" + ) + + # Replace placeholders with actual values for Node.js template + dockerfile_content = dockerfile_template.safe_substitute( + PYTHON_VERSION=python_version, + NODE_VERSION=str(node_version).strip(), + ) + return dockerfile_content + + # Standard template logic (multi-stage build) is_amd_gpu = self._is_amd() if is_amd_gpu: final_image = AMD_PYTHON_BASE_IMAGE.format(python_version=python_version) diff --git a/clarifai/runners/models/stdio_mcp_class.py b/clarifai/runners/models/stdio_mcp_class.py new file mode 100644 index 00000000..d2857e07 --- /dev/null +++ b/clarifai/runners/models/stdio_mcp_class.py @@ -0,0 +1,391 @@ +"""Bridge a stdio MCP server to a Python FastMCP server. + +The implementation keeps a **single long‑lived session** for the whole +FastMCP server lifetime: + +* The stdio process is started once (the first time it is needed). +* The same `ClientSession` object is reused for every subsequent call. +* The stdio process is shut down cleanly when the FastMCP server's lifespan + context exits. +""" + +import asyncio +import inspect +import os +import traceback +from contextlib import asynccontextmanager +from typing import Any, Optional + +import yaml + +from clarifai.runners.models.mcp_class import MCPModelClass +from clarifai.utils.logging import logger + +try: + from fastmcp import FastMCP + from mcp import ClientSession, StdioServerParameters + from mcp.client.stdio import stdio_client + from mcp.types import TextContent, Tool +except ImportError: + FastMCP = None + ClientSession = None + StdioServerParameters = None + stdio_client = None + Tool = None + TextContent = None + + +class StdioMCPClient: + """A thin wrapper around a stdio MCP server that re‑uses a single session.""" + + def __init__( + self, + command: str = "npx", + args: list[str] | None = None, + env: dict[str, str] | None = None, + ): + self.command = command + self.args = args or [] + if not self.args: + raise ValueError("args must be provided") + self.env = env or {} + + self._stdio_ctx = None + self._session_ctx = None + self._session: Optional[ClientSession] = None + self._started = False + self._lock = asyncio.Lock() + + async def _ensure_started(self) -> None: + """Start the stdio process and MCP session if not already running.""" + if self._started: + return + + async with self._lock: + # Double-check after acquiring lock + if self._started: + return + + server_params = StdioServerParameters( + command=self.command, + args=self.args, + env=self.env, + ) + self._stdio_ctx = stdio_client(server_params) + stdio_transport = await self._stdio_ctx.__aenter__() + read_stream, write_stream = stdio_transport + + self._session_ctx = ClientSession(read_stream, write_stream) + self._session = await self._session_ctx.__aenter__() + await self._session.initialize() + + self._started = True + logger.debug("StdioMCPClient: stdio process + MCP session started") + + async def close(self) -> None: + """Gracefully shut down the stdio process and MCP session.""" + async with self._lock: + if not self._started: + return + + if self._session_ctx is not None: + try: + await self._session_ctx.__aexit__(None, None, None) + except Exception: + logger.exception("Error while closing MCP session") + + if self._stdio_ctx is not None: + try: + await self._stdio_ctx.__aexit__(None, None, None) + except Exception: + logger.exception("Error while closing stdio transport") + + self._started = False + self._session = None + self._session_ctx = None + self._stdio_ctx = None + logger.debug("StdioMCPClient: stdio process + MCP session stopped") + + async def list_tools(self) -> list[Tool]: + """List all tools from the stdio MCP server.""" + await self._ensure_started() + result = await self._session.list_tools() + return result.tools + + async def call_tool(self, name: str, arguments: dict[str, Any]) -> str: + """Call a tool on the stdio MCP server.""" + await self._ensure_started() + result = await self._session.call_tool(name, arguments) + + output_parts = [] + for content in result.content: + if isinstance(content, TextContent): + output_parts.append(content.text) + elif hasattr(content, "text"): + output_parts.append(content.text) + else: + output_parts.append(str(content)) + + output = "\n".join(output_parts) + + if result.isError: + raise RuntimeError(f"Tool error: {output}") + + return output + + +class StdioMCPModelClass(MCPModelClass): + """Base class for bridging stdio MCP servers with Python FastMCP servers. + + This class automatically discovers and registers all tools from a stdio MCP server + into a Python FastMCP server, making them available to MCP clients. + + Configuration is read from config.yaml in the 'mcp' section: + + Example config.yaml: + mcp: + command: "npx" + args: ["-y", "@modelcontextprotocol/server-github"] + env: + GITHUB_PERSONAL_ACCESS_TOKEN: "your-token-here" + + Subclasses should simply inherit from this class: + + class MCPModel(StdioMCPModelClass): + pass + """ + + def __init__(self): + super().__init__() + self._stdio_client: Optional[StdioMCPClient] = None + self._server: Optional[FastMCP] = None + # Flag to indicate whether tools have been registered with the FastMCP server. + # Prevents duplicate registration. Reset to False on shutdown to allow re-registration if restarted. + self._tools_registered = False + + def _json_type_to_python(self, json_type: str) -> type: + return { + "string": str, + "integer": int, + "number": float, + "boolean": bool, + "array": list, + "object": dict, + }.get(json_type, str) + + def _create_tool_function( + self, + tool_name: str, + properties: dict, + required: list, + stdio_client: StdioMCPClient, + ) -> callable: + """Create a FastMCP‑compatible coroutine function dynamically.""" + func_name = tool_name.replace("-", "_").replace(".", "_") + + required_params = [p for p in properties if p in required] + optional_params = [f"{p}=None" for p in properties if p not in required] + params_str = ", ".join(required_params + optional_params) + + body = [ + " try:", + " args = {}", + ] + for param in properties: + body.append(f" if {param} is not None:") + body.append(f" args['{param}'] = {param}") + body.append(f" return await stdio_client.call_tool('{tool_name}', args)") + body.extend( + [ + " except Exception as e:", + " import traceback", + " error_type = type(e).__name__", + " error_msg = str(e) if str(e) else repr(e)", + " tb = ''.join(traceback.format_exception(type(e), e, e.__traceback__))", + f" return f'Error executing {tool_name} ({{error_type}}): {{error_msg}}\\n\\nTraceback:\\n{{tb}}'", + ] + ) + + code = f"async def {func_name}({params_str}) -> str:\n" + "\n".join(body) + + namespace = {"stdio_client": stdio_client} + exec(code, namespace) + func = namespace[func_name] + + annotations = {"return": str} + for param, schema in properties.items(): + py_type = self._json_type_to_python(schema.get("type", "string")) + annotations[param] = py_type if param in required else Optional[py_type] + func.__annotations__ = annotations + + return func + + def _find_config_file(self) -> Optional[str]: + """Find config.yaml file in the same directory as the class file.""" + # Get the file path of the actual class (subclass) being instantiated + try: + class_file = inspect.getfile(self.__class__) + class_dir = os.path.dirname(os.path.abspath(class_file)) + + config_path = os.path.join(os.path.dirname(class_dir), "config.yaml") + if os.path.exists(config_path): + return config_path + + except (OSError, TypeError) as e: + logger.warning(f"Could not determine class file location: {e}") + + return None + + def _load_secrets(self) -> list[dict[str, Any]]: + config_path = self._find_config_file() + if not config_path: + raise FileNotFoundError("config.yaml not found") + with open(config_path, "r") as f: + config = yaml.safe_load(f) + + return [ + { + "id": s.get("id"), + "value": s.get("value"), + "env_var": s.get("env_var"), + } + for s in config.get("secrets", []) + ] + + def _load_mcp_config(self) -> dict[str, Any]: + config_path = self._find_config_file() + if not config_path: + raise FileNotFoundError("config.yaml not found") + with open(config_path, "r") as f: + config = yaml.safe_load(f) + + mcp_cfg = config.get("mcp_server") + if not mcp_cfg: + raise ValueError("Missing 'mcp_server' section in config.yaml") + + command = mcp_cfg.get("command") + args = mcp_cfg.get("args") + env = mcp_cfg.get("env", {}) + + if not command: + raise ValueError("'command' missing in mcp_server") + if not args: + raise ValueError("'args' missing in mcp_server") + + return { + "command": command, + "args": args if isinstance(args, list) else [args], + "env": env if isinstance(env, dict) else {}, + } + + def _get_stdio_client(self) -> StdioMCPClient: + """Get or create the stdio MCP client.""" + if self._stdio_client is None: + cfg = self._load_mcp_config() + env = dict(cfg["env"]) + + for secret in self._load_secrets(): + env_var = secret.get("env_var") + if not env_var: + continue + if secret.get("value") is not None: + env[env_var] = secret["value"] + elif os.getenv(env_var) is not None: + env[env_var] = os.getenv(env_var) + + self._stdio_client = StdioMCPClient( + command=cfg["command"], + args=cfg["args"], + env=env, + ) + return self._stdio_client + + def get_server(self) -> FastMCP: + """Return the FastMCP server instance.""" + if self._server is not None: + return self._server + + if FastMCP is None: + raise ImportError("fastmcp package is required – install with `pip install fastmcp`") + + @asynccontextmanager + async def lifespan(server: FastMCP): + """Discover stdio tools and keep the session alive.""" + if self._tools_registered: + yield + return + + logger.info("🚀 Starting stdio MCP bridge...") + stdio_client = self._get_stdio_client() + + try: + tools = await stdio_client.list_tools() + logger.info(f"✅ Discovered {len(tools)} tools from stdio MCP") + + for tool in tools: + name = tool.name + desc = tool.description or f"Tool: {name}" + schema = tool.inputSchema or {} + props = schema.get("properties", {}) + required = schema.get("required", []) + + func = self._create_tool_function(name, props, required, stdio_client) + func.__doc__ = desc + server.add_tool(func, name=name, description=desc) + + # Preserve original JSON‑schema + if hasattr(server, "_tool_manager") and hasattr( + server._tool_manager, "_tools" + ): + reg = server._tool_manager._tools.get(name) + if reg and schema: + reg.parameters = schema + + logger.debug(f" ✅ Registered {name}") + + self._tools_registered = True + logger.info("✅ Bridge server ready") + + except Exception as exc: + logger.error(f"❌ Error during bridge startup: {exc}") + traceback.print_exc() + raise + + try: + yield + finally: + logger.info("🛑 Shutting down stdio MCP bridge...") + try: + await stdio_client.close() + except Exception: + logger.exception("Error while closing StdioMCPClient") + self._stdio_client = None + logger.info("🛑 Bridge shutdown complete") + + self._server = FastMCP( + "stdio-mcp-bridge", + instructions="Bridge to a stdio MCP server. All tools are automatically available.", + lifespan=lifespan, + ) + return self._server + + async def _background_shutdown(self) -> None: + """Override to also close the stdio client.""" + # Close stdio client first + if self._stdio_client is not None: + try: + await self._stdio_client.close() + except Exception: + logger.exception("Error while closing StdioMCPClient") + self._stdio_client = None + + # Then call parent shutdown + await super()._background_shutdown() + + def shutdown(self): + """ + Cleanly shut down the server and reset the tools registration flag. + Call this when the FastMCP server is shutting down. + """ + self._tools_registered = False + super().shutdown() diff --git a/tests/runners/test_stdio_mcp_class.py b/tests/runners/test_stdio_mcp_class.py new file mode 100644 index 00000000..edafa22c --- /dev/null +++ b/tests/runners/test_stdio_mcp_class.py @@ -0,0 +1,637 @@ +"""Comprehensive tests for stdio_mcp_class.py""" + +import os +import tempfile +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import yaml + +from clarifai.runners.models.stdio_mcp_class import StdioMCPClient, StdioMCPModelClass + + +@pytest.fixture +def temp_config_file(): + """Create a temporary config.yaml file for testing.""" + config_data = { + "mcp_server": { + "command": "uvx", + "args": ["mcp-server-calculator"], + "env": {}, + }, + "secrets": [], + } + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(config_data, f) + yield f.name + os.unlink(f.name) + + +@pytest.fixture +def temp_config_file_with_secrets(): + """Create a temporary config.yaml file with secrets.""" + config_data = { + "mcp_server": { + "command": "uvx", + "args": ["mcp-server-calculator"], + "env": {}, + }, + "secrets": [ + {"id": "secret1", "value": "secret_value", "env_var": "SECRET_ENV_VAR"}, + {"id": "secret2", "value": None, "env_var": "SECRET_ENV_VAR_2"}, + ], + } + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(config_data, f) + yield f.name + os.unlink(f.name) + + +@pytest.fixture +def mock_mcp_imports(): + """Mock MCP-related imports.""" + + # Create a mock TextContent class that works with isinstance + class MockTextContent: + pass + + with ( + patch("clarifai.runners.models.stdio_mcp_class.FastMCP") as mock_fastmcp, + patch("clarifai.runners.models.stdio_mcp_class.ClientSession") as mock_session, + patch("clarifai.runners.models.stdio_mcp_class.StdioServerParameters") as mock_params, + patch("clarifai.runners.models.stdio_mcp_class.stdio_client") as mock_stdio_client, + patch("clarifai.runners.models.stdio_mcp_class.Tool") as mock_tool, + patch( + "clarifai.runners.models.stdio_mcp_class.TextContent", MockTextContent + ) as mock_text_content, + ): + yield { + "FastMCP": mock_fastmcp, + "ClientSession": mock_session, + "StdioServerParameters": mock_params, + "stdio_client": mock_stdio_client, + "Tool": mock_tool, + "TextContent": MockTextContent, + } + + +class TestStdioMCPClient: + """Test cases for StdioMCPClient class.""" + + def test_init_with_args(self): + """Test StdioMCPClient initialization with args.""" + client = StdioMCPClient(command="uvx", args=["mcp-server-calculator"], env={}) + assert client.command == "uvx" + assert client.args == ["mcp-server-calculator"] + assert client.env == {} + assert client._started is False + assert client._session is None + + def test_init_without_args_raises_error(self): + """Test that initialization without args raises ValueError.""" + with pytest.raises(ValueError, match="args must be provided"): + StdioMCPClient(command="uvx", args=None) + + def test_init_with_empty_args_raises_error(self): + """Test that initialization with empty args raises ValueError.""" + with pytest.raises(ValueError, match="args must be provided"): + StdioMCPClient(command="uvx", args=[]) + + @pytest.mark.asyncio + async def test_ensure_started_creates_session(self, mock_mcp_imports): + """Test that _ensure_started creates a session.""" + mock_stdio_ctx = AsyncMock() + mock_transport = (AsyncMock(), AsyncMock()) + mock_stdio_ctx.__aenter__ = AsyncMock(return_value=mock_transport) + + mock_session_ctx = AsyncMock() + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_session_ctx.__aenter__ = AsyncMock(return_value=mock_session) + + mock_mcp_imports["stdio_client"].return_value = mock_stdio_ctx + mock_mcp_imports["ClientSession"].return_value = mock_session_ctx + + client = StdioMCPClient(command="uvx", args=["test"], env={}) + await client._ensure_started() + + assert client._started is True + assert client._session is not None + mock_session.initialize.assert_called_once() + + @pytest.mark.asyncio + async def test_ensure_started_idempotent(self, mock_mcp_imports): + """Test that _ensure_started is idempotent.""" + mock_stdio_ctx = AsyncMock() + mock_transport = (AsyncMock(), AsyncMock()) + mock_stdio_ctx.__aenter__ = AsyncMock(return_value=mock_transport) + + mock_session_ctx = AsyncMock() + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_session_ctx.__aenter__ = AsyncMock(return_value=mock_session) + + mock_mcp_imports["stdio_client"].return_value = mock_stdio_ctx + mock_mcp_imports["ClientSession"].return_value = mock_session_ctx + + client = StdioMCPClient(command="uvx", args=["test"], env={}) + await client._ensure_started() + await client._ensure_started() # Call again + + # Should only initialize once + assert mock_session.initialize.call_count == 1 + + @pytest.mark.asyncio + async def test_close_cleans_up_resources(self, mock_mcp_imports): + """Test that close properly cleans up resources.""" + mock_stdio_ctx = AsyncMock() + mock_transport = (AsyncMock(), AsyncMock()) + mock_stdio_ctx.__aenter__ = AsyncMock(return_value=mock_transport) + mock_stdio_ctx.__aexit__ = AsyncMock() + + mock_session_ctx = AsyncMock() + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_session_ctx.__aenter__ = AsyncMock(return_value=mock_session) + mock_session_ctx.__aexit__ = AsyncMock() + + mock_mcp_imports["stdio_client"].return_value = mock_stdio_ctx + mock_mcp_imports["ClientSession"].return_value = mock_session_ctx + + client = StdioMCPClient(command="uvx", args=["test"], env={}) + await client._ensure_started() + await client.close() + + assert client._started is False + assert client._session is None + mock_session_ctx.__aexit__.assert_called_once() + mock_stdio_ctx.__aexit__.assert_called_once() + + @pytest.mark.asyncio + async def test_close_when_not_started(self): + """Test that close does nothing when not started.""" + client = StdioMCPClient(command="uvx", args=["test"], env={}) + await client.close() # Should not raise + + @pytest.mark.asyncio + async def test_list_tools(self, mock_mcp_imports): + """Test listing tools from the stdio MCP server.""" + mock_tool = MagicMock() + mock_tool.name = "test_tool" + mock_tool.description = "A test tool" + + mock_stdio_ctx = AsyncMock() + mock_transport = (AsyncMock(), AsyncMock()) + mock_stdio_ctx.__aenter__ = AsyncMock(return_value=mock_transport) + + mock_session_ctx = AsyncMock() + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_session.list_tools = AsyncMock(return_value=MagicMock(tools=[mock_tool])) + mock_session_ctx.__aenter__ = AsyncMock(return_value=mock_session) + + mock_mcp_imports["stdio_client"].return_value = mock_stdio_ctx + mock_mcp_imports["ClientSession"].return_value = mock_session_ctx + + client = StdioMCPClient(command="uvx", args=["test"], env={}) + tools = await client.list_tools() + + assert len(tools) == 1 + assert tools[0].name == "test_tool" + + @pytest.mark.asyncio + async def test_call_tool_success(self, mock_mcp_imports): + """Test calling a tool successfully.""" + # Create a mock that is an instance of TextContent + TextContent = mock_mcp_imports["TextContent"] + mock_text_content = MagicMock(spec=TextContent) + mock_text_content.text = "result text" + + mock_stdio_ctx = AsyncMock() + mock_transport = (AsyncMock(), AsyncMock()) + mock_stdio_ctx.__aenter__ = AsyncMock(return_value=mock_transport) + + mock_session_ctx = AsyncMock() + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_result = MagicMock() + mock_result.content = [mock_text_content] + mock_result.isError = False + mock_session.call_tool = AsyncMock(return_value=mock_result) + mock_session_ctx.__aenter__ = AsyncMock(return_value=mock_session) + + mock_mcp_imports["stdio_client"].return_value = mock_stdio_ctx + mock_mcp_imports["ClientSession"].return_value = mock_session_ctx + + client = StdioMCPClient(command="uvx", args=["test"], env={}) + result = await client.call_tool("test_tool", {"arg1": "value1"}) + + assert result == "result text" + mock_session.call_tool.assert_called_once_with("test_tool", {"arg1": "value1"}) + + @pytest.mark.asyncio + async def test_call_tool_with_error(self, mock_mcp_imports): + """Test calling a tool that returns an error.""" + # Create a mock that is an instance of TextContent + TextContent = mock_mcp_imports["TextContent"] + mock_text_content = MagicMock(spec=TextContent) + mock_text_content.text = "error message" + + mock_stdio_ctx = AsyncMock() + mock_transport = (AsyncMock(), AsyncMock()) + mock_stdio_ctx.__aenter__ = AsyncMock(return_value=mock_transport) + + mock_session_ctx = AsyncMock() + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_result = MagicMock() + mock_result.content = [mock_text_content] + mock_result.isError = True + mock_session.call_tool = AsyncMock(return_value=mock_result) + mock_session_ctx.__aenter__ = AsyncMock(return_value=mock_session) + + mock_mcp_imports["stdio_client"].return_value = mock_stdio_ctx + mock_mcp_imports["ClientSession"].return_value = mock_session_ctx + + client = StdioMCPClient(command="uvx", args=["test"], env={}) + with pytest.raises(RuntimeError, match="Tool error: error message"): + await client.call_tool("test_tool", {"arg1": "value1"}) + + @pytest.mark.asyncio + async def test_call_tool_with_multiple_content_parts(self, mock_mcp_imports): + """Test calling a tool with multiple content parts.""" + # Create mocks that are instances of TextContent + TextContent = mock_mcp_imports["TextContent"] + mock_text_content1 = MagicMock(spec=TextContent) + mock_text_content1.text = "part1" + mock_text_content2 = MagicMock(spec=TextContent) + mock_text_content2.text = "part2" + + mock_stdio_ctx = AsyncMock() + mock_transport = (AsyncMock(), AsyncMock()) + mock_stdio_ctx.__aenter__ = AsyncMock(return_value=mock_transport) + + mock_session_ctx = AsyncMock() + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_result = MagicMock() + mock_result.content = [mock_text_content1, mock_text_content2] + mock_result.isError = False + mock_session.call_tool = AsyncMock(return_value=mock_result) + mock_session_ctx.__aenter__ = AsyncMock(return_value=mock_session) + + mock_mcp_imports["stdio_client"].return_value = mock_stdio_ctx + mock_mcp_imports["ClientSession"].return_value = mock_session_ctx + + client = StdioMCPClient(command="uvx", args=["test"], env={}) + result = await client.call_tool("test_tool", {}) + + assert result == "part1\npart2" + + +class TestStdioMCPModelClass: + """Test cases for StdioMCPModelClass.""" + + def test_init(self): + """Test StdioMCPModelClass initialization.""" + model = StdioMCPModelClass() + assert model._stdio_client is None + assert model._server is None + assert model._tools_registered is False + + def test_json_type_to_python(self): + """Test _json_type_to_python method.""" + model = StdioMCPModelClass() + assert model._json_type_to_python("string") is str + assert model._json_type_to_python("integer") is int + assert model._json_type_to_python("number") is float + assert model._json_type_to_python("boolean") is bool + assert model._json_type_to_python("array") is list + assert model._json_type_to_python("object") is dict + assert model._json_type_to_python("unknown") is str # Default + + @pytest.mark.asyncio + async def test_create_tool_function(self): + """Test _create_tool_function creates a valid function.""" + mock_stdio_client = AsyncMock() + mock_stdio_client.call_tool = AsyncMock(return_value="result") + + model = StdioMCPModelClass() + properties = { + "a": {"type": "string"}, + "b": {"type": "integer"}, + } + required = ["a"] + + func = model._create_tool_function("test-tool", properties, required, mock_stdio_client) + + assert callable(func) + assert func.__name__ == "test_tool" + assert "a" in func.__annotations__ + assert "b" in func.__annotations__ + + # Test calling the function + result = await func("value_a", 42) + assert result == "result" + mock_stdio_client.call_tool.assert_called_once_with("test-tool", {"a": "value_a", "b": 42}) + + @pytest.mark.asyncio + async def test_create_tool_function_with_optional_params(self): + """Test _create_tool_function with optional parameters.""" + mock_stdio_client = AsyncMock() + mock_stdio_client.call_tool = AsyncMock(return_value="result") + + model = StdioMCPModelClass() + properties = { + "a": {"type": "string"}, + "b": {"type": "integer"}, + } + required = [] # All optional + + func = model._create_tool_function("test-tool", properties, required, mock_stdio_client) + + # Call with only one parameter + result = await func("value_a") + assert result == "result" + mock_stdio_client.call_tool.assert_called_once_with("test-tool", {"a": "value_a"}) + + @pytest.mark.asyncio + async def test_create_tool_function_with_error(self): + """Test _create_tool_function error handling.""" + mock_stdio_client = AsyncMock() + mock_stdio_client.call_tool = AsyncMock(side_effect=RuntimeError("Test error")) + + model = StdioMCPModelClass() + properties = {"a": {"type": "string"}} + required = ["a"] + + func = model._create_tool_function("test-tool", properties, required, mock_stdio_client) + + result = await func("value_a") + assert "Error executing test-tool" in result + assert "RuntimeError" in result + assert "Test error" in result + + def test_find_config_file(self, temp_config_file): + """Test _find_config_file finds the config file.""" + + # Create a model class in a temporary directory + class TestModel(StdioMCPModelClass): + pass + + # The code expects: if model is at /path/to/1/model.py, config is at /path/to/config.yaml + # So we need to create a model file path in a subdirectory of the config's directory + config_dir = os.path.dirname(temp_config_file) + model_dir = os.path.join(config_dir, "1") + model_file = os.path.join(model_dir, "model.py") + + # Mock inspect.getfile to return our model file path + with ( + patch("clarifai.runners.models.stdio_mcp_class.inspect.getfile") as mock_getfile, + patch("clarifai.runners.models.stdio_mcp_class.os.path.join") as mock_join, + patch("clarifai.runners.models.stdio_mcp_class.os.path.exists") as mock_exists, + ): + mock_getfile.return_value = model_file + + # Make join return the temp_config_file when joining parent dir with "config.yaml" + def join_side_effect(*args): + if len(args) == 2 and args[1] == "config.yaml": + return temp_config_file + return os.path.join(*args) + + mock_join.side_effect = join_side_effect + + # Make exists return True for the config file path + def exists_side_effect(path): + return path == temp_config_file + + mock_exists.side_effect = exists_side_effect + + model = TestModel() + config_path = model._find_config_file() + + assert config_path is not None + assert config_path == temp_config_file + + def test_find_config_file_not_found(self): + """Test _find_config_file when config file doesn't exist.""" + + class TestModel(StdioMCPModelClass): + pass + + with ( + patch("clarifai.runners.models.stdio_mcp_class.inspect.getfile") as mock_getfile, + patch("clarifai.runners.models.stdio_mcp_class.os.path.exists") as mock_exists, + ): + mock_getfile.return_value = "/some/path/model.py" + mock_exists.return_value = False + + model = TestModel() + config_path = model._find_config_file() + + assert config_path is None + + def test_load_secrets(self, temp_config_file_with_secrets): + """Test _load_secrets loads secrets from config.""" + + class TestModel(StdioMCPModelClass): + pass + + with ( + patch("clarifai.runners.models.stdio_mcp_class.inspect.getfile") as mock_getfile, + patch("clarifai.runners.models.stdio_mcp_class.os.path.dirname") as mock_dirname, + patch("clarifai.runners.models.stdio_mcp_class.os.path.exists") as mock_exists, + patch("clarifai.runners.models.stdio_mcp_class.os.path.abspath") as mock_abspath, + ): + mock_getfile.return_value = temp_config_file_with_secrets + mock_abspath.return_value = temp_config_file_with_secrets + mock_dirname.side_effect = ( + lambda x: os.path.dirname(x) + if x != temp_config_file_with_secrets + else os.path.dirname(temp_config_file_with_secrets) + ) + mock_exists.return_value = True + + model = TestModel() + # Override _find_config_file to return our temp file + model._find_config_file = lambda: temp_config_file_with_secrets + + secrets = model._load_secrets() + + assert len(secrets) == 2 + assert secrets[0]["id"] == "secret1" + assert secrets[0]["value"] == "secret_value" + assert secrets[0]["env_var"] == "SECRET_ENV_VAR" + + def test_load_secrets_file_not_found(self): + """Test _load_secrets when config file doesn't exist.""" + model = StdioMCPModelClass() + model._find_config_file = lambda: None + + with pytest.raises(FileNotFoundError, match="config.yaml not found"): + model._load_secrets() + + def test_load_mcp_config(self, temp_config_file): + """Test _load_mcp_config loads MCP config.""" + model = StdioMCPModelClass() + model._find_config_file = lambda: temp_config_file + + config = model._load_mcp_config() + + assert config["command"] == "uvx" + assert config["args"] == ["mcp-server-calculator"] + assert config["env"] == {} + + def test_load_mcp_config_missing_section(self): + """Test _load_mcp_config when mcp_server section is missing.""" + config_data = {"secrets": []} + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(config_data, f) + temp_file = f.name + + try: + model = StdioMCPModelClass() + model._find_config_file = lambda: temp_file + + with pytest.raises(ValueError, match="Missing 'mcp_server' section"): + model._load_mcp_config() + finally: + os.unlink(temp_file) + + def test_load_mcp_config_missing_command(self): + """Test _load_mcp_config when command is missing.""" + config_data = {"mcp_server": {"args": ["test"]}, "secrets": []} + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(config_data, f) + temp_file = f.name + + try: + model = StdioMCPModelClass() + model._find_config_file = lambda: temp_file + + with pytest.raises(ValueError, match="'command' missing"): + model._load_mcp_config() + finally: + os.unlink(temp_file) + + def test_load_mcp_config_missing_args(self): + """Test _load_mcp_config when args is missing.""" + config_data = {"mcp_server": {"command": "uvx"}, "secrets": []} + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(config_data, f) + temp_file = f.name + + try: + model = StdioMCPModelClass() + model._find_config_file = lambda: temp_file + + with pytest.raises(ValueError, match="'args' missing"): + model._load_mcp_config() + finally: + os.unlink(temp_file) + + def test_load_mcp_config_with_string_args(self): + """Test _load_mcp_config when args is a string (should be converted to list).""" + config_data = {"mcp_server": {"command": "uvx", "args": "single-arg"}, "secrets": []} + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(config_data, f) + temp_file = f.name + + try: + model = StdioMCPModelClass() + model._find_config_file = lambda: temp_file + + config = model._load_mcp_config() + assert config["args"] == ["single-arg"] + finally: + os.unlink(temp_file) + + def test_get_stdio_client(self, temp_config_file): + """Test _get_stdio_client creates and returns a client.""" + model = StdioMCPModelClass() + model._find_config_file = lambda: temp_config_file + model._load_secrets = lambda: [] + + client = model._get_stdio_client() + + assert client is not None + assert isinstance(client, StdioMCPClient) + assert client.command == "uvx" + assert client.args == ["mcp-server-calculator"] + + def test_get_stdio_client_with_secrets(self, temp_config_file_with_secrets): + """Test _get_stdio_client includes secrets in environment.""" + model = StdioMCPModelClass() + model._find_config_file = lambda: temp_config_file_with_secrets + + with patch.dict(os.environ, {"SECRET_ENV_VAR_2": "env_value"}): + client = model._get_stdio_client() + + assert client is not None + assert client.env.get("SECRET_ENV_VAR") == "secret_value" + assert client.env.get("SECRET_ENV_VAR_2") == "env_value" + + def test_get_stdio_client_cached(self, temp_config_file): + """Test _get_stdio_client returns cached client.""" + model = StdioMCPModelClass() + model._find_config_file = lambda: temp_config_file + model._load_secrets = lambda: [] + + client1 = model._get_stdio_client() + client2 = model._get_stdio_client() + + assert client1 is client2 # Should be the same instance + + def test_get_server_creates_fastmcp_server(self, mock_mcp_imports, temp_config_file): + """Test get_server creates a FastMCP server.""" + mock_fastmcp_instance = MagicMock() + mock_mcp_imports["FastMCP"].return_value = mock_fastmcp_instance + + model = StdioMCPModelClass() + model._find_config_file = lambda: temp_config_file + + server = model.get_server() + + assert server is not None + assert server == mock_fastmcp_instance + mock_mcp_imports["FastMCP"].assert_called_once() + + def test_get_server_import_error(self): + """Test get_server raises ImportError when fastmcp is not available.""" + with patch("clarifai.runners.models.stdio_mcp_class.FastMCP", None): + model = StdioMCPModelClass() + with pytest.raises(ImportError, match="fastmcp package is required"): + model.get_server() + + @pytest.mark.asyncio + async def test_background_shutdown(self): + """Test _background_shutdown closes stdio client.""" + mock_stdio_client = AsyncMock() + mock_stdio_client.close = AsyncMock() + + model = StdioMCPModelClass() + model._stdio_client = mock_stdio_client + + # Mock parent's _background_shutdown to avoid calling it + with patch.object( + StdioMCPModelClass.__bases__[0], "_background_shutdown", new_callable=AsyncMock + ) as mock_parent_shutdown: + await model._background_shutdown() + + mock_stdio_client.close.assert_called_once() + assert model._stdio_client is None + # Verify parent shutdown is also called + mock_parent_shutdown.assert_called_once() + + def test_shutdown_resets_tools_registered(self): + """Test shutdown resets tools_registered flag.""" + model = StdioMCPModelClass() + model._tools_registered = True + + # Mock parent shutdown to avoid actual shutdown logic + with patch.object(StdioMCPModelClass.__bases__[0], "shutdown") as mock_parent_shutdown: + model.shutdown() + + assert model._tools_registered is False + # Verify parent shutdown is also called + mock_parent_shutdown.assert_called_once()