Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .claude/CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ AUTH_PIN=123456 # Optional: access code when auth is enabled
REQUIRE_AUTH=true # Force authentication in local development
PORT=5000 # Set by hosting platforms to indicate deployed mode
SECRET_KEY=override-me # Optional: otherwise random key generated per launch
TOOL_SEARCH_MODE=hybrid # Tool discovery: local | api | hybrid (default: hybrid)
```

## Running the App
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ MatHud pairs an interactive drawing canvas with an AI assistant to help visualiz
REQUIRE_AUTH=true # Force authentication in local development
PORT=5000 # Set by hosting platforms to indicate deployed mode
SECRET_KEY=override-me # Optional: otherwise a random key is generated per launch
TOOL_SEARCH_MODE=hybrid # Tool discovery: local | api | hybrid (default: hybrid)
```
2. Authentication rules (`static/app_manager.py`):
1. When `PORT` is set (typical in hosted deployments), authentication is enforced automatically.
Expand Down
3 changes: 2 additions & 1 deletion documentation/Project Architecture.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ MatHud is an interactive mathematical visualization tool that combines a drawing
- Coordinate Conversion: rectangular to polar and polar to rectangular transformations
- Equation Solving: linear equations, quadratic equations, systems of equations, numeric solving for transcendental/nonlinear systems via multi-start Newton-Raphson
- Linear Algebra: matrix and vector arithmetic, transpose, inverse, determinant, dot/cross products, norms, traces, diagonal/identity/zeros/ones helpers, reshape/size queries, and grouped expressions evaluated via MathJS
- Advanced Operations: 70 AI tool definitions available in `static/functions_definitions.py` including canvas operations, geometric shape creation/deletion, mathematical operations, transformations, coordinate system management, and workspace management
- Advanced Operations: 90+ AI tool definitions available in `static/functions_definitions.py` including canvas operations, geometric shape creation/deletion, mathematical operations, transformations, coordinate system management, and workspace management
- Tool Discovery: `static/tool_search_service.py` provides fast local keyword/category search (hybrid mode by default) to select relevant tools per query without requiring an API call for every interaction

## Geometric Shape Management
- Points: create, delete, translate points with coordinates and labels
Expand Down
31 changes: 31 additions & 0 deletions documentation/Reference Manual.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3761,6 +3761,37 @@ Dependencies:
- Parameter validation: Strict JSON schema enforcement for all function arguments
- Type safety: Required parameters and type checking for robust AI integration

### Tool Search Service (`static/tool_search_service.py`)

Provides tool discovery via fast local keyword/category matching (default) or AI-powered semantic matching. The search mode is selected by the `TOOL_SEARCH_MODE` environment variable:

- `local` — keyword + category index, no API call (~0.01ms per query)
- `api` — AI-powered semantic search via gpt-5-nano
- `hybrid` (default) — local first, falls back to API when confidence is low

**Architecture:**
- 13 tool categories (geometry_create, geometry_delete, geometry_update, constructions, functions_plots, math, statistics, graph_theory, canvas, workspace, transforms, areas, inspection) with keyword triggers
- Inverted indices built at module load time for O(1) token lookups
- Multi-signal scoring: category boost, name index match, description index match, exact tool name match, action-verb alignment, and intent-based disambiguation boosts
- LRU result cache with 5-minute TTL (100 entries max)
- Lazy OpenAI client initialization — local mode never touches the network

**Class: ToolSearchService**

Key Methods:
- `search_tools(query, model=None, max_results=10)`: Main entry point; dispatches to local or API search based on mode
- `search_tools_local(query, max_results=10)`: Fast local keyword/category search (no API call)
- `search_tools_formatted(query, model=None, max_results=10)`: Returns dict with tools list, count, and query for AI consumption
- `get_tool_by_name(name)`: Look up a tool definition by name (static method)
- `build_tool_descriptions(exclude_meta_tools=True)`: Build compact string of tool names and descriptions (static method)
- `get_all_tools()`: Get all available tool definitions (static method)

Module-level Functions:
- `clear_search_cache()`: Clear the search result cache

**Environment Variables:**
- `TOOL_SEARCH_MODE`: Search mode selection (`local`, `api`, or `hybrid`; default: `hybrid`)

### Flask Application Manager (`static/app_manager.py`)

**File Header:**
Expand Down
298 changes: 298 additions & 0 deletions scripts/compare_search_modes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
#!/usr/bin/env python3
"""Compare local vs API tool search accuracy and latency.

Runs the benchmark dataset against both search modes and outputs a
side-by-side comparison table with disagreement analysis.

Usage::

# Local only (no API key needed)
python scripts/compare_search_modes.py --modes local

# Full comparison (needs API key)
python scripts/compare_search_modes.py --modes local,api

# Save disagreements to CSV
python scripts/compare_search_modes.py --modes local,api --disagreements /tmp/disagreements.csv
"""

from __future__ import annotations

import argparse
import csv
import json
import os
import sys
import time
from pathlib import Path
from typing import Any, Dict, List, Optional

_project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if _project_root not in sys.path:
sys.path.insert(0, _project_root)

from static.tool_search_service import ToolSearchService, clear_search_cache

DATASET_PATH = Path("server_tests/data/tool_discovery_cases.yaml")


def _load_dataset() -> Dict[str, Any]:
raw = DATASET_PATH.read_text(encoding="utf-8")
parsed: Dict[str, Any] = json.loads(raw)
return parsed


def _get_tool_name(tool: Dict[str, Any]) -> str:
name: str = tool.get("function", {}).get("name", "")
return name


def _run_mode(
mode: str,
cases: List[Dict[str, Any]],
max_results: int = 10,
) -> Dict[str, Any]:
"""Run benchmark for a single search mode."""
os.environ["TOOL_SEARCH_MODE"] = mode
clear_search_cache()

service = ToolSearchService.__new__(ToolSearchService)
service._client = None
service._client_initialized = False
service.default_model = None
service.last_error = None

# For API mode, initialize the client properly
if mode == "api":
try:
service = ToolSearchService()
except ValueError:
print(f" [SKIP] API mode requires OPENAI_API_KEY")
return {"skipped": True}

top1_hits = 0
top3_hits = 0
top5_hits = 0
evaluated = 0
latencies: List[float] = []
case_results: List[Dict[str, Any]] = []

for case in cases:
expected_any = [str(x) for x in case.get("expected_any", []) if isinstance(x, str)]
if not expected_any:
continue

query = str(case.get("query", "")).strip()
evaluated += 1

start = time.perf_counter()
if mode == "local":
results = service.search_tools_local(query, max_results)
else:
results = service.search_tools(query, max_results=max_results)
elapsed_ms = (time.perf_counter() - start) * 1000
latencies.append(elapsed_ms)

ranked = [_get_tool_name(t) for t in results]
expected_set = set(expected_any)

top1 = ranked[0] if ranked else ""
is_top1 = top1 in expected_set
is_top3 = bool(expected_set & set(ranked[:3]))
is_top5 = bool(expected_set & set(ranked[:5]))

if is_top1:
top1_hits += 1
if is_top3:
top3_hits += 1
if is_top5:
top5_hits += 1

case_results.append({
"id": case.get("id", "?"),
"query": query,
"expected": expected_any,
"top1": top1,
"top3": ranked[:3],
"is_top1": is_top1,
"is_top3": is_top3,
"is_top5": is_top5,
})

# Rate-limit API calls
if mode == "api":
time.sleep(0.5)

avg_latency = sum(latencies) / len(latencies) if latencies else 0
latencies.sort()
p50 = latencies[len(latencies) // 2] if latencies else 0
p99_idx = min(int(len(latencies) * 0.99), len(latencies) - 1) if latencies else 0
p99 = latencies[p99_idx] if latencies else 0

return {
"skipped": False,
"evaluated": evaluated,
"top1_hits": top1_hits,
"top3_hits": top3_hits,
"top5_hits": top5_hits,
"top1_rate": top1_hits / evaluated if evaluated else 0,
"top3_rate": top3_hits / evaluated if evaluated else 0,
"top5_rate": top5_hits / evaluated if evaluated else 0,
"avg_latency_ms": avg_latency,
"p50_latency_ms": p50,
"p99_latency_ms": p99,
"case_results": case_results,
}


def _find_disagreements(
local_results: Dict[str, Any],
api_results: Dict[str, Any],
) -> List[Dict[str, Any]]:
"""Find cases where local and API disagree."""
disagreements: List[Dict[str, Any]] = []
local_cases = local_results.get("case_results", [])
api_cases = api_results.get("case_results", [])

api_by_id = {c["id"]: c for c in api_cases}

for lc in local_cases:
ac = api_by_id.get(lc["id"])
if ac is None:
continue

if lc["is_top1"] != ac["is_top1"]:
disagreements.append({
"id": lc["id"],
"query": lc["query"],
"expected": lc["expected"],
"local_top1": lc["top1"],
"api_top1": ac["top1"],
"local_correct": lc["is_top1"],
"api_correct": ac["is_top1"],
"winner": "local" if lc["is_top1"] else "api",
})

return disagreements


def main() -> int:
parser = argparse.ArgumentParser(description="Compare tool search modes")
parser.add_argument(
"--modes",
default="local",
help="Comma-separated search modes to compare (local,api)",
)
parser.add_argument(
"--max-results",
type=int,
default=10,
help="Max results per search (default: 10)",
)
parser.add_argument(
"--disagreements",
default="",
help="Path to write disagreements CSV",
)
args = parser.parse_args()

modes = [m.strip() for m in args.modes.split(",") if m.strip()]
if not modes:
print("No modes specified")
return 1

dataset = _load_dataset()
cases = dataset.get("cases", [])

print(f"Dataset: {len(cases)} cases")
print()

all_results: Dict[str, Dict[str, Any]] = {}

for mode in modes:
print(f"Running mode: {mode}")
result = _run_mode(mode, cases, args.max_results)
all_results[mode] = result

if result.get("skipped"):
print(f" Skipped (missing credentials)")
continue

print(f" Evaluated: {result['evaluated']}")
print(f" Top-1: {result['top1_hits']}/{result['evaluated']} = {result['top1_rate']:.3f}")
print(f" Top-3: {result['top3_hits']}/{result['evaluated']} = {result['top3_rate']:.3f}")
print(f" Top-5: {result['top5_hits']}/{result['evaluated']} = {result['top5_rate']:.3f}")
print(f" Latency: avg={result['avg_latency_ms']:.1f}ms, "
f"p50={result['p50_latency_ms']:.1f}ms, p99={result['p99_latency_ms']:.1f}ms")
print()

# Side-by-side comparison
if len(all_results) >= 2:
active = {k: v for k, v in all_results.items() if not v.get("skipped")}
if len(active) >= 2:
print("=" * 60)
print("Side-by-Side Comparison")
print("=" * 60)
header = f"{'Metric':<20}"
for mode in active:
header += f" {mode:>15}"
print(header)
print("-" * 60)

for metric in ["top1_rate", "top3_rate", "top5_rate", "avg_latency_ms", "p50_latency_ms", "p99_latency_ms"]:
row = f"{metric:<20}"
for mode in active:
val = active[mode].get(metric, 0)
if "rate" in metric:
row += f" {val:>14.3f}"
else:
row += f" {val:>13.1f}ms"
print(row)
print()

# Disagreement analysis
if "local" in all_results and "api" in all_results:
local_r = all_results["local"]
api_r = all_results["api"]
if not local_r.get("skipped") and not api_r.get("skipped"):
disagreements = _find_disagreements(local_r, api_r)
local_wins = [d for d in disagreements if d["winner"] == "local"]
api_wins = [d for d in disagreements if d["winner"] == "api"]

print(f"Disagreements: {len(disagreements)} total")
print(f" Local wins: {len(local_wins)}")
print(f" API wins: {len(api_wins)}")

if api_wins:
print(f"\nCases where API is right but local is wrong (tuning opportunities):")
for d in api_wins[:10]:
print(f" {d['id']}: {d['query']!r}")
print(f" expected={d['expected']}, local={d['local_top1']!r}, api={d['api_top1']!r}")

if local_wins:
print(f"\nCases where local is right but API is wrong (local advantages):")
for d in local_wins[:10]:
print(f" {d['id']}: {d['query']!r}")
print(f" expected={d['expected']}, local={d['local_top1']!r}, api={d['api_top1']!r}")

# Write disagreements CSV
if args.disagreements and disagreements:
dis_path = Path(args.disagreements)
dis_path.parent.mkdir(parents=True, exist_ok=True)
with dis_path.open("w", encoding="utf-8", newline="") as f:
writer = csv.DictWriter(f, fieldnames=[
"id", "query", "expected", "local_top1", "api_top1",
"local_correct", "api_correct", "winner",
])
writer.writeheader()
for d in disagreements:
csv_row: Dict[str, Any] = dict(d)
csv_row["expected"] = "|".join(csv_row["expected"])
writer.writerow(csv_row)
print(f"\nDisagreements written to: {dis_path}")

return 0


if __name__ == "__main__":
raise SystemExit(main())
Loading