diff --git a/README.md b/README.md index 3bb228d..8bc5ccc 100644 --- a/README.md +++ b/README.md @@ -487,9 +487,9 @@ The following environment variables are used to configure the ClickHouse and chD * `CLICKHOUSE_CONNECT_TIMEOUT`: Connection timeout in seconds * Default: `"30"` * Increase this value if you experience connection timeouts -* `CLICKHOUSE_SEND_RECEIVE_TIMEOUT`: Send/receive timeout in seconds - * Default: `"300"` - * Increase this value for long-running queries +* `CLICKHOUSE_SEND_RECEIVE_TIMEOUT`: Send/receive timeout in seconds for the underlying HTTP connection + * Default: automatically set to `CLICKHOUSE_MCP_QUERY_TIMEOUT + 5` so worker threads unblock shortly after a query timeout + * If explicitly set, the value is used as-is (e.g. `"300"` for long-running queries) * `CLICKHOUSE_DATABASE`: Default database to use * Default: None (uses server default) * Set this to automatically connect to a specific database @@ -503,9 +503,14 @@ The following environment variables are used to configure the ClickHouse and chD * `CLICKHOUSE_MCP_BIND_PORT`: Port to bind the MCP server to when using HTTP or SSE transport * Default: `"8000"` * Only used when transport is `"http"` or `"sse"` -* `CLICKHOUSE_MCP_QUERY_TIMEOUT`: Timeout in seconds for SELECT tools +* `CLICKHOUSE_MCP_QUERY_TIMEOUT`: Timeout in seconds for query tool calls * Default: `"30"` * Increase this if you see `Query timed out after ...` errors for heavy queries + * When a query times out, the server issues a `KILL QUERY` on the ClickHouse server to cancel it + * Unless `CLICKHOUSE_SEND_RECEIVE_TIMEOUT` is explicitly set, the HTTP read timeout is automatically aligned to this value plus a small buffer, so worker threads unblock shortly after a timeout +* `CLICKHOUSE_MCP_MAX_WORKERS`: Maximum number of concurrent query worker threads + * Default: `"10"` + * Increase if your workload requires many concurrent tool calls * `CLICKHOUSE_MCP_AUTH_TOKEN`: Authentication token for HTTP/SSE transports * Default: None * **Required** when using HTTP or SSE transport (unless `CLICKHOUSE_MCP_AUTH_DISABLED=true`) diff --git a/mcp_clickhouse/__init__.py b/mcp_clickhouse/__init__.py index 549195c..e9d29b7 100644 --- a/mcp_clickhouse/__init__.py +++ b/mcp_clickhouse/__init__.py @@ -2,6 +2,8 @@ from .mcp_server import ( create_clickhouse_client, + _clear_client_cache, + _resolve_client_config, list_databases, list_tables, run_query, @@ -27,6 +29,8 @@ "list_tables", "run_query", "create_clickhouse_client", + "_clear_client_cache", + "_resolve_client_config", "create_chdb_client", "run_chdb_select_query", "chdb_initial_prompt", diff --git a/mcp_clickhouse/mcp_env.py b/mcp_clickhouse/mcp_env.py index 49d5e9b..cec2fa1 100644 --- a/mcp_clickhouse/mcp_env.py +++ b/mcp_clickhouse/mcp_env.py @@ -291,6 +291,7 @@ class MCPServerConfig: CLICKHOUSE_MCP_BIND_HOST: Bind host for HTTP/SSE (default: 127.0.0.1) CLICKHOUSE_MCP_BIND_PORT: Bind port for HTTP/SSE (default: 8000) CLICKHOUSE_MCP_QUERY_TIMEOUT: SELECT tool timeout in seconds (default: 30) + CLICKHOUSE_MCP_MAX_WORKERS: Maximum thread pool workers for query execution (default: 10) CLICKHOUSE_MCP_AUTH_TOKEN: Authentication token for HTTP/SSE transports (required unless CLICKHOUSE_MCP_AUTH_DISABLED=true) CLICKHOUSE_MCP_AUTH_DISABLED: Disable authentication (default: false, use @@ -317,6 +318,14 @@ def bind_port(self) -> int: def query_timeout(self) -> int: return int(os.getenv("CLICKHOUSE_MCP_QUERY_TIMEOUT", "30")) + @property + def max_workers(self) -> int: + """Maximum thread pool workers for query execution. + + Default: 10 + """ + return int(os.getenv("CLICKHOUSE_MCP_MAX_WORKERS", "10")) + @property def auth_token(self) -> Optional[str]: """Get the authentication token for HTTP/SSE transports.""" diff --git a/mcp_clickhouse/mcp_server.py b/mcp_clickhouse/mcp_server.py index e4cf184..c9f7b3d 100644 --- a/mcp_clickhouse/mcp_server.py +++ b/mcp_clickhouse/mcp_server.py @@ -1,10 +1,12 @@ import logging import json -from typing import Optional, List, Any, Dict +from typing import Optional, List, Any, Dict, Tuple import concurrent.futures import atexit import os import re +import threading +import time import uuid import clickhouse_connect @@ -66,10 +68,24 @@ class Table: ) logger = logging.getLogger(MCP_SERVER_NAME) -QUERY_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=10) +load_dotenv() + +_max_workers = get_mcp_config().max_workers +QUERY_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=_max_workers) atexit.register(lambda: QUERY_EXECUTOR.shutdown(wait=True)) -load_dotenv() +# --- Client cache --- +# Cache of ClickHouse clients keyed by frozen config, enabling client reuse +# across tool calls. Each entry is (client, last_used_timestamp). +_client_cache: Dict[Tuple, Tuple] = {} +_client_cache_lock = threading.Lock() +_CLIENT_IDLE_PING_THRESHOLD = 60 # seconds before we ping to verify liveness + +# --- Active query tracker --- +# Maps query_id -> (cache_key, query_text) so we can KILL QUERY on the +# correct server when a timeout fires. +_active_queries: Dict[str, Tuple] = {} +_active_queries_lock = threading.Lock() # Configure authentication for HTTP/SSE transports auth_provider = None @@ -161,8 +177,19 @@ def to_json(obj: Any) -> str: def list_databases(): """List available ClickHouse databases""" logger.info("Listing all databases") - client = create_clickhouse_client() - result = client.command("SHOW DATABASES") + config = _resolve_client_config() + + for attempt in range(2): + try: + client = create_clickhouse_client(config=config) + result = client.command("SHOW DATABASES") + break + except Exception as err: + if attempt == 0 and _is_connection_error(err): + logger.warning("list_databases connection error, retrying: %s", err) + _evict_cached_client(config) + continue + raise # Convert newline-separated string to list and trim whitespace if isinstance(result, str): @@ -337,8 +364,33 @@ def list_tables( page_size, include_detailed_columns, ) - client = create_clickhouse_client() + config = _resolve_client_config() + + for attempt in range(2): + try: + client = create_clickhouse_client(config=config) + return _list_tables_impl( + client, database, like, not_like, page_token, + page_size, include_detailed_columns, + ) + except Exception as err: + if attempt == 0 and _is_connection_error(err): + logger.warning("list_tables connection error, retrying: %s", err) + _evict_cached_client(config) + continue + raise + +def _list_tables_impl( + client, + database: str, + like: Optional[str], + not_like: Optional[str], + page_token: Optional[str], + page_size: int, + include_detailed_columns: bool, +) -> Dict[str, Any]: + """Inner implementation of list_tables, separated for retry logic.""" if page_token and page_token in table_pagination_cache: cached_state = table_pagination_cache[page_token] cached_include_detailed = cached_state.get("include_detailed_columns", True) @@ -449,20 +501,82 @@ def _validate_query_for_destructive_ops(query: str) -> None: ) -def execute_query(query: str): - client = create_clickhouse_client() +def _is_connection_error(err: Exception) -> bool: + """Check if an exception indicates a broken connection rather than a query error.""" + from clickhouse_connect.driver.exceptions import OperationalError + if isinstance(err, (OSError, ConnectionError, OperationalError)): + return True + err_str = str(err).lower() + return any(s in err_str for s in ("connection", "timed out", "reset by peer", "eof")) + + +def execute_query(query: str, query_id: str, client_config: dict): + """Execute a query in a worker thread. + + Args: + query: SQL to execute. + query_id: Unique identifier for server-side tracking / cancellation. + client_config: Pre-resolved config dict (resolved on the request thread). + """ + cache_key = _config_to_cache_key(client_config) + with _active_queries_lock: + _active_queries[query_id] = (cache_key, query) + try: + client = create_clickhouse_client(config=client_config) _validate_query_for_destructive_ops(query) query_settings = build_query_settings(client) + query_settings["query_id"] = query_id res = client.query(query, settings=query_settings) - logger.info(f"Query returned {len(res.result_rows)} rows") + logger.info(f"Query {query_id} returned {len(res.result_rows)} rows") return {"columns": res.column_names, "rows": res.result_rows} except ToolError: raise except Exception as err: - logger.error(f"Error executing query: {err}") + # Evict the cached client on connection errors so the next call + # creates a fresh one. We do NOT retry here because the query may + # involve writes and retrying could duplicate side effects. + if _is_connection_error(err): + _evict_cached_client(client_config) + logger.error(f"Error executing query {query_id}: {err}") raise ToolError(f"Query execution failed: {str(err)}") + finally: + with _active_queries_lock: + _active_queries.pop(query_id, None) + + +def _cancel_query(query_id: str): + """Issue KILL QUERY on the ClickHouse server for a timed-out query. + + Uses the same cached client (same server/credentials) that originated + the query. Failures are logged but never raised — cancellation errors + must not mask the original timeout. + """ + with _active_queries_lock: + entry = _active_queries.pop(query_id, None) + + if entry is None: + logger.debug("Query %s already completed, nothing to cancel", query_id) + return + + cache_key, _query_text = entry + try: + with _client_cache_lock: + cached = _client_cache.get(cache_key) + if cached is None: + logger.warning( + "No cached client for query %s cancel — server-side query may still run", + query_id, + ) + return + + client, _ = cached + logger.info("Cancelling query %s via KILL QUERY", query_id) + client.command(f"KILL QUERY WHERE query_id = '{query_id}'") + logger.info("Successfully cancelled query %s", query_id) + except Exception as e: + logger.warning("Failed to cancel query %s: %s", query_id, e) def run_query(query: str): @@ -472,8 +586,23 @@ def run_query(query: str): to allow DDL and DML statements when your ClickHouse server permits them. """ logger.info(f"Executing query: {query}") + + # Resolve config on the request thread where FastMCP Context is available + client_config = _resolve_client_config() + query_id = str(uuid.uuid4()) + try: - future = QUERY_EXECUTOR.submit(execute_query, query) + # Log pool utilization for observability using the pending work queue + pending = QUERY_EXECUTOR._work_queue.qsize() + with _active_queries_lock: + in_flight = len(_active_queries) + if in_flight + pending >= _max_workers: + logger.warning( + "Thread pool saturated: %d in-flight + %d queued vs %d workers", + in_flight, pending, _max_workers, + ) + + future = QUERY_EXECUTOR.submit(execute_query, query, query_id, client_config) try: timeout_secs = get_mcp_config().query_timeout result = future.result(timeout=timeout_secs) @@ -488,8 +617,10 @@ def run_query(query: str): } return result except concurrent.futures.TimeoutError: - logger.warning(f"Query timed out after {timeout_secs} seconds: {query}") - future.cancel() + logger.warning( + "Query %s timed out after %s seconds: %s", query_id, timeout_secs, query + ) + _cancel_query(query_id) raise ToolError(f"Query timed out after {timeout_secs} seconds") except ToolError: raise @@ -498,39 +629,182 @@ def run_query(query: str): raise RuntimeError(f"Unexpected error during query execution: {str(e)}") -def create_clickhouse_client(): +def _config_to_cache_key(config: dict) -> tuple: + """Convert a client config dict into a hashable cache key. + + Handles nested dicts (e.g. 'settings') by recursively sorting items. + """ + items = [] + for k, v in sorted(config.items()): + if isinstance(v, dict): + v = _config_to_cache_key(v) + items.append((k, v)) + return tuple(items) + + +def _resolve_client_config() -> dict: + """Build the merged client config on the request thread. + + Must be called from the request thread where FastMCP Context is available. + Merges base config with any per-session overrides, then aligns + send_receive_timeout with the MCP query timeout. + """ client_config = get_config().get_client_config() + srt_explicitly_set = "CLICKHOUSE_SEND_RECEIVE_TIMEOUT" in os.environ try: ctx = get_context() session_config_overrides = ctx.get_state(CLIENT_CONFIG_OVERRIDES_KEY) if session_config_overrides and not isinstance(session_config_overrides, dict): - logger.warning(f"{CLIENT_CONFIG_OVERRIDES_KEY} must be a dict, got {type(session_config_overrides).__name__}. Ignoring.") + logger.warning( + f"{CLIENT_CONFIG_OVERRIDES_KEY} must be a dict, " + f"got {type(session_config_overrides).__name__}. Ignoring." + ) elif session_config_overrides: - logger.debug(f"Applying session-specific ClickHouse client config overrides: {list(session_config_overrides.keys())}") + logger.debug( + "Applying session-specific ClickHouse client config overrides: %s", + list(session_config_overrides.keys()), + ) + if "send_receive_timeout" in session_config_overrides: + srt_explicitly_set = True client_config.update(session_config_overrides) except RuntimeError: - # If we're outside a request context, just proceed with the default config + # Outside a request context — proceed with base config pass + # Align send_receive_timeout with MCP query timeout so worker threads + # unblock shortly after the MCP-level timeout fires, preventing zombie threads. + # Only auto-cap when neither env var nor session override explicitly set it. + if not srt_explicitly_set: + query_timeout = get_mcp_config().query_timeout + effective_srt = client_config.get("send_receive_timeout", 300) + if effective_srt > query_timeout + 5: + client_config["send_receive_timeout"] = query_timeout + 5 + + return client_config + + +def _evict_cached_client(config: dict) -> None: + """Evict a cached client for the given config, closing it. + + Call this when a query or command fails with a connection error so the + next call creates a fresh client instead of reusing the broken one. + """ + cache_key = _config_to_cache_key(config) + with _client_cache_lock: + entry = _client_cache.pop(cache_key, None) + if entry is not None: + client, _ = entry + logger.info("Evicted stale cached client for %s", config.get("host", "?")) + try: + client.close() + except Exception: + pass + + +def create_clickhouse_client(config: Optional[dict] = None): + """Get or create a cached ClickHouse client for the given config. + + Args: + config: Pre-resolved client config dict. When None the config is + resolved from env + session overrides (requires request context). + Pass an explicit config when calling from a worker thread. + """ + if config is None: + config = _resolve_client_config() + + cache_key = _config_to_cache_key(config) + + # Check cache — extract candidate without holding the lock during ping + candidate = None + with _client_cache_lock: + if cache_key in _client_cache: + client, last_used = _client_cache[cache_key] + if time.time() - last_used > _CLIENT_IDLE_PING_THRESHOLD: + candidate = client + else: + _client_cache[cache_key] = (client, time.time()) + logger.debug("Reusing cached client") + return client + + # Ping outside the lock so we don't serialize unrelated configs + if candidate is not None: + try: + alive = candidate.ping() + except Exception: + alive = False + + if alive: + with _client_cache_lock: + # Re-check: another thread may have evicted while we pinged + if cache_key in _client_cache: + _client_cache[cache_key] = (candidate, time.time()) + logger.debug("Reusing cached client (ping OK after idle)") + return candidate + # Was evicted by another thread; fall through to create new + else: + logger.warning("Cached client failed ping, creating new client") + with _client_cache_lock: + _client_cache.pop(cache_key, None) + try: + candidate.close() + except Exception: + pass + + # Create new client outside the lock (client creation is slow) logger.info( - f"Creating ClickHouse client connection to {client_config['host']}:{client_config['port']} " - f"as {client_config['username']} " - f"(secure={client_config['secure']}, verify={client_config['verify']}, " - f"connect_timeout={client_config['connect_timeout']}s, " - f"send_receive_timeout={client_config['send_receive_timeout']}s)" + "Creating ClickHouse client connection to %s:%s as %s " + "(secure=%s, verify=%s, connect_timeout=%ss, send_receive_timeout=%ss)", + config['host'], config['port'], config['username'], + config['secure'], config['verify'], + config['connect_timeout'], config['send_receive_timeout'], ) try: - client = clickhouse_connect.get_client(**client_config) - # Test the connection + # Disable autogenerate_session_id so the client is safe for concurrent + # use from the thread pool. clickhouse_connect rejects concurrent queries + # on the same session_id, but with this disabled each query runs + # without session affinity. + client = clickhouse_connect.get_client( + **config, autogenerate_session_id=False + ) version = client.server_version logger.info(f"Successfully connected to ClickHouse server version {version}") - return client except Exception as e: logger.error(f"Failed to connect to ClickHouse: {str(e)}") raise + with _client_cache_lock: + # Another thread may have raced and cached a client for this key + if cache_key in _client_cache: + try: + client.close() + except Exception: + pass + client, _ = _client_cache[cache_key] + _client_cache[cache_key] = (client, time.time()) + return client + _client_cache[cache_key] = (client, time.time()) + + return client + + +def _clear_client_cache(): + """Clear the client cache, closing all cached clients. + + Used during shutdown and for testing. + """ + with _client_cache_lock: + for _, (client, _) in list(_client_cache.items()): + try: + client.close() + except Exception: + pass + _client_cache.clear() + + +atexit.register(_clear_client_cache) + def build_query_settings(client) -> dict[str, str]: """Build query settings dict for ClickHouse queries. diff --git a/tests/test_client_cache.py b/tests/test_client_cache.py new file mode 100644 index 0000000..4530f0e --- /dev/null +++ b/tests/test_client_cache.py @@ -0,0 +1,267 @@ +"""Tests for ClickHouse client caching and reuse.""" + +import time +from unittest.mock import MagicMock, patch + +import pytest +from fastmcp.exceptions import ToolError + +from mcp_clickhouse.mcp_env import get_mcp_config +from mcp_clickhouse.mcp_server import ( + _active_queries, + _active_queries_lock, + _clear_client_cache, + _client_cache, + _client_cache_lock, + _config_to_cache_key, + _resolve_client_config, + create_clickhouse_client, + execute_query, +) + + +class TestConfigToCacheKey: + """Tests for the _config_to_cache_key helper.""" + + def test_deterministic_key(self): + config = {"host": "localhost", "port": 8443, "username": "default"} + assert _config_to_cache_key(config) == _config_to_cache_key(config) + + def test_order_independent(self): + config_a = {"host": "localhost", "port": 8443} + config_b = {"port": 8443, "host": "localhost"} + assert _config_to_cache_key(config_a) == _config_to_cache_key(config_b) + + def test_nested_dict(self): + config = {"host": "localhost", "settings": {"role": "admin", "readonly": "1"}} + key = _config_to_cache_key(config) + assert isinstance(key, tuple) + # Nested dict should also be a tuple + for k, v in key: + if k == "settings": + assert isinstance(v, tuple) + + def test_different_configs_different_keys(self): + config_a = {"host": "host1", "port": 8443} + config_b = {"host": "host2", "port": 8443} + assert _config_to_cache_key(config_a) != _config_to_cache_key(config_b) + + +class TestClientCaching: + """Tests for client cache behavior.""" + + def setup_method(self): + _clear_client_cache() + + def teardown_method(self): + _clear_client_cache() + + @patch("mcp_clickhouse.mcp_server.clickhouse_connect") + @patch("mcp_clickhouse.mcp_server.get_context", side_effect=RuntimeError) + def test_same_config_returns_cached_client(self, _mock_ctx, mock_cc): + """Same config should return the cached client without creating a new one.""" + mock_client = MagicMock(server_version="24.1") + mock_cc.get_client.return_value = mock_client + + client1 = create_clickhouse_client() + client2 = create_clickhouse_client() + + assert client1 is client2 + # get_client should only be called once + assert mock_cc.get_client.call_count == 1 + + @patch("mcp_clickhouse.mcp_server.clickhouse_connect") + @patch("mcp_clickhouse.mcp_server.get_context") + def test_different_config_creates_new_client(self, mock_get_context, mock_cc): + """Different session configs should produce different cached clients.""" + mock_client_a = MagicMock(server_version="24.1") + mock_client_b = MagicMock(server_version="24.1") + mock_cc.get_client.side_effect = [mock_client_a, mock_client_b] + + # First call: no overrides + mock_ctx = MagicMock() + mock_ctx.get_state.return_value = None + mock_get_context.return_value = mock_ctx + client1 = create_clickhouse_client() + + _clear_client_cache() + + # Second call: with override that changes the config key + mock_ctx2 = MagicMock() + mock_ctx2.get_state.return_value = {"connect_timeout": 99} + mock_get_context.return_value = mock_ctx2 + client2 = create_clickhouse_client() + + assert client1 is not client2 + assert mock_cc.get_client.call_count == 2 + + @patch("mcp_clickhouse.mcp_server.clickhouse_connect") + @patch("mcp_clickhouse.mcp_server.get_context", side_effect=RuntimeError) + def test_stale_client_evicted_on_ping_failure(self, _mock_ctx, mock_cc): + """Client that fails ping after idle should be evicted and recreated.""" + mock_client_old = MagicMock(server_version="24.1") + mock_client_old.ping.return_value = False + mock_client_new = MagicMock(server_version="24.2") + mock_cc.get_client.side_effect = [mock_client_old, mock_client_new] + + client1 = create_clickhouse_client() + assert client1 is mock_client_old + + # Simulate idle time exceeding threshold + with _client_cache_lock: + for key, val in _client_cache.items(): + client, _ = val + _client_cache[key] = (client, time.time() - 120) + + client2 = create_clickhouse_client() + assert client2 is mock_client_new + assert mock_cc.get_client.call_count == 2 + + @patch("mcp_clickhouse.mcp_server.clickhouse_connect") + @patch("mcp_clickhouse.mcp_server.get_context", side_effect=RuntimeError) + def test_autogenerate_session_id_disabled(self, _mock_ctx, mock_cc): + """Cached clients should be created with autogenerate_session_id=False.""" + mock_cc.get_client.return_value = MagicMock(server_version="24.1") + + create_clickhouse_client() + + call_kwargs = mock_cc.get_client.call_args[1] + assert call_kwargs["autogenerate_session_id"] is False + + @patch("mcp_clickhouse.mcp_server.clickhouse_connect") + @patch("mcp_clickhouse.mcp_server.get_context", side_effect=RuntimeError) + def test_clear_cache_closes_clients(self, _mock_ctx, mock_cc): + """_clear_client_cache should close all cached clients.""" + mock_client = MagicMock(server_version="24.1") + mock_cc.get_client.return_value = mock_client + + create_clickhouse_client() + _clear_client_cache() + + mock_client.close.assert_called_once() + + +class TestResolveClientConfig: + """Tests for _resolve_client_config.""" + + @patch("mcp_clickhouse.mcp_server.get_context", side_effect=RuntimeError) + def test_send_receive_timeout_capped_when_not_explicit(self, _mock_ctx): + """send_receive_timeout should be capped to query_timeout + 5 by default.""" + config = _resolve_client_config() + + expected = get_mcp_config().query_timeout + 5 + assert config["send_receive_timeout"] == expected + + @patch.dict("os.environ", {"CLICKHOUSE_SEND_RECEIVE_TIMEOUT": "200"}) + @patch("mcp_clickhouse.mcp_server.get_context", side_effect=RuntimeError) + def test_send_receive_timeout_not_capped_when_explicit(self, _mock_ctx): + """Explicit env var should bypass the auto-cap.""" + config = _resolve_client_config() + assert config["send_receive_timeout"] == 200 + + @patch("mcp_clickhouse.mcp_server.get_context") + def test_session_override_timeout_not_capped(self, mock_get_context): + """Session override of send_receive_timeout should bypass the auto-cap.""" + mock_ctx = MagicMock() + mock_ctx.get_state.return_value = {"send_receive_timeout": 300} + mock_get_context.return_value = mock_ctx + + config = _resolve_client_config() + assert config["send_receive_timeout"] == 300 + + +class TestEvictionOnError: + """Tests for client eviction on connection errors.""" + + def setup_method(self): + _clear_client_cache() + with _active_queries_lock: + _active_queries.clear() + + def teardown_method(self): + _clear_client_cache() + with _active_queries_lock: + _active_queries.clear() + + @patch("mcp_clickhouse.mcp_server.clickhouse_connect") + @patch("mcp_clickhouse.mcp_server.get_context", side_effect=RuntimeError) + def test_execute_query_evicts_on_connection_error(self, _mock_ctx, mock_cc): + """execute_query should evict the cached client on connection errors.""" + mock_client = MagicMock(server_version="24.1") + mock_client.server_settings = {} + mock_client.query.side_effect = ConnectionError("connection reset") + mock_cc.get_client.return_value = mock_client + + config = _resolve_client_config() + + with pytest.raises(ToolError, match="connection reset"): + execute_query("SELECT 1", "evict-test", config) + + # Client should have been evicted — next call creates a new one + mock_client_new = MagicMock(server_version="24.2") + mock_client_new.server_settings = {} + mock_result = MagicMock() + mock_result.result_rows = [] + mock_result.column_names = [] + mock_client_new.query.return_value = mock_result + mock_cc.get_client.return_value = mock_client_new + + execute_query("SELECT 1", "evict-test-2", config) + assert mock_cc.get_client.call_count == 2 + + @patch("mcp_clickhouse.mcp_server.clickhouse_connect") + @patch("mcp_clickhouse.mcp_server.get_context", side_effect=RuntimeError) + def test_execute_query_no_evict_on_sql_error(self, _mock_ctx, mock_cc): + """execute_query should NOT evict on normal SQL errors (not connection).""" + mock_client = MagicMock(server_version="24.1") + mock_client.server_settings = {} + mock_client.query.side_effect = Exception("Unknown column 'x'") + mock_cc.get_client.return_value = mock_client + + config = _resolve_client_config() + + with pytest.raises(ToolError): + execute_query("SELECT x", "no-evict-test", config) + + # Client should still be cached, second call reuses it + mock_client.query.side_effect = None + mock_result = MagicMock() + mock_result.result_rows = [] + mock_result.column_names = [] + mock_client.query.return_value = mock_result + execute_query("SELECT 1", "no-evict-test-2", config) + + # get_client only called once, reused from cache + assert mock_cc.get_client.call_count == 1 + + +class TestPingExceptionHandling: + """Tests for ping exception handling in create_clickhouse_client.""" + + def setup_method(self): + _clear_client_cache() + + def teardown_method(self): + _clear_client_cache() + + @patch("mcp_clickhouse.mcp_server.clickhouse_connect") + @patch("mcp_clickhouse.mcp_server.get_context", side_effect=RuntimeError) + def test_ping_exception_evicts_and_recreates(self, _mock_ctx, mock_cc): + """A ping() that raises should evict the client and create a new one.""" + mock_client_old = MagicMock(server_version="24.1") + mock_client_old.ping.side_effect = Exception("boom") + mock_client_new = MagicMock(server_version="24.2") + mock_cc.get_client.side_effect = [mock_client_old, mock_client_new] + + client1 = create_clickhouse_client() + assert client1 is mock_client_old + + # Simulate idle time exceeding threshold + with _client_cache_lock: + for key, val in _client_cache.items(): + client, _ = val + _client_cache[key] = (client, time.time() - 120) + + client2 = create_clickhouse_client() + assert client2 is mock_client_new + assert mock_cc.get_client.call_count == 2 diff --git a/tests/test_context_config_override.py b/tests/test_context_config_override.py index 789d131..0248ae9 100644 --- a/tests/test_context_config_override.py +++ b/tests/test_context_config_override.py @@ -10,6 +10,7 @@ from mcp_clickhouse.mcp_server import ( mcp, create_clickhouse_client, + _clear_client_cache, CLIENT_CONFIG_OVERRIDES_KEY, ) @@ -29,6 +30,14 @@ async def on_call_tool(self, context: MiddlewareContext, call_next: CallNext): class TestConfigOverrideUnit: """Unit tests for the config override merge logic in create_clickhouse_client.""" + def setup_method(self): + """Clear the client cache before each test.""" + _clear_client_cache() + + def teardown_method(self): + """Clear the client cache after each test.""" + _clear_client_cache() + @patch("mcp_clickhouse.mcp_server.clickhouse_connect") @patch("mcp_clickhouse.mcp_server.get_context") def test_overrides_merged_into_client_config(self, mock_get_context, mock_cc): diff --git a/tests/test_query_cancellation.py b/tests/test_query_cancellation.py new file mode 100644 index 0000000..d3a57f9 --- /dev/null +++ b/tests/test_query_cancellation.py @@ -0,0 +1,189 @@ +"""Tests for query ID tracking and server-side cancellation.""" + +import concurrent.futures +from unittest.mock import MagicMock, patch + +import pytest +from fastmcp.exceptions import ToolError + +from mcp_clickhouse.mcp_server import ( + _active_queries, + _active_queries_lock, + _cancel_query, + _clear_client_cache, + _client_cache, + _client_cache_lock, + _resolve_client_config, + execute_query, + run_query, +) + + +class TestQueryIdTracking: + """Tests for query_id propagation through execute_query.""" + + def setup_method(self): + _clear_client_cache() + with _active_queries_lock: + _active_queries.clear() + + def teardown_method(self): + _clear_client_cache() + with _active_queries_lock: + _active_queries.clear() + + @patch("mcp_clickhouse.mcp_server.clickhouse_connect") + @patch("mcp_clickhouse.mcp_server.get_context", side_effect=RuntimeError) + def test_query_id_passed_in_settings(self, _mock_ctx, mock_cc): + """query_id should be included in the settings dict passed to client.query().""" + mock_client = MagicMock(server_version="24.1") + mock_client.server_settings = {} + mock_result = MagicMock() + mock_result.result_rows = [("row1",)] + mock_result.column_names = ["col1"] + mock_client.query.return_value = mock_result + mock_cc.get_client.return_value = mock_client + + config = _resolve_client_config() + execute_query("SELECT 1", "test-query-id-123", config) + + # Verify query_id was passed in settings + call_args = mock_client.query.call_args + settings = call_args[1].get("settings") or call_args.kwargs.get("settings") + assert settings["query_id"] == "test-query-id-123" + + @patch("mcp_clickhouse.mcp_server.clickhouse_connect") + @patch("mcp_clickhouse.mcp_server.get_context", side_effect=RuntimeError) + def test_active_queries_tracked_and_cleaned(self, _mock_ctx, mock_cc): + """execute_query should register in _active_queries and clean up on completion.""" + mock_client = MagicMock(server_version="24.1") + mock_client.server_settings = {} + mock_result = MagicMock() + mock_result.result_rows = [] + mock_result.column_names = [] + mock_client.query.return_value = mock_result + mock_cc.get_client.return_value = mock_client + + config = _resolve_client_config() + + # Before execution + with _active_queries_lock: + assert "tracking-test-id" not in _active_queries + + execute_query("SELECT 1", "tracking-test-id", config) + + # After completion, should be cleaned up + with _active_queries_lock: + assert "tracking-test-id" not in _active_queries + + @patch("mcp_clickhouse.mcp_server.clickhouse_connect") + @patch("mcp_clickhouse.mcp_server.get_context", side_effect=RuntimeError) + def test_active_queries_cleaned_on_error(self, _mock_ctx, mock_cc): + """execute_query should clean up _active_queries even on error.""" + mock_client = MagicMock(server_version="24.1") + mock_client.server_settings = {} + mock_client.query.side_effect = Exception("DB error") + mock_cc.get_client.return_value = mock_client + + config = _resolve_client_config() + + with pytest.raises(ToolError): + execute_query("SELECT bad", "error-test-id", config) + + with _active_queries_lock: + assert "error-test-id" not in _active_queries + + +class TestCancelQuery: + """Tests for _cancel_query server-side cancellation.""" + + def setup_method(self): + _clear_client_cache() + with _active_queries_lock: + _active_queries.clear() + + def teardown_method(self): + _clear_client_cache() + with _active_queries_lock: + _active_queries.clear() + + def test_cancel_issues_kill_query(self): + """_cancel_query should issue KILL QUERY via the cached client.""" + mock_client = MagicMock() + cache_key = (("host", "localhost"), ("port", 8443)) + + # Set up cached client and active query + with _client_cache_lock: + _client_cache[cache_key] = (mock_client, 0) + with _active_queries_lock: + _active_queries["kill-test-id"] = (cache_key, "SELECT sleep(60)") + + _cancel_query("kill-test-id") + + mock_client.command.assert_called_once_with("KILL QUERY WHERE query_id = 'kill-test-id'") + # Should be removed from active queries + with _active_queries_lock: + assert "kill-test-id" not in _active_queries + + def test_cancel_noop_for_completed_query(self): + """_cancel_query should be a no-op if the query already completed.""" + # No entry in _active_queries + _cancel_query("nonexistent-id") # Should not raise + + def test_cancel_warns_without_cached_client(self): + """_cancel_query should log warning if no cached client is available.""" + cache_key = (("host", "gone"),) + with _active_queries_lock: + _active_queries["orphan-id"] = (cache_key, "SELECT 1") + + # No client in cache for this key + _cancel_query("orphan-id") # Should not raise + + with _active_queries_lock: + assert "orphan-id" not in _active_queries + + def test_cancel_failure_does_not_raise(self): + """_cancel_query should swallow exceptions from KILL QUERY.""" + mock_client = MagicMock() + mock_client.command.side_effect = Exception("Permission denied") + cache_key = (("host", "localhost"),) + + with _client_cache_lock: + _client_cache[cache_key] = (mock_client, 0) + with _active_queries_lock: + _active_queries["fail-id"] = (cache_key, "SELECT 1") + + _cancel_query("fail-id") # Should not raise + + +class TestRunQueryTimeout: + """Tests for run_query timeout triggering _cancel_query.""" + + def setup_method(self): + _clear_client_cache() + with _active_queries_lock: + _active_queries.clear() + + def teardown_method(self): + _clear_client_cache() + with _active_queries_lock: + _active_queries.clear() + + @patch("mcp_clickhouse.mcp_server._cancel_query") + @patch("mcp_clickhouse.mcp_server.QUERY_EXECUTOR") + @patch("mcp_clickhouse.mcp_server.get_context", side_effect=RuntimeError) + def test_timeout_triggers_cancel(self, _mock_ctx, mock_executor, mock_cancel): + """When run_query times out, it should call _cancel_query with the query_id.""" + mock_future = MagicMock() + mock_future.result.side_effect = concurrent.futures.TimeoutError() + mock_executor.submit.return_value = mock_future + mock_executor._work_queue.qsize.return_value = 0 + + with pytest.raises(ToolError, match="timed out"): + run_query("SELECT sleep(999)") + + # _cancel_query should have been called with the generated query_id + mock_cancel.assert_called_once() + query_id = mock_cancel.call_args[0][0] + assert isinstance(query_id, str) + assert len(query_id) > 0