From a641084010907bd058f3874572ea24aedb9cf94b Mon Sep 17 00:00:00 2001 From: Vlad Ciobanu <95963142+vl3c@users.noreply.github.com> Date: Tue, 24 Mar 2026 20:22:33 +0200 Subject: [PATCH 01/28] Add architectural review document Comprehensive review covering god classes, manager duplication, error handling inconsistencies, and CI/CD gaps with a phased refactoring roadmap. --- .../architectural_review_2026_03_24.md | 200 ++++++++++++++++++ 1 file changed, 200 insertions(+) create mode 100644 documentation/development/architectural_review_2026_03_24.md diff --git a/documentation/development/architectural_review_2026_03_24.md b/documentation/development/architectural_review_2026_03_24.md new file mode 100644 index 00000000..ca212353 --- /dev/null +++ b/documentation/development/architectural_review_2026_03_24.md @@ -0,0 +1,200 @@ +# MatHud Architectural Review — 2026-03-24 + +## Overall Assessment + +The project has strong foundations: good test coverage (~3,745 tests), strict mypy/ruff enforcement, clean manager-based decomposition, and well-documented conventions. The main issues are **god classes**, **copy-paste patterns across managers**, **inconsistent error handling**, and **missing CI/CD automation**. + +--- + +## CRITICAL: God Classes + +These files have accumulated too many responsibilities and need decomposition: + +### 1. `static/client/ai_interface.py` (2,342 lines, 30+ instance attributes) + +The single largest architectural problem. This class mixes: +- HTTP/AJAX communication protocol +- Streaming response buffering (text + reasoning + tool calls — 3 independent state machines) +- Chat UI DOM manipulation (message containers, image attachments, scroll behavior) +- TTS playback coordination +- Action trace collection +- Test execution state + +**Suggested extraction:** +- `ChatUIManager` — DOM manipulation, message rendering, image attachment UI +- `StreamingResponseHandler` — stream buffering, chunk assembly, timeout tracking +- `ReasoningHandler` — reasoning token streaming/display +- `ToolCallLogger` — tool call log entries and summary + +This would reduce AIInterface to ~600-800 lines of orchestration + AJAX transport. + +### 2. `static/client/canvas.py` (2,194 lines, 80+ public methods) + +Canvas acts as a universal coordinator. It should delegate more aggressively: +- Lines 297-356: Visibility culling logic — extract to `VisibilityManager` +- Lines 225-241: Frame batching logic — belongs in the renderer layer +- Lines 168-199: Legacy drawable registration — indicates incomplete migration +- Lines 360-403: Zoom displacement calculations — should use `coordinate_mapper` directly + +Target: reduce to ~500 lines focused on initialization, public API delegation, and state archiving. + +### 3. `static/routes.py` — `send_message_stream()` (194 lines) and `send_message()` (167 lines) + +Both route handlers mix request parsing, provider selection, vision capture, tool search interception, streaming, error recovery, and tool injection/reset. They share significant duplicated logic: +- Tool reset logic duplicated 3x (lines 891-897, 909-915, 1146-1152) +- Provider model setting duplicated (lines 811-812 vs 1108-1109) + +**Suggested extraction:** +- `_validate_and_parse_message_request()` — shared request parsing +- `ProviderManager` — unified interface over `ai_api`, `responses_api`, and custom providers +- `ToolInjectionManager` — search/inject/reset tool lifecycle +- `VisionManager` — vision capture, WebDriver init, image handling + +--- + +## HIGH: Manager Copy-Paste Pattern + +All 10+ drawable managers (`PointManager`, `SegmentManager`, `CircleManager`, `EllipseManager`, `ArcManager`, `AngleManager`, etc.) reimplement identical patterns: + +```python +def __init__(self, canvas, drawables_container, name_generator, dependency_manager, ...): + self.canvas = canvas + self.drawables = drawables_container + self.name_generator = name_generator + self.dependency_manager = dependency_manager + self._edit_policy = get_drawable_edit_policy(...) +``` + +Every manager also reimplements deletion-with-dependency-cleanup identically. + +**Fix:** Create `BaseDrawableManager` with: +- Common `__init__` accepting shared dependencies +- Standard `create()`, `get_by_name()`, `delete()` contract +- Built-in edit policy setup +- `remove_drawable_with_dependencies()` as inherited method + +--- + +## HIGH: Inconsistent Error Handling + +Three different error handling strategies coexist: + +| Pattern | Where | Problem | +|---|---|---| +| `print()` | `routes.py` lines 103, 177, 183, 905 | Not captured in logs | +| `logging.error()` | `tool_call_processor.py` lines 63-75 | Correct approach | +| Bare `except Exception: pass` | `routes.py` line 886, `workspace_manager.py` line 137, `drawables_container.py` lines 77-80 | Silently swallows errors | +| `except Exception: return` | `routes.py` line 239 | Returns potentially invalid state | + +Additionally, `static/client/utils/math_utils.py` (lines 2363, 2414, 2422) silently skips asymptote/limit calculations on failure. + +**Fix:** Standardize on Python `logging` module throughout. Replace all bare `except:` blocks with specific exception types. Create per-module loggers. + +--- + +## HIGH: Production Debug Code + +1. **Test error trigger in routes.py** (lines 855-858): + ```python + # TEMPORARY TEST TRIGGER - REMOVE AFTER TESTING + if "TEST_ERROR_TRIGGER_12345" in message: + raise ValueError("Test error triggered for message recovery testing") + ``` + +2. **Debug print statements in expression_evaluator.py** (lines 50, 72, 83, 90): + ```python + print(f"Evaluated numeric expression: {expression} = {result}") # DEBUG + ``` + +--- + +## MEDIUM: Workspace Serialization Issues + +### State mutation during serialization +`Segment.get_state()` calls `_sync_label_position()` which mutates internal state. Serialization should be read-only. + +### No schema versioning on drawables +Each drawable's `get_state()` returns `Dict[str, Any]` with no version field. Format varies by type. No validation on deserialization. + +### Client workspace_manager.py (1,421 lines) +Restoration logic uses per-type methods that are largely boilerplate. Should use a factory pattern where each drawable type owns its own `from_state()` class method. + +--- + +## MEDIUM: Missing Abstractions + +### Provider management +Code repeatedly operates on `app.ai_api`, `app.responses_api`, and `app.providers` dict separately. A `ProviderManager` would centralize model resolution, tool injection, and conversation lifecycle. + +### DrawableManagerProxy +Uses `__getattr__` reflection to break circular initialization. Adds runtime overhead and defeats IDE type narrowing. Consider restructuring the dependency graph. + +### Naming inconsistency +`DrawablesContainer` uses `add()`/`remove()` while all managers use `create_*()`/`delete_*()`. + +--- + +## MEDIUM: Configuration Scattered + +Constants that should be centralized are spread across files: + +| Constant | Location | Should be in | +|---|---|---| +| `AI_RESPONSE_TIMEOUT_MS = 60000` | `ai_interface.py:69` | `constants.py` | +| `REASONING_TIMEOUT_MS = 300000` | `ai_interface.py:70` | `constants.py` | +| `MAX_ATTACHED_IMAGES = 5` | `ai_interface.py:72` | `constants.py` | +| `IMAGE_SIZE_WARNING_BYTES = 10MB` | `ai_interface.py:73` | `constants.py` | +| `_MAX_TRACES = 100` | `action_trace_collector.py` | `constants.py` | +| `WORKSPACES_DIR = "workspaces"` | `workspace_manager.py:22` | `config.py` (server) | +| `CANVAS_SNAPSHOT_DIR` | `routes.py:47` | `config.py` (server) | + +--- + +## MEDIUM: Renderer Telemetry Duplication + +`SvgTelemetry` and `Canvas2DTelemetry` are near-identical classes. Extract a `BaseRendererTelemetry` class. + +--- + +## LOW: CI/CD and Infrastructure Gaps + +| Gap | Impact | +|---|---| +| No automated test execution on PR/push | Regressions can be merged undetected | +| No code coverage measurement | Can't track coverage trends | +| No linting checks in CI (ruff/mypy) | Type errors can be merged | +| 7 unpinned dependencies in requirements.txt | Reproducibility risk | +| No Dependabot or dependency scanning | Security vulnerability blind spot | + +--- + +## LOW: Test Coverage Gaps + +- `result_validator.py` (102 lines) — no dedicated tests +- `command_autocomplete.py` (504 lines) — minimal testing +- `canvas_event_handler.py` (622 lines) — only 2 test files for complex state machine +- `functions_definitions.py` (2,731 lines) — manual JSON schemas with no code generation + +--- + +## Refactoring Roadmap + +### Phase 1: Quick Wins (low risk, high value) +1. Remove debug code (test trigger in routes.py, print statements in expression_evaluator.py) +2. Centralize scattered constants +3. Standardize error handling (replace `print()` with `logging`, eliminate bare `except:`) +4. Pin all dependencies in requirements.txt + +### Phase 2: Structural (medium risk, high value) +5. Create `BaseDrawableManager` — eliminates manager duplication +6. Extract `ChatUIManager` and `StreamingResponseHandler` from AIInterface +7. Extract `VisibilityManager` from Canvas +8. Create `ProviderManager` to unify provider operations in routes.py +9. Make `get_state()` side-effect-free across all drawables + +### Phase 3: Architecture (higher risk, long-term value) +10. Add CI/CD pipeline with automated tests + linting +11. Implement drawable state schema versioning +12. Restructure dependency graph to eliminate `DrawableManagerProxy` +13. Add workspace format migration support +14. Extract route handler logic into service classes for testability From 22c4affc735e888217fb47de091d6762585e7e5e Mon Sep 17 00:00:00 2001 From: Vlad Ciobanu <95963142+vl3c@users.noreply.github.com> Date: Tue, 24 Mar 2026 20:22:58 +0200 Subject: [PATCH 02/28] Remove debug code left in production paths Remove TEST_ERROR_TRIGGER_12345 test block from routes.py and debug print statements from expression_evaluator.py and result_processor.py. --- static/client/expression_evaluator.py | 4 -- static/client/result_processor.py | 4 -- static/routes.py | 83 +++++++-------------------- 3 files changed, 22 insertions(+), 69 deletions(-) diff --git a/static/client/expression_evaluator.py b/static/client/expression_evaluator.py index 8af03c5a..6e533c3f 100644 --- a/static/client/expression_evaluator.py +++ b/static/client/expression_evaluator.py @@ -47,7 +47,6 @@ def evaluate_numeric_expression(expression: str, variables: Dict[str, Any]) -> f float: The computed numeric result """ result: Any = MathUtils.evaluate(expression, variables) - print(f"Evaluated numeric expression: {expression} = {result}") # DEBUG # Convert numeric results to float for consistency if isinstance(result, (int, float)): result = float(result) @@ -69,7 +68,6 @@ def evaluate_function(expression: str, canvas: "Canvas") -> float: Raises: ValueError: If canvas is None, expression format is invalid, or function not found """ - print(f"Evaluating function with expression: {expression}") # DEBUG if canvas is None: raise ValueError("Cannot evaluate function: no canvas available") @@ -80,14 +78,12 @@ def evaluate_function(expression: str, canvas: "Canvas") -> float: function_name: str argument: str function_name, argument = match.groups() - print(f"Function name: {function_name}, argument: {argument}") # DEBUG else: raise ValueError(f"Invalid function expression: {expression}") for function in functions: if function.name.lower() == function_name.lower(): # If the function name matches, evaluate the function - print(f"Found function: {function.name} = {function.function_string}") # DEBUG try: argument_val: float = float(argument) # Convert argument to float result: Any = function.function(argument_val) diff --git a/static/client/result_processor.py b/static/client/result_processor.py index 2634a7eb..26acb989 100644 --- a/static/client/result_processor.py +++ b/static/client/result_processor.py @@ -249,7 +249,6 @@ def _is_function_available( """Check if the function exists and update results if not.""" if function_name not in available_functions: error_msg: str = f"Error: function {function_name} not found." - print(error_msg) # DEBUG results[function_name] = error_msg return False return True @@ -338,9 +337,6 @@ def _handle_exception(exception: Exception, function_name: str, results: Dict[st function_name: Name of the function that caused the exception results: Dictionary to update with the error information """ - error_message: str = f"Error calling function {function_name}: {exception}" - print(error_message) # DEBUG - # Use the function name as the key for storing the error key: str = function_name diff --git a/static/routes.py b/static/routes.py index 7ae92005..8413bb52 100644 --- a/static/routes.py +++ b/static/routes.py @@ -16,6 +16,7 @@ import functools import hmac import json +import logging import math import os import time @@ -28,14 +29,18 @@ from static.ai_model import AIModel, PROVIDER_OPENAI, PROVIDER_ANTHROPIC, PROVIDER_OPENROUTER, PROVIDER_OLLAMA from static.app_manager import AppManager, MatHudFlask from static.canvas_state_summarizer import compare_canvas_states +from static.config import CANVAS_SNAPSHOT_DIR, CANVAS_SNAPSHOT_PATH from static.openai_api_base import OpenAIAPIBase from static.providers import ProviderRegistry, create_provider_instance +from static.route_helpers import get_active_provider, reset_tools_for_all_providers from static.tool_call_processor import ProcessedToolCall, ToolCallProcessor from static.tts_manager import get_tts_manager from static.webdriver_manager import SvgState F = TypeVar("F", bound=Callable[..., ResponseReturnValue]) +_logger = logging.getLogger(__name__) + # Global dictionary to track login attempts by IP address # Format: {ip_address: last_attempt_timestamp} JsonValue = Union[str, int, float, bool, None, Dict[str, "JsonValue"], List["JsonValue"]] @@ -44,8 +49,6 @@ login_attempts: Dict[str, float] = {} ToolCallList = List[ProcessedToolCall] -CANVAS_SNAPSHOT_DIR = "canvas_snapshots" -CANVAS_SNAPSHOT_PATH = os.path.join(CANVAS_SNAPSHOT_DIR, "canvas.png") def get_provider_for_model(app: MatHudFlask, model_id: str) -> OpenAIAPIBase: @@ -100,7 +103,7 @@ def save_canvas_snapshot_from_data_url(data_url: str) -> bool: try: image_bytes = base64.b64decode(encoded) except Exception as exc: - print(f"Failed to decode canvas snapshot: {exc}") + _logger.error("Failed to decode canvas snapshot: %s", exc) return False try: os.makedirs(CANVAS_SNAPSHOT_DIR, exist_ok=True) @@ -108,7 +111,7 @@ def save_canvas_snapshot_from_data_url(data_url: str) -> bool: snapshot_file.write(image_bytes) return True except Exception as exc: - print(f"Failed to write canvas snapshot: {exc}") + _logger.error("Failed to write canvas snapshot: %s", exc) return False @@ -174,13 +177,13 @@ def handle_vision_capture( try: init_webdriver() except Exception as exc: - print(f"Failed to initialize WebDriver for vision capture: {exc}") + _logger.error("Failed to initialize WebDriver for vision capture: %s", exc) if app.webdriver_manager is not None: try: app.webdriver_manager.capture_svg_state(cast(SvgState, svg_state)) except Exception as exc: - print(f"WebDriver capture failed: {exc}") + _logger.error("WebDriver capture failed: %s", exc) def _intercept_search_tools( @@ -237,6 +240,7 @@ def _intercept_search_tools( return _filter_tool_calls_by_allowed_names(tool_calls, allowed_names) except Exception: + _logger.exception("search_tools interception failed; returning original tool calls") return tool_calls # On error, return original calls @@ -717,7 +721,7 @@ def init_webdriver_route() -> ResponseReturnValue: base_url = f"http://127.0.0.1:{port}/" app.webdriver_manager = WebDriverManager(base_url=base_url) except Exception as e: - print(f"Failed to initialize WebDriverManager: {str(e)}") + _logger.error("Failed to initialize WebDriverManager: %s", e) return AppManager.make_response( message=f"WebDriver initialization failed: {str(e)}", status="error", code=500 ) @@ -807,14 +811,7 @@ def send_message_stream() -> ResponseReturnValue: attached_images = [img for img in attached_images_raw if isinstance(img, str)] # Get the provider for this model and update all relevant APIs - if ai_model: - app.ai_api.set_model(ai_model) - app.responses_api.set_model(ai_model) - # Get or create provider instance for this model - provider = get_provider_for_model(app, ai_model) - else: - # Use default OpenAI provider - provider = app.ai_api + provider = get_active_provider(app, ai_model) app.log_manager.log_user_message(message) @@ -852,11 +849,6 @@ def _yield_pending_logs() -> Iterator[str]: yield json.dumps(log_event) + "\n" try: - # TEMPORARY TEST TRIGGER - REMOVE AFTER TESTING - if "TEST_ERROR_TRIGGER_12345" in message: - raise ValueError("Test error triggered for message recovery testing") - # END TEMPORARY TEST TRIGGER - # Route to appropriate API based on model and provider model = provider.get_model() if model.provider == PROVIDER_OPENAI and model.is_reasoning_model: @@ -884,17 +876,10 @@ def _yield_pending_logs() -> Iterator[str]: event_dict["ai_tool_calls"] = cast(JsonValue, filtered_calls) app.log_manager.log_ai_tool_calls(filtered_calls) except Exception: - pass + _logger.exception("Failed to log/filter final stream event") # Reset tools if AI finished (not requesting more tool calls) finish_reason = event_dict.get("finish_reason") - if finish_reason != "tool_calls": - if app.ai_api.has_injected_tools(): - app.ai_api.reset_tools() - if app.responses_api.has_injected_tools(): - app.responses_api.reset_tools() - # Reset tools on active provider if different - if provider not in (app.ai_api, app.responses_api) and provider.has_injected_tools(): - provider.reset_tools() + reset_tools_for_all_providers(app, finish_reason, active_provider=provider) yield json.dumps(event_dict) + "\n" else: yield json.dumps(event) + "\n" @@ -903,16 +888,10 @@ def _yield_pending_logs() -> Iterator[str]: yield from _yield_pending_logs() except Exception as exc: error_msg = f"Streaming exception: {exc}" - print(f"[Routes /send_message] {error_msg}") + _logger.error("%s", error_msg) app.log_manager.log_error(error_msg, source="routes") # Reset tools on error - if app.ai_api.has_injected_tools(): - app.ai_api.reset_tools() - if app.responses_api.has_injected_tools(): - app.responses_api.reset_tools() - # Reset tools on active provider if different - if provider not in (app.ai_api, app.responses_api) and provider.has_injected_tools(): - provider.reset_tools() + reset_tools_for_all_providers(app, "error", active_provider=provider) # Yield pending logs so client sees them before error yield from _yield_pending_logs() # Include error details in the payload for transparency @@ -927,7 +906,7 @@ def _yield_pending_logs() -> Iterator[str]: yield json.dumps(error_payload) + "\n" except Exception: fallback_error_msg = "Failed to send detailed error payload; falling back." - print(f"[Routes /send_message] {fallback_error_msg}") + _logger.error("%s", fallback_error_msg) app.log_manager.log_error(fallback_error_msg, source="routes") fallback_payload: StreamEventDict = { "type": "final", @@ -1104,14 +1083,7 @@ def send_message() -> ResponseReturnValue: attached_images = [img for img in attached_images_raw if isinstance(img, str)] # Get the provider for this model and update all relevant APIs - if ai_model: - app.ai_api.set_model(ai_model) - app.responses_api.set_model(ai_model) - # Get or create provider instance for this model - provider = get_provider_for_model(app, ai_model) - else: - # Use default OpenAI provider - provider = app.ai_api + provider = get_active_provider(app, ai_model) app.log_manager.log_user_message(message) @@ -1140,17 +1112,6 @@ def send_message() -> ResponseReturnValue: # Store attached images in app context for API access app.current_attached_images = attached_images - def _reset_tools_if_needed(finish_reason: Any) -> None: - """Reset tools if AI finished (not requesting more tool calls).""" - if finish_reason != "tool_calls": - if app.ai_api.has_injected_tools(): - app.ai_api.reset_tools() - if app.responses_api.has_injected_tools(): - app.responses_api.reset_tools() - # Reset tools on active provider if different - if provider not in (app.ai_api, app.responses_api) and provider.has_injected_tools(): - provider.reset_tools() - try: # Route to appropriate API based on model and provider model = provider.get_model() @@ -1163,7 +1124,7 @@ def _reset_tools_if_needed(finish_reason: Any) -> None: break if final_event is None: - _reset_tools_if_needed("error") + reset_tools_for_all_providers(app, "error", active_provider=provider) return AppManager.make_response( message="No final response event produced", status="error", @@ -1185,7 +1146,7 @@ def _reset_tools_if_needed(finish_reason: Any) -> None: app.log_manager.log_ai_response(ai_message) app.log_manager.log_ai_tool_calls(ai_tool_calls) - _reset_tools_if_needed(finish_reason) + reset_tools_for_all_providers(app, finish_reason, active_provider=provider) return AppManager.make_response( data=cast( JsonObject, @@ -1205,7 +1166,7 @@ def _reset_tools_if_needed(finish_reason: Any) -> None: ai_tool_calls = _intercept_search_tools(app, ai_tool_calls, provider) finish_reason = getattr(choice, "finish_reason", None) - _reset_tools_if_needed(finish_reason) + reset_tools_for_all_providers(app, finish_reason, active_provider=provider) return AppManager.make_response( data=cast( JsonObject, @@ -1217,7 +1178,7 @@ def _reset_tools_if_needed(finish_reason: Any) -> None: ) ) except Exception as exc: - _reset_tools_if_needed("error") + reset_tools_for_all_providers(app, "error", active_provider=provider) return AppManager.make_response( message=str(exc), status="error", From f8cb3cf8be01349584c0415976593c556e856ebd Mon Sep 17 00:00:00 2001 From: Vlad Ciobanu <95963142+vl3c@users.noreply.github.com> Date: Tue, 24 Mar 2026 20:23:05 +0200 Subject: [PATCH 03/28] Centralize scattered constants into constants.py and config.py Move AI timeout/limit constants from ai_interface.py and trace limits from action_trace_collector.py into client constants.py. Create static/config.py for server-side paths and schema version. --- server_tests/test_workspace_management.py | 3 +-- static/client/ai_interface.py | 26 ++++++++---------- static/client/constants.py | 14 ++++++++++ .../client/managers/action_trace_collector.py | 16 +++++------ static/config.py | 27 +++++++++++++++++++ static/workspace_manager.py | 3 +-- 6 files changed, 61 insertions(+), 28 deletions(-) create mode 100644 static/config.py diff --git a/server_tests/test_workspace_management.py b/server_tests/test_workspace_management.py index 1ea2dbdf..e417c2d6 100644 --- a/server_tests/test_workspace_management.py +++ b/server_tests/test_workspace_management.py @@ -9,9 +9,8 @@ from server_tests.test_mocks import CanvasStateDict, MockCanvas from static.client.managers.polygon_type import PolygonType from static.client.utils.polygon_canonicalizer import canonicalize_rectangle +from static.config import CURRENT_WORKSPACE_SCHEMA_VERSION, WORKSPACES_DIR from static.workspace_manager import ( - CURRENT_WORKSPACE_SCHEMA_VERSION, - WORKSPACES_DIR, WorkspaceManager, WorkspaceState, ) diff --git a/static/client/ai_interface.py b/static/client/ai_interface.py index 8b09f86a..f3b99b8a 100644 --- a/static/client/ai_interface.py +++ b/static/client/ai_interface.py @@ -36,6 +36,12 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, cast from browser import document, html, ajax, window, console, aio +from constants import ( + AI_RESPONSE_TIMEOUT_MS, + IMAGE_SIZE_WARNING_BYTES, + MAX_ATTACHED_IMAGES, + REASONING_TIMEOUT_MS, +) from function_registry import FunctionRegistry from process_function_calls import ProcessFunctionCalls from result_processor import ResultProcessor @@ -65,16 +71,6 @@ class AIInterface: markdown_parser (MarkdownParser): Converts markdown text to HTML for rich formatting """ - # Timeout in milliseconds for AI responses (60 seconds for local LLMs) - AI_RESPONSE_TIMEOUT_MS: int = 60000 - # Extended timeout for reasoning models and local LLMs (5 minutes) - REASONING_TIMEOUT_MS: int = 300000 - - # Maximum number of images per message - MAX_ATTACHED_IMAGES: int = 5 - # Warning threshold for image size (10MB) - IMAGE_SIZE_WARNING_BYTES: int = 10 * 1024 * 1024 - def __init__(self, canvas: "Canvas") -> None: """Initialize the AI interface with canvas integration and function registry. @@ -327,11 +323,11 @@ def _on_files_selected(self, event: Any) -> None: # Check if we've hit the limit current_count = len(self._attached_images) - remaining = self.MAX_ATTACHED_IMAGES - current_count + remaining = MAX_ATTACHED_IMAGES - current_count if remaining <= 0: self._print_system_message_in_chat( - f"Maximum of {self.MAX_ATTACHED_IMAGES} images per message. Remove some to add more." + f"Maximum of {MAX_ATTACHED_IMAGES} images per message. Remove some to add more." ) file_input.value = "" return @@ -339,7 +335,7 @@ def _on_files_selected(self, event: Any) -> None: files_to_process = min(files.length, remaining) if files.length > remaining: self._print_system_message_in_chat( - f"Only attaching {remaining} of {files.length} images (limit: {self.MAX_ATTACHED_IMAGES})." + f"Only attaching {remaining} of {files.length} images (limit: {MAX_ATTACHED_IMAGES})." ) for i in range(files_to_process): @@ -355,7 +351,7 @@ def _read_and_attach_image(self, file: Any) -> None: """Read an image file and add it to the attached images list.""" try: # Check file size - if hasattr(file, "size") and file.size > self.IMAGE_SIZE_WARNING_BYTES: + if hasattr(file, "size") and file.size > IMAGE_SIZE_WARNING_BYTES: size_mb = file.size / (1024 * 1024) self._print_system_message_in_chat( f"Warning: Image '{file.name}' is {size_mb:.1f}MB. Large images may slow down processing." @@ -1828,7 +1824,7 @@ def _start_response_timeout(self, use_reasoning_timeout: bool = False) -> None: try: # Cancel any existing timeout first self._cancel_response_timeout() - timeout_ms = self.REASONING_TIMEOUT_MS if use_reasoning_timeout else self.AI_RESPONSE_TIMEOUT_MS + timeout_ms = REASONING_TIMEOUT_MS if use_reasoning_timeout else AI_RESPONSE_TIMEOUT_MS self._response_timeout_id = window.setTimeout(self._on_response_timeout, timeout_ms) except Exception as e: print(f"Error starting response timeout: {e}") diff --git a/static/client/constants.py b/static/client/constants.py index 339c3e6a..5eec6ea3 100644 --- a/static/client/constants.py +++ b/static/client/constants.py @@ -8,6 +8,8 @@ - Visual Styling: Point sizes, colors, fonts - User Interaction: Click thresholds, zoom factors - Angle Visualization: Arc display and text positioning + - AI Interface: Timeouts, image limits + - Action Trace Collector: Trace storage limits - Performance: Event throttling for smooth interactions Dependencies: @@ -59,6 +61,18 @@ # Default rendering backend used by Canvas when none is specified DEFAULT_RENDERER_MODE: str = "canvas2d" # other options: "svg", "webgl" +# ===== AI INTERFACE CONSTANTS ===== +# Timeouts and limits for AI communication +AI_RESPONSE_TIMEOUT_MS: int = 60000 # Timeout for AI responses (60 seconds) +REASONING_TIMEOUT_MS: int = 300000 # Extended timeout for reasoning models (5 minutes) +MAX_ATTACHED_IMAGES: int = 5 # Maximum number of images per message +IMAGE_SIZE_WARNING_BYTES: int = 10 * 1024 * 1024 # Warning threshold for image size (10MB) + +# ===== ACTION TRACE COLLECTOR CONSTANTS ===== +# Limits for the action trace storage system +MAX_TRACES: int = 100 # Maximum number of traces kept in FIFO store +MAX_RESULT_STR_LEN: int = 500 # Truncation length for result strings in exports + # ===== PERFORMANCE OPTIMIZATION CONSTANTS ===== # Event throttling settings for smooth user experience mousemove_throttle_ms: int = 8 # Mouse movement throttling (8ms = ~120fps for smooth panning) diff --git a/static/client/managers/action_trace_collector.py b/static/client/managers/action_trace_collector.py index 06847151..ef388c9f 100644 --- a/static/client/managers/action_trace_collector.py +++ b/static/client/managers/action_trace_collector.py @@ -12,15 +12,13 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from browser import window +from constants import MAX_RESULT_STR_LEN, MAX_TRACES if TYPE_CHECKING: TracedCall = Dict[str, Any] StateDelta = Dict[str, List[str]] ActionTrace = Dict[str, Any] -_MAX_TRACES = 100 -_MAX_RESULT_STR_LEN = 500 - # Functions that are not safe to replay (side-effects outside canvas state). _NON_REPLAYABLE_FUNCTIONS = frozenset( { @@ -82,7 +80,7 @@ def build_trace( # ------------------------------------------------------------------ def store(self, trace: "ActionTrace") -> None: - """Append *trace* to the in-memory store (FIFO, capped at _MAX_TRACES). + """Append *trace* to the in-memory store (FIFO, capped at MAX_TRACES). Full canvas snapshots are stripped from all but the latest trace. """ @@ -95,8 +93,8 @@ def store(self, trace: "ActionTrace") -> None: self._traces.append(trace) # Enforce FIFO cap - if len(self._traces) > _MAX_TRACES: - self._traces = self._traces[-_MAX_TRACES:] + if len(self._traces) > MAX_TRACES: + self._traces = self._traces[-MAX_TRACES:] # ------------------------------------------------------------------ # Retrieval @@ -290,9 +288,9 @@ def _extract_drawable_map(state: Dict[str, Any]) -> Dict[str, Any]: @staticmethod def _truncate(value: Any) -> Any: - """Truncate string values to _MAX_RESULT_STR_LEN for export.""" - if isinstance(value, str) and len(value) > _MAX_RESULT_STR_LEN: - return value[:_MAX_RESULT_STR_LEN] + "..." + """Truncate string values to MAX_RESULT_STR_LEN for export.""" + if isinstance(value, str) and len(value) > MAX_RESULT_STR_LEN: + return value[:MAX_RESULT_STR_LEN] + "..." return value @staticmethod diff --git a/static/config.py b/static/config.py new file mode 100644 index 00000000..ec920e25 --- /dev/null +++ b/static/config.py @@ -0,0 +1,27 @@ +""" +MatHud Server-Side Configuration Constants + +Centralized configuration values for server-side modules. +Defines directory paths, schema versions, and other shared settings. + +Categories: + - Workspace Management: Storage directories and schema versioning + - Canvas Snapshots: Screenshot storage paths + +Dependencies: + - os: Path construction for snapshot paths +""" + +from __future__ import annotations + +import os + +# ===== WORKSPACE MANAGEMENT CONSTANTS ===== +# Directory and versioning for workspace persistence +WORKSPACES_DIR: str = "workspaces" +CURRENT_WORKSPACE_SCHEMA_VERSION: int = 1 + +# ===== CANVAS SNAPSHOT CONSTANTS ===== +# Paths for Selenium-captured canvas screenshots +CANVAS_SNAPSHOT_DIR: str = "canvas_snapshots" +CANVAS_SNAPSHOT_PATH: str = os.path.join(CANVAS_SNAPSHOT_DIR, "canvas.png") diff --git a/static/workspace_manager.py b/static/workspace_manager.py index 4144190c..9eeac29f 100644 --- a/static/workspace_manager.py +++ b/static/workspace_manager.py @@ -19,8 +19,7 @@ from datetime import datetime from typing import Dict, List, Optional, TypedDict, Union, cast -WORKSPACES_DIR = "workspaces" -CURRENT_WORKSPACE_SCHEMA_VERSION = 1 +from static.config import CURRENT_WORKSPACE_SCHEMA_VERSION, WORKSPACES_DIR JsonPrimitive = Union[str, int, float, bool, None] JsonValue = Union[JsonPrimitive, Dict[str, "JsonValue"], List["JsonValue"]] From c17f41a897beba02e104870f87996b86ad115a48 Mon Sep 17 00:00:00 2001 From: Vlad Ciobanu <95963142+vl3c@users.noreply.github.com> Date: Tue, 24 Mar 2026 20:23:26 +0200 Subject: [PATCH 04/28] Extract shared env loading into static/env_config.py Replace duplicated load_dotenv patterns in openai_api_base.py, tool_search_service.py, and anthropic_api.py with shared get_api_key() utility. Also remove unused variable in tool_search_service.py. --- server_tests/test_openai_api_base.py | 18 ++----- static/env_config.py | 72 ++++++++++++++++++++++++++++ static/openai_api_base.py | 15 +----- static/providers/anthropic_api.py | 13 +---- static/tool_search_service.py | 19 +------- 5 files changed, 83 insertions(+), 54 deletions(-) create mode 100644 static/env_config.py diff --git a/server_tests/test_openai_api_base.py b/server_tests/test_openai_api_base.py index 8bb9d750..599c4f5b 100644 --- a/server_tests/test_openai_api_base.py +++ b/server_tests/test_openai_api_base.py @@ -475,20 +475,12 @@ def test_api_key_from_environment(self) -> None: finally: os.environ.pop("OPENAI_API_KEY", None) - @patch("static.openai_api_base.load_dotenv") - @patch("static.openai_api_base.os.path.exists") - def test_api_key_missing_returns_placeholder(self, mock_exists: Mock, mock_load_dotenv: Mock) -> None: + @patch("static.openai_api_base.get_api_key", return_value="") + def test_api_key_missing_returns_placeholder(self, mock_get_key: Mock) -> None: """Test missing API key returns placeholder instead of crashing.""" - # Mock .env file doesn't exist - mock_exists.return_value = False - # Remove API key from environment - original = os.environ.pop("OPENAI_API_KEY", None) - try: - result = OpenAIAPIBase._initialize_api_key() - self.assertEqual(result, "not-configured") - finally: - if original: - os.environ["OPENAI_API_KEY"] = original + result = OpenAIAPIBase._initialize_api_key() + self.assertEqual(result, "not-configured") + mock_get_key.assert_called_once_with("OPENAI_API_KEY", required=False, fallback="") if __name__ == "__main__": diff --git a/static/env_config.py b/static/env_config.py new file mode 100644 index 00000000..aea0f506 --- /dev/null +++ b/static/env_config.py @@ -0,0 +1,72 @@ +"""Centralized environment variable loading for the MatHud backend. + +Provides helpers that consolidate the duplicated .env discovery logic +(project root then parent directory) used by multiple API modules. +""" + +from __future__ import annotations + +import os +from typing import Optional + +from dotenv import load_dotenv + + +def load_env_files() -> None: + """Load .env files from the project root and its parent directory. + + The project-root .env is loaded first so its values take precedence. + A parent-directory .env is loaded afterwards only when the file exists, + which is useful when API keys are stored one level above the repository. + + Calling this function multiple times is safe; ``python-dotenv`` will + not overwrite variables that are already present in ``os.environ``. + """ + load_dotenv() + parent_env = os.path.join(os.path.dirname(os.getcwd()), ".env") + if os.path.exists(parent_env): + load_dotenv(parent_env) + + +def get_api_key( + name: str, + *, + required: bool = True, + fallback: Optional[str] = None, +) -> str: + """Return an API key from the environment, loading .env files first. + + The function checks ``os.environ`` before touching disk so that + explicitly-set variables are returned immediately. + + Args: + name: Environment variable name (e.g. ``"OPENAI_API_KEY"``). + required: When *True* and the key is missing, raise ``ValueError``. + When *False*, return *fallback* instead. + fallback: Value to return when the key is absent and *required* is + ``False``. Defaults to ``None``, but the return type is always + ``str`` — callers that pass ``required=False`` should supply a + non-``None`` fallback or handle the empty-string case. + + Returns: + The API key string. + + Raises: + ValueError: If *required* is True and the key cannot be found. + """ + # Fast path: already in environment (avoids disk I/O) + api_key = os.getenv(name) + if api_key: + return api_key + + # Load .env files and retry + load_env_files() + api_key = os.getenv(name) + + if api_key: + return api_key + + if required: + raise ValueError(f"{name} not found in environment or .env file") + + return fallback if fallback is not None else "" diff --git a/static/openai_api_base.py b/static/openai_api_base.py index 9951396b..e54ea643 100644 --- a/static/openai_api_base.py +++ b/static/openai_api_base.py @@ -16,10 +16,10 @@ from types import SimpleNamespace from typing import Any, Dict, List, Literal, Optional, Union -from dotenv import load_dotenv from openai import OpenAI from static.ai_model import AIModel +from static.env_config import get_api_key from static.canvas_state_summarizer import compare_canvas_states from static.functions_definitions import FUNCTIONS, FunctionDefinition from static.token_estimation import estimate_tokens_from_bytes @@ -83,21 +83,10 @@ def _initialize_api_key() -> str: to start with other providers configured. Actual OpenAI API calls will fail with an authentication error in that case. """ - api_key = os.getenv("OPENAI_API_KEY") - if api_key: - return api_key - - # Load from project .env, then parent .env (API keys may live outside repo) - load_dotenv() - parent_env = os.path.join(os.path.dirname(os.getcwd()), ".env") - if os.path.exists(parent_env): - load_dotenv(parent_env) - api_key = os.getenv("OPENAI_API_KEY") - + api_key = get_api_key("OPENAI_API_KEY", required=False, fallback="") if not api_key: logging.getLogger("mathud").warning("OPENAI_API_KEY not found. OpenAI models will be unavailable.") return "not-configured" - return api_key def __init__( diff --git a/static/providers/anthropic_api.py b/static/providers/anthropic_api.py index a2a9fa60..0082a4c4 100644 --- a/static/providers/anthropic_api.py +++ b/static/providers/anthropic_api.py @@ -9,14 +9,12 @@ import json import logging -import os from collections.abc import Iterator, Sequence from types import SimpleNamespace from typing import Any, Dict, List, Optional -from dotenv import load_dotenv - from static.ai_model import AIModel +from static.env_config import get_api_key from static.functions_definitions import FunctionDefinition from static.openai_api_base import MessageDict, OpenAIAPIBase, StreamEvent, ToolMode from static.providers import PROVIDER_ANTHROPIC, ProviderRegistry @@ -26,14 +24,7 @@ def _get_anthropic_api_key() -> str: """Get the Anthropic API key from environment.""" - load_dotenv() - parent_env = os.path.join(os.path.dirname(os.getcwd()), ".env") - if os.path.exists(parent_env): - load_dotenv(parent_env) - api_key = os.getenv("ANTHROPIC_API_KEY") - if not api_key: - raise ValueError("ANTHROPIC_API_KEY not found in environment or .env file") - return api_key + return get_api_key("ANTHROPIC_API_KEY") class AnthropicAPI(OpenAIAPIBase): diff --git a/static/tool_search_service.py b/static/tool_search_service.py index 9b52be56..921923c5 100644 --- a/static/tool_search_service.py +++ b/static/tool_search_service.py @@ -25,10 +25,10 @@ from collections import defaultdict from typing import Any, Dict, List, Optional, Tuple, TypedDict -from dotenv import load_dotenv from openai import OpenAI from static.ai_model import AIModel +from static.env_config import get_api_key from static.functions_definitions import FUNCTIONS, FunctionDefinition _logger = logging.getLogger("mathud") @@ -416,21 +416,7 @@ def client(self, value: OpenAI) -> None: @staticmethod def _initialize_api_key() -> str: """Initialize the OpenAI API key from environment or .env file.""" - api_key = os.getenv("OPENAI_API_KEY") - if api_key: - return api_key - - # Load from project .env, then parent .env (API keys may live outside repo) - load_dotenv() - parent_env = os.path.join(os.path.dirname(os.getcwd()), ".env") - if os.path.exists(parent_env): - load_dotenv(parent_env) - api_key = os.getenv("OPENAI_API_KEY") - - if not api_key: - raise ValueError("OPENAI_API_KEY not found in environment or .env file") - - return api_key + return get_api_key("OPENAI_API_KEY") @staticmethod def get_all_tools() -> List[FunctionDefinition]: @@ -549,7 +535,6 @@ def search_tools_local( scores: Dict[str, float] = defaultdict(float) # 1. Exact tool name match - query_lower = query.lower().strip() for token in query_tokens: if token in _ALL_TOOL_NAMES: scores[token] += 8.0 From 7bd164102c4edbedce59f13653447ce8e6b19fdc Mon Sep 17 00:00:00 2001 From: Vlad Ciobanu <95963142+vl3c@users.noreply.github.com> Date: Tue, 24 Mar 2026 20:23:32 +0200 Subject: [PATCH 05/28] Standardize error handling and extract route helper functions Replace print() error reporting with logging in routes.py. Add error visibility to bare except blocks. Extract duplicated tool reset and provider model logic into static/route_helpers.py. --- server_tests/test_routes.py | 3 +- static/client/managers/drawables_container.py | 3 +- static/route_helpers.py | 88 +++++++++++++++++++ 3 files changed, 92 insertions(+), 2 deletions(-) create mode 100644 static/route_helpers.py diff --git a/server_tests/test_routes.py b/server_tests/test_routes.py index e765f640..f454f24c 100644 --- a/server_tests/test_routes.py +++ b/server_tests/test_routes.py @@ -7,7 +7,8 @@ from unittest.mock import Mock, patch from static.app_manager import AppManager, MatHudFlask -from static.routes import CANVAS_SNAPSHOT_PATH, save_canvas_snapshot_from_data_url +from static.config import CANVAS_SNAPSHOT_PATH +from static.routes import save_canvas_snapshot_from_data_url from static.openai_completions_api import OpenAIChatCompletionsAPI from static.openai_responses_api import OpenAIResponsesAPI diff --git a/static/client/managers/drawables_container.py b/static/client/managers/drawables_container.py index a1b08b04..53d5391c 100644 --- a/static/client/managers/drawables_container.py +++ b/static/client/managers/drawables_container.py @@ -76,7 +76,8 @@ def _is_renderable(self, drawable: "Drawable") -> bool: renderable_attr = getattr(drawable, "is_renderable", True) try: return bool(renderable_attr) - except Exception: + except Exception as e: + print(f"[DrawablesContainer] Error in layering comparison: {e}") return True def _apply_layering(self, colored: List["Drawable"], others: List["Drawable"]) -> List["Drawable"]: diff --git a/static/route_helpers.py b/static/route_helpers.py new file mode 100644 index 00000000..830030a4 --- /dev/null +++ b/static/route_helpers.py @@ -0,0 +1,88 @@ +"""Route helper functions for MatHud Flask routes. + +Extracts duplicated logic from routes.py into reusable helpers for +provider model synchronization and tool lifecycle management. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, cast + +if TYPE_CHECKING: + from static.app_manager import MatHudFlask + from static.openai_api_base import OpenAIAPIBase + + +def reset_tools_for_all_providers( + app: MatHudFlask, + finish_reason: object, + *, + active_provider: OpenAIAPIBase | None = None, +) -> None: + """Reset injected tools on all provider instances when a conversation turn ends. + + Called when the AI finish reason is anything other than ``"tool_calls"`` + (i.e. the model is done calling tools) or on error. Resets the built-in + OpenAI APIs (``app.ai_api`` and ``app.responses_api``) plus the + *active_provider* if it differs from the built-in ones. + + Args: + app: The Flask application instance carrying provider references. + finish_reason: The finish reason string from the AI response. + When equal to ``"tool_calls"`` this function is a no-op. + active_provider: The provider that handled the current request. + If ``None`` only the two built-in OpenAI providers are checked. + """ + if finish_reason == "tool_calls": + return + + if app.ai_api.has_injected_tools(): + app.ai_api.reset_tools() + if app.responses_api.has_injected_tools(): + app.responses_api.reset_tools() + + if ( + active_provider is not None + and active_provider not in (app.ai_api, app.responses_api) + and active_provider.has_injected_tools() + ): + active_provider.reset_tools() + + +def update_all_provider_models(app: MatHudFlask, model_id: str) -> None: + """Synchronize the model selection across the built-in OpenAI providers. + + Both ``app.ai_api`` and ``app.responses_api`` are updated so that + whichever one handles the next request uses the correct model. + + Args: + app: The Flask application instance. + model_id: The model identifier string (e.g. ``"gpt-4.1"``). + """ + app.ai_api.set_model(model_id) + app.responses_api.set_model(model_id) + + +def get_active_provider(app: MatHudFlask, model_id: str | None) -> OpenAIAPIBase: + """Return the correct provider for a given model, updating built-in APIs. + + When *model_id* is provided the built-in OpenAI providers are + synchronized via :func:`update_all_provider_models` and the + appropriate provider instance is resolved (creating one lazily if + needed). When *model_id* is ``None`` the default ``app.ai_api`` + provider is returned. + + Args: + app: The Flask application instance. + model_id: The model identifier, or ``None`` for the default. + + Returns: + The provider instance that should handle the current request. + """ + from static.routes import get_provider_for_model + + if model_id: + update_all_provider_models(app, model_id) + return cast("OpenAIAPIBase", get_provider_for_model(app, model_id)) + + return cast("OpenAIAPIBase", app.ai_api) From 88145ef2bca7de39cf5d166206b559fa913a28f1 Mon Sep 17 00:00:00 2001 From: Vlad Ciobanu <95963142+vl3c@users.noreply.github.com> Date: Tue, 24 Mar 2026 20:23:36 +0200 Subject: [PATCH 06/28] Pin all dependencies to exact versions Replace >= bounds and unpinned packages with == pins matching the currently installed versions for reproducible builds. --- requirements.txt | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/requirements.txt b/requirements.txt index 3cc9e5be..4f056461 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,19 +1,19 @@ -openai>=1.60.0 -anthropic>=0.40.0 -flask -flask-session -cachelib -python-dotenv -selenium -geckodriver-autoinstaller -requests -types-requests +openai==2.8.1 +anthropic==0.76.0 +flask==3.1.2 +flask-session==0.8.0 +cachelib==0.13.0 +python-dotenv==1.2.1 +selenium==4.38.0 +geckodriver-autoinstaller==0.1.0 +requests==2.32.5 +types-requests==2.32.4.20250913 mypy==1.11.2 -ruff>=0.8.0 -kokoro>=0.9.4 -soundfile -numpy +ruff==0.15.0 +kokoro==0.9.4 +soundfile==0.13.1 +numpy==2.4.1 # CLI dependencies -click>=8.0.0 -webdriver-manager>=4.0.0 -psutil>=5.9.0 +click==8.3.0 +webdriver-manager==4.0.2 +psutil==7.2.2 From b66f9d62063e62516f5c444225e06062b34f761e Mon Sep 17 00:00:00 2001 From: Vlad Ciobanu <95963142+vl3c@users.noreply.github.com> Date: Tue, 24 Mar 2026 20:38:04 +0200 Subject: [PATCH 07/28] Fix client test imports for moved constants Update test_action_trace_collector.py to import MAX_TRACES from constants instead of the removed _MAX_TRACES private variable. Update test_image_attachment.py to import MAX_ATTACHED_IMAGES and IMAGE_SIZE_WARNING_BYTES from constants instead of accessing them as class attributes. Use Optional[] syntax in route_helpers.py for consistency with codebase conventions. --- static/client/client_tests/test_action_trace_collector.py | 7 ++++--- static/client/client_tests/test_image_attachment.py | 8 ++++---- static/route_helpers.py | 6 +++--- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/static/client/client_tests/test_action_trace_collector.py b/static/client/client_tests/test_action_trace_collector.py index 944a878f..58d8fd01 100644 --- a/static/client/client_tests/test_action_trace_collector.py +++ b/static/client/client_tests/test_action_trace_collector.py @@ -6,7 +6,8 @@ import unittest from typing import Any, Dict, List -from managers.action_trace_collector import ActionTraceCollector, _MAX_TRACES +from constants import MAX_TRACES +from managers.action_trace_collector import ActionTraceCollector class TestComputeStateDelta(unittest.TestCase): @@ -140,9 +141,9 @@ def test_store_strips_snapshots_from_older(self) -> None: self.assertIn("canvas_state_after", traces[1]) def test_store_cap(self) -> None: - for i in range(_MAX_TRACES + 20): + for i in range(MAX_TRACES + 20): self.collector.store(self._make_trace(trace_id=f"t{i}")) - self.assertEqual(len(self.collector.get_traces()), _MAX_TRACES) + self.assertEqual(len(self.collector.get_traces()), MAX_TRACES) def test_clear(self) -> None: self.collector.store(self._make_trace()) diff --git a/static/client/client_tests/test_image_attachment.py b/static/client/client_tests/test_image_attachment.py index ac7c11cc..52110dad 100644 --- a/static/client/client_tests/test_image_attachment.py +++ b/static/client/client_tests/test_image_attachment.py @@ -15,6 +15,8 @@ from typing import Any, Dict, List, Optional from unittest.mock import MagicMock +from constants import IMAGE_SIZE_WARNING_BYTES, MAX_ATTACHED_IMAGES + class MockCanvas: """Mock canvas for testing.""" @@ -85,8 +87,6 @@ def setUp(self) -> None: # so we'll test the logic patterns directly self.ai = MagicMock(spec=AIInterface) self.ai._attached_images = [] - self.ai.MAX_ATTACHED_IMAGES = 5 - self.ai.IMAGE_SIZE_WARNING_BYTES = 10 * 1024 * 1024 def test_initial_state_empty(self) -> None: """Test attached images starts empty.""" @@ -127,11 +127,11 @@ def test_clear_images(self) -> None: def test_max_images_constant(self) -> None: """Test maximum images constant is set.""" - self.assertEqual(self.ai.MAX_ATTACHED_IMAGES, 5) + self.assertEqual(MAX_ATTACHED_IMAGES, 5) def test_image_size_warning_constant(self) -> None: """Test image size warning threshold is 10MB.""" - self.assertEqual(self.ai.IMAGE_SIZE_WARNING_BYTES, 10 * 1024 * 1024) + self.assertEqual(IMAGE_SIZE_WARNING_BYTES, 10 * 1024 * 1024) class TestImageValidation(unittest.TestCase): diff --git a/static/route_helpers.py b/static/route_helpers.py index 830030a4..bddf4357 100644 --- a/static/route_helpers.py +++ b/static/route_helpers.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Optional, cast if TYPE_CHECKING: from static.app_manager import MatHudFlask @@ -17,7 +17,7 @@ def reset_tools_for_all_providers( app: MatHudFlask, finish_reason: object, *, - active_provider: OpenAIAPIBase | None = None, + active_provider: Optional[OpenAIAPIBase] = None, ) -> None: """Reset injected tools on all provider instances when a conversation turn ends. @@ -63,7 +63,7 @@ def update_all_provider_models(app: MatHudFlask, model_id: str) -> None: app.responses_api.set_model(model_id) -def get_active_provider(app: MatHudFlask, model_id: str | None) -> OpenAIAPIBase: +def get_active_provider(app: MatHudFlask, model_id: Optional[str]) -> OpenAIAPIBase: """Return the correct provider for a given model, updating built-in APIs. When *model_id* is provided the built-in OpenAI providers are From 98480221762aac50e7124568d2ec011cdf0cc316 Mon Sep 17 00:00:00 2001 From: Vlad Ciobanu <95963142+vl3c@users.noreply.github.com> Date: Tue, 24 Mar 2026 20:41:40 +0200 Subject: [PATCH 08/28] Extract BaseRendererTelemetry from duplicated SVG/Canvas2D telemetry Move shared telemetry logic (phase tracking, per-drawable counters, frame recording, adapter events) into base_telemetry.py. Canvas2D overrides only _new_drawable_bucket() for its two extra counters. Removes ~270 lines of duplication. --- static/client/rendering/base_telemetry.py | 175 +++++++++++++++++++ static/client/rendering/canvas2d_renderer.py | 150 ++-------------- static/client/rendering/svg_renderer.py | 131 +------------- 3 files changed, 191 insertions(+), 265 deletions(-) create mode 100644 static/client/rendering/base_telemetry.py diff --git a/static/client/rendering/base_telemetry.py b/static/client/rendering/base_telemetry.py new file mode 100644 index 00000000..9d012e5f --- /dev/null +++ b/static/client/rendering/base_telemetry.py @@ -0,0 +1,175 @@ +"""Base telemetry class shared by SVG and Canvas 2D renderers. + +Provides common timing, counting, and snapshot logic so that each renderer's +telemetry subclass only needs to define its per-drawable bucket schema. +""" + +from __future__ import annotations + +import time +from typing import Any, Dict + +from browser import window + + +class BaseRendererTelemetry: + """Performance telemetry collector base for rendering backends. + + Tracks timing metrics for plan building and application, cache statistics, + and per-drawable performance data. Subclasses override ``_new_drawable_bucket`` + to add renderer-specific counters. + + Attributes: + _phase_totals: Cumulative timing for each rendering phase. + _phase_counts: Operation counts for each phase. + _per_drawable: Per-drawable type timing breakdown. + _adapter_events: Event counts from the primitive adapter. + _frames: Total frames rendered since last reset. + """ + + def __init__(self) -> None: + """Initialize telemetry with zeroed counters.""" + self.reset() + + def reset(self) -> None: + """Reset all telemetry counters to zero.""" + self._phase_totals: Dict[str, float] = { + "plan_build_ms": 0.0, + "plan_apply_ms": 0.0, + "cartesian_plan_build_ms": 0.0, + "cartesian_plan_apply_ms": 0.0, + } + self._phase_counts: Dict[str, int] = { + "plan_build_count": 0, + "plan_apply_count": 0, + "cartesian_plan_count": 0, + "plan_miss_count": 0, + "plan_skip_count": 0, + } + self._per_drawable: Dict[str, Dict[str, float]] = {} + self._adapter_events: Dict[str, int] = {} + self._frames: int = 0 + self._max_batch_depth: int = 0 + + def begin_frame(self) -> None: + """Signal the start of a new frame for counting purposes.""" + self._frames += 1 + + def end_frame(self) -> None: + """Signal the end of a frame (currently no-op).""" + pass + + def _now(self) -> float: + """Get current timestamp in milliseconds using performance.now() if available.""" + try: + perf = getattr(window, "performance", None) + if perf is not None: + return float(perf.now()) + except Exception: + pass + return time.time() * 1000.0 + + def mark_time(self) -> float: + """Record and return the current timestamp for duration measurement.""" + return self._now() + + def elapsed_since(self, start: float) -> float: + """Calculate milliseconds elapsed since a marked timestamp.""" + return max(self._now() - start, 0.0) + + # ------------------------------------------------------------------ + # Per-drawable bucket + # ------------------------------------------------------------------ + + def _new_drawable_bucket(self) -> Dict[str, float]: + """Return a fresh per-drawable counters dict. + + Subclasses may override to add renderer-specific keys. + """ + return { + "plan_build_ms": 0.0, + "plan_apply_ms": 0.0, + "plan_build_count": 0, + "plan_apply_count": 0, + "plan_miss_count": 0, + "plan_skip_count": 0, + } + + def _drawable_bucket(self, name: str) -> Dict[str, float]: + """Get or create the telemetry bucket for a drawable type.""" + bucket = self._per_drawable.get(name) + if bucket is None: + bucket = self._new_drawable_bucket() + self._per_drawable[name] = bucket + return bucket + + # ------------------------------------------------------------------ + # Recording helpers + # ------------------------------------------------------------------ + + def record_plan_build(self, name: str, duration_ms: float, *, cartesian: bool = False) -> None: + """Record time spent building a render plan.""" + self._phase_totals["plan_build_ms"] += duration_ms + self._phase_counts["plan_build_count"] += 1 + bucket = self._drawable_bucket(name) + bucket["plan_build_ms"] += duration_ms + bucket["plan_build_count"] += 1 + if cartesian: + self._phase_totals["cartesian_plan_build_ms"] += duration_ms + self._phase_counts["cartesian_plan_count"] += 1 + + def record_plan_apply(self, name: str, duration_ms: float, *, cartesian: bool = False) -> None: + """Record time spent applying a render plan.""" + self._phase_totals["plan_apply_ms"] += duration_ms + self._phase_counts["plan_apply_count"] += 1 + bucket = self._drawable_bucket(name) + bucket["plan_apply_ms"] += duration_ms + bucket["plan_apply_count"] += 1 + if cartesian: + self._phase_totals["cartesian_plan_apply_ms"] += duration_ms + + def record_plan_miss(self, name: str) -> None: + """Record when a drawable could not be rendered via a plan.""" + self._phase_counts["plan_miss_count"] += 1 + bucket = self._drawable_bucket(name) + bucket["plan_miss_count"] += 1 + + def record_plan_skip(self, name: str) -> None: + """Record when a plan was skipped due to being off-screen.""" + self._phase_counts["plan_skip_count"] += 1 + bucket = self._drawable_bucket(name) + bucket["plan_skip_count"] += 1 + + def record_adapter_event(self, name: str, amount: int = 1) -> None: + """Record an event from the primitive adapter.""" + self._adapter_events[name] = self._adapter_events.get(name, 0) + amount + + def track_batch_depth(self, depth: int) -> None: + """Track the maximum nested batch depth seen.""" + if depth > self._max_batch_depth: + self._max_batch_depth = depth + + # ------------------------------------------------------------------ + # Snapshot / drain + # ------------------------------------------------------------------ + + def snapshot(self) -> Dict[str, Any]: + """Get a copy of all current telemetry data without resetting.""" + adapter_events = dict(self._adapter_events) + if self._max_batch_depth: + adapter_events["max_batch_depth"] = self._max_batch_depth + per_drawable = {name: dict(bucket) for name, bucket in self._per_drawable.items()} + phase = dict(self._phase_totals) + phase.update(self._phase_counts) + return { + "frames": self._frames, + "phase": phase, + "per_drawable": per_drawable, + "adapter_events": adapter_events, + } + + def drain(self) -> Dict[str, Any]: + """Get all telemetry data and reset counters to zero.""" + snapshot = self.snapshot() + self.reset() + return snapshot diff --git a/static/client/rendering/canvas2d_renderer.py b/static/client/rendering/canvas2d_renderer.py index 978ff7d7..0925c635 100644 --- a/static/client/rendering/canvas2d_renderer.py +++ b/static/client/rendering/canvas2d_renderer.py @@ -20,11 +20,11 @@ from __future__ import annotations -import time from typing import Any, Callable, Dict, Optional, Tuple from browser import document, html, window +from rendering.base_telemetry import BaseRendererTelemetry from rendering.style_manager import get_renderer_style from rendering.interfaces import RendererProtocol from rendering.canvas2d_primitive_adapter import Canvas2DPrimitiveAdapter @@ -36,151 +36,21 @@ ) -class Canvas2DTelemetry: +class Canvas2DTelemetry(BaseRendererTelemetry): """Performance telemetry collector for Canvas 2D rendering. - Tracks timing metrics for plan building and application, cache statistics, - and per-drawable performance data. Useful for identifying rendering - bottlenecks and optimizing performance. - - Attributes: - _phase_totals: Cumulative timing for each rendering phase. - _phase_counts: Operation counts for each phase. - _per_drawable: Per-drawable type timing breakdown. - _adapter_events: Event counts from the primitive adapter. - _frames: Total frames rendered since last reset. + Inherits common timing, counting, and snapshot logic from + ``BaseRendererTelemetry``. Overrides ``_new_drawable_bucket`` to include + legacy-render counters specific to the Canvas 2D pipeline. """ - def __init__(self) -> None: - """Initialize telemetry with zeroed counters.""" - self.reset() - - def reset(self) -> None: - """Reset all telemetry counters to zero.""" - self._phase_totals: Dict[str, float] = { - "plan_build_ms": 0.0, - "plan_apply_ms": 0.0, - "cartesian_plan_build_ms": 0.0, - "cartesian_plan_apply_ms": 0.0, - } - self._phase_counts: Dict[str, int] = { - "plan_build_count": 0, - "plan_apply_count": 0, - "cartesian_plan_count": 0, - "plan_miss_count": 0, - "plan_skip_count": 0, - } - self._per_drawable: Dict[str, Dict[str, float]] = {} - self._adapter_events: Dict[str, int] = {} - self._frames: int = 0 - self._max_batch_depth: int = 0 - - def begin_frame(self) -> None: - """Signal the start of a new frame for counting purposes.""" - self._frames += 1 - - def end_frame(self) -> None: - """Signal the end of a frame (currently no-op).""" - pass - - def _now(self) -> float: - """Get current timestamp in milliseconds using performance.now() if available.""" - try: - perf = getattr(window, "performance", None) - if perf is not None: - return float(perf.now()) - except Exception: - pass - return time.time() * 1000.0 - - def mark_time(self) -> float: - """Record and return the current timestamp for duration measurement.""" - return self._now() - - def elapsed_since(self, start: float) -> float: - """Calculate milliseconds elapsed since a marked timestamp.""" - return max(self._now() - start, 0.0) - - def _drawable_bucket(self, name: str) -> Dict[str, float]: - """Get or create the telemetry bucket for a drawable type.""" - bucket = self._per_drawable.get(name) - if bucket is None: - bucket = { - "plan_build_ms": 0.0, - "plan_apply_ms": 0.0, - "legacy_render_ms": 0.0, - "plan_build_count": 0, - "plan_apply_count": 0, - "legacy_render_count": 0, - "plan_miss_count": 0, - "plan_skip_count": 0, - } - self._per_drawable[name] = bucket + def _new_drawable_bucket(self) -> Dict[str, float]: + """Return a fresh per-drawable counters dict with legacy-render keys.""" + bucket = super()._new_drawable_bucket() + bucket["legacy_render_ms"] = 0.0 + bucket["legacy_render_count"] = 0 return bucket - def record_plan_build(self, name: str, duration_ms: float, *, cartesian: bool = False) -> None: - """Record time spent building a render plan.""" - self._phase_totals["plan_build_ms"] += duration_ms - self._phase_counts["plan_build_count"] += 1 - bucket = self._drawable_bucket(name) - bucket["plan_build_ms"] += duration_ms - bucket["plan_build_count"] += 1 - if cartesian: - self._phase_totals["cartesian_plan_build_ms"] += duration_ms - self._phase_counts["cartesian_plan_count"] += 1 - - def record_plan_apply(self, name: str, duration_ms: float, *, cartesian: bool = False) -> None: - """Record time spent applying a render plan to the canvas.""" - self._phase_totals["plan_apply_ms"] += duration_ms - self._phase_counts["plan_apply_count"] += 1 - bucket = self._drawable_bucket(name) - bucket["plan_apply_ms"] += duration_ms - bucket["plan_apply_count"] += 1 - if cartesian: - self._phase_totals["cartesian_plan_apply_ms"] += duration_ms - - def record_plan_miss(self, name: str) -> None: - """Record when a drawable could not be rendered via a plan.""" - self._phase_counts["plan_miss_count"] += 1 - bucket = self._drawable_bucket(name) - bucket["plan_miss_count"] += 1 - - def record_plan_skip(self, name: str) -> None: - """Record when a plan was skipped due to being off-screen.""" - self._phase_counts["plan_skip_count"] += 1 - bucket = self._drawable_bucket(name) - bucket["plan_skip_count"] += 1 - - def record_adapter_event(self, name: str, amount: int = 1) -> None: - """Record an event from the primitive adapter.""" - self._adapter_events[name] = self._adapter_events.get(name, 0) + amount - - def track_batch_depth(self, depth: int) -> None: - """Track the maximum nested batch depth seen.""" - if depth > self._max_batch_depth: - self._max_batch_depth = depth - - def snapshot(self) -> Dict[str, Any]: - """Get a copy of all current telemetry data without resetting.""" - adapter_events = dict(self._adapter_events) - if self._max_batch_depth: - adapter_events["max_batch_depth"] = self._max_batch_depth - per_drawable = {name: dict(bucket) for name, bucket in self._per_drawable.items()} - phase = dict(self._phase_totals) - phase.update(self._phase_counts) - return { - "frames": self._frames, - "phase": phase, - "per_drawable": per_drawable, - "adapter_events": adapter_events, - } - - def drain(self) -> Dict[str, Any]: - """Get all telemetry data and reset counters to zero.""" - snapshot = self.snapshot() - self.reset() - return snapshot - class Canvas2DRenderer(RendererProtocol): """Renderer using the HTML5 Canvas 2D API. diff --git a/static/client/rendering/svg_renderer.py b/static/client/rendering/svg_renderer.py index 0f43ca15..9410b137 100644 --- a/static/client/rendering/svg_renderer.py +++ b/static/client/rendering/svg_renderer.py @@ -20,7 +20,6 @@ from __future__ import annotations -import time from typing import Any, Callable, Dict, Optional, Set, Tuple from browser import document, svg, window @@ -33,137 +32,19 @@ ) from rendering.interfaces import RendererProtocol from rendering.style_manager import get_renderer_style +from rendering.base_telemetry import BaseRendererTelemetry from rendering.svg_primitive_adapter import SvgPrimitiveAdapter -class SvgTelemetry: +class SvgTelemetry(BaseRendererTelemetry): """Performance telemetry collector for SVG rendering. - Tracks timing metrics for plan building and application, cache statistics, - and per-drawable performance data. - - Attributes: - _phase_totals: Cumulative timing for each rendering phase. - _phase_counts: Operation counts for each phase. - _per_drawable: Per-drawable type timing breakdown. - _adapter_events: Event counts from the primitive adapter. - _frames: Total frames rendered since last reset. + Inherits all common timing, counting, and snapshot logic from + ``BaseRendererTelemetry``. No SVG-specific overrides are needed at + this time, but the subclass is kept so the renderer can evolve its + telemetry independently if required. """ - def __init__(self) -> None: - """Initialize telemetry with zeroed counters.""" - self.reset() - - def reset(self) -> None: - """Reset all telemetry counters to zero.""" - self._phase_totals: Dict[str, float] = { - "plan_build_ms": 0.0, - "plan_apply_ms": 0.0, - "cartesian_plan_build_ms": 0.0, - "cartesian_plan_apply_ms": 0.0, - } - self._phase_counts: Dict[str, int] = { - "plan_build_count": 0, - "plan_apply_count": 0, - "cartesian_plan_count": 0, - "plan_miss_count": 0, - "plan_skip_count": 0, - } - self._per_drawable: Dict[str, Dict[str, float]] = {} - self._adapter_events: Dict[str, int] = {} - self._frames: int = 0 - self._max_batch_depth: int = 0 - - def begin_frame(self) -> None: - self._frames += 1 - - def end_frame(self) -> None: - pass - - def _now(self) -> float: - try: - perf = getattr(window, "performance", None) - if perf is not None: - return float(perf.now()) - except Exception: - pass - return time.time() * 1000.0 - - def mark_time(self) -> float: - return self._now() - - def elapsed_since(self, start: float) -> float: - return max(self._now() - start, 0.0) - - def _drawable_bucket(self, name: str) -> Dict[str, float]: - bucket = self._per_drawable.get(name) - if bucket is None: - bucket = { - "plan_build_ms": 0.0, - "plan_apply_ms": 0.0, - "plan_build_count": 0, - "plan_apply_count": 0, - "plan_miss_count": 0, - "plan_skip_count": 0, - } - self._per_drawable[name] = bucket - return bucket - - def record_plan_build(self, name: str, duration_ms: float, *, cartesian: bool = False) -> None: - self._phase_totals["plan_build_ms"] += duration_ms - self._phase_counts["plan_build_count"] += 1 - bucket = self._drawable_bucket(name) - bucket["plan_build_ms"] += duration_ms - bucket["plan_build_count"] += 1 - if cartesian: - self._phase_totals["cartesian_plan_build_ms"] += duration_ms - self._phase_counts["cartesian_plan_count"] += 1 - - def record_plan_apply(self, name: str, duration_ms: float, *, cartesian: bool = False) -> None: - self._phase_totals["plan_apply_ms"] += duration_ms - self._phase_counts["plan_apply_count"] += 1 - bucket = self._drawable_bucket(name) - bucket["plan_apply_ms"] += duration_ms - bucket["plan_apply_count"] += 1 - if cartesian: - self._phase_totals["cartesian_plan_apply_ms"] += duration_ms - - def record_plan_miss(self, name: str) -> None: - self._phase_counts["plan_miss_count"] += 1 - bucket = self._drawable_bucket(name) - bucket["plan_miss_count"] += 1 - - def record_plan_skip(self, name: str) -> None: - self._phase_counts["plan_skip_count"] += 1 - bucket = self._drawable_bucket(name) - bucket["plan_skip_count"] += 1 - - def record_adapter_event(self, name: str, amount: int = 1) -> None: - self._adapter_events[name] = self._adapter_events.get(name, 0) + amount - - def track_batch_depth(self, depth: int) -> None: - if depth > self._max_batch_depth: - self._max_batch_depth = depth - - def snapshot(self) -> Dict[str, Any]: - adapter_events = dict(self._adapter_events) - if self._max_batch_depth: - adapter_events["max_batch_depth"] = self._max_batch_depth - per_drawable = {name: dict(bucket) for name, bucket in self._per_drawable.items()} - phase = dict(self._phase_totals) - phase.update(self._phase_counts) - return { - "frames": self._frames, - "phase": phase, - "per_drawable": per_drawable, - "adapter_events": adapter_events, - } - - def drain(self) -> Dict[str, Any]: - snapshot = self.snapshot() - self.reset() - return snapshot - class SvgRenderer(RendererProtocol): """Renderer using SVG DOM elements. From a6b494e92542d215b635a1657b44f5cd9c01e406 Mon Sep 17 00:00:00 2001 From: Vlad Ciobanu <95963142+vl3c@users.noreply.github.com> Date: Tue, 24 Mar 2026 20:41:58 +0200 Subject: [PATCH 09/28] Extract BaseDrawableManager from shared drawable manager patterns Create a base class that captures the common constructor boilerplate, attribute storage, edit-policy lookup, and name-based retrieval shared across all drawable managers. Migrate PointManager and LabelManager to inherit from it, replacing duplicated __init__ assignments and lookup loops with super().__init__() and _get_by_name() delegation. --- .../client/managers/base_drawable_manager.py | 70 +++++++++++++++++++ static/client/managers/label_manager.py | 32 ++++----- static/client/managers/point_manager.py | 30 ++++---- 3 files changed, 102 insertions(+), 30 deletions(-) create mode 100644 static/client/managers/base_drawable_manager.py diff --git a/static/client/managers/base_drawable_manager.py b/static/client/managers/base_drawable_manager.py new file mode 100644 index 00000000..403aef8a --- /dev/null +++ b/static/client/managers/base_drawable_manager.py @@ -0,0 +1,70 @@ +""" +Base class for drawable managers. + +Captures the common constructor pattern, attribute storage, edit-policy +lookup, and name-based retrieval shared by every specialized manager. +Subclasses supply the drawable-type name and any additional constructor +parameters they need. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +from managers.edit_policy import DrawableEditPolicy, get_drawable_edit_policy + +if TYPE_CHECKING: + from drawables.drawable import Drawable + from canvas import Canvas + from managers.drawables_container import DrawablesContainer + from managers.drawable_dependency_manager import DrawableDependencyManager + from managers.drawable_manager_proxy import DrawableManagerProxy + from name_generator.drawable import DrawableNameGenerator + + +class BaseDrawableManager: + """Shared foundation for specialized drawable managers. + + Stores the five common dependencies every manager receives and + automatically resolves the edit policy for the declared drawable type. + + Subclasses must set ``drawable_type`` (a class attribute) to the + drawable class name string (e.g. ``"Point"``, ``"Label"``). + """ + + drawable_type: str = "" + """Override in subclasses with the drawable class name (e.g. ``"Point"``).""" + + def __init__( + self, + canvas: "Canvas", + drawables_container: "DrawablesContainer", + name_generator: "DrawableNameGenerator", + dependency_manager: "DrawableDependencyManager", + drawable_manager_proxy: "DrawableManagerProxy", + ) -> None: + self.canvas: "Canvas" = canvas + self.drawables: "DrawablesContainer" = drawables_container + self.name_generator: "DrawableNameGenerator" = name_generator + self.dependency_manager: "DrawableDependencyManager" = dependency_manager + self.drawable_manager: "DrawableManagerProxy" = drawable_manager_proxy + self.edit_policy: Optional[DrawableEditPolicy] = ( + get_drawable_edit_policy(self.drawable_type) if self.drawable_type else None + ) + + # ------------------------------------------------------------------ + # Common lookup + # ------------------------------------------------------------------ + + def _get_by_name(self, name: str) -> Optional["Drawable"]: + """Look up a drawable by *name* inside the container. + + Iterates over drawables whose class name matches ``drawable_type``. + Returns ``None`` when *name* is empty or no match is found. + """ + if not name: + return None + for drawable in self.drawables.get_by_class_name(self.drawable_type): + if getattr(drawable, "name", None) == name: + return drawable + return None diff --git a/static/client/managers/label_manager.py b/static/client/managers/label_manager.py index c281a826..dd148984 100644 --- a/static/client/managers/label_manager.py +++ b/static/client/managers/label_manager.py @@ -14,7 +14,8 @@ from constants import default_color, default_label_font_size from drawables.label import Label from utils.math_utils import MathUtils -from managers.edit_policy import DrawableEditPolicy, EditRule, get_drawable_edit_policy +from managers.base_drawable_manager import BaseDrawableManager +from managers.edit_policy import EditRule if TYPE_CHECKING: from canvas import Canvas @@ -24,9 +25,11 @@ from name_generator.drawable import DrawableNameGenerator -class LabelManager: +class LabelManager(BaseDrawableManager): """Manages label drawables for a Canvas.""" + drawable_type: str = "Label" + def __init__( self, canvas: "Canvas", @@ -35,20 +38,17 @@ def __init__( dependency_manager: "DrawableDependencyManager", drawable_manager_proxy: "DrawableManagerProxy", ) -> None: - self.canvas = canvas - self.drawables = drawables_container - self.name_generator = name_generator - self.dependency_manager = dependency_manager - self.drawable_manager = drawable_manager_proxy - self.label_edit_policy: Optional[DrawableEditPolicy] = get_drawable_edit_policy("Label") + super().__init__( + canvas, + drawables_container, + name_generator, + dependency_manager, + drawable_manager_proxy, + ) def get_label_by_name(self, name: str) -> Optional[Label]: - if not name: - return None - for label in self.drawables.Labels: - if label.name == name: - return label - return None + result = self._get_by_name(name) + return cast(Optional[Label], result) def get_labels_at_position(self, x: float, y: float) -> List[Label]: matches: List[Label] = [] @@ -182,12 +182,12 @@ def _collect_label_requested_fields( return pending_fields def _validate_label_policy(self, requested_fields: List[str]) -> Dict[str, EditRule]: - if not self.label_edit_policy: + if not self.edit_policy: raise ValueError("Edit policy for labels is not configured.") validated_rules: Dict[str, EditRule] = {} for field in requested_fields: - rule = self.label_edit_policy.get_rule(field) + rule = self.edit_policy.get_rule(field) if not rule: raise ValueError(f"Editing field '{field}' is not permitted for labels.") validated_rules[field] = rule diff --git a/static/client/managers/point_manager.py b/static/client/managers/point_manager.py index 3ba28336..e487468e 100644 --- a/static/client/managers/point_manager.py +++ b/static/client/managers/point_manager.py @@ -39,7 +39,8 @@ from drawables.point import Point from drawables.segment import Segment from utils.math_utils import MathUtils -from managers.edit_policy import DrawableEditPolicy, EditRule, get_drawable_edit_policy +from managers.base_drawable_manager import BaseDrawableManager +from managers.edit_policy import EditRule from managers.dependency_removal import get_polygon_segments, remove_drawable_with_dependencies if TYPE_CHECKING: @@ -51,7 +52,7 @@ from name_generator.drawable import DrawableNameGenerator -class PointManager: +class PointManager(BaseDrawableManager): """ Manages point drawables for a Canvas. @@ -61,6 +62,8 @@ class PointManager: - Deleting point objects """ + drawable_type: str = "Point" + def __init__( self, canvas: "Canvas", @@ -79,12 +82,13 @@ def __init__( dependency_manager: Manager for drawable dependencies drawable_manager_proxy: Proxy to the main DrawableManager """ - self.canvas: "Canvas" = canvas - self.drawables: "DrawablesContainer" = drawables_container - self.name_generator: "DrawableNameGenerator" = name_generator - self.dependency_manager: "DrawableDependencyManager" = dependency_manager - self.drawable_manager: "DrawableManagerProxy" = drawable_manager_proxy - self.point_edit_policy: Optional[DrawableEditPolicy] = get_drawable_edit_policy("Point") + super().__init__( + canvas, + drawables_container, + name_generator, + dependency_manager, + drawable_manager_proxy, + ) def get_point(self, x: float, y: float) -> Optional[Point]: """ @@ -117,10 +121,8 @@ def get_point_by_name(self, name: str) -> Optional[Point]: Returns: Point: The point with the matching name, or None if not found """ - for point in self.drawables.Points: - if point.name == name: - return point - return None + result = self._get_by_name(name) + return cast(Optional[Point], result) def create_point( self, @@ -390,12 +392,12 @@ def _rename_dependent_rotational_drawables(self, point: Point) -> None: def _validate_point_policy(self, requested_fields: List[str]) -> Dict[str, EditRule]: """Ensure every requested field is allowed by the policy definition.""" - if not self.point_edit_policy: + if not self.edit_policy: raise ValueError("Edit policy for points is not configured.") validated_rules: Dict[str, EditRule] = {} for field in requested_fields: - rule = self.point_edit_policy.get_rule(field) + rule = self.edit_policy.get_rule(field) if not rule: raise ValueError(f"Editing field '{field}' is not permitted for points.") validated_rules[field] = rule From 62c7875b7fac1fe0a62867250bcf0591db4f1b5a Mon Sep 17 00:00:00 2001 From: Vlad Ciobanu <95963142+vl3c@users.noreply.github.com> Date: Tue, 24 Mar 2026 21:49:48 +0200 Subject: [PATCH 10/28] Add tests for new server modules from refactor - test_env_config.py: 13 tests for load_env_files/get_api_key - test_config.py: 8 tests for server-side constant validation - test_route_helpers.py: 14 tests for provider reset/model/routing helpers --- server_tests/test_config.py | 56 ++++++++ server_tests/test_env_config.py | 186 ++++++++++++++++++++++++++ server_tests/test_route_helpers.py | 208 +++++++++++++++++++++++++++++ 3 files changed, 450 insertions(+) create mode 100644 server_tests/test_config.py create mode 100644 server_tests/test_env_config.py create mode 100644 server_tests/test_route_helpers.py diff --git a/server_tests/test_config.py b/server_tests/test_config.py new file mode 100644 index 00000000..f3ce6c9d --- /dev/null +++ b/server_tests/test_config.py @@ -0,0 +1,56 @@ +"""Tests for static.config — server-side configuration constants.""" + +from __future__ import annotations + +import os +import unittest + +from static.config import ( + CANVAS_SNAPSHOT_DIR, + CANVAS_SNAPSHOT_PATH, + CURRENT_WORKSPACE_SCHEMA_VERSION, + WORKSPACES_DIR, +) + + +class TestConfigConstants(unittest.TestCase): + """Verify types, values, and consistency of configuration constants.""" + + # -- Constant types and values -- + + def test_workspaces_dir_is_non_empty_string(self) -> None: + self.assertIsInstance(WORKSPACES_DIR, str) + self.assertTrue(len(WORKSPACES_DIR) > 0) + + def test_canvas_snapshot_dir_is_non_empty_string(self) -> None: + self.assertIsInstance(CANVAS_SNAPSHOT_DIR, str) + self.assertTrue(len(CANVAS_SNAPSHOT_DIR) > 0) + + def test_schema_version_is_positive_integer(self) -> None: + self.assertIsInstance(CURRENT_WORKSPACE_SCHEMA_VERSION, int) + self.assertGreater(CURRENT_WORKSPACE_SCHEMA_VERSION, 0) + + def test_canvas_snapshot_path_is_string(self) -> None: + self.assertIsInstance(CANVAS_SNAPSHOT_PATH, str) + + def test_canvas_snapshot_path_contains_dir_and_filename(self) -> None: + self.assertIn(CANVAS_SNAPSHOT_DIR, CANVAS_SNAPSHOT_PATH) + self.assertIn("canvas.png", CANVAS_SNAPSHOT_PATH) + + # -- Path construction -- + + def test_canvas_snapshot_path_equals_os_path_join(self) -> None: + expected = os.path.join(CANVAS_SNAPSHOT_DIR, "canvas.png") + self.assertEqual(CANVAS_SNAPSHOT_PATH, expected) + + # -- Consistency with usage -- + + def test_workspaces_dir_matches_expected_name(self) -> None: + self.assertEqual(WORKSPACES_DIR, "workspaces") + + def test_schema_version_at_least_one(self) -> None: + self.assertGreaterEqual(CURRENT_WORKSPACE_SCHEMA_VERSION, 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/server_tests/test_env_config.py b/server_tests/test_env_config.py new file mode 100644 index 00000000..fcd035ef --- /dev/null +++ b/server_tests/test_env_config.py @@ -0,0 +1,186 @@ +"""Tests for the centralized environment-variable loading module. + +Covers ``load_env_files()`` discovery logic and ``get_api_key()`` fast-path, +required/optional, and fallback behaviours. +""" + +from __future__ import annotations + +import os +import unittest +from unittest.mock import patch + +from static.env_config import get_api_key, load_env_files + + +class TestLoadEnvFiles(unittest.TestCase): + """Test cases for load_env_files().""" + + @patch("static.env_config.load_dotenv") + @patch("static.env_config.os.path.exists", return_value=True) + def test_loads_project_root_and_parent_env( + self, mock_exists: unittest.mock.MagicMock, mock_load: unittest.mock.MagicMock + ) -> None: + """When a parent .env exists, both project-root and parent are loaded.""" + load_env_files() + + # First call: project-root (no explicit path) + # Second call: parent directory .env + self.assertEqual(mock_load.call_count, 2) + # The first call uses no arguments (project root default) + mock_load.assert_any_call() + + @patch("static.env_config.load_dotenv") + @patch("static.env_config.os.path.exists", return_value=False) + def test_skips_parent_env_when_missing( + self, mock_exists: unittest.mock.MagicMock, mock_load: unittest.mock.MagicMock + ) -> None: + """When the parent .env does not exist, only the project-root is loaded.""" + load_env_files() + + # Only the project-root call (no path arg) should occur + mock_load.assert_called_once_with() + + @patch("static.env_config.load_dotenv") + @patch("static.env_config.os.path.exists", return_value=False) + def test_does_not_crash_when_no_env_files( + self, mock_exists: unittest.mock.MagicMock, mock_load: unittest.mock.MagicMock + ) -> None: + """Calling load_env_files when no .env files exist must not raise.""" + # Should complete without error + load_env_files() + + @patch("static.env_config.load_dotenv") + @patch("static.env_config.os.path.exists", return_value=True) + def test_idempotent_multiple_calls( + self, mock_exists: unittest.mock.MagicMock, mock_load: unittest.mock.MagicMock + ) -> None: + """Calling load_env_files twice delegates to load_dotenv each time. + + python-dotenv itself handles the idempotency guarantee (it will not + overwrite variables already present in os.environ), so repeated calls + are safe even though load_dotenv is invoked again. + """ + load_env_files() + load_env_files() + + # Two invocations x 2 load_dotenv calls each = 4 total + self.assertEqual(mock_load.call_count, 4) + + @patch("static.env_config.load_dotenv") + @patch("static.env_config.os.path.exists", return_value=True) + def test_parent_env_path_constructed_correctly( + self, mock_exists: unittest.mock.MagicMock, mock_load: unittest.mock.MagicMock + ) -> None: + """The parent .env path should be /../.env.""" + load_env_files() + + expected_parent = os.path.join(os.path.dirname(os.getcwd()), ".env") + mock_exists.assert_called_once_with(expected_parent) + mock_load.assert_any_call(expected_parent) + + +class TestGetApiKeyFastPath(unittest.TestCase): + """Test cases for the get_api_key() fast-path (key already in environ).""" + + @patch("static.env_config.load_env_files") + @patch.dict(os.environ, {"MY_KEY": "fast-value"}, clear=False) + def test_returns_key_without_loading_env( + self, mock_load: unittest.mock.MagicMock + ) -> None: + """When the key is already in os.environ, skip load_env_files entirely.""" + result = get_api_key("MY_KEY") + + self.assertEqual(result, "fast-value") + mock_load.assert_not_called() + + +class TestGetApiKeyRequired(unittest.TestCase): + """Test cases for get_api_key() with required=True (the default).""" + + @patch("static.env_config.load_env_files") + @patch.dict(os.environ, {"FOUND_KEY": "secret-123"}, clear=False) + def test_returns_key_when_present_in_environ( + self, mock_load: unittest.mock.MagicMock + ) -> None: + """Returns the key when it exists in os.environ (fast-path).""" + result = get_api_key("FOUND_KEY", required=True) + self.assertEqual(result, "secret-123") + + @patch("static.env_config.load_env_files") + def test_returns_key_loaded_from_env_file( + self, mock_load: unittest.mock.MagicMock + ) -> None: + """Returns the key when load_env_files populates it on the retry.""" + key_name = "LAZY_KEY" + # Ensure the key is absent initially + env_copy = os.environ.copy() + env_copy.pop(key_name, None) + + def _inject_key() -> None: + os.environ[key_name] = "loaded-from-dotenv" + + mock_load.side_effect = _inject_key + + with patch.dict(os.environ, env_copy, clear=True): + result = get_api_key(key_name, required=True) + + self.assertEqual(result, "loaded-from-dotenv") + mock_load.assert_called_once() + # Clean up injected key + os.environ.pop(key_name, None) + + @patch("static.env_config.load_env_files") + @patch.dict(os.environ, {}, clear=True) + def test_raises_value_error_when_missing( + self, mock_load: unittest.mock.MagicMock + ) -> None: + """Raises ValueError when required=True and the key cannot be found.""" + with self.assertRaises(ValueError) as ctx: + get_api_key("MISSING_KEY", required=True) + + self.assertIn("MISSING_KEY", str(ctx.exception)) + + @patch("static.env_config.load_env_files") + @patch.dict(os.environ, {}, clear=True) + def test_raises_value_error_by_default( + self, mock_load: unittest.mock.MagicMock + ) -> None: + """required defaults to True, so omitting it still raises on missing key.""" + with self.assertRaises(ValueError): + get_api_key("ABSENT_KEY") + + +class TestGetApiKeyOptional(unittest.TestCase): + """Test cases for get_api_key() with required=False.""" + + @patch("static.env_config.load_env_files") + @patch.dict(os.environ, {}, clear=True) + def test_returns_fallback_when_missing( + self, mock_load: unittest.mock.MagicMock + ) -> None: + """Returns the explicit fallback value when key is absent.""" + result = get_api_key("NOPE", required=False, fallback="default-val") + self.assertEqual(result, "default-val") + + @patch("static.env_config.load_env_files") + @patch.dict(os.environ, {}, clear=True) + def test_returns_empty_string_when_no_fallback( + self, mock_load: unittest.mock.MagicMock + ) -> None: + """Returns empty string when key is absent and no fallback is given.""" + result = get_api_key("NOPE", required=False) + self.assertEqual(result, "") + + @patch("static.env_config.load_env_files") + @patch.dict(os.environ, {"OPT_KEY": "found-it"}, clear=False) + def test_returns_key_when_present_even_if_not_required( + self, mock_load: unittest.mock.MagicMock + ) -> None: + """When the key exists, it is returned regardless of required flag.""" + result = get_api_key("OPT_KEY", required=False, fallback="ignored") + self.assertEqual(result, "found-it") + + +if __name__ == "__main__": + unittest.main() diff --git a/server_tests/test_route_helpers.py b/server_tests/test_route_helpers.py new file mode 100644 index 00000000..57a3b014 --- /dev/null +++ b/server_tests/test_route_helpers.py @@ -0,0 +1,208 @@ +"""Tests for static.route_helpers module. + +Covers reset_tools_for_all_providers, update_all_provider_models, +and get_active_provider using MagicMock-based app/provider stubs. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from static.route_helpers import ( + get_active_provider, + reset_tools_for_all_providers, + update_all_provider_models, +) + + +def _make_app( + ai_injected: bool = False, + responses_injected: bool = False, +) -> MagicMock: + """Build a MagicMock that mimics MatHudFlask with two built-in providers.""" + app = MagicMock() + + app.ai_api.has_injected_tools.return_value = ai_injected + app.responses_api.has_injected_tools.return_value = responses_injected + + return app + + +# --------------------------------------------------------------------------- +# reset_tools_for_all_providers +# --------------------------------------------------------------------------- + +class TestResetToolsForAllProviders: + """Tests for reset_tools_for_all_providers.""" + + def test_noop_when_finish_reason_is_tool_calls(self) -> None: + """No reset happens when the model still wants to call tools.""" + app = _make_app(ai_injected=True, responses_injected=True) + + reset_tools_for_all_providers(app, "tool_calls") + + app.ai_api.has_injected_tools.assert_not_called() + app.responses_api.has_injected_tools.assert_not_called() + app.ai_api.reset_tools.assert_not_called() + app.responses_api.reset_tools.assert_not_called() + + def test_resets_ai_api_when_injected_and_stop(self) -> None: + """ai_api is reset when it has injected tools and finish_reason is 'stop'.""" + app = _make_app(ai_injected=True, responses_injected=False) + + reset_tools_for_all_providers(app, "stop") + + app.ai_api.reset_tools.assert_called_once() + app.responses_api.reset_tools.assert_not_called() + + def test_resets_responses_api_when_injected(self) -> None: + """responses_api is reset when it has injected tools.""" + app = _make_app(ai_injected=False, responses_injected=True) + + reset_tools_for_all_providers(app, "stop") + + app.ai_api.reset_tools.assert_not_called() + app.responses_api.reset_tools.assert_called_once() + + def test_no_reset_when_no_injected_tools(self) -> None: + """Neither provider is reset when neither has injected tools.""" + app = _make_app(ai_injected=False, responses_injected=False) + + reset_tools_for_all_providers(app, "stop") + + app.ai_api.reset_tools.assert_not_called() + app.responses_api.reset_tools.assert_not_called() + + def test_resets_active_provider_when_different_and_injected(self) -> None: + """An active_provider that differs from the built-ins is reset.""" + app = _make_app(ai_injected=False, responses_injected=False) + extra_provider = MagicMock() + extra_provider.has_injected_tools.return_value = True + + reset_tools_for_all_providers(app, "stop", active_provider=extra_provider) + + extra_provider.reset_tools.assert_called_once() + + def test_does_not_reset_active_provider_when_same_as_ai_api(self) -> None: + """active_provider is NOT reset when it is the same object as ai_api.""" + app = _make_app(ai_injected=True, responses_injected=False) + # Pass the same object that lives on app.ai_api. + same_provider = app.ai_api + + reset_tools_for_all_providers(app, "stop", active_provider=same_provider) + + # ai_api.reset_tools should be called once (from the built-in path), + # but NOT a second time from the active_provider path. + app.ai_api.reset_tools.assert_called_once() + + def test_handles_none_active_provider(self) -> None: + """None active_provider is handled gracefully (no AttributeError).""" + app = _make_app(ai_injected=False, responses_injected=False) + + # Should not raise + reset_tools_for_all_providers(app, "stop", active_provider=None) + + app.ai_api.reset_tools.assert_not_called() + app.responses_api.reset_tools.assert_not_called() + + def test_resets_all_three_when_all_injected(self) -> None: + """All three providers reset when each has injected tools.""" + app = _make_app(ai_injected=True, responses_injected=True) + extra_provider = MagicMock() + extra_provider.has_injected_tools.return_value = True + + reset_tools_for_all_providers(app, "stop", active_provider=extra_provider) + + app.ai_api.reset_tools.assert_called_once() + app.responses_api.reset_tools.assert_called_once() + extra_provider.reset_tools.assert_called_once() + + def test_does_not_reset_active_provider_when_same_as_responses_api(self) -> None: + """active_provider is NOT reset when it is the same object as responses_api.""" + app = _make_app(ai_injected=False, responses_injected=True) + same_provider = app.responses_api + + reset_tools_for_all_providers(app, "stop", active_provider=same_provider) + + # Only the built-in path calls reset_tools, not the active_provider path. + app.responses_api.reset_tools.assert_called_once() + + def test_does_not_reset_active_provider_without_injected_tools(self) -> None: + """active_provider is NOT reset when it has no injected tools.""" + app = _make_app(ai_injected=False, responses_injected=False) + extra_provider = MagicMock() + extra_provider.has_injected_tools.return_value = False + + reset_tools_for_all_providers(app, "stop", active_provider=extra_provider) + + extra_provider.reset_tools.assert_not_called() + + +# --------------------------------------------------------------------------- +# update_all_provider_models +# --------------------------------------------------------------------------- + +class TestUpdateAllProviderModels: + """Tests for update_all_provider_models.""" + + def test_sets_model_on_both_providers(self) -> None: + """Both ai_api and responses_api receive set_model with the correct id.""" + app = _make_app() + model_id = "gpt-4.1" + + update_all_provider_models(app, model_id) + + app.ai_api.set_model.assert_called_once_with(model_id) + app.responses_api.set_model.assert_called_once_with(model_id) + + +# --------------------------------------------------------------------------- +# get_active_provider +# --------------------------------------------------------------------------- + +class TestGetActiveProvider: + """Tests for get_active_provider.""" + + def test_returns_ai_api_when_model_id_is_none(self) -> None: + """Default provider is app.ai_api when model_id is None.""" + app = _make_app() + + result = get_active_provider(app, None) + + assert result is app.ai_api + app.ai_api.set_model.assert_not_called() + app.responses_api.set_model.assert_not_called() + + @patch("static.route_helpers.update_all_provider_models") + @patch("static.routes.get_provider_for_model") + def test_calls_update_and_get_provider_when_model_id_given( + self, + mock_get_provider: MagicMock, + mock_update: MagicMock, + ) -> None: + """update_all_provider_models and get_provider_for_model are called.""" + app = _make_app() + sentinel_provider = MagicMock(name="resolved_provider") + mock_get_provider.return_value = sentinel_provider + + result = get_active_provider(app, "gpt-4.1") + + mock_update.assert_called_once_with(app, "gpt-4.1") + mock_get_provider.assert_called_once_with(app, "gpt-4.1") + assert result is sentinel_provider + + @patch("static.route_helpers.update_all_provider_models") + @patch("static.routes.get_provider_for_model") + def test_returns_resolved_provider( + self, + mock_get_provider: MagicMock, + mock_update: MagicMock, + ) -> None: + """The provider object returned by get_provider_for_model is passed through.""" + app = _make_app() + expected = MagicMock(name="expected_provider") + mock_get_provider.return_value = expected + + result = get_active_provider(app, "o4-mini") + + assert result is expected From 90a90f3d3cc95a9c3d2edc777ba48df372691347 Mon Sep 17 00:00:00 2001 From: Vlad Ciobanu <95963142+vl3c@users.noreply.github.com> Date: Tue, 24 Mar 2026 21:49:54 +0200 Subject: [PATCH 11/28] Add client tests for BaseDrawableManager and BaseRendererTelemetry - test_base_drawable_manager.py: 12 tests for constructor, _get_by_name, edit policy resolution, and type filtering - test_base_telemetry.py: 15 test classes covering initialization, reset, timing, recording, snapshot/drain, and Canvas2D bucket override --- .../test_base_drawable_manager.py | 136 +++++ .../client_tests/test_base_telemetry.py | 540 ++++++++++++++++++ static/client/client_tests/tests.py | 34 ++ 3 files changed, 710 insertions(+) create mode 100644 static/client/client_tests/test_base_drawable_manager.py create mode 100644 static/client/client_tests/test_base_telemetry.py diff --git a/static/client/client_tests/test_base_drawable_manager.py b/static/client/client_tests/test_base_drawable_manager.py new file mode 100644 index 00000000..e83f2eec --- /dev/null +++ b/static/client/client_tests/test_base_drawable_manager.py @@ -0,0 +1,136 @@ +"""Tests for BaseDrawableManager base class.""" + +import unittest + +from managers.base_drawable_manager import BaseDrawableManager +from managers.drawables_container import DrawablesContainer +from .simple_mock import SimpleMock + + +class ConcreteManager(BaseDrawableManager): + """Minimal concrete subclass for testing the base class.""" + + drawable_type = "Point" + + +class TestBaseDrawableManager(unittest.TestCase): + def setUp(self) -> None: + self.canvas = SimpleMock(name="CanvasMock") + self.drawables = DrawablesContainer() + self.name_generator = SimpleMock(name="NameGeneratorMock") + self.dependency_manager = SimpleMock(name="DependencyManagerMock") + self.drawable_manager_proxy = SimpleMock(name="DrawableManagerProxyMock") + + self.manager = ConcreteManager( + canvas=self.canvas, + drawables_container=self.drawables, + name_generator=self.name_generator, + dependency_manager=self.dependency_manager, + drawable_manager_proxy=self.drawable_manager_proxy, + ) + + def _make_drawable(self, class_name: str, name: str) -> SimpleMock: + """Create a mock drawable with the given class name and name.""" + return SimpleMock( + name=name, + get_class_name=lambda _cn=class_name: _cn, + is_renderable=True, + ) + + # ------------------------------------------------------------------ + # Constructor stores all dependencies + # ------------------------------------------------------------------ + + def test_constructor_stores_canvas(self) -> None: + self.assertIs(self.manager.canvas, self.canvas) + + def test_constructor_stores_drawables(self) -> None: + self.assertIs(self.manager.drawables, self.drawables) + + def test_constructor_stores_name_generator(self) -> None: + self.assertIs(self.manager.name_generator, self.name_generator) + + def test_constructor_stores_dependency_manager(self) -> None: + self.assertIs(self.manager.dependency_manager, self.dependency_manager) + + def test_constructor_stores_drawable_manager_proxy(self) -> None: + self.assertIs(self.manager.drawable_manager, self.drawable_manager_proxy) + + # ------------------------------------------------------------------ + # Edit policy resolved for known type + # ------------------------------------------------------------------ + + def test_edit_policy_resolved_for_known_type(self) -> None: + self.assertIsNotNone(self.manager.edit_policy) + self.assertEqual(self.manager.edit_policy.drawable_type, "Point") + + def test_edit_policy_is_none_for_empty_drawable_type(self) -> None: + class EmptyTypeManager(BaseDrawableManager): + drawable_type = "" + + manager = EmptyTypeManager( + canvas=self.canvas, + drawables_container=self.drawables, + name_generator=self.name_generator, + dependency_manager=self.dependency_manager, + drawable_manager_proxy=self.drawable_manager_proxy, + ) + self.assertIsNone(manager.edit_policy) + + # ------------------------------------------------------------------ + # _get_by_name returns matching drawable + # ------------------------------------------------------------------ + + def test_get_by_name_returns_matching_drawable(self) -> None: + point_a = self._make_drawable("Point", "A") + self.drawables.add(point_a) + + result = self.manager._get_by_name("A") + + self.assertIs(result, point_a) + + # ------------------------------------------------------------------ + # _get_by_name returns None for no match + # ------------------------------------------------------------------ + + def test_get_by_name_returns_none_for_no_match(self) -> None: + point_a = self._make_drawable("Point", "A") + self.drawables.add(point_a) + + result = self.manager._get_by_name("Z") + + self.assertIsNone(result) + + # ------------------------------------------------------------------ + # _get_by_name returns None for empty string + # ------------------------------------------------------------------ + + def test_get_by_name_returns_none_for_empty_string(self) -> None: + point_a = self._make_drawable("Point", "A") + self.drawables.add(point_a) + + result = self.manager._get_by_name("") + + self.assertIsNone(result) + + # ------------------------------------------------------------------ + # _get_by_name filters by drawable_type + # ------------------------------------------------------------------ + + def test_get_by_name_filters_by_drawable_type(self) -> None: + segment = self._make_drawable("Segment", "A") + point_a = self._make_drawable("Point", "A") + self.drawables.add(segment) + self.drawables.add(point_a) + + result = self.manager._get_by_name("A") + + self.assertIs(result, point_a) + + def test_get_by_name_ignores_other_types_with_same_name(self) -> None: + segment = self._make_drawable("Segment", "B") + self.drawables.add(segment) + + result = self.manager._get_by_name("B") + + self.assertIsNone(result) diff --git a/static/client/client_tests/test_base_telemetry.py b/static/client/client_tests/test_base_telemetry.py new file mode 100644 index 00000000..1f2a5da7 --- /dev/null +++ b/static/client/client_tests/test_base_telemetry.py @@ -0,0 +1,540 @@ +"""Tests for BaseRendererTelemetry and Canvas2DTelemetry.""" + +from __future__ import annotations + +import unittest + +from rendering.base_telemetry import BaseRendererTelemetry +from rendering.canvas2d_renderer import Canvas2DTelemetry + + +class TestBaseTelemetryInit(unittest.TestCase): + """Verify that initialization calls reset and zeroes all counters.""" + + def test_init_calls_reset(self) -> None: + tel = BaseRendererTelemetry() + self.assertEqual(tel._frames, 0) + self.assertEqual(tel._max_batch_depth, 0) + self.assertIsInstance(tel._phase_totals, dict) + self.assertIsInstance(tel._phase_counts, dict) + self.assertIsInstance(tel._per_drawable, dict) + self.assertIsInstance(tel._adapter_events, dict) + + def test_init_phase_totals_zero(self) -> None: + tel = BaseRendererTelemetry() + for key in ("plan_build_ms", "plan_apply_ms", + "cartesian_plan_build_ms", "cartesian_plan_apply_ms"): + self.assertEqual(tel._phase_totals[key], 0.0, f"{key} should be 0.0") + + def test_init_phase_counts_zero(self) -> None: + tel = BaseRendererTelemetry() + for key in ("plan_build_count", "plan_apply_count", + "cartesian_plan_count", "plan_miss_count", "plan_skip_count"): + self.assertEqual(tel._phase_counts[key], 0, f"{key} should be 0") + + def test_init_per_drawable_empty(self) -> None: + tel = BaseRendererTelemetry() + self.assertEqual(len(tel._per_drawable), 0) + + def test_init_adapter_events_empty(self) -> None: + tel = BaseRendererTelemetry() + self.assertEqual(len(tel._adapter_events), 0) + + +class TestBaseTelemetryReset(unittest.TestCase): + """Verify reset() zeroes all accumulated data.""" + + def test_reset_clears_phase_totals(self) -> None: + tel = BaseRendererTelemetry() + tel._phase_totals["plan_build_ms"] = 42.0 + tel.reset() + self.assertEqual(tel._phase_totals["plan_build_ms"], 0.0) + + def test_reset_clears_phase_counts(self) -> None: + tel = BaseRendererTelemetry() + tel._phase_counts["plan_build_count"] = 7 + tel.reset() + self.assertEqual(tel._phase_counts["plan_build_count"], 0) + + def test_reset_clears_per_drawable(self) -> None: + tel = BaseRendererTelemetry() + tel.record_plan_build("Circle", 1.0) + self.assertIn("Circle", tel._per_drawable) + tel.reset() + self.assertEqual(len(tel._per_drawable), 0) + + def test_reset_clears_adapter_events(self) -> None: + tel = BaseRendererTelemetry() + tel.record_adapter_event("draw_circle", 5) + self.assertIn("draw_circle", tel._adapter_events) + tel.reset() + self.assertEqual(len(tel._adapter_events), 0) + + def test_reset_clears_frames(self) -> None: + tel = BaseRendererTelemetry() + tel.begin_frame() + tel.begin_frame() + self.assertEqual(tel._frames, 2) + tel.reset() + self.assertEqual(tel._frames, 0) + + def test_reset_clears_max_batch_depth(self) -> None: + tel = BaseRendererTelemetry() + tel.track_batch_depth(5) + self.assertEqual(tel._max_batch_depth, 5) + tel.reset() + self.assertEqual(tel._max_batch_depth, 0) + + +class TestBaseTelemetryBeginFrame(unittest.TestCase): + """Verify begin_frame increments frame counter.""" + + def test_begin_frame_increments(self) -> None: + tel = BaseRendererTelemetry() + self.assertEqual(tel._frames, 0) + tel.begin_frame() + self.assertEqual(tel._frames, 1) + tel.begin_frame() + self.assertEqual(tel._frames, 2) + + def test_end_frame_is_noop(self) -> None: + tel = BaseRendererTelemetry() + tel.begin_frame() + tel.end_frame() + self.assertEqual(tel._frames, 1) + + +class TestBaseTelemetryTiming(unittest.TestCase): + """Verify mark_time and elapsed_since return reasonable values.""" + + def test_mark_time_returns_positive(self) -> None: + tel = BaseRendererTelemetry() + t = tel.mark_time() + self.assertGreater(t, 0) + + def test_elapsed_since_non_negative(self) -> None: + tel = BaseRendererTelemetry() + t = tel.mark_time() + elapsed = tel.elapsed_since(t) + self.assertGreaterEqual(elapsed, 0.0) + + def test_elapsed_since_far_past_is_positive(self) -> None: + tel = BaseRendererTelemetry() + elapsed = tel.elapsed_since(0.0) + self.assertGreater(elapsed, 0.0) + + def test_elapsed_since_far_future_is_zero(self) -> None: + tel = BaseRendererTelemetry() + elapsed = tel.elapsed_since(1e15) + self.assertEqual(elapsed, 0.0) + + +class TestBaseTelemetryRecordPlanBuild(unittest.TestCase): + """Verify record_plan_build accumulation.""" + + def test_accumulates_phase_totals(self) -> None: + tel = BaseRendererTelemetry() + tel.record_plan_build("Point", 1.5) + tel.record_plan_build("Point", 2.5) + self.assertAlmostEqual(tel._phase_totals["plan_build_ms"], 4.0) + + def test_increments_phase_count(self) -> None: + tel = BaseRendererTelemetry() + tel.record_plan_build("Point", 1.0) + tel.record_plan_build("Segment", 2.0) + self.assertEqual(tel._phase_counts["plan_build_count"], 2) + + def test_creates_per_drawable_bucket(self) -> None: + tel = BaseRendererTelemetry() + tel.record_plan_build("Circle", 3.0) + self.assertIn("Circle", tel._per_drawable) + bucket = tel._per_drawable["Circle"] + self.assertAlmostEqual(bucket["plan_build_ms"], 3.0) + self.assertEqual(bucket["plan_build_count"], 1) + + def test_accumulates_in_same_bucket(self) -> None: + tel = BaseRendererTelemetry() + tel.record_plan_build("Circle", 1.0) + tel.record_plan_build("Circle", 2.0) + bucket = tel._per_drawable["Circle"] + self.assertAlmostEqual(bucket["plan_build_ms"], 3.0) + self.assertEqual(bucket["plan_build_count"], 2) + + def test_does_not_affect_cartesian_by_default(self) -> None: + tel = BaseRendererTelemetry() + tel.record_plan_build("Point", 5.0) + self.assertEqual(tel._phase_totals["cartesian_plan_build_ms"], 0.0) + self.assertEqual(tel._phase_counts["cartesian_plan_count"], 0) + + +class TestBaseTelemetryRecordPlanApply(unittest.TestCase): + """Verify record_plan_apply accumulation.""" + + def test_accumulates_phase_totals(self) -> None: + tel = BaseRendererTelemetry() + tel.record_plan_apply("Point", 1.5) + tel.record_plan_apply("Point", 2.5) + self.assertAlmostEqual(tel._phase_totals["plan_apply_ms"], 4.0) + + def test_increments_phase_count(self) -> None: + tel = BaseRendererTelemetry() + tel.record_plan_apply("Segment", 1.0) + self.assertEqual(tel._phase_counts["plan_apply_count"], 1) + + def test_creates_per_drawable_bucket(self) -> None: + tel = BaseRendererTelemetry() + tel.record_plan_apply("Triangle", 2.0) + self.assertIn("Triangle", tel._per_drawable) + bucket = tel._per_drawable["Triangle"] + self.assertAlmostEqual(bucket["plan_apply_ms"], 2.0) + self.assertEqual(bucket["plan_apply_count"], 1) + + def test_does_not_affect_cartesian_by_default(self) -> None: + tel = BaseRendererTelemetry() + tel.record_plan_apply("Point", 5.0) + self.assertEqual(tel._phase_totals["cartesian_plan_apply_ms"], 0.0) + + +class TestBaseTelemetryRecordPlanMiss(unittest.TestCase): + """Verify record_plan_miss increments counters.""" + + def test_increments_phase_count(self) -> None: + tel = BaseRendererTelemetry() + tel.record_plan_miss("Circle") + tel.record_plan_miss("Circle") + self.assertEqual(tel._phase_counts["plan_miss_count"], 2) + + def test_increments_per_drawable_bucket(self) -> None: + tel = BaseRendererTelemetry() + tel.record_plan_miss("Circle") + bucket = tel._per_drawable["Circle"] + self.assertEqual(bucket["plan_miss_count"], 1) + + +class TestBaseTelemetryRecordPlanSkip(unittest.TestCase): + """Verify record_plan_skip increments counters.""" + + def test_increments_phase_count(self) -> None: + tel = BaseRendererTelemetry() + tel.record_plan_skip("Ellipse") + self.assertEqual(tel._phase_counts["plan_skip_count"], 1) + + def test_increments_per_drawable_bucket(self) -> None: + tel = BaseRendererTelemetry() + tel.record_plan_skip("Ellipse") + tel.record_plan_skip("Ellipse") + bucket = tel._per_drawable["Ellipse"] + self.assertEqual(bucket["plan_skip_count"], 2) + + +class TestBaseTelemetryCartesian(unittest.TestCase): + """Verify cartesian=True updates both regular and cartesian counters.""" + + def test_record_plan_build_cartesian(self) -> None: + tel = BaseRendererTelemetry() + tel.record_plan_build("Grid", 3.0, cartesian=True) + self.assertAlmostEqual(tel._phase_totals["plan_build_ms"], 3.0) + self.assertAlmostEqual(tel._phase_totals["cartesian_plan_build_ms"], 3.0) + self.assertEqual(tel._phase_counts["plan_build_count"], 1) + self.assertEqual(tel._phase_counts["cartesian_plan_count"], 1) + + def test_record_plan_apply_cartesian(self) -> None: + tel = BaseRendererTelemetry() + tel.record_plan_apply("Grid", 2.0, cartesian=True) + self.assertAlmostEqual(tel._phase_totals["plan_apply_ms"], 2.0) + self.assertAlmostEqual(tel._phase_totals["cartesian_plan_apply_ms"], 2.0) + self.assertEqual(tel._phase_counts["plan_apply_count"], 1) + + def test_cartesian_does_not_increment_cartesian_plan_count_on_apply(self) -> None: + tel = BaseRendererTelemetry() + tel.record_plan_apply("Grid", 2.0, cartesian=True) + # cartesian_plan_count is only incremented by record_plan_build + self.assertEqual(tel._phase_counts["cartesian_plan_count"], 0) + + def test_mixed_cartesian_and_regular(self) -> None: + tel = BaseRendererTelemetry() + tel.record_plan_build("Point", 1.0) + tel.record_plan_build("Grid", 2.0, cartesian=True) + self.assertAlmostEqual(tel._phase_totals["plan_build_ms"], 3.0) + self.assertAlmostEqual(tel._phase_totals["cartesian_plan_build_ms"], 2.0) + self.assertEqual(tel._phase_counts["plan_build_count"], 2) + self.assertEqual(tel._phase_counts["cartesian_plan_count"], 1) + + +class TestBaseTelemetryAdapterEvent(unittest.TestCase): + """Verify record_adapter_event initializes and accumulates.""" + + def test_initializes_new_event(self) -> None: + tel = BaseRendererTelemetry() + tel.record_adapter_event("fill_rect") + self.assertEqual(tel._adapter_events["fill_rect"], 1) + + def test_accumulates_events(self) -> None: + tel = BaseRendererTelemetry() + tel.record_adapter_event("fill_rect", 3) + tel.record_adapter_event("fill_rect", 2) + self.assertEqual(tel._adapter_events["fill_rect"], 5) + + def test_multiple_event_types(self) -> None: + tel = BaseRendererTelemetry() + tel.record_adapter_event("fill_rect", 2) + tel.record_adapter_event("draw_line", 1) + self.assertEqual(tel._adapter_events["fill_rect"], 2) + self.assertEqual(tel._adapter_events["draw_line"], 1) + + def test_default_amount_is_one(self) -> None: + tel = BaseRendererTelemetry() + tel.record_adapter_event("stroke_path") + tel.record_adapter_event("stroke_path") + self.assertEqual(tel._adapter_events["stroke_path"], 2) + + +class TestBaseTelemetryBatchDepth(unittest.TestCase): + """Verify track_batch_depth tracks maximum depth.""" + + def test_tracks_maximum(self) -> None: + tel = BaseRendererTelemetry() + tel.track_batch_depth(3) + tel.track_batch_depth(5) + tel.track_batch_depth(2) + self.assertEqual(tel._max_batch_depth, 5) + + def test_ignores_lower_depth(self) -> None: + tel = BaseRendererTelemetry() + tel.track_batch_depth(10) + tel.track_batch_depth(3) + self.assertEqual(tel._max_batch_depth, 10) + + def test_zero_depth_does_not_update(self) -> None: + tel = BaseRendererTelemetry() + tel.track_batch_depth(0) + self.assertEqual(tel._max_batch_depth, 0) + + +class TestBaseTelemetrySnapshot(unittest.TestCase): + """Verify snapshot returns correct dict and does NOT reset.""" + + def test_snapshot_returns_expected_keys(self) -> None: + tel = BaseRendererTelemetry() + snap = tel.snapshot() + self.assertIn("frames", snap) + self.assertIn("phase", snap) + self.assertIn("per_drawable", snap) + self.assertIn("adapter_events", snap) + + def test_snapshot_frames_value(self) -> None: + tel = BaseRendererTelemetry() + tel.begin_frame() + tel.begin_frame() + snap = tel.snapshot() + self.assertEqual(snap["frames"], 2) + + def test_snapshot_phase_merges_totals_and_counts(self) -> None: + tel = BaseRendererTelemetry() + tel.record_plan_build("Point", 1.5) + snap = tel.snapshot() + phase = snap["phase"] + self.assertIn("plan_build_ms", phase) + self.assertIn("plan_build_count", phase) + self.assertAlmostEqual(phase["plan_build_ms"], 1.5) + self.assertEqual(phase["plan_build_count"], 1) + + def test_snapshot_per_drawable_present(self) -> None: + tel = BaseRendererTelemetry() + tel.record_plan_build("Circle", 2.0) + snap = tel.snapshot() + self.assertIn("Circle", snap["per_drawable"]) + self.assertAlmostEqual(snap["per_drawable"]["Circle"]["plan_build_ms"], 2.0) + + def test_snapshot_adapter_events_present(self) -> None: + tel = BaseRendererTelemetry() + tel.record_adapter_event("draw_arc", 3) + snap = tel.snapshot() + self.assertEqual(snap["adapter_events"]["draw_arc"], 3) + + def test_snapshot_includes_max_batch_depth_when_nonzero(self) -> None: + tel = BaseRendererTelemetry() + tel.track_batch_depth(4) + snap = tel.snapshot() + self.assertEqual(snap["adapter_events"]["max_batch_depth"], 4) + + def test_snapshot_excludes_max_batch_depth_when_zero(self) -> None: + tel = BaseRendererTelemetry() + snap = tel.snapshot() + self.assertNotIn("max_batch_depth", snap["adapter_events"]) + + def test_snapshot_does_not_reset(self) -> None: + tel = BaseRendererTelemetry() + tel.begin_frame() + tel.record_plan_build("Point", 1.0) + tel.record_adapter_event("fill", 2) + tel.track_batch_depth(3) + + snap1 = tel.snapshot() + snap2 = tel.snapshot() + + self.assertEqual(snap1["frames"], snap2["frames"]) + self.assertEqual(snap1["phase"], snap2["phase"]) + self.assertEqual(tel._frames, 1) + + def test_snapshot_returns_copies(self) -> None: + tel = BaseRendererTelemetry() + tel.record_plan_build("Point", 1.0) + snap = tel.snapshot() + # Mutating the snapshot should not affect internal state + snap["phase"]["plan_build_ms"] = 999.0 + snap["per_drawable"]["Point"]["plan_build_ms"] = 999.0 + self.assertAlmostEqual(tel._phase_totals["plan_build_ms"], 1.0) + self.assertAlmostEqual(tel._per_drawable["Point"]["plan_build_ms"], 1.0) + + +class TestBaseTelemetryDrain(unittest.TestCase): + """Verify drain returns correct dict and DOES reset.""" + + def test_drain_returns_expected_keys(self) -> None: + tel = BaseRendererTelemetry() + result = tel.drain() + self.assertIn("frames", result) + self.assertIn("phase", result) + self.assertIn("per_drawable", result) + self.assertIn("adapter_events", result) + + def test_drain_returns_accumulated_data(self) -> None: + tel = BaseRendererTelemetry() + tel.begin_frame() + tel.record_plan_build("Segment", 5.0) + tel.record_adapter_event("stroke", 10) + result = tel.drain() + self.assertEqual(result["frames"], 1) + self.assertAlmostEqual(result["phase"]["plan_build_ms"], 5.0) + self.assertEqual(result["adapter_events"]["stroke"], 10) + + def test_drain_resets_counters(self) -> None: + tel = BaseRendererTelemetry() + tel.begin_frame() + tel.record_plan_build("Segment", 5.0) + tel.record_adapter_event("stroke", 10) + tel.track_batch_depth(3) + tel.drain() + + self.assertEqual(tel._frames, 0) + self.assertEqual(tel._phase_totals["plan_build_ms"], 0.0) + self.assertEqual(tel._phase_counts["plan_build_count"], 0) + self.assertEqual(len(tel._per_drawable), 0) + self.assertEqual(len(tel._adapter_events), 0) + self.assertEqual(tel._max_batch_depth, 0) + + def test_drain_then_snapshot_is_empty(self) -> None: + tel = BaseRendererTelemetry() + tel.begin_frame() + tel.record_plan_build("Point", 1.0) + tel.drain() + snap = tel.snapshot() + self.assertEqual(snap["frames"], 0) + self.assertEqual(snap["phase"]["plan_build_ms"], 0.0) + self.assertEqual(len(snap["per_drawable"]), 0) + + +class TestBaseTelemetryNewDrawableBucket(unittest.TestCase): + """Verify _new_drawable_bucket returns correct default keys.""" + + def test_default_bucket_keys(self) -> None: + tel = BaseRendererTelemetry() + bucket = tel._new_drawable_bucket() + expected_keys = { + "plan_build_ms", + "plan_apply_ms", + "plan_build_count", + "plan_apply_count", + "plan_miss_count", + "plan_skip_count", + } + self.assertEqual(set(bucket.keys()), expected_keys) + + def test_default_bucket_values_zero(self) -> None: + tel = BaseRendererTelemetry() + bucket = tel._new_drawable_bucket() + for key, value in bucket.items(): + self.assertEqual(value, 0.0, f"{key} should be 0.0") + + def test_buckets_are_independent(self) -> None: + tel = BaseRendererTelemetry() + b1 = tel._new_drawable_bucket() + b2 = tel._new_drawable_bucket() + b1["plan_build_ms"] = 99.0 + self.assertEqual(b2["plan_build_ms"], 0.0) + + +class TestCanvas2DTelemetryOverride(unittest.TestCase): + """Verify Canvas2DTelemetry adds extra keys to the drawable bucket.""" + + def test_bucket_has_legacy_render_keys(self) -> None: + tel = Canvas2DTelemetry() + bucket = tel._new_drawable_bucket() + self.assertIn("legacy_render_ms", bucket) + self.assertIn("legacy_render_count", bucket) + + def test_bucket_retains_base_keys(self) -> None: + tel = Canvas2DTelemetry() + bucket = tel._new_drawable_bucket() + base_keys = { + "plan_build_ms", + "plan_apply_ms", + "plan_build_count", + "plan_apply_count", + "plan_miss_count", + "plan_skip_count", + } + for key in base_keys: + self.assertIn(key, bucket, f"Missing base key: {key}") + + def test_legacy_keys_default_zero(self) -> None: + tel = Canvas2DTelemetry() + bucket = tel._new_drawable_bucket() + self.assertEqual(bucket["legacy_render_ms"], 0.0) + self.assertEqual(bucket["legacy_render_count"], 0) + + def test_record_plan_build_creates_canvas2d_bucket(self) -> None: + tel = Canvas2DTelemetry() + tel.record_plan_build("Point", 1.0) + bucket = tel._per_drawable["Point"] + self.assertIn("legacy_render_ms", bucket) + self.assertIn("legacy_render_count", bucket) + + def test_canvas2d_inherits_all_base_behavior(self) -> None: + tel = Canvas2DTelemetry() + tel.begin_frame() + tel.record_plan_build("Circle", 2.0) + tel.record_plan_apply("Circle", 1.0) + tel.record_plan_miss("Circle") + tel.record_plan_skip("Circle") + tel.record_adapter_event("fill", 3) + tel.track_batch_depth(2) + + snap = tel.snapshot() + self.assertEqual(snap["frames"], 1) + self.assertAlmostEqual(snap["phase"]["plan_build_ms"], 2.0) + self.assertAlmostEqual(snap["phase"]["plan_apply_ms"], 1.0) + self.assertEqual(snap["phase"]["plan_miss_count"], 1) + self.assertEqual(snap["phase"]["plan_skip_count"], 1) + self.assertEqual(snap["adapter_events"]["fill"], 3) + self.assertEqual(snap["adapter_events"]["max_batch_depth"], 2) + + +__all__ = [ + "TestBaseTelemetryInit", + "TestBaseTelemetryReset", + "TestBaseTelemetryBeginFrame", + "TestBaseTelemetryTiming", + "TestBaseTelemetryRecordPlanBuild", + "TestBaseTelemetryRecordPlanApply", + "TestBaseTelemetryRecordPlanMiss", + "TestBaseTelemetryRecordPlanSkip", + "TestBaseTelemetryCartesian", + "TestBaseTelemetryAdapterEvent", + "TestBaseTelemetryBatchDepth", + "TestBaseTelemetrySnapshot", + "TestBaseTelemetryDrain", + "TestBaseTelemetryNewDrawableBucket", + "TestCanvas2DTelemetryOverride", +] diff --git a/static/client/client_tests/tests.py b/static/client/client_tests/tests.py index d38986f8..fd338581 100644 --- a/static/client/client_tests/tests.py +++ b/static/client/client_tests/tests.py @@ -157,6 +157,7 @@ TestCircleArcRenderer, TestRendererEdgeCases as TestDrawableRendererEdgeCases, ) +from .test_base_drawable_manager import TestBaseDrawableManager from .test_point_manager import TestPointManagerUpdates from .test_bar_renderer import TestBarRenderer from .test_function_renderables import ( @@ -257,6 +258,23 @@ TestCompactSummary, ) from .test_result_processor_traced import TestGetResultsTraced +from .test_base_telemetry import ( + TestBaseTelemetryInit, + TestBaseTelemetryReset, + TestBaseTelemetryBeginFrame, + TestBaseTelemetryTiming, + TestBaseTelemetryRecordPlanBuild, + TestBaseTelemetryRecordPlanApply, + TestBaseTelemetryRecordPlanMiss, + TestBaseTelemetryRecordPlanSkip, + TestBaseTelemetryCartesian, + TestBaseTelemetryAdapterEvent, + TestBaseTelemetryBatchDepth, + TestBaseTelemetrySnapshot, + TestBaseTelemetryDrain, + TestBaseTelemetryNewDrawableBucket, + TestCanvas2DTelemetryOverride, +) class Tests: @@ -356,6 +374,7 @@ def _get_test_cases(self) -> List[Type[unittest.TestCase]]: TestSegmentLabelRenderer, TestCircleArcRenderer, TestDrawableRendererEdgeCases, + TestBaseDrawableManager, TestPointManagerUpdates, TestFunctionRenderable, TestFunctionsBoundedAreaRenderable, @@ -558,6 +577,21 @@ def _get_test_cases(self) -> List[Type[unittest.TestCase]]: TestExportTracesJson, TestCompactSummary, TestGetResultsTraced, + TestBaseTelemetryInit, + TestBaseTelemetryReset, + TestBaseTelemetryBeginFrame, + TestBaseTelemetryTiming, + TestBaseTelemetryRecordPlanBuild, + TestBaseTelemetryRecordPlanApply, + TestBaseTelemetryRecordPlanMiss, + TestBaseTelemetryRecordPlanSkip, + TestBaseTelemetryCartesian, + TestBaseTelemetryAdapterEvent, + TestBaseTelemetryBatchDepth, + TestBaseTelemetrySnapshot, + TestBaseTelemetryDrain, + TestBaseTelemetryNewDrawableBucket, + TestCanvas2DTelemetryOverride, ] def _create_test_suite(self) -> unittest.TestSuite: From dc119abbf628cd6ea942ce1578b757753a21935c Mon Sep 17 00:00:00 2001 From: Vlad Ciobanu <95963142+vl3c@users.noreply.github.com> Date: Tue, 24 Mar 2026 22:12:59 +0200 Subject: [PATCH 12/28] Update error recovery tests for removed test trigger Replace TEST_ERROR_TRIGGER_12345 magic string with descriptive test message. Tests were already fully mocked and never depended on the server-side trigger. --- static/client/client_tests/test_error_recovery.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/static/client/client_tests/test_error_recovery.py b/static/client/client_tests/test_error_recovery.py index d12f0138..20517372 100644 --- a/static/client/client_tests/test_error_recovery.py +++ b/static/client/client_tests/test_error_recovery.py @@ -1,7 +1,7 @@ """Tests for the message recovery feature on AI errors. -When an AI request fails (e.g., TEST_ERROR_TRIGGER_12345), the user's message -should be restored to the input field so they can edit and retry. +When an AI request fails, the user's message should be restored to the +input field so they can edit and retry. """ from __future__ import annotations @@ -36,7 +36,7 @@ def test_restore_user_message_on_error_populates_input(self) -> None: try: # Set the buffered message - test_message = "TEST_ERROR_TRIGGER_12345" + test_message = "simulate server error for retry" ai._last_user_message = test_message # Call the recovery method @@ -158,7 +158,7 @@ def mock_restore() -> None: ai._reset_tool_call_log_state = lambda: None # Set the buffered message - original_message = "TEST_ERROR_TRIGGER_12345" + original_message = "simulate server error for retry" ai._last_user_message = original_message # Simulate error completion event From 9f62dbf03a667911c238ce113641a825e58d4e08 Mon Sep 17 00:00:00 2001 From: Vlad Ciobanu <95963142+vl3c@users.noreply.github.com> Date: Tue, 24 Mar 2026 22:13:05 +0200 Subject: [PATCH 13/28] Remove side effect from Segment.get_state() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The _sync_label_position() call mutated state during serialization but was unnecessary — label position is not serialized, and the sync already happens in all mutation methods and the renderer. --- static/client/drawables/segment.py | 1 - 1 file changed, 1 deletion(-) diff --git a/static/client/drawables/segment.py b/static/client/drawables/segment.py index fabd6e81..d97eab43 100644 --- a/static/client/drawables/segment.py +++ b/static/client/drawables/segment.py @@ -93,7 +93,6 @@ def _calculate_line_algebraic_formula(self) -> Dict[str, float]: def get_state(self) -> Dict[str, Any]: # Keep endpoint ordering consistent with in-memory references so downstream # consumers (workspace saves, dependency checks) preserve segment identity. - self._sync_label_position() state: Dict[str, Any] = { "name": self.name, "args": { From 9989dfcfa6b3e2aef5632bec55960cf9b0b16a37 Mon Sep 17 00:00:00 2001 From: Vlad Ciobanu <95963142+vl3c@users.noreply.github.com> Date: Tue, 24 Mar 2026 22:13:11 +0200 Subject: [PATCH 14/28] Migrate 7 more managers to BaseDrawableManager Segment, Vector, Circle, Ellipse, Arc, Angle, and Polygon managers now inherit from BaseDrawableManager. Replaces duplicated __init__ boilerplate, edit policy setup, and get_by_name lookups with shared base class behavior. Polygon partially migrated since it manages multiple subtypes dynamically. --- static/client/managers/angle_manager.py | 34 +++++++++++------------ static/client/managers/arc_manager.py | 34 ++++++++++++----------- static/client/managers/circle_manager.py | 31 +++++++++++---------- static/client/managers/ellipse_manager.py | 31 +++++++++++---------- static/client/managers/polygon_manager.py | 19 +++++++------ static/client/managers/segment_manager.py | 30 ++++++++++---------- static/client/managers/vector_manager.py | 27 +++++++++--------- 7 files changed, 107 insertions(+), 99 deletions(-) diff --git a/static/client/managers/angle_manager.py b/static/client/managers/angle_manager.py index bc881fd1..4f55f237 100644 --- a/static/client/managers/angle_manager.py +++ b/static/client/managers/angle_manager.py @@ -48,7 +48,8 @@ from drawables.angle import Angle from drawables.point import Point from drawables.segment import Segment -from managers.edit_policy import DrawableEditPolicy, EditRule, get_drawable_edit_policy +from managers.base_drawable_manager import BaseDrawableManager +from managers.edit_policy import EditRule if TYPE_CHECKING: from drawables.drawable import Drawable @@ -61,7 +62,7 @@ from name_generator.drawable import DrawableNameGenerator -class AngleManager: +class AngleManager(BaseDrawableManager): """ Manages Angle drawables for a Canvas. This class is responsible for: @@ -70,6 +71,8 @@ class AngleManager: - (Future) Deleting Angle objects and managing their dependencies. """ + drawable_type: str = "Angle" + def __init__( self, canvas: "Canvas", @@ -92,14 +95,15 @@ def __init__( segment_manager: Manager for Segment drawables. drawable_manager_proxy: Proxy to the main DrawableManager or for inter-manager calls. """ - self.canvas: "Canvas" = canvas - self.drawables: "DrawablesContainer" = drawables_container - self.name_generator: "DrawableNameGenerator" = name_generator - self.dependency_manager: "DrawableDependencyManager" = dependency_manager + super().__init__( + canvas, + drawables_container, + name_generator, + dependency_manager, + drawable_manager_proxy, + ) self.point_manager: "PointManager" = point_manager self.segment_manager: "SegmentManager" = segment_manager - self.drawable_manager: "DrawableManagerProxy" = drawable_manager_proxy - self.angle_edit_policy: Optional[DrawableEditPolicy] = get_drawable_edit_policy("Angle") def create_angle( self, @@ -224,14 +228,8 @@ def get_angle_by_name(self, name: str) -> Optional[Angle]: Returns: The Angle object if found, otherwise None. """ - # Ensure self.drawables.Angles exists and is iterable - if not hasattr(self.drawables, "Angles") or not isinstance(self.drawables.Angles, list): - # print("AngleManager: DrawablesContainer has no 'Angles' list or it's not a list.") - return None - for angle in self.drawables.Angles: - if angle.name == name: - return angle - return None + result = self._get_by_name(name) + return cast(Optional[Angle], result) def get_angle_by_segments( self, segment1: Segment, segment2: Segment, is_reflex_filter: Optional[bool] = None @@ -432,12 +430,12 @@ def _collect_angle_requested_fields( return pending_fields def _validate_angle_policy(self, requested_fields: List[str]) -> Dict[str, EditRule]: - if not self.angle_edit_policy: + if not self.edit_policy: raise ValueError("Edit policy for angles is not configured.") validated_rules: Dict[str, EditRule] = {} for field in requested_fields: - rule = self.angle_edit_policy.get_rule(field) + rule = self.edit_policy.get_rule(field) if not rule: raise ValueError(f"Editing field '{field}' is not permitted for angles.") validated_rules[field] = rule diff --git a/static/client/managers/arc_manager.py b/static/client/managers/arc_manager.py index 65396954..4a0ef95f 100644 --- a/static/client/managers/arc_manager.py +++ b/static/client/managers/arc_manager.py @@ -12,7 +12,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast from drawables.circle_arc import CircleArc -from managers.edit_policy import DrawableEditPolicy, EditRule, get_drawable_edit_policy +from managers.base_drawable_manager import BaseDrawableManager +from managers.edit_policy import EditRule from utils.math_utils import MathUtils if TYPE_CHECKING: @@ -27,9 +28,11 @@ from name_generator.drawable import DrawableNameGenerator -class ArcManager: +class ArcManager(BaseDrawableManager): """Manager responsible for lifecycle of CircleArc drawables.""" + drawable_type: str = "CircleArc" + def __init__( self, canvas: "Canvas", @@ -39,13 +42,14 @@ def __init__( point_manager: "PointManager", drawable_manager_proxy: "DrawableManagerProxy", ) -> None: - self.canvas = canvas - self.drawables = drawables_container - self.name_generator = name_generator - self.dependency_manager = dependency_manager - self.point_manager = point_manager - self.drawable_manager = drawable_manager_proxy - self.arc_edit_policy: Optional[DrawableEditPolicy] = get_drawable_edit_policy("CircleArc") + super().__init__( + canvas, + drawables_container, + name_generator, + dependency_manager, + drawable_manager_proxy, + ) + self.point_manager: "PointManager" = point_manager # ------------------------------------------------------------------ # Creation helpers @@ -472,10 +476,8 @@ def _get_circle_arc_or_raise(self, name: str) -> CircleArc: return arc def get_circle_arc_by_name(self, name: str) -> Optional[CircleArc]: - for arc in cast(List[CircleArc], self.drawables.CircleArcs): - if arc.name == name: - return arc - return None + result = self._get_by_name(name) + return cast(Optional[CircleArc], result) def delete_circle_arc(self, name: str) -> bool: arc = self.get_circle_arc_by_name(name) @@ -516,7 +518,7 @@ def update_circle_arc( requested_fields = self._collect_arc_requested_fields(new_color, use_major_arc) - if self.arc_edit_policy: + if self.edit_policy: self._validate_arc_policy(requested_fields) self.canvas.undo_redo_manager.archive() @@ -554,12 +556,12 @@ def _collect_arc_requested_fields( return requested_fields def _validate_arc_policy(self, requested_fields: List[str]) -> Dict[str, EditRule]: - if not self.arc_edit_policy: + if not self.edit_policy: return {} validated: Dict[str, EditRule] = {} for field in requested_fields: - rule = self.arc_edit_policy.get_rule(field) + rule = self.edit_policy.get_rule(field) if not rule: raise ValueError(f"Editing field '{field}' is not permitted for circle arcs.") validated[field] = rule diff --git a/static/client/managers/circle_manager.py b/static/client/managers/circle_manager.py index 015a2c67..1ab693b5 100644 --- a/static/client/managers/circle_manager.py +++ b/static/client/managers/circle_manager.py @@ -38,7 +38,8 @@ from typing import TYPE_CHECKING, Dict, List, Optional, cast from drawables.circle import Circle -from managers.edit_policy import DrawableEditPolicy, EditRule, get_drawable_edit_policy +from managers.base_drawable_manager import BaseDrawableManager +from managers.edit_policy import EditRule from managers.dependency_removal import remove_drawable_with_dependencies if TYPE_CHECKING: @@ -50,7 +51,7 @@ from name_generator.drawable import DrawableNameGenerator -class CircleManager: +class CircleManager(BaseDrawableManager): """ Manages circle drawables for a Canvas. @@ -60,6 +61,8 @@ class CircleManager: - Deleting circle objects with proper cleanup and redrawing """ + drawable_type: str = "Circle" + def __init__( self, canvas: "Canvas", @@ -80,13 +83,14 @@ def __init__( point_manager: Manager for point drawables drawable_manager_proxy: Proxy to the main DrawableManager """ - self.canvas: "Canvas" = canvas - self.drawables: "DrawablesContainer" = drawables_container - self.name_generator: "DrawableNameGenerator" = name_generator - self.dependency_manager: "DrawableDependencyManager" = dependency_manager + super().__init__( + canvas, + drawables_container, + name_generator, + dependency_manager, + drawable_manager_proxy, + ) self.point_manager: "PointManager" = point_manager - self.drawable_manager: "DrawableManagerProxy" = drawable_manager_proxy - self.circle_edit_policy: Optional[DrawableEditPolicy] = get_drawable_edit_policy("Circle") def get_circle(self, center_x: float, center_y: float, radius: float) -> Optional[Circle]: """ @@ -116,11 +120,8 @@ def get_circle_by_name(self, name: str) -> Optional[Circle]: Returns: Circle: The circle object with the given name, or None if not found """ - circles = self.drawables.Circles - for circle in circles: - if circle.name == name: - return circle - return None + result = self._get_by_name(name) + return cast(Optional[Circle], result) def create_circle( self, @@ -273,12 +274,12 @@ def _collect_circle_requested_fields( return pending_fields def _validate_circle_policy(self, requested_fields: List[str]) -> Dict[str, EditRule]: - if not self.circle_edit_policy: + if not self.edit_policy: raise ValueError("Edit policy for circles is not configured.") validated_rules: Dict[str, EditRule] = {} for field in requested_fields: - rule = self.circle_edit_policy.get_rule(field) + rule = self.edit_policy.get_rule(field) if not rule: raise ValueError(f"Editing field '{field}' is not permitted for circles.") validated_rules[field] = rule diff --git a/static/client/managers/ellipse_manager.py b/static/client/managers/ellipse_manager.py index eed4f8a6..368940c5 100644 --- a/static/client/managers/ellipse_manager.py +++ b/static/client/managers/ellipse_manager.py @@ -37,7 +37,8 @@ from typing import TYPE_CHECKING, Dict, List, Optional, cast from drawables.ellipse import Ellipse -from managers.edit_policy import DrawableEditPolicy, EditRule, get_drawable_edit_policy +from managers.base_drawable_manager import BaseDrawableManager +from managers.edit_policy import EditRule from managers.dependency_removal import remove_drawable_with_dependencies if TYPE_CHECKING: @@ -49,7 +50,7 @@ from name_generator.drawable import DrawableNameGenerator -class EllipseManager: +class EllipseManager(BaseDrawableManager): """ Manages ellipse drawables for a Canvas. @@ -59,6 +60,8 @@ class EllipseManager: - Deleting ellipse objects with proper cleanup and redrawing """ + drawable_type: str = "Ellipse" + def __init__( self, canvas: "Canvas", @@ -79,13 +82,14 @@ def __init__( point_manager: Manager for point drawables drawable_manager_proxy: Proxy to the main DrawableManager """ - self.canvas: "Canvas" = canvas - self.drawables: "DrawablesContainer" = drawables_container - self.name_generator: "DrawableNameGenerator" = name_generator - self.dependency_manager: "DrawableDependencyManager" = dependency_manager + super().__init__( + canvas, + drawables_container, + name_generator, + dependency_manager, + drawable_manager_proxy, + ) self.point_manager: "PointManager" = point_manager - self.drawable_manager: "DrawableManagerProxy" = drawable_manager_proxy - self.ellipse_edit_policy: Optional[DrawableEditPolicy] = get_drawable_edit_policy("Ellipse") def get_ellipse(self, center_x: float, center_y: float, radius_x: float, radius_y: float) -> Optional[Ellipse]: """ @@ -123,11 +127,8 @@ def get_ellipse_by_name(self, name: str) -> Optional[Ellipse]: Returns: Ellipse: The ellipse object with the given name, or None if not found """ - ellipses = self.drawables.Ellipses - for ellipse in ellipses: - if ellipse.name == name: - return ellipse - return None + result = self._get_by_name(name) + return cast(Optional[Ellipse], result) def create_ellipse( self, @@ -315,12 +316,12 @@ def _collect_ellipse_requested_fields( return pending_fields def _validate_ellipse_policy(self, requested_fields: List[str]) -> Dict[str, EditRule]: - if not self.ellipse_edit_policy: + if not self.edit_policy: raise ValueError("Edit policy for ellipses is not configured.") validated_rules: Dict[str, EditRule] = {} for field in requested_fields: - rule = self.ellipse_edit_policy.get_rule(field) + rule = self.edit_policy.get_rule(field) if not rule: raise ValueError(f"Editing field '{field}' is not permitted for ellipses.") validated_rules[field] = rule diff --git a/static/client/managers/polygon_manager.py b/static/client/managers/polygon_manager.py index 0a7cc340..6ef92c31 100644 --- a/static/client/managers/polygon_manager.py +++ b/static/client/managers/polygon_manager.py @@ -27,6 +27,7 @@ from drawables.rectangle import Rectangle from drawables.triangle import Triangle from drawables.position import Position +from managers.base_drawable_manager import BaseDrawableManager from managers.dependency_removal import remove_drawable_with_dependencies from managers.polygon_type import PolygonType from managers.edit_policy import EditRule, get_drawable_edit_policy @@ -60,7 +61,7 @@ SegmentList = List["Segment"] -class PolygonManager: +class PolygonManager(BaseDrawableManager): """Manages polygonal drawables with shared create/update/delete flows.""" _TYPE_TO_SIDE_COUNT: Dict[PolygonType, int] = { @@ -113,13 +114,15 @@ def __init__( segment_manager: "SegmentManager", drawable_manager_proxy: "DrawableManagerProxy", ) -> None: - self.canvas = canvas - self.drawables = drawables_container - self.name_generator = name_generator - self.dependency_manager = dependency_manager - self.point_manager = point_manager - self.segment_manager = segment_manager - self.drawable_manager = drawable_manager_proxy + super().__init__( + canvas, + drawables_container, + name_generator, + dependency_manager, + drawable_manager_proxy, + ) + self.point_manager: "PointManager" = point_manager + self.segment_manager: "SegmentManager" = segment_manager # ------------------------------------------------------------------ # # Public API diff --git a/static/client/managers/segment_manager.py b/static/client/managers/segment_manager.py index f2eff33d..0c3b679a 100644 --- a/static/client/managers/segment_manager.py +++ b/static/client/managers/segment_manager.py @@ -40,8 +40,9 @@ from drawables.label import Label from drawables.segment import Segment from utils.math_utils import MathUtils +from managers.base_drawable_manager import BaseDrawableManager from managers.dependency_removal import get_polygon_segments, remove_drawable_with_dependencies -from managers.edit_policy import DrawableEditPolicy, EditRule, get_drawable_edit_policy +from managers.edit_policy import EditRule if TYPE_CHECKING: from drawables.drawable import Drawable @@ -54,7 +55,7 @@ from name_generator.drawable import DrawableNameGenerator -class SegmentManager: +class SegmentManager(BaseDrawableManager): """ Manages segment drawables for a Canvas. @@ -64,6 +65,8 @@ class SegmentManager: - Deleting segment objects """ + drawable_type: str = "Segment" + def __init__( self, canvas: "Canvas", @@ -84,13 +87,14 @@ def __init__( point_manager: Manager for point drawables drawable_manager_proxy: Proxy to the main DrawableManager """ - self.canvas: "Canvas" = canvas - self.drawables: "DrawablesContainer" = drawables_container - self.name_generator: "DrawableNameGenerator" = name_generator - self.dependency_manager: "DrawableDependencyManager" = dependency_manager + super().__init__( + canvas, + drawables_container, + name_generator, + dependency_manager, + drawable_manager_proxy, + ) self.point_manager: "PointManager" = point_manager - self.drawable_manager: "DrawableManagerProxy" = drawable_manager_proxy - self.segment_edit_policy: Optional[DrawableEditPolicy] = get_drawable_edit_policy("Segment") def get_segment_by_coordinates(self, x1: float, y1: float, x2: float, y2: float) -> Optional[Segment]: """ @@ -125,10 +129,8 @@ def get_segment_by_name(self, name: str) -> Optional[Segment]: Returns: Segment: The segment with the matching name, or None if not found """ - for segment in self.drawables.Segments: - if segment.name == name: - return segment - return None + result = self._get_by_name(name) + return cast(Optional[Segment], result) def get_segment_by_points(self, p1: "Point", p2: "Point") -> Optional[Segment]: """ @@ -421,12 +423,12 @@ def _collect_segment_requested_fields( return pending_fields def _validate_segment_policy(self, requested_fields: List[str]) -> Dict[str, EditRule]: - if not self.segment_edit_policy: + if not self.edit_policy: raise ValueError("Edit policy for segments is not configured.") validated_rules: Dict[str, EditRule] = {} for field in requested_fields: - rule = self.segment_edit_policy.get_rule(field) + rule = self.edit_policy.get_rule(field) if not rule: raise ValueError(f"Editing field '{field}' is not permitted for segments.") validated_rules[field] = rule diff --git a/static/client/managers/vector_manager.py b/static/client/managers/vector_manager.py index 29fcf0cf..123c6473 100644 --- a/static/client/managers/vector_manager.py +++ b/static/client/managers/vector_manager.py @@ -37,9 +37,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, List, Optional, cast from drawables.vector import Vector +from managers.base_drawable_manager import BaseDrawableManager from managers.dependency_removal import remove_drawable_with_dependencies from utils.math_utils import MathUtils @@ -53,7 +54,7 @@ from name_generator.drawable import DrawableNameGenerator -class VectorManager: +class VectorManager(BaseDrawableManager): """ Manages vector drawables for a Canvas. @@ -63,6 +64,8 @@ class VectorManager: - Deleting vector objects """ + drawable_type: str = "Vector" + def __init__( self, canvas: "Canvas", @@ -83,12 +86,14 @@ def __init__( point_manager: Manager for point drawables drawable_manager_proxy: Proxy to the main DrawableManager """ - self.canvas: "Canvas" = canvas - self.drawables: "DrawablesContainer" = drawables_container - self.name_generator: "DrawableNameGenerator" = name_generator - self.dependency_manager: "DrawableDependencyManager" = dependency_manager + super().__init__( + canvas, + drawables_container, + name_generator, + dependency_manager, + drawable_manager_proxy, + ) self.point_manager: "PointManager" = point_manager - self.drawable_manager: "DrawableManagerProxy" = drawable_manager_proxy def get_vector(self, x1: float, y1: float, x2: float, y2: float) -> Optional[Vector]: """ @@ -115,12 +120,8 @@ def get_vector(self, x1: float, y1: float, x2: float, y2: float) -> Optional[Vec return None def get_vector_by_name(self, name: str) -> Optional[Vector]: - if not name: - return None - for vector in self.drawables.Vectors: - if vector.name == name: - return vector - return None + result = self._get_by_name(name) + return cast(Optional[Vector], result) def create_vector( self, From 9d214397ea5f2d74c2a319b0a2c238aea5c70a08 Mon Sep 17 00:00:00 2001 From: Vlad Ciobanu <95963142+vl3c@users.noreply.github.com> Date: Tue, 24 Mar 2026 22:13:16 +0200 Subject: [PATCH 15/28] Extract VisibilityManager from Canvas Move 10 viewport culling methods from canvas.py into a dedicated VisibilityManager class. Accepts coordinate_mapper and dimensions instead of a Canvas back-reference to avoid circular dependencies. Canvas retains thin delegation stubs to preserve the public API. Reduces canvas.py by ~62 lines. --- static/client/canvas.py | 82 ++------- static/client/managers/visibility_manager.py | 179 +++++++++++++++++++ 2 files changed, 189 insertions(+), 72 deletions(-) create mode 100644 static/client/managers/visibility_manager.py diff --git a/static/client/canvas.py b/static/client/canvas.py index 0ad2e9d3..243e881f 100644 --- a/static/client/canvas.py +++ b/static/client/canvas.py @@ -43,7 +43,6 @@ from drawables_aggregator import Point from cartesian_system_2axis import Cartesian2Axis from coordinate_mapper import CoordinateMapper -from utils.math_utils import MathUtils from utils.style_utils import StyleUtils from utils.graph_analyzer import GraphAnalyzer from utils.relation_inspector import RelationInspector @@ -54,6 +53,7 @@ from managers.drawable_dependency_manager import DrawableDependencyManager from managers.transformations_manager import TransformationsManager from managers.coordinate_system_manager import CoordinateSystemManager +from managers.visibility_manager import VisibilityManager from managers.polygon_type import PolygonType from constants import DEFAULT_RENDERER_MODE from rendering.factory import create_renderer @@ -117,6 +117,9 @@ def __init__( self.cartesian2axis: Cartesian2Axis = Cartesian2Axis(self.coordinate_mapper) # Add managers + self.visibility_manager: VisibilityManager = VisibilityManager( + self.coordinate_mapper, width, height + ) self.undo_redo_manager: UndoRedoManager = UndoRedoManager(self) self.drawable_manager: DrawableManager = DrawableManager(self) self.dependency_manager: DrawableDependencyManager = self.drawable_manager.dependency_manager @@ -295,65 +298,8 @@ def _render_drawable_with_renderer(self, renderer: Optional[RendererProtocol], d pass def _is_drawable_visible(self, drawable: "Drawable") -> bool: - """Best-effort visibility check to avoid rendering off-canvas objects. - - Mirrors prior behavior for segments and points; other types default to visible - because they manage their own bounds or are inexpensive. - """ - class_name = self._safe_drawable_class_name(drawable) - try: - if class_name == "Point": - return self._is_point_drawable_visible(drawable) - - if class_name == "Segment": - return self._is_segment_drawable_visible(drawable) - - if class_name == "Vector": - return self._is_vector_drawable_visible(drawable) - - # Default: visible - return True - except Exception: - return True - - def _safe_drawable_class_name(self, drawable: Any) -> str: - try: - return str( - drawable.get_class_name() if hasattr(drawable, "get_class_name") else drawable.__class__.__name__ - ) - except Exception: - return str(drawable.__class__.__name__) - - def _is_point_drawable_visible(self, drawable: Any) -> bool: - # Use screen coordinates if available, else compute - # Math-only point; map via CoordinateMapper - x, y = self.coordinate_mapper.math_to_screen(drawable.x, drawable.y) - return self.is_point_within_canvas_visible_area(x, y) - - def _is_segment_drawable_visible(self, drawable: Any) -> bool: - return self._is_math_segment_visible(drawable.point1, drawable.point2) - - def _is_vector_drawable_visible(self, drawable: Any) -> bool: - seg = getattr(drawable, "segment", None) - if seg is None: - return True - return self._is_math_segment_visible(seg.point1, seg.point2) - - def _is_math_segment_visible(self, p1: Any, p2: Any) -> bool: - x1, y1, x2, y2 = self._segment_screen_coordinates(p1, p2) - return self._is_screen_segment_visible(x1, y1, x2, y2) - - def _segment_screen_coordinates(self, p1: Any, p2: Any) -> Tuple[float, float, float, float]: - x1, y1 = self.coordinate_mapper.math_to_screen(p1.x, p1.y) - x2, y2 = self.coordinate_mapper.math_to_screen(p2.x, p2.y) - return x1, y1, x2, y2 - - def _is_screen_segment_visible(self, x1: float, y1: float, x2: float, y2: float) -> bool: - return ( - self.is_point_within_canvas_visible_area(x1, y1) - or self.is_point_within_canvas_visible_area(x2, y2) - or self.any_segment_part_visible_in_canvas_area(x1, y1, x2, y2) - ) + """Best-effort visibility check — delegates to VisibilityManager.""" + return self.visibility_manager.is_drawable_visible(drawable) # Removed legacy zoom displacement; zoom handled via CoordinateMapper @@ -905,8 +851,8 @@ def update_point( ) def is_point_within_canvas_visible_area(self, x: float, y: float) -> bool: - """Check if a point is within the visible area of the canvas""" - return (0 <= x <= self.width) and (0 <= y <= self.height) + """Check if a point is within the visible area of the canvas — delegates to VisibilityManager.""" + return self.visibility_manager.is_point_within_canvas_visible_area(x, y) def get_segment_by_coordinates(self, x1: float, y1: float, x2: float, y2: float) -> Optional["Drawable"]: """Get a segment by its endpoint coordinates""" @@ -973,16 +919,8 @@ def update_segment( ) def any_segment_part_visible_in_canvas_area(self, x1: float, y1: float, x2: float, y2: float) -> bool: - """Check if any part of a segment is visible in the canvas area""" - intersect_top = MathUtils.segments_intersect(x1, y1, x2, y2, 0, 0, self.width, 0) - intersect_right = MathUtils.segments_intersect(x1, y1, x2, y2, self.width, 0, self.width, self.height) - intersect_bottom = MathUtils.segments_intersect(x1, y1, x2, y2, self.width, self.height, 0, self.height) - intersect_left = MathUtils.segments_intersect(x1, y1, x2, y2, 0, self.height, 0, 0) - point1_visible: bool = self.is_point_within_canvas_visible_area(x1, y1) - point2_visible: bool = self.is_point_within_canvas_visible_area(x2, y2) - return bool( - intersect_top or intersect_right or intersect_bottom or intersect_left or point1_visible or point2_visible - ) + """Check if any part of a segment is visible — delegates to VisibilityManager.""" + return self.visibility_manager.any_segment_part_visible_in_canvas_area(x1, y1, x2, y2) def get_vector(self, x1: float, y1: float, x2: float, y2: float) -> Optional["Drawable"]: """Get a vector by its origin and tip coordinates""" diff --git a/static/client/managers/visibility_manager.py b/static/client/managers/visibility_manager.py new file mode 100644 index 00000000..cb44244e --- /dev/null +++ b/static/client/managers/visibility_manager.py @@ -0,0 +1,179 @@ +""" +MatHud Visibility Manager + +Provides viewport culling logic to determine whether drawable objects fall within +the visible canvas area. Extracted from Canvas to isolate geometric visibility +checks from the central coordinator. + +Dependencies are injected as lightweight values (coordinate_mapper, width, height) +so the manager never holds a back-reference to Canvas. +""" + +from __future__ import annotations + +from typing import Any, Tuple, TYPE_CHECKING + +from utils.math_utils import MathUtils + +if TYPE_CHECKING: + from coordinate_mapper import CoordinateMapper + from drawables.drawable import Drawable + + +class VisibilityManager: + """Viewport culling for drawable objects. + + Determines whether points, segments, vectors, and other drawables intersect + the visible canvas rectangle. Types without a specialised check default to + visible so that complex shapes (circles, functions, etc.) are never + incorrectly culled. + + Attributes: + _coordinate_mapper: Coordinate transformation service for math-to-screen conversion. + _width: Canvas viewport width in pixels. + _height: Canvas viewport height in pixels. + """ + + def __init__(self, coordinate_mapper: "CoordinateMapper", width: float, height: float) -> None: + """Initialise the visibility manager. + + Args: + coordinate_mapper: Provides math-to-screen coordinate conversion. + width: Canvas viewport width in pixels. + height: Canvas viewport height in pixels. + """ + self._coordinate_mapper: "CoordinateMapper" = coordinate_mapper + self._width: float = width + self._height: float = height + + # ------------------------------------------------------------------ + # Canvas-dimension accessors (kept in sync by Canvas when resized) + # ------------------------------------------------------------------ + + def update_dimensions(self, width: float, height: float) -> None: + """Update cached canvas dimensions after a resize. + + Args: + width: New canvas viewport width in pixels. + height: New canvas viewport height in pixels. + """ + self._width = width + self._height = height + + # ------------------------------------------------------------------ + # Top-level drawable visibility + # ------------------------------------------------------------------ + + def is_drawable_visible(self, drawable: "Drawable") -> bool: + """Best-effort visibility check to avoid rendering off-canvas objects. + + Mirrors prior behaviour for segments and points; other types default to + visible because they manage their own bounds or are inexpensive. + """ + class_name = self._safe_drawable_class_name(drawable) + try: + if class_name == "Point": + return self._is_point_drawable_visible(drawable) + + if class_name == "Segment": + return self._is_segment_drawable_visible(drawable) + + if class_name == "Vector": + return self._is_vector_drawable_visible(drawable) + + # Default: visible + return True + except Exception: + return True + + # ------------------------------------------------------------------ + # Public primitive checks + # ------------------------------------------------------------------ + + def is_point_within_canvas_visible_area(self, x: float, y: float) -> bool: + """Check if a screen-coordinate point is within the visible canvas area.""" + return (0 <= x <= self._width) and (0 <= y <= self._height) + + def any_segment_part_visible_in_canvas_area( + self, x1: float, y1: float, x2: float, y2: float + ) -> bool: + """Check if any part of a segment (in screen coordinates) is visible.""" + intersect_top = MathUtils.segments_intersect(x1, y1, x2, y2, 0, 0, self._width, 0) + intersect_right = MathUtils.segments_intersect( + x1, y1, x2, y2, self._width, 0, self._width, self._height + ) + intersect_bottom = MathUtils.segments_intersect( + x1, y1, x2, y2, self._width, self._height, 0, self._height + ) + intersect_left = MathUtils.segments_intersect(x1, y1, x2, y2, 0, self._height, 0, 0) + point1_visible: bool = self.is_point_within_canvas_visible_area(x1, y1) + point2_visible: bool = self.is_point_within_canvas_visible_area(x2, y2) + return bool( + intersect_top + or intersect_right + or intersect_bottom + or intersect_left + or point1_visible + or point2_visible + ) + + # ------------------------------------------------------------------ + # Drawable-type specific checks (private) + # ------------------------------------------------------------------ + + def _is_point_drawable_visible(self, drawable: Any) -> bool: + """Check visibility for a Point drawable.""" + x, y = self._coordinate_mapper.math_to_screen(drawable.x, drawable.y) + return self.is_point_within_canvas_visible_area(x, y) + + def _is_segment_drawable_visible(self, drawable: Any) -> bool: + """Check visibility for a Segment drawable.""" + return self._is_math_segment_visible(drawable.point1, drawable.point2) + + def _is_vector_drawable_visible(self, drawable: Any) -> bool: + """Check visibility for a Vector drawable.""" + seg = getattr(drawable, "segment", None) + if seg is None: + return True + return self._is_math_segment_visible(seg.point1, seg.point2) + + # ------------------------------------------------------------------ + # Segment helpers (private) + # ------------------------------------------------------------------ + + def _is_math_segment_visible(self, p1: Any, p2: Any) -> bool: + """Check whether a math-coordinate segment is visible on screen.""" + x1, y1, x2, y2 = self._segment_screen_coordinates(p1, p2) + return self._is_screen_segment_visible(x1, y1, x2, y2) + + def _segment_screen_coordinates(self, p1: Any, p2: Any) -> Tuple[float, float, float, float]: + """Convert two math-coordinate points to screen coordinates.""" + x1, y1 = self._coordinate_mapper.math_to_screen(p1.x, p1.y) + x2, y2 = self._coordinate_mapper.math_to_screen(p2.x, p2.y) + return x1, y1, x2, y2 + + def _is_screen_segment_visible( + self, x1: float, y1: float, x2: float, y2: float + ) -> bool: + """Check whether a screen-coordinate segment is visible.""" + return ( + self.is_point_within_canvas_visible_area(x1, y1) + or self.is_point_within_canvas_visible_area(x2, y2) + or self.any_segment_part_visible_in_canvas_area(x1, y1, x2, y2) + ) + + # ------------------------------------------------------------------ + # Utilities (private) + # ------------------------------------------------------------------ + + @staticmethod + def _safe_drawable_class_name(drawable: Any) -> str: + """Return the class name of a drawable, falling back to __class__.__name__.""" + try: + return str( + drawable.get_class_name() + if hasattr(drawable, "get_class_name") + else drawable.__class__.__name__ + ) + except Exception: + return str(drawable.__class__.__name__) From 7a821fd1be55b8542b9998c9579ab271ca4b0114 Mon Sep 17 00:00:00 2001 From: Vlad Ciobanu <95963142+vl3c@users.noreply.github.com> Date: Tue, 24 Mar 2026 22:18:07 +0200 Subject: [PATCH 16/28] Fix angle manager test mock for BaseDrawableManager migration Add get_by_class_name to drawables container mock so _get_by_name lookups work after AngleManager inherits from BaseDrawableManager. --- static/client/client_tests/test_angle_manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/static/client/client_tests/test_angle_manager.py b/static/client/client_tests/test_angle_manager.py index 2240601c..24bb23b0 100644 --- a/static/client/client_tests/test_angle_manager.py +++ b/static/client/client_tests/test_angle_manager.py @@ -44,6 +44,7 @@ def setUp(self) -> None: name="DrawablesContainerMock", Angles=[], # Holds created Angle instances add=MagicMock(side_effect=lambda x: self.drawables_container_mock.Angles.append(x)), + get_by_class_name=lambda cls_name: self.drawables_container_mock.Angles if cls_name == "Angle" else [], ) self.name_generator_mock = SimpleMock(name="NameGeneratorMock") # Basic mock for now From d32510285765112fffe800c587523a58c7fb2e4f Mon Sep 17 00:00:00 2001 From: Vlad Ciobanu <95963142+vl3c@users.noreply.github.com> Date: Tue, 24 Mar 2026 23:20:03 +0200 Subject: [PATCH 17/28] Extract ToolCallLogManager from AIInterface Move tool call log state and 6 methods into a dedicated ToolCallLogManager class. Update test_tool_call_log.py to test the new class directly. AIInterface delegates via self._tool_call_log instance. --- static/client/ai_interface.py | 219 +-------------- .../client_tests/test_error_recovery.py | 14 +- .../client/client_tests/test_tool_call_log.py | 255 ++++++++--------- static/client/tool_call_log_manager.py | 258 ++++++++++++++++++ 4 files changed, 393 insertions(+), 353 deletions(-) create mode 100644 static/client/tool_call_log_manager.py diff --git a/static/client/ai_interface.py b/static/client/ai_interface.py index f3b99b8a..24e213a1 100644 --- a/static/client/ai_interface.py +++ b/static/client/ai_interface.py @@ -44,9 +44,9 @@ ) from function_registry import FunctionRegistry from process_function_calls import ProcessFunctionCalls -from result_processor import ResultProcessor from workspace_manager import WorkspaceManager from markdown_parser import MarkdownParser +from tool_call_log_manager import ToolCallLogManager from slash_command_handler import SlashCommandHandler from command_autocomplete import CommandAutocomplete from tts_controller import get_tts_controller, TTSController @@ -107,11 +107,8 @@ def __init__(self, canvas: "Canvas") -> None: self._is_reasoning: bool = False self._request_start_time: Optional[float] = None # Timestamp when user request started self._needs_continuation_separator: bool = False # Add newline before next text after tool calls - # Tool call log state - self._tool_call_log_entries: list[dict[str, Any]] = [] - self._tool_call_log_element: Optional[Any] = None #
element - self._tool_call_log_summary: Optional[Any] = None # element - self._tool_call_log_content: Optional[Any] = None # content container div + # Tool call log state (delegated to ToolCallLogManager) + self._tool_call_log = ToolCallLogManager() # Timeout state self._response_timeout_id: Optional[int] = None # Chat message menu state @@ -1063,197 +1060,6 @@ def _ensure_reasoning_element(self) -> None: except Exception as e: print(f"Error creating reasoning element: {e}") - def _reset_tool_call_log_state(self) -> None: - """Reset all tool call log state for a new turn.""" - self._tool_call_log_entries = [] - self._tool_call_log_element = None - self._tool_call_log_summary = None - self._tool_call_log_content = None - - def _format_tool_call_args_display(self, args: dict[str, Any]) -> str: - """Format a tool call's arguments dict for compact display. - - Filters out the ``canvas`` key, truncates individual values to 30 - characters and the total string to 80 characters. - """ - parts: list[str] = [] - for k, v in args.items(): - if k == "canvas": - continue - v_str = str(v) - if len(v_str) > 30: - v_str = v_str[:27] + "..." - parts.append(f"{k}: {v_str}") - result = ", ".join(parts) - if len(result) > 80: - result = result[:77] + "..." - return result - - def _create_tool_call_entry_element(self, entry: dict[str, Any]) -> Any: - """Build the DOM element for a single tool call log entry.""" - div = html.DIV(Class="tool-call-entry") - - is_error = entry.get("is_error", False) - status_class = "tool-call-status error" if is_error else "tool-call-status success" - status_char = "\u2717" if is_error else "\u2713" - status_span = html.SPAN(status_char, Class=status_class) - div <= status_span - - name_span = html.SPAN(entry.get("name", ""), Class="tool-call-name") - div <= name_span - - short_args = entry.get("args_display", "") - full_args = entry.get("args_full", short_args) - args_span = html.SPAN(f"({short_args})", Class="tool-call-args") - div <= args_span - - # Show error message or result - result_display = entry.get("result_display", "") - result_full = entry.get("result_full", result_display) - result_span: Any = None - - if is_error: - error_msg = entry.get("error_message", "") - if error_msg: - err_span = html.SPAN(f" \u2192 {error_msg}", Class="tool-call-error-msg") - div <= err_span - elif result_display: - result_span = html.SPAN(f" \u2192 {result_display}", Class="tool-call-result") - div <= result_span - - # Click to toggle between truncated and full view - def _toggle_expand(event: Any) -> None: - try: - if div.classList.contains("expanded"): - div.classList.remove("expanded") - args_span.text = f"({short_args})" - if result_span is not None and result_display: - result_span.text = f" \u2192 {result_display}" - else: - div.classList.add("expanded") - args_span.text = f"({full_args})" - if result_span is not None and result_full: - result_span.text = f" \u2192 {result_full}" - except Exception: - pass - - div.bind("click", _toggle_expand) - - return div - - def _ensure_tool_call_log_element(self) -> None: - """Create the tool-call-log ``
`` element if it doesn't exist yet.""" - if self._tool_call_log_element is not None: - return - - # We need a message container to attach to - if self._stream_message_container is None: - self._ensure_stream_message_element() - - details = html.DETAILS(Class="tool-call-log-dropdown") - summary = html.SUMMARY("Using tools...", Class="tool-call-log-summary") - content_div = html.DIV(Class="tool-call-log-content") - details <= summary - details <= content_div - - # Insert before the content element so it appears after reasoning but before text - if self._stream_message_container is not None and self._stream_content_element is not None: - try: - self._stream_message_container.insertBefore(details, self._stream_content_element) - except Exception: - self._stream_message_container <= details - elif self._stream_message_container is not None: - self._stream_message_container <= details - - self._tool_call_log_element = details - self._tool_call_log_summary = summary - self._tool_call_log_content = content_div - - def _add_tool_call_entries(self, tool_calls: list[dict[str, Any]], call_results: dict[str, Any]) -> None: - """Record tool call entries and update the dropdown UI. - - Args: - tool_calls: Raw tool call dicts from the AI response. - call_results: Dict mapping result keys to their outcomes. - """ - self._ensure_tool_call_log_element() - - for call in tool_calls: - function_name: str = call.get("function_name", "") - args: dict[str, Any] = call.get("arguments", {}) - args_display = self._format_tool_call_args_display(args) - - result_key = ResultProcessor._generate_result_key(function_name, args) - - # Special handling for evaluate_expression which uses expression as key - if function_name == "evaluate_expression" and "expression" in args: - expr = str(args.get("expression", "")).replace(" ", "") - variables = args.get("variables") - if variables and isinstance(variables, dict): - vars_str = ", ".join(f"{k}:{v}" for k, v in variables.items()) - expr_key = f"{expr} for {vars_str}" - else: - expr_key = expr - result_value = call_results.get(expr_key, call_results.get(result_key, "")) - else: - result_value = call_results.get(result_key, call_results.get(function_name, "")) - is_error = isinstance(result_value, str) and result_value.startswith("Error:") - error_message = result_value if is_error else "" - - # Full untruncated args for the expanded view - args_full = ", ".join(f"{k}: {v}" for k, v in args.items() if k != "canvas") - - # Format result for display (truncate if too long) - result_display = "" - if not is_error and result_value: - result_str = str(result_value) - if len(result_str) > 100: - result_display = result_str[:97] + "..." - else: - result_display = result_str - - entry: dict[str, Any] = { - "name": function_name, - "args_display": args_display, - "args_full": args_full, - "is_error": is_error, - "error_message": error_message, - "result_display": result_display, - "result_full": str(result_value) if result_value else "", - } - self._tool_call_log_entries.append(entry) - - entry_el = self._create_tool_call_entry_element(entry) - if self._tool_call_log_content is not None: - self._tool_call_log_content <= entry_el - - # Update summary with running count - count = len(self._tool_call_log_entries) - if self._tool_call_log_summary is not None: - self._tool_call_log_summary.text = f"Using tools... ({count} so far)" - - def _finalize_tool_call_log(self) -> None: - """Update the tool call log summary to its final state.""" - if not self._tool_call_log_entries: - return - - count = len(self._tool_call_log_entries) - error_count = sum(1 for e in self._tool_call_log_entries if e.get("is_error")) - - label = f"Used {count} tool" if count == 1 else f"Used {count} tools" - if error_count: - label += f" ({error_count} failed)" - - if self._tool_call_log_summary is not None: - self._tool_call_log_summary.text = label - - # Ensure collapsed — removeAttribute is reliable for boolean HTML attributes - if self._tool_call_log_element is not None: - try: - self._tool_call_log_element.removeAttribute("open") - except Exception: - pass - def _on_stream_log(self, event_obj: Any) -> None: """Handle a server log event: output to browser console with appropriate level.""" try: @@ -1332,7 +1138,7 @@ def _on_stream_token(self, text: str) -> None: def _finalize_stream_message(self, final_message: Optional[str] = None) -> None: """Convert the streamed plain text to parsed markdown and render math.""" try: - self._finalize_tool_call_log() + self._tool_call_log.finalize() # Prefer the accumulated buffer (contains all text across tool calls) # Only use final_message as fallback if buffer is empty @@ -1375,7 +1181,7 @@ def _finalize_stream_message(self, final_message: Optional[str] = None) -> None: # Reasoning but no text content - remove the empty container self._remove_empty_response_container() elif text_to_render: - if self._tool_call_log_element is not None and self._stream_message_container is not None: + if self._tool_call_log.element is not None and self._stream_message_container is not None: # Tool call log exists — update the container in place to preserve the dropdown self._set_raw_message_text(self._stream_message_container, text_to_render) if self._stream_content_element is not None: @@ -1414,7 +1220,7 @@ def _finalize_stream_message(self, final_message: Optional[str] = None) -> None: self._reasoning_summary = None self._is_reasoning = False self._request_start_time = None - self._reset_tool_call_log_state() + self._tool_call_log.reset() def _remove_empty_response_container(self) -> None: """Remove the current response container if it has no actual text content. @@ -1432,7 +1238,7 @@ def _remove_empty_response_container(self) -> None: has_element_text = bool(element_text.strip()) except Exception: pass - has_tool_call_log = bool(self._tool_call_log_entries) + has_tool_call_log = bool(self._tool_call_log.entries) # Only remove if there's NO actual text content anywhere and no tool call log if ( @@ -1454,7 +1260,7 @@ def _remove_empty_response_container(self) -> None: self._reasoning_summary = None self._reasoning_buffer = "" self._is_reasoning = False - self._reset_tool_call_log_state() + self._tool_call_log.reset() # Don't reset _request_start_time here - we want to keep timing across tool calls except Exception as e: print(f"Error removing empty container: {e}") @@ -1500,7 +1306,10 @@ def _on_stream_final(self, event_obj: Any) -> None: self.canvas, ) self._store_results_in_canvas_state(call_results) - self._add_tool_call_entries(ai_tool_calls, call_results) + if self._stream_message_container is None: + self._ensure_stream_message_element() + self._tool_call_log.ensure_element(self._stream_message_container, self._stream_content_element) + self._tool_call_log.add_entries(ai_tool_calls, call_results) if self._stop_requested: # Still capture trace for the executed tool calls @@ -2129,7 +1938,7 @@ def _send_prompt_to_ai_stream( self._reasoning_summary = None self._is_reasoning = False self._needs_continuation_separator = False - self._reset_tool_call_log_state() + self._tool_call_log.reset() try: payload = self._create_request_payload(prompt, include_svg=True) @@ -2182,7 +1991,7 @@ def _send_prompt_to_ai( self._reasoning_summary = None self._is_reasoning = False self._needs_continuation_separator = False - self._reset_tool_call_log_state() + self._tool_call_log.reset() self._send_request(prompt, action_trace=action_trace) diff --git a/static/client/client_tests/test_error_recovery.py b/static/client/client_tests/test_error_recovery.py index 20517372..d49f291d 100644 --- a/static/client/client_tests/test_error_recovery.py +++ b/static/client/client_tests/test_error_recovery.py @@ -11,6 +11,8 @@ from browser import document, html, window +from tool_call_log_manager import ToolCallLogManager + class TestErrorRecovery(unittest.TestCase): """Test the message recovery mechanism on AI errors.""" @@ -92,10 +94,7 @@ def test_message_buffer_cleared_on_success(self) -> None: ai._reasoning_summary = None ai._is_reasoning = False ai._request_start_time = None - ai._tool_call_log_entries = [] - ai._tool_call_log_element = None - ai._tool_call_log_summary = None - ai._tool_call_log_content = None + ai._tool_call_log = ToolCallLogManager() ai.is_processing = True ai._stop_requested = False ai._response_timeout_id = None @@ -105,7 +104,6 @@ def test_message_buffer_cleared_on_success(self) -> None: ai._finalize_stream_message = lambda msg=None: None ai._enable_send_controls = lambda: None ai._normalize_stream_event = lambda e: e if isinstance(e, dict) else {} - ai._reset_tool_call_log_state = lambda: None # Set the buffered message ai._last_user_message = "test message" @@ -135,10 +133,7 @@ def test_message_buffer_preserved_on_error(self) -> None: ai._reasoning_summary = None ai._is_reasoning = False ai._request_start_time = None - ai._tool_call_log_entries = [] - ai._tool_call_log_element = None - ai._tool_call_log_summary = None - ai._tool_call_log_content = None + ai._tool_call_log = ToolCallLogManager() ai.is_processing = True ai._stop_requested = False ai._response_timeout_id = None @@ -155,7 +150,6 @@ def mock_restore() -> None: ai._enable_send_controls = lambda: None ai._normalize_stream_event = lambda e: e if isinstance(e, dict) else {} ai._restore_user_message_on_error = mock_restore - ai._reset_tool_call_log_state = lambda: None # Set the buffered message original_message = "simulate server error for retry" diff --git a/static/client/client_tests/test_tool_call_log.py b/static/client/client_tests/test_tool_call_log.py index 0eff4ce7..bcbd98de 100644 --- a/static/client/client_tests/test_tool_call_log.py +++ b/static/client/client_tests/test_tool_call_log.py @@ -5,7 +5,7 @@ from browser import html -from ai_interface import AIInterface +from tool_call_log_manager import ToolCallLogManager from .simple_mock import SimpleMock @@ -42,72 +42,54 @@ def _find_child_by_class(parent: Any, cls: str) -> Optional[Any]: return None -def _make_ai() -> AIInterface: - """Create a minimal AIInterface instance for testing (no heavy __init__).""" - ai = AIInterface.__new__(AIInterface) - ai._tool_call_log_entries = [] - ai._tool_call_log_element = None - ai._tool_call_log_summary = None - ai._tool_call_log_content = None - ai._stream_buffer = "" - ai._stream_content_element = None - ai._stream_message_container = None - ai._reasoning_buffer = "" - ai._reasoning_element = None - ai._reasoning_details = None - ai._reasoning_summary = None - ai._is_reasoning = False - ai._request_start_time = None - ai._needs_continuation_separator = False - ai._open_message_menu = None - ai._message_menu_global_bound = True - ai._copy_text_to_clipboard = SimpleMock(return_value=True) - return ai +def _make_mgr() -> ToolCallLogManager: + """Create a fresh ToolCallLogManager instance for testing.""" + return ToolCallLogManager() class TestToolCallLog(unittest.TestCase): - """Tests for the tool-call log dropdown feature in AIInterface.""" + """Tests for the tool-call log dropdown feature via ToolCallLogManager.""" # ── Argument formatting ────────────────────────────────────── def test_format_args_simple(self) -> None: - ai = _make_ai() - result = ai._format_tool_call_args_display({"x": 5, "y": 10}) + mgr = _make_mgr() + result = mgr.format_args_display({"x": 5, "y": 10}) self.assertEqual(result, "x: 5, y: 10") def test_format_args_filters_canvas(self) -> None: - ai = _make_ai() - result = ai._format_tool_call_args_display({"x": 5, "canvas": "", "y": 10}) + mgr = _make_mgr() + result = mgr.format_args_display({"x": 5, "canvas": "", "y": 10}) self.assertEqual(result, "x: 5, y: 10") def test_format_args_truncates_long_values(self) -> None: - ai = _make_ai() + mgr = _make_mgr() long_val = "a" * 50 - result = ai._format_tool_call_args_display({"data": long_val}) + result = mgr.format_args_display({"data": long_val}) self.assertIn("...", result) # The value portion should be at most 30 characters val_part = result.split(": ", 1)[1] self.assertLessEqual(len(val_part), 30) def test_format_args_truncates_total_string(self) -> None: - ai = _make_ai() + mgr = _make_mgr() # Many short args that together exceed 80 chars args = {f"k{i}": f"value{i}" for i in range(20)} - result = ai._format_tool_call_args_display(args) + result = mgr.format_args_display(args) self.assertLessEqual(len(result), 80) self.assertTrue(result.endswith("...")) def test_format_args_empty(self) -> None: - ai = _make_ai() - result = ai._format_tool_call_args_display({}) + mgr = _make_mgr() + result = mgr.format_args_display({}) self.assertEqual(result, "") # ── Entry element creation ─────────────────────────────────── def test_entry_element_success(self) -> None: - ai = _make_ai() + mgr = _make_mgr() entry = {"name": "create_point", "args_display": "x: 5, y: 10", "is_error": False, "error_message": ""} - el = ai._create_tool_call_entry_element(entry) + el = mgr.create_entry_element(entry) cls = _get_class_attr(el) self.assertIn("tool-call-entry", cls) @@ -124,14 +106,14 @@ def test_entry_element_success(self) -> None: self.assertIn("x: 5", args.text) def test_entry_element_error(self) -> None: - ai = _make_ai() + mgr = _make_mgr() entry = { "name": "create_segment", "args_display": "start: Q, end: R", "is_error": True, "error_message": "Error: Point Q not found", } - el = ai._create_tool_call_entry_element(entry) + el = mgr.create_entry_element(entry) status = _find_child_by_class(el, "tool-call-status") self.assertIsNotNone(status) @@ -142,9 +124,9 @@ def test_entry_element_error(self) -> None: self.assertIn("Point Q not found", err_msg.text) def test_entry_element_empty_args(self) -> None: - ai = _make_ai() + mgr = _make_mgr() entry = {"name": "clear_canvas", "args_display": "", "is_error": False, "error_message": ""} - el = ai._create_tool_call_entry_element(entry) + el = mgr.create_entry_element(entry) args = _find_child_by_class(el, "tool-call-args") self.assertIsNotNone(args) @@ -153,114 +135,100 @@ def test_entry_element_empty_args(self) -> None: # ── Ensure tool call log element ───────────────────────────── def test_ensure_creates_details_element(self) -> None: - ai = _make_ai() - # Provide a container so _ensure doesn't try to access document + mgr = _make_mgr() container = html.DIV() content = html.DIV(Class="chat-content") container <= content - ai._stream_message_container = container - ai._stream_content_element = content - ai._ensure_tool_call_log_element() + mgr.ensure_element(container, content) - self.assertIsNotNone(ai._tool_call_log_element) - cls = _get_class_attr(ai._tool_call_log_element) + self.assertIsNotNone(mgr.element) + cls = _get_class_attr(mgr.element) self.assertIn("tool-call-log-dropdown", cls) def test_ensure_creates_summary(self) -> None: - ai = _make_ai() + mgr = _make_mgr() container = html.DIV() content = html.DIV(Class="chat-content") container <= content - ai._stream_message_container = container - ai._stream_content_element = content - ai._ensure_tool_call_log_element() + mgr.ensure_element(container, content) - self.assertIsNotNone(ai._tool_call_log_summary) - self.assertIn("Using tools", ai._tool_call_log_summary.text) + self.assertIsNotNone(mgr.summary) + self.assertIn("Using tools", mgr.summary.text) def test_ensure_creates_content_div(self) -> None: - ai = _make_ai() + mgr = _make_mgr() container = html.DIV() content = html.DIV(Class="chat-content") container <= content - ai._stream_message_container = container - ai._stream_content_element = content - ai._ensure_tool_call_log_element() + mgr.ensure_element(container, content) - self.assertIsNotNone(ai._tool_call_log_content) - cls = _get_class_attr(ai._tool_call_log_content) + self.assertIsNotNone(mgr.content) + cls = _get_class_attr(mgr.content) self.assertIn("tool-call-log-content", cls) def test_ensure_idempotent(self) -> None: - ai = _make_ai() + mgr = _make_mgr() container = html.DIV() content = html.DIV(Class="chat-content") container <= content - ai._stream_message_container = container - ai._stream_content_element = content - ai._ensure_tool_call_log_element() - first_el = ai._tool_call_log_element + mgr.ensure_element(container, content) + first_el = mgr.element - ai._ensure_tool_call_log_element() - self.assertIs(ai._tool_call_log_element, first_el) + mgr.ensure_element(container, content) + self.assertIs(mgr.element, first_el) def test_ensure_creates_container_if_missing(self) -> None: - ai = _make_ai() - # Don't set _stream_message_container — it should create one - # We need _ensure_stream_message_element to work, which needs document["chat-history"] - # Instead, verify our guard logic: if _tool_call_log_element is already set, skip + mgr = _make_mgr() + # If element is already set, ensure_element should be a no-op sentinel = html.DETAILS() - ai._tool_call_log_element = sentinel - ai._ensure_tool_call_log_element() - self.assertIs(ai._tool_call_log_element, sentinel) + mgr.element = sentinel + mgr.ensure_element(None, None) + self.assertIs(mgr.element, sentinel) # ── Adding entries ─────────────────────────────────────────── def test_add_single_success_entry(self) -> None: - ai = _make_ai() + mgr = _make_mgr() container = html.DIV() content = html.DIV(Class="chat-content") container <= content - ai._stream_message_container = container - ai._stream_content_element = content + mgr.ensure_element(container, content) tool_calls = [{"function_name": "create_point", "arguments": {"x": 5, "y": 10, "name": "A"}}] call_results = {"create_point(x:5, y:10, name:A)": "Point A created at (5, 10)"} - ai._add_tool_call_entries(tool_calls, call_results) + mgr.add_entries(tool_calls, call_results) - self.assertEqual(len(ai._tool_call_log_entries), 1) - self.assertFalse(ai._tool_call_log_entries[0]["is_error"]) + self.assertEqual(len(mgr.entries), 1) + self.assertFalse(mgr.entries[0]["is_error"]) # Content div should have one child entry - self.assertTrue(len(ai._tool_call_log_content.children) >= 1) + self.assertTrue(len(mgr.content.children) >= 1) def test_add_single_error_entry(self) -> None: - ai = _make_ai() + mgr = _make_mgr() container = html.DIV() content = html.DIV(Class="chat-content") container <= content - ai._stream_message_container = container - ai._stream_content_element = content + mgr.ensure_element(container, content) tool_calls = [{"function_name": "create_segment", "arguments": {"start": "Q", "end": "R"}}] call_results = {"create_segment(start:Q, end:R)": "Error: Point Q not found"} - ai._add_tool_call_entries(tool_calls, call_results) + mgr.add_entries(tool_calls, call_results) - self.assertEqual(len(ai._tool_call_log_entries), 1) - self.assertTrue(ai._tool_call_log_entries[0]["is_error"]) + self.assertEqual(len(mgr.entries), 1) + self.assertTrue(mgr.entries[0]["is_error"]) def test_add_multiple_entries(self) -> None: - ai = _make_ai() + mgr = _make_mgr() container = html.DIV() content = html.DIV(Class="chat-content") container <= content - ai._stream_message_container = container - ai._stream_content_element = content + mgr.ensure_element(container, content) tool_calls = [ {"function_name": "create_point", "arguments": {"x": 0, "y": 0, "name": "O"}}, @@ -271,87 +239,82 @@ def test_add_multiple_entries(self) -> None: "create_point(x:5, y:5, name:P)": "Point P created", } - ai._add_tool_call_entries(tool_calls, call_results) - self.assertEqual(len(ai._tool_call_log_entries), 2) + mgr.add_entries(tool_calls, call_results) + self.assertEqual(len(mgr.entries), 2) def test_add_entries_accumulates_across_calls(self) -> None: - ai = _make_ai() + mgr = _make_mgr() container = html.DIV() content = html.DIV(Class="chat-content") container <= content - ai._stream_message_container = container - ai._stream_content_element = content + mgr.ensure_element(container, content) # First round - ai._add_tool_call_entries( + mgr.add_entries( [{"function_name": "create_point", "arguments": {"x": 1, "y": 2}}], {"create_point(x:1, y:2)": "ok"}, ) - self.assertEqual(len(ai._tool_call_log_entries), 1) + self.assertEqual(len(mgr.entries), 1) # Second round - ai._add_tool_call_entries( + mgr.add_entries( [{"function_name": "create_circle", "arguments": {"center": "A", "radius": 5}}], {"create_circle(center:A, radius:5)": "ok"}, ) - self.assertEqual(len(ai._tool_call_log_entries), 2) + self.assertEqual(len(mgr.entries), 2) def test_add_entries_updates_summary(self) -> None: - ai = _make_ai() + mgr = _make_mgr() container = html.DIV() content = html.DIV(Class="chat-content") container <= content - ai._stream_message_container = container - ai._stream_content_element = content + mgr.ensure_element(container, content) - ai._add_tool_call_entries( + mgr.add_entries( [{"function_name": "f", "arguments": {}}], {"f()": "ok"}, ) - self.assertIn("1 so far", ai._tool_call_log_summary.text) + self.assertIn("1 so far", mgr.summary.text) # ── Finalize tool call log ─────────────────────────────────── def test_finalize_singular(self) -> None: - ai = _make_ai() + mgr = _make_mgr() container = html.DIV() content = html.DIV(Class="chat-content") container <= content - ai._stream_message_container = container - ai._stream_content_element = content + mgr.ensure_element(container, content) - ai._add_tool_call_entries( + mgr.add_entries( [{"function_name": "f", "arguments": {}}], {"f()": "ok"}, ) - ai._finalize_tool_call_log() - self.assertEqual(ai._tool_call_log_summary.text, "Used 1 tool") + mgr.finalize() + self.assertEqual(mgr.summary.text, "Used 1 tool") def test_finalize_plural(self) -> None: - ai = _make_ai() + mgr = _make_mgr() container = html.DIV() content = html.DIV(Class="chat-content") container <= content - ai._stream_message_container = container - ai._stream_content_element = content + mgr.ensure_element(container, content) for i in range(3): - ai._add_tool_call_entries( + mgr.add_entries( [{"function_name": f"f{i}", "arguments": {}}], {f"f{i}()": "ok"}, ) - ai._finalize_tool_call_log() - self.assertEqual(ai._tool_call_log_summary.text, "Used 3 tools") + mgr.finalize() + self.assertEqual(mgr.summary.text, "Used 3 tools") def test_finalize_with_errors(self) -> None: - ai = _make_ai() + mgr = _make_mgr() container = html.DIV() content = html.DIV(Class="chat-content") container <= content - ai._stream_message_container = container - ai._stream_content_element = content + mgr.ensure_element(container, content) - ai._add_tool_call_entries( + mgr.add_entries( [ {"function_name": "a", "arguments": {}}, {"function_name": "b", "arguments": {}}, @@ -359,44 +322,60 @@ def test_finalize_with_errors(self) -> None: ], {"a()": "ok", "b()": "ok", "c()": "Error: something failed"}, ) - ai._finalize_tool_call_log() - self.assertIn("Used 3 tools", ai._tool_call_log_summary.text) - self.assertIn("1 failed", ai._tool_call_log_summary.text) + mgr.finalize() + self.assertIn("Used 3 tools", mgr.summary.text) + self.assertIn("1 failed", mgr.summary.text) def test_finalize_no_entries(self) -> None: - ai = _make_ai() + mgr = _make_mgr() # Should be a no-op, no crash - ai._finalize_tool_call_log() - self.assertIsNone(ai._tool_call_log_summary) + mgr.finalize() + self.assertIsNone(mgr.summary) # ── State management ───────────────────────────────────────── def test_state_reset_clears_tool_call_log(self) -> None: - """Verify _reset_tool_call_log_state clears all tool call log state.""" - ai = _make_ai() + """Verify reset() clears all tool call log state.""" + mgr = _make_mgr() container = html.DIV() content = html.DIV(Class="chat-content") container <= content - ai._stream_message_container = container - ai._stream_content_element = content + mgr.ensure_element(container, content) - ai._add_tool_call_entries( + mgr.add_entries( [{"function_name": "f", "arguments": {}}], {"f()": "ok"}, ) - self.assertEqual(len(ai._tool_call_log_entries), 1) - self.assertIsNotNone(ai._tool_call_log_element) + self.assertEqual(len(mgr.entries), 1) + self.assertIsNotNone(mgr.element) - ai._reset_tool_call_log_state() + mgr.reset() - self.assertEqual(ai._tool_call_log_entries, []) - self.assertIsNone(ai._tool_call_log_element) - self.assertIsNone(ai._tool_call_log_summary) - self.assertIsNone(ai._tool_call_log_content) + self.assertEqual(mgr.entries, []) + self.assertIsNone(mgr.element) + self.assertIsNone(mgr.summary) + self.assertIsNone(mgr.content) def test_container_preserved_with_tool_log(self) -> None: """_remove_empty_response_container preserves container when tool log exists.""" - ai = _make_ai() + from ai_interface import AIInterface + + ai = AIInterface.__new__(AIInterface) + ai._tool_call_log = ToolCallLogManager() + ai._stream_buffer = "" + ai._stream_content_element = None + ai._stream_message_container = None + ai._reasoning_buffer = "" + ai._reasoning_element = None + ai._reasoning_details = None + ai._reasoning_summary = None + ai._is_reasoning = False + ai._request_start_time = None + ai._needs_continuation_separator = False + ai._open_message_menu = None + ai._message_menu_global_bound = True + ai._copy_text_to_clipboard = SimpleMock(return_value=True) + container = html.DIV() content = html.DIV(Class="chat-content") content.text = "" @@ -406,7 +385,7 @@ def test_container_preserved_with_tool_log(self) -> None: ai._stream_buffer = "" # Add tool call entries so the log is non-empty - ai._tool_call_log_entries = [{"name": "f", "is_error": False}] + ai._tool_call_log.entries = [{"name": "f", "is_error": False}] # Call the actual method — it should NOT remove the container # because tool call log entries exist diff --git a/static/client/tool_call_log_manager.py b/static/client/tool_call_log_manager.py new file mode 100644 index 00000000..44c11ac4 --- /dev/null +++ b/static/client/tool_call_log_manager.py @@ -0,0 +1,258 @@ +"""Tool call log manager for the AI interface. + +Manages the collapsible tool-call log dropdown that appears in the chat +when the AI executes tool calls. Tracks entries, builds DOM elements, +and updates the summary as tool calls accumulate. + +Extracted from ``AIInterface`` to reduce god-class complexity while +preserving the identical public behaviour. +""" + +from __future__ import annotations + +from typing import Any + +from browser import html + +from result_processor import ResultProcessor + + +class ToolCallLogManager: + """Manages the tool-call log dropdown UI and its backing state. + + Attributes: + entries: Accumulated tool-call entry dicts for the current turn. + element: The ``
`` DOM element (or ``None`` before first use). + summary: The ```` DOM element inside *element*. + content: The container ``
`` holding individual entry rows. + """ + + def __init__(self) -> None: + self.entries: list[dict[str, Any]] = [] + self.element: Any | None = None #
element + self.summary: Any | None = None # element + self.content: Any | None = None # content container div + + # ── State management ──────────────────────────────────────── + + def reset(self) -> None: + """Reset all tool call log state for a new turn.""" + self.entries = [] + self.element = None + self.summary = None + self.content = None + + # ── Formatting helpers ────────────────────────────────────── + + def format_args_display(self, args: dict[str, Any]) -> str: + """Format a tool call's arguments dict for compact display. + + Filters out the ``canvas`` key, truncates individual values to 30 + characters and the total string to 80 characters. + """ + parts: list[str] = [] + for k, v in args.items(): + if k == "canvas": + continue + v_str = str(v) + if len(v_str) > 30: + v_str = v_str[:27] + "..." + parts.append(f"{k}: {v_str}") + result = ", ".join(parts) + if len(result) > 80: + result = result[:77] + "..." + return result + + # ── DOM element creation ──────────────────────────────────── + + def create_entry_element(self, entry: dict[str, Any]) -> Any: + """Build the DOM element for a single tool call log entry.""" + div = html.DIV(Class="tool-call-entry") + + is_error = entry.get("is_error", False) + status_class = "tool-call-status error" if is_error else "tool-call-status success" + status_char = "\u2717" if is_error else "\u2713" + status_span = html.SPAN(status_char, Class=status_class) + div <= status_span + + name_span = html.SPAN(entry.get("name", ""), Class="tool-call-name") + div <= name_span + + short_args = entry.get("args_display", "") + full_args = entry.get("args_full", short_args) + args_span = html.SPAN(f"({short_args})", Class="tool-call-args") + div <= args_span + + # Show error message or result + result_display = entry.get("result_display", "") + result_full = entry.get("result_full", result_display) + result_span: Any = None + + if is_error: + error_msg = entry.get("error_message", "") + if error_msg: + err_span = html.SPAN(f" \u2192 {error_msg}", Class="tool-call-error-msg") + div <= err_span + elif result_display: + result_span = html.SPAN(f" \u2192 {result_display}", Class="tool-call-result") + div <= result_span + + # Click to toggle between truncated and full view + def _toggle_expand(event: Any) -> None: + try: + if div.classList.contains("expanded"): + div.classList.remove("expanded") + args_span.text = f"({short_args})" + if result_span is not None and result_display: + result_span.text = f" \u2192 {result_display}" + else: + div.classList.add("expanded") + args_span.text = f"({full_args})" + if result_span is not None and result_full: + result_span.text = f" \u2192 {result_full}" + except Exception: + pass + + div.bind("click", _toggle_expand) + + return div + + # ── Ensure / create the log dropdown ──────────────────────── + + def ensure_element( + self, + stream_container: Any | None, + stream_content: Any | None, + ) -> None: + """Create the tool-call-log ``
`` element if it doesn't exist yet. + + Args: + stream_container: The outer message container DOM node. If ``None`` + the caller must create the stream message element first. + stream_content: The chat-content ``
`` inside *stream_container*. + """ + if self.element is not None: + return + + details = html.DETAILS(Class="tool-call-log-dropdown") + summary = html.SUMMARY("Using tools...", Class="tool-call-log-summary") + content_div = html.DIV(Class="tool-call-log-content") + details <= summary + details <= content_div + + # Insert before the content element so it appears after reasoning but before text + if stream_container is not None and stream_content is not None: + try: + stream_container.insertBefore(details, stream_content) + except Exception: + stream_container <= details + elif stream_container is not None: + stream_container <= details + + self.element = details + self.summary = summary + self.content = content_div + + # ── Adding entries ────────────────────────────────────────── + + def add_entries(self, tool_calls: list[dict[str, Any]], call_results: dict[str, Any]) -> None: + """Record tool call entries and update the dropdown UI. + + .. note:: + + The caller must ensure that ``ensure_element`` has been called + (or the element already exists) before invoking this method. + This method calls ``ensure_element`` itself as a convenience, + but passes ``None`` containers — so the log dropdown will only + be created if the caller has previously set up the element. + + Args: + tool_calls: Raw tool call dicts from the AI response. + call_results: Dict mapping result keys to their outcomes. + """ + # ensure_element is a no-op when self.element is already set + # When called from AIInterface, ensure_element is called beforehand + # with the proper container references. + if self.element is None: + # Defensive: do nothing if the element was never created. + # The caller (AIInterface) is responsible for calling + # ensure_element with the right containers first. + pass + + for call in tool_calls: + function_name: str = call.get("function_name", "") + args: dict[str, Any] = call.get("arguments", {}) + args_display = self.format_args_display(args) + + result_key = ResultProcessor._generate_result_key(function_name, args) + + # Special handling for evaluate_expression which uses expression as key + if function_name == "evaluate_expression" and "expression" in args: + expr = str(args.get("expression", "")).replace(" ", "") + variables = args.get("variables") + if variables and isinstance(variables, dict): + vars_str = ", ".join(f"{k}:{v}" for k, v in variables.items()) + expr_key = f"{expr} for {vars_str}" + else: + expr_key = expr + result_value = call_results.get(expr_key, call_results.get(result_key, "")) + else: + result_value = call_results.get(result_key, call_results.get(function_name, "")) + is_error = isinstance(result_value, str) and result_value.startswith("Error:") + error_message = result_value if is_error else "" + + # Full untruncated args for the expanded view + args_full = ", ".join(f"{k}: {v}" for k, v in args.items() if k != "canvas") + + # Format result for display (truncate if too long) + result_display = "" + if not is_error and result_value: + result_str = str(result_value) + if len(result_str) > 100: + result_display = result_str[:97] + "..." + else: + result_display = result_str + + entry: dict[str, Any] = { + "name": function_name, + "args_display": args_display, + "args_full": args_full, + "is_error": is_error, + "error_message": error_message, + "result_display": result_display, + "result_full": str(result_value) if result_value else "", + } + self.entries.append(entry) + + entry_el = self.create_entry_element(entry) + if self.content is not None: + self.content <= entry_el + + # Update summary with running count + count = len(self.entries) + if self.summary is not None: + self.summary.text = f"Using tools... ({count} so far)" + + # ── Finalization ──────────────────────────────────────────── + + def finalize(self) -> None: + """Update the tool call log summary to its final state.""" + if not self.entries: + return + + count = len(self.entries) + error_count = sum(1 for e in self.entries if e.get("is_error")) + + label = f"Used {count} tool" if count == 1 else f"Used {count} tools" + if error_count: + label += f" ({error_count} failed)" + + if self.summary is not None: + self.summary.text = label + + # Ensure collapsed — removeAttribute is reliable for boolean HTML attributes + if self.element is not None: + try: + self.element.removeAttribute("open") + except Exception: + pass From cc803a01277f6a282db91d88bd4279ced8adc1bc Mon Sep 17 00:00:00 2001 From: Vlad Ciobanu <95963142+vl3c@users.noreply.github.com> Date: Tue, 24 Mar 2026 23:24:07 +0200 Subject: [PATCH 18/28] Extract MessageMenuManager from AIInterface Move message menu, clipboard, and raw text methods into a dedicated MessageMenuManager class. TTS menu items are wired via optional callbacks to avoid coupling. Update test_chat_message_menu.py to test the new class directly. --- static/client/ai_interface.py | 275 +---------------- .../client_tests/test_chat_message_menu.py | 14 +- static/client/message_menu_manager.py | 291 ++++++++++++++++++ 3 files changed, 314 insertions(+), 266 deletions(-) create mode 100644 static/client/message_menu_manager.py diff --git a/static/client/ai_interface.py b/static/client/ai_interface.py index 24e213a1..e0dad0fc 100644 --- a/static/client/ai_interface.py +++ b/static/client/ai_interface.py @@ -47,6 +47,7 @@ from workspace_manager import WorkspaceManager from markdown_parser import MarkdownParser from tool_call_log_manager import ToolCallLogManager +from message_menu_manager import MessageMenuManager from slash_command_handler import SlashCommandHandler from command_autocomplete import CommandAutocomplete from tts_controller import get_tts_controller, TTSController @@ -111,9 +112,11 @@ def __init__(self, canvas: "Canvas") -> None: self._tool_call_log = ToolCallLogManager() # Timeout state self._response_timeout_id: Optional[int] = None - # Chat message menu state - self._open_message_menu: Optional[Any] = None # DOMNode - self._message_menu_global_bound: bool = False + # Chat message menu (delegated to MessageMenuManager) + self._message_menu = MessageMenuManager( + on_read_aloud=self._handle_tts_read_aloud, + on_tts_settings=self._show_tts_settings_modal, + ) # Image attachment state self._attached_images: list[str] = [] # Data URLs of attached images # Message recovery state @@ -497,250 +500,6 @@ def _render_math(self) -> None: # MathJax not available or error occurred, continue silently pass - def _set_raw_message_text(self, message_container: Any, raw_text: str) -> None: - """Attach raw message source text to a message container for later actions (copy, etc.).""" - try: - # Store on the element object itself to avoid parsing rendered HTML. - setattr(message_container, "_raw_message_text", raw_text) - except Exception: - pass - - def _get_raw_message_text(self, message_container: Any) -> str: - """Return the stored raw message text from the container, or empty string if missing.""" - try: - value = getattr(message_container, "_raw_message_text", "") - if isinstance(value, str): - return value - return str(value) - except Exception: - return "" - - def _copy_text_to_clipboard(self, text: str) -> bool: - """Copy text to clipboard using the modern API with a fallback for older contexts.""" - if text is None: - text = "" - if not isinstance(text, str): - try: - text = str(text) - except Exception: - text = "" - - # Prefer navigator.clipboard when available (may require secure context). - try: - navigator = getattr(window, "navigator", None) - clipboard = getattr(navigator, "clipboard", None) if navigator is not None else None - write_text = getattr(clipboard, "writeText", None) if clipboard is not None else None - if callable(write_text): - write_text(text) - return True - except Exception: - pass - - # Fallback: temporary textarea + execCommand('copy') - try: - textarea = html.TEXTAREA() - textarea.value = text - textarea.attrs["readonly"] = "readonly" - textarea.style.position = "fixed" - textarea.style.left = "0" - textarea.style.top = "0" - textarea.style.opacity = "0" - - # Append to DOM, select content, copy, then remove. - document <= textarea - try: - textarea.focus() - except Exception: - pass - try: - textarea.select() - except Exception: - pass - try: - textarea.setSelectionRange(0, len(text)) - except Exception: - pass - - copied = False - try: - copied = bool(window.document.execCommand("copy")) - except Exception: - try: - copied = bool(document.execCommand("copy")) - except Exception: - copied = False - - try: - textarea.remove() - except Exception: - pass - - return copied - except Exception: - return False - - def _bind_message_menu_global_handlers(self) -> None: - """Bind global document handlers needed for message menus (close on outside click).""" - if self._message_menu_global_bound: - return - try: - document.bind("click", self._on_document_click_close_message_menu) - self._message_menu_global_bound = True - except Exception: - self._message_menu_global_bound = False - - def _on_document_click_close_message_menu(self, _event: Any) -> None: - """Close any open message menu when clicking outside of it.""" - try: - if self._open_message_menu is not None: - self._hide_message_menu(self._open_message_menu) - except Exception: - self._open_message_menu = None - - def _hide_message_menu(self, menu: Any) -> None: - try: - menu.style.display = "none" - except Exception: - pass - if self._open_message_menu is menu: - self._open_message_menu = None - - def _show_message_menu(self, menu: Any) -> None: - try: - if self._open_message_menu is not None and self._open_message_menu is not menu: - self._hide_message_menu(self._open_message_menu) - except Exception: - self._open_message_menu = None - - try: - menu.style.display = "block" - except Exception: - pass - self._open_message_menu = menu - - def _toggle_message_menu(self, menu: Any) -> None: - try: - current_display = getattr(menu.style, "display", "") - except Exception: - current_display = "" - - if current_display == "none" or not current_display: - self._show_message_menu(menu) - else: - self._hide_message_menu(menu) - - def _attach_message_menu(self, message_container: Any, is_ai_message: bool = False) -> None: - """Attach the per-message '...' menu to the message container (idempotent). - - Args: - message_container: The DOM element to attach the menu to - is_ai_message: Whether this is an AI message (enables TTS option) - """ - try: - if bool(getattr(message_container, "_has_message_menu", False)): - return - setattr(message_container, "_has_message_menu", True) - except Exception: - # If we cannot track state on the element, continue anyway. - pass - - self._bind_message_menu_global_handlers() - - menu_button = html.BUTTON("...", Class="chat-message-menu-button") - try: - menu_button.attrs["type"] = "button" - menu_button.attrs["title"] = "Message options" - menu_button.attrs["aria-label"] = "Message options" - except Exception: - pass - - menu = html.DIV(Class="chat-message-menu") - try: - menu.style.display = "none" - except Exception: - pass - - copy_item = html.BUTTON("Copy message text", Class="chat-message-menu-item") - try: - copy_item.attrs["type"] = "button" - except Exception: - pass - - def _stop_propagation(ev: Any) -> None: - try: - ev.stopPropagation() - except Exception: - pass - - def _on_menu_button_click(ev: Any) -> None: - _stop_propagation(ev) - self._toggle_message_menu(menu) - - def _on_menu_click(ev: Any) -> None: - _stop_propagation(ev) - - def _on_copy_click(ev: Any) -> None: - _stop_propagation(ev) - raw_text = self._get_raw_message_text(message_container) - self._copy_text_to_clipboard(raw_text) - self._hide_message_menu(menu) - - try: - menu_button.bind("click", _on_menu_button_click) - menu.bind("click", _on_menu_click) - copy_item.bind("click", _on_copy_click) - except Exception: - pass - - menu <= copy_item - - # Add TTS options for AI messages - if is_ai_message: - read_aloud_item = html.BUTTON("Read aloud", Class="chat-message-menu-item tts-read-aloud") - try: - read_aloud_item.attrs["type"] = "button" - except Exception: - pass - - def _on_read_aloud_click(ev: Any) -> None: - _stop_propagation(ev) - self._hide_message_menu(menu) - raw_text = self._get_raw_message_text(message_container) - self._handle_tts_read_aloud(raw_text, read_aloud_item) - - try: - read_aloud_item.bind("click", _on_read_aloud_click) - except Exception: - pass - - menu <= read_aloud_item - - # TTS settings option - tts_settings_item = html.BUTTON("TTS settings...", Class="chat-message-menu-item") - try: - tts_settings_item.attrs["type"] = "button" - except Exception: - pass - - def _on_tts_settings_click(ev: Any) -> None: - _stop_propagation(ev) - self._hide_message_menu(menu) - self._show_tts_settings_modal() - - try: - tts_settings_item.bind("click", _on_tts_settings_click) - except Exception: - pass - - menu <= tts_settings_item - - # Add button + menu to the message container (positioned by CSS). - try: - message_container <= menu_button - message_container <= menu - except Exception: - pass - def _handle_tts_read_aloud(self, text: str, button_element: Any) -> None: """Handle TTS read aloud action. @@ -981,8 +740,8 @@ def handler(event: Any) -> None: message_container <= images_container # Store the raw source text for copy actions (do not rely on rendered HTML) - self._set_raw_message_text(message_container, message) - self._attach_message_menu(message_container, is_ai_message=(sender == "AI")) + self._message_menu.set_raw_text(message_container, message) + self._message_menu.attach(message_container, is_ai_message=(sender == "AI")) return message_container @@ -1019,8 +778,8 @@ def _ensure_stream_message_element(self) -> None: self._stream_message_container = container self._stream_content_element = content # Initialize raw text storage for streaming content - self._set_raw_message_text(container, "") - self._attach_message_menu(container, is_ai_message=True) + self._message_menu.set_raw_text(container, "") + self._message_menu.attach(container, is_ai_message=True) except Exception as e: print(f"Error creating streaming element: {e}") @@ -1055,8 +814,8 @@ def _ensure_reasoning_element(self) -> None: self._stream_message_container = container self._stream_content_element = response_content # Initialize raw text storage for reasoning responses - self._set_raw_message_text(container, "") - self._attach_message_menu(container, is_ai_message=True) + self._message_menu.set_raw_text(container, "") + self._message_menu.attach(container, is_ai_message=True) except Exception as e: print(f"Error creating reasoning element: {e}") @@ -1130,7 +889,7 @@ def _on_stream_token(self, text: str) -> None: if self._stream_content_element is not None: self._stream_content_element.text = self._stream_buffer if self._stream_message_container is not None: - self._set_raw_message_text(self._stream_message_container, self._stream_buffer) + self._message_menu.set_raw_text(self._stream_message_container, self._stream_buffer) document["chat-history"].scrollTop = document["chat-history"].scrollHeight except Exception as e: print(f"Error handling stream token: {e}") @@ -1147,7 +906,7 @@ def _finalize_stream_message(self, final_message: Optional[str] = None) -> None: # If we have reasoning content and actual text, create a combined element if self._reasoning_buffer and self._stream_message_container is not None: # Preserve raw source for copy actions - self._set_raw_message_text(self._stream_message_container, text_to_render) + self._message_menu.set_raw_text(self._stream_message_container, text_to_render) if text_to_render and self._stream_content_element is not None: # Update the response content with parsed markdown parsed_content = self._parse_markdown_to_html(text_to_render) @@ -1183,7 +942,7 @@ def _finalize_stream_message(self, final_message: Optional[str] = None) -> None: elif text_to_render: if self._tool_call_log.element is not None and self._stream_message_container is not None: # Tool call log exists — update the container in place to preserve the dropdown - self._set_raw_message_text(self._stream_message_container, text_to_render) + self._message_menu.set_raw_text(self._stream_message_container, text_to_render) if self._stream_content_element is not None: parsed_content = self._parse_markdown_to_html(text_to_render) self._stream_content_element.innerHTML = parsed_content @@ -1502,8 +1261,8 @@ def _print_system_message_in_chat(self, message: str) -> None: message_container <= content_element # Store raw text for copy actions - self._set_raw_message_text(message_container, message) - self._attach_message_menu(message_container) + self._message_menu.set_raw_text(message_container, message) + self._message_menu.attach(message_container) # Add to chat history document["chat-history"] <= message_container diff --git a/static/client/client_tests/test_chat_message_menu.py b/static/client/client_tests/test_chat_message_menu.py index 3c2f2172..bcd3acfd 100644 --- a/static/client/client_tests/test_chat_message_menu.py +++ b/static/client/client_tests/test_chat_message_menu.py @@ -5,7 +5,7 @@ from browser import html, window -from ai_interface import AIInterface +from message_menu_manager import MessageMenuManager from .simple_mock import SimpleMock @@ -49,18 +49,16 @@ def _get_class_attr(node: Any) -> str: class TestChatMessageMenu(unittest.TestCase): def test_copy_message_text_uses_raw_source(self) -> None: - # Create an AIInterface instance without running __init__ to avoid heavy dependencies. - ai = AIInterface.__new__(AIInterface) - ai._open_message_menu = None - ai._message_menu_global_bound = True # Avoid binding document handlers in tests. + # Create a MessageMenuManager instance without TTS callbacks. + mgr = MessageMenuManager() copy_mock = SimpleMock(return_value=True) - ai._copy_text_to_clipboard = copy_mock + mgr.copy_to_clipboard = copy_mock container = html.DIV() raw_text = "Hello \\(x^2\\)" - ai._set_raw_message_text(container, raw_text) - ai._attach_message_menu(container) + mgr.set_raw_text(container, raw_text) + mgr.attach(container) menu_button: Optional[Any] = None menu: Optional[Any] = None diff --git a/static/client/message_menu_manager.py b/static/client/message_menu_manager.py new file mode 100644 index 00000000..a45e4703 --- /dev/null +++ b/static/client/message_menu_manager.py @@ -0,0 +1,291 @@ +"""Message menu manager for the AI interface. + +Manages the per-message context menu (copy, TTS, etc.) that appears on +chat messages. Handles global click-to-close behaviour, clipboard +operations, and menu item construction. + +Extracted from ``AIInterface`` to reduce god-class complexity while +preserving the identical public behaviour. +""" + +from __future__ import annotations + +from typing import Any, Callable, Optional + +from browser import document, html, window + + +class MessageMenuManager: + """Manages per-message context menus in the chat interface. + + Attributes: + _open_menu: The currently visible menu DOM element, or ``None``. + _global_bound: Whether the document-level click handler has been + registered (to avoid duplicate bindings). + """ + + def __init__( + self, + on_read_aloud: Optional[Callable[[str, Any], None]] = None, + on_tts_settings: Optional[Callable[[], None]] = None, + ) -> None: + self._open_menu: Optional[Any] = None + self._global_bound: bool = False + self._on_read_aloud = on_read_aloud + self._on_tts_settings = on_tts_settings + + # ── Raw text storage ──────────────────────────────────────── + + def set_raw_text(self, container: Any, text: str) -> None: + """Attach raw message source text to a container for later actions (copy, etc.).""" + try: + setattr(container, "_raw_message_text", text) + except Exception: + pass + + def get_raw_text(self, container: Any) -> str: + """Return the stored raw message text from the container, or empty string if missing.""" + try: + value = getattr(container, "_raw_message_text", "") + if isinstance(value, str): + return value + return str(value) + except Exception: + return "" + + # ── Clipboard ─────────────────────────────────────────────── + + def copy_to_clipboard(self, text: str) -> bool: + """Copy text to clipboard using the modern API with a fallback for older contexts.""" + if text is None: + text = "" + if not isinstance(text, str): + try: + text = str(text) + except Exception: + text = "" + + # Prefer navigator.clipboard when available (may require secure context). + try: + navigator = getattr(window, "navigator", None) + clipboard = getattr(navigator, "clipboard", None) if navigator is not None else None + write_text = getattr(clipboard, "writeText", None) if clipboard is not None else None + if callable(write_text): + write_text(text) + return True + except Exception: + pass + + # Fallback: temporary textarea + execCommand('copy') + try: + textarea = html.TEXTAREA() + textarea.value = text + textarea.attrs["readonly"] = "readonly" + textarea.style.position = "fixed" + textarea.style.left = "0" + textarea.style.top = "0" + textarea.style.opacity = "0" + + # Append to DOM, select content, copy, then remove. + document <= textarea + try: + textarea.focus() + except Exception: + pass + try: + textarea.select() + except Exception: + pass + try: + textarea.setSelectionRange(0, len(text)) + except Exception: + pass + + copied = False + try: + copied = bool(window.document.execCommand("copy")) + except Exception: + try: + copied = bool(document.execCommand("copy")) + except Exception: + copied = False + + try: + textarea.remove() + except Exception: + pass + + return copied + except Exception: + return False + + # ── Global handlers ───────────────────────────────────────── + + def bind_global_handlers(self) -> None: + """Bind global document handlers needed for message menus (close on outside click).""" + if self._global_bound: + return + try: + document.bind("click", self._on_document_click) + self._global_bound = True + except Exception: + self._global_bound = False + + def _on_document_click(self, _event: Any) -> None: + """Close any open message menu when clicking outside of it.""" + try: + if self._open_menu is not None: + self._hide(self._open_menu) + except Exception: + self._open_menu = None + + # ── Show / hide / toggle ──────────────────────────────────── + + def _hide(self, menu: Any) -> None: + try: + menu.style.display = "none" + except Exception: + pass + if self._open_menu is menu: + self._open_menu = None + + def _show(self, menu: Any) -> None: + try: + if self._open_menu is not None and self._open_menu is not menu: + self._hide(self._open_menu) + except Exception: + self._open_menu = None + + try: + menu.style.display = "block" + except Exception: + pass + self._open_menu = menu + + def _toggle(self, menu: Any) -> None: + try: + current_display = getattr(menu.style, "display", "") + except Exception: + current_display = "" + + if current_display == "none" or not current_display: + self._show(menu) + else: + self._hide(menu) + + # ── Attach menu to a message container ────────────────────── + + def attach(self, container: Any, is_ai_message: bool = False) -> None: + """Attach the per-message '...' menu to the message container (idempotent). + + Args: + container: The DOM element to attach the menu to + is_ai_message: Whether this is an AI message (enables TTS option) + """ + try: + if bool(getattr(container, "_has_message_menu", False)): + return + setattr(container, "_has_message_menu", True) + except Exception: + # If we cannot track state on the element, continue anyway. + pass + + self.bind_global_handlers() + + menu_button = html.BUTTON("...", Class="chat-message-menu-button") + try: + menu_button.attrs["type"] = "button" + menu_button.attrs["title"] = "Message options" + menu_button.attrs["aria-label"] = "Message options" + except Exception: + pass + + menu = html.DIV(Class="chat-message-menu") + try: + menu.style.display = "none" + except Exception: + pass + + copy_item = html.BUTTON("Copy message text", Class="chat-message-menu-item") + try: + copy_item.attrs["type"] = "button" + except Exception: + pass + + def _stop_propagation(ev: Any) -> None: + try: + ev.stopPropagation() + except Exception: + pass + + def _on_menu_button_click(ev: Any) -> None: + _stop_propagation(ev) + self._toggle(menu) + + def _on_menu_click(ev: Any) -> None: + _stop_propagation(ev) + + def _on_copy_click(ev: Any) -> None: + _stop_propagation(ev) + raw_text = self.get_raw_text(container) + self.copy_to_clipboard(raw_text) + self._hide(menu) + + try: + menu_button.bind("click", _on_menu_button_click) + menu.bind("click", _on_menu_click) + copy_item.bind("click", _on_copy_click) + except Exception: + pass + + menu <= copy_item + + # Add TTS options for AI messages when callbacks are provided + if is_ai_message and self._on_read_aloud is not None: + read_aloud_item = html.BUTTON("Read aloud", Class="chat-message-menu-item tts-read-aloud") + try: + read_aloud_item.attrs["type"] = "button" + except Exception: + pass + + def _on_read_aloud_click(ev: Any) -> None: + _stop_propagation(ev) + self._hide(menu) + raw_text = self.get_raw_text(container) + if self._on_read_aloud is not None: + self._on_read_aloud(raw_text, read_aloud_item) + + try: + read_aloud_item.bind("click", _on_read_aloud_click) + except Exception: + pass + + menu <= read_aloud_item + + # TTS settings option + if self._on_tts_settings is not None: + tts_settings_item = html.BUTTON("TTS settings...", Class="chat-message-menu-item") + try: + tts_settings_item.attrs["type"] = "button" + except Exception: + pass + + def _on_tts_settings_click(ev: Any) -> None: + _stop_propagation(ev) + self._hide(menu) + if self._on_tts_settings is not None: + self._on_tts_settings() + + try: + tts_settings_item.bind("click", _on_tts_settings_click) + except Exception: + pass + + menu <= tts_settings_item + + # Add button + menu to the message container (positioned by CSS). + try: + container <= menu_button + container <= menu + except Exception: + pass From b978da71cfe5801c625e76700c20e4e0bebc854f Mon Sep 17 00:00:00 2001 From: Vlad Ciobanu <95963142+vl3c@users.noreply.github.com> Date: Tue, 24 Mar 2026 23:28:04 +0200 Subject: [PATCH 19/28] Extract ImageAttachmentManager from AIInterface Move image attachment state and 11 methods into a dedicated ImageAttachmentManager class. System messages routed via callback. AIInterface keeps thin delegation wrappers for external callers. --- static/client/ai_interface.py | 203 ++----------------- static/client/image_attachment_manager.py | 225 ++++++++++++++++++++++ 2 files changed, 239 insertions(+), 189 deletions(-) create mode 100644 static/client/image_attachment_manager.py diff --git a/static/client/ai_interface.py b/static/client/ai_interface.py index e0dad0fc..80a82fe4 100644 --- a/static/client/ai_interface.py +++ b/static/client/ai_interface.py @@ -38,8 +38,6 @@ from browser import document, html, ajax, window, console, aio from constants import ( AI_RESPONSE_TIMEOUT_MS, - IMAGE_SIZE_WARNING_BYTES, - MAX_ATTACHED_IMAGES, REASONING_TIMEOUT_MS, ) from function_registry import FunctionRegistry @@ -48,6 +46,7 @@ from markdown_parser import MarkdownParser from tool_call_log_manager import ToolCallLogManager from message_menu_manager import MessageMenuManager +from image_attachment_manager import ImageAttachmentManager from slash_command_handler import SlashCommandHandler from command_autocomplete import CommandAutocomplete from tts_controller import get_tts_controller, TTSController @@ -117,8 +116,10 @@ def __init__(self, canvas: "Canvas") -> None: on_read_aloud=self._handle_tts_read_aloud, on_tts_settings=self._show_tts_settings_modal, ) - # Image attachment state - self._attached_images: list[str] = [] # Data URLs of attached images + # Image attachment state (delegated to ImageAttachmentManager) + self._image_attachment = ImageAttachmentManager( + on_system_message=self._print_system_message_in_chat, + ) # Message recovery state self._last_user_message: str = "" # Buffered message for recovery on error # TTS state @@ -272,188 +273,12 @@ def initialize_autocomplete(self) -> None: print(f"Error initializing command autocomplete: {e}") def initialize_image_attachment(self) -> None: - """Initialize image attachment functionality. - - Binds event handlers for the attach button and file input. - Should be called after the DOM is ready. - """ - try: - # Bind attach button click - if "attach-button" in document: - document["attach-button"].bind("click", self._on_attach_button_click) - - # Bind file input change - if "image-attach-input" in document: - document["image-attach-input"].bind("change", self._on_files_selected) - - # Bind modal close handlers - if "image-modal" in document: - modal = document["image-modal"] - modal.bind("click", self._on_modal_backdrop_click) - - close_btn = document.select_one(".image-modal-close") - if close_btn: - close_btn.bind("click", self._close_image_modal) - except Exception as e: - print(f"Error initializing image attachment: {e}") - - def _on_attach_button_click(self, event: Any) -> None: - """Handle attach button click - trigger file picker.""" - try: - if "image-attach-input" in document: - document["image-attach-input"].click() - except Exception as e: - print(f"Error triggering file picker: {e}") + """Initialize image attachment functionality (delegates to ImageAttachmentManager).""" + self._image_attachment.initialize() def trigger_file_picker(self) -> None: - """Programmatically trigger the file picker for image attachment. - - This is called by the /attach slash command. - """ - self._on_attach_button_click(None) - - def _on_files_selected(self, event: Any) -> None: - """Handle file input change - read selected files as data URLs.""" - try: - file_input = event.target - files = file_input.files - - if not files or files.length == 0: - return - - # Check if we've hit the limit - current_count = len(self._attached_images) - remaining = MAX_ATTACHED_IMAGES - current_count - - if remaining <= 0: - self._print_system_message_in_chat( - f"Maximum of {MAX_ATTACHED_IMAGES} images per message. Remove some to add more." - ) - file_input.value = "" - return - - files_to_process = min(files.length, remaining) - if files.length > remaining: - self._print_system_message_in_chat( - f"Only attaching {remaining} of {files.length} images (limit: {MAX_ATTACHED_IMAGES})." - ) - - for i in range(files_to_process): - file = files[i] - self._read_and_attach_image(file) - - # Clear the input so the same file can be selected again - file_input.value = "" - except Exception as e: - print(f"Error handling file selection: {e}") - - def _read_and_attach_image(self, file: Any) -> None: - """Read an image file and add it to the attached images list.""" - try: - # Check file size - if hasattr(file, "size") and file.size > IMAGE_SIZE_WARNING_BYTES: - size_mb = file.size / (1024 * 1024) - self._print_system_message_in_chat( - f"Warning: Image '{file.name}' is {size_mb:.1f}MB. Large images may slow down processing." - ) - - # Create FileReader to convert to data URL - reader = window.FileReader.new() - - def on_load(event: Any) -> None: - try: - data_url = reader.result - if isinstance(data_url, str) and data_url.startswith("data:image"): - self._attached_images.append(data_url) - self._update_preview_area() - except Exception as e: - print(f"Error processing image: {e}") - - reader.onload = on_load - reader.readAsDataURL(file) - except Exception as e: - print(f"Error reading image file: {e}") - - def _update_preview_area(self) -> None: - """Update the image preview area to reflect current attached images.""" - try: - preview_area = document["image-preview-area"] - - # Clear existing previews - preview_area.clear() - - if not self._attached_images: - preview_area.style.display = "none" - return - - preview_area.style.display = "flex" - - for idx, data_url in enumerate(self._attached_images): - # Create preview item container - item = html.DIV(Class="image-preview-item") - - # Create thumbnail image - img = html.IMG(src=data_url) - item <= img - - # Create remove button - remove_btn = html.BUTTON("\u00d7", Class="remove-btn") - remove_btn.attrs["title"] = "Remove image" - - # Bind remove handler with closure for index - def make_remove_handler(index: int) -> Any: - def handler(event: Any) -> None: - event.stopPropagation() - self._remove_attached_image(index) - - return handler - - remove_btn.bind("click", make_remove_handler(idx)) - item <= remove_btn - - preview_area <= item - except Exception as e: - print(f"Error updating preview area: {e}") - - def _remove_attached_image(self, index: int) -> None: - """Remove an attached image by index.""" - try: - if 0 <= index < len(self._attached_images): - self._attached_images.pop(index) - self._update_preview_area() - except Exception as e: - print(f"Error removing attached image: {e}") - - def _clear_attached_images(self) -> None: - """Clear all attached images.""" - self._attached_images = [] - self._update_preview_area() - - def _show_image_modal(self, data_url: str) -> None: - """Display an image in full-size modal.""" - try: - modal = document["image-modal"] - modal_img = document["image-modal-img"] - modal_img.src = data_url - modal.style.display = "flex" - except Exception as e: - print(f"Error showing image modal: {e}") - - def _close_image_modal(self, event: Any = None) -> None: - """Close the image modal.""" - try: - modal = document["image-modal"] - modal.style.display = "none" - except Exception as e: - print(f"Error closing image modal: {e}") - - def _on_modal_backdrop_click(self, event: Any) -> None: - """Close modal when clicking outside the image.""" - try: - if event.target.id == "image-modal": - self._close_image_modal() - except Exception as e: - print(f"Error handling modal click: {e}") + """Programmatically trigger the file picker for image attachment.""" + self._image_attachment.trigger_file_picker() def _store_results_in_canvas_state(self, call_results: Dict[str, Any]) -> None: """Store valid function call results in the canvas state, skipping special cases and formatting values.""" @@ -731,7 +556,7 @@ def _create_message_element( # Bind click to show modal def make_image_click_handler(url: str) -> Any: def handler(event: Any) -> None: - self._show_image_modal(url) + self._image_attachment.show_modal(url) return handler @@ -1763,7 +1588,7 @@ def send_user_message(self, message: str) -> None: Allows sending with just attached images (empty message). """ has_text = bool(message.strip()) - has_images = len(self._attached_images) > 0 + has_images = len(self._image_attachment.images) > 0 # Need either text or images to send if self.is_processing or (not has_text and not has_images): @@ -1777,7 +1602,7 @@ def send_user_message(self, message: str) -> None: return # Capture attached images before clearing - images_to_send = list(self._attached_images) if self._attached_images else None + images_to_send = list(self._image_attachment.images) if self._image_attachment.images else None # Use a default message for image-only sends display_message = message if has_text else "[Image attached]" @@ -1787,7 +1612,7 @@ def send_user_message(self, message: str) -> None: self._print_user_message_in_chat(display_message, images=images_to_send) # Clear attached images after displaying (not after successful send) - self._clear_attached_images() + self._image_attachment.clear() # Regular AI flow self._disable_send_controls() @@ -1879,7 +1704,7 @@ def interact_with_ai(self, event: Any) -> None: # Get the user's message from the input field user_message = document["chat-input"].value.strip() - has_images = len(self._attached_images) > 0 + has_images = len(self._image_attachment.images) > 0 # Allow sending if there's text OR attached images if user_message or has_images: diff --git a/static/client/image_attachment_manager.py b/static/client/image_attachment_manager.py new file mode 100644 index 00000000..552e295a --- /dev/null +++ b/static/client/image_attachment_manager.py @@ -0,0 +1,225 @@ +"""Image attachment manager for the AI interface. + +Manages the image attachment lifecycle: file picking, reading, preview +area updates, limit enforcement, and the full-size image modal. + +Extracted from ``AIInterface`` to reduce god-class complexity while +preserving the identical public behaviour. +""" + +from __future__ import annotations + +from typing import Any, Callable, Optional + +from browser import document, html, window + +from constants import IMAGE_SIZE_WARNING_BYTES, MAX_ATTACHED_IMAGES + + +class ImageAttachmentManager: + """Manages image attachment state, DOM previews, and the image modal. + + Attributes: + _images: Data URLs of currently attached images. + _on_system_message: Optional callback to display system messages in chat. + """ + + def __init__(self, on_system_message: Optional[Callable[[str], None]] = None) -> None: + self._images: list[str] = [] + self._on_system_message = on_system_message + + # ── Public API ────────────────────────────────────────────── + + @property + def images(self) -> list[str]: + """Return the list of currently attached image data URLs.""" + return self._images + + def initialize(self) -> None: + """Initialize image attachment functionality. + + Binds event handlers for the attach button, file input, and modal. + Should be called after the DOM is ready. + """ + try: + # Bind attach button click + if "attach-button" in document: + document["attach-button"].bind("click", self._on_attach_button_click) + + # Bind file input change + if "image-attach-input" in document: + document["image-attach-input"].bind("change", self._on_files_selected) + + # Bind modal close handlers + if "image-modal" in document: + modal = document["image-modal"] + modal.bind("click", self._on_modal_backdrop_click) + + close_btn = document.select_one(".image-modal-close") + if close_btn: + close_btn.bind("click", self._close_modal) + except Exception as e: + print(f"Error initializing image attachment: {e}") + + def trigger_file_picker(self) -> None: + """Programmatically trigger the file picker for image attachment. + + This is called by the /attach slash command. + """ + self._on_attach_button_click(None) + + def clear(self) -> None: + """Clear all attached images.""" + self._images = [] + self._update_preview_area() + + def show_modal(self, data_url: str) -> None: + """Display an image in full-size modal.""" + try: + modal = document["image-modal"] + modal_img = document["image-modal-img"] + modal_img.src = data_url + modal.style.display = "flex" + except Exception as e: + print(f"Error showing image modal: {e}") + + # ── Internal handlers ─────────────────────────────────────── + + def _on_attach_button_click(self, event: Any) -> None: + """Handle attach button click - trigger file picker.""" + try: + if "image-attach-input" in document: + document["image-attach-input"].click() + except Exception as e: + print(f"Error triggering file picker: {e}") + + def _on_files_selected(self, event: Any) -> None: + """Handle file input change - read selected files as data URLs.""" + try: + file_input = event.target + files = file_input.files + + if not files or files.length == 0: + return + + # Check if we've hit the limit + current_count = len(self._images) + remaining = MAX_ATTACHED_IMAGES - current_count + + if remaining <= 0: + if self._on_system_message: + self._on_system_message( + f"Maximum of {MAX_ATTACHED_IMAGES} images per message. Remove some to add more." + ) + file_input.value = "" + return + + files_to_process = min(files.length, remaining) + if files.length > remaining: + if self._on_system_message: + self._on_system_message( + f"Only attaching {remaining} of {files.length} images (limit: {MAX_ATTACHED_IMAGES})." + ) + + for i in range(files_to_process): + file = files[i] + self._read_and_attach_image(file) + + # Clear the input so the same file can be selected again + file_input.value = "" + except Exception as e: + print(f"Error handling file selection: {e}") + + def _read_and_attach_image(self, file: Any) -> None: + """Read an image file and add it to the attached images list.""" + try: + # Check file size + if hasattr(file, "size") and file.size > IMAGE_SIZE_WARNING_BYTES: + size_mb = file.size / (1024 * 1024) + if self._on_system_message: + self._on_system_message( + f"Warning: Image '{file.name}' is {size_mb:.1f}MB. Large images may slow down processing." + ) + + # Create FileReader to convert to data URL + reader = window.FileReader.new() + + def on_load(event: Any) -> None: + try: + data_url = reader.result + if isinstance(data_url, str) and data_url.startswith("data:image"): + self._images.append(data_url) + self._update_preview_area() + except Exception as e: + print(f"Error processing image: {e}") + + reader.onload = on_load + reader.readAsDataURL(file) + except Exception as e: + print(f"Error reading image file: {e}") + + def _update_preview_area(self) -> None: + """Update the image preview area to reflect current attached images.""" + try: + preview_area = document["image-preview-area"] + + # Clear existing previews + preview_area.clear() + + if not self._images: + preview_area.style.display = "none" + return + + preview_area.style.display = "flex" + + for idx, data_url in enumerate(self._images): + # Create preview item container + item = html.DIV(Class="image-preview-item") + + # Create thumbnail image + img = html.IMG(src=data_url) + item <= img + + # Create remove button + remove_btn = html.BUTTON("\u00d7", Class="remove-btn") + remove_btn.attrs["title"] = "Remove image" + + # Bind remove handler with closure for index + def make_remove_handler(index: int) -> Any: + def handler(event: Any) -> None: + event.stopPropagation() + self._remove_image(index) + + return handler + + remove_btn.bind("click", make_remove_handler(idx)) + item <= remove_btn + + preview_area <= item + except Exception as e: + print(f"Error updating preview area: {e}") + + def _remove_image(self, index: int) -> None: + """Remove an attached image by index.""" + try: + if 0 <= index < len(self._images): + self._images.pop(index) + self._update_preview_area() + except Exception as e: + print(f"Error removing attached image: {e}") + + def _close_modal(self, event: Any = None) -> None: + """Close the image modal.""" + try: + modal = document["image-modal"] + modal.style.display = "none" + except Exception as e: + print(f"Error closing image modal: {e}") + + def _on_modal_backdrop_click(self, event: Any) -> None: + """Close modal when clicking outside the image.""" + try: + if event.target.id == "image-modal": + self._close_modal() + except Exception as e: + print(f"Error handling modal click: {e}") From 2a74c95ff5a3f20eb6515577d961f436a69c83c9 Mon Sep 17 00:00:00 2001 From: Vlad Ciobanu <95963142+vl3c@users.noreply.github.com> Date: Tue, 24 Mar 2026 23:31:12 +0200 Subject: [PATCH 20/28] Extract TTSUIManager from AIInterface Move TTS read-aloud, markdown stripping, and settings modal into a dedicated TTSUIManager class. Wired to MessageMenuManager via callbacks for read-aloud and settings menu items. --- static/client/ai_interface.py | 195 +--------------------------- static/client/tts_ui_manager.py | 222 ++++++++++++++++++++++++++++++++ 2 files changed, 227 insertions(+), 190 deletions(-) create mode 100644 static/client/tts_ui_manager.py diff --git a/static/client/ai_interface.py b/static/client/ai_interface.py index 80a82fe4..247ef842 100644 --- a/static/client/ai_interface.py +++ b/static/client/ai_interface.py @@ -31,7 +31,6 @@ from __future__ import annotations import json -import re import traceback from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, cast @@ -49,7 +48,7 @@ from image_attachment_manager import ImageAttachmentManager from slash_command_handler import SlashCommandHandler from command_autocomplete import CommandAutocomplete -from tts_controller import get_tts_controller, TTSController +from tts_ui_manager import TTSUIManager from managers.action_trace_collector import ActionTraceCollector if TYPE_CHECKING: @@ -111,10 +110,12 @@ def __init__(self, canvas: "Canvas") -> None: self._tool_call_log = ToolCallLogManager() # Timeout state self._response_timeout_id: Optional[int] = None + # TTS UI (delegated to TTSUIManager) + self._tts_ui = TTSUIManager(on_system_message=self._print_system_message_in_chat) # Chat message menu (delegated to MessageMenuManager) self._message_menu = MessageMenuManager( - on_read_aloud=self._handle_tts_read_aloud, - on_tts_settings=self._show_tts_settings_modal, + on_read_aloud=self._tts_ui.handle_read_aloud, + on_tts_settings=self._tts_ui.show_settings_modal, ) # Image attachment state (delegated to ImageAttachmentManager) self._image_attachment = ImageAttachmentManager( @@ -122,9 +123,6 @@ def __init__(self, canvas: "Canvas") -> None: ) # Message recovery state self._last_user_message: str = "" # Buffered message for recovery on error - # TTS state - self._tts_controller: TTSController = get_tts_controller() - self._tts_settings_modal: Optional[Any] = None # DOMNode for TTS settings modal # Action trace collector for deterministic tool-execution logs self._trace_collector: ActionTraceCollector = ActionTraceCollector() self._register_trace_js_api() @@ -325,189 +323,6 @@ def _render_math(self) -> None: # MathJax not available or error occurred, continue silently pass - def _handle_tts_read_aloud(self, text: str, button_element: Any) -> None: - """Handle TTS read aloud action. - - Args: - text: Text to read aloud - button_element: The menu button to update based on state - """ - if not text or not text.strip(): - return - - # If already playing, stop instead - if self._tts_controller.is_playing(): - self._tts_controller.stop() - return - - # Set up state change callback to update button text - def on_state_change(state: str) -> None: - try: - if state == "loading": - button_element.text = "Loading..." - button_element.classList.add("tts-loading") - button_element.classList.remove("tts-playing") - elif state == "playing": - button_element.text = "Stop reading" - button_element.classList.remove("tts-loading") - button_element.classList.add("tts-playing") - else: - button_element.text = "Read aloud" - button_element.classList.remove("tts-loading") - button_element.classList.remove("tts-playing") - except Exception: - pass - - # Set up error callback to show message to user - def on_error(message: str) -> None: - self._print_system_message_in_chat(message) - - self._tts_controller.on_state_change = on_state_change - self._tts_controller.on_error = on_error - - # Strip markdown formatting for cleaner TTS (basic cleanup) - clean_text = self._strip_markdown_for_tts(text) - - # Start TTS - self._tts_controller.speak(clean_text) - - def _strip_markdown_for_tts(self, text: str) -> str: - """Strip markdown formatting from text for cleaner TTS output. - - Args: - text: Text with potential markdown formatting - - Returns: - Clean text suitable for TTS - """ - result = text - - # Remove code blocks - result = re.sub(r"```[\s\S]*?```", "", result) - result = re.sub(r"`[^`]+`", "", result) - - # Remove headers - result = re.sub(r"^#{1,6}\s+", "", result, flags=re.MULTILINE) - - # Remove bold/italic - result = re.sub(r"\*\*([^*]+)\*\*", r"\1", result) - result = re.sub(r"\*([^*]+)\*", r"\1", result) - result = re.sub(r"__([^_]+)__", r"\1", result) - result = re.sub(r"_([^_]+)_", r"\1", result) - - # Remove links, keep text - result = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", result) - - # Remove images - result = re.sub(r"!\[[^\]]*\]\([^)]+\)", "", result) - - # Remove horizontal rules - result = re.sub(r"^[-*_]{3,}$", "", result, flags=re.MULTILINE) - - # Clean up extra whitespace - result = re.sub(r"\n{3,}", "\n\n", result) - result = result.strip() - - return result - - def _show_tts_settings_modal(self) -> None: - """Display the TTS settings modal dialog.""" - # Remove existing modal if present - self._close_tts_settings_modal() - - # Create modal backdrop - modal = html.DIV(Class="tts-settings-modal") - modal.id = "tts-settings-modal" - - # Create modal content - content = html.DIV(Class="tts-settings-content") - - # Header - header = html.DIV(Class="tts-settings-header") - title = html.H3("TTS Settings") - close_btn = html.BUTTON("\u00d7", Class="tts-settings-close") - close_btn.attrs["type"] = "button" - close_btn.attrs["title"] = "Close" - header <= title - header <= close_btn - content <= header - - # Voice selection - voice_group = html.DIV(Class="tts-settings-group") - voice_label = html.LABEL("Voice:") - voice_select = html.SELECT(id="tts-voice-select") - - # Add voice options - # Note: Voice IDs must match TTSManager.VOICES in static/tts_manager.py - voices = [ - ("am_michael", "Michael (Male)"), - ("am_fenrir", "Fenrir (Male, deeper)"), - ("am_onyx", "Onyx (Male, darker)"), - ("am_echo", "Echo (Male, resonant)"), - ("af_nova", "Nova (Female)"), - ("af_bella", "Bella (Female, warm)"), - ] - current_voice = self._tts_controller.get_voice() - for voice_id, voice_name in voices: - option = html.OPTION(voice_name, value=voice_id) - if voice_id == current_voice: - option.attrs["selected"] = "selected" - voice_select <= option - - voice_group <= voice_label - voice_group <= voice_select - content <= voice_group - - # Buttons - buttons = html.DIV(Class="tts-settings-buttons") - save_btn = html.BUTTON("Save", Class="tts-settings-save") - save_btn.attrs["type"] = "button" - cancel_btn = html.BUTTON("Cancel", Class="tts-settings-cancel") - cancel_btn.attrs["type"] = "button" - buttons <= save_btn - buttons <= cancel_btn - content <= buttons - - modal <= content - - # Bind events - def on_close(ev: Any) -> None: - self._close_tts_settings_modal() - - def on_save(ev: Any) -> None: - try: - voice_value = document["tts-voice-select"].value - self._tts_controller.set_voice(voice_value) - except Exception as e: - print(f"Error saving TTS settings: {e}") - self._close_tts_settings_modal() - - def on_backdrop_click(ev: Any) -> None: - if ev.target == modal: - self._close_tts_settings_modal() - - close_btn.bind("click", on_close) - cancel_btn.bind("click", on_close) - save_btn.bind("click", on_save) - modal.bind("click", on_backdrop_click) - - # Add to document - document <= modal - self._tts_settings_modal = modal - - def _close_tts_settings_modal(self) -> None: - """Close and remove the TTS settings modal.""" - try: - if self._tts_settings_modal: - self._tts_settings_modal.remove() - self._tts_settings_modal = None - # Also try by ID in case reference was lost - existing = document.select_one("#tts-settings-modal") - if existing: - existing.remove() - except Exception: - pass - def _create_message_element( self, sender: str, diff --git a/static/client/tts_ui_manager.py b/static/client/tts_ui_manager.py new file mode 100644 index 00000000..28ebfe18 --- /dev/null +++ b/static/client/tts_ui_manager.py @@ -0,0 +1,222 @@ +"""TTS UI manager for the AI interface. + +Manages TTS read-aloud actions, markdown stripping for speech, and +the TTS settings modal dialog. + +Extracted from ``AIInterface`` to reduce god-class complexity while +preserving the identical public behaviour. +""" + +from __future__ import annotations + +import re +from typing import Any, Callable, Optional + +from browser import document, html + +from tts_controller import get_tts_controller, TTSController + + +class TTSUIManager: + """Manages TTS UI interactions: read-aloud, settings modal, and text cleanup. + + Attributes: + _tts_controller: The browser-side TTS playback controller. + _settings_modal: The currently open settings modal DOM element, or ``None``. + _on_system_message: Optional callback to display system messages in chat. + """ + + def __init__(self, on_system_message: Optional[Callable[[str], None]] = None) -> None: + self._tts_controller: TTSController = get_tts_controller() + self._settings_modal: Optional[Any] = None + self._on_system_message = on_system_message + + # ── Read aloud ─────────────────────────────────────────────── + + def handle_read_aloud(self, text: str, button_element: Any) -> None: + """Handle TTS read aloud action. + + Args: + text: Text to read aloud + button_element: The menu button to update based on state + """ + if not text or not text.strip(): + return + + # If already playing, stop instead + if self._tts_controller.is_playing(): + self._tts_controller.stop() + return + + # Set up state change callback to update button text + def on_state_change(state: str) -> None: + try: + if state == "loading": + button_element.text = "Loading..." + button_element.classList.add("tts-loading") + button_element.classList.remove("tts-playing") + elif state == "playing": + button_element.text = "Stop reading" + button_element.classList.remove("tts-loading") + button_element.classList.add("tts-playing") + else: + button_element.text = "Read aloud" + button_element.classList.remove("tts-loading") + button_element.classList.remove("tts-playing") + except Exception: + pass + + # Set up error callback to show message to user + def on_error(message: str) -> None: + if self._on_system_message is not None: + self._on_system_message(message) + + self._tts_controller.on_state_change = on_state_change + self._tts_controller.on_error = on_error + + # Strip markdown formatting for cleaner TTS (basic cleanup) + clean_text = self.strip_markdown(text) + + # Start TTS + self._tts_controller.speak(clean_text) + + # ── Markdown stripping ─────────────────────────────────────── + + def strip_markdown(self, text: str) -> str: + """Strip markdown formatting from text for cleaner TTS output. + + Args: + text: Text with potential markdown formatting + + Returns: + Clean text suitable for TTS + """ + result = text + + # Remove code blocks + result = re.sub(r"```[\s\S]*?```", "", result) + result = re.sub(r"`[^`]+`", "", result) + + # Remove headers + result = re.sub(r"^#{1,6}\s+", "", result, flags=re.MULTILINE) + + # Remove bold/italic + result = re.sub(r"\*\*([^*]+)\*\*", r"\1", result) + result = re.sub(r"\*([^*]+)\*", r"\1", result) + result = re.sub(r"__([^_]+)__", r"\1", result) + result = re.sub(r"_([^_]+)_", r"\1", result) + + # Remove links, keep text + result = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", result) + + # Remove images + result = re.sub(r"!\[[^\]]*\]\([^)]+\)", "", result) + + # Remove horizontal rules + result = re.sub(r"^[-*_]{3,}$", "", result, flags=re.MULTILINE) + + # Clean up extra whitespace + result = re.sub(r"\n{3,}", "\n\n", result) + result = result.strip() + + return result + + # ── Settings modal ─────────────────────────────────────────── + + def show_settings_modal(self) -> None: + """Display the TTS settings modal dialog.""" + # Remove existing modal if present + self._close_settings_modal() + + # Create modal backdrop + modal = html.DIV(Class="tts-settings-modal") + modal.id = "tts-settings-modal" + + # Create modal content + content = html.DIV(Class="tts-settings-content") + + # Header + header = html.DIV(Class="tts-settings-header") + title = html.H3("TTS Settings") + close_btn = html.BUTTON("\u00d7", Class="tts-settings-close") + close_btn.attrs["type"] = "button" + close_btn.attrs["title"] = "Close" + header <= title + header <= close_btn + content <= header + + # Voice selection + voice_group = html.DIV(Class="tts-settings-group") + voice_label = html.LABEL("Voice:") + voice_select = html.SELECT(id="tts-voice-select") + + # Add voice options + # Note: Voice IDs must match TTSManager.VOICES in static/tts_manager.py + voices = [ + ("am_michael", "Michael (Male)"), + ("am_fenrir", "Fenrir (Male, deeper)"), + ("am_onyx", "Onyx (Male, darker)"), + ("am_echo", "Echo (Male, resonant)"), + ("af_nova", "Nova (Female)"), + ("af_bella", "Bella (Female, warm)"), + ] + current_voice = self._tts_controller.get_voice() + for voice_id, voice_name in voices: + option = html.OPTION(voice_name, value=voice_id) + if voice_id == current_voice: + option.attrs["selected"] = "selected" + voice_select <= option + + voice_group <= voice_label + voice_group <= voice_select + content <= voice_group + + # Buttons + buttons = html.DIV(Class="tts-settings-buttons") + save_btn = html.BUTTON("Save", Class="tts-settings-save") + save_btn.attrs["type"] = "button" + cancel_btn = html.BUTTON("Cancel", Class="tts-settings-cancel") + cancel_btn.attrs["type"] = "button" + buttons <= save_btn + buttons <= cancel_btn + content <= buttons + + modal <= content + + # Bind events + def on_close(ev: Any) -> None: + self._close_settings_modal() + + def on_save(ev: Any) -> None: + try: + voice_value = document["tts-voice-select"].value + self._tts_controller.set_voice(voice_value) + except Exception as e: + print(f"Error saving TTS settings: {e}") + self._close_settings_modal() + + def on_backdrop_click(ev: Any) -> None: + if ev.target == modal: + self._close_settings_modal() + + close_btn.bind("click", on_close) + cancel_btn.bind("click", on_close) + save_btn.bind("click", on_save) + modal.bind("click", on_backdrop_click) + + # Add to document + document <= modal + self._settings_modal = modal + + def _close_settings_modal(self) -> None: + """Close and remove the TTS settings modal.""" + try: + if self._settings_modal: + self._settings_modal.remove() + self._settings_modal = None + # Also try by ID in case reference was lost + existing = document.select_one("#tts-settings-modal") + if existing: + existing.remove() + except Exception: + pass From ac6b042452c378776d9310fa0cd93fa4468d2a7a Mon Sep 17 00:00:00 2001 From: Vlad Ciobanu <95963142+vl3c@users.noreply.github.com> Date: Tue, 24 Mar 2026 23:40:50 +0200 Subject: [PATCH 21/28] Extract ChatUIManager from AIInterface Move streaming state, message rendering, markdown parsing, and token handlers into a dedicated ChatUIManager class (605 lines). AIInterface drops from 1,547 to 1,078 lines, retaining streaming orchestration, request building, and public API delegation. Cross-concern communication via injected callbacks for timeouts, image clicks, and system messages. --- static/client/ai_interface.py | 544 ++-------------- static/client/chat_ui_manager.py | 605 ++++++++++++++++++ .../client_tests/test_error_recovery.py | 36 +- .../client/client_tests/test_tool_call_log.py | 30 +- 4 files changed, 669 insertions(+), 546 deletions(-) create mode 100644 static/client/chat_ui_manager.py diff --git a/static/client/ai_interface.py b/static/client/ai_interface.py index 247ef842..b1fd6e21 100644 --- a/static/client/ai_interface.py +++ b/static/client/ai_interface.py @@ -25,7 +25,7 @@ - function_registry: Available AI function mappings - process_function_calls: Function execution coordination - workspace_manager: File persistence operations - - markdown_parser: Rich text formatting support + - chat_ui_manager: Chat message rendering and streaming display """ from __future__ import annotations @@ -34,7 +34,7 @@ import traceback from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, cast -from browser import document, html, ajax, window, console, aio +from browser import document, ajax, window, console, aio from constants import ( AI_RESPONSE_TIMEOUT_MS, REASONING_TIMEOUT_MS, @@ -42,13 +42,13 @@ from function_registry import FunctionRegistry from process_function_calls import ProcessFunctionCalls from workspace_manager import WorkspaceManager -from markdown_parser import MarkdownParser from tool_call_log_manager import ToolCallLogManager from message_menu_manager import MessageMenuManager from image_attachment_manager import ImageAttachmentManager from slash_command_handler import SlashCommandHandler from command_autocomplete import CommandAutocomplete from tts_ui_manager import TTSUIManager +from chat_ui_manager import ChatUIManager from managers.action_trace_collector import ActionTraceCollector if TYPE_CHECKING: @@ -67,7 +67,7 @@ class AIInterface: is_processing (bool): Tracks whether an AI request is currently being processed available_functions (dict): Registry of all functions available to the AI undoable_functions (tuple): Functions that support undo/redo operations - markdown_parser (MarkdownParser): Converts markdown text to HTML for rich formatting + _chat_ui (ChatUIManager): Manages chat message rendering and streaming display """ def __init__(self, canvas: "Canvas") -> None: @@ -89,23 +89,10 @@ def __init__(self, canvas: "Canvas") -> None: canvas, self.workspace_manager, self ) self.undoable_functions: tuple[str, ...] = FunctionRegistry.get_undoable_functions() - self.markdown_parser: MarkdownParser = MarkdownParser() # Slash command handler for local commands self.slash_command_handler: SlashCommandHandler = SlashCommandHandler(canvas, self.workspace_manager, self) # Command autocomplete popup (initialized lazily when DOM is ready) self.command_autocomplete: Optional[CommandAutocomplete] = None - # Streaming state - self._stream_buffer: str = "" - self._stream_content_element: Optional[Any] = None # DOMNode - self._stream_message_container: Optional[Any] = None # DOMNode - # Reasoning streaming state - self._reasoning_buffer: str = "" - self._reasoning_element: Optional[Any] = None # DOMNode - self._reasoning_details: Optional[Any] = None # DOMNode (details element) - self._reasoning_summary: Optional[Any] = None # DOMNode (summary element) - self._is_reasoning: bool = False - self._request_start_time: Optional[float] = None # Timestamp when user request started - self._needs_continuation_separator: bool = False # Add newline before next text after tool calls # Tool call log state (delegated to ToolCallLogManager) self._tool_call_log = ToolCallLogManager() # Timeout state @@ -121,6 +108,14 @@ def __init__(self, canvas: "Canvas") -> None: self._image_attachment = ImageAttachmentManager( on_system_message=self._print_system_message_in_chat, ) + # Chat UI (delegated to ChatUIManager) + self._chat_ui = ChatUIManager( + message_menu=self._message_menu, + tool_call_log=self._tool_call_log, + on_image_click=self._image_attachment.show_modal, + on_start_timeout=lambda use_reasoning: self._start_response_timeout(use_reasoning_timeout=use_reasoning), + on_cancel_timeout=self._cancel_response_timeout, + ) # Message recovery state self._last_user_message: str = "" # Buffered message for recovery on error # Action trace collector for deterministic tool-execution logs @@ -308,156 +303,9 @@ def _store_results_in_canvas_state(self, call_results: Dict[str, Any]) -> None: # result=formatted_value # ) - def _parse_markdown_to_html(self, text: str) -> str: - """Parse markdown text to HTML using the dedicated markdown parser.""" - return cast(str, self.markdown_parser.parse(text)) - - def _render_math(self) -> None: - """Trigger MathJax rendering for newly added content.""" - try: - # Check if MathJax is available - if hasattr(window, "MathJax") and hasattr(window.MathJax, "typesetPromise"): - # Re-render math in the chat history - window.MathJax.typesetPromise([document["chat-history"]]) - except Exception: - # MathJax not available or error occurred, continue silently - pass - - def _create_message_element( - self, - sender: str, - message: str, - message_type: str = "normal", - images: Optional[list[str]] = None, - ) -> Any: # DOMNode - """Create a styled message element with markdown support and optional images. - - Args: - sender: The message sender ("User" or "AI") - message: The message text content - message_type: CSS class for message styling ("normal", "system") - images: Optional list of image data URLs to display with the message - - Returns: - DOM element for the message - """ - try: - # Create message container - message_container = html.DIV(Class=f"chat-message {message_type}") - - # Create sender label - sender_label = html.SPAN(f"{sender}: ", Class=f"chat-sender {sender.lower()}") - - # Parse markdown and create content element - if sender == "AI": - parsed_content = self._parse_markdown_to_html(message) - content_element = html.DIV(Class="chat-content markdown") - content_element.innerHTML = parsed_content - else: - # For user messages, keep them as plain text for now - content_element = html.SPAN(message, Class="chat-content") - - # Assemble the message - message_container <= sender_label - message_container <= content_element - - # Add images if provided - if images: - images_container = html.DIV(Class="chat-message-images") - for data_url in images: - img = html.IMG(src=data_url, Class="chat-message-image") - img.attrs["alt"] = "Attached image" - - # Bind click to show modal - def make_image_click_handler(url: str) -> Any: - def handler(event: Any) -> None: - self._image_attachment.show_modal(url) - - return handler - - img.bind("click", make_image_click_handler(data_url)) - images_container <= img - message_container <= images_container - - # Store the raw source text for copy actions (do not rely on rendered HTML) - self._message_menu.set_raw_text(message_container, message) - self._message_menu.attach(message_container, is_ai_message=(sender == "AI")) - - return message_container - - except Exception as e: - print(f"Error creating message element: {e}") - # Fall back to simple paragraph - if sender == "AI": - content = message.replace("\n", "
") - return html.P(f"{sender}: {content}", innerHTML=True) - else: - return html.P(f"{sender}: {message}") - def _print_ai_message_in_chat(self, ai_message: str) -> None: - """Print an AI message to the chat history with markdown support and scroll to bottom.""" - if ai_message: - message_element = self._create_message_element("AI", ai_message) - document["chat-history"] <= message_element - # Trigger MathJax rendering for new content - self._render_math() - # Scroll the chat history to the bottom - document["chat-history"].scrollTop = document["chat-history"].scrollHeight - - def _ensure_stream_message_element(self) -> None: - """Create the streaming AI message element if it does not exist yet.""" - if self._stream_content_element is None: - try: - container = html.DIV(Class="chat-message normal") - label = html.SPAN("AI: ", Class="chat-sender ai") - content = html.DIV(Class="chat-content") - content.text = "" - container <= label - container <= content - document["chat-history"] <= container - self._stream_message_container = container - self._stream_content_element = content - # Initialize raw text storage for streaming content - self._message_menu.set_raw_text(container, "") - self._message_menu.attach(container, is_ai_message=True) - except Exception as e: - print(f"Error creating streaming element: {e}") - - def _ensure_reasoning_element(self) -> None: - """Create the reasoning dropdown element inside the AI message box.""" - if self._reasoning_element is None: - try: - container = html.DIV(Class="chat-message normal") - label = html.SPAN("AI: ", Class="chat-sender ai") - - # Collapsible dropdown for reasoning - details = html.DETAILS(Class="reasoning-dropdown") - # Start collapsed by default (user can expand if curious) - summary = html.SUMMARY("Thinking...", Class="reasoning-summary") - reasoning_content = html.DIV(Class="reasoning-content") - reasoning_content.text = "" - details <= summary - details <= reasoning_content - - # Content area for the actual response (hidden initially) - response_content = html.DIV(Class="chat-content") - response_content.text = "" - - container <= label - container <= details - container <= response_content - document["chat-history"] <= container - - self._reasoning_element = reasoning_content - self._reasoning_details = details - self._reasoning_summary = summary - self._stream_message_container = container - self._stream_content_element = response_content - # Initialize raw text storage for reasoning responses - self._message_menu.set_raw_text(container, "") - self._message_menu.attach(container, is_ai_message=True) - except Exception as e: - print(f"Error creating reasoning element: {e}") + """Print an AI message to the chat history (delegates to ChatUIManager).""" + self._chat_ui.print_ai_message(ai_message) def _on_stream_log(self, event_obj: Any) -> None: """Handle a server log event: output to browser console with appropriate level.""" @@ -480,189 +328,20 @@ def _on_stream_log(self, event_obj: Any) -> None: print(f"Error handling server log event: {e}") def _on_stream_reasoning(self, text: str) -> None: - """Handle a reasoning token: append to reasoning buffer and update UI.""" - try: - # Use extended timeout for reasoning phase - self._start_response_timeout(use_reasoning_timeout=True) - self._is_reasoning = True - - # Don't repeat the placeholder if we already have it - if "(Reasoning in progress...)" in text and "(Reasoning in progress...)" in self._reasoning_buffer: - return - - self._reasoning_buffer += text - self._ensure_reasoning_element() - if self._reasoning_element is not None: - self._reasoning_element.text = self._reasoning_buffer - document["chat-history"].scrollTop = document["chat-history"].scrollHeight - except Exception as e: - print(f"Error handling reasoning token: {e}") + """Handle a reasoning token (delegates to ChatUIManager).""" + self._chat_ui.on_stream_reasoning(text) def _on_stream_token(self, text: str) -> None: - """Handle a streamed token: append to buffer and update the UI element.""" - try: - # Reset timeout since we're receiving data (use normal timeout for response) - self._start_response_timeout(use_reasoning_timeout=False) - - # If we were in reasoning phase, collapse the reasoning dropdown - if self._is_reasoning and self._reasoning_details is not None: - try: - del self._reasoning_details.attrs["open"] - except Exception: - try: - self._reasoning_details.attrs["open"] = False - except Exception: - pass - self._is_reasoning = False - - # When continuing after tool calls, clear the buffer and start fresh - # The AI will re-state any necessary context in its new response - # This prevents duplication when AI restates previous confirmations - if self._needs_continuation_separator: - self._stream_buffer = "" - self._needs_continuation_separator = False - - self._stream_buffer += text - # Use reasoning element's response area if it exists, otherwise create normal element - if self._stream_content_element is None and self._reasoning_element is None: - self._ensure_stream_message_element() - if self._stream_content_element is not None: - self._stream_content_element.text = self._stream_buffer - if self._stream_message_container is not None: - self._message_menu.set_raw_text(self._stream_message_container, self._stream_buffer) - document["chat-history"].scrollTop = document["chat-history"].scrollHeight - except Exception as e: - print(f"Error handling stream token: {e}") + """Handle a streamed token (delegates to ChatUIManager).""" + self._chat_ui.on_stream_token(text) def _finalize_stream_message(self, final_message: Optional[str] = None) -> None: - """Convert the streamed plain text to parsed markdown and render math.""" - try: - self._tool_call_log.finalize() - - # Prefer the accumulated buffer (contains all text across tool calls) - # Only use final_message as fallback if buffer is empty - text_to_render = self._stream_buffer if self._stream_buffer.strip() else (final_message or "") - - # If we have reasoning content and actual text, create a combined element - if self._reasoning_buffer and self._stream_message_container is not None: - # Preserve raw source for copy actions - self._message_menu.set_raw_text(self._stream_message_container, text_to_render) - if text_to_render and self._stream_content_element is not None: - # Update the response content with parsed markdown - parsed_content = self._parse_markdown_to_html(text_to_render) - self._stream_content_element.innerHTML = parsed_content - self._stream_content_element.classList.add("markdown") - - # Update summary to show elapsed time and ensure dropdown stays closed - if self._reasoning_summary is not None and self._request_start_time is not None: - try: - from browser import window - - elapsed_ms = window.Date.now() - self._request_start_time - elapsed_seconds = int(elapsed_ms / 1000) - self._reasoning_summary.text = f"Thought for {elapsed_seconds} seconds" - except Exception: - pass - - # Ensure dropdown is closed - if self._reasoning_details is not None: - try: - del self._reasoning_details.attrs["open"] - except Exception: - try: - self._reasoning_details.attrs["open"] = False - except Exception: - pass - - self._render_math() - document["chat-history"].scrollTop = document["chat-history"].scrollHeight - else: - # Reasoning but no text content - remove the empty container - self._remove_empty_response_container() - elif text_to_render: - if self._tool_call_log.element is not None and self._stream_message_container is not None: - # Tool call log exists — update the container in place to preserve the dropdown - self._message_menu.set_raw_text(self._stream_message_container, text_to_render) - if self._stream_content_element is not None: - parsed_content = self._parse_markdown_to_html(text_to_render) - self._stream_content_element.innerHTML = parsed_content - self._stream_content_element.classList.add("markdown") - self._render_math() - document["chat-history"].scrollTop = document["chat-history"].scrollHeight - else: - # No reasoning or tool log, use standard finalization - final_element = self._create_message_element("AI", text_to_render) - - history = document["chat-history"] - if self._stream_message_container is not None: - try: - history.replaceChild(final_element, self._stream_message_container) - except Exception: - history <= final_element - else: - history <= final_element - - self._render_math() - history.scrollTop = history.scrollHeight - else: - # No text content at all - remove any empty container - self._remove_empty_response_container() - except Exception as e: - print(f"Error finalizing stream message: {e}") - finally: - self._stream_buffer = "" - self._stream_content_element = None - self._stream_message_container = None - self._reasoning_buffer = "" - self._reasoning_element = None - self._reasoning_details = None - self._reasoning_summary = None - self._is_reasoning = False - self._request_start_time = None - self._tool_call_log.reset() + """Finalize the streamed message (delegates to ChatUIManager).""" + self._chat_ui.finalize_stream(final_message) def _remove_empty_response_container(self) -> None: - """Remove the current response container if it has no actual text content. - - This cleans up "Thinking..." boxes when the AI only performs tool calls - without providing a text response. Never removes a container with actual text. - """ - try: - # Check if there's actual text content in buffer or visible in the element - has_buffer_text = bool(self._stream_buffer.strip()) - has_element_text = False - if self._stream_content_element is not None: - try: - element_text = self._stream_content_element.text or self._stream_content_element.innerHTML or "" - has_element_text = bool(element_text.strip()) - except Exception: - pass - has_tool_call_log = bool(self._tool_call_log.entries) - - # Only remove if there's NO actual text content anywhere and no tool call log - if ( - self._stream_message_container is not None - and not has_buffer_text - and not has_element_text - and not has_tool_call_log - ): - history = document["chat-history"] - try: - history.removeChild(self._stream_message_container) - except Exception: - pass - # Reset state - self._stream_message_container = None - self._stream_content_element = None - self._reasoning_element = None - self._reasoning_details = None - self._reasoning_summary = None - self._reasoning_buffer = "" - self._is_reasoning = False - self._tool_call_log.reset() - # Don't reset _request_start_time here - we want to keep timing across tool calls - except Exception as e: - print(f"Error removing empty container: {e}") + """Remove empty response container (delegates to ChatUIManager).""" + self._chat_ui.remove_empty_container() def _on_stream_final(self, event_obj: Any) -> None: """Handle the final event from the streaming response.""" @@ -680,8 +359,8 @@ def _on_stream_final(self, event_obj: Any) -> None: # If no tool calls OR finish reason indicates completion, finalize the message if finish_reason in ("stop", "error", "completed") or not ai_tool_calls: - if not self._stream_buffer and ai_message: - self._stream_buffer = ai_message + if not self._chat_ui.stream_buffer and ai_message: + self._chat_ui.stream_buffer = ai_message self._finalize_stream_message(ai_message or None) # Restore user message on error so they can retry if finish_reason == "error": @@ -705,9 +384,9 @@ def _on_stream_final(self, event_obj: Any) -> None: self.canvas, ) self._store_results_in_canvas_state(call_results) - if self._stream_message_container is None: - self._ensure_stream_message_element() - self._tool_call_log.ensure_element(self._stream_message_container, self._stream_content_element) + if self._chat_ui.stream_container is None: + self._chat_ui.ensure_stream_element() + self._tool_call_log.ensure_element(self._chat_ui.stream_container, self._chat_ui.stream_content) self._tool_call_log.add_entries(ai_tool_calls, call_results) if self._stop_requested: @@ -743,8 +422,8 @@ def _on_stream_final(self, event_obj: Any) -> None: # Reset timeout with extended duration - AI needs time to process tool results self._start_response_timeout(use_reasoning_timeout=True) # Mark that we need a newline separator before the next text - if self._stream_buffer.strip(): - self._needs_continuation_separator = True + if self._chat_ui.stream_buffer.strip(): + self._chat_ui.needs_continuation_separator = True self._send_prompt_to_ai( None, json.dumps(call_results), @@ -854,139 +533,12 @@ def _normalize_stream_event(self, event_obj: Any) -> Dict[str, Any]: return {} def _print_user_message_in_chat(self, user_message: str, images: Optional[list[str]] = None) -> None: - """Print a user message to the chat history and scroll to bottom. - - Args: - user_message: The text message from the user - images: Optional list of image data URLs to display with the message - """ - # Add the user's message to the chat history with markdown support - message_element = self._create_message_element("User", user_message, images=images) - document["chat-history"] <= message_element - # Trigger MathJax rendering for new content - self._render_math() - # Scroll the chat history to the bottom - document["chat-history"].scrollTop = document["chat-history"].scrollHeight + """Print a user message to the chat history (delegates to ChatUIManager).""" + self._chat_ui.print_user_message(user_message, images) def _print_system_message_in_chat(self, message: str) -> None: - """Print a system/command response to the chat history. - - Used for slash command responses that don't come from AI. - - Args: - message: The message to display (supports markdown) - """ - try: - # Create message container with system styling - message_container = html.DIV(Class="chat-message system") - - # Create sender label - sender_label = html.SPAN("System: ", Class="chat-sender system") - - # Check if message is long and needs expandable display - line_count = message.count("\n") - is_long_message = len(message) > 800 or line_count > 20 - - if is_long_message: - # Create expandable content with details/summary - content_element = self._create_expandable_content(message) - else: - # Parse markdown and create content element - parsed_content = self._parse_markdown_to_html(message) - content_element = html.DIV(Class="chat-content markdown") - content_element.innerHTML = parsed_content - - # Assemble the message - message_container <= sender_label - message_container <= content_element - - # Store raw text for copy actions - self._message_menu.set_raw_text(message_container, message) - self._message_menu.attach(message_container) - - # Add to chat history - document["chat-history"] <= message_container - - # Trigger MathJax rendering for new content - self._render_math() - - # Scroll to bottom - document["chat-history"].scrollTop = document["chat-history"].scrollHeight - except Exception as e: - print(f"Error printing system message: {e}") - # Fallback to simple paragraph - fallback = html.P(f"System: {message}") - document["chat-history"] <= fallback - - def _create_expandable_content(self, message: str) -> Any: - """Create an expandable content element for long messages. - - Args: - message: The full message content - - Returns: - A DOM element with expandable content - """ - # Create preview (first ~500 chars or 10 lines) - lines = message.split("\n") - if len(lines) > 10: - preview_text = "\n".join(lines[:10]) + "\n..." - elif len(message) > 500: - preview_text = message[:500] + "..." - else: - preview_text = message - - # Create container - container = html.DIV(Class="chat-content expandable-content") - - # Create preview section - preview = html.DIV(Class="content-preview") - preview.innerHTML = f"
{self._escape_html(preview_text)}
" - - # Create full content section (hidden initially) - full_content = html.DIV(Class="content-full", style={"display": "none"}) - full_content.innerHTML = f"
{self._escape_html(message)}
" - - # Create toggle button - toggle_btn = html.BUTTON("Show more", Class="expand-toggle-btn") - - def toggle_content(event: Any) -> None: - try: - if full_content.style.display == "none": - preview.style.display = "none" - full_content.style.display = "block" - toggle_btn.text = "Show less" - else: - preview.style.display = "block" - full_content.style.display = "none" - toggle_btn.text = "Show more" - except Exception: - pass - - toggle_btn.bind("click", toggle_content) - - container <= preview - container <= full_content - container <= toggle_btn - - return container - - def _escape_html(self, text: str) -> str: - """Escape HTML special characters. - - Args: - text: Text to escape - - Returns: - Escaped text safe for HTML - """ - return ( - text.replace("&", "&") - .replace("<", "<") - .replace(">", ">") - .replace('"', """) - .replace("'", "'") - ) + """Print a system message to the chat history (delegates to ChatUIManager).""" + self._chat_ui.print_system_message(message) def _debug_log_ai_response(self, ai_message: str, ai_function_calls: Any, finish_reason: str) -> None: """Log debug information about the AI response.""" @@ -1076,7 +628,7 @@ def stop_ai_processing(self) -> None: # Always notify the backend so it can clear stale conversation state # (e.g. previous_response_id pointing to unanswered tool calls). # The backend handles empty text gracefully. - self._save_partial_response(self._stream_buffer or "") + self._save_partial_response(self._chat_ui.stream_buffer or "") self._finalize_stream_message() self._print_system_message_in_chat("Generation stopped.") self._enable_send_controls() @@ -1326,18 +878,8 @@ def _send_prompt_to_ai_stream( # For new user messages, reset all state including containers and buffers # For tool call results, preserve everything to keep intermediary text visible if user_message is not None and tool_call_results is None: - self._request_start_time = window.Date.now() - # Reset all streaming state for new conversation turn - self._stream_buffer = "" - self._stream_content_element = None - self._stream_message_container = None - self._reasoning_buffer = "" - self._reasoning_element = None - self._reasoning_details = None - self._reasoning_summary = None - self._is_reasoning = False - self._needs_continuation_separator = False - self._tool_call_log.reset() + self._chat_ui.request_start_time = window.Date.now() + self._chat_ui.reset_streaming_state() try: payload = self._create_request_payload(prompt, include_svg=True) @@ -1379,18 +921,8 @@ def _send_prompt_to_ai( # For new user messages, reset all state including containers and buffers # For tool call results, preserve everything to keep intermediary text visible if user_message is not None and tool_call_results is None: - self._request_start_time = window.Date.now() - # Reset all streaming state for new conversation turn - self._stream_buffer = "" - self._stream_content_element = None - self._stream_message_container = None - self._reasoning_buffer = "" - self._reasoning_element = None - self._reasoning_details = None - self._reasoning_summary = None - self._is_reasoning = False - self._needs_continuation_separator = False - self._tool_call_log.reset() + self._chat_ui.request_start_time = window.Date.now() + self._chat_ui.reset_streaming_state() self._send_request(prompt, action_trace=action_trace) diff --git a/static/client/chat_ui_manager.py b/static/client/chat_ui_manager.py new file mode 100644 index 00000000..92c71af8 --- /dev/null +++ b/static/client/chat_ui_manager.py @@ -0,0 +1,605 @@ +"""Chat UI manager for the AI interface. + +Manages message rendering, streaming token display, markdown parsing, +and MathJax rendering for the chat interface. Owns all DOM manipulation +for chat messages (user, AI, system) and the streaming response lifecycle. + +Extracted from ``AIInterface`` to reduce god-class complexity while +preserving the identical public behaviour. +""" + +from __future__ import annotations + +from typing import Any, Callable, Optional, cast + +from browser import document, html, window + +from markdown_parser import MarkdownParser +from message_menu_manager import MessageMenuManager +from tool_call_log_manager import ToolCallLogManager + + +class ChatUIManager: + """Manages chat message rendering, streaming display, and markdown formatting. + + Attributes: + markdown_parser: Converts markdown text to HTML for rich formatting. + """ + + def __init__( + self, + message_menu: MessageMenuManager, + tool_call_log: ToolCallLogManager, + on_image_click: Optional[Callable[[str], None]] = None, + on_start_timeout: Optional[Callable[[bool], None]] = None, + on_cancel_timeout: Optional[Callable[[], None]] = None, + ) -> None: + self._message_menu = message_menu + self._tool_call_log = tool_call_log + self._on_image_click = on_image_click + self._on_start_timeout = on_start_timeout + self._on_cancel_timeout = on_cancel_timeout + + # Markdown parser + self.markdown_parser: MarkdownParser = MarkdownParser() + + # Streaming state + self._stream_buffer: str = "" + self._stream_content_element: Optional[Any] = None # DOMNode + self._stream_message_container: Optional[Any] = None # DOMNode + + # Reasoning streaming state + self._reasoning_buffer: str = "" + self._reasoning_element: Optional[Any] = None # DOMNode + self._reasoning_details: Optional[Any] = None # DOMNode (details element) + self._reasoning_summary: Optional[Any] = None # DOMNode (summary element) + self._is_reasoning: bool = False + + self._request_start_time: Optional[float] = None # Timestamp when user request started + self._needs_continuation_separator: bool = False # Add newline before next text after tool calls + + # ── Read-only properties for AIInterface access ────────────── + + @property + def stream_buffer(self) -> str: + """Return the current streaming text buffer.""" + return self._stream_buffer + + @stream_buffer.setter + def stream_buffer(self, value: str) -> None: + """Set the streaming text buffer.""" + self._stream_buffer = value + + @property + def stream_container(self) -> Optional[Any]: + """Return the current streaming message container DOM element.""" + return self._stream_message_container + + @property + def stream_content(self) -> Optional[Any]: + """Return the current streaming content DOM element.""" + return self._stream_content_element + + @property + def is_reasoning(self) -> bool: + """Return whether the AI is currently in reasoning phase.""" + return self._is_reasoning + + @property + def needs_continuation_separator(self) -> bool: + """Return whether a continuation separator is needed before next text.""" + return self._needs_continuation_separator + + @needs_continuation_separator.setter + def needs_continuation_separator(self, value: bool) -> None: + """Set the continuation separator flag.""" + self._needs_continuation_separator = value + + @property + def request_start_time(self) -> Optional[float]: + """Return the timestamp when the current request started.""" + return self._request_start_time + + @request_start_time.setter + def request_start_time(self, value: Optional[float]) -> None: + """Set the timestamp when the current request started.""" + self._request_start_time = value + + # ── Markdown / rendering ───────────────────────────────────── + + def parse_markdown(self, text: str) -> str: + """Parse markdown text to HTML using the dedicated markdown parser.""" + return cast(str, self.markdown_parser.parse(text)) + + def render_math(self) -> None: + """Trigger MathJax rendering for newly added content.""" + try: + # Check if MathJax is available + if hasattr(window, "MathJax") and hasattr(window.MathJax, "typesetPromise"): + # Re-render math in the chat history + window.MathJax.typesetPromise([document["chat-history"]]) + except Exception: + # MathJax not available or error occurred, continue silently + pass + + # ── Message element creation ───────────────────────────────── + + def create_message_element( + self, + sender: str, + message: str, + message_type: str = "normal", + images: Optional[list[str]] = None, + ) -> Any: # DOMNode + """Create a styled message element with markdown support and optional images. + + Args: + sender: The message sender ("User" or "AI") + message: The message text content + message_type: CSS class for message styling ("normal", "system") + images: Optional list of image data URLs to display with the message + + Returns: + DOM element for the message + """ + try: + # Create message container + message_container = html.DIV(Class=f"chat-message {message_type}") + + # Create sender label + sender_label = html.SPAN(f"{sender}: ", Class=f"chat-sender {sender.lower()}") + + # Parse markdown and create content element + if sender == "AI": + parsed_content = self.parse_markdown(message) + content_element = html.DIV(Class="chat-content markdown") + content_element.innerHTML = parsed_content + else: + # For user messages, keep them as plain text for now + content_element = html.SPAN(message, Class="chat-content") + + # Assemble the message + message_container <= sender_label + message_container <= content_element + + # Add images if provided + if images: + images_container = html.DIV(Class="chat-message-images") + for data_url in images: + img = html.IMG(src=data_url, Class="chat-message-image") + img.attrs["alt"] = "Attached image" + + # Bind click to show modal + def make_image_click_handler(url: str) -> Any: + def handler(event: Any) -> None: + if self._on_image_click is not None: + self._on_image_click(url) + + return handler + + img.bind("click", make_image_click_handler(data_url)) + images_container <= img + message_container <= images_container + + # Store the raw source text for copy actions (do not rely on rendered HTML) + self._message_menu.set_raw_text(message_container, message) + self._message_menu.attach(message_container, is_ai_message=(sender == "AI")) + + return message_container + + except Exception as e: + print(f"Error creating message element: {e}") + # Fall back to simple paragraph + if sender == "AI": + content = message.replace("\n", "
") + return html.P(f"{sender}: {content}", innerHTML=True) + else: + return html.P(f"{sender}: {message}") + + def print_ai_message(self, ai_message: str) -> None: + """Print an AI message to the chat history with markdown support and scroll to bottom.""" + if ai_message: + message_element = self.create_message_element("AI", ai_message) + document["chat-history"] <= message_element + # Trigger MathJax rendering for new content + self.render_math() + # Scroll the chat history to the bottom + document["chat-history"].scrollTop = document["chat-history"].scrollHeight + + def print_user_message(self, user_message: str, images: Optional[list[str]] = None) -> None: + """Print a user message to the chat history and scroll to bottom. + + Args: + user_message: The text message from the user + images: Optional list of image data URLs to display with the message + """ + # Add the user's message to the chat history with markdown support + message_element = self.create_message_element("User", user_message, images=images) + document["chat-history"] <= message_element + # Trigger MathJax rendering for new content + self.render_math() + # Scroll the chat history to the bottom + document["chat-history"].scrollTop = document["chat-history"].scrollHeight + + def print_system_message(self, message: str) -> None: + """Print a system/command response to the chat history. + + Used for slash command responses that don't come from AI. + + Args: + message: The message to display (supports markdown) + """ + try: + # Create message container with system styling + message_container = html.DIV(Class="chat-message system") + + # Create sender label + sender_label = html.SPAN("System: ", Class="chat-sender system") + + # Check if message is long and needs expandable display + line_count = message.count("\n") + is_long_message = len(message) > 800 or line_count > 20 + + if is_long_message: + # Create expandable content with details/summary + content_element = self._create_expandable_content(message) + else: + # Parse markdown and create content element + parsed_content = self.parse_markdown(message) + content_element = html.DIV(Class="chat-content markdown") + content_element.innerHTML = parsed_content + + # Assemble the message + message_container <= sender_label + message_container <= content_element + + # Store raw text for copy actions + self._message_menu.set_raw_text(message_container, message) + self._message_menu.attach(message_container) + + # Add to chat history + document["chat-history"] <= message_container + + # Trigger MathJax rendering for new content + self.render_math() + + # Scroll to bottom + document["chat-history"].scrollTop = document["chat-history"].scrollHeight + except Exception as e: + print(f"Error printing system message: {e}") + # Fallback to simple paragraph + fallback = html.P(f"System: {message}") + document["chat-history"] <= fallback + + def _create_expandable_content(self, message: str) -> Any: + """Create an expandable content element for long messages. + + Args: + message: The full message content + + Returns: + A DOM element with expandable content + """ + # Create preview (first ~500 chars or 10 lines) + lines = message.split("\n") + if len(lines) > 10: + preview_text = "\n".join(lines[:10]) + "\n..." + elif len(message) > 500: + preview_text = message[:500] + "..." + else: + preview_text = message + + # Create container + container = html.DIV(Class="chat-content expandable-content") + + # Create preview section + preview = html.DIV(Class="content-preview") + preview.innerHTML = f"
{self._escape_html(preview_text)}
" + + # Create full content section (hidden initially) + full_content = html.DIV(Class="content-full", style={"display": "none"}) + full_content.innerHTML = f"
{self._escape_html(message)}
" + + # Create toggle button + toggle_btn = html.BUTTON("Show more", Class="expand-toggle-btn") + + def toggle_content(event: Any) -> None: + try: + if full_content.style.display == "none": + preview.style.display = "none" + full_content.style.display = "block" + toggle_btn.text = "Show less" + else: + preview.style.display = "block" + full_content.style.display = "none" + toggle_btn.text = "Show more" + except Exception: + pass + + toggle_btn.bind("click", toggle_content) + + container <= preview + container <= full_content + container <= toggle_btn + + return container + + def _escape_html(self, text: str) -> str: + """Escape HTML special characters. + + Args: + text: Text to escape + + Returns: + Escaped text safe for HTML + """ + return ( + text.replace("&", "&") + .replace("<", "<") + .replace(">", ">") + .replace('"', """) + .replace("'", "'") + ) + + # ── Streaming container management ─────────────────────────── + + def ensure_stream_element(self) -> None: + """Create the streaming AI message element if it does not exist yet.""" + if self._stream_content_element is None: + try: + container = html.DIV(Class="chat-message normal") + label = html.SPAN("AI: ", Class="chat-sender ai") + content = html.DIV(Class="chat-content") + content.text = "" + container <= label + container <= content + document["chat-history"] <= container + self._stream_message_container = container + self._stream_content_element = content + # Initialize raw text storage for streaming content + self._message_menu.set_raw_text(container, "") + self._message_menu.attach(container, is_ai_message=True) + except Exception as e: + print(f"Error creating streaming element: {e}") + + def ensure_reasoning_element(self) -> None: + """Create the reasoning dropdown element inside the AI message box.""" + if self._reasoning_element is None: + try: + container = html.DIV(Class="chat-message normal") + label = html.SPAN("AI: ", Class="chat-sender ai") + + # Collapsible dropdown for reasoning + details = html.DETAILS(Class="reasoning-dropdown") + # Start collapsed by default (user can expand if curious) + summary = html.SUMMARY("Thinking...", Class="reasoning-summary") + reasoning_content = html.DIV(Class="reasoning-content") + reasoning_content.text = "" + details <= summary + details <= reasoning_content + + # Content area for the actual response (hidden initially) + response_content = html.DIV(Class="chat-content") + response_content.text = "" + + container <= label + container <= details + container <= response_content + document["chat-history"] <= container + + self._reasoning_element = reasoning_content + self._reasoning_details = details + self._reasoning_summary = summary + self._stream_message_container = container + self._stream_content_element = response_content + # Initialize raw text storage for reasoning responses + self._message_menu.set_raw_text(container, "") + self._message_menu.attach(container, is_ai_message=True) + except Exception as e: + print(f"Error creating reasoning element: {e}") + + # ── Streaming token handlers ───────────────────────────────── + + def on_stream_token(self, text: str) -> None: + """Handle a streamed token: append to buffer and update the UI element.""" + try: + # Reset timeout since we're receiving data (use normal timeout for response) + if self._on_start_timeout is not None: + self._on_start_timeout(False) + + # If we were in reasoning phase, collapse the reasoning dropdown + if self._is_reasoning and self._reasoning_details is not None: + try: + del self._reasoning_details.attrs["open"] + except Exception: + try: + self._reasoning_details.attrs["open"] = False + except Exception: + pass + self._is_reasoning = False + + # When continuing after tool calls, clear the buffer and start fresh + # The AI will re-state any necessary context in its new response + # This prevents duplication when AI restates previous confirmations + if self._needs_continuation_separator: + self._stream_buffer = "" + self._needs_continuation_separator = False + + self._stream_buffer += text + # Use reasoning element's response area if it exists, otherwise create normal element + if self._stream_content_element is None and self._reasoning_element is None: + self.ensure_stream_element() + if self._stream_content_element is not None: + self._stream_content_element.text = self._stream_buffer + if self._stream_message_container is not None: + self._message_menu.set_raw_text(self._stream_message_container, self._stream_buffer) + document["chat-history"].scrollTop = document["chat-history"].scrollHeight + except Exception as e: + print(f"Error handling stream token: {e}") + + def on_stream_reasoning(self, text: str) -> None: + """Handle a reasoning token: append to reasoning buffer and update UI.""" + try: + # Use extended timeout for reasoning phase + if self._on_start_timeout is not None: + self._on_start_timeout(True) + self._is_reasoning = True + + # Don't repeat the placeholder if we already have it + if "(Reasoning in progress...)" in text and "(Reasoning in progress...)" in self._reasoning_buffer: + return + + self._reasoning_buffer += text + self.ensure_reasoning_element() + if self._reasoning_element is not None: + self._reasoning_element.text = self._reasoning_buffer + document["chat-history"].scrollTop = document["chat-history"].scrollHeight + except Exception as e: + print(f"Error handling reasoning token: {e}") + + # ── Stream finalization ────────────────────────────────────── + + def finalize_stream(self, final_message: Optional[str] = None) -> None: + """Convert the streamed plain text to parsed markdown and render math.""" + try: + self._tool_call_log.finalize() + + # Prefer the accumulated buffer (contains all text across tool calls) + # Only use final_message as fallback if buffer is empty + text_to_render = self._stream_buffer if self._stream_buffer.strip() else (final_message or "") + + # If we have reasoning content and actual text, create a combined element + if self._reasoning_buffer and self._stream_message_container is not None: + # Preserve raw source for copy actions + self._message_menu.set_raw_text(self._stream_message_container, text_to_render) + if text_to_render and self._stream_content_element is not None: + # Update the response content with parsed markdown + parsed_content = self.parse_markdown(text_to_render) + self._stream_content_element.innerHTML = parsed_content + self._stream_content_element.classList.add("markdown") + + # Update summary to show elapsed time and ensure dropdown stays closed + if self._reasoning_summary is not None and self._request_start_time is not None: + try: + from browser import window + + elapsed_ms = window.Date.now() - self._request_start_time + elapsed_seconds = int(elapsed_ms / 1000) + self._reasoning_summary.text = f"Thought for {elapsed_seconds} seconds" + except Exception: + pass + + # Ensure dropdown is closed + if self._reasoning_details is not None: + try: + del self._reasoning_details.attrs["open"] + except Exception: + try: + self._reasoning_details.attrs["open"] = False + except Exception: + pass + + self.render_math() + document["chat-history"].scrollTop = document["chat-history"].scrollHeight + else: + # Reasoning but no text content - remove the empty container + self.remove_empty_container() + elif text_to_render: + if self._tool_call_log.element is not None and self._stream_message_container is not None: + # Tool call log exists — update the container in place to preserve the dropdown + self._message_menu.set_raw_text(self._stream_message_container, text_to_render) + if self._stream_content_element is not None: + parsed_content = self.parse_markdown(text_to_render) + self._stream_content_element.innerHTML = parsed_content + self._stream_content_element.classList.add("markdown") + self.render_math() + document["chat-history"].scrollTop = document["chat-history"].scrollHeight + else: + # No reasoning or tool log, use standard finalization + final_element = self.create_message_element("AI", text_to_render) + + history = document["chat-history"] + if self._stream_message_container is not None: + try: + history.replaceChild(final_element, self._stream_message_container) + except Exception: + history <= final_element + else: + history <= final_element + + self.render_math() + history.scrollTop = history.scrollHeight + else: + # No text content at all - remove any empty container + self.remove_empty_container() + except Exception as e: + print(f"Error finalizing stream message: {e}") + finally: + self._stream_buffer = "" + self._stream_content_element = None + self._stream_message_container = None + self._reasoning_buffer = "" + self._reasoning_element = None + self._reasoning_details = None + self._reasoning_summary = None + self._is_reasoning = False + self._request_start_time = None + self._tool_call_log.reset() + + def remove_empty_container(self) -> None: + """Remove the current response container if it has no actual text content. + + This cleans up "Thinking..." boxes when the AI only performs tool calls + without providing a text response. Never removes a container with actual text. + """ + try: + # Check if there's actual text content in buffer or visible in the element + has_buffer_text = bool(self._stream_buffer.strip()) + has_element_text = False + if self._stream_content_element is not None: + try: + element_text = self._stream_content_element.text or self._stream_content_element.innerHTML or "" + has_element_text = bool(element_text.strip()) + except Exception: + pass + has_tool_call_log = bool(self._tool_call_log.entries) + + # Only remove if there's NO actual text content anywhere and no tool call log + if ( + self._stream_message_container is not None + and not has_buffer_text + and not has_element_text + and not has_tool_call_log + ): + history = document["chat-history"] + try: + history.removeChild(self._stream_message_container) + except Exception: + pass + # Reset state + self._stream_message_container = None + self._stream_content_element = None + self._reasoning_element = None + self._reasoning_details = None + self._reasoning_summary = None + self._reasoning_buffer = "" + self._is_reasoning = False + self._tool_call_log.reset() + # Don't reset _request_start_time here - we want to keep timing across tool calls + except Exception as e: + print(f"Error removing empty container: {e}") + + # ── State reset ────────────────────────────────────────────── + + def reset_streaming_state(self) -> None: + """Reset all streaming state for a new conversation turn.""" + self._stream_buffer = "" + self._stream_content_element = None + self._stream_message_container = None + self._reasoning_buffer = "" + self._reasoning_element = None + self._reasoning_details = None + self._reasoning_summary = None + self._is_reasoning = False + self._needs_continuation_separator = False + self._tool_call_log.reset() diff --git a/static/client/client_tests/test_error_recovery.py b/static/client/client_tests/test_error_recovery.py index d49f291d..4887ff77 100644 --- a/static/client/client_tests/test_error_recovery.py +++ b/static/client/client_tests/test_error_recovery.py @@ -12,6 +12,8 @@ from browser import document, html, window from tool_call_log_manager import ToolCallLogManager +from message_menu_manager import MessageMenuManager +from chat_ui_manager import ChatUIManager class TestErrorRecovery(unittest.TestCase): @@ -85,20 +87,15 @@ def test_message_buffer_cleared_on_success(self) -> None: ai = self._create_ai_interface() # Set up minimal state for _on_stream_final - ai._stream_buffer = "" - ai._stream_content_element = None - ai._stream_message_container = None - ai._reasoning_buffer = "" - ai._reasoning_element = None - ai._reasoning_details = None - ai._reasoning_summary = None - ai._is_reasoning = False - ai._request_start_time = None - ai._tool_call_log = ToolCallLogManager() + tool_call_log = ToolCallLogManager() + ai._tool_call_log = tool_call_log + ai._chat_ui = ChatUIManager( + message_menu=MessageMenuManager(), + tool_call_log=tool_call_log, + ) ai.is_processing = True ai._stop_requested = False ai._response_timeout_id = None - ai.markdown_parser = type("MockParser", (), {"parse": lambda s, t: t})() # Mock methods ai._finalize_stream_message = lambda msg=None: None @@ -124,20 +121,15 @@ def test_message_buffer_preserved_on_error(self) -> None: ai = self._create_ai_interface() # Set up minimal state - ai._stream_buffer = "" - ai._stream_content_element = None - ai._stream_message_container = None - ai._reasoning_buffer = "" - ai._reasoning_element = None - ai._reasoning_details = None - ai._reasoning_summary = None - ai._is_reasoning = False - ai._request_start_time = None - ai._tool_call_log = ToolCallLogManager() + tool_call_log = ToolCallLogManager() + ai._tool_call_log = tool_call_log + ai._chat_ui = ChatUIManager( + message_menu=MessageMenuManager(), + tool_call_log=tool_call_log, + ) ai.is_processing = True ai._stop_requested = False ai._response_timeout_id = None - ai.markdown_parser = type("MockParser", (), {"parse": lambda s, t: t})() # Track if restore was called restore_called = [False] diff --git a/static/client/client_tests/test_tool_call_log.py b/static/client/client_tests/test_tool_call_log.py index bcbd98de..98bd9026 100644 --- a/static/client/client_tests/test_tool_call_log.py +++ b/static/client/client_tests/test_tool_call_log.py @@ -6,6 +6,8 @@ from browser import html from tool_call_log_manager import ToolCallLogManager +from message_menu_manager import MessageMenuManager +from chat_ui_manager import ChatUIManager from .simple_mock import SimpleMock @@ -361,28 +363,20 @@ def test_container_preserved_with_tool_log(self) -> None: from ai_interface import AIInterface ai = AIInterface.__new__(AIInterface) - ai._tool_call_log = ToolCallLogManager() - ai._stream_buffer = "" - ai._stream_content_element = None - ai._stream_message_container = None - ai._reasoning_buffer = "" - ai._reasoning_element = None - ai._reasoning_details = None - ai._reasoning_summary = None - ai._is_reasoning = False - ai._request_start_time = None - ai._needs_continuation_separator = False - ai._open_message_menu = None - ai._message_menu_global_bound = True - ai._copy_text_to_clipboard = SimpleMock(return_value=True) + tool_call_log = ToolCallLogManager() + ai._tool_call_log = tool_call_log + ai._chat_ui = ChatUIManager( + message_menu=MessageMenuManager(), + tool_call_log=tool_call_log, + ) container = html.DIV() content = html.DIV(Class="chat-content") content.text = "" container <= content - ai._stream_message_container = container - ai._stream_content_element = content - ai._stream_buffer = "" + ai._chat_ui._stream_message_container = container + ai._chat_ui._stream_content_element = content + ai._chat_ui._stream_buffer = "" # Add tool call entries so the log is non-empty ai._tool_call_log.entries = [{"name": "f", "is_error": False}] @@ -392,5 +386,5 @@ def test_container_preserved_with_tool_log(self) -> None: ai._remove_empty_response_container() self.assertIs( - ai._stream_message_container, container, "Container should NOT be removed when tool call log has entries" + ai._chat_ui.stream_container, container, "Container should NOT be removed when tool call log has entries" ) From 12f15b252b3c18c3ca632f0cf1e395efab05fef4 Mon Sep 17 00:00:00 2001 From: Vlad Ciobanu <95963142+vl3c@users.noreply.github.com> Date: Tue, 24 Mar 2026 23:58:58 +0200 Subject: [PATCH 22/28] Update architectural review with completion status and Phase 3 roadmap Mark Phases 1-2 as complete, add detailed Phase 3 items covering CI/CD, schema versioning, dependency graph restructuring, workspace migrations, and further decomposition of AIInterface and Canvas. --- .../architectural_review_2026_03_24.md | 44 +++++++++++-------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/documentation/development/architectural_review_2026_03_24.md b/documentation/development/architectural_review_2026_03_24.md index ca212353..9ebd3688 100644 --- a/documentation/development/architectural_review_2026_03_24.md +++ b/documentation/development/architectural_review_2026_03_24.md @@ -179,22 +179,28 @@ Constants that should be centralized are spread across files: ## Refactoring Roadmap -### Phase 1: Quick Wins (low risk, high value) -1. Remove debug code (test trigger in routes.py, print statements in expression_evaluator.py) -2. Centralize scattered constants -3. Standardize error handling (replace `print()` with `logging`, eliminate bare `except:`) -4. Pin all dependencies in requirements.txt - -### Phase 2: Structural (medium risk, high value) -5. Create `BaseDrawableManager` — eliminates manager duplication -6. Extract `ChatUIManager` and `StreamingResponseHandler` from AIInterface -7. Extract `VisibilityManager` from Canvas -8. Create `ProviderManager` to unify provider operations in routes.py -9. Make `get_state()` side-effect-free across all drawables - -### Phase 3: Architecture (higher risk, long-term value) -10. Add CI/CD pipeline with automated tests + linting -11. Implement drawable state schema versioning -12. Restructure dependency graph to eliminate `DrawableManagerProxy` -13. Add workspace format migration support -14. Extract route handler logic into service classes for testability +### Phase 1: Quick Wins — COMPLETE (PR #49) +1. ~~Remove debug code (test trigger in routes.py, print statements in expression_evaluator.py)~~ +2. ~~Centralize scattered constants~~ +3. ~~Standardize error handling (replace `print()` with `logging`, eliminate bare `except:`)~~ +4. ~~Pin all dependencies in requirements.txt~~ +5. ~~Extract shared env loading into env_config.py~~ +6. ~~Extract route helpers (tool reset/provider deduplication)~~ + +### Phase 2: Structural — COMPLETE (PR #49) +7. ~~Create `BaseDrawableManager` — 9 managers migrated~~ +8. ~~Extract `BaseRendererTelemetry` from SVG/Canvas2D~~ +9. ~~Extract `VisibilityManager` from Canvas~~ +10. ~~Make `get_state()` side-effect-free (Segment fix)~~ +11. ~~Decompose AIInterface (2,339 → 1,078 lines) into 5 classes:~~ + - ~~ToolCallLogManager, MessageMenuManager, ImageAttachmentManager, TTSUIManager, ChatUIManager~~ +12. ~~Add tests for all new modules (35 server + 74 client)~~ + +### Phase 3: Architecture (higher risk, long-term value) — TODO +13. **Add CI/CD pipeline** — GitHub Actions workflow running server tests (`pytest`), client tests (Selenium via CLI), ruff, and mypy on every PR. High value: protects all refactoring work going forward. +14. **Implement drawable state schema versioning** — Add version field to each drawable's `get_state()` output. Create `from_state()` factory class methods on each drawable for deserialization. Validate schema on workspace load. +15. **Restructure dependency graph to eliminate `DrawableManagerProxy`** — The proxy uses `__getattr__` reflection to break circular initialization. Restructure so managers receive specific interfaces (Protocols) rather than a proxy. This improves IDE support and type safety. +16. **Add workspace format migration support** — With schema versioning in place, add migration functions that upgrade workspace JSON from version N to N+1. Enables breaking changes to serialization format without data loss. +17. **Extract route handler logic into service classes** — Move business logic from Flask route functions into testable service classes. Routes become thin HTTP adapters. Enables testing without Flask test client. +18. **Further decompose AIInterface** — The remaining 1,078 lines still mix request building, streaming orchestration, timeout management, and test execution. Candidates for extraction: `RequestBuilder` (~200 lines for payload/vision/send), `StreamingOrchestrator` (~200 lines for _on_stream_final/error/timeout). +19. **Further decompose Canvas** — Canvas is still ~2,100 lines after VisibilityManager extraction. Candidates: `CanvasRenderingCoordinator` (frame batching, render dispatch), `ZoomController` (zoom displacement calculations). From 4907db3983e3500364754b60333671e992f49146 Mon Sep 17 00:00:00 2001 From: Vlad Ciobanu <95963142+vl3c@users.noreply.github.com> Date: Wed, 25 Mar 2026 00:03:14 +0200 Subject: [PATCH 23/28] Update documentation for architecture refactoring Update CLAUDE.md backend/client highlights with new modules (config.py, env_config.py, route_helpers.py, 5 extracted AI managers, 3 base classes). Update Project Architecture.txt with restructured AI Interface section, new manager/renderer base classes, and updated file trees. --- .claude/CLAUDE.md | 20 +++++++++++--- documentation/Project Architecture.txt | 38 ++++++++++++++++++-------- 2 files changed, 42 insertions(+), 16 deletions(-) diff --git a/.claude/CLAUDE.md b/.claude/CLAUDE.md index dc876f38..5736577a 100644 --- a/.claude/CLAUDE.md +++ b/.claude/CLAUDE.md @@ -87,15 +87,27 @@ Then navigate to `http://127.0.0.1:5004/` in the browser. 2. `tool_call_processor.py`, `ai_model.py`, and `functions_definitions.py` define the function-call surface exposed to GPT models. 3. `webdriver_manager.py` captures canvas screenshots for the vision workflow. 4. `workspace_manager.py` and `log_manager.py` handle persistence and auditing. -5. `style.css` and other assets shared with the frontend live here for Flask to serve. +5. `config.py` centralizes server-side constants (workspace dirs, schema version, snapshot paths). +6. `env_config.py` provides shared environment variable loading, replacing duplicated `load_dotenv` patterns. +7. `route_helpers.py` contains extracted route helper functions for provider management and tool lifecycle. +8. `style.css` and other assets shared with the frontend live here for Flask to serve. ## Client Highlights (`static/client/`) 1. `main.py` bootstraps Brython and registers managers. 2. `canvas.py`, `canvas_event_handler.py`, and `managers/` orchestrate SVG drawing, selection, undo, and edit policies. 3. `drawables/` contains shape classes (point, segment, vector, triangle, rectangle, circle, ellipse, angle, etc.). -4. `ai_interface.py`, `process_function_calls.py`, and `result_processor.py` coordinate chat responses and tool execution. -5. `expression_evaluator.py`, `expression_validator.py`, and `result_validator.py` provide math parsing, validation, and error messaging. -6. `client_tests/` plus `test_runner.py` implement the Brython test harness (register new tests in `client_tests/tests.py`). +4. `ai_interface.py` orchestrates AI communication and request building, delegating to five extracted managers. +5. `chat_ui_manager.py` handles message rendering, streaming display, and markdown parsing. +6. `message_menu_manager.py` manages context menus, clipboard, and raw text storage. +7. `tool_call_log_manager.py` provides tool call visualization and state tracking. +8. `image_attachment_manager.py` manages the file picker, preview, and image modal. +9. `tts_ui_manager.py` controls text-to-speech UI (read aloud, voice settings). +10. `process_function_calls.py` and `result_processor.py` coordinate tool execution and result handling. +11. `expression_evaluator.py`, `expression_validator.py`, and `result_validator.py` provide math parsing, validation, and error messaging. +12. `managers/base_drawable_manager.py` provides a shared base for all drawable managers. +13. `managers/visibility_manager.py` handles viewport culling extracted from Canvas. +14. `rendering/base_telemetry.py` provides a shared telemetry base for renderers. +15. `client_tests/` plus `test_runner.py` implement the Brython test harness (register new tests in `client_tests/tests.py`). --- diff --git a/documentation/Project Architecture.txt b/documentation/Project Architecture.txt index 87cce4bd..d0813ef7 100644 --- a/documentation/Project Architecture.txt +++ b/documentation/Project Architecture.txt @@ -70,7 +70,7 @@ MatHud is an interactive mathematical visualization tool that combines a drawing 8. DOM surfaces are layered as WebGL (highest `z-index` 20), Canvas2D (`z-index` 10), and SVG (base layer) inside `#math-container`, with inactive surfaces kept `pointer-events: none`. Labels remain on SVG to guarantee text clarity across backends. 9. Runtime selection relies on the preference chain inside `create_renderer`: it first tries Canvas2D, falls back to SVG if Canvas2D fails, and finally attempts WebGL. Constructor errors are caught so the factory can continue down the chain. 10. Feature flags remain available for diagnostics: `window.MatHudSvgOffscreen` (or `localStorage["mathud.svg.offscreen"]`) toggles SVG offscreen staging, and `window.MatHudCanvas2DOffscreen` (or `localStorage["mathud.canvas2d.offscreen"]`) enables Canvas2D layer compositing. -11. Renderers expose `begin_frame`, `end_frame`, `peek_telemetry`, and `drain_telemetry` hooks so automated tests and performance harnesses can capture plan build/apply timings, skip counts, adapter events, and maximum batch depth. +11. Renderers expose `begin_frame`, `end_frame`, `peek_telemetry`, and `drain_telemetry` hooks so automated tests and performance harnesses can capture plan build/apply timings, skip counts, adapter events, and maximum batch depth. Shared telemetry infrastructure lives in `base_telemetry.py`. ## Multi-Step Calculations - Expressions: evaluate mathematical expressions and functions @@ -124,7 +124,7 @@ When the AI executes tool calls during a conversation turn, a collapsible "Tool - **Entry contents**: Each entry shows a status icon (checkmark for success, cross for error), the function name, a compact argument summary, and an error message when applicable. - **Final summary**: After the turn completes, the summary updates to "Used N tool(s)" with an optional "(M failed)" suffix if any calls returned errors. - **State lifecycle**: Tool call log state (`_tool_call_log_entries`, `_tool_call_log_element`, etc.) is reset at the start of each new user message and cleaned up in `_finalize_stream_message`. -- **Implementation**: `AIInterface._add_tool_call_entries` records entries and updates the DOM; `_finalize_tool_call_log` writes the final summary. CSS classes are prefixed with `tool-call-log-*` (defined in `static/style.css`). +- **Implementation**: `ToolCallLogManager` (extracted from `AIInterface`) records entries and updates the DOM; its `finalize` method writes the final summary. CSS classes are prefixed with `tool-call-log-*` (defined in `static/style.css`). - **Tests**: `static/client/client_tests/test_tool_call_log.py` covers argument formatting, entry element creation, dropdown lifecycle, accumulation across rounds, finalization, and state management. --- @@ -167,6 +167,9 @@ static/ ├── functions_definitions.py # 70 AI function/tool definitions ├── ai_model.py # AI model configuration utilities ├── tool_call_processor.py # Tool call processing utilities +├── config.py # Centralized server-side constants +├── env_config.py # Shared environment loading utility +├── route_helpers.py # Extracted route helper functions ├── style.css # Frontend CSS styling └── client/ # Custom Python modules for Brython (moved from Brython-3.11.3/Lib/site-packages/) # Note: Brython core (3.12.5), Math.js (14.5.2), and Nerdamer (1.1.13) now loaded from CDN @@ -179,7 +182,12 @@ templates/ static/client/ ├── main.py # Brython initialization and entry point -├── ai_interface.py # AI communication and UI management +├── ai_interface.py # AI communication orchestration and request building (refactored, ~1,078 lines) +├── chat_ui_manager.py # Message rendering, streaming display, markdown (extracted from ai_interface) +├── message_menu_manager.py # Context menus and clipboard (extracted from ai_interface) +├── tool_call_log_manager.py # Tool call visualization (extracted from ai_interface) +├── image_attachment_manager.py # Image attachment handling (extracted from ai_interface) +├── tts_ui_manager.py # Text-to-speech UI (extracted from ai_interface) ├── markdown_parser.py # Comprehensive markdown parser with LaTeX support ├── canvas_event_handler.py # Mouse/keyboard event handling with extensive error handling ├── canvas.py # Core SVG canvas management @@ -245,6 +253,7 @@ static/client/ │ ├── segments_area_renderable.py # Area between one or two segments │ ├── shared_drawable_renderers.py # Drawing helpers shared across all backends │ ├── style_manager.py # Renderer style dictionaries +│ ├── base_telemetry.py # Shared telemetry base class for renderers │ ├── cached_render_plan.py # OptimizedPrimitivePlan type for caching │ ├── canvas2d_primitive_adapter.py # Canvas 2D primitive adapter │ ├── svg_primitive_adapter.py # SVG primitive adapter with DOM pooling @@ -257,8 +266,10 @@ static/client/ │ ├── screen_offset_label_layout.py # Label overlap resolution │ └── label_overlap_resolver.py # Greedy overlap resolution algorithm ├── managers/ # Specialized object managers +│ ├── base_drawable_manager.py # Base class for drawable managers │ ├── drawable_manager.py # Central drawable orchestrator │ ├── drawable_manager_proxy.py # Initialization dependency management +│ ├── visibility_manager.py # Drawable visibility control (extracted from Canvas) │ ├── drawables_container.py # Drawable storage and organization │ ├── drawable_dependency_manager.py # Object dependency tracking │ ├── undo_redo_manager.py # Undo/redo functionality @@ -363,7 +374,7 @@ server_tests/ ## Architecture Notes - **Brython Integration**: Extensive use of Python in the browser via Brython 3.12.5 framework (CDN) - **Facade Pattern**: `ProcessFunctionCalls` acts as facade for `ResultProcessor`, `ExpressionEvaluator`, `ResultValidator` -- **Manager Pattern**: Specialized managers for each geometric shape type with centralized coordination +- **Manager Pattern**: Specialized managers for each geometric shape type with centralized coordination; drawable managers inherit from `BaseDrawableManager` - **Error Handling**: Comprehensive try-catch blocks throughout client-side code - **Testing**: Dual testing system with server-side pytest and client-side Brython unittest - **Renderer Logic Tests**: `server_tests/client_renderer/` houses pytest suites that exercise plan caching, factory fallbacks, helper metadata, and color normalization using lightweight browser stubs. The harness runs through `run_server_tests.py`, which now injects `static/client` into `PYTHONPATH` so the renderer modules and shared utilities import correctly during server-side runs. @@ -385,26 +396,28 @@ server_tests/ - `static/client/canvas_event_handler.py` manages user interactions with the canvas. ## AI Interface Layer (`static/client/ai_interface.py`) -- Core class `AIInterface` manages Brython-side communication between UI and backend. +- Core class `AIInterface` manages Brython-side AI communication orchestration and request building (~1,078 lines, refactored down from ~2,339). +- Five UI-focused concerns were extracted into dedicated managers: + - `ChatUIManager` (`chat_ui_manager.py`): message rendering, streaming display, and markdown formatting. + - `MessageMenuManager` (`message_menu_manager.py`): context menus and clipboard operations. + - `ToolCallLogManager` (`tool_call_log_manager.py`): tool call visualization and log dropdown. + - `ImageAttachmentManager` (`image_attachment_manager.py`): image attachment handling. + - `TTSUIManager` (`tts_ui_manager.py`): text-to-speech UI controls. - Instantiated in `main.py` and receives events from `CanvasEventHandler`. - Initializes client-side `WorkspaceManager` (from `static/client/workspace_manager.py`) and `FunctionRegistry`. -- **Markdown Integration**: Initializes `MarkdownParser` for comprehensive text formatting and LaTeX rendering. - Maintains dictionary of available functions for math operations and canvas manipulation (via `FunctionRegistry`). - Handles function call results and maintains state between interactions. -- **UI Element References**: References `send-button` for button control, `chat-input` for input field, `vision-toggle` for vision checkbox, and `ai-model-selector` for model selection. - Key methods: - `interact_with_ai`: Entry point for user interactions from UI. - `_send_prompt_to_ai`: Prepares JSON prompt (with canvas state, user message, tool results, vision settings from `document["vision-toggle"]`, AI model from `document["ai-model-selector"]`) and sends AJAX POST request to `/send_message` Flask endpoint. - `_process_ai_response`: Processes AI responses from backend. If tool calls are present, uses `ProcessFunctionCalls.get_results` and then sends results back to AI via `_send_prompt_to_ai`. - - `_parse_markdown_to_html`: Converts markdown text to HTML using the dedicated `MarkdownParser`. - - `_render_math`: Triggers MathJax rendering for newly added mathematical expressions. - - `_create_message_element`: Creates styled message elements with markdown support for chat display. - (Workspace operations like `save_workspace` are invoked via function calls processed by `ProcessFunctionCalls` which then use the client-side `WorkspaceManager` instance). ## Client-Side Brython Managers (in `static/client/managers/`) - A suite of manager classes, running in the Brython environment, handle the detailed logic for canvas objects and operations. +- **`BaseDrawableManager`**: Base class providing shared infrastructure (container access, undo archiving, name generation) for all specialized drawable managers. - **`DrawableManager`**: Central orchestrator for all drawable objects. It coordinates the specialized managers listed below. Instantiated by `Canvas`. -- **Specialized Drawable Managers**: +- **Specialized Drawable Managers** (inherit from `BaseDrawableManager`): - `PointManager`: Manages point objects. - `SegmentManager`: Manages line segment objects. - `VectorManager`: Manages vector objects. @@ -421,6 +434,7 @@ server_tests/ - `DrawableDependencyManager`: Tracks dependencies between drawable objects. - `DrawableManagerProxy`: A proxy to help manage initialization dependencies for `DrawableManager`. - `DrawablesContainer`: A data structure used by `DrawableManager` to hold drawable instances. + - `VisibilityManager`: Controls drawable visibility state (extracted from `Canvas`). - These managers are responsible for creating, deleting, updating, and querying drawable objects on the canvas, and are used by `ProcessFunctionCalls` when executing drawing-related function calls. ## Workspace Management @@ -532,7 +546,7 @@ server_tests/ - Math evaluations (`ExpressionEvaluator`). - Client-side workspace operations (e.g., `save_workspace`) that call methods on the client `WorkspaceManager`, which then makes further AJAX calls. iii. Results from `ProcessFunctionCalls.get_results` are collected. - iv. `AIInterface._add_tool_call_entries` records each tool call in the tool call log dropdown, updating the live count in the UI. + iv. `ToolCallLogManager` records each tool call in the tool call log dropdown, updating the live count in the UI. v. `AIInterface._store_results_in_canvas_state` updates canvas computations. vi. `AIInterface._send_prompt_to_ai` is called again, this time with `tool_call_results` (and no `user_message`), looping back to step 4 (backend forwards tool results to OpenAI). 8. Canvas is updated by the managers if drawing functions were called. Results are displayed in chat. From 3ad082b4207dc6470973246d1ed2554adbc7b208 Mon Sep 17 00:00:00 2001 From: Vlad Ciobanu <95963142+vl3c@users.noreply.github.com> Date: Wed, 25 Mar 2026 01:12:07 +0200 Subject: [PATCH 24/28] Complete env_config migration and fix hardcoded snapshot path Migrate remaining load_dotenv patterns in app_manager.py, providers/__init__.py, and openrouter_api.py to use env_config. Replace hardcoded "canvas_snapshots/canvas.png" in openai_api_base.py with config.CANVAS_SNAPSHOT_PATH. --- static/app_manager.py | 7 ++----- static/openai_api_base.py | 3 ++- static/providers/__init__.py | 12 +++--------- static/providers/openrouter_api.py | 12 ++---------- 4 files changed, 9 insertions(+), 25 deletions(-) diff --git a/static/app_manager.py b/static/app_manager.py index 47acfda7..f5880b13 100644 --- a/static/app_manager.py +++ b/static/app_manager.py @@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Dict, Optional, Tuple, TypedDict, Union from cachelib.file import FileSystemCache -from dotenv import load_dotenv from flask import Flask, Response, jsonify from flask_session import Session as FlaskSession +from static.env_config import load_env_files from static.log_manager import LogManager from static.openai_completions_api import OpenAIChatCompletionsAPI from static.openai_responses_api import OpenAIResponsesAPI @@ -91,10 +91,7 @@ def is_deployed() -> bool: @staticmethod def _load_env() -> None: """Load environment from project .env and parent .env (API keys).""" - load_dotenv() - parent_env = os.path.join(os.path.dirname(os.getcwd()), ".env") - if os.path.exists(parent_env): - load_dotenv(parent_env) + load_env_files() @staticmethod def requires_auth() -> bool: diff --git a/static/openai_api_base.py b/static/openai_api_base.py index e54ea643..e607d776 100644 --- a/static/openai_api_base.py +++ b/static/openai_api_base.py @@ -19,6 +19,7 @@ from openai import OpenAI from static.ai_model import AIModel +from static.config import CANVAS_SNAPSHOT_PATH from static.env_config import get_api_key from static.canvas_state_summarizer import compare_canvas_states from static.functions_definitions import FUNCTIONS, FunctionDefinition @@ -308,7 +309,7 @@ def _create_enhanced_prompt_with_image( # Add canvas snapshot if vision is enabled if include_canvas_snapshot: try: - with open("canvas_snapshots/canvas.png", "rb") as image_file: + with open(CANVAS_SNAPSHOT_PATH, "rb") as image_file: image_data = base64.b64encode(image_file.read()).decode("utf-8") content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_data}"}}) has_images = True diff --git a/static/providers/__init__.py b/static/providers/__init__.py index 3caf14a1..aabe54ef 100644 --- a/static/providers/__init__.py +++ b/static/providers/__init__.py @@ -11,7 +11,7 @@ import os from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type -from dotenv import load_dotenv +from static.env_config import load_env_files if TYPE_CHECKING: from static.openai_api_base import OpenAIAPIBase @@ -85,10 +85,7 @@ def is_provider_available(cls, provider_name: str) -> bool: return LocalProviderRegistry.is_provider_available(provider_name) # API-based providers use API key check - load_dotenv() - parent_env = os.path.join(os.path.dirname(os.getcwd()), ".env") - if os.path.exists(parent_env): - load_dotenv(parent_env) + load_env_files() key_name = cls._api_key_names.get(provider_name) if not key_name: return False @@ -104,10 +101,7 @@ def get_available_providers(cls) -> List[str]: Returns: List of available provider names """ - load_dotenv() - parent_env = os.path.join(os.path.dirname(os.getcwd()), ".env") - if os.path.exists(parent_env): - load_dotenv(parent_env) + load_env_files() available = [] # Check API-based providers diff --git a/static/providers/openrouter_api.py b/static/providers/openrouter_api.py index dc32a69a..536a83fc 100644 --- a/static/providers/openrouter_api.py +++ b/static/providers/openrouter_api.py @@ -7,14 +7,13 @@ from __future__ import annotations -import os from collections.abc import Sequence from typing import Optional -from dotenv import load_dotenv from openai import OpenAI from static.ai_model import AIModel +from static.env_config import get_api_key from static.functions_definitions import FunctionDefinition from static.openai_api_base import ToolMode from static.openai_completions_api import OpenAIChatCompletionsAPI @@ -23,14 +22,7 @@ def _get_openrouter_api_key() -> str: """Get the OpenRouter API key from environment.""" - load_dotenv() - parent_env = os.path.join(os.path.dirname(os.getcwd()), ".env") - if os.path.exists(parent_env): - load_dotenv(parent_env) - api_key = os.getenv("OPENROUTER_API_KEY") - if not api_key: - raise ValueError("OPENROUTER_API_KEY not found in environment or .env file") - return api_key + return get_api_key("OPENROUTER_API_KEY") class OpenRouterAPI(OpenAIChatCompletionsAPI): From 4ce7ab5aa1ddcce69dad5620d2f962c5ad308c03 Mon Sep 17 00:00:00 2001 From: Vlad Ciobanu <95963142+vl3c@users.noreply.github.com> Date: Wed, 25 Mar 2026 01:12:14 +0200 Subject: [PATCH 25/28] Make ResultProcessor.generate_result_key public Rename _generate_result_key to generate_result_key in result_processor.py and update caller in tool_call_log_manager.py. Fixes private method coupling flagged in review. --- documentation/Reference Manual.txt | 2 +- static/client/result_processor.py | 4 ++-- static/client/tool_call_log_manager.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/documentation/Reference Manual.txt b/documentation/Reference Manual.txt index 3626edec..41fe0491 100644 --- a/documentation/Reference Manual.txt +++ b/documentation/Reference Manual.txt @@ -500,7 +500,7 @@ computation history for mathematical operations. - `_process_function_call(call, available_functions, non_computation_functions, unformattable_functions, canvas, results)`: Process a single function call and update results - `_is_function_available(function_name, available_functions, results)`: Check if the function exists and update results if not - `_execute_function(function_name, args, available_functions)`: Execute the function with the provided arguments -- `_generate_result_key(function_name, args)`: Generate a consistent key format for the results dictionary +- `generate_result_key(function_name, args)`: Generate a consistent key format for the results dictionary - `_process_result(function_name, args, result, key, unformattable_functions, non_computation_functions, canvas, results)`: Process the result based on function type - `_handle_exception(exception, function_name, results)`: Handle exceptions during function calls - `_add_computation_if_needed(result, function_name, non_computation_functions, expression, canvas)`: Add computation to canvas if needed diff --git a/static/client/result_processor.py b/static/client/result_processor.py index 26acb989..ea237eeb 100644 --- a/static/client/result_processor.py +++ b/static/client/result_processor.py @@ -235,7 +235,7 @@ def _process_function_call( result: Any = ResultProcessor._execute_function(function_name, args, available_functions) # Format the key for results dictionary - key: str = ResultProcessor._generate_result_key(function_name, args) + key: str = ResultProcessor.generate_result_key(function_name, args) # Process the result based on function type ResultProcessor._process_result( @@ -260,7 +260,7 @@ def _execute_function(function_name: str, args: Dict[str, Any], available_functi return result @staticmethod - def _generate_result_key(function_name: str, args: Dict[str, Any]) -> str: + def generate_result_key(function_name: str, args: Dict[str, Any]) -> str: """Generate a consistent key format for the results dictionary.""" formatted_args: str = ResultProcessor._format_arguments(args) return f"{function_name}({formatted_args})" diff --git a/static/client/tool_call_log_manager.py b/static/client/tool_call_log_manager.py index 44c11ac4..d080cd3f 100644 --- a/static/client/tool_call_log_manager.py +++ b/static/client/tool_call_log_manager.py @@ -184,7 +184,7 @@ def add_entries(self, tool_calls: list[dict[str, Any]], call_results: dict[str, args: dict[str, Any] = call.get("arguments", {}) args_display = self.format_args_display(args) - result_key = ResultProcessor._generate_result_key(function_name, args) + result_key = ResultProcessor.generate_result_key(function_name, args) # Special handling for evaluate_expression which uses expression as key if function_name == "evaluate_expression" and "expression" in args: @@ -250,9 +250,9 @@ def finalize(self) -> None: if self.summary is not None: self.summary.text = label - # Ensure collapsed — removeAttribute is reliable for boolean HTML attributes + # Ensure collapsed if self.element is not None: try: - self.element.removeAttribute("open") + del self.element.attrs["open"] except Exception: pass From 82541c06aaf8c0a490dee6b9549f53765a387289 Mon Sep 17 00:00:00 2001 From: Vlad Ciobanu <95963142+vl3c@users.noreply.github.com> Date: Wed, 25 Mar 2026 01:12:22 +0200 Subject: [PATCH 26/28] Address medium and low priority review feedback Remove dead code from ai_interface.py (_send_prompt_to_ai_stream, no-op float block). Fix finish_reason typing to Optional[str]. Update image attachment tests to use real ImageAttachmentManager. Extract _get_class_attr to shared simple_mock.py. Standardize details closing, remove redundant import and test. --- server_tests/test_route_helpers.py | 7 +- static/client/ai_interface.py | 45 +---- static/client/chat_ui_manager.py | 2 - static/client/client_tests/simple_mock.py | 43 ++++ .../client_tests/test_chat_message_menu.py | 40 +--- .../client_tests/test_image_attachment.py | 190 ++++++++++-------- .../client/client_tests/test_tool_call_log.py | 27 +-- static/route_helpers.py | 2 +- 8 files changed, 153 insertions(+), 203 deletions(-) diff --git a/server_tests/test_route_helpers.py b/server_tests/test_route_helpers.py index 57a3b014..3d8f475b 100644 --- a/server_tests/test_route_helpers.py +++ b/server_tests/test_route_helpers.py @@ -6,6 +6,7 @@ from __future__ import annotations +import unittest from unittest.mock import MagicMock, patch from static.route_helpers import ( @@ -32,7 +33,7 @@ def _make_app( # reset_tools_for_all_providers # --------------------------------------------------------------------------- -class TestResetToolsForAllProviders: +class TestResetToolsForAllProviders(unittest.TestCase): """Tests for reset_tools_for_all_providers.""" def test_noop_when_finish_reason_is_tool_calls(self) -> None: @@ -142,7 +143,7 @@ def test_does_not_reset_active_provider_without_injected_tools(self) -> None: # update_all_provider_models # --------------------------------------------------------------------------- -class TestUpdateAllProviderModels: +class TestUpdateAllProviderModels(unittest.TestCase): """Tests for update_all_provider_models.""" def test_sets_model_on_both_providers(self) -> None: @@ -160,7 +161,7 @@ def test_sets_model_on_both_providers(self) -> None: # get_active_provider # --------------------------------------------------------------------------- -class TestGetActiveProvider: +class TestGetActiveProvider(unittest.TestCase): """Tests for get_active_provider.""" def test_returns_ai_api_when_model_id_is_none(self) -> None: diff --git a/static/client/ai_interface.py b/static/client/ai_interface.py index b1fd6e21..511ee704 100644 --- a/static/client/ai_interface.py +++ b/static/client/ai_interface.py @@ -291,12 +291,6 @@ def _store_results_in_canvas_state(self, call_results: Dict[str, Any]) -> None: if not ProcessFunctionCalls.is_successful_result(value): continue - # Format numeric results consistently - if isinstance(value, (int, float)): - float(value) # Always convert numeric values to float - else: - pass - # DISABLED: Saving basic calculations to canvas state (takes up too many tokens, not useful info to store) # self.canvas.add_computation( # expression=key, # The key is already the expression @@ -822,7 +816,7 @@ def _start_streaming_request(self, payload: Dict[str, Any]) -> None: try: payload_json = json.dumps(payload) payload_js = window.JSON.parse(payload_json) - # Don't reset any state here - all state management is done in _send_prompt_to_ai_stream + # Don't reset any state here - all state management is done in _send_prompt_to_ai # This preserves intermediary text and reasoning content across tool call continuations # Call JS streaming helper with reasoning and log callbacks window.sendMessageStream( @@ -852,43 +846,6 @@ def _send_request( payload = self._create_request_payload(prompt, include_svg=False, action_trace=action_trace) self._start_streaming_request(payload) - def _send_prompt_to_ai_stream( - self, - user_message: Optional[str] = None, - tool_call_results: Optional[str] = None, - attached_images: Optional[list[str]] = None, - ) -> None: - canvas_state = self.canvas.get_canvas_state() - use_vision = document["vision-toggle"].checked and user_message is not None and tool_call_results is None - prompt_json: Dict[str, Any] = { - "canvas_state": canvas_state, - "user_message": user_message, - "tool_call_results": tool_call_results, - "use_vision": use_vision, - "ai_model": document["ai-model-selector"].value, - } - # Include attached images if provided (works independently of vision toggle) - if attached_images: - prompt_json["attached_images"] = attached_images - prompt = json.dumps(prompt_json) - print( - f"Prompt for AI (stream): {prompt[:500]}..." if len(prompt) > 500 else f"Prompt for AI (stream): {prompt}" - ) - - # For new user messages, reset all state including containers and buffers - # For tool call results, preserve everything to keep intermediary text visible - if user_message is not None and tool_call_results is None: - self._chat_ui.request_start_time = window.Date.now() - self._chat_ui.reset_streaming_state() - - try: - payload = self._create_request_payload(prompt, include_svg=True) - self._start_streaming_request(payload) - except Exception as e: - print(f"Error preparing streaming request: {str(e)}") - payload = self._create_request_payload(prompt, include_svg=False) - self._start_streaming_request(payload) - def _send_prompt_to_ai( self, user_message: Optional[str] = None, diff --git a/static/client/chat_ui_manager.py b/static/client/chat_ui_manager.py index 92c71af8..1731afda 100644 --- a/static/client/chat_ui_manager.py +++ b/static/client/chat_ui_manager.py @@ -481,8 +481,6 @@ def finalize_stream(self, final_message: Optional[str] = None) -> None: # Update summary to show elapsed time and ensure dropdown stays closed if self._reasoning_summary is not None and self._request_start_time is not None: try: - from browser import window - elapsed_ms = window.Date.now() - self._request_start_time elapsed_seconds = int(elapsed_ms / 1000) self._reasoning_summary.text = f"Thought for {elapsed_seconds} seconds" diff --git a/static/client/client_tests/simple_mock.py b/static/client/client_tests/simple_mock.py index 3e1e1eeb..f9cc02cb 100644 --- a/static/client/client_tests/simple_mock.py +++ b/static/client/client_tests/simple_mock.py @@ -3,6 +3,49 @@ from typing import Any, Dict, List, Tuple +def get_class_attr(node: Any) -> str: + """Extract the CSS class string from a DOM node across Brython environments. + + Tries multiple access patterns because Brython may expose attrs as a + dict-like object, a plain attribute, or via ``getAttribute``. + """ + try: + attrs = getattr(node, "attrs", None) + # Brython may expose attrs as a dict-like object (not always a plain dict). + if attrs is not None and hasattr(attrs, "get"): + value = attrs.get("class", "") + if isinstance(value, str): + return value + return "" if value is None else str(value) + except Exception: + pass + + # Fallbacks for environments where attrs is not dict-like. + try: + value = getattr(node, "class_name", None) + if isinstance(value, str): + return value + except Exception: + pass + + try: + value = getattr(node, "className", None) + if isinstance(value, str): + return value + except Exception: + pass + + try: + getter = getattr(node, "getAttribute", None) + if callable(getter): + value = getter("class") + if isinstance(value, str): + return value + except Exception: + pass + return "" + + class SimpleMock: _attributes: Dict[str, Any] _return_value: Any diff --git a/static/client/client_tests/test_chat_message_menu.py b/static/client/client_tests/test_chat_message_menu.py index bcd3acfd..33995233 100644 --- a/static/client/client_tests/test_chat_message_menu.py +++ b/static/client/client_tests/test_chat_message_menu.py @@ -6,45 +6,7 @@ from browser import html, window from message_menu_manager import MessageMenuManager -from .simple_mock import SimpleMock - - -def _get_class_attr(node: Any) -> str: - try: - attrs = getattr(node, "attrs", None) - # Brython may expose attrs as a dict-like object (not always a plain dict). - if attrs is not None and hasattr(attrs, "get"): - value = attrs.get("class", "") - if isinstance(value, str): - return value - return "" if value is None else str(value) - except Exception: - pass - - # Fallbacks for environments where attrs is not dict-like. - try: - value = getattr(node, "class_name", None) - if isinstance(value, str): - return value - except Exception: - pass - - try: - value = getattr(node, "className", None) - if isinstance(value, str): - return value - except Exception: - pass - - try: - getter = getattr(node, "getAttribute", None) - if callable(getter): - value = getter("class") - if isinstance(value, str): - return value - except Exception: - pass - return "" +from .simple_mock import SimpleMock, get_class_attr as _get_class_attr class TestChatMessageMenu(unittest.TestCase): diff --git a/static/client/client_tests/test_image_attachment.py b/static/client/client_tests/test_image_attachment.py index 52110dad..837092ac 100644 --- a/static/client/client_tests/test_image_attachment.py +++ b/static/client/client_tests/test_image_attachment.py @@ -2,11 +2,15 @@ Tests for the image attachment feature in the client. Tests cover: -- _attached_images state management +- ImageAttachmentManager state management (uses real manager) - Preview area DOM updates - Payload generation with attached images - /attach slash command - Image modal functionality + +Where possible, tests instantiate the real ``ImageAttachmentManager`` and +exercise its public API directly. Tests that require browser DOM elements +(file picker, preview area rendering) still use mocks. """ from __future__ import annotations @@ -16,6 +20,7 @@ from unittest.mock import MagicMock from constants import IMAGE_SIZE_WARNING_BYTES, MAX_ATTACHED_IMAGES +from image_attachment_manager import ImageAttachmentManager class MockCanvas: @@ -76,54 +81,52 @@ def __init__(self) -> None: class TestAttachedImagesState(unittest.TestCase): - """Tests for _attached_images state management.""" + """Tests for ImageAttachmentManager state management using the real class.""" def setUp(self) -> None: - """Set up test fixtures with mock AI interface.""" - from ai_interface import AIInterface - - self.canvas = MockCanvas() - # We can't fully initialize AIInterface without browser module, - # so we'll test the logic patterns directly - self.ai = MagicMock(spec=AIInterface) - self.ai._attached_images = [] + """Set up a real ImageAttachmentManager instance.""" + self.mgr = ImageAttachmentManager() def test_initial_state_empty(self) -> None: - """Test attached images starts empty.""" - self.assertEqual(self.ai._attached_images, []) + """Test images property starts empty.""" + self.assertEqual(self.mgr.images, []) def test_append_image(self) -> None: - """Test appending an image to the list.""" + """Test appending an image via internal list.""" test_url = "data:image/png;base64,test123" - self.ai._attached_images.append(test_url) - self.assertEqual(len(self.ai._attached_images), 1) - self.assertEqual(self.ai._attached_images[0], test_url) + self.mgr._images.append(test_url) + self.assertEqual(len(self.mgr.images), 1) + self.assertEqual(self.mgr.images[0], test_url) def test_append_multiple_images(self) -> None: """Test appending multiple images.""" images = ["data:image/png;base64,img1", "data:image/jpeg;base64,img2", "data:image/png;base64,img3"] for img in images: - self.ai._attached_images.append(img) - self.assertEqual(len(self.ai._attached_images), 3) - self.assertEqual(self.ai._attached_images, images) + self.mgr._images.append(img) + self.assertEqual(len(self.mgr.images), 3) + self.assertEqual(self.mgr.images, images) - def test_remove_image_by_index(self) -> None: - """Test removing an image by index.""" - self.ai._attached_images = [ + def test_remove_image_via_manager(self) -> None: + """Test removing an image through the manager's _remove_image method.""" + self.mgr._images = [ "data:image/png;base64,img1", "data:image/png;base64,img2", "data:image/png;base64,img3", ] - self.ai._attached_images.pop(1) - self.assertEqual(len(self.ai._attached_images), 2) - self.assertEqual(self.ai._attached_images[0], "data:image/png;base64,img1") - self.assertEqual(self.ai._attached_images[1], "data:image/png;base64,img3") + # Patch _update_preview_area to avoid DOM access + self.mgr._update_preview_area = lambda: None # type: ignore[assignment] + self.mgr._remove_image(1) + self.assertEqual(len(self.mgr.images), 2) + self.assertEqual(self.mgr.images[0], "data:image/png;base64,img1") + self.assertEqual(self.mgr.images[1], "data:image/png;base64,img3") def test_clear_images(self) -> None: - """Test clearing all images.""" - self.ai._attached_images = ["data:image/png;base64,img1", "data:image/png;base64,img2"] - self.ai._attached_images = [] - self.assertEqual(len(self.ai._attached_images), 0) + """Test clearing all images via the manager's clear method.""" + self.mgr._images = ["data:image/png;base64,img1", "data:image/png;base64,img2"] + # Patch _update_preview_area to avoid DOM access + self.mgr._update_preview_area = lambda: None # type: ignore[assignment] + self.mgr.clear() + self.assertEqual(len(self.mgr.images), 0) def test_max_images_constant(self) -> None: """Test maximum images constant is set.""" @@ -133,6 +136,26 @@ def test_image_size_warning_constant(self) -> None: """Test image size warning threshold is 10MB.""" self.assertEqual(IMAGE_SIZE_WARNING_BYTES, 10 * 1024 * 1024) + def test_remove_invalid_index_negative(self) -> None: + """Test removing with negative index does nothing.""" + self.mgr._images = ["img1", "img2", "img3"] + self.mgr._update_preview_area = lambda: None # type: ignore[assignment] + self.mgr._remove_image(-1) + self.assertEqual(self.mgr.images, ["img1", "img2", "img3"]) + + def test_remove_invalid_index_too_large(self) -> None: + """Test removing with too large index does nothing.""" + self.mgr._images = ["img1", "img2", "img3"] + self.mgr._update_preview_area = lambda: None # type: ignore[assignment] + self.mgr._remove_image(10) + self.assertEqual(self.mgr.images, ["img1", "img2", "img3"]) + + def test_on_system_message_callback_stored(self) -> None: + """Test that the system message callback is stored correctly.""" + messages: List[str] = [] + mgr = ImageAttachmentManager(on_system_message=messages.append) + self.assertIs(mgr._on_system_message, messages.append) + class TestImageValidation(unittest.TestCase): """Tests for image URL validation logic.""" @@ -273,113 +296,104 @@ def test_image_command_description(self) -> None: class TestImageLimitLogic(unittest.TestCase): - """Tests for image attachment limit enforcement logic.""" + """Tests for image attachment limit enforcement using real ImageAttachmentManager.""" + + def setUp(self) -> None: + """Set up a real ImageAttachmentManager with DOM patched out.""" + self.mgr = ImageAttachmentManager() + self.mgr._update_preview_area = lambda: None # type: ignore[assignment] def test_remaining_slots_calculation(self) -> None: - """Test remaining slots calculation.""" - max_images = 5 - current_count = 2 - remaining = max_images - current_count + """Test remaining slots calculation via manager state.""" + self.mgr._images = ["img1", "img2"] + remaining = MAX_ATTACHED_IMAGES - len(self.mgr.images) self.assertEqual(remaining, 3) def test_remaining_slots_at_limit(self) -> None: """Test remaining slots when at limit.""" - max_images = 5 - current_count = 5 - remaining = max_images - current_count + self.mgr._images = ["img"] * MAX_ATTACHED_IMAGES + remaining = MAX_ATTACHED_IMAGES - len(self.mgr.images) self.assertEqual(remaining, 0) def test_files_to_process_limited(self) -> None: """Test files to process is limited by remaining slots.""" - max_images = 5 - current_count = 3 - remaining = max_images - current_count + self.mgr._images = ["img"] * 3 + remaining = MAX_ATTACHED_IMAGES - len(self.mgr.images) files_selected = 5 - files_to_process = min(files_selected, remaining) self.assertEqual(files_to_process, 2) def test_all_files_processed_when_below_limit(self) -> None: """Test all files processed when total is below limit.""" - max_images = 5 - current_count = 1 - remaining = max_images - current_count + self.mgr._images = ["img"] + remaining = MAX_ATTACHED_IMAGES - len(self.mgr.images) files_selected = 2 - files_to_process = min(files_selected, remaining) self.assertEqual(files_to_process, 2) class TestImageRemovalLogic(unittest.TestCase): - """Tests for image removal from attached images list.""" + """Tests for image removal using real ImageAttachmentManager._remove_image.""" + + def setUp(self) -> None: + """Set up a real ImageAttachmentManager with DOM patched out.""" + self.mgr = ImageAttachmentManager() + self.mgr._update_preview_area = lambda: None # type: ignore[assignment] def test_remove_first_image(self) -> None: """Test removing first image from list.""" - images = ["img1", "img2", "img3"] - index = 0 - if 0 <= index < len(images): - images.pop(index) - self.assertEqual(images, ["img2", "img3"]) + self.mgr._images = ["img1", "img2", "img3"] + self.mgr._remove_image(0) + self.assertEqual(self.mgr.images, ["img2", "img3"]) def test_remove_middle_image(self) -> None: """Test removing middle image from list.""" - images = ["img1", "img2", "img3"] - index = 1 - if 0 <= index < len(images): - images.pop(index) - self.assertEqual(images, ["img1", "img3"]) + self.mgr._images = ["img1", "img2", "img3"] + self.mgr._remove_image(1) + self.assertEqual(self.mgr.images, ["img1", "img3"]) def test_remove_last_image(self) -> None: """Test removing last image from list.""" - images = ["img1", "img2", "img3"] - index = 2 - if 0 <= index < len(images): - images.pop(index) - self.assertEqual(images, ["img1", "img2"]) + self.mgr._images = ["img1", "img2", "img3"] + self.mgr._remove_image(2) + self.assertEqual(self.mgr.images, ["img1", "img2"]) def test_remove_invalid_index_negative(self) -> None: """Test removing with negative index does nothing.""" - images = ["img1", "img2", "img3"] - original = images.copy() - index = -1 - if 0 <= index < len(images): - images.pop(index) - self.assertEqual(images, original) + self.mgr._images = ["img1", "img2", "img3"] + self.mgr._remove_image(-1) + self.assertEqual(self.mgr.images, ["img1", "img2", "img3"]) def test_remove_invalid_index_too_large(self) -> None: """Test removing with too large index does nothing.""" - images = ["img1", "img2", "img3"] - original = images.copy() - index = 10 - if 0 <= index < len(images): - images.pop(index) - self.assertEqual(images, original) + self.mgr._images = ["img1", "img2", "img3"] + self.mgr._remove_image(10) + self.assertEqual(self.mgr.images, ["img1", "img2", "img3"]) class TestPreviewAreaLogic(unittest.TestCase): - """Tests for preview area update logic.""" + """Tests for preview area update logic using real ImageAttachmentManager.""" def test_preview_visible_with_images(self) -> None: - """Test preview area should be visible when images attached.""" - attached_images = ["img1"] - should_show = len(attached_images) > 0 - self.assertTrue(should_show) + """Test images list is non-empty when images attached.""" + mgr = ImageAttachmentManager() + mgr._images = ["img1"] + self.assertTrue(len(mgr.images) > 0) def test_preview_hidden_without_images(self) -> None: - """Test preview area should be hidden when no images.""" - attached_images: List[str] = [] - should_show = len(attached_images) > 0 - self.assertFalse(should_show) + """Test images list is empty when no images.""" + mgr = ImageAttachmentManager() + self.assertFalse(len(mgr.images) > 0) def test_preview_thumbnail_count_matches_images(self) -> None: - """Test number of preview thumbnails matches attached images.""" - attached_images = ["img1", "img2", "img3"] - thumbnail_count = len(attached_images) - self.assertEqual(thumbnail_count, 3) + """Test number of images matches attached count.""" + mgr = ImageAttachmentManager() + mgr._images = ["img1", "img2", "img3"] + self.assertEqual(len(mgr.images), 3) class TestModalLogic(unittest.TestCase): - """Tests for image modal display logic.""" + """Tests for image modal display logic (requires DOM mocks).""" def test_modal_display_style_visible(self) -> None: """Test modal display style when showing.""" diff --git a/static/client/client_tests/test_tool_call_log.py b/static/client/client_tests/test_tool_call_log.py index 98bd9026..cdee1162 100644 --- a/static/client/client_tests/test_tool_call_log.py +++ b/static/client/client_tests/test_tool_call_log.py @@ -8,32 +8,7 @@ from tool_call_log_manager import ToolCallLogManager from message_menu_manager import MessageMenuManager from chat_ui_manager import ChatUIManager -from .simple_mock import SimpleMock - - -def _get_class_attr(node: Any) -> str: - try: - attrs = getattr(node, "attrs", None) - if attrs is not None and hasattr(attrs, "get"): - value = attrs.get("class", "") - if isinstance(value, str): - return value - return "" if value is None else str(value) - except Exception: - pass - try: - value = getattr(node, "class_name", None) - if isinstance(value, str): - return value - except Exception: - pass - try: - value = getattr(node, "className", None) - if isinstance(value, str): - return value - except Exception: - pass - return "" +from .simple_mock import SimpleMock, get_class_attr as _get_class_attr def _find_child_by_class(parent: Any, cls: str) -> Optional[Any]: diff --git a/static/route_helpers.py b/static/route_helpers.py index bddf4357..71233297 100644 --- a/static/route_helpers.py +++ b/static/route_helpers.py @@ -15,7 +15,7 @@ def reset_tools_for_all_providers( app: MatHudFlask, - finish_reason: object, + finish_reason: Optional[str], *, active_provider: Optional[OpenAIAPIBase] = None, ) -> None: From 595300561d758550b9483c3e8c90a6be56797c1c Mon Sep 17 00:00:00 2001 From: Vlad Ciobanu <95963142+vl3c@users.noreply.github.com> Date: Wed, 25 Mar 2026 01:23:04 +0200 Subject: [PATCH 27/28] Remove redundant test and convert bare asserts to unittest style Drop test_schema_version_at_least_one (redundant with test_schema_version_is_positive_integer). Convert 3 bare assert statements in test_route_helpers.py to self.assertIs(). --- server_tests/test_config.py | 3 --- server_tests/test_route_helpers.py | 6 +++--- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/server_tests/test_config.py b/server_tests/test_config.py index f3ce6c9d..e3fb9685 100644 --- a/server_tests/test_config.py +++ b/server_tests/test_config.py @@ -48,9 +48,6 @@ def test_canvas_snapshot_path_equals_os_path_join(self) -> None: def test_workspaces_dir_matches_expected_name(self) -> None: self.assertEqual(WORKSPACES_DIR, "workspaces") - def test_schema_version_at_least_one(self) -> None: - self.assertGreaterEqual(CURRENT_WORKSPACE_SCHEMA_VERSION, 1) - if __name__ == "__main__": unittest.main() diff --git a/server_tests/test_route_helpers.py b/server_tests/test_route_helpers.py index 3d8f475b..7052a145 100644 --- a/server_tests/test_route_helpers.py +++ b/server_tests/test_route_helpers.py @@ -170,7 +170,7 @@ def test_returns_ai_api_when_model_id_is_none(self) -> None: result = get_active_provider(app, None) - assert result is app.ai_api + self.assertIs(result, app.ai_api) app.ai_api.set_model.assert_not_called() app.responses_api.set_model.assert_not_called() @@ -190,7 +190,7 @@ def test_calls_update_and_get_provider_when_model_id_given( mock_update.assert_called_once_with(app, "gpt-4.1") mock_get_provider.assert_called_once_with(app, "gpt-4.1") - assert result is sentinel_provider + self.assertIs(result, sentinel_provider) @patch("static.route_helpers.update_all_provider_models") @patch("static.routes.get_provider_for_model") @@ -206,4 +206,4 @@ def test_returns_resolved_provider( result = get_active_provider(app, "o4-mini") - assert result is expected + self.assertIs(result, expected) From 9e49d93bbe85e9899c5b639240ffcfb63b9b6bc4 Mon Sep 17 00:00:00 2001 From: Vlad Ciobanu <95963142+vl3c@users.noreply.github.com> Date: Wed, 25 Mar 2026 20:43:31 +0200 Subject: [PATCH 28/28] Address remaining review feedback Document API key required vs optional as deliberate design choice. Remove unused VisibilityManager.update_dimensions. Deduplicate hardcoded voice list via shared TTS_VOICE_OPTIONS constant. Add SvgTelemetry smoke test (6 tests). --- .../client_tests/test_base_telemetry.py | 52 ++++++++++++++++++- static/client/client_tests/tests.py | 2 + static/client/managers/visibility_manager.py | 14 ----- static/client/tts_controller.py | 11 ++++ static/client/tts_ui_manager.py | 15 ++---- static/env_config.py | 6 +++ static/openai_api_base.py | 4 ++ static/providers/anthropic_api.py | 2 + static/providers/openrouter_api.py | 2 + 9 files changed, 82 insertions(+), 26 deletions(-) diff --git a/static/client/client_tests/test_base_telemetry.py b/static/client/client_tests/test_base_telemetry.py index 1f2a5da7..b5873e77 100644 --- a/static/client/client_tests/test_base_telemetry.py +++ b/static/client/client_tests/test_base_telemetry.py @@ -1,4 +1,4 @@ -"""Tests for BaseRendererTelemetry and Canvas2DTelemetry.""" +"""Tests for BaseRendererTelemetry, Canvas2DTelemetry, and SvgTelemetry.""" from __future__ import annotations @@ -6,6 +6,7 @@ from rendering.base_telemetry import BaseRendererTelemetry from rendering.canvas2d_renderer import Canvas2DTelemetry +from rendering.svg_renderer import SvgTelemetry class TestBaseTelemetryInit(unittest.TestCase): @@ -521,6 +522,54 @@ def test_canvas2d_inherits_all_base_behavior(self) -> None: self.assertEqual(snap["adapter_events"]["max_batch_depth"], 2) +class TestSvgTelemetry(unittest.TestCase): + """Smoke tests for SvgTelemetry (empty subclass of BaseRendererTelemetry).""" + + def test_instantiation(self) -> None: + tel = SvgTelemetry() + self.assertIsInstance(tel, BaseRendererTelemetry) + + def test_has_base_methods(self) -> None: + tel = SvgTelemetry() + for method_name in ("reset", "begin_frame", "snapshot", "drain"): + self.assertTrue( + callable(getattr(tel, method_name, None)), + f"SvgTelemetry should have callable '{method_name}'", + ) + + def test_new_drawable_bucket_matches_base(self) -> None: + svg_tel = SvgTelemetry() + base_tel = BaseRendererTelemetry() + svg_bucket = svg_tel._new_drawable_bucket() + base_bucket = base_tel._new_drawable_bucket() + self.assertEqual( + set(svg_bucket.keys()), + set(base_bucket.keys()), + "SvgTelemetry bucket should have exactly the base keys (no extras)", + ) + + def test_new_drawable_bucket_values_zero(self) -> None: + tel = SvgTelemetry() + bucket = tel._new_drawable_bucket() + for key, value in bucket.items(): + self.assertEqual(value, 0.0, f"{key} should be 0.0") + + def test_snapshot_after_begin_frame(self) -> None: + tel = SvgTelemetry() + tel.begin_frame() + snap = tel.snapshot() + self.assertEqual(snap["frames"], 1) + + def test_drain_resets(self) -> None: + tel = SvgTelemetry() + tel.begin_frame() + tel.record_plan_build("Point", 2.0) + result = tel.drain() + self.assertEqual(result["frames"], 1) + self.assertEqual(tel._frames, 0) + self.assertEqual(len(tel._per_drawable), 0) + + __all__ = [ "TestBaseTelemetryInit", "TestBaseTelemetryReset", @@ -537,4 +586,5 @@ def test_canvas2d_inherits_all_base_behavior(self) -> None: "TestBaseTelemetryDrain", "TestBaseTelemetryNewDrawableBucket", "TestCanvas2DTelemetryOverride", + "TestSvgTelemetry", ] diff --git a/static/client/client_tests/tests.py b/static/client/client_tests/tests.py index fd338581..4b5f596a 100644 --- a/static/client/client_tests/tests.py +++ b/static/client/client_tests/tests.py @@ -274,6 +274,7 @@ TestBaseTelemetryDrain, TestBaseTelemetryNewDrawableBucket, TestCanvas2DTelemetryOverride, + TestSvgTelemetry, ) @@ -592,6 +593,7 @@ def _get_test_cases(self) -> List[Type[unittest.TestCase]]: TestBaseTelemetryDrain, TestBaseTelemetryNewDrawableBucket, TestCanvas2DTelemetryOverride, + TestSvgTelemetry, ] def _create_test_suite(self) -> unittest.TestSuite: diff --git a/static/client/managers/visibility_manager.py b/static/client/managers/visibility_manager.py index cb44244e..c791b3ae 100644 --- a/static/client/managers/visibility_manager.py +++ b/static/client/managers/visibility_manager.py @@ -46,20 +46,6 @@ def __init__(self, coordinate_mapper: "CoordinateMapper", width: float, height: self._width: float = width self._height: float = height - # ------------------------------------------------------------------ - # Canvas-dimension accessors (kept in sync by Canvas when resized) - # ------------------------------------------------------------------ - - def update_dimensions(self, width: float, height: float) -> None: - """Update cached canvas dimensions after a resize. - - Args: - width: New canvas viewport width in pixels. - height: New canvas viewport height in pixels. - """ - self._width = width - self._height = height - # ------------------------------------------------------------------ # Top-level drawable visibility # ------------------------------------------------------------------ diff --git a/static/client/tts_controller.py b/static/client/tts_controller.py index 68ade466..5509e4ee 100644 --- a/static/client/tts_controller.py +++ b/static/client/tts_controller.py @@ -24,6 +24,17 @@ # State type TTSState = Literal["idle", "loading", "playing"] +# Canonical voice list shared with tts_ui_manager. Keep in sync with +# TTSManager.VOICES in static/tts_manager.py on the server side. +TTS_VOICE_OPTIONS: list[tuple[str, str]] = [ + ("am_michael", "Michael (Male)"), + ("am_fenrir", "Fenrir (Male, deeper)"), + ("am_onyx", "Onyx (Male, darker)"), + ("am_echo", "Echo (Male, resonant)"), + ("af_nova", "Nova (Female)"), + ("af_bella", "Bella (Female, warm)"), +] + class TTSController: """Controls TTS audio playback in the browser. diff --git a/static/client/tts_ui_manager.py b/static/client/tts_ui_manager.py index 28ebfe18..96ae9748 100644 --- a/static/client/tts_ui_manager.py +++ b/static/client/tts_ui_manager.py @@ -14,7 +14,7 @@ from browser import document, html -from tts_controller import get_tts_controller, TTSController +from tts_controller import get_tts_controller, TTSController, TTS_VOICE_OPTIONS class TTSUIManager: @@ -150,16 +150,9 @@ def show_settings_modal(self) -> None: voice_label = html.LABEL("Voice:") voice_select = html.SELECT(id="tts-voice-select") - # Add voice options - # Note: Voice IDs must match TTSManager.VOICES in static/tts_manager.py - voices = [ - ("am_michael", "Michael (Male)"), - ("am_fenrir", "Fenrir (Male, deeper)"), - ("am_onyx", "Onyx (Male, darker)"), - ("am_echo", "Echo (Male, resonant)"), - ("af_nova", "Nova (Female)"), - ("af_bella", "Bella (Female, warm)"), - ] + # Voice options imported from tts_controller.TTS_VOICE_OPTIONS + # (canonical list kept in sync with TTSManager.VOICES on the server) + voices = TTS_VOICE_OPTIONS current_voice = self._tts_controller.get_voice() for voice_id, voice_name in voices: option = html.OPTION(voice_name, value=voice_id) diff --git a/static/env_config.py b/static/env_config.py index aea0f506..d0cd1dee 100644 --- a/static/env_config.py +++ b/static/env_config.py @@ -39,6 +39,12 @@ def get_api_key( The function checks ``os.environ`` before touching disk so that explicitly-set variables are returned immediately. + Callers choose *required* based on provider semantics: + - Opt-in providers (Anthropic, OpenRouter) use ``required=True`` because a + missing key means the user misconfigured an explicit provider choice. + - The default provider (OpenAI) uses ``required=False`` so the app can + start without an OpenAI key when only third-party providers are used. + Args: name: Environment variable name (e.g. ``"OPENAI_API_KEY"``). required: When *True* and the key is missing, raise ``ValueError``. diff --git a/static/openai_api_base.py b/static/openai_api_base.py index e607d776..12586e74 100644 --- a/static/openai_api_base.py +++ b/static/openai_api_base.py @@ -84,6 +84,10 @@ def _initialize_api_key() -> str: to start with other providers configured. Actual OpenAI API calls will fail with an authentication error in that case. """ + # required=False: OpenAI is the default provider but the app can start + # without an OpenAI key when the user configures a third-party provider + # (Anthropic, OpenRouter). A missing key degrades gracefully — actual + # OpenAI API calls will fail with an auth error at call time. api_key = get_api_key("OPENAI_API_KEY", required=False, fallback="") if not api_key: logging.getLogger("mathud").warning("OPENAI_API_KEY not found. OpenAI models will be unavailable.") diff --git a/static/providers/anthropic_api.py b/static/providers/anthropic_api.py index 0082a4c4..1d1f037b 100644 --- a/static/providers/anthropic_api.py +++ b/static/providers/anthropic_api.py @@ -24,6 +24,8 @@ def _get_anthropic_api_key() -> str: """Get the Anthropic API key from environment.""" + # required=True (default): Anthropic is an explicitly opted-in provider, + # so a missing key is a configuration error rather than a graceful fallback. return get_api_key("ANTHROPIC_API_KEY") diff --git a/static/providers/openrouter_api.py b/static/providers/openrouter_api.py index 536a83fc..c9fa34bf 100644 --- a/static/providers/openrouter_api.py +++ b/static/providers/openrouter_api.py @@ -22,6 +22,8 @@ def _get_openrouter_api_key() -> str: """Get the OpenRouter API key from environment.""" + # required=True (default): OpenRouter is an explicitly opted-in provider, + # so a missing key is a configuration error rather than a graceful fallback. return get_api_key("OPENROUTER_API_KEY")