diff --git a/.claude/CLAUDE.md b/.claude/CLAUDE.md index 65309480..dc876f38 100644 --- a/.claude/CLAUDE.md +++ b/.claude/CLAUDE.md @@ -53,6 +53,7 @@ AUTH_PIN=123456 # Optional: access code when auth is enabled REQUIRE_AUTH=true # Force authentication in local development PORT=5000 # Set by hosting platforms to indicate deployed mode SECRET_KEY=override-me # Optional: otherwise random key generated per launch +TOOL_SEARCH_MODE=hybrid # Tool discovery: local | api | hybrid (default: hybrid) ``` ## Running the App diff --git a/README.md b/README.md index af7ab713..90e495a7 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,7 @@ MatHud pairs an interactive drawing canvas with an AI assistant to help visualiz REQUIRE_AUTH=true # Force authentication in local development PORT=5000 # Set by hosting platforms to indicate deployed mode SECRET_KEY=override-me # Optional: otherwise a random key is generated per launch + TOOL_SEARCH_MODE=hybrid # Tool discovery: local | api | hybrid (default: hybrid) ``` 2. Authentication rules (`static/app_manager.py`): 1. When `PORT` is set (typical in hosted deployments), authentication is enforced automatically. diff --git a/documentation/Project Architecture.txt b/documentation/Project Architecture.txt index d630312c..87cce4bd 100644 --- a/documentation/Project Architecture.txt +++ b/documentation/Project Architecture.txt @@ -38,7 +38,8 @@ MatHud is an interactive mathematical visualization tool that combines a drawing - Coordinate Conversion: rectangular to polar and polar to rectangular transformations - Equation Solving: linear equations, quadratic equations, systems of equations, numeric solving for transcendental/nonlinear systems via multi-start Newton-Raphson - Linear Algebra: matrix and vector arithmetic, transpose, inverse, determinant, dot/cross products, norms, traces, diagonal/identity/zeros/ones helpers, reshape/size queries, and grouped expressions evaluated via MathJS -- Advanced Operations: 70 AI tool definitions available in `static/functions_definitions.py` including canvas operations, geometric shape creation/deletion, mathematical operations, transformations, coordinate system management, and workspace management +- Advanced Operations: 90+ AI tool definitions available in `static/functions_definitions.py` including canvas operations, geometric shape creation/deletion, mathematical operations, transformations, coordinate system management, and workspace management +- Tool Discovery: `static/tool_search_service.py` provides fast local keyword/category search (hybrid mode by default) to select relevant tools per query without requiring an API call for every interaction ## Geometric Shape Management - Points: create, delete, translate points with coordinates and labels diff --git a/documentation/Reference Manual.txt b/documentation/Reference Manual.txt index c80c71f3..3626edec 100644 --- a/documentation/Reference Manual.txt +++ b/documentation/Reference Manual.txt @@ -3761,6 +3761,37 @@ Dependencies: - Parameter validation: Strict JSON schema enforcement for all function arguments - Type safety: Required parameters and type checking for robust AI integration +### Tool Search Service (`static/tool_search_service.py`) + +Provides tool discovery via fast local keyword/category matching (default) or AI-powered semantic matching. The search mode is selected by the `TOOL_SEARCH_MODE` environment variable: + +- `local` — keyword + category index, no API call (~0.01ms per query) +- `api` — AI-powered semantic search via gpt-5-nano +- `hybrid` (default) — local first, falls back to API when confidence is low + +**Architecture:** +- 13 tool categories (geometry_create, geometry_delete, geometry_update, constructions, functions_plots, math, statistics, graph_theory, canvas, workspace, transforms, areas, inspection) with keyword triggers +- Inverted indices built at module load time for O(1) token lookups +- Multi-signal scoring: category boost, name index match, description index match, exact tool name match, action-verb alignment, and intent-based disambiguation boosts +- LRU result cache with 5-minute TTL (100 entries max) +- Lazy OpenAI client initialization — local mode never touches the network + +**Class: ToolSearchService** + +Key Methods: +- `search_tools(query, model=None, max_results=10)`: Main entry point; dispatches to local or API search based on mode +- `search_tools_local(query, max_results=10)`: Fast local keyword/category search (no API call) +- `search_tools_formatted(query, model=None, max_results=10)`: Returns dict with tools list, count, and query for AI consumption +- `get_tool_by_name(name)`: Look up a tool definition by name (static method) +- `build_tool_descriptions(exclude_meta_tools=True)`: Build compact string of tool names and descriptions (static method) +- `get_all_tools()`: Get all available tool definitions (static method) + +Module-level Functions: +- `clear_search_cache()`: Clear the search result cache + +**Environment Variables:** +- `TOOL_SEARCH_MODE`: Search mode selection (`local`, `api`, or `hybrid`; default: `hybrid`) + ### Flask Application Manager (`static/app_manager.py`) **File Header:** diff --git a/scripts/compare_search_modes.py b/scripts/compare_search_modes.py new file mode 100644 index 00000000..7d35a628 --- /dev/null +++ b/scripts/compare_search_modes.py @@ -0,0 +1,298 @@ +#!/usr/bin/env python3 +"""Compare local vs API tool search accuracy and latency. + +Runs the benchmark dataset against both search modes and outputs a +side-by-side comparison table with disagreement analysis. + +Usage:: + + # Local only (no API key needed) + python scripts/compare_search_modes.py --modes local + + # Full comparison (needs API key) + python scripts/compare_search_modes.py --modes local,api + + # Save disagreements to CSV + python scripts/compare_search_modes.py --modes local,api --disagreements /tmp/disagreements.csv +""" + +from __future__ import annotations + +import argparse +import csv +import json +import os +import sys +import time +from pathlib import Path +from typing import Any, Dict, List, Optional + +_project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if _project_root not in sys.path: + sys.path.insert(0, _project_root) + +from static.tool_search_service import ToolSearchService, clear_search_cache + +DATASET_PATH = Path("server_tests/data/tool_discovery_cases.yaml") + + +def _load_dataset() -> Dict[str, Any]: + raw = DATASET_PATH.read_text(encoding="utf-8") + parsed: Dict[str, Any] = json.loads(raw) + return parsed + + +def _get_tool_name(tool: Dict[str, Any]) -> str: + name: str = tool.get("function", {}).get("name", "") + return name + + +def _run_mode( + mode: str, + cases: List[Dict[str, Any]], + max_results: int = 10, +) -> Dict[str, Any]: + """Run benchmark for a single search mode.""" + os.environ["TOOL_SEARCH_MODE"] = mode + clear_search_cache() + + service = ToolSearchService.__new__(ToolSearchService) + service._client = None + service._client_initialized = False + service.default_model = None + service.last_error = None + + # For API mode, initialize the client properly + if mode == "api": + try: + service = ToolSearchService() + except ValueError: + print(f" [SKIP] API mode requires OPENAI_API_KEY") + return {"skipped": True} + + top1_hits = 0 + top3_hits = 0 + top5_hits = 0 + evaluated = 0 + latencies: List[float] = [] + case_results: List[Dict[str, Any]] = [] + + for case in cases: + expected_any = [str(x) for x in case.get("expected_any", []) if isinstance(x, str)] + if not expected_any: + continue + + query = str(case.get("query", "")).strip() + evaluated += 1 + + start = time.perf_counter() + if mode == "local": + results = service.search_tools_local(query, max_results) + else: + results = service.search_tools(query, max_results=max_results) + elapsed_ms = (time.perf_counter() - start) * 1000 + latencies.append(elapsed_ms) + + ranked = [_get_tool_name(t) for t in results] + expected_set = set(expected_any) + + top1 = ranked[0] if ranked else "" + is_top1 = top1 in expected_set + is_top3 = bool(expected_set & set(ranked[:3])) + is_top5 = bool(expected_set & set(ranked[:5])) + + if is_top1: + top1_hits += 1 + if is_top3: + top3_hits += 1 + if is_top5: + top5_hits += 1 + + case_results.append({ + "id": case.get("id", "?"), + "query": query, + "expected": expected_any, + "top1": top1, + "top3": ranked[:3], + "is_top1": is_top1, + "is_top3": is_top3, + "is_top5": is_top5, + }) + + # Rate-limit API calls + if mode == "api": + time.sleep(0.5) + + avg_latency = sum(latencies) / len(latencies) if latencies else 0 + latencies.sort() + p50 = latencies[len(latencies) // 2] if latencies else 0 + p99_idx = min(int(len(latencies) * 0.99), len(latencies) - 1) if latencies else 0 + p99 = latencies[p99_idx] if latencies else 0 + + return { + "skipped": False, + "evaluated": evaluated, + "top1_hits": top1_hits, + "top3_hits": top3_hits, + "top5_hits": top5_hits, + "top1_rate": top1_hits / evaluated if evaluated else 0, + "top3_rate": top3_hits / evaluated if evaluated else 0, + "top5_rate": top5_hits / evaluated if evaluated else 0, + "avg_latency_ms": avg_latency, + "p50_latency_ms": p50, + "p99_latency_ms": p99, + "case_results": case_results, + } + + +def _find_disagreements( + local_results: Dict[str, Any], + api_results: Dict[str, Any], +) -> List[Dict[str, Any]]: + """Find cases where local and API disagree.""" + disagreements: List[Dict[str, Any]] = [] + local_cases = local_results.get("case_results", []) + api_cases = api_results.get("case_results", []) + + api_by_id = {c["id"]: c for c in api_cases} + + for lc in local_cases: + ac = api_by_id.get(lc["id"]) + if ac is None: + continue + + if lc["is_top1"] != ac["is_top1"]: + disagreements.append({ + "id": lc["id"], + "query": lc["query"], + "expected": lc["expected"], + "local_top1": lc["top1"], + "api_top1": ac["top1"], + "local_correct": lc["is_top1"], + "api_correct": ac["is_top1"], + "winner": "local" if lc["is_top1"] else "api", + }) + + return disagreements + + +def main() -> int: + parser = argparse.ArgumentParser(description="Compare tool search modes") + parser.add_argument( + "--modes", + default="local", + help="Comma-separated search modes to compare (local,api)", + ) + parser.add_argument( + "--max-results", + type=int, + default=10, + help="Max results per search (default: 10)", + ) + parser.add_argument( + "--disagreements", + default="", + help="Path to write disagreements CSV", + ) + args = parser.parse_args() + + modes = [m.strip() for m in args.modes.split(",") if m.strip()] + if not modes: + print("No modes specified") + return 1 + + dataset = _load_dataset() + cases = dataset.get("cases", []) + + print(f"Dataset: {len(cases)} cases") + print() + + all_results: Dict[str, Dict[str, Any]] = {} + + for mode in modes: + print(f"Running mode: {mode}") + result = _run_mode(mode, cases, args.max_results) + all_results[mode] = result + + if result.get("skipped"): + print(f" Skipped (missing credentials)") + continue + + print(f" Evaluated: {result['evaluated']}") + print(f" Top-1: {result['top1_hits']}/{result['evaluated']} = {result['top1_rate']:.3f}") + print(f" Top-3: {result['top3_hits']}/{result['evaluated']} = {result['top3_rate']:.3f}") + print(f" Top-5: {result['top5_hits']}/{result['evaluated']} = {result['top5_rate']:.3f}") + print(f" Latency: avg={result['avg_latency_ms']:.1f}ms, " + f"p50={result['p50_latency_ms']:.1f}ms, p99={result['p99_latency_ms']:.1f}ms") + print() + + # Side-by-side comparison + if len(all_results) >= 2: + active = {k: v for k, v in all_results.items() if not v.get("skipped")} + if len(active) >= 2: + print("=" * 60) + print("Side-by-Side Comparison") + print("=" * 60) + header = f"{'Metric':<20}" + for mode in active: + header += f" {mode:>15}" + print(header) + print("-" * 60) + + for metric in ["top1_rate", "top3_rate", "top5_rate", "avg_latency_ms", "p50_latency_ms", "p99_latency_ms"]: + row = f"{metric:<20}" + for mode in active: + val = active[mode].get(metric, 0) + if "rate" in metric: + row += f" {val:>14.3f}" + else: + row += f" {val:>13.1f}ms" + print(row) + print() + + # Disagreement analysis + if "local" in all_results and "api" in all_results: + local_r = all_results["local"] + api_r = all_results["api"] + if not local_r.get("skipped") and not api_r.get("skipped"): + disagreements = _find_disagreements(local_r, api_r) + local_wins = [d for d in disagreements if d["winner"] == "local"] + api_wins = [d for d in disagreements if d["winner"] == "api"] + + print(f"Disagreements: {len(disagreements)} total") + print(f" Local wins: {len(local_wins)}") + print(f" API wins: {len(api_wins)}") + + if api_wins: + print(f"\nCases where API is right but local is wrong (tuning opportunities):") + for d in api_wins[:10]: + print(f" {d['id']}: {d['query']!r}") + print(f" expected={d['expected']}, local={d['local_top1']!r}, api={d['api_top1']!r}") + + if local_wins: + print(f"\nCases where local is right but API is wrong (local advantages):") + for d in local_wins[:10]: + print(f" {d['id']}: {d['query']!r}") + print(f" expected={d['expected']}, local={d['local_top1']!r}, api={d['api_top1']!r}") + + # Write disagreements CSV + if args.disagreements and disagreements: + dis_path = Path(args.disagreements) + dis_path.parent.mkdir(parents=True, exist_ok=True) + with dis_path.open("w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=[ + "id", "query", "expected", "local_top1", "api_top1", + "local_correct", "api_correct", "winner", + ]) + writer.writeheader() + for d in disagreements: + csv_row: Dict[str, Any] = dict(d) + csv_row["expected"] = "|".join(csv_row["expected"]) + writer.writerow(csv_row) + print(f"\nDisagreements written to: {dis_path}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/server_tests/test_prompt_pipeline.py b/server_tests/test_prompt_pipeline.py new file mode 100644 index 00000000..877a7d2d --- /dev/null +++ b/server_tests/test_prompt_pipeline.py @@ -0,0 +1,418 @@ +"""Mocked end-to-end prompt pipeline tests. + +Verifies the full pipeline: natural-language prompt → real local tool search +→ real filtering → correct tool calls returned to the client. + +Only the OpenAI API call is mocked; ``_intercept_search_tools``, +``ToolSearchService.search_tools_local``, and the filtering helpers all run +for real with ``TOOL_SEARCH_MODE=local``. +""" + +from __future__ import annotations + +import json +import os +import unittest +from typing import Any, Dict, Iterator, List, Optional +from unittest.mock import Mock, patch + +from static.app_manager import AppManager, MatHudFlask +from static.openai_completions_api import OpenAIChatCompletionsAPI +from static.openai_responses_api import OpenAIResponsesAPI +from static.tool_search_service import clear_search_cache + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_payload(msg: str, model: str = "gpt-4.1") -> Dict[str, Any]: + """Build the POST body expected by ``/send_message_stream`` and ``/send_message``.""" + return { + "message": json.dumps( + {"user_message": msg, "use_vision": False, "ai_model": model} + ), + "svg_state": None, + } + + +def _search_call(query: str) -> Dict[str, Any]: + """Build a ``search_tools`` tool-call dict.""" + return {"function_name": "search_tools", "arguments": {"query": query}} + + +def _tool_call(name: str, args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + """Build an action tool-call dict.""" + return {"function_name": name, "arguments": args or {}} + + +def _mock_stream_final( + message: str, + tool_calls: List[Dict[str, Any]], +) -> Iterator[Dict[str, Any]]: + """Return a single-event stream matching the ``create_chat_completion_stream`` contract.""" + return iter( + [ + { + "type": "final", + "ai_message": message, + "ai_tool_calls": tool_calls, + "finish_reason": "tool_calls", + } + ] + ) + + +def _parse_ndjson_events(response_data: bytes) -> List[Dict[str, Any]]: + """Parse NDJSON response body into a list of event dicts.""" + return [ + json.loads(line) + for line in response_data.decode("utf-8").split("\n") + if line.strip() + ] + + +def _get_final_event(events: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + """Return the last ``final`` event, if any.""" + finals = [e for e in events if isinstance(e, dict) and e.get("type") == "final"] + return finals[-1] if finals else None + + +def _get_final_tool_names(events: List[Dict[str, Any]]) -> List[str]: + """Extract function names from the final event's tool calls.""" + final = _get_final_event(events) + if final is None: + return [] + return [ + tc.get("function_name", "") + for tc in final.get("ai_tool_calls", []) + if isinstance(tc, dict) + ] + + +# --------------------------------------------------------------------------- +# Streaming tests (14) +# --------------------------------------------------------------------------- + + +class TestPromptPipelineStream(unittest.TestCase): + """Full pipeline tests via ``/send_message_stream``. + + Mock only the OpenAI streaming call; everything else runs for real. + """ + + def setUp(self) -> None: + self._saved_env: Dict[str, Optional[str]] = {} + for key in ("REQUIRE_AUTH", "TOOL_SEARCH_MODE"): + self._saved_env[key] = os.environ.get(key) + os.environ["REQUIRE_AUTH"] = "false" + os.environ["TOOL_SEARCH_MODE"] = "local" + + clear_search_cache() + + self.app: MatHudFlask = AppManager.create_app() + self.app.config["TESTING"] = True + self.client = self.app.test_client() + + def tearDown(self) -> None: + for key, val in self._saved_env.items(): + if val is not None: + os.environ[key] = val + else: + os.environ.pop(key, None) + clear_search_cache() + + # -- helpers -- + + def _post_stream( + self, + msg: str, + tool_calls: List[Dict[str, Any]], + ai_message: str = "Using tools", + ) -> List[Dict[str, Any]]: + """POST to ``/send_message_stream`` with a mocked final event and return parsed events.""" + with patch.object( + OpenAIChatCompletionsAPI, + "create_chat_completion_stream", + return_value=_mock_stream_final(ai_message, tool_calls), + ): + resp = self.client.post( + "/send_message_stream", + json=_make_payload(msg), + ) + self.assertEqual(resp.status_code, 200) + return _parse_ndjson_events(resp.data) + + def _assert_tool_passes( + self, + msg: str, + search_query: str, + action_name: str, + action_args: Optional[Dict[str, Any]] = None, + ) -> None: + """Assert that ``action_name`` survives the pipeline.""" + events = self._post_stream( + msg, + [_search_call(search_query), _tool_call(action_name, action_args)], + ) + names = _get_final_tool_names(events) + self.assertIn(action_name, names) + + # -- individual tests -- + + def test_circle_creation(self) -> None: + self._assert_tool_passes( + "draw a circle", + "draw circle", + "create_circle", + {"center_x": 0, "center_y": 0, "radius": 5}, + ) + + def test_triangle_creation(self) -> None: + self._assert_tool_passes( + "create a triangle", + "create triangle", + "create_polygon", + {"vertices": [[0, 0], [4, 0], [2, 3]]}, + ) + + def test_derivative(self) -> None: + self._assert_tool_passes( + "find the derivative of x^2", + "derivative", + "derive", + {"expression": "x^2", "variable": "x"}, + ) + + def test_solve_equation(self) -> None: + self._assert_tool_passes( + "solve x^2 - 1 = 0", + "solve equation", + "solve", + {"expression": "x^2-1=0", "variable": "x"}, + ) + + def test_plot_distribution(self) -> None: + self._assert_tool_passes( + "plot a normal distribution", + "plot normal distribution", + "plot_distribution", + {"distribution_type": "normal", "mean": 0, "std_dev": 1}, + ) + + def test_descriptive_stats(self) -> None: + self._assert_tool_passes( + "compute descriptive statistics for [1,2,3]", + "descriptive statistics", + "compute_descriptive_statistics", + {"data": [1, 2, 3]}, + ) + + def test_create_graph(self) -> None: + self._assert_tool_passes( + "create a weighted graph", + "create weighted graph vertices edges", + "generate_graph", + {"graph_name": "G1", "vertices": ["A", "B"]}, + ) + + def test_undo(self) -> None: + self._assert_tool_passes("undo last action", "undo", "undo") + + def test_save_workspace(self) -> None: + self._assert_tool_passes( + "save my workspace", + "save workspace", + "save_workspace", + {"name": "MyProject"}, + ) + + def test_rotate_object(self) -> None: + self._assert_tool_passes( + "rotate the triangle", + "rotate triangle", + "rotate_object", + {"object_name": "t1", "angle": 45}, + ) + + def test_multi_tool(self) -> None: + """Both ``create_point`` and ``create_segment`` should pass.""" + events = self._post_stream( + "create a point and a segment", + [ + _search_call("create point segment"), + _tool_call("create_point", {"x": 0, "y": 0}), + _tool_call("create_segment", {"x1": 0, "y1": 0, "x2": 1, "y2": 1}), + ], + ) + names = _get_final_tool_names(events) + self.assertIn("create_point", names) + self.assertIn("create_segment", names) + + def test_filters_irrelevant_tool(self) -> None: + """``analyze_graph`` should be filtered out for a circle query.""" + events = self._post_stream( + "draw a circle", + [ + _search_call("draw circle"), + _tool_call("create_circle", {"center_x": 0, "center_y": 0, "radius": 5}), + _tool_call("analyze_graph", {"graph_name": "G1", "algorithm": "bfs"}), + ], + ) + names = _get_final_tool_names(events) + self.assertIn("create_circle", names) + self.assertNotIn("analyze_graph", names) + + def test_essential_passthrough(self) -> None: + """Essential tools pass even if not in search results.""" + events = self._post_stream( + "get canvas state and undo", + [ + _search_call("canvas state undo"), + _tool_call("get_current_canvas_state"), + _tool_call("undo"), + ], + ) + names = _get_final_tool_names(events) + self.assertIn("get_current_canvas_state", names) + self.assertIn("undo", names) + + def test_no_search_tools_passthrough(self) -> None: + """When no ``search_tools`` call is present, all tools pass unfiltered.""" + events = self._post_stream( + "create a point", + [_tool_call("create_point", {"x": 5, "y": 10})], + ) + names = _get_final_tool_names(events) + self.assertIn("create_point", names) + + +# --------------------------------------------------------------------------- +# Non-streaming tests (3) +# --------------------------------------------------------------------------- + + +class TestPromptPipelineNonStream(unittest.TestCase): + """Full pipeline tests via ``/send_message``. + + Mock only the OpenAI call; everything else runs for real. + """ + + def setUp(self) -> None: + self._saved_env: Dict[str, Optional[str]] = {} + for key in ("REQUIRE_AUTH", "TOOL_SEARCH_MODE"): + self._saved_env[key] = os.environ.get(key) + os.environ["REQUIRE_AUTH"] = "false" + os.environ["TOOL_SEARCH_MODE"] = "local" + + clear_search_cache() + + self.app: MatHudFlask = AppManager.create_app() + self.app.config["TESTING"] = True + self.client = self.app.test_client() + + def tearDown(self) -> None: + for key, val in self._saved_env.items(): + if val is not None: + os.environ[key] = val + else: + os.environ.pop(key, None) + clear_search_cache() + + # -- reasoning model (o3): uses create_response_stream, consumed via /send_message -- + + @patch.object(OpenAIResponsesAPI, "create_response_stream") + def test_reasoning_model_derivative(self, mock_stream: Mock) -> None: + mock_stream.return_value = _mock_stream_final( + "Taking derivative", + [ + _search_call("derivative"), + _tool_call("derive", {"expression": "x^3", "variable": "x"}), + ], + ) + resp = self.client.post("/send_message", json=_make_payload("derivative of x^3", "o3")) + data = json.loads(resp.data) + + self.assertEqual(resp.status_code, 200) + self.assertEqual(data["status"], "success") + tool_names = [ + tc.get("function_name", "") + for tc in data["data"]["ai_tool_calls"] + if isinstance(tc, dict) + ] + self.assertIn("derive", tool_names) + + # -- chat completions model (gpt-4.1): uses create_chat_completion -- + + @patch.object(OpenAIChatCompletionsAPI, "create_chat_completion") + def test_chat_completion_circle(self, mock_completion: Mock) -> None: + mock_completion.return_value = self._make_mock_choice( + "Drawing circle", + [ + ("search_tools", json.dumps({"query": "draw circle"})), + ("create_circle", json.dumps({"center_x": 0, "center_y": 0, "radius": 5})), + ], + ) + resp = self.client.post("/send_message", json=_make_payload("draw a circle", "gpt-4.1")) + data = json.loads(resp.data) + + self.assertEqual(resp.status_code, 200) + self.assertEqual(data["status"], "success") + tool_names = [ + tc.get("function_name", "") + for tc in data["data"]["ai_tool_calls"] + if isinstance(tc, dict) + ] + self.assertIn("create_circle", tool_names) + + @patch.object(OpenAIChatCompletionsAPI, "create_chat_completion") + def test_chat_completion_filters_irrelevant(self, mock_completion: Mock) -> None: + mock_completion.return_value = self._make_mock_choice( + "Drawing circle", + [ + ("search_tools", json.dumps({"query": "draw circle"})), + ("create_circle", json.dumps({"center_x": 0, "center_y": 0, "radius": 5})), + ("analyze_graph", json.dumps({"graph_name": "G1", "algorithm": "bfs"})), + ], + ) + resp = self.client.post("/send_message", json=_make_payload("draw a circle", "gpt-4.1")) + data = json.loads(resp.data) + + self.assertEqual(resp.status_code, 200) + self.assertEqual(data["status"], "success") + tool_names = [ + tc.get("function_name", "") + for tc in data["data"]["ai_tool_calls"] + if isinstance(tc, dict) + ] + self.assertIn("create_circle", tool_names) + self.assertNotIn("analyze_graph", tool_names) + + # -- mock builder -- + + @staticmethod + def _make_mock_choice( + content: str, + tool_calls: List[tuple[str, str]], + ) -> Any: + """Build a ``SimpleNamespace`` choice matching the ``ToolCallObject`` protocol. + + Each entry in *tool_calls* is ``(function_name, arguments_json_str)``. + """ + from types import SimpleNamespace + + mock_tool_calls = [ + SimpleNamespace( + id=f"call_{i}", + function=SimpleNamespace(name=name, arguments=args), + ) + for i, (name, args) in enumerate(tool_calls) + ] + return SimpleNamespace( + finish_reason="tool_calls", + message=SimpleNamespace(content=content, tool_calls=mock_tool_calls), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/server_tests/test_tool_discovery_live.py b/server_tests/test_tool_discovery_live.py index 29c096f5..a441b61c 100644 --- a/server_tests/test_tool_discovery_live.py +++ b/server_tests/test_tool_discovery_live.py @@ -31,6 +31,8 @@ "ranked", "status", "error", + "search_ms", + "search_mode", ] @@ -116,7 +118,7 @@ def _search_ranked_names_with_model( content = str(content) tool_names = service._parse_tool_names(content) - ranked: List[str] = [] + ranked = [] for name in tool_names: if name in ESSENTIAL_TOOLS: continue @@ -305,6 +307,7 @@ def test_live_tool_discovery_benchmark() -> None: for tool_name in expected_any: assert tool_name in all_tools, f"Unknown expected tool '{tool_name}' in {case_id}" + search_start = time.perf_counter() ranked, search_error = _search_ranked_names_with_model( service=service, model=model, @@ -312,6 +315,8 @@ def test_live_tool_discovery_benchmark() -> None: max_results=max_results, provider_instance=provider_instance, ) + search_ms = (time.perf_counter() - search_start) * 1000 + search_mode = os.getenv("TOOL_SEARCH_MODE", "local").strip().lower() blocked = _is_blocked_error(search_error) infra_blocked = _is_infra_error(search_error) @@ -381,6 +386,8 @@ def test_live_tool_discovery_benchmark() -> None: "ranked": "|".join(ranked), "status": status, "error": search_error or "", + "search_ms": f"{search_ms:.1f}", + "search_mode": search_mode, } rows.append(row) if csv_path is not None: diff --git a/server_tests/test_tool_search_local.py b/server_tests/test_tool_search_local.py new file mode 100644 index 00000000..8ad0767c --- /dev/null +++ b/server_tests/test_tool_search_local.py @@ -0,0 +1,638 @@ +""" +Offline benchmark and latency tests for the local tool search engine. + +Runs the 190-case benchmark dataset against ``search_tools_local()`` without +any API key. Executes with standard ``pytest`` in under a second. + +Usage:: + + python -m pytest server_tests/test_tool_search_local.py -v +""" + +from __future__ import annotations + +import json +import os +import sys +import time +from pathlib import Path +from typing import Any, Dict, List + +import pytest + +_project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if _project_root not in sys.path: + sys.path.insert(0, _project_root) + +from static.tool_search_service import ToolSearchService, clear_search_cache + +DATASET_PATH = Path("server_tests/data/tool_discovery_cases.yaml") + + +def _load_dataset() -> Dict[str, Any]: + raw = DATASET_PATH.read_text(encoding="utf-8") + parsed: Dict[str, Any] = json.loads(raw) + return parsed + + +def _get_tool_name(tool: Dict[str, Any]) -> str: + name: str = tool.get("function", {}).get("name", "") + return name + + +class TestLocalSearchBenchmark: + """Run benchmark dataset against local search (no API key needed).""" + + @pytest.fixture(autouse=True) + def _setup(self) -> None: + clear_search_cache() + self.service = ToolSearchService.__new__(ToolSearchService) + self.service._client = None + self.service._client_initialized = False + self.service.default_model = None + self.service.last_error = None + + def _run_benchmark(self, max_results: int = 10) -> Dict[str, Any]: + """Run the full benchmark and return metrics.""" + dataset = _load_dataset() + cases = dataset.get("cases", []) + + positive_total = 0 + positive_evaluated = 0 + top1_hits = 0 + top3_hits = 0 + top5_hits = 0 + + negative_total = 0 + negative_pass = 0 + + failed_cases: List[Dict[str, Any]] = [] + + for case in cases: + query = str(case.get("query", "")).strip() + expected_any = [str(x) for x in case.get("expected_any", []) if isinstance(x, str)] + + results = self.service.search_tools_local(query, max_results) + ranked = [_get_tool_name(t) for t in results] + + if not expected_any: + negative_total += 1 + if not ranked: + negative_pass += 1 + continue + + positive_total += 1 + positive_evaluated += 1 + + top1 = ranked[0] if ranked else "" + top3_set = set(ranked[:3]) + top5_set = set(ranked[:5]) + expected_set = set(expected_any) + + is_top1 = top1 in expected_set + is_top3 = bool(expected_set & top3_set) + is_top5 = bool(expected_set & top5_set) + + if is_top1: + top1_hits += 1 + if is_top3: + top3_hits += 1 + if is_top5: + top5_hits += 1 + + if not is_top5: + failed_cases.append({ + "id": case.get("id", "?"), + "query": query, + "expected": expected_any, + "got_top5": ranked[:5], + }) + + top1_rate = top1_hits / positive_evaluated if positive_evaluated else 0.0 + top3_rate = top3_hits / positive_evaluated if positive_evaluated else 0.0 + top5_rate = top5_hits / positive_evaluated if positive_evaluated else 0.0 + + return { + "positive_total": positive_total, + "positive_evaluated": positive_evaluated, + "top1_hits": top1_hits, + "top3_hits": top3_hits, + "top5_hits": top5_hits, + "top1_rate": top1_rate, + "top3_rate": top3_rate, + "top5_rate": top5_rate, + "negative_total": negative_total, + "negative_pass": negative_pass, + "failed_cases": failed_cases, + } + + def test_local_search_accuracy(self) -> None: + """Local search should meet accuracy thresholds on benchmark dataset.""" + metrics = self._run_benchmark() + + print( + f"\nLocal search benchmark: " + f"top1={metrics['top1_rate']:.3f} ({metrics['top1_hits']}/{metrics['positive_evaluated']}), " + f"top3={metrics['top3_rate']:.3f} ({metrics['top3_hits']}/{metrics['positive_evaluated']}), " + f"top5={metrics['top5_rate']:.3f} ({metrics['top5_hits']}/{metrics['positive_evaluated']})" + ) + + if metrics["failed_cases"]: + print(f"Failed cases ({len(metrics['failed_cases'])}):") + for fc in metrics["failed_cases"][:15]: + print(f" - {fc['id']}: query={fc['query']!r}") + print(f" expected={fc['expected']}, got_top5={fc['got_top5']}") + + assert metrics["positive_evaluated"] > 0, "No positive cases evaluated" + assert metrics["top1_rate"] >= 0.80, ( + f"Top-1 accuracy {metrics['top1_rate']:.3f} below threshold 0.80" + ) + assert metrics["top3_rate"] >= 0.88, ( + f"Top-3 accuracy {metrics['top3_rate']:.3f} below threshold 0.88" + ) + + def test_local_search_top5_rate(self) -> None: + """Local search top-5 accuracy should be reasonable.""" + metrics = self._run_benchmark() + assert metrics["top5_rate"] >= 0.90, ( + f"Top-5 accuracy {metrics['top5_rate']:.3f} below threshold 0.90" + ) + + +class TestLocalSearchLatency: + """Ensure local search completes fast enough.""" + + @pytest.fixture(autouse=True) + def _setup(self) -> None: + clear_search_cache() + self.service = ToolSearchService.__new__(ToolSearchService) + self.service._client = None + self.service._client_initialized = False + self.service.default_model = None + self.service.last_error = None + + def test_local_search_latency_p99(self) -> None: + """Local search should complete in under 5ms per query (p99).""" + queries = [ + "draw a circle", + "solve x^2 - 4 = 0", + "create a point at 3, 4", + "calculate derivative of sin(x)", + "plot normal distribution mean 0 sigma 1", + "generate a weighted graph with vertices A B C", + "undo the last action", + "save workspace as MyProject", + "shade area under curve", + "translate point A by dx=2 dy=3", + "fit regression to data", + "compute mean and median of [1,2,3,4,5]", + "draw a parametric curve", + "construct perpendicular bisector of segment AB", + "find shortest path from A to D", + "zoom in on the canvas", + "clear everything", + "create segment from origin to (5,5)", + "reflect triangle across y-axis", + "integrate x^2 from 0 to 1", + ] + + # Warm up + for q in queries[:3]: + clear_search_cache() + self.service.search_tools_local(q) + + # Measure + latencies: List[float] = [] + for q in queries: + clear_search_cache() + start = time.perf_counter() + self.service.search_tools_local(q) + elapsed_ms = (time.perf_counter() - start) * 1000 + latencies.append(elapsed_ms) + + latencies.sort() + p50 = latencies[len(latencies) // 2] + p99_idx = min(int(len(latencies) * 0.99), len(latencies) - 1) + p99 = latencies[p99_idx] + avg = sum(latencies) / len(latencies) + + print(f"\nLocal search latency: avg={avg:.2f}ms, p50={p50:.2f}ms, p99={p99:.2f}ms") + + assert p99 < 5.0, f"p99 latency {p99:.2f}ms exceeds 5ms threshold" + + def test_local_search_consistent_results(self) -> None: + """Same query should return same results (deterministic).""" + query = "draw a circle at center 0,0 with radius 5" + + clear_search_cache() + result1 = self.service.search_tools_local(query) + clear_search_cache() + result2 = self.service.search_tools_local(query) + + names1 = [_get_tool_name(t) for t in result1] + names2 = [_get_tool_name(t) for t in result2] + assert names1 == names2 + + +# --------------------------------------------------------------------------- +# Creative real-world prompt tests +# --------------------------------------------------------------------------- + +def _top_n(service: ToolSearchService, query: str, n: int = 5) -> List[str]: + """Return top-n tool names for a query.""" + clear_search_cache() + results = service.search_tools_local(query, max_results=n) + return [_get_tool_name(t) for t in results] + + +class TestCreativePrompts: + """Test the classifier against creative, realistic, and tricky user prompts. + + These go beyond the benchmark dataset to probe edge cases, slang, + multi-step requests, and domain-specific phrasing. + """ + + @pytest.fixture(autouse=True) + def _setup(self) -> None: + clear_search_cache() + self.service = ToolSearchService.__new__(ToolSearchService) + self.service._client = None + self.service._client_initialized = False + self.service.default_model = None + self.service.last_error = None + + # -- Casual / conversational phrasing -- + + def test_casual_circle(self) -> None: + """Casual 'gimme a circle' should find create_circle.""" + names = _top_n(self.service, "gimme a circle centered at the origin") + assert "create_circle" in names + + def test_casual_undo(self) -> None: + """'oops go back' should find undo.""" + names = _top_n(self.service, "oops go back") + assert "undo" in names + + def test_casual_delete(self) -> None: + """'get rid of that triangle' should find delete_polygon.""" + names = _top_n(self.service, "get rid of that triangle") + assert "delete_polygon" in names + + def test_casual_zoom(self) -> None: + """'I can't see anything, zoom out' should find zoom.""" + names = _top_n(self.service, "I can't see anything, zoom out") + assert "zoom" in names + + # -- Math class homework scenarios -- + + def test_homework_quadratic(self) -> None: + """Student solving a quadratic.""" + names = _top_n(self.service, "What are the roots of x^2 - 5x + 6?") + assert "solve" in names + + def test_homework_derivative_chain_rule(self) -> None: + """Chain rule derivative.""" + names = _top_n(self.service, "take the derivative of sin(3x^2 + 1)") + assert "derive" in names + + def test_homework_integral_area(self) -> None: + """Find area under a curve via integration.""" + names = _top_n(self.service, + "what's the area under the curve y=x^3 between 0 and 2?") + assert "integrate" in names or "create_colored_area" in names + + def test_homework_system_word_problem(self) -> None: + """Word-problem style system of equations.""" + names = _top_n(self.service, + "If apples cost $2 and bananas cost $3, and I spent $13 on " + "5 fruits, how many of each did I buy?") + assert "solve_system_of_equations" in names or "solve" in names + + def test_homework_factor_polynomial(self) -> None: + """Factor a cubic polynomial.""" + names = _top_n(self.service, "factor x^3 - 27 completely") + assert "factor" in names + + def test_homework_limit(self) -> None: + """L'Hopital's rule limit.""" + names = _top_n(self.service, + "What is the limit of sin(x)/x as x approaches 0?") + assert "limit" in names + + # -- Engineering / physics style -- + + def test_physics_projectile(self) -> None: + """Projectile motion parametric curve.""" + names = _top_n(self.service, + "plot the trajectory: x(t) = 10t, y(t) = 10t - 4.9t^2") + assert "draw_parametric_function" in names + + def test_physics_unit_conversion(self) -> None: + """Physics unit conversion.""" + names = _top_n(self.service, "convert 9.8 meters per second squared to feet") + assert "convert" in names + + def test_engineering_matrix(self) -> None: + """Stiffness matrix computation.""" + names = _top_n(self.service, + "find the inverse of the 3x3 matrix [[1,2,3],[0,1,4],[5,6,0]]") + assert "evaluate_linear_algebra_expression" in names + + # -- Geometry constructions -- + + def test_geometry_circumscribed_circle(self) -> None: + """Circumscribed circle of a triangle.""" + names = _top_n(self.service, + "construct the circumscribed circle of triangle ABC") + assert "construct_circumcircle" in names + + def test_geometry_incircle(self) -> None: + """Inscribed circle.""" + names = _top_n(self.service, + "draw the inscribed circle inside the triangle") + assert "construct_incircle" in names + + def test_geometry_midpoint(self) -> None: + """Find midpoint of a segment.""" + names = _top_n(self.service, "mark the midpoint of segment PQ") + assert "construct_midpoint" in names + + def test_geometry_parallel(self) -> None: + """Construct parallel line.""" + names = _top_n(self.service, + "draw a line through point C parallel to segment AB") + assert "construct_parallel_line" in names + + # -- Statistics / data science -- + + def test_stats_bell_curve(self) -> None: + """Bell curve is a normal distribution.""" + names = _top_n(self.service, + "show me a bell curve with mean 100 and std dev 15") + assert "plot_distribution" in names + + def test_stats_bar_chart(self) -> None: + """Simple bar chart.""" + names = _top_n(self.service, + "make a bar chart comparing sales: Q1=100, Q2=150, Q3=80, Q4=200") + assert "plot_bars" in names + + def test_stats_regression_fit(self) -> None: + """Fit a trend line to data.""" + names = _top_n(self.service, + "fit a line of best fit through these data points") + assert "fit_regression" in names + + def test_stats_descriptive(self) -> None: + """Basic descriptive statistics.""" + names = _top_n(self.service, + "give me the mean, median, and standard deviation of " + "[88, 92, 76, 95, 83, 91, 78]") + assert "compute_descriptive_statistics" in names + + # -- Graph theory -- + + def test_graph_shortest_path(self) -> None: + """Dijkstra / shortest path.""" + names = _top_n(self.service, + "what's the shortest path from node S to node T in the graph?") + assert "analyze_graph" in names + + def test_graph_minimum_spanning_tree(self) -> None: + """MST.""" + names = _top_n(self.service, + "compute the minimum spanning tree of graph G1") + assert "analyze_graph" in names + + def test_graph_topological_sort(self) -> None: + """Topological sort.""" + names = _top_n(self.service, + "topologically sort the DAG") + assert "analyze_graph" in names + + def test_graph_create_network(self) -> None: + """Create a network graph.""" + names = _top_n(self.service, + "build a weighted undirected network with 6 nodes and 8 edges") + assert "generate_graph" in names + + # -- Workspace management -- + + def test_workspace_checkpoint(self) -> None: + """Save a checkpoint.""" + names = _top_n(self.service, "save my progress as 'homework_ch7'") + assert "save_workspace" in names + + def test_workspace_resume(self) -> None: + """Resume previous work.""" + names = _top_n(self.service, + "pick up where I left off on the calculus_project workspace") + assert "load_workspace" in names + + def test_workspace_browse(self) -> None: + """Browse workspaces.""" + names = _top_n(self.service, "what workspaces do I have saved?") + assert "list_workspaces" in names + + # -- Canvas operations -- + + def test_canvas_wipe(self) -> None: + """Wipe the canvas.""" + names = _top_n(self.service, "wipe everything clean and start fresh") + assert "clear_canvas" in names + + def test_canvas_grid_toggle(self) -> None: + """Toggle grid.""" + names = _top_n(self.service, "hide the grid lines") + assert "set_grid_visible" in names + + def test_canvas_polar_mode(self) -> None: + """Switch to polar coordinates.""" + names = _top_n(self.service, "switch to polar coordinate mode") + assert "set_coordinate_system" in names + + # -- Transforms -- + + def test_transform_slide(self) -> None: + """Translate using casual language.""" + names = _top_n(self.service, + "slide the rectangle 3 units to the right and 2 up") + assert "translate_object" in names + + def test_transform_flip(self) -> None: + """Reflect using 'flip'.""" + names = _top_n(self.service, "flip the triangle over the x-axis") + assert "reflect_object" in names + + def test_transform_double_size(self) -> None: + """Scale up.""" + names = _top_n(self.service, + "make the circle twice as big") + assert "scale_object" in names + + def test_transform_rotate_45(self) -> None: + """Rotate by 45 degrees.""" + names = _top_n(self.service, "rotate the square 45 degrees clockwise") + assert "rotate_object" in names + + # -- Colored areas -- + + def test_area_shade_between_curves(self) -> None: + """Shade between two curves.""" + names = _top_n(self.service, + "shade the region between y=x^2 and y=x") + assert "create_region_colored_area" in names or "create_colored_area" in names + + def test_area_highlight_integral(self) -> None: + """Highlight the area for a definite integral.""" + names = _top_n(self.service, + "highlight the area under e^(-x) from 0 to infinity") + assert "create_colored_area" in names + + # -- Edge cases and tricky phrasing -- + + def test_ambiguous_graph_word(self) -> None: + """'graph' meaning plot, not graph theory.""" + names = _top_n(self.service, + "graph the absolute value function |x|") + assert "draw_function" in names or "draw_piecewise_function" in names + + def test_ambiguous_normal(self) -> None: + """'normal' meaning distribution, not normal line.""" + names = _top_n(self.service, "plot a normal distribution") + assert "plot_distribution" in names + + def test_ambiguous_normal_line(self) -> None: + """'normal line' meaning perpendicular to tangent.""" + names = _top_n(self.service, + "draw the normal line to the curve at x=2") + assert "draw_normal_line" in names + + def test_no_verb_query(self) -> None: + """Query with no clear verb.""" + names = _top_n(self.service, "circle radius 5 center (0,0)") + assert "create_circle" in names + + def test_emoji_and_noise(self) -> None: + """Query with noise characters.""" + names = _top_n(self.service, "!!! draw a big triangle please!!!") + assert "create_polygon" in names + + def test_very_long_query(self) -> None: + """Verbose multi-sentence request.""" + names = _top_n(self.service, + "I'm working on my geometry homework and I need to create " + "a point at coordinates (3, 7). This point represents the " + "location of a lighthouse on my map. Could you help me " + "place it on the canvas?") + assert "create_point" in names + + def test_mixed_math_notation(self) -> None: + """Math notation mixed with words.""" + names = _top_n(self.service, + "compute integral from 0 to pi of sin(x) dx") + assert "integrate" in names + + def test_creative_lissajous(self) -> None: + """Lissajous curve (parametric).""" + names = _top_n(self.service, + "draw a Lissajous figure with x=sin(3t) y=sin(2t)") + assert "draw_parametric_function" in names + + def test_creative_rose_curve(self) -> None: + """Rose curve expressed parametrically.""" + names = _top_n(self.service, + "plot the rose curve r=cos(4*theta) in polar") + assert "draw_parametric_function" in names or "draw_function" in names + + def test_inspect_perpendicularity(self) -> None: + """Check if two segments are perpendicular.""" + names = _top_n(self.service, + "are segments AB and CD perpendicular to each other?") + assert "inspect_relation" in names + + def test_numeric_approximation(self) -> None: + """Numerical root finding.""" + names = _top_n(self.service, + "find an approximate solution to x = cos(x)") + assert "solve_numeric" in names + + def test_expand_binomial(self) -> None: + """Binomial expansion.""" + names = _top_n(self.service, "expand (2x - 3)^4") + assert "expand" in names + + def test_simplify_trig_identity(self) -> None: + """Simplify a trig expression.""" + names = _top_n(self.service, + "simplify sin^2(x) + cos^2(x) - 1") + assert "simplify" in names + + def test_delete_specific_plot(self) -> None: + """Delete a named plot.""" + names = _top_n(self.service, + "remove the distribution plot called 'bell1'") + assert "delete_plot" in names + + def test_update_circle_color(self) -> None: + """Change a circle's color.""" + names = _top_n(self.service, "change circle C1 to red") + assert "update_circle" in names + + def test_ellipse_creation(self) -> None: + """Create an ellipse.""" + names = _top_n(self.service, + "draw an ellipse with semi-major axis 5 and semi-minor axis 3") + assert "create_ellipse" in names + + def test_canvas_state_inspection(self) -> None: + """Inspect canvas state.""" + names = _top_n(self.service, "what objects are currently on the canvas?") + assert "get_current_canvas_state" in names + + def test_angle_creation(self) -> None: + """Create an angle.""" + names = _top_n(self.service, + "show the angle between rays BA and BC") + assert "create_angle" in names + + def test_vector_creation(self) -> None: + """Create a vector.""" + names = _top_n(self.service, + "draw a vector from (1,1) pointing to (4,5)") + assert "create_vector" in names + + def test_piecewise_absolute_value(self) -> None: + """Absolute value as piecewise.""" + names = _top_n(self.service, + "draw f(x) = x when x >= 0 and -x when x < 0") + assert "draw_piecewise_function" in names + + def test_tangent_line_at_point(self) -> None: + """Tangent line at a specific point.""" + names = _top_n(self.service, + "draw the tangent to y=x^3 at the point where x=1") + assert "draw_tangent_line" in names + + def test_convert_degrees_radians(self) -> None: + """Convert between angle units.""" + names = _top_n(self.service, "how many radians is 270 degrees?") + assert "convert" in names + + def test_bisect_angle(self) -> None: + """Bisect an angle.""" + names = _top_n(self.service, + "bisect the angle at vertex B in triangle ABC") + assert "construct_angle_bisector" in names + + def test_colored_area_function(self) -> None: + """Shade under a specific function.""" + names = _top_n(self.service, + "fill the area under y=1/x from x=1 to x=e with blue") + assert "create_colored_area" in names + + def test_delete_workspace(self) -> None: + """Delete a workspace.""" + names = _top_n(self.service, + "remove the workspace named 'old_draft'") + assert "delete_workspace" in names diff --git a/server_tests/test_tool_search_service.py b/server_tests/test_tool_search_service.py index 873cd7a2..b748656b 100644 --- a/server_tests/test_tool_search_service.py +++ b/server_tests/test_tool_search_service.py @@ -18,11 +18,24 @@ if _project_root not in sys.path: sys.path.insert(0, _project_root) -from static.tool_search_service import ToolSearchService +from static.tool_search_service import ( + TOOL_CATEGORIES, + ToolSearchService, + _TOOL_BY_NAME, + _TOOL_NAME_INDEX, + _search_cache, + clear_search_cache, +) from static.ai_model import AIModel from static.functions_definitions import FUNCTIONS +@pytest.fixture(autouse=True) +def _clear_cache() -> None: + """Clear the search cache before each test.""" + clear_search_cache() + + class TestToolSearchServiceBasics: """Test basic ToolSearchService functionality.""" @@ -136,7 +149,12 @@ def test_parse_single_item(self) -> None: class TestSearchToolsWithMock: - """Test search_tools with mocked OpenAI client.""" + """Test search_tools with mocked OpenAI client (API mode).""" + + @pytest.fixture(autouse=True) + def _api_mode(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Force API search mode for these tests.""" + monkeypatch.setenv("TOOL_SEARCH_MODE", "api") @pytest.fixture def mock_client(self) -> MagicMock: @@ -307,18 +325,23 @@ def test_search_uses_reasoning_default_model_when_none(self, mock_client: MagicM assert "max_tokens" not in call_args.kwargs def test_search_uses_default_model_when_none(self, service: ToolSearchService, mock_client: MagicMock) -> None: - """search_tools should use gpt-4.1-mini when no model specified.""" + """search_tools should use gpt-5-nano when no model specified.""" self._setup_mock_response(mock_client, '["create_circle"]') service.search_tools("draw") call_args = mock_client.chat.completions.create.call_args - assert call_args.kwargs.get("model") == "gpt-4.1-mini" + assert call_args.kwargs.get("model") == "gpt-5-nano" class TestSearchToolsFormatted: """Test the search_tools_formatted method.""" + @pytest.fixture(autouse=True) + def _api_mode(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Force API search mode for these tests.""" + monkeypatch.setenv("TOOL_SEARCH_MODE", "api") + @pytest.fixture def mock_client(self) -> MagicMock: """Create a mock OpenAI client.""" @@ -471,7 +494,7 @@ def test_set_tool_mode_invalid_raises(self) -> None: with patch.object(OpenAIAPIBase, "_initialize_api_key", return_value="test-key"): api = OpenAIAPIBase() with pytest.raises(ValueError): - api.set_tool_mode("invalid") # type: ignore + api.set_tool_mode("invalid") def test_custom_tools_override_mode(self) -> None: """Custom tools should override tool mode selection.""" @@ -549,8 +572,9 @@ def test_search_tools_included_when_requested(self) -> None: descriptions = ToolSearchService.build_tool_descriptions(exclude_meta_tools=False) assert "- search_tools:" in descriptions - def test_search_filters_out_search_tools(self) -> None: - """search_tools should not appear in search results.""" + def test_search_filters_out_search_tools(self, monkeypatch: pytest.MonkeyPatch) -> None: + """search_tools should not appear in search results (API mode).""" + monkeypatch.setenv("TOOL_SEARCH_MODE", "api") mock_client = MagicMock() mock_message = MagicMock() mock_message.content = '["search_tools", "create_circle", "create_point"]' @@ -740,3 +764,252 @@ def test_extract_from_none(self) -> None: """Should return empty for None.""" result = ToolSearchService._extract_list_from_parsed(None) assert result == [] + + +class TestLocalSearch: + """Test the local search engine.""" + + @pytest.fixture(autouse=True) + def _setup(self) -> None: + clear_search_cache() + self.service = ToolSearchService.__new__(ToolSearchService) + self.service._client = None + self.service._client_initialized = False + self.service.default_model = None + self.service.last_error = None + + def test_local_search_returns_results(self) -> None: + """search_tools_local should return matching tools.""" + results = self.service.search_tools_local("draw a circle") + assert len(results) > 0 + names = [t["function"]["name"] for t in results] + assert "create_circle" in names + + def test_local_search_respects_max_results(self) -> None: + """search_tools_local should limit results.""" + results = self.service.search_tools_local("create something", max_results=3) + assert len(results) <= 3 + + def test_local_search_empty_query(self) -> None: + """search_tools_local should return empty for empty tokens.""" + results = self.service.search_tools_local("") + assert results == [] + + def test_local_search_excludes_search_tools(self) -> None: + """search_tools_local should not return search_tools.""" + results = self.service.search_tools_local("search for tools", max_results=20) + names = [t["function"]["name"] for t in results] + assert "search_tools" not in names + + def test_local_search_finds_solve(self) -> None: + """Local search should find solve for equation queries.""" + results = self.service.search_tools_local("solve x^2 = 4") + names = [t["function"]["name"] for t in results[:3]] + assert "solve" in names + + def test_local_search_finds_workspace(self) -> None: + """Local search should find workspace tools.""" + results = self.service.search_tools_local("save workspace as test") + names = [t["function"]["name"] for t in results[:3]] + assert "save_workspace" in names + + def test_local_search_finds_parametric(self) -> None: + """Local search should find parametric functions.""" + results = self.service.search_tools_local("draw parametric curve x=cos(t) y=sin(t)") + names = [t["function"]["name"] for t in results[:3]] + assert "draw_parametric_function" in names + + def test_local_search_graph_vs_function(self) -> None: + """'graph f(x)' should match draw_function, not generate_graph.""" + results = self.service.search_tools_local("graph f(x) = sin(x)") + names = [t["function"]["name"] for t in results[:3]] + assert "draw_function" in names + + def test_local_search_graph_theory(self) -> None: + """Graph with vertices/edges should match graph theory tools.""" + results = self.service.search_tools_local( + "create a directed graph with vertices A B C and edges A-B B-C" + ) + names = [t["function"]["name"] for t in results[:3]] + assert "generate_graph" in names + + def test_local_search_deterministic(self) -> None: + """Same query should always return same results.""" + r1 = self.service.search_tools_local("draw a triangle") + clear_search_cache() + r2 = self.service.search_tools_local("draw a triangle") + n1 = [t["function"]["name"] for t in r1] + n2 = [t["function"]["name"] for t in r2] + assert n1 == n2 + + +class TestSearchCache: + """Test the LRU result cache.""" + + @pytest.fixture(autouse=True) + def _setup(self, monkeypatch: pytest.MonkeyPatch) -> None: + clear_search_cache() + monkeypatch.setenv("TOOL_SEARCH_MODE", "local") + + def test_cache_returns_same_results(self) -> None: + """Cached results should match original results.""" + service = ToolSearchService.__new__(ToolSearchService) + service._client = None + service._client_initialized = False + service.default_model = None + service.last_error = None + + r1 = service.search_tools("draw a circle") + r2 = service.search_tools("draw a circle") + n1 = [t["function"]["name"] for t in r1] + n2 = [t["function"]["name"] for t in r2] + assert n1 == n2 + + def test_cache_populated_after_search(self) -> None: + """Cache should have an entry after a search.""" + service = ToolSearchService.__new__(ToolSearchService) + service._client = None + service._client_initialized = False + service.default_model = None + service.last_error = None + + assert len(_search_cache) == 0 + service.search_tools("draw a circle") + assert len(_search_cache) == 1 + + def test_clear_cache(self) -> None: + """clear_search_cache should empty the cache.""" + service = ToolSearchService.__new__(ToolSearchService) + service._client = None + service._client_initialized = False + service.default_model = None + service.last_error = None + + service.search_tools("draw") + assert len(_search_cache) > 0 + clear_search_cache() + assert len(_search_cache) == 0 + + def test_different_queries_different_cache_keys(self) -> None: + """Different queries should use different cache keys.""" + service = ToolSearchService.__new__(ToolSearchService) + service._client = None + service._client_initialized = False + service.default_model = None + service.last_error = None + + service.search_tools("draw a circle") + service.search_tools("solve equation") + assert len(_search_cache) == 2 + + +class TestSearchModeSwitching: + """Test TOOL_SEARCH_MODE env var switching.""" + + @pytest.fixture(autouse=True) + def _setup(self) -> None: + clear_search_cache() + + def test_local_mode_no_api_call(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Local mode should not call the API.""" + monkeypatch.setenv("TOOL_SEARCH_MODE", "local") + mock_client = MagicMock() + service = ToolSearchService(client=mock_client) + + results = service.search_tools("draw a circle") + + mock_client.chat.completions.create.assert_not_called() + assert len(results) > 0 + names = [t["function"]["name"] for t in results] + assert "create_circle" in names + + def test_api_mode_calls_api(self, monkeypatch: pytest.MonkeyPatch) -> None: + """API mode should call the OpenAI API.""" + monkeypatch.setenv("TOOL_SEARCH_MODE", "api") + mock_client = MagicMock() + mock_message = MagicMock() + mock_message.content = '["create_circle"]' + mock_choice = MagicMock() + mock_choice.message = mock_message + mock_response = MagicMock() + mock_response.choices = [mock_choice] + mock_client.chat.completions.create.return_value = mock_response + + service = ToolSearchService(client=mock_client) + results = service.search_tools("draw a circle") + + mock_client.chat.completions.create.assert_called_once() + assert len(results) == 1 + + def test_default_mode_is_hybrid(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Default mode (no env var) should be hybrid.""" + monkeypatch.delenv("TOOL_SEARCH_MODE", raising=False) + mock_client = MagicMock() + service = ToolSearchService(client=mock_client) + + # Hybrid uses local first; high-confidence queries won't hit the API + results = service.search_tools("draw a circle") + + mock_client.chat.completions.create.assert_not_called() + assert len(results) > 0 + + def test_hybrid_mode_uses_local_when_confident(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Hybrid mode should use local results when confidence is high.""" + monkeypatch.setenv("TOOL_SEARCH_MODE", "hybrid") + mock_client = MagicMock() + service = ToolSearchService(client=mock_client) + + # "draw a circle" should have high confidence locally + results = service.search_tools("create a circle at 0,0 with radius 5") + + # Should not need API for a confident local match + mock_client.chat.completions.create.assert_not_called() + names = [t["function"]["name"] for t in results] + assert "create_circle" in names + + +class TestToolCategoryRegistry: + """Test the TOOL_CATEGORIES constant and indices.""" + + def test_categories_not_empty(self) -> None: + """TOOL_CATEGORIES should be non-empty.""" + assert len(TOOL_CATEGORIES) > 0 + + def test_all_category_tools_exist(self) -> None: + """Every tool listed in a category should exist in FUNCTIONS.""" + all_tool_names = {f.get("function", {}).get("name") for f in FUNCTIONS} + for cat_name, entry in TOOL_CATEGORIES.items(): + for tool_name in entry["tools"]: + assert tool_name in all_tool_names, ( + f"Tool '{tool_name}' in category '{cat_name}' not found in FUNCTIONS" + ) + + def test_name_index_populated(self) -> None: + """Tool name index should be populated.""" + assert len(_TOOL_NAME_INDEX) > 0 + # "create" should index multiple tools + assert len(_TOOL_NAME_INDEX.get("create", [])) > 5 + + def test_tool_by_name_lookup(self) -> None: + """_TOOL_BY_NAME should include all tools.""" + assert "create_circle" in _TOOL_BY_NAME + assert "search_tools" in _TOOL_BY_NAME # meta-tools still in lookup + assert "undo" in _TOOL_BY_NAME + + +class TestLazyClientInitialization: + """Test that OpenAI client is lazily initialized.""" + + def test_no_client_needed_for_local_mode(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Local mode should work without API key.""" + monkeypatch.setenv("TOOL_SEARCH_MODE", "local") + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + + service = ToolSearchService.__new__(ToolSearchService) + service._client = None + service._client_initialized = False + service.default_model = None + service.last_error = None + + results = service.search_tools("draw a circle") + assert len(results) > 0 diff --git a/static/tool_search_service.py b/static/tool_search_service.py index 90cfb188..9b52be56 100644 --- a/static/tool_search_service.py +++ b/static/tool_search_service.py @@ -1,14 +1,18 @@ """ MatHud Tool Search Service -Provides semantic tool discovery using AI-powered matching. -Given a user's description of what they want to accomplish, searches through -available tool definitions and returns the most relevant matches. +Provides tool discovery via fast local keyword/category matching (default) +or AI-powered semantic matching. The search mode is selected by the +``TOOL_SEARCH_MODE`` environment variable: + +* ``local`` -- keyword + category index, no API call +* ``api`` -- original GPT-based semantic search +* ``hybrid`` (default) -- local first, fall back to API when confidence is low Dependencies: - static.ai_model: AI model configuration - static.functions_definitions: Tool definitions to search through - - openai: OpenAI API client for semantic matching + - openai: OpenAI API client (only needed for ``api`` / ``hybrid`` modes) """ from __future__ import annotations @@ -17,7 +21,9 @@ import logging import os import re -from typing import Any, Dict, List, Optional +import time +from collections import defaultdict +from typing import Any, Dict, List, Optional, Tuple, TypedDict from dotenv import load_dotenv from openai import OpenAI @@ -30,15 +36,311 @@ # Tools to exclude from search results (meta-tools that shouldn't be recommended) EXCLUDED_FROM_SEARCH = frozenset({"search_tools"}) +# --------------------------------------------------------------------------- +# Result cache +# --------------------------------------------------------------------------- +CACHE_TTL = 300 # seconds +CACHE_MAX_SIZE = 100 + +_search_cache: Dict[str, Tuple[float, List[FunctionDefinition]]] = {} + + +def _cache_get(key: str) -> Optional[List[FunctionDefinition]]: + entry = _search_cache.get(key) + if entry is None: + return None + ts, results = entry + if time.monotonic() - ts > CACHE_TTL: + _search_cache.pop(key, None) + return None + return results + + +def _cache_put(key: str, results: List[FunctionDefinition]) -> None: + # Evict oldest if at capacity + if len(_search_cache) >= CACHE_MAX_SIZE and key not in _search_cache: + oldest_key = min(_search_cache, key=lambda k: _search_cache[k][0]) + _search_cache.pop(oldest_key, None) + _search_cache[key] = (time.monotonic(), results) + + +def clear_search_cache() -> None: + """Clear the search result cache.""" + _search_cache.clear() + + +# --------------------------------------------------------------------------- +# Tool category registry +# --------------------------------------------------------------------------- + +class CategoryEntry(TypedDict): + tools: List[str] + keywords: List[str] + + +TOOL_CATEGORIES: Dict[str, CategoryEntry] = { + "geometry_create": { + "tools": [ + "create_point", "create_segment", "create_vector", + "create_polygon", "create_circle", "create_circle_arc", + "create_ellipse", "create_label", "create_angle", + ], + "keywords": [ + "create", "draw", "make", "add", "place", "put", + "construct", "point", "segment", "vector", "polygon", + "triangle", "rectangle", "circle", "ellipse", "arc", + "label", "angle", "line", + ], + }, + "geometry_delete": { + "tools": [ + "delete_point", "delete_segment", "delete_vector", + "delete_polygon", "delete_circle", "delete_circle_arc", + "delete_ellipse", "delete_label", "delete_angle", + ], + "keywords": [ + "delete", "remove", "erase", "destroy", "rid", + ], + }, + "geometry_update": { + "tools": [ + "update_point", "update_segment", "update_vector", + "update_polygon", "update_circle", "update_circle_arc", + "update_ellipse", "update_label", "update_angle", + ], + "keywords": [ + "update", "change", "modify", "edit", "rename", "resize", "set", + "recolor", "reposition", + ], + }, + "constructions": { + "tools": [ + "construct_midpoint", "construct_perpendicular_bisector", + "construct_perpendicular_from_point", "construct_angle_bisector", + "construct_parallel_line", "construct_circumcircle", + "construct_incircle", + ], + "keywords": [ + "midpoint", "bisector", "perpendicular", "parallel", + "circumcircle", "incircle", "construct", "bisect", + "inscribed", "circumscribed", + ], + }, + "functions_plots": { + "tools": [ + "draw_function", "delete_function", "update_function", + "draw_piecewise_function", "delete_piecewise_function", + "update_piecewise_function", + "draw_parametric_function", "delete_parametric_function", + "update_parametric_function", + "draw_tangent_line", "draw_normal_line", + ], + "keywords": [ + "function", "plot", "curve", "equation", "parametric", + "piecewise", "tangent", "normal", "y=", "f(x)", + "parabola", "sine", "cosine", "exponential", "logarithm", + "polynomial", "lissajous", "spiral", + ], + }, + "math": { + "tools": [ + "evaluate_expression", "evaluate_linear_algebra_expression", + "convert", "limit", "derive", "integrate", "numeric_integrate", + "simplify", "expand", "factor", "solve", + "solve_system_of_equations", "solve_numeric", + ], + "keywords": [ + "calculate", "evaluate", "solve", "derivative", "integral", + "integrate", "simplify", "factor", "expand", "limit", + "differentiate", "calculus", "algebra", "expression", + "equation", "compute", "math", "formula", "linear", + "matrix", "determinant", "eigenvalue", "inverse", + "numeric", "numerical", "system", "simultaneous", + "find", "root", "roots", "reduce", "multiply", + ], + }, + "statistics": { + "tools": [ + "plot_distribution", "plot_bars", "delete_plot", + "fit_regression", "compute_descriptive_statistics", + ], + "keywords": [ + "statistics", "distribution", "bar", "chart", "mean", + "median", "regression", "histogram", "normal", "gaussian", + "standard", "deviation", "stdev", "quartile", "percentile", + "probability", "bell", "discrete", "continuous", "frequency", + "descriptive", "stats", "average", "variance", + ], + }, + "graph_theory": { + "tools": [ + "generate_graph", "delete_graph", "analyze_graph", + ], + "keywords": [ + "graph", "tree", "dag", "vertex", "vertices", "edge", + "node", "shortest", "path", "bfs", "dfs", "mst", + "topological", "sort", "spanning", "adjacency", + "directed", "undirected", "weighted", "network", + ], + }, + "canvas": { + "tools": [ + "zoom", "clear_canvas", "reset_canvas", + "set_coordinate_system", "set_grid_visible", + "get_current_canvas_state", "run_tests", + "undo", "redo", + ], + "keywords": [ + "zoom", "canvas", "grid", "coordinate", "reset", "clear", + "view", "axes", "viewport", "pan", "fit", + "undo", "redo", "state", "polar", "cartesian", + "wipe", "clean", "fresh", "oops", "back", "revert", + ], + }, + "workspace": { + "tools": [ + "save_workspace", "load_workspace", + "list_workspaces", "delete_workspace", + ], + "keywords": [ + "workspace", "save", "load", "export", "import", "project", + "session", "open", "restore", "persist", "store", + "inventory", "available", + ], + }, + "transforms": { + "tools": [ + "translate_object", "rotate_object", "reflect_object", + "scale_object", "shear_object", + ], + "keywords": [ + "translate", "rotate", "reflect", "mirror", "scale", + "shear", "transform", "move", "shift", "flip", + "enlarge", "shrink", "stretch", "turn", "spin", + "slide", "twice", "double", "bigger", "smaller", "larger", + ], + }, + "areas": { + "tools": [ + "create_colored_area", "create_region_colored_area", + "delete_colored_area", "update_colored_area", + "calculate_area", + ], + "keywords": [ + "area", "shade", "region", "color", "fill", "highlight", + "between", "under", "above", "bounded", + ], + }, + "inspection": { + "tools": ["inspect_relation"], + "keywords": [ + "inspect", "relation", "check", "verify", "collinear", + "concurrent", "tangent", "congruent", "similar", + "relationship", "distance", "measure", + ], + }, + "coordinates": { + "tools": [ + "set_coordinate_system", "convert_coordinates", + ], + "keywords": [ + "polar", "cartesian", "coordinate", "system", "convert", + "cylindrical", "spherical", + ], + }, +} + +# --------------------------------------------------------------------------- +# Inverted indices — built once at module load time +# --------------------------------------------------------------------------- + +# token -> list of tool names whose *name* contains the token +_TOOL_NAME_INDEX: Dict[str, List[str]] = defaultdict(list) + +# token -> list of tool names whose *description* contains the token +_TOOL_DESC_INDEX: Dict[str, List[str]] = defaultdict(list) + +# keyword -> list of category names +_CATEGORY_KEYWORD_INDEX: Dict[str, List[str]] = defaultdict(list) + +# tool_name -> FunctionDefinition (fast lookup) +_TOOL_BY_NAME: Dict[str, FunctionDefinition] = {} + +# Set of all searchable tool names +_ALL_TOOL_NAMES: frozenset[str] = frozenset() + + +def _build_indices() -> None: + """Populate inverted indices from FUNCTIONS and TOOL_CATEGORIES.""" + global _ALL_TOOL_NAMES + + names: List[str] = [] + for tool in FUNCTIONS: + func = tool.get("function", {}) + name = func.get("name", "") + if not name: + continue + # Always register in the name lookup (used by get_tool_by_name) + _TOOL_BY_NAME[name] = tool + + # Skip meta-tools for search indices + if name in EXCLUDED_FROM_SEARCH: + continue + + names.append(name) + + # Index name tokens + name_tokens = name.lower().replace("_", " ").split() + for token in name_tokens: + if token and len(token) > 1: + _TOOL_NAME_INDEX[token].append(name) + + # Index description tokens + description = func.get("description", "") + desc_tokens = set(re.findall(r"[a-z0-9]+", description.lower())) + for token in desc_tokens: + if len(token) > 1: + _TOOL_DESC_INDEX[token].append(name) + + _ALL_TOOL_NAMES = frozenset(names) + + # Build category keyword index + for cat_name, cat_entry in TOOL_CATEGORIES.items(): + for keyword in cat_entry["keywords"]: + _CATEGORY_KEYWORD_INDEX[keyword.lower()].append(cat_name) + + +_build_indices() + +# Confidence threshold for hybrid mode +CONFIDENCE_THRESHOLD = 6.0 + +# Action-verb to tool-name prefix mapping +_ACTION_VERB_MAP: Dict[str, str] = { + "create": "create_", + "draw": "draw_", + "make": "create_", + "add": "create_", + "delete": "delete_", + "remove": "delete_", + "erase": "delete_", + "update": "update_", + "change": "update_", + "modify": "update_", + "edit": "update_", + "construct": "construct_", + "plot": "draw_", + "slide": "translate_", +} + class ToolSearchService: - """Service for semantic tool discovery using AI-powered matching. + """Service for tool discovery via local keyword matching or AI-powered search. - Uses the app's AI model to find the most relevant tools for a given query - by analyzing tool names and descriptions. + The search mode is controlled by the ``TOOL_SEARCH_MODE`` environment variable. """ - # System prompt for tool selection + # System prompt for tool selection (used in API mode) TOOL_SELECTOR_PROMPT = """You are a tool selector. Given a user's description of what they want to accomplish, select the most relevant tools from the list below. Return ONLY a JSON array of tool names, ordered by relevance (most relevant first). Available tools: @@ -70,11 +372,9 @@ class ToolSearchService: "on", "or", "please", - "show", "the", "to", "up", - "use", "with", "you", } @@ -88,18 +388,30 @@ def __init__( """Initialize the tool search service. Args: - client: Optional OpenAI-compatible client. If not provided, creates a new one. - default_model: Optional default model to use for search. If not provided, - uses gpt-4.1-mini for OpenAI or the client's configured model for local LLMs. + client: Optional OpenAI-compatible client. If not provided, creates + one only when API mode is actually needed. + default_model: Optional default model to use for API-based search. """ - if client is not None: - self.client = client - else: - api_key = self._initialize_api_key() - self.client = OpenAI(api_key=api_key) - + self._client = client + self._client_initialized = client is not None self.default_model = default_model self.last_error: Optional[str] = None + self._last_local_top_score: float = 0.0 + + @property + def client(self) -> OpenAI: + """Lazily initialize the OpenAI client on first access.""" + if not self._client_initialized: + api_key = self._initialize_api_key() + self._client = OpenAI(api_key=api_key) + self._client_initialized = True + assert self._client is not None + return self._client + + @client.setter + def client(self, value: OpenAI) -> None: + self._client = value + self._client_initialized = True @staticmethod def _initialize_api_key() -> str: @@ -122,11 +434,7 @@ def _initialize_api_key() -> str: @staticmethod def get_all_tools() -> List[FunctionDefinition]: - """Get all available tool definitions. - - Returns: - List of all function definitions. - """ + """Get all available tool definitions.""" return list(FUNCTIONS) @staticmethod @@ -155,19 +463,12 @@ def build_tool_descriptions(exclude_meta_tools: bool = True) -> str: @staticmethod def get_tool_by_name(name: str) -> Optional[FunctionDefinition]: - """Get a tool definition by its name. - - Args: - name: The tool name to look up. + """Get a tool definition by its name.""" + return _TOOL_BY_NAME.get(name) - Returns: - The tool definition if found, None otherwise. - """ - for tool in FUNCTIONS: - func = tool.get("function", {}) - if func.get("name") == name: - return tool - return None + # ------------------------------------------------------------------ + # Public search entry point + # ------------------------------------------------------------------ def search_tools( self, @@ -177,12 +478,11 @@ def search_tools( ) -> List[FunctionDefinition]: """Search for tools matching a query description. - Uses AI to semantically match the query against tool descriptions - and return the most relevant tool definitions. + Dispatches to local or API search based on ``TOOL_SEARCH_MODE`` env var. Args: query: Description of what the user wants to accomplish. - model: AI model to use for matching. Defaults to gpt-4.1-mini. + model: AI model to use for matching (API mode only). max_results: Maximum number of tools to return (1-20). Returns: @@ -196,9 +496,395 @@ def search_tools( # Clamp max_results to valid range max_results = max(1, min(20, max_results)) + mode = os.getenv("TOOL_SEARCH_MODE", "hybrid").strip().lower() + + # Check cache + cache_key = f"{mode}:{query.lower().strip()}:{max_results}" + cached = _cache_get(cache_key) + if cached is not None: + return cached + + if mode == "api": + results = self._search_tools_api(query, model, max_results) + elif mode == "hybrid": + results = self.search_tools_local(query, max_results) + if not results or self._last_local_top_score < CONFIDENCE_THRESHOLD: + results = self._search_tools_api(query, model, max_results) + else: # "local" (default) + results = self.search_tools_local(query, max_results) + + _cache_put(cache_key, results) + return results + + # ------------------------------------------------------------------ + # Local search + # ------------------------------------------------------------------ + + def search_tools_local( + self, + query: str, + max_results: int = 10, + ) -> List[FunctionDefinition]: + """Search for tools using fast local keyword/category matching. + + No API call is made. Scoring uses: + 1. Category keyword boost (+5.0) + 2. Inverted index name match (+3.0) + 3. Inverted index description match (+1.0) + 4. Exact tool name match (+8.0) + 5. Action-verb alignment (+2.0) + 6. Intent boosts for confusion clusters + + Args: + query: Description of what the user wants to accomplish. + max_results: Maximum number of tools to return. + + Returns: + List of matching tool definitions, ordered by score. + """ + query_tokens = self._tokenize(query) + if not query_tokens: + return [] + + scores: Dict[str, float] = defaultdict(float) + + # 1. Exact tool name match + query_lower = query.lower().strip() + for token in query_tokens: + if token in _ALL_TOOL_NAMES: + scores[token] += 8.0 + + # Also check underscore-joined bigrams/trigrams for multi-word tool names + for i in range(len(query_tokens)): + for j in range(i + 1, min(i + 4, len(query_tokens) + 1)): + candidate = "_".join(query_tokens[i:j]) + if candidate in _ALL_TOOL_NAMES: + scores[candidate] += 8.0 + + # 2. Category keyword boost + matched_categories: set[str] = set() + for token in query_tokens: + cats = _CATEGORY_KEYWORD_INDEX.get(token, []) + matched_categories.update(cats) + + for cat_name in matched_categories: + cat_entry = TOOL_CATEGORIES[cat_name] + for tool_name in cat_entry["tools"]: + if tool_name in _ALL_TOOL_NAMES: + scores[tool_name] += 5.0 + + # 3. Inverted index name match + for token in query_tokens: + for tool_name in _TOOL_NAME_INDEX.get(token, []): + scores[tool_name] += 3.0 + + # 4. Inverted index description match + for token in query_tokens: + for tool_name in _TOOL_DESC_INDEX.get(token, []): + scores[tool_name] += 1.0 + + # 5. Action-verb alignment + for token in query_tokens: + prefix = _ACTION_VERB_MAP.get(token) + if prefix: + for tool_name in _ALL_TOOL_NAMES: + if tool_name.startswith(prefix): + scores[tool_name] += 2.0 + + # 6. Intent boosts (same as existing _tool_score confusion boosts) + self._apply_intent_boosts(query_tokens, scores, raw_query=query) + + # Sort by score descending, then alphabetically for ties + ranked = sorted(scores.items(), key=lambda item: (-item[1], item[0])) + + # Return top results as full FunctionDefinition objects + results: List[FunctionDefinition] = [] + for tool_name, score in ranked: + if score <= 0: + break + tool = _TOOL_BY_NAME.get(tool_name) + if tool is not None: + results.append(tool) + if len(results) >= max_results: + break + + # Store top score for hybrid-mode confidence check + self._last_local_top_score = ranked[0][1] if ranked else 0.0 + + _logger.info( + f"Local tool search for '{query}' found {len(results)} tools" + ) + return results + + @staticmethod + def _apply_intent_boosts( + query_tokens: List[str], + scores: Dict[str, float], + raw_query: str = "", + ) -> None: + """Apply intent-based score boosts for known confusion clusters.""" + token_set = set(query_tokens) + raw_lower = raw_query.lower() if raw_query else " ".join(query_tokens) + + # -- Transforms -- + if token_set & {"move", "shift", "translate", "slide"}: + scores["translate_object"] += 6.0 + if token_set & {"turn", "spin", "rotate"}: + scores["rotate_object"] += 6.0 + if token_set & {"reflect", "mirror", "flip"}: + scores["reflect_object"] += 4.0 + if token_set & {"twice", "double", "bigger", "larger", "smaller", "big"}: + scores["scale_object"] += 8.0 + + # -- Areas / shading -- + if token_set & {"shade", "shading", "color", "fill", "highlight"}: + scores["create_colored_area"] += 5.0 + scores["create_region_colored_area"] += 4.0 + if token_set & {"area"}: + scores["calculate_area"] += 2.0 + scores["create_colored_area"] += 3.0 + scores["create_region_colored_area"] += 2.0 + if token_set & {"region", "between"}: + scores["create_region_colored_area"] += 5.0 + scores["create_colored_area"] += 3.0 + + # -- Statistics -- + if token_set & {"distribution", "gaussian", "bell"}: + scores["plot_distribution"] += 6.0 + if "normal" in token_set and not (token_set & {"line", "perpendicular"}): + scores["plot_distribution"] += 4.0 + if token_set & {"bar", "bars"}: + scores["plot_bars"] += 6.0 + if token_set & {"regression", "fit", "fitting"}: + scores["fit_regression"] += 6.0 + if token_set & {"descriptive", "stats", "statistics", "mean", "median", "stdev", "quartile", "average"}: + scores["compute_descriptive_statistics"] += 6.0 + + # -- Linear algebra -- + if token_set & {"determinant", "eigenvalue", "matrix", "matrices"}: + scores["evaluate_linear_algebra_expression"] += 6.0 + if "multiply" in token_set and token_set & {"matrix", "matrices"}: + scores["evaluate_linear_algebra_expression"] += 6.0 + + # -- Undo/redo -- + if "undo" in token_set: + scores["undo"] += 8.0 + if "redo" in token_set: + scores["redo"] += 8.0 + if token_set & {"oops", "revert"}: + scores["undo"] += 8.0 + if "back" in token_set and "go" in token_set: + scores["undo"] += 8.0 + + # -- Tangent/normal lines -- + if "tangent" in token_set: + scores["draw_tangent_line"] += 6.0 + if "normal" in token_set and token_set & {"line", "perpendicular"}: + scores["draw_normal_line"] += 6.0 + + # -- Constructions -- + if "inscribed" in token_set and token_set & {"circle"}: + scores["construct_incircle"] += 10.0 + if "circumscribed" in token_set and token_set & {"circle"}: + scores["construct_circumcircle"] += 10.0 + if token_set & {"perpendicular"}: + # Question form ("are X and Y perpendicular?") → inspect, not construct + is_question = ( + raw_lower.endswith("?") + or "each other" in raw_lower + or raw_lower.lstrip().startswith("are ") + or raw_lower.lstrip().startswith("is ") + or token_set & {"check", "verify", "whether"} + ) + if is_question: + scores["inspect_relation"] += 10.0 + else: + scores["construct_perpendicular_bisector"] += 3.0 + scores["construct_perpendicular_from_point"] += 3.0 + if token_set & {"midpoint", "middle"}: + scores["construct_midpoint"] += 6.0 + + # -- Delete with casual language -- + if "rid" in token_set: + # "get rid of" is a common idiom for delete + for tool_name in _ALL_TOOL_NAMES: + if tool_name.startswith("delete_"): + scores[tool_name] += 4.0 + # Penalize create_ tools when intent is clearly delete + if tool_name.startswith("create_"): + scores[tool_name] -= 6.0 + + # -- Solve family -- + if token_set & {"solve", "find"} and token_set & {"system", "simultaneous", "equations"}: + scores["solve_system_of_equations"] += 8.0 + elif token_set & {"system", "simultaneous"}: + scores["solve_system_of_equations"] += 6.0 + # Detect multiple equations: "... and ..." pattern with = signs + _eq_count = raw_lower.count("=") + if _eq_count >= 2 and token_set & {"solve", "find"}: + scores["solve_system_of_equations"] += 6.0 + elif _eq_count >= 2 and " and " in raw_lower: + scores["solve_system_of_equations"] += 5.0 + if token_set & {"solve", "find"} and token_set & {"numeric", "numerical", "numerically", "approximate", "root"}: + scores["solve_numeric"] += 6.0 + elif token_set & {"numeric", "numerical", "numerically", "approximate"}: + scores["solve_numeric"] += 4.0 + if token_set & {"solve"}: + scores["solve"] += 4.0 + if token_set & {"roots"}: + scores["solve"] += 6.0 + # Word-problem patterns (cost, spent, buy, how many) + if token_set & {"cost", "spent", "buy", "price", "total"} and token_set & {"how", "many", "each"}: + scores["solve_system_of_equations"] += 8.0 + scores["solve"] += 6.0 + elif "how" in token_set and "many" in token_set: + scores["solve"] += 4.0 + scores["solve_system_of_equations"] += 3.0 + # "find x" without system/numeric context -> solve + if "find" in token_set and not (token_set & {"system", "simultaneous", "numeric", "numerical", "numerically", + "approximate", "shortest", "path", "bfs", "dfs"}): + scores["solve"] += 3.0 + + # -- Calculus -- + if token_set & {"derivative", "differentiate", "diff", "d/dx"}: + scores["derive"] += 6.0 + if token_set & {"integral", "integrate", "integration"}: + scores["integrate"] += 6.0 + scores["numeric_integrate"] += 2.0 + if token_set & {"simplify", "reduce"}: + scores["simplify"] += 6.0 + if token_set & {"expand", "multiply"} and not (token_set & {"matrix", "matrices"}): + scores["expand"] += 6.0 + if token_set & {"factor", "factorize", "factorise", "factored"}: + scores["factor"] += 6.0 + if token_set & {"limit", "lim"}: + scores["limit"] += 6.0 + if token_set & {"compute", "calculate", "evaluate"} and not ( + token_set & {"area", "statistics", "descriptive", "stats", "mean", "median"} + ): + scores["evaluate_expression"] += 3.0 + + # -- Convert (unit vs coordinate) -- + if token_set & {"convert", "change"} and token_set & { + "unit", "units", "temperature", "celsius", "fahrenheit", + "miles", "km", "kilometers", "meters", "inches", "feet", + }: + scores["convert"] += 8.0 + if token_set & {"convert", "change"} and token_set & { + "polar", "cartesian", "coordinate", "coordinates", + "rectangular", + }: + scores["convert_coordinates"] += 8.0 + + # -- Functions/plotting -- + if token_set & {"parametric"}: + scores["draw_parametric_function"] += 6.0 + # Detect parametric-like patterns: x=f(t), y=g(t) or x(t)=..., y(t)=... + if ("x=" in raw_lower and "y=" in raw_lower and + any(fn in raw_lower for fn in ("cos(t)", "sin(t)", "t*", "(t)"))): + scores["draw_parametric_function"] += 8.0 + if token_set & {"piecewise", "rules"}: + scores["draw_piecewise_function"] += 6.0 + # Detect piecewise patterns: "for x<..." or "when x>..." with multiple conditions + if re.search(r"for\s+x\s*[<>]|when\s+x\s*[<>]|x\s*>=|x\s*<=", raw_lower): + scores["draw_piecewise_function"] += 6.0 + # "graph f(x)" / "graph y=" means draw_function, not graph theory + if "graph" in token_set and token_set & {"f(x)", "y=", "sin", "cos", "exp", "ln", "log"}: + scores["draw_function"] += 6.0 + # "graph f(x)=..." pattern: "graph" + words suggesting a function equation + if "graph" in token_set and not ( + token_set & {"vertex", "vertices", "edge", "edges", "node", "nodes", + "directed", "undirected", "weighted", "shortest", "path", + "bfs", "dfs", "mst", "topological", "spanning", "adjacency", + "degree", "dag", "tree", "network"} + ): + # Boost function plotting when "graph" appears without graph-theory context + scores["draw_function"] += 4.0 + scores["draw_piecewise_function"] += 2.0 + # Penalize graph-theory tools to avoid confusion + scores["generate_graph"] -= 3.0 + scores["analyze_graph"] -= 3.0 + scores["delete_graph"] -= 3.0 + if token_set & {"curve", "plot"} and not (token_set & {"bar", "bars", "distribution"}): + scores["draw_function"] += 3.0 + + # -- Graph theory (only with explicit graph-theory context) -- + if "graph" in token_set and token_set & { + "vertex", "vertices", "edge", "edges", "node", "nodes", + "directed", "undirected", "weighted", "network", + }: + scores["generate_graph"] += 8.0 + scores["analyze_graph"] += 4.0 + if token_set & {"shortest", "bfs", "dfs", "mst", "topological", "spanning", "degree"}: + scores["analyze_graph"] += 8.0 + if "graph" in token_set and token_set & {"analysis", "analyze", "statistics", "degree"}: + scores["analyze_graph"] += 6.0 + + # -- Inspection -- + if token_set & {"inspect", "relation", "relationship"}: + scores["inspect_relation"] += 6.0 + + # -- Workspace operations -- + if token_set & {"workspace", "project", "session"}: + for name in ("save_workspace", "load_workspace", "list_workspaces", "delete_workspace"): + scores[name] += 4.0 + if token_set & {"save", "persist", "store"}: + scores["save_workspace"] += 5.0 + if token_set & {"load", "open", "restore"}: + scores["load_workspace"] += 5.0 + if token_set & {"list", "inventory", "available", "names"}: + scores["list_workspaces"] += 5.0 + + # -- Canvas operations -- + if token_set & {"zoom", "viewport", "pan", "reframe", "narrow", "window"}: + scores["zoom"] += 6.0 + if "reset" in token_set and token_set & {"canvas", "zoom", "view", "default"}: + scores["reset_canvas"] += 6.0 + if token_set & {"wipe", "clean", "fresh"}: + scores["clear_canvas"] += 8.0 + scores["reset_canvas"] += 4.0 + if token_set & {"grid"}: + scores["set_grid_visible"] += 6.0 + + # -- Labels -- + if token_set & {"text", "annotation", "note", "label", "annotate"}: + scores["create_label"] += 5.0 + if token_set & {"delete"} and token_set & {"plot"}: + scores["delete_plot"] += 6.0 + + # -- Polygons -- + if token_set & {"triangle", "quadrilateral", "pentagon", "hexagon", "rectangle", "square"}: + scores["create_polygon"] += 4.0 + scores["delete_polygon"] += 4.0 + scores["update_polygon"] += 4.0 + + # -- Coordinate system -- + if token_set & {"polar", "cartesian"} and token_set & {"system", "mode", "switch"}: + scores["set_coordinate_system"] += 6.0 + + # ------------------------------------------------------------------ + # API-based search (original implementation) + # ------------------------------------------------------------------ + + def _search_tools_api( + self, + query: str, + model: Optional[AIModel] = None, + max_results: int = 10, + ) -> List[FunctionDefinition]: + """Search for tools using the AI API (original implementation). + + Args: + query: Description of what the user wants to accomplish. + model: AI model to use for matching. Defaults to gpt-4.1-mini. + max_results: Maximum number of tools to return (1-20). + + Returns: + List of matching tool definitions, ordered by relevance. + """ # Use provided model, instance default, or fallback to gpt-4.1-mini. if model is None: - model = self.default_model or AIModel.from_identifier("gpt-4.1-mini") + model = self.default_model or AIModel.from_identifier("gpt-5-nano") # Build the prompt tool_descriptions = self.build_tool_descriptions() @@ -321,12 +1007,6 @@ def _extract_list_from_parsed(parsed: Any) -> List[str]: """Extract a list of strings from a parsed JSON value. Handles both direct arrays and objects with 'tools' key. - - Args: - parsed: The parsed JSON value. - - Returns: - List of tool name strings. """ # Direct array if isinstance(parsed, list): @@ -346,12 +1026,6 @@ def _parse_tool_names(content: str) -> List[str]: - JSON arrays: ["tool1", "tool2"] - JSON objects: {"tools": ["tool1", "tool2"]} - Markdown code blocks with JSON - - Args: - content: The AI response content. - - Returns: - List of tool names extracted from the response. """ content = content.strip()