From c2cbbed5964459e8b9ee6682cd6371a9c5684fbc Mon Sep 17 00:00:00 2001 From: Vlad Ciobanu <95963142+vl3c@users.noreply.github.com> Date: Tue, 17 Feb 2026 20:36:57 +0200 Subject: [PATCH 1/3] Add type stubs for Brython browser module imports MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add .pyi stubs for the browser package (browser/__init__.pyi, browser/ajax.pyi, browser/aio.pyi) covering all from-browser imports used in client-side Python. Fix pre-existing mypy issues in canvas2d_primitive_adapter.py exposed by the new stubs. Update tests in test_browser_typing_stubs.py. All 1065 tests pass (0 failures). Mypy errors in the changed files are pre-existing (647 errors on main, 576 on this branch — net improvement). --- app.py | 42 +- cli/browser.py | 12 +- cli/config.py | 1 + cli/server.py | 7 +- cli/tests.py | 6 +- diagrams/scripts/generate_arch.py | 182 +- diagrams/scripts/generate_brython_diagrams.py | 558 +- diagrams/scripts/generate_diagrams.py | 339 +- diagrams/scripts/setup_diagram_tools.py | 54 +- diagrams/scripts/utils.py | 26 +- .../metrics/project_metrics_analyzer.py | 353 +- generate_diagrams_launcher.py | 2 +- mypy.ini | 3 +- run_server_tests.py | 6 +- scripts/canvas_prompt_telemetry_report.py | 2 +- scripts/linear_algebra_expected_values.py | 33 +- server_tests/__init__.py | 5 +- server_tests/client_renderer/__init__.py | 6 +- .../client_renderer/renderer_fixtures.py | 12 +- .../test_canvas2d_primitive_adapter.py | 12 +- .../test_polar_renderer_plan.py | 12 +- .../test_renderer_factory_plan.py | 4 +- .../test_webgl_primitive_adapter.py | 16 +- server_tests/python_path_setup.py | 3 +- server_tests/test_adaptive_sampler.py | 46 +- server_tests/test_ai_model.py | 38 +- server_tests/test_browser_typing_stubs.py | 1412 +++++ server_tests/test_canvas_state_summarizer.py | 3 +- server_tests/test_cli/test_config.py | 21 +- server_tests/test_cli/test_server.py | 14 +- server_tests/test_coordinate_mapper.py | 115 +- .../test_coordinate_system_manager.py | 4 +- .../test_function_renderable_paths.py | 9 +- server_tests/test_image_attachment.py | 241 +- server_tests/test_local_provider_base.py | 2 +- server_tests/test_markdown_parser.py | 63 +- server_tests/test_mocks.py | 120 +- server_tests/test_ollama_api.py | 1 + server_tests/test_ollama_integration.py | 98 +- server_tests/test_openai_api_base.py | 206 +- server_tests/test_openai_completions_api.py | 134 +- server_tests/test_openai_responses_api.py | 253 +- server_tests/test_plot_tool_schemas.py | 18 +- server_tests/test_polar_conversion.py | 2 +- server_tests/test_polar_grid.py | 26 +- server_tests/test_position.py | 2 +- server_tests/test_provider_connections.py | 23 +- server_tests/test_regression_pure.py | 5 +- server_tests/test_routes.py | 378 +- server_tests/test_search_tool_wiring_smoke.py | 4 +- server_tests/test_statistics_pure.py | 2 - server_tests/test_tool_argument_validator.py | 109 +- server_tests/test_tool_discovery_live.py | 24 +- server_tests/test_tool_search_service.py | 80 +- server_tests/test_tts_manager.py | 5 +- server_tests/test_workspace_management.py | 150 +- static/ai_model.py | 4 +- static/app_manager.py | 34 +- static/client/ai_interface.py | 295 +- static/client/browser.pyi | 55 - static/client/canvas.py | 207 +- static/client/canvas_event_handler.py | 38 +- static/client/cartesian_system_2axis.py | 14 +- .../client_tests/ai_result_formatter.py | 6 +- .../renderer_performance_tests.py | 5 +- static/client/client_tests/simple_mock.py | 4 +- .../test_action_trace_collector.py | 74 +- static/client/client_tests/test_angle.py | 136 +- .../client/client_tests/test_angle_manager.py | 64 +- .../client/client_tests/test_arc_manager.py | 18 +- .../test_area_expression_evaluator.py | 4 + .../client/client_tests/test_bar_manager.py | 2 - .../client/client_tests/test_bar_renderer.py | 6 +- static/client/client_tests/test_canvas.py | 661 +- static/client/client_tests/test_cartesian.py | 77 +- .../client_tests/test_chat_message_menu.py | 2 - static/client/client_tests/test_circle.py | 9 +- static/client/client_tests/test_circle_arc.py | 1 - .../client_tests/test_circle_manager.py | 12 +- .../test_closed_shape_colored_area.py | 2 - .../client_tests/test_colored_area_helpers.py | 1 - .../test_coordinate_system_manager.py | 6 +- .../test_custom_drawable_names.py | 33 +- static/client/client_tests/test_decagon.py | 1 - .../test_drawable_dependency_manager.py | 145 +- .../test_drawable_name_generator.py | 123 +- .../client_tests/test_drawable_renderers.py | 13 +- .../client_tests/test_drawables_container.py | 24 +- static/client/client_tests/test_ellipse.py | 13 +- .../client_tests/test_ellipse_manager.py | 8 +- .../client_tests/test_error_recovery.py | 4 +- .../client/client_tests/test_event_handler.py | 4 +- .../client_tests/test_expression_validator.py | 44 +- .../client/client_tests/test_font_helpers.py | 7 +- static/client/client_tests/test_function.py | 71 +- ...nction_bounded_colored_area_integration.py | 12 +- .../client_tests/test_function_calling.py | 77 +- .../client_tests/test_function_manager.py | 4 +- .../client_tests/test_function_renderables.py | 108 +- ...t_function_segment_bounded_colored_area.py | 57 +- .../test_functions_bounded_colored_area.py | 46 +- .../client_tests/test_generic_polygon.py | 1 - .../client_tests/test_geometry_utils.py | 1 - .../client_tests/test_graph_analyzer.py | 39 +- .../client/client_tests/test_graph_layout.py | 431 +- .../client/client_tests/test_graph_manager.py | 10 +- .../client/client_tests/test_graph_utils.py | 22 +- static/client/client_tests/test_heptagon.py | 1 - static/client/client_tests/test_hexagon.py | 1 - .../client_tests/test_image_attachment.py | 57 +- .../client/client_tests/test_intersections.py | 27 +- static/client/client_tests/test_label.py | 13 +- .../test_label_overlap_resolver.py | 2 - .../client_tests/test_linear_algebra_utils.py | 1 - .../client_tests/test_math_functions.py | 532 +- static/client/client_tests/test_nonagon.py | 1 - .../client_tests/test_numeric_solver.py | 11 +- static/client/client_tests/test_octagon.py | 1 - .../client_tests/test_optimized_renderers.py | 13 +- .../client/client_tests/test_path_elements.py | 1 - static/client/client_tests/test_pentagon.py | 1 - .../test_periodicity_detection.py | 26 +- .../client_tests/test_piecewise_function.py | 19 +- static/client/client_tests/test_point.py | 7 +- .../client/client_tests/test_point_manager.py | 12 +- static/client/client_tests/test_polar_grid.py | 4 +- .../test_polygon_canonicalizer.py | 9 +- .../client_tests/test_polygon_manager.py | 10 +- .../client/client_tests/test_quadrilateral.py | 2 +- static/client/client_tests/test_rectangle.py | 4 +- static/client/client_tests/test_region.py | 3 - .../client_tests/test_relation_inspector.py | 85 +- .../client_tests/test_renderer_edge_cases.py | 41 +- .../client_tests/test_renderer_logic.py | 2 +- .../client_tests/test_renderer_primitives.py | 38 +- .../test_result_processor_traced.py | 63 +- .../test_screen_offset_label_layout.py | 9 +- static/client/client_tests/test_segment.py | 14 +- .../client_tests/test_segment_manager.py | 1 - .../test_segments_bounded_colored_area.py | 63 +- .../client_tests/test_slash_commands.py | 49 +- .../test_statistics_distributions.py | 3 - .../client_tests/test_statistics_manager.py | 8 +- .../client_tests/test_tangent_manager.py | 4 +- static/client/client_tests/test_throttle.py | 3 +- .../client/client_tests/test_tool_call_log.py | 5 +- .../test_transformations_manager.py | 8 +- static/client/client_tests/test_transforms.py | 1 + static/client/client_tests/test_triangle.py | 5 +- .../client_tests/test_undo_redo_manager.py | 1 - static/client/client_tests/test_vector.py | 5 +- .../client_tests/test_vector_manager.py | 1 - .../client/client_tests/test_window_mocks.py | 5 +- .../client_tests/test_workspace_plots.py | 2 - static/client/client_tests/test_zoom.py | 3 +- static/client/client_tests/tests.py | 3 +- static/client/command_autocomplete.py | 21 +- static/client/coordinate_mapper.py | 63 +- static/client/drawables/angle.py | 112 +- static/client/drawables/attached_label.py | 2 - static/client/drawables/bar.py | 6 +- static/client/drawables/bars_plot.py | 2 - static/client/drawables/circle.py | 11 +- static/client/drawables/circle_arc.py | 7 +- .../drawables/closed_shape_colored_area.py | 9 +- static/client/drawables/colored_area.py | 12 +- static/client/drawables/continuous_plot.py | 3 - static/client/drawables/decagon.py | 1 - static/client/drawables/directed_graph.py | 4 +- static/client/drawables/discrete_plot.py | 2 - static/client/drawables/drawable.py | 1 + static/client/drawables/ellipse.py | 27 +- static/client/drawables/function.py | 62 +- .../function_segment_bounded_colored_area.py | 393 +- .../functions_bounded_colored_area.py | 101 +- static/client/drawables/generic_polygon.py | 1 - static/client/drawables/heptagon.py | 1 - static/client/drawables/hexagon.py | 1 - static/client/drawables/label.py | 4 +- static/client/drawables/label_render_mode.py | 3 - static/client/drawables/nonagon.py | 1 - static/client/drawables/octagon.py | 1 - .../client/drawables/parametric_function.py | 14 +- static/client/drawables/pentagon.py | 1 - static/client/drawables/piecewise_function.py | 57 +- .../drawables/piecewise_function_interval.py | 3 +- static/client/drawables/plot.py | 1 - static/client/drawables/point.py | 9 +- static/client/drawables/polygon.py | 4 +- static/client/drawables/position.py | 3 +- static/client/drawables/quadrilateral.py | 1 - static/client/drawables/rectangle.py | 39 +- static/client/drawables/segment.py | 11 +- .../segments_bounded_colored_area.py | 42 +- static/client/drawables/triangle.py | 36 +- static/client/drawables/undirected_graph.py | 4 +- static/client/drawables/vector.py | 4 +- static/client/drawables_aggregator.py | 34 +- static/client/expression_evaluator.py | 16 +- static/client/expression_validator.py | 376 +- static/client/function_registry.py | 49 +- static/client/geometry/__init__.py | 28 +- static/client/geometry/graph_state.py | 3 - static/client/geometry/path/__init__.py | 27 +- static/client/geometry/path/circular_arc.py | 11 +- static/client/geometry/path/composite_path.py | 7 +- static/client/geometry/path/elliptical_arc.py | 26 +- static/client/geometry/path/intersections.py | 70 +- static/client/geometry/path/line_segment.py | 1 - static/client/geometry/path/path_element.py | 1 - static/client/geometry/region.py | 67 +- static/client/main.py | 32 +- .../client/managers/action_trace_collector.py | 32 +- static/client/managers/angle_manager.py | 180 +- static/client/managers/arc_manager.py | 47 +- static/client/managers/bar_manager.py | 2 - static/client/managers/circle_manager.py | 9 +- .../client/managers/colored_area_manager.py | 25 +- .../client/managers/construction_manager.py | 73 +- static/client/managers/dependency_removal.py | 4 +- .../managers/drawable_dependency_manager.py | 224 +- static/client/managers/drawable_manager.py | 156 +- .../client/managers/drawable_manager_proxy.py | 1 + static/client/managers/drawables_container.py | 67 +- static/client/managers/edit_policy.py | 2 - static/client/managers/ellipse_manager.py | 19 +- static/client/managers/function_manager.py | 9 +- static/client/managers/graph_manager.py | 8 +- static/client/managers/label_manager.py | 1 - .../managers/parametric_function_manager.py | 4 +- .../managers/piecewise_function_manager.py | 12 +- static/client/managers/point_manager.py | 53 +- static/client/managers/polygon_manager.py | 4 +- static/client/managers/polygon_type.py | 1 - static/client/managers/segment_manager.py | 72 +- static/client/managers/statistics_manager.py | 43 +- static/client/managers/tangent_manager.py | 14 +- .../managers/transformations_manager.py | 19 +- static/client/managers/undo_redo_manager.py | 29 +- static/client/managers/vector_manager.py | 17 +- static/client/markdown_parser.py | 316 +- static/client/name_generator/__init__.py | 10 +- static/client/name_generator/arc.py | 20 +- static/client/name_generator/base.py | 2 +- static/client/name_generator/drawable.py | 25 +- static/client/name_generator/function.py | 30 +- static/client/name_generator/label.py | 1 - static/client/name_generator/point.py | 36 +- static/client/numeric_solver/__init__.py | 2 +- .../client/numeric_solver/expression_utils.py | 50 +- static/client/numeric_solver/solver.py | 17 +- static/client/polar_grid.py | 17 +- static/client/process_function_calls.py | 13 +- static/client/rendering/cached_render_plan.py | 69 +- .../rendering/canvas2d_primitive_adapter.py | 16 +- static/client/rendering/canvas2d_renderer.py | 42 +- static/client/rendering/factory.py | 9 +- static/client/rendering/helpers/__init__.py | 1 - .../rendering/helpers/angle_renderer.py | 17 +- .../client/rendering/helpers/area_builders.py | 1 - .../client/rendering/helpers/bar_renderer.py | 2 - .../rendering/helpers/cartesian_renderer.py | 163 +- .../rendering/helpers/circle_arc_renderer.py | 5 +- .../rendering/helpers/circle_renderer.py | 1 - .../helpers/colored_area_renderer.py | 1 - .../rendering/helpers/ellipse_renderer.py | 1 - .../client/rendering/helpers/font_helpers.py | 1 - .../rendering/helpers/function_renderer.py | 1 - .../helpers/label_overlap_resolver.py | 2 - .../rendering/helpers/label_renderer.py | 1 - .../helpers/parametric_function_renderer.py | 5 +- .../rendering/helpers/point_renderer.py | 1 - .../rendering/helpers/polar_renderer.py | 78 +- .../helpers/screen_offset_label_helper.py | 2 - .../helpers/screen_offset_label_layout.py | 15 +- .../rendering/helpers/segment_renderer.py | 1 - .../rendering/helpers/shape_decorator.py | 3 +- .../rendering/helpers/vector_renderer.py | 1 - .../rendering/helpers/world_label_helper.py | 6 +- static/client/rendering/interfaces.py | 26 +- static/client/rendering/primitives.py | 26 +- .../client/rendering/renderables/__init__.py | 1 - .../rendering/renderables/adaptive_sampler.py | 21 +- .../closed_shape_area_renderable.py | 1 - .../renderables/function_renderable.py | 108 +- .../function_segment_area_renderable.py | 13 +- .../renderables/functions_area_renderable.py | 22 +- .../renderables/segments_area_renderable.py | 5 +- static/client/rendering/style_manager.py | 14 - .../client/rendering/svg_primitive_adapter.py | 13 +- static/client/rendering/svg_renderer.py | 22 +- .../rendering/webgl_primitive_adapter.py | 1 - static/client/rendering/webgl_renderer.py | 20 +- static/client/result_processor.py | 166 +- static/client/slash_command_handler.py | 3 + static/client/test_runner.py | 313 +- static/client/tts_controller.py | 34 +- static/client/typing/browser/__init__.pyi | 214 + static/client/typing/browser/_dom.pyi | 70 + static/client/typing/browser/aio.pyi | 12 + static/client/typing/browser/ajax.pyi | 46 + .../client/utils/area_expression_evaluator.py | 47 +- .../client/utils/canonicalizers/__init__.py | 3 - static/client/utils/canonicalizers/common.py | 3 - .../utils/canonicalizers/quadrilateral.py | 26 +- .../client/utils/canonicalizers/triangle.py | 4 +- static/client/utils/computation_utils.py | 6 +- static/client/utils/geometry_utils.py | 87 +- static/client/utils/graph_analyzer.py | 52 +- static/client/utils/graph_layout.py | 143 +- static/client/utils/graph_utils.py | 12 +- static/client/utils/linear_algebra_utils.py | 4 +- static/client/utils/math_utils.py | 301 +- static/client/utils/polygon_canonicalizer.py | 3 - static/client/utils/polygon_subtypes.py | 3 +- static/client/utils/relation_inspector.py | 121 +- .../client/utils/statistics/distributions.py | 7 +- static/client/utils/statistics/regression.py | 25 +- static/client/utils/style_utils.py | 169 +- static/client/workspace_manager.py | 101 +- static/functions_definitions.py | 5558 ++++++++--------- static/log_manager.py | 24 +- static/mirror_client_modules.py | 1 - static/openai_api_base.py | 44 +- static/openai_completions_api.py | 54 +- static/openai_responses_api.py | 80 +- static/providers/__init__.py | 8 +- static/providers/anthropic_api.py | 99 +- static/providers/local/__init__.py | 73 +- static/providers/local/ollama_api.py | 24 +- static/providers/openrouter_api.py | 1 + static/routes.py | 477 +- static/tool_argument_validator.py | 38 +- static/tool_search_service.py | 42 +- static/tts_manager.py | 18 +- static/webdriver_manager.py | 28 +- static/workspace_manager.py | 17 +- 337 files changed, 12293 insertions(+), 10019 deletions(-) create mode 100644 server_tests/test_browser_typing_stubs.py delete mode 100644 static/client/browser.pyi create mode 100644 static/client/typing/browser/__init__.pyi create mode 100644 static/client/typing/browser/_dom.pyi create mode 100644 static/client/typing/browser/aio.pyi create mode 100644 static/client/typing/browser/ajax.pyi diff --git a/app.py b/app.py index 7a4bf742..2fe46a39 100644 --- a/app.py +++ b/app.py @@ -17,7 +17,7 @@ def signal_handler(sig: int, frame: FrameType | None) -> None: Cleans up WebDriver and Ollama resources and exits the application properly. """ - print('\nShutting down gracefully...') + print("\nShutting down gracefully...") # Clean up WebDriverManager if app.webdriver_manager is not None: try: @@ -28,6 +28,7 @@ def signal_handler(sig: int, frame: FrameType | None) -> None: # Clean up Ollama server if we started it try: from static.providers.local.ollama_api import OllamaAPI + OllamaAPI.stop_server() except Exception as e: print(f"Error stopping Ollama: {e}") @@ -42,54 +43,55 @@ def signal_handler(sig: int, frame: FrameType | None) -> None: # Register signal handler at module level for both run modes signal.signal(signal.SIGINT, signal_handler) -if __name__ == '__main__': +if __name__ == "__main__": """Main execution block. Starts Flask server in a daemon thread, initializes WebDriver for vision system, and maintains the main thread for graceful interrupt handling. """ # Parse command-line arguments - parser = argparse.ArgumentParser(description='MatHud Flask Application') + parser = argparse.ArgumentParser(description="MatHud Flask Application") parser.add_argument( - '-p', '--port', - type=int, - default=None, - help='Port to run the server on (default: 5000, or PORT env var)' + "-p", "--port", type=int, default=None, help="Port to run the server on (default: 5000, or PORT env var)" ) args = parser.parse_args() try: # Priority: CLI argument > environment variable > default (5000) - env_port = os.environ.get('PORT') + env_port = os.environ.get("PORT") port = args.port if args.port is not None else int(env_port or 5000) # Store port in app config for WebDriverManager to use - app.config['SERVER_PORT'] = port + app.config["SERVER_PORT"] = port # Check if we're running in a deployment environment is_deployed = args.port is None and env_port is not None - force_non_debug = os.environ.get('MATHUD_NON_DEBUG', '').lower() in ('1', 'true', 'yes') + force_non_debug = os.environ.get("MATHUD_NON_DEBUG", "").lower() in ("1", "true", "yes") # Enable debug mode for local development debug_mode = not (is_deployed or force_non_debug) if is_deployed: # For deployment: run Flask directly without threading - host = '0.0.0.0' # Bind to all interfaces for deployment + host = "0.0.0.0" # Bind to all interfaces for deployment print(f"Starting Flask app on {host}:{port} (deployment mode)") app.run(host=host, port=port, debug=False) else: # For local development: use threading approach with debug capability - host = '127.0.0.1' # Localhost for development + host = "127.0.0.1" # Localhost for development print(f"Starting Flask app on {host}:{port} (development mode, debug={debug_mode})") from threading import Thread - server = Thread(target=app.run, kwargs={ - 'host': host, - 'port': port, - 'debug': debug_mode, - 'use_reloader': False # Disable reloader in thread mode to avoid issues - }) + + server = Thread( + target=app.run, + kwargs={ + "host": host, + "port": port, + "debug": debug_mode, + "use_reloader": False, # Disable reloader in thread mode to avoid issues + }, + ) server.daemon = True # Make the server thread a daemon so it exits when main thread exits server.start() @@ -99,6 +101,7 @@ def signal_handler(sig: int, frame: FrameType | None) -> None: # Start Ollama server if installed (only in local development) try: from static.providers.local.ollama_api import OllamaAPI + if OllamaAPI.is_ollama_installed(): success, message = OllamaAPI.start_server(timeout=10) print(f"Ollama: {message}") @@ -110,8 +113,9 @@ def signal_handler(sig: int, frame: FrameType | None) -> None: # Initialize WebDriver (only in local development) if app.webdriver_manager is None: import requests + try: - requests.get(f'http://{host}:{port}/init_webdriver') + requests.get(f"http://{host}:{port}/init_webdriver") print("WebDriver initialized successfully") except Exception as e: print(f"Failed to initialize WebDriver: {str(e)}") diff --git a/cli/browser.py b/cli/browser.py index f8d1afde..ddc07a6e 100644 --- a/cli/browser.py +++ b/cli/browser.py @@ -88,9 +88,7 @@ def setup(self) -> None: options.add_argument("--remote-debugging-pipe") profiles_root = runtime_root / "profiles" profiles_root.mkdir(parents=True, exist_ok=True) - self._profile_dir = Path( - tempfile.mkdtemp(prefix="chrome-profile-", dir=str(profiles_root)) - ) + self._profile_dir = Path(tempfile.mkdtemp(prefix="chrome-profile-", dir=str(profiles_root))) options.add_argument(f"--user-data-dir={self._profile_dir}") if platform.machine() in ("aarch64", "arm64"): @@ -266,9 +264,7 @@ def wait_for_element( raise RuntimeError("Browser not initialized. Call setup() first.") try: - WebDriverWait(self.driver, timeout).until( - EC.presence_of_element_located((by, selector)) - ) + WebDriverWait(self.driver, timeout).until(EC.presence_of_element_located((by, selector))) return True except Exception: return False @@ -313,9 +309,7 @@ def get_canvas_state(self) -> dict[str, Any]: Returns: Canvas state dictionary. """ - result = self.execute_js( - "return window._canvas ? JSON.stringify(window._canvas.get_state()) : null" - ) + result = self.execute_js("return window._canvas ? JSON.stringify(window._canvas.get_state()) : null") if result: parsed: dict[str, Any] = json.loads(result) return parsed diff --git a/cli/config.py b/cli/config.py index e7b9f11d..4e0b57b7 100644 --- a/cli/config.py +++ b/cli/config.py @@ -43,6 +43,7 @@ # CLI output directory (for screenshots, etc.) CLI_OUTPUT_DIR = Path(__file__).parent / "output" + # Python interpreter path def get_python_path() -> Path: """Get the path to the Python interpreter in the virtual environment.""" diff --git a/cli/server.py b/cli/server.py index 86546296..bca952c2 100644 --- a/cli/server.py +++ b/cli/server.py @@ -262,8 +262,7 @@ def start( ) return ( False, - f"Port {self.port} is already in use. " - f"Choose another port or stop the existing listener.", + f"Port {self.port} is already in use. Choose another port or stop the existing listener.", ) else: owner_pid = self._find_listener_pid_on_port() @@ -275,8 +274,7 @@ def start( ) return ( False, - f"Port {self.port} is already in use. " - f"Choose another port or stop the existing listener.", + f"Port {self.port} is already in use. Choose another port or stop the existing listener.", ) # Set environment to disable auth for CLI operations @@ -451,6 +449,7 @@ def status(port: int, as_json: bool) -> None: if as_json: import json + click.echo(json.dumps(info)) else: if info["running"]: diff --git a/cli/tests.py b/cli/tests.py index b9fb96b6..2ca125b4 100644 --- a/cli/tests.py +++ b/cli/tests.py @@ -451,6 +451,10 @@ def all_cmd(port: int, with_auth: bool, start_server: bool, skip_lint: bool) -> raise SystemExit(1) lint_label = "Lint + " if not skip_lint else "" - click.echo(click.style(f"\nAll passed! ({lint_label}Server + {results.get('tests_run', 0)} client tests)", fg="green", bold=True)) + click.echo( + click.style( + f"\nAll passed! ({lint_label}Server + {results.get('tests_run', 0)} client tests)", fg="green", bold=True + ) + ) if results.get("screenshot"): click.echo(f"\nScreenshot saved to: {results['screenshot']}") diff --git a/diagrams/scripts/generate_arch.py b/diagrams/scripts/generate_arch.py index 8473f23c..d50c18e2 100644 --- a/diagrams/scripts/generate_arch.py +++ b/diagrams/scripts/generate_arch.py @@ -54,8 +54,6 @@ def __init__( self.svg_dir.mkdir(exist_ok=True) (self.svg_dir / "architecture").mkdir(exist_ok=True) - - def clean_generated_folders(self) -> None: """Carefully delete all content from generated_png and generated_svg folders.""" print("Cleaning generated folders before architecture diagram generation...") @@ -94,6 +92,7 @@ def clean_generated_folders(self) -> None: elif item.is_dir(): # Recursively delete directory contents import shutil + shutil.rmtree(item) folder_dirs += 1 total_deleted_dirs += 1 @@ -123,8 +122,6 @@ def get_output_dir(self, fmt: str) -> Path: else: return self.png_dir / "architecture" - - def generate_system_overview_diagram(self) -> None: """Generate overall MatHud system architecture diagram.""" try: @@ -142,23 +139,20 @@ def generate_system_overview_diagram(self) -> None: output_dir = self.get_output_dir(fmt) diagram_path = output_dir / "system_overview" - with Diagram("MatHud System Overview", - filename=str(diagram_path), - show=False, - direction="TB", - outformat=fmt, - graph_attr={ - "fontname": DIAGRAM_FONT, - "fontsize": str(DIAGRAM_FONT_SIZE), - "dpi": "150" - }, - node_attr={ - "fontname": DIAGRAM_FONT, - "fontsize": str(DIAGRAM_FONT_SIZE), - "width": "1.2", - "height": "1.2" - }): - + with Diagram( + "MatHud System Overview", + filename=str(diagram_path), + show=False, + direction="TB", + outformat=fmt, + graph_attr={"fontname": DIAGRAM_FONT, "fontsize": str(DIAGRAM_FONT_SIZE), "dpi": "150"}, + node_attr={ + "fontname": DIAGRAM_FONT, + "fontsize": str(DIAGRAM_FONT_SIZE), + "width": "1.2", + "height": "1.2", + }, + ): # User Interface Layer user = Client("User Browser") @@ -242,8 +236,8 @@ def generate_system_overview_diagram(self) -> None: print(f" + System overview diagram: {diagram_path}.{fmt}") # Post-process SVG files to use configured font - if fmt == 'svg': - post_process_svg_fonts(output_dir / f'system_overview.{fmt}') + if fmt == "svg": + post_process_svg_fonts(output_dir / f"system_overview.{fmt}") except Exception as e: print(f"System overview diagram failed: {e}") @@ -261,23 +255,20 @@ def generate_ai_integration_diagram(self) -> None: output_dir = self.get_output_dir(fmt) diagram_path = output_dir / "ai_integration" - with Diagram("MatHud AI Integration & Function Call Flow", - filename=str(diagram_path), - show=False, - direction="LR", - outformat=fmt, - graph_attr={ - "fontname": DIAGRAM_FONT, - "fontsize": str(DIAGRAM_FONT_SIZE), - "dpi": "150" - }, - node_attr={ - "fontname": DIAGRAM_FONT, - "fontsize": str(DIAGRAM_FONT_SIZE), - "width": "1.2", - "height": "1.2" - }): - + with Diagram( + "MatHud AI Integration & Function Call Flow", + filename=str(diagram_path), + show=False, + direction="LR", + outformat=fmt, + graph_attr={"fontname": DIAGRAM_FONT, "fontsize": str(DIAGRAM_FONT_SIZE), "dpi": "150"}, + node_attr={ + "fontname": DIAGRAM_FONT, + "fontsize": str(DIAGRAM_FONT_SIZE), + "width": "1.2", + "height": "1.2", + }, + ): # Input Sources with Cluster("User Input"): user_text = Client("User Message\n(Math Problems)") @@ -340,8 +331,8 @@ def generate_ai_integration_diagram(self) -> None: print(f" + AI integration diagram: {diagram_path}.{fmt}") # Post-process SVG files to use configured font - if fmt == 'svg': - post_process_svg_fonts(output_dir / f'ai_integration.{fmt}') + if fmt == "svg": + post_process_svg_fonts(output_dir / f"ai_integration.{fmt}") except Exception as e: print(f"AI integration diagram failed: {e}") @@ -358,23 +349,20 @@ def generate_webdriver_flow_diagram(self) -> None: output_dir = self.get_output_dir(fmt) diagram_path = output_dir / "webdriver_flow" - with Diagram("MatHud Vision System & WebDriver Flow", - filename=str(diagram_path), - show=False, - direction="TB", - outformat=fmt, - graph_attr={ - "fontname": DIAGRAM_FONT, - "fontsize": str(DIAGRAM_FONT_SIZE), - "dpi": "150" - }, - node_attr={ - "fontname": DIAGRAM_FONT, - "fontsize": str(DIAGRAM_FONT_SIZE), - "width": "1.2", - "height": "1.2" - }): - + with Diagram( + "MatHud Vision System & WebDriver Flow", + filename=str(diagram_path), + show=False, + direction="TB", + outformat=fmt, + graph_attr={"fontname": DIAGRAM_FONT, "fontsize": str(DIAGRAM_FONT_SIZE), "dpi": "150"}, + node_attr={ + "fontname": DIAGRAM_FONT, + "fontsize": str(DIAGRAM_FONT_SIZE), + "width": "1.2", + "height": "1.2", + }, + ): # Trigger with Cluster("Vision Request Trigger"): user_request = Client("User Enables Vision\n+ Sends Message") @@ -432,8 +420,8 @@ def generate_webdriver_flow_diagram(self) -> None: print(f" + WebDriver flow diagram: {diagram_path}.{fmt}") # Post-process SVG files to use configured font - if fmt == 'svg': - post_process_svg_fonts(output_dir / f'webdriver_flow.{fmt}') + if fmt == "svg": + post_process_svg_fonts(output_dir / f"webdriver_flow.{fmt}") except Exception as e: print(f"WebDriver flow diagram failed: {e}") @@ -453,23 +441,20 @@ def generate_data_flow_diagram(self) -> None: output_dir = self.get_output_dir(fmt) diagram_path = output_dir / "data_flow" - with Diagram("MatHud Data Flow Pipeline", - filename=str(diagram_path), - show=False, - direction="LR", - outformat=fmt, - graph_attr={ - "fontname": DIAGRAM_FONT, - "fontsize": str(DIAGRAM_FONT_SIZE), - "dpi": "150" - }, - node_attr={ - "fontname": DIAGRAM_FONT, - "fontsize": str(DIAGRAM_FONT_SIZE), - "width": "1.2", - "height": "1.2" - }): - + with Diagram( + "MatHud Data Flow Pipeline", + filename=str(diagram_path), + show=False, + direction="LR", + outformat=fmt, + graph_attr={"fontname": DIAGRAM_FONT, "fontsize": str(DIAGRAM_FONT_SIZE), "dpi": "150"}, + node_attr={ + "fontname": DIAGRAM_FONT, + "fontsize": str(DIAGRAM_FONT_SIZE), + "width": "1.2", + "height": "1.2", + }, + ): # Input Stage with Cluster("Input Stage"): user_input = Client("User Input") @@ -525,8 +510,8 @@ def generate_data_flow_diagram(self) -> None: print(f" + Data flow diagram: {diagram_path}.{fmt}") # Post-process SVG files to use configured font - if fmt == 'svg': - post_process_svg_fonts(output_dir / f'data_flow.{fmt}') + if fmt == "svg": + post_process_svg_fonts(output_dir / f"data_flow.{fmt}") except Exception as e: print(f"Data flow diagram failed: {e}") @@ -542,23 +527,20 @@ def generate_manager_architecture_diagram(self) -> None: output_dir = self.get_output_dir(fmt) diagram_path = output_dir / "manager_architecture" - with Diagram("MatHud Manager Pattern Architecture", - filename=str(diagram_path), - show=False, - direction="TB", - outformat=fmt, - graph_attr={ - "fontname": DIAGRAM_FONT, - "fontsize": str(DIAGRAM_FONT_SIZE), - "dpi": "150" - }, - node_attr={ - "fontname": DIAGRAM_FONT, - "fontsize": str(DIAGRAM_FONT_SIZE), - "width": "1.2", - "height": "1.2" - }): - + with Diagram( + "MatHud Manager Pattern Architecture", + filename=str(diagram_path), + show=False, + direction="TB", + outformat=fmt, + graph_attr={"fontname": DIAGRAM_FONT, "fontsize": str(DIAGRAM_FONT_SIZE), "dpi": "150"}, + node_attr={ + "fontname": DIAGRAM_FONT, + "fontsize": str(DIAGRAM_FONT_SIZE), + "width": "1.2", + "height": "1.2", + }, + ): # Central Coordinator with Cluster("Central Canvas System"): canvas = Python("Canvas\n(SVG Manipulation)") @@ -638,8 +620,8 @@ def generate_manager_architecture_diagram(self) -> None: print(f" + Manager architecture diagram: {diagram_path}.{fmt}") # Post-process SVG files to use configured font - if fmt == 'svg': - post_process_svg_fonts(output_dir / f'manager_architecture.{fmt}') + if fmt == "svg": + post_process_svg_fonts(output_dir / f"manager_architecture.{fmt}") except Exception as e: print(f"Manager architecture diagram failed: {e}") @@ -703,5 +685,5 @@ def main() -> None: generator.generate_all_architecture_diagrams(clean_first=True) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/diagrams/scripts/generate_brython_diagrams.py b/diagrams/scripts/generate_brython_diagrams.py index ca10c9e7..6536d6a7 100644 --- a/diagrams/scripts/generate_brython_diagrams.py +++ b/diagrams/scripts/generate_brython_diagrams.py @@ -27,12 +27,7 @@ from typing import List, Sequence, Tuple # Import shared utilities -from utils import ( - setup_graphviz_path, - setup_font_environment, - post_process_svg_fonts, - DIAGRAM_FONT -) +from utils import setup_graphviz_path, setup_font_environment, post_process_svg_fonts, DIAGRAM_FONT class BrythonDiagramGenerator: @@ -95,29 +90,58 @@ def __init__( # System component definitions self.drawable_classes: List[str] = [ - 'point.py', 'segment.py', 'vector.py', 'triangle.py', 'rectangle.py', - 'circle.py', 'ellipse.py', 'angle.py', 'function.py', 'colored_area.py', - 'functions_bounded_colored_area.py', 'function_segment_bounded_colored_area.py', - 'segments_bounded_colored_area.py', 'rotatable_polygon.py', 'drawable.py' + "point.py", + "segment.py", + "vector.py", + "triangle.py", + "rectangle.py", + "circle.py", + "ellipse.py", + "angle.py", + "function.py", + "colored_area.py", + "functions_bounded_colored_area.py", + "function_segment_bounded_colored_area.py", + "segments_bounded_colored_area.py", + "rotatable_polygon.py", + "drawable.py", ] self.manager_classes: List[str] = [ - 'drawable_manager.py', 'point_manager.py', 'segment_manager.py', - 'vector_manager.py', 'polygon_manager.py', - 'circle_manager.py', 'ellipse_manager.py', 'angle_manager.py', - 'function_manager.py', 'colored_area_manager.py', 'drawable_dependency_manager.py', - 'drawable_manager_proxy.py', 'transformations_manager.py', 'undo_redo_manager.py', - 'drawables_container.py' + "drawable_manager.py", + "point_manager.py", + "segment_manager.py", + "vector_manager.py", + "polygon_manager.py", + "circle_manager.py", + "ellipse_manager.py", + "angle_manager.py", + "function_manager.py", + "colored_area_manager.py", + "drawable_dependency_manager.py", + "drawable_manager_proxy.py", + "transformations_manager.py", + "undo_redo_manager.py", + "drawables_container.py", ] self.core_system_files: List[str] = [ - 'canvas.py', 'ai_interface.py', 'canvas_event_handler.py', - 'workspace_manager.py', 'result_processor.py', 'process_function_calls.py' + "canvas.py", + "ai_interface.py", + "canvas_event_handler.py", + "workspace_manager.py", + "result_processor.py", + "process_function_calls.py", ] self.utility_files: List[str] = [ - 'expression_evaluator.py', 'expression_validator.py', 'markdown_parser.py', - 'function_registry.py', 'result_validator.py', 'constants.py', 'geometry.py' + "expression_evaluator.py", + "expression_validator.py", + "markdown_parser.py", + "function_registry.py", + "result_validator.py", + "constants.py", + "geometry.py", ] def get_brython_output_dir(self, fmt: str, subdir: str = "") -> Path: @@ -214,32 +238,37 @@ def generate_core_system_diagrams(self) -> None: for fmt in self.formats: output_dir = self.get_brython_output_dir(fmt, "core") cmd = [ - 'pyreverse', - '-o', fmt, - '-p', 'brython_core_classes', - '--output-directory', str(output_dir), - '--show-associated', '1', - '--show-ancestors', '1' + "pyreverse", + "-o", + fmt, + "-p", + "brython_core_classes", + "--output-directory", + str(output_dir), + "--show-associated", + "1", + "--show-ancestors", + "1", ] + core_files try: subprocess.run(cmd, check=True, capture_output=True, text=True) print(f" + Core system diagram generated: {output_dir}/classes_brython_core_classes.{fmt}") - if fmt == 'svg': - self._process_svg_font_and_count(output_dir / f'classes_brython_core_classes.{fmt}') + if fmt == "svg": + self._process_svg_font_and_count(output_dir / f"classes_brython_core_classes.{fmt}") except subprocess.CalledProcessError as e: print(f"Error generating core system diagram: {e.stderr}") # Individual core component diagrams important_core_modules: List[Tuple[str, str]] = [ - ('canvas.py', 'brython_canvas_system'), - ('ai_interface.py', 'brython_ai_interface'), - ('canvas_event_handler.py', 'brython_event_handling'), - ('workspace_manager.py', 'brython_workspace_client'), - ('result_processor.py', 'brython_result_processor'), - ('process_function_calls.py', 'brython_function_execution') + ("canvas.py", "brython_canvas_system"), + ("ai_interface.py", "brython_ai_interface"), + ("canvas_event_handler.py", "brython_event_handling"), + ("workspace_manager.py", "brython_workspace_client"), + ("result_processor.py", "brython_result_processor"), + ("process_function_calls.py", "brython_function_execution"), ] for module_file, diagram_name in important_core_modules: @@ -248,20 +277,24 @@ def generate_core_system_diagrams(self) -> None: for fmt in self.formats: output_dir = self.get_brython_output_dir(fmt, "core") cmd = [ - 'pyreverse', - '-o', fmt, - '-p', diagram_name, - '--output-directory', str(output_dir), - '--show-associated', '1', - str(module_path) + "pyreverse", + "-o", + fmt, + "-p", + diagram_name, + "--output-directory", + str(output_dir), + "--show-associated", + "1", + str(module_path), ] try: subprocess.run(cmd, check=True, capture_output=True, text=True) print(f" + {diagram_name} diagram generated: {output_dir}/classes_{diagram_name}.{fmt}") - if fmt == 'svg': - self._process_svg_font_and_count(output_dir / f'classes_{diagram_name}.{fmt}') + if fmt == "svg": + self._process_svg_font_and_count(output_dir / f"classes_{diagram_name}.{fmt}") except subprocess.CalledProcessError as e: print(f"Error generating {diagram_name} diagram: {e.stderr}") @@ -277,67 +310,96 @@ def generate_drawable_system_diagrams(self) -> None: # Complete drawable hierarchy diagram drawable_files: List[str] = [ - str(drawables_dir / module) - for module in self.drawable_classes - if (drawables_dir / module).exists() + str(drawables_dir / module) for module in self.drawable_classes if (drawables_dir / module).exists() ] if drawable_files: for fmt in self.formats: output_dir = self.get_brython_output_dir(fmt, "drawables") cmd = [ - 'pyreverse', - '-o', fmt, - '-p', 'brython_drawable_hierarchy', - '--output-directory', str(output_dir), - '--show-associated', '1', - '--show-ancestors', '1', - '--show-builtin', '0' + "pyreverse", + "-o", + fmt, + "-p", + "brython_drawable_hierarchy", + "--output-directory", + str(output_dir), + "--show-associated", + "1", + "--show-ancestors", + "1", + "--show-builtin", + "0", ] + drawable_files try: subprocess.run(cmd, check=True, capture_output=True, text=True) - print(f" + Drawable hierarchy diagram generated: {output_dir}/classes_brython_drawable_hierarchy.{fmt}") + print( + f" + Drawable hierarchy diagram generated: {output_dir}/classes_brython_drawable_hierarchy.{fmt}" + ) - if fmt == 'svg': - self._process_svg_font_and_count(output_dir / f'classes_brython_drawable_hierarchy.{fmt}') + if fmt == "svg": + self._process_svg_font_and_count(output_dir / f"classes_brython_drawable_hierarchy.{fmt}") except subprocess.CalledProcessError as e: print(f" Error generating drawable hierarchy diagram: {e.stderr}") # Specific drawable type diagrams drawable_categories: List[Tuple[List[str], str]] = [ - (['point.py', 'segment.py', 'vector.py', 'triangle.py', 'rectangle.py', 'circle.py', 'ellipse.py', 'angle.py'], 'brython_geometric_objects'), - (['function.py'], 'brython_function_plotting'), - (['colored_area.py', 'functions_bounded_colored_area.py', 'function_segment_bounded_colored_area.py', 'segments_bounded_colored_area.py'], 'brython_colored_areas'), - (['rotatable_polygon.py', 'drawable.py', 'position.py'], 'brython_base_drawable_system') + ( + [ + "point.py", + "segment.py", + "vector.py", + "triangle.py", + "rectangle.py", + "circle.py", + "ellipse.py", + "angle.py", + ], + "brython_geometric_objects", + ), + (["function.py"], "brython_function_plotting"), + ( + [ + "colored_area.py", + "functions_bounded_colored_area.py", + "function_segment_bounded_colored_area.py", + "segments_bounded_colored_area.py", + ], + "brython_colored_areas", + ), + (["rotatable_polygon.py", "drawable.py", "position.py"], "brython_base_drawable_system"), ] for category_files, diagram_name in drawable_categories: category_paths = [ - str(drawables_dir / module) - for module in category_files - if (drawables_dir / module).exists() + str(drawables_dir / module) for module in category_files if (drawables_dir / module).exists() ] if category_paths: for fmt in self.formats: output_dir = self.get_brython_output_dir(fmt, "drawables") cmd = [ - 'pyreverse', - '-o', fmt, - '-p', diagram_name, - '--output-directory', str(output_dir), - '--show-associated', '1', - '--show-ancestors', '1' + "pyreverse", + "-o", + fmt, + "-p", + diagram_name, + "--output-directory", + str(output_dir), + "--show-associated", + "1", + "--show-ancestors", + "1", ] + category_paths try: subprocess.run(cmd, check=True, capture_output=True, text=True) print(f" + {diagram_name} diagram generated: {output_dir}/classes_{diagram_name}.{fmt}") - if fmt == 'svg': - self._process_svg_font_and_count(output_dir / f'classes_{diagram_name}.{fmt}') + if fmt == "svg": + self._process_svg_font_and_count(output_dir / f"classes_{diagram_name}.{fmt}") except subprocess.CalledProcessError as e: print(f"Error generating {diagram_name} diagram: {e.stderr}") @@ -353,67 +415,95 @@ def generate_manager_system_diagrams(self) -> None: # Complete manager orchestration diagram manager_files: List[str] = [ - str(managers_dir / module) - for module in self.manager_classes - if (managers_dir / module).exists() + str(managers_dir / module) for module in self.manager_classes if (managers_dir / module).exists() ] if manager_files: for fmt in self.formats: output_dir = self.get_brython_output_dir(fmt, "managers") cmd = [ - 'pyreverse', - '-o', fmt, - '-p', 'brython_manager_orchestration', - '--output-directory', str(output_dir), - '--show-associated', '1', - '--show-ancestors', '1', - '--show-builtin', '0' + "pyreverse", + "-o", + fmt, + "-p", + "brython_manager_orchestration", + "--output-directory", + str(output_dir), + "--show-associated", + "1", + "--show-ancestors", + "1", + "--show-builtin", + "0", ] + manager_files try: subprocess.run(cmd, check=True, capture_output=True, text=True) - print(f" + Manager orchestration diagram generated: {output_dir}/classes_brython_manager_orchestration.{fmt}") + print( + f" + Manager orchestration diagram generated: {output_dir}/classes_brython_manager_orchestration.{fmt}" + ) - if fmt == 'svg': - self._process_svg_font_and_count(output_dir / f'classes_brython_manager_orchestration.{fmt}') + if fmt == "svg": + self._process_svg_font_and_count(output_dir / f"classes_brython_manager_orchestration.{fmt}") except subprocess.CalledProcessError as e: print(f" Error generating manager orchestration diagram: {e.stderr}") # Specific manager category diagrams manager_categories: List[Tuple[List[str], str]] = [ - (['point_manager.py', 'segment_manager.py', 'vector_manager.py', 'polygon_manager.py', 'circle_manager.py', 'ellipse_manager.py', 'angle_manager.py'], 'brython_shape_managers'), - (['function_manager.py', 'colored_area_manager.py'], 'brython_specialized_managers'), - (['drawable_manager.py', 'drawable_dependency_manager.py', 'drawable_manager_proxy.py', 'drawables_container.py'], 'brython_core_managers'), - (['transformations_manager.py', 'undo_redo_manager.py'], 'brython_system_managers') + ( + [ + "point_manager.py", + "segment_manager.py", + "vector_manager.py", + "polygon_manager.py", + "circle_manager.py", + "ellipse_manager.py", + "angle_manager.py", + ], + "brython_shape_managers", + ), + (["function_manager.py", "colored_area_manager.py"], "brython_specialized_managers"), + ( + [ + "drawable_manager.py", + "drawable_dependency_manager.py", + "drawable_manager_proxy.py", + "drawables_container.py", + ], + "brython_core_managers", + ), + (["transformations_manager.py", "undo_redo_manager.py"], "brython_system_managers"), ] for category_files, diagram_name in manager_categories: category_paths = [ - str(managers_dir / module) - for module in category_files - if (managers_dir / module).exists() + str(managers_dir / module) for module in category_files if (managers_dir / module).exists() ] if category_paths: for fmt in self.formats: output_dir = self.get_brython_output_dir(fmt, "managers") cmd = [ - 'pyreverse', - '-o', fmt, - '-p', diagram_name, - '--output-directory', str(output_dir), - '--show-associated', '1', - '--show-ancestors', '1' + "pyreverse", + "-o", + fmt, + "-p", + diagram_name, + "--output-directory", + str(output_dir), + "--show-associated", + "1", + "--show-ancestors", + "1", ] + category_paths try: subprocess.run(cmd, check=True, capture_output=True, text=True) print(f" + {diagram_name} diagram generated: {output_dir}/classes_{diagram_name}.{fmt}") - if fmt == 'svg': - self._process_svg_font_and_count(output_dir / f'classes_{diagram_name}.{fmt}') + if fmt == "svg": + self._process_svg_font_and_count(output_dir / f"classes_{diagram_name}.{fmt}") except subprocess.CalledProcessError as e: print(f"Error generating {diagram_name} diagram: {e.stderr}") @@ -424,10 +514,10 @@ def generate_integration_diagrams(self) -> None: # AJAX and AI integration components integration_files: List[str] = [ - str(self.brython_source_dir / 'ai_interface.py'), - str(self.brython_source_dir / 'result_processor.py'), - str(self.brython_source_dir / 'process_function_calls.py'), - str(self.brython_source_dir / 'workspace_manager.py') + str(self.brython_source_dir / "ai_interface.py"), + str(self.brython_source_dir / "result_processor.py"), + str(self.brython_source_dir / "process_function_calls.py"), + str(self.brython_source_dir / "workspace_manager.py"), ] existing_integration_files: List[str] = [ @@ -438,53 +528,66 @@ def generate_integration_diagrams(self) -> None: for fmt in self.formats: output_dir = self.get_brython_output_dir(fmt, "integration") cmd = [ - 'pyreverse', - '-o', fmt, - '-p', 'brython_ajax_communication', - '--output-directory', str(output_dir), - '--show-associated', '1', - '--show-ancestors', '1' + "pyreverse", + "-o", + fmt, + "-p", + "brython_ajax_communication", + "--output-directory", + str(output_dir), + "--show-associated", + "1", + "--show-ancestors", + "1", ] + existing_integration_files try: subprocess.run(cmd, check=True, capture_output=True, text=True) - print(f" + AJAX communication diagram generated: {output_dir}/classes_brython_ajax_communication.{fmt}") + print( + f" + AJAX communication diagram generated: {output_dir}/classes_brython_ajax_communication.{fmt}" + ) - if fmt == 'svg': - self._process_svg_font_and_count(output_dir / f'classes_brython_ajax_communication.{fmt}') + if fmt == "svg": + self._process_svg_font_and_count(output_dir / f"classes_brython_ajax_communication.{fmt}") except subprocess.CalledProcessError as e: print(f"Error generating AJAX communication diagram: {e.stderr}") # Function execution pipeline execution_files: List[str] = [ - str(self.brython_source_dir / 'process_function_calls.py'), - str(self.brython_source_dir / 'result_processor.py'), - str(self.brython_source_dir / 'expression_evaluator.py'), - str(self.brython_source_dir / 'result_validator.py') + str(self.brython_source_dir / "process_function_calls.py"), + str(self.brython_source_dir / "result_processor.py"), + str(self.brython_source_dir / "expression_evaluator.py"), + str(self.brython_source_dir / "result_validator.py"), ] - existing_execution_files: List[str] = [ - file_path for file_path in execution_files if Path(file_path).exists() - ] + existing_execution_files: List[str] = [file_path for file_path in execution_files if Path(file_path).exists()] if existing_execution_files: for fmt in self.formats: output_dir = self.get_brython_output_dir(fmt, "integration") cmd = [ - 'pyreverse', - '-o', fmt, - '-p', 'brython_function_execution_pipeline', - '--output-directory', str(output_dir), - '--show-associated', '1' + "pyreverse", + "-o", + fmt, + "-p", + "brython_function_execution_pipeline", + "--output-directory", + str(output_dir), + "--show-associated", + "1", ] + existing_execution_files try: subprocess.run(cmd, check=True, capture_output=True, text=True) - print(f" + Function execution pipeline diagram generated: {output_dir}/classes_brython_function_execution_pipeline.{fmt}") + print( + f" + Function execution pipeline diagram generated: {output_dir}/classes_brython_function_execution_pipeline.{fmt}" + ) - if fmt == 'svg': - self._process_svg_font_and_count(output_dir / f'classes_brython_function_execution_pipeline.{fmt}') + if fmt == "svg": + self._process_svg_font_and_count( + output_dir / f"classes_brython_function_execution_pipeline.{fmt}" + ) except subprocess.CalledProcessError as e: print(f"Error generating function execution pipeline diagram: {e.stderr}") @@ -495,9 +598,9 @@ def generate_utility_system_diagrams(self) -> None: # Expression and validation utilities validation_files = [ - str(self.brython_source_dir / 'expression_evaluator.py'), - str(self.brython_source_dir / 'expression_validator.py'), - str(self.brython_source_dir / 'result_validator.py') + str(self.brython_source_dir / "expression_evaluator.py"), + str(self.brython_source_dir / "expression_validator.py"), + str(self.brython_source_dir / "result_validator.py"), ] existing_validation_files = [f for f in validation_files if Path(f).exists()] @@ -506,27 +609,33 @@ def generate_utility_system_diagrams(self) -> None: for fmt in self.formats: output_dir = self.get_brython_output_dir(fmt, "utilities") cmd = [ - 'pyreverse', - '-o', fmt, - '-p', 'brython_expression_system', - '--output-directory', str(output_dir), - '--show-associated', '1' + "pyreverse", + "-o", + fmt, + "-p", + "brython_expression_system", + "--output-directory", + str(output_dir), + "--show-associated", + "1", ] + existing_validation_files try: subprocess.run(cmd, check=True, capture_output=True, text=True) - print(f" + Expression system diagram generated: {output_dir}/classes_brython_expression_system.{fmt}") + print( + f" + Expression system diagram generated: {output_dir}/classes_brython_expression_system.{fmt}" + ) - if fmt == 'svg': - self._process_svg_font_and_count(output_dir / f'classes_brython_expression_system.{fmt}') + if fmt == "svg": + self._process_svg_font_and_count(output_dir / f"classes_brython_expression_system.{fmt}") except subprocess.CalledProcessError as e: print(f"Error generating expression system diagram: {e.stderr}") # Content processing utilities content_files = [ - str(self.brython_source_dir / 'markdown_parser.py'), - str(self.brython_source_dir / 'function_registry.py') + str(self.brython_source_dir / "markdown_parser.py"), + str(self.brython_source_dir / "function_registry.py"), ] existing_content_files = [f for f in content_files if Path(f).exists()] @@ -535,19 +644,25 @@ def generate_utility_system_diagrams(self) -> None: for fmt in self.formats: output_dir = self.get_brython_output_dir(fmt, "utilities") cmd = [ - 'pyreverse', - '-o', fmt, - '-p', 'brython_content_processing', - '--output-directory', str(output_dir), - '--show-associated', '1' + "pyreverse", + "-o", + fmt, + "-p", + "brython_content_processing", + "--output-directory", + str(output_dir), + "--show-associated", + "1", ] + existing_content_files try: subprocess.run(cmd, check=True, capture_output=True, text=True) - print(f" + Content processing diagram generated: {output_dir}/classes_brython_content_processing.{fmt}") + print( + f" + Content processing diagram generated: {output_dir}/classes_brython_content_processing.{fmt}" + ) - if fmt == 'svg': - self._process_svg_font_and_count(output_dir / f'classes_brython_content_processing.{fmt}') + if fmt == "svg": + self._process_svg_font_and_count(output_dir / f"classes_brython_content_processing.{fmt}") except subprocess.CalledProcessError as e: print(f"Error generating content processing diagram: {e.stderr}") @@ -562,19 +677,22 @@ def generate_utility_system_diagrams(self) -> None: for fmt in self.formats: output_dir = self.get_brython_output_dir(fmt, "utilities") cmd = [ - 'pyreverse', - '-o', fmt, - '-p', 'brython_utils', - '--output-directory', str(output_dir), - str(utils_dir) + "pyreverse", + "-o", + fmt, + "-p", + "brython_utils", + "--output-directory", + str(output_dir), + str(utils_dir), ] try: subprocess.run(cmd, check=True, capture_output=True, text=True) print(f" + Utils system diagram generated: {output_dir}/classes_brython_utils.{fmt}") - if fmt == 'svg': - self._process_svg_font_and_count(output_dir / f'classes_brython_utils.{fmt}') + if fmt == "svg": + self._process_svg_font_and_count(output_dir / f"classes_brython_utils.{fmt}") except subprocess.CalledProcessError as e: print(f"Error generating utils system diagram: {e.stderr}") @@ -583,19 +701,22 @@ def generate_utility_system_diagrams(self) -> None: for fmt in self.formats: output_dir = self.get_brython_output_dir(fmt, "utilities") cmd = [ - 'pyreverse', - '-o', fmt, - '-p', 'brython_name_generator', - '--output-directory', str(output_dir), - str(name_gen_dir) + "pyreverse", + "-o", + fmt, + "-p", + "brython_name_generator", + "--output-directory", + str(output_dir), + str(name_gen_dir), ] try: subprocess.run(cmd, check=True, capture_output=True, text=True) print(f" + Name generator diagram generated: {output_dir}/classes_brython_name_generator.{fmt}") - if fmt == 'svg': - self._process_svg_font_and_count(output_dir / f'classes_brython_name_generator.{fmt}') + if fmt == "svg": + self._process_svg_font_and_count(output_dir / f"classes_brython_name_generator.{fmt}") except subprocess.CalledProcessError as e: print(f"Error generating name generator diagram: {e.stderr}") @@ -612,19 +733,22 @@ def generate_testing_diagrams(self) -> None: for fmt in self.formats: output_dir = self.get_brython_output_dir(fmt, "testing") cmd = [ - 'pyreverse', - '-o', fmt, - '-p', 'brython_test_framework', - '--output-directory', str(output_dir), - str(tests_dir) + "pyreverse", + "-o", + fmt, + "-p", + "brython_test_framework", + "--output-directory", + str(output_dir), + str(tests_dir), ] try: subprocess.run(cmd, check=True, capture_output=True, text=True) print(f" + Test framework diagram generated: {output_dir}/classes_brython_test_framework.{fmt}") - if fmt == 'svg': - self._process_svg_font_and_count(output_dir / f'classes_brython_test_framework.{fmt}') + if fmt == "svg": + self._process_svg_font_and_count(output_dir / f"classes_brython_test_framework.{fmt}") except subprocess.CalledProcessError as e: print(f"Error generating test framework diagram: {e.stderr}") @@ -634,19 +758,22 @@ def generate_testing_diagrams(self) -> None: for fmt in self.formats: output_dir = self.get_brython_output_dir(fmt, "testing") cmd = [ - 'pyreverse', - '-o', fmt, - '-p', 'brython_test_runner', - '--output-directory', str(output_dir), - str(test_runner_file) + "pyreverse", + "-o", + fmt, + "-p", + "brython_test_runner", + "--output-directory", + str(output_dir), + str(test_runner_file), ] try: subprocess.run(cmd, check=True, capture_output=True, text=True) print(f" + Test runner diagram generated: {output_dir}/classes_brython_test_runner.{fmt}") - if fmt == 'svg': - self._process_svg_font_and_count(output_dir / f'classes_brython_test_runner.{fmt}') + if fmt == "svg": + self._process_svg_font_and_count(output_dir / f"classes_brython_test_runner.{fmt}") except subprocess.CalledProcessError as e: print(f"Error generating test runner diagram: {e.stderr}") @@ -659,33 +786,39 @@ def generate_package_structure_diagrams(self) -> None: for fmt in self.formats: output_dir = self.get_brython_output_dir(fmt) cmd = [ - 'pyreverse', - '-o', fmt, - '-p', 'brython_complete_system', - '--output-directory', str(output_dir), - '--show-associated', '1', - '--show-ancestors', '1', - '-m', 'yes', # Show module names - str(self.brython_source_dir) + "pyreverse", + "-o", + fmt, + "-p", + "brython_complete_system", + "--output-directory", + str(output_dir), + "--show-associated", + "1", + "--show-ancestors", + "1", + "-m", + "yes", # Show module names + str(self.brython_source_dir), ] try: subprocess.run(cmd, check=True, capture_output=True, text=True) print(f" + Complete system diagram generated: {output_dir}/packages_brython_complete_system.{fmt}") - if fmt == 'svg': - self._process_svg_font_and_count(output_dir / f'packages_brython_complete_system.{fmt}') + if fmt == "svg": + self._process_svg_font_and_count(output_dir / f"packages_brython_complete_system.{fmt}") except subprocess.CalledProcessError as e: print(f"Error generating complete system diagram: {e.stderr}") # Individual package diagrams packages: List[Tuple[str, str]] = [ - ('drawables', 'brython_drawables_package'), - ('managers', 'brython_managers_package'), - ('utils', 'brython_utils_package'), - ('name_generator', 'brython_name_generator_package'), - ('client_tests', 'brython_tests_package') + ("drawables", "brython_drawables_package"), + ("managers", "brython_managers_package"), + ("utils", "brython_utils_package"), + ("name_generator", "brython_name_generator_package"), + ("client_tests", "brython_tests_package"), ] for package_name, diagram_name in packages: @@ -694,20 +827,26 @@ def generate_package_structure_diagrams(self) -> None: for fmt in self.formats: output_dir = self.get_brython_output_dir(fmt) cmd = [ - 'pyreverse', - '-o', fmt, - '-p', diagram_name, - '--output-directory', str(output_dir), - '-m', 'yes', - str(package_dir) + "pyreverse", + "-o", + fmt, + "-p", + diagram_name, + "--output-directory", + str(output_dir), + "-m", + "yes", + str(package_dir), ] try: subprocess.run(cmd, check=True, capture_output=True, text=True) - print(f" + {package_name} package diagram generated: {output_dir}/packages_{diagram_name}.{fmt}") + print( + f" + {package_name} package diagram generated: {output_dir}/packages_{diagram_name}.{fmt}" + ) - if fmt == 'svg': - self._process_svg_font_and_count(output_dir / f'packages_{diagram_name}.{fmt}') + if fmt == "svg": + self._process_svg_font_and_count(output_dir / f"packages_{diagram_name}.{fmt}") except subprocess.CalledProcessError as e: print(f"Error generating {package_name} package diagram: {e.stderr}") @@ -720,24 +859,21 @@ def run(self) -> None: def main() -> None: """Main function with command line argument parsing.""" parser = argparse.ArgumentParser(description="Generate comprehensive Brython client-side diagrams") - parser.add_argument('--png-dir', default='../generated_png', - help='Output directory for PNG diagrams (default: ../generated_png)') - parser.add_argument('--svg-dir', default='../generated_svg', - help='Output directory for SVG diagrams (default: ../generated_svg)') - parser.add_argument('--format', default='png,svg', - help='Output formats (default: png,svg)') + parser.add_argument( + "--png-dir", default="../generated_png", help="Output directory for PNG diagrams (default: ../generated_png)" + ) + parser.add_argument( + "--svg-dir", default="../generated_svg", help="Output directory for SVG diagrams (default: ../generated_svg)" + ) + parser.add_argument("--format", default="png,svg", help="Output formats (default: png,svg)") args = parser.parse_args() # Parse formats - formats: List[str] = [fmt.strip() for fmt in args.format.split(',') if fmt.strip()] + formats: List[str] = [fmt.strip() for fmt in args.format.split(",") if fmt.strip()] # Create and run generator - generator = BrythonDiagramGenerator( - png_dir=args.png_dir, - svg_dir=args.svg_dir, - formats=formats - ) + generator = BrythonDiagramGenerator(png_dir=args.png_dir, svg_dir=args.svg_dir, formats=formats) generator.run() diff --git a/diagrams/scripts/generate_diagrams.py b/diagrams/scripts/generate_diagrams.py index b5c5e997..5f6b7c75 100644 --- a/diagrams/scripts/generate_diagrams.py +++ b/diagrams/scripts/generate_diagrams.py @@ -57,13 +57,11 @@ def __init__( # Setup font configuration for all diagrams setup_font_environment() - - def get_output_dir(self, fmt: str) -> Path: """Get the appropriate output directory for a format.""" - if fmt == 'png': + if fmt == "png": return self.png_dir - elif fmt == 'svg': + elif fmt == "svg": return self.svg_dir else: # For other formats like dot, use svg directory @@ -72,7 +70,7 @@ def get_output_dir(self, fmt: str) -> Path: def get_server_output_dir(self, fmt: str) -> Path: """Get the server-specific output directory for a format.""" base_dir = self.get_output_dir(fmt) - server_dir = base_dir / 'server' + server_dir = base_dir / "server" server_dir.mkdir(parents=True, exist_ok=True) return server_dir @@ -83,16 +81,12 @@ def _update_fonts_and_count(self, svg_file: Path) -> None: def check_dependencies(self) -> None: """Check if required tools are installed.""" - tools: Dict[str, str] = { - 'pyreverse': 'pylint', - 'dot': 'graphviz', - 'python': 'python' - } + tools: Dict[str, str] = {"pyreverse": "pylint", "dot": "graphviz", "python": "python"} missing: List[str] = [] for tool, package in tools.items(): # Use 'where' on Windows, 'which' on Unix-like systems - cmd = 'where' if sys.platform == 'win32' else 'which' + cmd = "where" if sys.platform == "win32" else "which" try: result = subprocess.run([cmd, tool], capture_output=True, text=True) if result.returncode != 0: @@ -113,33 +107,38 @@ def generate_class_diagrams(self) -> None: # List of all Python files with classes class_files: List[str] = [ - 'static/app_manager.py', - 'static/openai_api.py', - 'static/webdriver_manager.py', - 'static/workspace_manager.py', - 'static/ai_model.py', - 'static/log_manager.py', - 'static/tool_call_processor.py' + "static/app_manager.py", + "static/openai_api.py", + "static/webdriver_manager.py", + "static/workspace_manager.py", + "static/ai_model.py", + "static/log_manager.py", + "static/tool_call_processor.py", ] for fmt in self.formats: output_dir = self.get_server_output_dir(fmt) cmd = [ - 'pyreverse', - '-o', fmt, - '-p', 'MatHud_AllClasses', - '--output-directory', str(output_dir), - '--show-associated', '1', - '--show-ancestors', '1' + "pyreverse", + "-o", + fmt, + "-p", + "MatHud_AllClasses", + "--output-directory", + str(output_dir), + "--show-associated", + "1", + "--show-ancestors", + "1", ] + class_files # Add all class files explicitly try: - subprocess.run(cmd, check=True, capture_output=True, text=True, cwd='../..') + subprocess.run(cmd, check=True, capture_output=True, text=True, cwd="../..") print(f" + Main class diagram generated: {output_dir}/classes_MatHud_AllClasses.{fmt}") # Post-process SVG files to use configured font - if fmt == 'svg': - self._update_fonts_and_count(output_dir / f'classes_MatHud_AllClasses.{fmt}') + if fmt == "svg": + self._update_fonts_and_count(output_dir / f"classes_MatHud_AllClasses.{fmt}") except subprocess.CalledProcessError as e: print(f" Error: Error generating main class diagram: {e.stderr}") @@ -151,25 +150,31 @@ def generate_package_diagrams(self) -> None: for fmt in self.formats: output_dir = self.get_server_output_dir(fmt) cmd = [ - 'pyreverse', - '-o', fmt, - '-p', 'MatHud_packages', - '--output-directory', str(output_dir), - '--show-associated', '1', - '--show-ancestors', '1', - '-m', 'yes', # Show module names - 'static/', # Main application code - 'app.py', # Entry point - 'run_server_tests.py' # Test code + "pyreverse", + "-o", + fmt, + "-p", + "MatHud_packages", + "--output-directory", + str(output_dir), + "--show-associated", + "1", + "--show-ancestors", + "1", + "-m", + "yes", # Show module names + "static/", # Main application code + "app.py", # Entry point + "run_server_tests.py", # Test code ] try: - subprocess.run(cmd, check=True, capture_output=True, text=True, cwd='../..') + subprocess.run(cmd, check=True, capture_output=True, text=True, cwd="../..") print(f" + Package diagram generated: {output_dir}/packages_MatHud_packages.{fmt}") # Post-process SVG files to use configured font - if fmt == 'svg': - self._update_fonts_and_count(output_dir / f'packages_MatHud_packages.{fmt}') + if fmt == "svg": + self._update_fonts_and_count(output_dir / f"packages_MatHud_packages.{fmt}") except subprocess.CalledProcessError as e: print(f" Error: Error generating package diagram: {e.stderr}") @@ -179,34 +184,37 @@ def generate_module_specific_diagrams(self) -> None: print("Generating module-specific diagrams...") important_modules: List[Tuple[str, str]] = [ - ('static/app_manager.py', 'AppManager'), - ('static/openai_api.py', 'OpenAI_API'), - ('static/webdriver_manager.py', 'WebDriver'), - ('static/workspace_manager.py', 'Workspace') + ("static/app_manager.py", "AppManager"), + ("static/openai_api.py", "OpenAI_API"), + ("static/webdriver_manager.py", "WebDriver"), + ("static/workspace_manager.py", "Workspace"), # Note: routes.py is handled separately in generate_flask_routes_diagram() ] for module_path, name in important_modules: # Use absolute path for checking existence - abs_module_path = Path('../..') / module_path + abs_module_path = Path("../..") / module_path if abs_module_path.exists(): for fmt in self.formats: output_dir = self.get_server_output_dir(fmt) cmd = [ - 'pyreverse', - '-o', fmt, - '-p', name, - '--output-directory', str(output_dir), - module_path # Use relative path for pyreverse + "pyreverse", + "-o", + fmt, + "-p", + name, + "--output-directory", + str(output_dir), + module_path, # Use relative path for pyreverse ] try: - subprocess.run(cmd, check=True, capture_output=True, text=True, cwd='../..') + subprocess.run(cmd, check=True, capture_output=True, text=True, cwd="../..") print(f" + {name} diagram generated: {output_dir}/classes_{name}.{fmt}") # Post-process SVG files to use configured font - if fmt == 'svg': - self._update_fonts_and_count(output_dir / f'classes_{name}.{fmt}') + if fmt == "svg": + self._update_fonts_and_count(output_dir / f"classes_{name}.{fmt}") except subprocess.CalledProcessError as e: print(f" Error: Error generating {name} diagram: {e.stderr}") @@ -231,7 +239,7 @@ def _generate_custom_routes_diagram(self) -> None: from pathlib import Path # Read routes.py content - routes_file = Path('../../static/routes.py') + routes_file = Path("../../static/routes.py") if not routes_file.exists(): print(" Warning: routes.py not found for custom analysis") return @@ -251,25 +259,25 @@ def _generate_custom_routes_diagram(self) -> None: # Save SVG version if "svg" in self.formats: - svg_output_dir = self.get_server_output_dir('svg') - svg_output_file = svg_output_dir / 'flask_routes_custom.svg' + svg_output_dir = self.get_server_output_dir("svg") + svg_output_file = svg_output_dir / "flask_routes_custom.svg" svg_output_file.write_text(svg_content) print(f" + Custom Flask routes diagram (SVG): {svg_output_file}") # Convert to PNG if needed if "png" in self.formats: - png_output_dir = self.get_server_output_dir('png') - png_output_file = png_output_dir / 'flask_routes_custom.png' + png_output_dir = self.get_server_output_dir("png") + png_output_file = png_output_dir / "flask_routes_custom.png" self._convert_svg_to_png(svg_output_file, png_output_file) # If only PNG format requested, create SVG first then convert elif "png" in self.formats: # Create temporary SVG - temp_svg = Path('temp_flask_routes.svg') + temp_svg = Path("temp_flask_routes.svg") temp_svg.write_text(svg_content) - png_output_dir = self.get_server_output_dir('png') - png_output_file = png_output_dir / 'flask_routes_custom.png' + png_output_dir = self.get_server_output_dir("png") + png_output_file = png_output_dir / "flask_routes_custom.png" self._convert_svg_to_png(temp_svg, png_output_file) # Clean up temporary file @@ -283,6 +291,7 @@ def _convert_svg_to_png(self, svg_file: Path, png_file: Path) -> None: """Convert SVG file to PNG using cairosvg.""" try: import cairosvg + cairosvg.svg2png(url=str(svg_file), write_to=str(png_file)) print(f" + Converted to PNG: {png_file}") except ImportError: @@ -296,9 +305,9 @@ def _convert_svg_to_png_fallback(self, svg_file: Path, png_file: Path) -> None: """Fallback SVG to PNG conversion using dot command.""" try: # Use dot (Graphviz) to convert SVG to PNG - subprocess.run([ - 'dot', '-Tpng', str(svg_file), '-o', str(png_file) - ], capture_output=True, text=True, check=True) + subprocess.run( + ["dot", "-Tpng", str(svg_file), "-o", str(png_file)], capture_output=True, text=True, check=True + ) print(f" + Converted to PNG (via dot): {png_file}") except subprocess.CalledProcessError as e: print(f" Warning: Fallback conversion failed: {e}") @@ -321,31 +330,31 @@ def _create_routes_svg( - + Flask Routes - + ''' # Add title svg += f''' -MatHud Flask API Routes +MatHud Flask API Routes ''' # Add routes y_offset = svg_height - 80 for i, (path, methods, func_name) in enumerate(routes): - methods_clean = methods.replace("'", "").replace('"', '') if methods else "GET" + methods_clean = methods.replace("'", "").replace('"', "") if methods else "GET" # Route box box_y = y_offset - (i * 70) svg += f''' - - - -{methods_clean} + + + +{methods_clean} {path} -{func_name}() +{func_name}() ''' @@ -354,17 +363,17 @@ def _create_routes_svg( svg += f''' - -All Functions: + +All Functions: ''' for i, func in enumerate(functions): - svg += f'{func}()\n' + svg += f'{func}()\n' - svg += '\n' + svg += "\n" # Close SVG - svg += '\n' + svg += "\n" return svg @@ -374,21 +383,25 @@ def _generate_pyreverse_routes_diagram(self) -> None: for fmt in self.formats: output_dir = self.get_server_output_dir(fmt) cmd = [ - 'pyreverse', - '-o', fmt, - '-p', 'FlaskRoutes', - '--output-directory', str(output_dir), - '--show-associated', '1', - 'static/routes.py' + "pyreverse", + "-o", + fmt, + "-p", + "FlaskRoutes", + "--output-directory", + str(output_dir), + "--show-associated", + "1", + "static/routes.py", ] try: - subprocess.run(cmd, check=True, capture_output=True, text=True, cwd='../..') + subprocess.run(cmd, check=True, capture_output=True, text=True, cwd="../..") print(f" + Pyreverse routes diagram: {output_dir}/classes_FlaskRoutes.{fmt}") # Post-process SVG files to use configured font - if fmt == 'svg': - self._update_fonts_and_count(output_dir / f'classes_FlaskRoutes.{fmt}') + if fmt == "svg": + self._update_fonts_and_count(output_dir / f"classes_FlaskRoutes.{fmt}") except subprocess.CalledProcessError: print(" Warning: Pyreverse routes minimal (functions only, no classes)") @@ -400,22 +413,26 @@ def _generate_function_call_diagram(self) -> None: """Generate function call relationships for routes.py.""" try: # Use pydeps to analyze function calls in routes.py specifically - output_dir = self.get_server_output_dir('svg') if "svg" in self.formats else self.get_server_output_dir('png') + output_dir = ( + self.get_server_output_dir("svg") if "svg" in self.formats else self.get_server_output_dir("png") + ) cmd = [ - 'pydeps', - '--show-deps', - '--max-bacon', '2', # Limited depth for function calls - '--no-show', - '-o', str(output_dir / 'routes_functions.svg'), - 'static/routes.py' + "pydeps", + "--show-deps", + "--max-bacon", + "2", # Limited depth for function calls + "--no-show", + "-o", + str(output_dir / "routes_functions.svg"), + "static/routes.py", ] - subprocess.run(cmd, check=True, capture_output=True, text=True, cwd='../..') + subprocess.run(cmd, check=True, capture_output=True, text=True, cwd="../..") print(f" + Routes function calls: {output_dir}/routes_functions.svg") # Post-process SVG file to use configured font - self._update_fonts_and_count(output_dir / 'routes_functions.svg') + self._update_fonts_and_count(output_dir / "routes_functions.svg") except subprocess.CalledProcessError: print(" Warning: Routes function call analysis failed") @@ -428,14 +445,14 @@ def generate_function_analysis(self) -> None: # Files that may benefit from function analysis files_to_analyze: List[Tuple[str, str]] = [ - ('static/functions_definitions.py', 'FunctionDefinitions'), - ('app.py', 'AppMain'), - ('run_server_tests.py', 'server_tests') + ("static/functions_definitions.py", "FunctionDefinitions"), + ("app.py", "AppMain"), + ("run_server_tests.py", "server_tests"), ] for file_path, name in files_to_analyze: # Check if file exists - abs_file_path = Path('../..') / file_path + abs_file_path = Path("../..") / file_path if abs_file_path.exists(): self._generate_file_function_analysis(file_path, name) @@ -443,25 +460,30 @@ def _generate_file_function_analysis(self, file_path: str, name: str) -> None: """Generate function analysis for a specific file.""" try: # Use pydeps for enhanced function call visualization - output_dir = self.get_server_output_dir('svg') if "svg" in self.formats else self.get_server_output_dir('png') + output_dir = ( + self.get_server_output_dir("svg") if "svg" in self.formats else self.get_server_output_dir("png") + ) cmd = [ - 'pydeps', - '--show-deps', - '--max-bacon', '2', - '--cluster', - '--rankdir', 'TB', - '--no-show', - '-o', str(output_dir / f'functions_{name.lower()}.svg'), - file_path + "pydeps", + "--show-deps", + "--max-bacon", + "2", + "--cluster", + "--rankdir", + "TB", + "--no-show", + "-o", + str(output_dir / f"functions_{name.lower()}.svg"), + file_path, ] try: - subprocess.run(cmd, check=True, capture_output=True, text=True, cwd='../..') + subprocess.run(cmd, check=True, capture_output=True, text=True, cwd="../..") print(f" + Function analysis for {name}: {output_dir}/functions_{name.lower()}.svg") # Post-process SVG file to use configured font - self._update_fonts_and_count(output_dir / f'functions_{name.lower()}.svg') + self._update_fonts_and_count(output_dir / f"functions_{name.lower()}.svg") except subprocess.CalledProcessError: print(f" Warning: Function analysis for {name} failed - file may have no dependencies") @@ -477,9 +499,7 @@ def generate_architecture_diagram(self) -> None: # Create architecture diagram generator with same settings arch_generator = ArchitectureDiagramGenerator( - png_dir=str(self.png_dir), - svg_dir=str(self.svg_dir), - formats=self.formats + png_dir=str(self.png_dir), svg_dir=str(self.svg_dir), formats=self.formats ) # Generate all architecture diagrams (no cleaning in integrated mode) @@ -498,45 +518,51 @@ def generate_dependency_graph(self) -> None: print("Generating dependency graph...") # Dependencies are typically SVG, so use SVG directory if available - output_dir = self.get_server_output_dir('svg') if "svg" in self.formats else self.get_server_output_dir('png') + output_dir = self.get_server_output_dir("svg") if "svg" in self.formats else self.get_server_output_dir("png") try: # Generate main project dependency graph cmd = [ - 'pydeps', - '--show-deps', - '--max-bacon', '4', # Increased depth - '--cluster', - '--rankdir', 'TB', - '--no-show', # Prevent automatic opening of the generated file - '--include-missing', # Show external dependencies - '-o', str(output_dir / 'dependencies_main.svg'), - 'app.py' # Start from main entry point + "pydeps", + "--show-deps", + "--max-bacon", + "4", # Increased depth + "--cluster", + "--rankdir", + "TB", + "--no-show", # Prevent automatic opening of the generated file + "--include-missing", # Show external dependencies + "-o", + str(output_dir / "dependencies_main.svg"), + "app.py", # Start from main entry point ] - subprocess.run(cmd, check=True, capture_output=True, text=True, cwd='../..') + subprocess.run(cmd, check=True, capture_output=True, text=True, cwd="../..") print(f" + Main dependency graph generated: {output_dir}/dependencies_main.svg") # Post-process SVG file to use configured font - self._update_fonts_and_count(output_dir / 'dependencies_main.svg') + self._update_fonts_and_count(output_dir / "dependencies_main.svg") # Generate static module dependencies cmd = [ - 'pydeps', - '--show-deps', - '--max-bacon', '3', - '--cluster', - '--rankdir', 'LR', # Left-to-right for better readability - '--no-show', - '-o', str(output_dir / 'dependencies_static.svg'), - 'static/' + "pydeps", + "--show-deps", + "--max-bacon", + "3", + "--cluster", + "--rankdir", + "LR", # Left-to-right for better readability + "--no-show", + "-o", + str(output_dir / "dependencies_static.svg"), + "static/", ] - subprocess.run(cmd, check=True, capture_output=True, text=True, cwd='../..') + subprocess.run(cmd, check=True, capture_output=True, text=True, cwd="../..") print(f" + Static module dependencies generated: {output_dir}/dependencies_static.svg") # Post-process SVG file to use configured font - self._update_fonts_and_count(output_dir / 'dependencies_static.svg') + self._update_fonts_and_count(output_dir / "dependencies_static.svg") except subprocess.CalledProcessError as e: print(" Error: Error generating dependency graph") @@ -555,8 +581,8 @@ def generate_call_graph(self) -> None: from pycallgraph2.output import GraphvizOutput # Use PNG directory for call graph output - output_dir = self.get_server_output_dir('png') - output_file = output_dir / 'call_graph.png' + output_dir = self.get_server_output_dir("png") + output_file = output_dir / "call_graph.png" print(" Warning: Call graph generation is experimental") print(" This will trace app.py execution and may take time...") @@ -566,7 +592,9 @@ def generate_call_graph(self) -> None: # For now, just show the command to run manually print(" Note: To generate call graph manually:") print(" cd to project root, then run:") - print(" pycallgraph graphviz --output-file=diagrams/generated_png/server/call_graph.png -- python app.py") + print( + " pycallgraph graphviz --output-file=diagrams/generated_png/server/call_graph.png -- python app.py" + ) except ImportError: print(" Error: pycallgraph2 not found") @@ -582,9 +610,7 @@ def generate_brython_diagrams(self) -> None: # Create Brython diagram generator with same settings brython_generator = BrythonDiagramGenerator( - png_dir=str(self.png_dir), - svg_dir=str(self.svg_dir), - formats=self.formats + png_dir=str(self.png_dir), svg_dir=str(self.svg_dir), formats=self.formats ) # Generate all Brython diagrams @@ -632,23 +658,26 @@ def run(self, include_brython: bool = False) -> None: def main() -> None: - parser = argparse.ArgumentParser(description='Generate diagrams for MatHud project') - parser.add_argument('--png-dir', default='../generated_png', - help='Output directory for PNG diagrams (default: ../generated_png)') - parser.add_argument('--svg-dir', default='../generated_svg', - help='Output directory for SVG diagrams (default: ../generated_svg)') - parser.add_argument('--format', default='png,svg', - help='Output formats: png,svg,dot (default: png,svg)') + parser = argparse.ArgumentParser(description="Generate diagrams for MatHud project") + parser.add_argument( + "--png-dir", default="../generated_png", help="Output directory for PNG diagrams (default: ../generated_png)" + ) + parser.add_argument( + "--svg-dir", default="../generated_svg", help="Output directory for SVG diagrams (default: ../generated_svg)" + ) + parser.add_argument("--format", default="png,svg", help="Output formats: png,svg,dot (default: png,svg)") # Create mutually exclusive group for Brython options brython_group = parser.add_mutually_exclusive_group() - brython_group.add_argument('--include-brython', action='store_true', - help='Include comprehensive Brython client-side diagrams') - brython_group.add_argument('--no-brython', action='store_true', - help='Explicitly disable Brython diagrams (overrides default)') + brython_group.add_argument( + "--include-brython", action="store_true", help="Include comprehensive Brython client-side diagrams" + ) + brython_group.add_argument( + "--no-brython", action="store_true", help="Explicitly disable Brython diagrams (overrides default)" + ) args = parser.parse_args() - formats: List[str] = [f.strip() for f in args.format.split(',') if f.strip()] + formats: List[str] = [f.strip() for f in args.format.split(",") if f.strip()] # Determine if Brython should be included include_brython = args.include_brython and not args.no_brython @@ -657,5 +686,5 @@ def main() -> None: generator.run(include_brython=include_brython) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/diagrams/scripts/setup_diagram_tools.py b/diagrams/scripts/setup_diagram_tools.py index cf974e3f..e9d7bc54 100644 --- a/diagrams/scripts/setup_diagram_tools.py +++ b/diagrams/scripts/setup_diagram_tools.py @@ -21,11 +21,11 @@ class DiagramToolsSetup: def __init__(self) -> None: self.system: str = platform.system().lower() self.python_packages: List[str] = [ - 'pylint', - 'graphviz', - 'diagrams', - 'pydeps', - 'pycallgraph2', + "pylint", + "graphviz", + "diagrams", + "pydeps", + "pycallgraph2", ] def run_command(self, cmd: str, description: str) -> bool: @@ -41,24 +41,26 @@ def run_command(self, cmd: str, description: str) -> bool: def install_graphviz_system(self) -> bool: """Install system-level Graphviz based on the operating system.""" - if self.system == 'windows': + if self.system == "windows": print("For Windows, please manually install Graphviz:") print("1. Download from: https://graphviz.org/download/") print("2. Install and add to PATH") print("3. Or use: winget install graphviz") return True - elif self.system == 'darwin': # macOS - return self.run_command('brew install graphviz', 'Installing Graphviz via Homebrew') - elif self.system == 'linux': + elif self.system == "darwin": # macOS + return self.run_command("brew install graphviz", "Installing Graphviz via Homebrew") + elif self.system == "linux": # Try different package managers - if subprocess.run(['which', 'apt'], capture_output=True).returncode == 0: - return self.run_command('sudo apt update && sudo apt install -y graphviz', 'Installing Graphviz via apt') - elif subprocess.run(['which', 'yum'], capture_output=True).returncode == 0: - return self.run_command('sudo yum install -y graphviz', 'Installing Graphviz via yum') - elif subprocess.run(['which', 'dnf'], capture_output=True).returncode == 0: - return self.run_command('sudo dnf install -y graphviz', 'Installing Graphviz via dnf') - elif subprocess.run(['which', 'pacman'], capture_output=True).returncode == 0: - return self.run_command('sudo pacman -S graphviz', 'Installing Graphviz via pacman') + if subprocess.run(["which", "apt"], capture_output=True).returncode == 0: + return self.run_command( + "sudo apt update && sudo apt install -y graphviz", "Installing Graphviz via apt" + ) + elif subprocess.run(["which", "yum"], capture_output=True).returncode == 0: + return self.run_command("sudo yum install -y graphviz", "Installing Graphviz via yum") + elif subprocess.run(["which", "dnf"], capture_output=True).returncode == 0: + return self.run_command("sudo dnf install -y graphviz", "Installing Graphviz via dnf") + elif subprocess.run(["which", "pacman"], capture_output=True).returncode == 0: + return self.run_command("sudo pacman -S graphviz", "Installing Graphviz via pacman") else: print("Please manually install Graphviz for your Linux distribution") return False @@ -69,14 +71,11 @@ def install_python_packages(self) -> None: print("Installing Python packages...") # Upgrade pip first - self.run_command(f'{sys.executable} -m pip install --upgrade pip', 'Upgrading pip') + self.run_command(f"{sys.executable} -m pip install --upgrade pip", "Upgrading pip") # Install packages for package in self.python_packages: - success = self.run_command( - f'{sys.executable} -m pip install {package}', - f'Installing {package}' - ) + success = self.run_command(f"{sys.executable} -m pip install {package}", f"Installing {package}") if not success: print(f"Failed to install {package} - you may need to install it manually") @@ -85,9 +84,9 @@ def verify_installation(self) -> bool: print("\nVerifying installation...") tools_to_check = [ - ('pyreverse', 'pyreverse --help'), - ('dot', 'dot -V'), - ('python', f'{sys.executable} --version'), + ("pyreverse", "pyreverse --help"), + ("dot", "dot -V"), + ("python", f"{sys.executable} --version"), ] success_count = 0 @@ -103,8 +102,7 @@ def verify_installation(self) -> bool: python_success = 0 for package in self.python_packages: try: - subprocess.run([sys.executable, '-c', f'import {package}'], - check=True, capture_output=True) + subprocess.run([sys.executable, "-c", f"import {package}"], check=True, capture_output=True) print(f" Python package {package} is available") python_success += 1 except subprocess.CalledProcessError: @@ -164,5 +162,5 @@ def main() -> None: setup.setup() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/diagrams/scripts/utils.py b/diagrams/scripts/utils.py index 4522c115..c6143598 100644 --- a/diagrams/scripts/utils.py +++ b/diagrams/scripts/utils.py @@ -20,14 +20,14 @@ def setup_graphviz_path() -> None: try: # Check if dot is already available try: - subprocess.run(['dot', '-V'], check=True, capture_output=True) + subprocess.run(["dot", "-V"], check=True, capture_output=True) print(" + Graphviz dot command is already available") return except (subprocess.CalledProcessError, FileNotFoundError): pass # Only setup on Windows - if sys.platform != 'win32': + if sys.platform != "win32": return # Common Graphviz installation paths on Windows @@ -47,14 +47,14 @@ def setup_graphviz_path() -> None: if graphviz_bin: # Add to PATH for this session - current_path = os.environ.get('PATH', '') + current_path = os.environ.get("PATH", "") if graphviz_bin not in current_path: - os.environ['PATH'] = f"{graphviz_bin};{current_path}" + os.environ["PATH"] = f"{graphviz_bin};{current_path}" print(f" + Added Graphviz to PATH: {graphviz_bin}") # Verify it works try: - subprocess.run(['dot', '-V'], check=True, capture_output=True) + subprocess.run(["dot", "-V"], check=True, capture_output=True) print(" + Graphviz dot command is now available") except subprocess.CalledProcessError: print("Warning: Graphviz found but dot command still not working") @@ -71,11 +71,11 @@ def setup_graphviz_path() -> None: def setup_font_environment() -> None: """Setup environment variables to use configured font in Graphviz.""" # Set Graphviz font preferences - os.environ['FONTNAME'] = DIAGRAM_FONT - os.environ['FONTSIZE'] = DIAGRAM_FONT_SIZE_STR + os.environ["FONTNAME"] = DIAGRAM_FONT + os.environ["FONTSIZE"] = DIAGRAM_FONT_SIZE_STR # Some systems use different environment variables - os.environ['GRAPHVIZ_DOT_FONTNAME'] = DIAGRAM_FONT - os.environ['GRAPHVIZ_DOT_FONTSIZE'] = DIAGRAM_FONT_SIZE_STR + os.environ["GRAPHVIZ_DOT_FONTNAME"] = DIAGRAM_FONT + os.environ["GRAPHVIZ_DOT_FONTSIZE"] = DIAGRAM_FONT_SIZE_STR def post_process_svg_fonts(svg_file: Path, diagram_font: str = DIAGRAM_FONT) -> bool: @@ -84,7 +84,7 @@ def post_process_svg_fonts(svg_file: Path, diagram_font: str = DIAGRAM_FONT) -> if not svg_file.exists(): return False - content = svg_file.read_text(encoding='utf-8') + content = svg_file.read_text(encoding="utf-8") # Replace common serif fonts with configured font font_replacements: Tuple[Tuple[str, str], ...] = ( @@ -104,12 +104,12 @@ def post_process_svg_fonts(svg_file: Path, diagram_font: str = DIAGRAM_FONT) -> modified = True # Add default font if no font-family is specified in text elements - if 'font-family' not in content and ' None: file_details: DefaultDict[str, List[Dict[str, Any]]] = defaultdict(list) self.metrics: Dict[str, Any] = { - 'files': files, - 'lines': lines, - 'classes': 0, - 'methods': 0, - 'functions': 0, - 'test_functions': 0, - 'ai_functions': 0, - 'drawable_classes': 0, - 'manager_classes': 0, - 'imports': 0, - 'comments': 0, - 'docstrings': 0, - 'docstring_lines': 0, - 'reference_manual_lines': 0, - 'unique_python_imports': set(), - 'python_dependencies': 0, - 'javascript_libraries': 0, - 'test_files': 0, - 'file_details': file_details, + "files": files, + "lines": lines, + "classes": 0, + "methods": 0, + "functions": 0, + "test_functions": 0, + "ai_functions": 0, + "drawable_classes": 0, + "manager_classes": 0, + "imports": 0, + "comments": 0, + "docstrings": 0, + "docstring_lines": 0, + "reference_manual_lines": 0, + "unique_python_imports": set(), + "python_dependencies": 0, + "javascript_libraries": 0, + "test_files": 0, + "file_details": file_details, } # File extensions to analyze self.extensions: Dict[str, str] = { - '.py': 'Python', - '.html': 'HTML', - '.css': 'CSS', - '.txt': 'Text', - '.md': 'Markdown', + ".py": "Python", + ".html": "HTML", + ".css": "CSS", + ".txt": "Text", + ".md": "Markdown", # '.js': 'JavaScript', - '.json': 'JSON' + ".json": "JSON", } # Directories to exclude self.exclude_dirs: Set[str] = { - '__pycache__', '.git', 'venv', '.vscode', '.pytest_cache', - 'logs', 'workspaces', 'canvas_snapshots', 'generated_svg', 'generated_png' + "__pycache__", + ".git", + "venv", + ".vscode", + ".pytest_cache", + "logs", + "workspaces", + "canvas_snapshots", + "generated_svg", + "generated_png", } def analyze_project(self) -> None: @@ -94,13 +102,13 @@ def analyze_file(self, file_path: Path) -> None: return file_type = self.extensions[suffix] - self.metrics['files'][file_type] += 1 + self.metrics["files"][file_type] += 1 # Read file content try: - with open(file_path, 'r', encoding='utf-8') as f: + with open(file_path, "r", encoding="utf-8") as f: content = f.read() - lines = content.split('\n') + lines = content.split("\n") except UnicodeDecodeError: # Skip binary files return @@ -108,42 +116,42 @@ def analyze_file(self, file_path: Path) -> None: line_count = len(lines) # Handle Reference Manual separately (contains duplicated docstrings) - is_reference_manual = file_path.name == 'Reference Manual.txt' + is_reference_manual = file_path.name == "Reference Manual.txt" if is_reference_manual: # Track but don't count toward documentation totals - self.metrics['reference_manual_lines'] = line_count + self.metrics["reference_manual_lines"] = line_count else: - self.metrics['lines'][file_type] += line_count + self.metrics["lines"][file_type] += line_count # Store file details relative_path = file_path.relative_to(self.project_root) file_info: Dict[str, Any] = { - 'path': str(relative_path), - 'lines': line_count, - 'size': file_path.stat().st_size, - 'is_reference_manual': is_reference_manual + "path": str(relative_path), + "lines": line_count, + "size": file_path.stat().st_size, + "is_reference_manual": is_reference_manual, } # Analyze Python files in detail - if suffix == '.py': + if suffix == ".py": py_metrics = self.analyze_python_file(content, file_path) file_info.update(py_metrics) # Subtract docstring lines from Python code lines - docstring_lines = py_metrics.get('docstring_lines', 0) - self.metrics['lines'][file_type] -= docstring_lines - self.metrics['docstring_lines'] += docstring_lines - file_info['lines'] -= docstring_lines - file_info['docstring_lines'] = docstring_lines + docstring_lines = py_metrics.get("docstring_lines", 0) + self.metrics["lines"][file_type] -= docstring_lines + self.metrics["docstring_lines"] += docstring_lines + file_info["lines"] -= docstring_lines + file_info["docstring_lines"] = docstring_lines # Count test files from server_tests and client_tests directories relative_path = file_path.relative_to(self.project_root) - if any(part in str(relative_path).lower() for part in ['server_tests', 'client_tests']): - if file_path.suffix.lower() == '.py': # Only count Python test files - self.metrics['test_files'] += 1 - file_info['is_test_file'] = True + if any(part in str(relative_path).lower() for part in ["server_tests", "client_tests"]): + if file_path.suffix.lower() == ".py": # Only count Python test files + self.metrics["test_files"] += 1 + file_info["is_test_file"] = True - file_details = cast(DefaultDict[str, List[Dict[str, Any]]], self.metrics['file_details']) + file_details = cast(DefaultDict[str, List[Dict[str, Any]]], self.metrics["file_details"]) file_details[file_type].append(file_info) except Exception as e: @@ -151,19 +159,19 @@ def analyze_file(self, file_path: Path) -> None: def analyze_python_file(self, content: str, file_path: Path) -> Dict[str, Any]: """Detailed analysis of Python files.""" - lines = content.split('\n') + lines = content.split("\n") file_metrics: Dict[str, Any] = { - 'classes': 0, - 'methods': 0, - 'functions': 0, - 'test_functions': 0, - 'imports': 0, - 'comments': 0, - 'docstrings': 0, - 'docstring_lines': 0, - 'is_drawable': False, - 'is_manager': False, - 'is_test': False + "classes": 0, + "methods": 0, + "functions": 0, + "test_functions": 0, + "imports": 0, + "comments": 0, + "docstrings": 0, + "docstring_lines": 0, + "is_drawable": False, + "is_manager": False, + "is_test": False, } in_multiline_string = False @@ -179,18 +187,18 @@ def analyze_python_file(self, content: str, file_path: Path) -> Dict[str, Any]: # Check if it's a single-line docstring if stripped.endswith(string_delimiter) and len(stripped) > 6: # Single-line docstring - file_metrics['docstrings'] += 1 - self.metrics['docstrings'] += 1 - file_metrics['docstring_lines'] += 1 + file_metrics["docstrings"] += 1 + self.metrics["docstrings"] += 1 + file_metrics["docstring_lines"] += 1 else: # Multi-line docstring starts in_multiline_string = True - file_metrics['docstrings'] += 1 - self.metrics['docstrings'] += 1 - file_metrics['docstring_lines'] += 1 + file_metrics["docstrings"] += 1 + self.metrics["docstrings"] += 1 + file_metrics["docstring_lines"] += 1 else: # Inside multiline docstring - file_metrics['docstring_lines'] += 1 + file_metrics["docstring_lines"] += 1 if string_delimiter and string_delimiter in stripped: in_multiline_string = False continue @@ -199,66 +207,78 @@ def analyze_python_file(self, content: str, file_path: Path) -> Dict[str, Any]: continue # Class definitions - if re.match(r'^class\s+\w+', stripped): - file_metrics['classes'] += 1 - self.metrics['classes'] += 1 + if re.match(r"^class\s+\w+", stripped): + file_metrics["classes"] += 1 + self.metrics["classes"] += 1 # Check for specific class types - if 'drawable' in file_path.name.lower() or any( - keyword in stripped.lower() for keyword in ['drawable', 'point', 'segment', 'circle', 'triangle', 'rectangle', 'ellipse', 'vector', 'angle', 'function'] + if "drawable" in file_path.name.lower() or any( + keyword in stripped.lower() + for keyword in [ + "drawable", + "point", + "segment", + "circle", + "triangle", + "rectangle", + "ellipse", + "vector", + "angle", + "function", + ] ): - file_metrics['is_drawable'] = True - if 'drawable' in stripped.lower(): - self.metrics['drawable_classes'] += 1 + file_metrics["is_drawable"] = True + if "drawable" in stripped.lower(): + self.metrics["drawable_classes"] += 1 - if 'manager' in stripped.lower(): - file_metrics['is_manager'] = True - self.metrics['manager_classes'] += 1 + if "manager" in stripped.lower(): + file_metrics["is_manager"] = True + self.metrics["manager_classes"] += 1 # Method definitions (inside classes) - use original line for indentation - if re.match(r'^\s+def\s+\w+', line): - file_metrics['methods'] += 1 - self.metrics['methods'] += 1 + if re.match(r"^\s+def\s+\w+", line): + file_metrics["methods"] += 1 + self.metrics["methods"] += 1 # Test methods - if 'def test_' in stripped: - file_metrics['test_functions'] += 1 - self.metrics['test_functions'] += 1 + if "def test_" in stripped: + file_metrics["test_functions"] += 1 + self.metrics["test_functions"] += 1 # Function definitions (at module level) - elif re.match(r'^def\s+\w+', stripped): - file_metrics['functions'] += 1 - self.metrics['functions'] += 1 + elif re.match(r"^def\s+\w+", stripped): + file_metrics["functions"] += 1 + self.metrics["functions"] += 1 # Test functions - if 'def test_' in stripped: - file_metrics['test_functions'] += 1 - self.metrics['test_functions'] += 1 + if "def test_" in stripped: + file_metrics["test_functions"] += 1 + self.metrics["test_functions"] += 1 # Import statements - elif stripped.startswith('import ') or stripped.startswith('from '): - file_metrics['imports'] += 1 - self.metrics['imports'] += 1 + elif stripped.startswith("import ") or stripped.startswith("from "): + file_metrics["imports"] += 1 + self.metrics["imports"] += 1 # Track unique imports for dependency analysis import_module = self.extract_import_module(stripped) if import_module: - self.metrics['unique_python_imports'].add(import_module) + self.metrics["unique_python_imports"].add(import_module) # Comments - elif stripped.startswith('#'): - file_metrics['comments'] += 1 - self.metrics['comments'] += 1 + elif stripped.startswith("#"): + file_metrics["comments"] += 1 + self.metrics["comments"] += 1 # Check if it's a test file - if 'test' in file_path.name.lower(): - file_metrics['is_test'] = True + if "test" in file_path.name.lower(): + file_metrics["is_test"] = True # Special case: analyze functions_definitions.py for AI functions - if file_path.name == 'functions_definitions.py': + if file_path.name == "functions_definitions.py": ai_functions = self.count_ai_functions(content) - self.metrics['ai_functions'] = ai_functions - file_metrics['ai_functions'] = ai_functions + self.metrics["ai_functions"] = ai_functions + file_metrics["ai_functions"] = ai_functions return file_metrics @@ -266,15 +286,15 @@ def extract_import_module(self, import_line: str) -> Optional[str]: """Extract the main module name from an import statement.""" try: # Handle 'import module' and 'from module import ...' - if import_line.startswith('import '): - module = import_line[7:].split('.')[0].split(' as ')[0].split(',')[0].strip() - elif import_line.startswith('from '): - module = import_line[5:].split('.')[0].split(' import')[0].strip() + if import_line.startswith("import "): + module = import_line[7:].split(".")[0].split(" as ")[0].split(",")[0].strip() + elif import_line.startswith("from "): + module = import_line[5:].split(".")[0].split(" import")[0].strip() else: return None # Filter out relative imports and local modules - if module and not module.startswith('.') and module.isidentifier(): + if module and not module.startswith(".") and module.isidentifier(): return module return None except Exception: @@ -287,59 +307,66 @@ def analyze_dependencies(self) -> None: # Main requirements.txt requirements_files = [ - self.project_root / 'requirements.txt', - self.project_root / 'diagrams' / 'diagram_requirements.txt' + self.project_root / "requirements.txt", + self.project_root / "diagrams" / "diagram_requirements.txt", ] for requirements_file in requirements_files: if requirements_file.exists(): try: - with open(requirements_file, 'r', encoding='utf-8') as f: + with open(requirements_file, "r", encoding="utf-8") as f: content = f.read() - lines = [line.strip() for line in content.split('\n') if line.strip() and not line.strip().startswith('#')] + lines = [ + line.strip() + for line in content.split("\n") + if line.strip() and not line.strip().startswith("#") + ] # Count dependencies (remove version specs) for line in lines: - dep = line.split('==')[0].split('>=')[0].split('<=')[0].split('~=')[0].split('!=')[0].strip() + dep = ( + line.split("==")[0].split(">=")[0].split("<=")[0].split("~=")[0].split("!=")[0].strip() + ) if dep: deps.add(dep) except Exception: pass - self.metrics['python_dependencies'] = len(deps) + self.metrics["python_dependencies"] = len(deps) # Analyze index.html for JavaScript libraries - index_file = self.project_root / 'templates' / 'index.html' + index_file = self.project_root / "templates" / "index.html" if index_file.exists(): try: - with open(index_file, 'r', encoding='utf-8') as f: + with open(index_file, "r", encoding="utf-8") as f: content = f.read() # Count script tags and CDN libraries js_libs = set() # Look for script src tags import re + script_pattern = r']+src=["\']([^"\']+)["\']' matches = re.findall(script_pattern, content, re.IGNORECASE) for src in matches: - if 'http' in src or 'cdn' in src: + if "http" in src or "cdn" in src: # External library - lib_name = src.split('/')[-1].split('.')[0] + lib_name = src.split("/")[-1].split(".")[0] js_libs.add(lib_name) - elif '.js' in src: + elif ".js" in src: # Local library - lib_name = src.split('/')[-1].split('.')[0] + lib_name = src.split("/")[-1].split(".")[0] js_libs.add(lib_name) # Also check for specific known libraries mentioned in text - if 'brython' in content.lower(): - js_libs.add('brython') - if 'mathjax' in content.lower(): - js_libs.add('mathjax') - if 'nerdamer' in content.lower(): - js_libs.add('nerdamer') - - self.metrics['javascript_libraries'] = len(js_libs) + if "brython" in content.lower(): + js_libs.add("brython") + if "mathjax" in content.lower(): + js_libs.add("mathjax") + if "nerdamer" in content.lower(): + js_libs.add("nerdamer") + + self.metrics["javascript_libraries"] = len(js_libs) except Exception: pass @@ -358,19 +385,19 @@ def generate_reports(self) -> None: def print_summary(self) -> None: """Print summary metrics to console.""" - print("\n" + "="*60) + print("\n" + "=" * 60) print("MATHUD PROJECT METRICS SUMMARY") - print("="*60) + print("=" * 60) print("\nFILE STATISTICS:") - total_files = sum(self.metrics['files'].values()) - total_lines = sum(self.metrics['lines'].values()) + total_files = sum(self.metrics["files"].values()) + total_lines = sum(self.metrics["lines"].values()) print(f" Total Files: {total_files:,}") print(f" Total Lines: {total_lines:,}") print("\nFILES BY TYPE:") - for file_type, count in sorted(self.metrics['files'].items()): - lines = self.metrics['lines'][file_type] + for file_type, count in sorted(self.metrics["files"].items()): + lines = self.metrics["lines"][file_type] print(f" {file_type:>10}: {count:>3} files, {lines:>6,} lines") print("\nCODE STRUCTURE:") @@ -402,9 +429,9 @@ def print_summary(self) -> None: def save_overview_table(self) -> None: """Save a formatted overview table to file.""" - output_file = self.project_root / 'Documentation' / 'Project Overview Table.txt' + output_file = self.project_root / "Documentation" / "Project Overview Table.txt" - with open(output_file, 'w', encoding='utf-8') as f: + with open(output_file, "w", encoding="utf-8") as f: f.write("MatHud Project Overview Table\n") f.write("=" * 50 + "\n") f.write(f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n") @@ -415,16 +442,22 @@ def save_overview_table(self) -> None: f.write(f"{'Metric':<25} {'Count':<10} {'Details':<15}\n") f.write("-" * 50 + "\n") - total_files = sum(self.metrics['files'].values()) - total_lines = sum(self.metrics['lines'].values()) + total_files = sum(self.metrics["files"].values()) + total_lines = sum(self.metrics["lines"].values()) f.write(f"{'Total Files':<25} {total_files:<10,} {'All types':<15}\n") f.write(f"{'Total Lines of Code':<25} {total_lines:<10,} {'All files':<15}\n") - f.write(f"{'Python Files':<25} {self.metrics['files']['Python']:<10} {self.metrics['lines']['Python']:,} lines\n") + f.write( + f"{'Python Files':<25} {self.metrics['files']['Python']:<10} {self.metrics['lines']['Python']:,} lines\n" + ) f.write(f"{'HTML Files':<25} {self.metrics['files']['HTML']:<10} {self.metrics['lines']['HTML']:,} lines\n") f.write(f"{'CSS Files':<25} {self.metrics['files']['CSS']:<10} {self.metrics['lines']['CSS']:,} lines\n") - f.write(f"{'JavaScript Files':<25} {self.metrics['files']['JavaScript']:<10} {self.metrics['lines']['JavaScript']:,} lines\n") - f.write(f"{'Documentation Files':<25} {self.metrics['files']['Text'] + self.metrics['files']['Markdown']:<10} {self.metrics['lines']['Text'] + self.metrics['lines']['Markdown']:,} lines\n") + f.write( + f"{'JavaScript Files':<25} {self.metrics['files']['JavaScript']:<10} {self.metrics['lines']['JavaScript']:,} lines\n" + ) + f.write( + f"{'Documentation Files':<25} {self.metrics['files']['Text'] + self.metrics['files']['Markdown']:<10} {self.metrics['lines']['Text'] + self.metrics['lines']['Markdown']:,} lines\n" + ) f.write("\n" + "-" * 50 + "\n") f.write("CODE ARCHITECTURE\n") @@ -438,10 +471,14 @@ def save_overview_table(self) -> None: f.write("\n" + "-" * 50 + "\n") f.write("DEPENDENCIES\n") f.write("-" * 50 + "\n") - f.write(f"{'Python Dependencies':<25} {self.metrics['python_dependencies']:<10} {'External packages':<15}\n") + f.write( + f"{'Python Dependencies':<25} {self.metrics['python_dependencies']:<10} {'External packages':<15}\n" + ) f.write(f"{'JavaScript Libraries':<25} {self.metrics['javascript_libraries']:<10} {'Frontend libs':<15}\n") f.write(f"{'Import Statements':<25} {self.metrics['imports']:<10} {'All imports':<15}\n") - f.write(f"{'Unique Python Imports':<25} {len(self.metrics['unique_python_imports']):<10} {'Distinct modules':<15}\n") + f.write( + f"{'Unique Python Imports':<25} {len(self.metrics['unique_python_imports']):<10} {'Distinct modules':<15}\n" + ) f.write("\n" + "-" * 50 + "\n") f.write("TESTING\n") @@ -473,22 +510,24 @@ def save_overview_table(self) -> None: f.write(f"• {self.metrics['test_functions']} test functions for quality assurance\n") f.write("• Flask backend with Brython frontend architecture\n") f.write("• Interactive SVG canvas with real-time mathematical visualization\n\n") - f.write(f"NOTE: Reference Manual.txt ({self.metrics['reference_manual_lines']:,} lines) contains duplicated docstring\n") + f.write( + f"NOTE: Reference Manual.txt ({self.metrics['reference_manual_lines']:,} lines) contains duplicated docstring\n" + ) f.write("content and is excluded from documentation totals to avoid double-counting.\n") print(f"\nOverview table saved to: {output_file}") def save_detailed_report(self) -> None: """Save detailed metrics report to file.""" - output_file = self.project_root / 'Documentation' / 'metrics' / 'detailed_project_metrics.txt' + output_file = self.project_root / "Documentation" / "metrics" / "detailed_project_metrics.txt" - with open(output_file, 'w', encoding='utf-8') as f: + with open(output_file, "w", encoding="utf-8") as f: f.write("MatHud Project - Detailed Metrics Report\n") f.write("=" * 60 + "\n") f.write(f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n") # Detailed file breakdown by type - file_details = cast(DefaultDict[str, List[Dict[str, Any]]], self.metrics['file_details']) + file_details = cast(DefaultDict[str, List[Dict[str, Any]]], self.metrics["file_details"]) for file_type in sorted(file_details.keys()): files = file_details[file_type] if not files: @@ -498,35 +537,36 @@ def save_detailed_report(self) -> None: f.write("-" * 40 + "\n") # Sort files by line count (descending) - files.sort(key=lambda x: int(x['lines']), reverse=True) + files.sort(key=lambda x: int(x["lines"]), reverse=True) for file_info in files: f.write(f"{file_info['path']:<40} {file_info['lines']:>6} lines") - if file_type == 'Python': - if file_info.get('classes', 0) > 0: + if file_type == "Python": + if file_info.get("classes", 0) > 0: f.write(f" | {file_info['classes']} classes") - if file_info.get('methods', 0) > 0: + if file_info.get("methods", 0) > 0: f.write(f" | {file_info['methods']} methods") - if file_info.get('functions', 0) > 0: + if file_info.get("functions", 0) > 0: f.write(f" | {file_info['functions']} functions") - if file_info.get('test_functions', 0) > 0: + if file_info.get("test_functions", 0) > 0: f.write(f" | {file_info['test_functions']} tests") - if file_info.get('ai_functions', 0) > 0: + if file_info.get("ai_functions", 0) > 0: f.write(f" | {file_info['ai_functions']} AI functions") # Mark Reference Manual as duplicated content - if file_info.get('is_reference_manual', False): + if file_info.get("is_reference_manual", False): f.write(f" | {file_info['size']:,} bytes | EXCLUDED (duplicated docstrings)\n") else: f.write(f" | {file_info['size']:,} bytes\n") - total_lines = sum(f['lines'] for f in files) - total_size = sum(f['size'] for f in files) + total_lines = sum(f["lines"] for f in files) + total_size = sum(f["size"] for f in files) f.write(f"\nSubtotal: {total_lines:,} lines, {total_size:,} bytes\n") print(f"Detailed report saved to: {output_file}") + def main() -> None: """Main function to run the project analysis.""" # Get the project root (parent directory of Documentation) @@ -544,5 +584,6 @@ def main() -> None: print(" • Documentation/Project Overview Table.txt (summary)") print(" • Documentation/metrics/detailed_project_metrics.txt (detailed breakdown)") + if __name__ == "__main__": main() diff --git a/generate_diagrams_launcher.py b/generate_diagrams_launcher.py index 13db004d..f1f8e5eb 100644 --- a/generate_diagrams_launcher.py +++ b/generate_diagrams_launcher.py @@ -58,5 +58,5 @@ def main() -> None: os.chdir(original_cwd) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mypy.ini b/mypy.ini index 422eca17..993077ec 100644 --- a/mypy.ini +++ b/mypy.ini @@ -11,6 +11,7 @@ ignore_missing_imports = True implicit_reexport = False incremental = True explicit_package_bases = True -files = app.py, static/app_manager.py, static/workspace_manager.py, static/log_manager.py, static/openai_api_base.py, static/openai_completions_api.py, static/openai_responses_api.py, static/routes.py, static/tool_call_processor.py, static/ai_model.py, static/webdriver_manager.py, static/functions_definitions.py, run_server_tests.py, server_tests/test_mocks.py, server_tests/test_routes.py, server_tests/test_workspace_management.py, static/client/constants.py, static/client/expression_validator.py, static/client/markdown_parser.py, static/client/main.py, static/client/ai_interface.py, static/client/canvas.py, static/client/canvas_event_handler.py, static/client/cartesian_system_2axis.py, static/client/coordinate_mapper.py, static/client/expression_evaluator.py, static/client/function_registry.py, static/client/process_function_calls.py, static/client/result_processor.py, static/client/result_validator.py, static/client/workspace_manager.py, static/client/utils/math_utils.py, static/client/utils/computation_utils.py, static/client/utils/geometry_utils.py, static/client/utils/style_utils.py, static/client/utils/linear_algebra_utils.py, static/client/name_generator/base.py, static/client/name_generator/drawable.py, static/client/name_generator/function.py, static/client/name_generator/point.py, static/client/managers/undo_redo_manager.py, static/client/managers/transformations_manager.py, static/client/managers/drawable_manager.py, static/client/managers/drawables_container.py, static/client/managers/drawable_manager_proxy.py, static/client/managers/drawable_dependency_manager.py, static/client/managers/point_manager.py, static/client/managers/segment_manager.py, static/client/managers/vector_manager.py, static/client/managers/circle_manager.py, static/client/managers/ellipse_manager.py, static/client/managers/function_manager.py, static/client/managers/colored_area_manager.py, static/client/managers/angle_manager.py, static/client/drawables/position.py, static/client/drawables/drawable.py, static/client/drawables/point.py, static/client/drawables/segment.py, static/client/drawables/vector.py, static/client/drawables/triangle.py, static/client/drawables/rectangle.py, static/client/drawables/circle.py, static/client/drawables/ellipse.py, static/client/drawables/function.py, static/client/drawables/angle.py, static/client/drawables/colored_area.py, static/client/drawables/functions_bounded_colored_area.py, static/client/drawables/segments_bounded_colored_area.py, static/client/drawables/function_segment_bounded_colored_area.py, static/client/test_runner.py, static/client/rendering/interfaces.py, static/client/rendering/primitives.py, static/client/rendering/renderables/function_renderable.py, static/client/rendering/renderables/functions_area_renderable.py, static/client/rendering/renderables/segments_area_renderable.py, static/client/rendering/renderables/function_segment_area_renderable.py, static/client/rendering/svg_renderer.py, static/client/client_tests/test_angle.py, static/client/client_tests/test_angle_manager.py, static/client/client_tests/test_canvas.py, static/client/client_tests/test_cartesian.py, static/client/client_tests/test_circle.py, static/client/client_tests/test_custom_drawable_names.py, static/client/client_tests/test_drawable_dependency_manager.py, static/client/client_tests/test_drawable_name_generator.py, static/client/client_tests/test_drawables_container.py, static/client/client_tests/test_ellipse.py, static/client/client_tests/test_event_handler.py, static/client/client_tests/test_expression_validator.py, static/client/client_tests/test_function.py, static/client/client_tests/test_function_bounded_colored_area_integration.py, static/client/client_tests/test_function_calling.py, static/client/client_tests/test_linear_algebra_utils.py, static/client/client_tests/test_function_segment_bounded_colored_area.py, static/client/client_tests/test_functions_bounded_colored_area.py, static/client/client_tests/test_math_functions.py, static/client/client_tests/test_point.py, static/client/client_tests/test_rectangle.py, static/client/client_tests/test_segment.py, static/client/client_tests/test_segments_bounded_colored_area.py, static/client/client_tests/test_throttle.py, static/client/client_tests/test_triangle.py, static/client/client_tests/test_vector.py, static/client/client_tests/test_window_mocks.py, static/client/client_tests/ai_result_formatter.py, static/client/client_tests/brython_io.py, static/client/client_tests/simple_mock.py, static/client/client_tests/tests.py, generate_diagrams_launcher.py, scripts/linear_algebra_expected_values.py, diagrams/scripts/utils.py, diagrams/scripts/generate_diagrams.py, diagrams/scripts/generate_arch.py, diagrams/scripts/generate_brython_diagrams.py, diagrams/scripts/setup_diagram_tools.py, static/client/client_tests/__init__.py, static/client/drawables/__init__.py, static/client/managers/__init__.py, static/client/name_generator/__init__.py, static/client/rendering/__init__.py, static/client/utils/__init__.py, server_tests/__init__.py, server_tests/python_path_setup.py, documentation/metrics/project_metrics_analyzer.py +mypy_path = static/client/typing +files = app.py, static/app_manager.py, static/workspace_manager.py, static/log_manager.py, static/openai_api_base.py, static/openai_completions_api.py, static/openai_responses_api.py, static/routes.py, static/tool_call_processor.py, static/ai_model.py, static/webdriver_manager.py, static/functions_definitions.py, run_server_tests.py, server_tests/test_mocks.py, server_tests/test_routes.py, server_tests/test_workspace_management.py, static/client/constants.py, static/client/expression_validator.py, static/client/markdown_parser.py, static/client/main.py, static/client/ai_interface.py, static/client/canvas.py, static/client/canvas_event_handler.py, static/client/cartesian_system_2axis.py, static/client/coordinate_mapper.py, static/client/expression_evaluator.py, static/client/function_registry.py, static/client/process_function_calls.py, static/client/result_processor.py, static/client/result_validator.py, static/client/workspace_manager.py, static/client/utils/math_utils.py, static/client/utils/computation_utils.py, static/client/utils/geometry_utils.py, static/client/utils/style_utils.py, static/client/utils/linear_algebra_utils.py, static/client/name_generator/base.py, static/client/name_generator/drawable.py, static/client/name_generator/function.py, static/client/name_generator/point.py, static/client/managers/undo_redo_manager.py, static/client/managers/transformations_manager.py, static/client/managers/drawable_manager.py, static/client/managers/drawables_container.py, static/client/managers/drawable_manager_proxy.py, static/client/managers/drawable_dependency_manager.py, static/client/managers/point_manager.py, static/client/managers/segment_manager.py, static/client/managers/vector_manager.py, static/client/managers/circle_manager.py, static/client/managers/ellipse_manager.py, static/client/managers/function_manager.py, static/client/managers/colored_area_manager.py, static/client/managers/angle_manager.py, static/client/drawables/position.py, static/client/drawables/drawable.py, static/client/drawables/point.py, static/client/drawables/segment.py, static/client/drawables/vector.py, static/client/drawables/triangle.py, static/client/drawables/rectangle.py, static/client/drawables/circle.py, static/client/drawables/ellipse.py, static/client/drawables/function.py, static/client/drawables/angle.py, static/client/drawables/colored_area.py, static/client/drawables/functions_bounded_colored_area.py, static/client/drawables/segments_bounded_colored_area.py, static/client/drawables/function_segment_bounded_colored_area.py, static/client/test_runner.py, static/client/rendering/interfaces.py, static/client/rendering/primitives.py, static/client/rendering/renderables/function_renderable.py, static/client/rendering/renderables/functions_area_renderable.py, static/client/rendering/renderables/segments_area_renderable.py, static/client/rendering/renderables/function_segment_area_renderable.py, static/client/rendering/svg_renderer.py, static/client/client_tests/test_angle.py, static/client/client_tests/test_angle_manager.py, static/client/client_tests/test_canvas.py, static/client/client_tests/test_cartesian.py, static/client/client_tests/test_circle.py, static/client/client_tests/test_custom_drawable_names.py, static/client/client_tests/test_drawable_dependency_manager.py, static/client/client_tests/test_drawable_name_generator.py, static/client/client_tests/test_drawables_container.py, static/client/client_tests/test_ellipse.py, static/client/client_tests/test_event_handler.py, static/client/client_tests/test_expression_validator.py, static/client/client_tests/test_function.py, static/client/client_tests/test_function_bounded_colored_area_integration.py, static/client/client_tests/test_function_calling.py, static/client/client_tests/test_linear_algebra_utils.py, static/client/client_tests/test_function_segment_bounded_colored_area.py, static/client/client_tests/test_functions_bounded_colored_area.py, static/client/client_tests/test_math_functions.py, static/client/client_tests/test_point.py, static/client/client_tests/test_rectangle.py, static/client/client_tests/test_segment.py, static/client/client_tests/test_segments_bounded_colored_area.py, static/client/client_tests/test_throttle.py, static/client/client_tests/test_triangle.py, static/client/client_tests/test_vector.py, static/client/client_tests/test_window_mocks.py, static/client/client_tests/ai_result_formatter.py, static/client/client_tests/brython_io.py, static/client/client_tests/simple_mock.py, static/client/client_tests/tests.py, generate_diagrams_launcher.py, scripts/linear_algebra_expected_values.py, diagrams/scripts/utils.py, diagrams/scripts/generate_diagrams.py, diagrams/scripts/generate_arch.py, diagrams/scripts/generate_brython_diagrams.py, diagrams/scripts/setup_diagram_tools.py, static/client/client_tests/__init__.py, static/client/drawables/__init__.py, static/client/managers/__init__.py, static/client/name_generator/__init__.py, static/client/rendering/__init__.py, static/client/utils/__init__.py, server_tests/__init__.py, server_tests/python_path_setup.py, documentation/metrics/project_metrics_analyzer.py, server_tests/test_browser_typing_stubs.py follow_imports = skip diff --git a/run_server_tests.py b/run_server_tests.py index ab0067da..b12cb4be 100644 --- a/run_server_tests.py +++ b/run_server_tests.py @@ -61,7 +61,7 @@ def run_tests() -> int: # Handle -k for keyword filtering elif arg == "-k" and i + 1 < len(sys.argv): - extra_args.extend(["-k", sys.argv[i+1]]) + extra_args.extend(["-k", sys.argv[i + 1]]) i += 2 continue @@ -87,7 +87,7 @@ def run_tests() -> int: # Set test environment - disable authentication by default for testing if not with_auth: - os.environ['REQUIRE_AUTH'] = 'false' + os.environ["REQUIRE_AUTH"] = "false" print("Test mode: authentication disabled for testing") else: print("Test mode: authentication enabled (--with-auth)") @@ -123,6 +123,7 @@ def run_tests() -> int: # Return the exit code return result.returncode + def show_help() -> None: """Show help information about this script and pytest options. @@ -142,5 +143,6 @@ def show_help() -> None: print(" -q, --quiet Decrease verbosity") print("\nFor more options: python run_server_tests.py -- --help") + if __name__ == "__main__": sys.exit(run_tests()) diff --git a/scripts/canvas_prompt_telemetry_report.py b/scripts/canvas_prompt_telemetry_report.py index dcb5bfd8..2b87c193 100644 --- a/scripts/canvas_prompt_telemetry_report.py +++ b/scripts/canvas_prompt_telemetry_report.py @@ -48,7 +48,7 @@ def _read_rows(log_file: Path, modes: Optional[set[str]]) -> List[Dict[str, Any] marker_idx = line.find(LOG_MARKER) if marker_idx < 0: continue - payload_text = line[marker_idx + len(LOG_MARKER):].strip() + payload_text = line[marker_idx + len(LOG_MARKER) :].strip() if not payload_text: continue try: diff --git a/scripts/linear_algebra_expected_values.py b/scripts/linear_algebra_expected_values.py index 69c6d96f..9d100bfb 100644 --- a/scripts/linear_algebra_expected_values.py +++ b/scripts/linear_algebra_expected_values.py @@ -15,19 +15,25 @@ def main() -> None: np.set_printoptions(linewidth=200, suppress=True) - a_matrix = np.array([ - [42, -17, 63, -5], - [-28, 91, -74, 60], - [39, -56, 81, -13], - [22, -48, 9, 100], - ], dtype=float) - - b_matrix = np.array([ - [15, -88, 71, 20], - [-93, 7, -44, 55], - [61, -36, 29, 90], - [-14, 66, -53, 77], - ], dtype=float) + a_matrix = np.array( + [ + [42, -17, 63, -5], + [-28, 91, -74, 60], + [39, -56, 81, -13], + [22, -48, 9, 100], + ], + dtype=float, + ) + + b_matrix = np.array( + [ + [15, -88, 71, 20], + [-93, 7, -44, 55], + [61, -36, 29, 90], + [-14, 66, -53, 77], + ], + dtype=float, + ) a_plus_b = a_matrix + b_matrix a_inverse = np.linalg.inv(a_matrix) @@ -62,4 +68,3 @@ def _format_float(value: Any) -> str: if __name__ == "__main__": main() - diff --git a/server_tests/__init__.py b/server_tests/__init__.py index 88b1b33b..2fcb166d 100644 --- a/server_tests/__init__.py +++ b/server_tests/__init__.py @@ -11,9 +11,6 @@ sys.path.append(PROJECT_ROOT) # Add the site-packages directory to Python path -SITE_PACKAGES_PATH: str = os.path.join( - PROJECT_ROOT, - 'static', 'client' -) +SITE_PACKAGES_PATH: str = os.path.join(PROJECT_ROOT, "static", "client") if SITE_PACKAGES_PATH not in sys.path: sys.path.append(SITE_PACKAGES_PATH) diff --git a/server_tests/client_renderer/__init__.py b/server_tests/client_renderer/__init__.py index 16339910..e7384f5e 100644 --- a/server_tests/client_renderer/__init__.py +++ b/server_tests/client_renderer/__init__.py @@ -5,6 +5,7 @@ import sys from types import ModuleType, SimpleNamespace + def _install_browser_stub() -> None: if "browser" in sys.modules: return @@ -80,8 +81,8 @@ def __init__(self) -> None: self.Math = SimpleNamespace() self.math = SimpleNamespace( format=lambda value: value, - sqrt=lambda value: value ** 0.5, - pow=lambda base, exp: base ** exp, + sqrt=lambda value: value**0.5, + pow=lambda base, exp: base**exp, det=lambda _matrix: 0.0, evaluate=lambda _expr, _vars=None: 0.0, ) @@ -96,6 +97,7 @@ def __init__(self) -> None: sys.modules["browser"] = browser + _install_browser_stub() __all__: list[str] = [] diff --git a/server_tests/client_renderer/renderer_fixtures.py b/server_tests/client_renderer/renderer_fixtures.py index b5e07f94..0774f6ed 100644 --- a/server_tests/client_renderer/renderer_fixtures.py +++ b/server_tests/client_renderer/renderer_fixtures.py @@ -89,7 +89,9 @@ class PrimitiveRecorder: def __init__(self) -> None: self.calls: List[Tuple[str, Tuple[Any, ...], Dict[str, Any]]] = [] - def fill_circle(self, center: Tuple[float, float], radius: float, fill: Any, stroke: Any = None, *, screen_space: bool = False) -> None: + def fill_circle( + self, center: Tuple[float, float], radius: float, fill: Any, stroke: Any = None, *, screen_space: bool = False + ) -> None: self.calls.append(("fill_circle", (center, radius, fill, stroke, screen_space), {})) def draw_text( @@ -119,7 +121,13 @@ class Offset: class CoordinateMapperStub: - def __init__(self, *, scale_factor: float = 1.0, origin: Tuple[float, float] = (0.0, 0.0), offset: Tuple[float, float] = (0.0, 0.0)) -> None: + def __init__( + self, + *, + scale_factor: float = 1.0, + origin: Tuple[float, float] = (0.0, 0.0), + offset: Tuple[float, float] = (0.0, 0.0), + ) -> None: self.scale_factor = scale_factor self.origin = SimpleNamespace(x=origin[0], y=origin[1]) self.offset = SimpleNamespace(x=offset[0], y=offset[1]) diff --git a/server_tests/client_renderer/test_canvas2d_primitive_adapter.py b/server_tests/client_renderer/test_canvas2d_primitive_adapter.py index 6539b7c3..87724ab6 100644 --- a/server_tests/client_renderer/test_canvas2d_primitive_adapter.py +++ b/server_tests/client_renderer/test_canvas2d_primitive_adapter.py @@ -114,7 +114,9 @@ def lineTo(self, x: float, y: float) -> None: def arc(self, x: float, y: float, radius: float, start: float, end: float, ccw: bool = False) -> None: self.operations.append(("arc", x, y, radius, start, end, ccw)) - def ellipse(self, x: float, y: float, rx: float, ry: float, rotation: float, start: float, end: float, ccw: bool = False) -> None: + def ellipse( + self, x: float, y: float, rx: float, ry: float, rotation: float, start: float, end: float, ccw: bool = False + ) -> None: self.operations.append(("ellipse", x, y, rx, ry, rotation, start, end, ccw)) def closePath(self) -> None: @@ -153,6 +155,7 @@ class TestCanvas2DPrimitiveAdapter(unittest.TestCase): def setUp(self) -> None: self.canvas_el = MockCanvasElement() from rendering.canvas2d_primitive_adapter import Canvas2DPrimitiveAdapter + self.adapter = Canvas2DPrimitiveAdapter(self.canvas_el) def test_stroke_line_draws_line(self) -> None: @@ -268,6 +271,7 @@ class TestCanvas2DPrimitiveAdapterStateManagement(unittest.TestCase): def test_stroke_state_cached_between_calls(self) -> None: canvas_el = MockCanvasElement() from rendering.canvas2d_primitive_adapter import Canvas2DPrimitiveAdapter + adapter = Canvas2DPrimitiveAdapter(canvas_el) stroke = StrokeStyle(color="#FF0000", width=2.0) @@ -287,6 +291,7 @@ def test_stroke_state_cached_between_calls(self) -> None: def test_different_strokes_change_state(self) -> None: canvas_el = MockCanvasElement() from rendering.canvas2d_primitive_adapter import Canvas2DPrimitiveAdapter + adapter = Canvas2DPrimitiveAdapter(canvas_el) stroke1 = StrokeStyle(color="#FF0000", width=1.0) @@ -302,6 +307,7 @@ def test_different_strokes_change_state(self) -> None: def test_fill_with_no_opacity_uses_default_alpha(self) -> None: canvas_el = MockCanvasElement() from rendering.canvas2d_primitive_adapter import Canvas2DPrimitiveAdapter + adapter = Canvas2DPrimitiveAdapter(canvas_el) fill = FillStyle(color="#0000FF", opacity=None) @@ -317,6 +323,7 @@ class TestCanvas2DPrimitiveAdapterEdgeCases(unittest.TestCase): def test_empty_polyline_does_not_crash(self) -> None: canvas_el = MockCanvasElement() from rendering.canvas2d_primitive_adapter import Canvas2DPrimitiveAdapter + adapter = Canvas2DPrimitiveAdapter(canvas_el) stroke = StrokeStyle(color="#000000", width=1.0) @@ -329,6 +336,7 @@ def test_empty_polyline_does_not_crash(self) -> None: def test_zero_radius_circle_handles_gracefully(self) -> None: canvas_el = MockCanvasElement() from rendering.canvas2d_primitive_adapter import Canvas2DPrimitiveAdapter + adapter = Canvas2DPrimitiveAdapter(canvas_el) fill = FillStyle(color="#FF0000") @@ -341,6 +349,7 @@ def test_zero_radius_circle_handles_gracefully(self) -> None: def test_negative_radius_circle_handles_gracefully(self) -> None: canvas_el = MockCanvasElement() from rendering.canvas2d_primitive_adapter import Canvas2DPrimitiveAdapter + adapter = Canvas2DPrimitiveAdapter(canvas_el) fill = FillStyle(color="#FF0000") @@ -356,4 +365,3 @@ def test_negative_radius_circle_handles_gracefully(self) -> None: "TestCanvas2DPrimitiveAdapterStateManagement", "TestCanvas2DPrimitiveAdapterEdgeCases", ] - diff --git a/server_tests/client_renderer/test_polar_renderer_plan.py b/server_tests/client_renderer/test_polar_renderer_plan.py index 5be2e9d1..ea6baf53 100644 --- a/server_tests/client_renderer/test_polar_renderer_plan.py +++ b/server_tests/client_renderer/test_polar_renderer_plan.py @@ -23,6 +23,7 @@ class Position: """Simple Position class for testing.""" + def __init__(self, x: float = 0, y: float = 0): self.x = x self.y = y @@ -259,28 +260,31 @@ class TestPolarRendererMethodAvailability(unittest.TestCase): def test_svg_renderer_has_render_polar_import(self) -> None: """Test that svg_renderer imports build_plan_for_polar.""" from rendering import svg_renderer + # Check that the import statement exists in the module import_source = svg_renderer.__file__ - with open(import_source, 'r') as f: + with open(import_source, "r") as f: content = f.read() self.assertIn("build_plan_for_polar", content) def test_canvas2d_renderer_has_render_polar_import(self) -> None: """Test that canvas2d_renderer imports build_plan_for_polar.""" from rendering import canvas2d_renderer + import_source = canvas2d_renderer.__file__ - with open(import_source, 'r') as f: + with open(import_source, "r") as f: content = f.read() self.assertIn("build_plan_for_polar", content) def test_webgl_renderer_has_render_polar_import(self) -> None: """Test that webgl_renderer imports build_plan_for_polar.""" from rendering import webgl_renderer + import_source = webgl_renderer.__file__ - with open(import_source, 'r') as f: + with open(import_source, "r") as f: content = f.read() self.assertIn("build_plan_for_polar", content) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/server_tests/client_renderer/test_renderer_factory_plan.py b/server_tests/client_renderer/test_renderer_factory_plan.py index 89262422..e7a3d178 100644 --- a/server_tests/client_renderer/test_renderer_factory_plan.py +++ b/server_tests/client_renderer/test_renderer_factory_plan.py @@ -35,7 +35,9 @@ def webgl_renderer() -> object: result = factory.create_renderer() self.assertIs(result, sentinel) - self.assertEqual(attempts, ["canvas2d", "svg"], "Factory should skip failing constructors and stop at first success") + self.assertEqual( + attempts, ["canvas2d", "svg"], "Factory should skip failing constructors and stop at first success" + ) def test_preferred_renderer_short_circuits_fallback(self) -> None: calls: dict[str, int] = {"canvas2d": 0, "svg": 0, "webgl": 0} diff --git a/server_tests/client_renderer/test_webgl_primitive_adapter.py b/server_tests/client_renderer/test_webgl_primitive_adapter.py index e91a73e5..4f128153 100644 --- a/server_tests/client_renderer/test_webgl_primitive_adapter.py +++ b/server_tests/client_renderer/test_webgl_primitive_adapter.py @@ -18,7 +18,9 @@ def _draw_lines(self, points: List[Tuple[float, float]], color: Tuple[float, flo def _draw_line_strip(self, points: List[Tuple[float, float]], color: Tuple[float, float, float, float]) -> None: self.draw_calls.append(("draw_line_strip", points, color)) - def _draw_points(self, points: List[Tuple[float, float]], color: Tuple[float, float, float, float], size: float) -> None: + def _draw_points( + self, points: List[Tuple[float, float]], color: Tuple[float, float, float, float], size: float + ) -> None: self.draw_calls.append(("draw_points", points, color, size)) def _parse_color(self, color: str) -> Tuple[float, float, float, float]: @@ -35,6 +37,7 @@ class TestWebGLPrimitiveAdapter(unittest.TestCase): def setUp(self) -> None: self.renderer = MockWebGLRenderer() from rendering.webgl_primitive_adapter import WebGLPrimitiveAdapter + self.adapter = WebGLPrimitiveAdapter(self.renderer) def test_stroke_line_draws_lines(self) -> None: @@ -120,6 +123,7 @@ def test_stroke_arc_approximates_with_samples(self) -> None: def test_color_parsing_hex_format(self) -> None: from rendering.webgl_primitive_adapter import WebGLPrimitiveAdapter + renderer = MockWebGLRenderer() adapter = WebGLPrimitiveAdapter(renderer) @@ -138,6 +142,7 @@ class TestWebGLPrimitiveAdapterSampling(unittest.TestCase): def test_circle_sampling_produces_closed_path(self) -> None: renderer = MockWebGLRenderer() from rendering.webgl_primitive_adapter import WebGLPrimitiveAdapter + adapter = WebGLPrimitiveAdapter(renderer) samples = adapter._sample_circle((50.0, 60.0), 20.0) @@ -152,6 +157,7 @@ def test_circle_sampling_produces_closed_path(self) -> None: def test_ellipse_sampling_produces_closed_path(self) -> None: renderer = MockWebGLRenderer() from rendering.webgl_primitive_adapter import WebGLPrimitiveAdapter + adapter = WebGLPrimitiveAdapter(renderer) samples = adapter._sample_ellipse((70.0, 80.0), 30.0, 20.0, 0.0) @@ -166,9 +172,11 @@ def test_ellipse_sampling_produces_closed_path(self) -> None: def test_arc_sampling_respects_angle_range(self) -> None: renderer = MockWebGLRenderer() from rendering.webgl_primitive_adapter import WebGLPrimitiveAdapter + adapter = WebGLPrimitiveAdapter(renderer) import math + samples = adapter._sample_arc((100.0, 110.0), 50.0, 0.0, math.pi / 2, True) self.assertGreater(len(samples), 2) @@ -178,6 +186,7 @@ class TestWebGLPrimitiveAdapterEdgeCases(unittest.TestCase): def test_empty_polyline_does_not_crash(self) -> None: renderer = MockWebGLRenderer() from rendering.webgl_primitive_adapter import WebGLPrimitiveAdapter + adapter = WebGLPrimitiveAdapter(renderer) stroke = StrokeStyle(color="#000000", width=1.0) @@ -190,6 +199,7 @@ def test_empty_polyline_does_not_crash(self) -> None: def test_single_point_polyline_does_not_crash(self) -> None: renderer = MockWebGLRenderer() from rendering.webgl_primitive_adapter import WebGLPrimitiveAdapter + adapter = WebGLPrimitiveAdapter(renderer) stroke = StrokeStyle(color="#000000", width=1.0) @@ -202,6 +212,7 @@ def test_single_point_polyline_does_not_crash(self) -> None: def test_zero_radius_circle_handles_gracefully(self) -> None: renderer = MockWebGLRenderer() from rendering.webgl_primitive_adapter import WebGLPrimitiveAdapter + adapter = WebGLPrimitiveAdapter(renderer) fill = FillStyle(color="#FF0000") @@ -214,6 +225,7 @@ def test_zero_radius_circle_handles_gracefully(self) -> None: def test_negative_radius_circle_handles_gracefully(self) -> None: renderer = MockWebGLRenderer() from rendering.webgl_primitive_adapter import WebGLPrimitiveAdapter + adapter = WebGLPrimitiveAdapter(renderer) fill = FillStyle(color="#FF0000") @@ -226,6 +238,7 @@ def test_negative_radius_circle_handles_gracefully(self) -> None: def test_fill_polygon_with_single_point_handles_gracefully(self) -> None: renderer = MockWebGLRenderer() from rendering.webgl_primitive_adapter import WebGLPrimitiveAdapter + adapter = WebGLPrimitiveAdapter(renderer) fill = FillStyle(color="#00FF00") @@ -241,4 +254,3 @@ def test_fill_polygon_with_single_point_handles_gracefully(self) -> None: "TestWebGLPrimitiveAdapterSampling", "TestWebGLPrimitiveAdapterEdgeCases", ] - diff --git a/server_tests/python_path_setup.py b/server_tests/python_path_setup.py index a31948a7..3dee0047 100644 --- a/server_tests/python_path_setup.py +++ b/server_tests/python_path_setup.py @@ -6,7 +6,8 @@ # Add the site-packages directory to Python path SITE_PACKAGES_PATH: str = os.path.join( os.path.dirname(os.path.dirname(os.path.abspath(__file__))), # Go up one level from Tests - 'static', 'client' + "static", + "client", ) # Add the path if it's not already there diff --git a/server_tests/test_adaptive_sampler.py b/server_tests/test_adaptive_sampler.py index 091ebd79..9c650c9f 100644 --- a/server_tests/test_adaptive_sampler.py +++ b/server_tests/test_adaptive_sampler.py @@ -10,7 +10,7 @@ import unittest from typing import Tuple -sys.path.insert(0, 'static/client/rendering/renderables') +sys.path.insert(0, "static/client/rendering/renderables") from adaptive_sampler import ( AdaptiveSampler, @@ -67,13 +67,9 @@ def test_linear_covers_full_range_after_pan(self) -> None: ] for left, right in pan_positions: samples = get_samples(left, right, lambda x: x, scaled_transform) + self.assertAlmostEqual(samples[0], left, places=10, msg=f"First sample {samples[0]} != left bound {left}") self.assertAlmostEqual( - samples[0], left, places=10, - msg=f"First sample {samples[0]} != left bound {left}" - ) - self.assertAlmostEqual( - samples[-1], right, places=10, - msg=f"Last sample {samples[-1]} != right bound {right}" + samples[-1], right, places=10, msg=f"Last sample {samples[-1]} != right bound {right}" ) def test_linear_no_gaps_in_coverage(self) -> None: @@ -159,7 +155,7 @@ class TestAdaptiveSamplerMaxDepth(unittest.TestCase): def test_respects_max_depth(self) -> None: """Should not exceed 2^MAX_DEPTH + 1 samples.""" - max_possible = (2 ** MAX_DEPTH) + 1 + max_possible = (2**MAX_DEPTH) + 1 samples = get_samples(-10, 10, lambda x: math.sin(100 * x), scaled_transform) self.assertLessEqual(len(samples), max_possible) @@ -169,9 +165,10 @@ class TestAdaptiveSamplerInvalidValues(unittest.TestCase): def test_handles_nan(self) -> None: """Should handle NaN values gracefully.""" + def func_with_nan(x: float) -> float: if x == 0: - return float('nan') + return float("nan") return x samples = get_samples(-10, 10, func_with_nan, identity_transform) @@ -180,9 +177,10 @@ def func_with_nan(x: float) -> float: def test_handles_inf(self) -> None: """Should handle infinity values gracefully.""" + def func_with_inf(x: float) -> float: if abs(x) < 0.01: - return float('inf') + return float("inf") return 1 / x samples = get_samples(-10, 10, func_with_inf, identity_transform) @@ -191,6 +189,7 @@ def func_with_inf(x: float) -> float: def test_handles_exception(self) -> None: """Should handle function exceptions gracefully.""" + def func_with_exception(x: float) -> float: if x == 0: raise ValueError("Division by zero") @@ -241,8 +240,11 @@ def test_linear_vs_curved_ratio(self) -> None: linear_samples = get_samples(-10, 10, lambda x: x, scaled_transform) curved_samples = get_samples(-10, 10, lambda x: x * x, scaled_transform) - self.assertGreater(len(curved_samples), len(linear_samples), - f"Curved ({len(curved_samples)}) should use more samples than linear ({len(linear_samples)})") + self.assertGreater( + len(curved_samples), + len(linear_samples), + f"Curved ({len(curved_samples)}) should use more samples than linear ({len(linear_samples)})", + ) class TestAdaptiveSamplerBenchmarks(unittest.TestCase): @@ -256,15 +258,15 @@ def _time_adaptive(self, eval_func) -> float: """Time adaptive sample generation.""" start = time.perf_counter() for _ in range(self.ITERATIONS): - AdaptiveSampler.generate_samples( - self.LEFT, self.RIGHT, eval_func, scaled_transform - ) + AdaptiveSampler.generate_samples(self.LEFT, self.RIGHT, eval_func, scaled_transform) return (time.perf_counter() - start) * 1000 / self.ITERATIONS def test_linear_benchmark(self) -> None: """Benchmark linear function y=x.""" + def eval_func(x): return x + adaptive_ms = self._time_adaptive(eval_func) adaptive_count = len(get_samples(self.LEFT, self.RIGHT, eval_func, scaled_transform)) @@ -273,8 +275,10 @@ def eval_func(x): def test_quadratic_benchmark(self) -> None: """Benchmark quadratic function y=x^2.""" + def eval_func(x): return x * x + adaptive_ms = self._time_adaptive(eval_func) adaptive_count = len(get_samples(self.LEFT, self.RIGHT, eval_func, scaled_transform)) @@ -282,8 +286,10 @@ def eval_func(x): def test_sin_benchmark(self) -> None: """Benchmark sin function.""" + def eval_func(x): return math.sin(x) + adaptive_ms = self._time_adaptive(eval_func) adaptive_count = len(get_samples(self.LEFT, self.RIGHT, eval_func, scaled_transform)) @@ -291,8 +297,10 @@ def eval_func(x): def test_high_amplitude_sin_benchmark(self) -> None: """Benchmark high amplitude sin function.""" + def eval_func(x): return math.sin(x) * 100 + adaptive_ms = self._time_adaptive(eval_func) adaptive_count = len(get_samples(self.LEFT, self.RIGHT, eval_func, scaled_transform)) @@ -300,16 +308,16 @@ def eval_func(x): def test_high_frequency_sin_benchmark(self) -> None: """Benchmark high frequency sin(10x).""" + def eval_func(x): return 10 * math.sin(10 * x) + adaptive_ms = self._time_adaptive(eval_func) adaptive_count = len(get_samples(self.LEFT, self.RIGHT, eval_func, scaled_transform)) print(f"\n### High Freq Sin (y=10*sin(10x)): {adaptive_ms:.3f}ms, {adaptive_count} samples") - self.assertGreater(adaptive_count, 20, - f"High frequency sin should produce >20 samples, got {adaptive_count}") + self.assertGreater(adaptive_count, 20, f"High frequency sin should produce >20 samples, got {adaptive_count}") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() - diff --git a/server_tests/test_ai_model.py b/server_tests/test_ai_model.py index ef1e99e7..6d958ff6 100644 --- a/server_tests/test_ai_model.py +++ b/server_tests/test_ai_model.py @@ -17,11 +17,7 @@ class TestAIModel(unittest.TestCase): def test_model_initialization(self) -> None: """Test direct model initialization with all parameters.""" - model = AIModel( - identifier="test-model", - has_vision=True, - is_reasoning_model=True - ) + model = AIModel(identifier="test-model", has_vision=True, is_reasoning_model=True) self.assertEqual(model.id, "test-model") self.assertTrue(model.has_vision) self.assertTrue(model.is_reasoning_model) @@ -120,37 +116,29 @@ def test_all_reasoning_models_identified(self) -> None: reasoning_models = ["gpt-5-chat-latest", "gpt-5.2-chat-latest", "gpt-5.2", "o3", "o4-mini"] for model_id in reasoning_models: model = AIModel.from_identifier(model_id) - self.assertTrue( - model.is_reasoning_model, - f"{model_id} should be a reasoning model" - ) + self.assertTrue(model.is_reasoning_model, f"{model_id} should be a reasoning model") def test_all_standard_models_identified(self) -> None: """Test that all standard models are correctly identified as non-reasoning.""" standard_models = [ - "gpt-4.1", "gpt-4.1-mini", "gpt-4.1-nano", - "gpt-4o", "gpt-4o-mini", "gpt-5-nano", "gpt-3.5-turbo" + "gpt-4.1", + "gpt-4.1-mini", + "gpt-4.1-nano", + "gpt-4o", + "gpt-4o-mini", + "gpt-5-nano", + "gpt-3.5-turbo", ] for model_id in standard_models: model = AIModel.from_identifier(model_id) - self.assertFalse( - model.is_reasoning_model, - f"{model_id} should NOT be a reasoning model" - ) + self.assertFalse(model.is_reasoning_model, f"{model_id} should NOT be a reasoning model") def test_model_configs_completeness(self) -> None: """Test that MODEL_CONFIGS has both required keys for all models.""" for model_id, config in AIModel.MODEL_CONFIGS.items(): - self.assertIn( - "has_vision", config, - f"{model_id} missing 'has_vision' in config" - ) - self.assertIn( - "is_reasoning_model", config, - f"{model_id} missing 'is_reasoning_model' in config" - ) + self.assertIn("has_vision", config, f"{model_id} missing 'has_vision' in config") + self.assertIn("is_reasoning_model", config, f"{model_id} missing 'is_reasoning_model' in config") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() - diff --git a/server_tests/test_browser_typing_stubs.py b/server_tests/test_browser_typing_stubs.py new file mode 100644 index 00000000..ed791c05 --- /dev/null +++ b/server_tests/test_browser_typing_stubs.py @@ -0,0 +1,1412 @@ +"""Acceptance tests for Brython browser module type stubs. + +Validates that the .pyi stubs under static/client/typing/browser/ +are syntactically valid, export the expected names, and pass MyPy +type-checking for common browser API usage patterns. +""" + +from __future__ import annotations + +import ast +import glob +import os +import tempfile +import unittest + +from mypy import api + + +# --------------------------------------------------------------------------- +# Paths +# --------------------------------------------------------------------------- + +_ROOT: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +_STUBS_DIR: str = os.path.join(_ROOT, "static", "client", "typing", "browser") +_MYPY_PATH: str = os.path.join(_ROOT, "static", "client", "typing") + +_INIT_PYI: str = os.path.join(_STUBS_DIR, "__init__.pyi") +_DOM_PYI: str = os.path.join(_STUBS_DIR, "_dom.pyi") +_AJAX_PYI: str = os.path.join(_STUBS_DIR, "ajax.pyi") +_AIO_PYI: str = os.path.join(_STUBS_DIR, "aio.pyi") + +_ALL_STUB_FILES: list[str] = [_INIT_PYI, _DOM_PYI, _AJAX_PYI, _AIO_PYI] + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _run_mypy_snippet(snippet: str) -> tuple[str, str, int]: + """Run mypy on a code snippet with browser stubs available. + + Sets ``MYPYPATH`` so mypy discovers the browser stub package. + """ + fd, path = tempfile.mkstemp(suffix=".py") + old_mypypath = os.environ.get("MYPYPATH") + try: + with os.fdopen(fd, "w") as f: + f.write(snippet) + os.environ["MYPYPATH"] = _MYPY_PATH + result: tuple[str, str, int] = api.run( + [ + "--python-version", + "3.11", + "--no-error-summary", + "--ignore-missing-imports", + "--follow-imports", + "skip", + "--explicit-package-bases", + path, + ] + ) + return result + finally: + if old_mypypath is None: + os.environ.pop("MYPYPATH", None) + else: + os.environ["MYPYPATH"] = old_mypypath + os.unlink(path) + + +def _parse_class_names(tree: ast.Module, class_name: str) -> set[str]: + """Extract method and attribute names from a class in an AST.""" + names: set[str] = set() + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef) and node.name == class_name: + for item in node.body: + if isinstance(item, ast.FunctionDef): + names.add(item.name) + elif isinstance(item, ast.AnnAssign) and isinstance(item.target, ast.Name): + names.add(item.target.id) + return names + + +def _get_all_class_names(tree: ast.Module) -> set[str]: + """Extract all class names defined at module level in an AST.""" + names: set[str] = set() + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.ClassDef): + names.add(node.name) + return names + + +def _get_all_function_names(tree: ast.Module) -> set[str]: + """Extract all function names defined at module level in an AST.""" + names: set[str] = set() + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.FunctionDef): + names.add(node.name) + return names + + +# --------------------------------------------------------------------------- +# Tests: Stub file validity +# --------------------------------------------------------------------------- + + +class TestStubFileValidity(unittest.TestCase): + """Verify that stub files exist, parse correctly, and contain no .py files.""" + + def test_stub_files_are_valid_python(self) -> None: + """Each .pyi file must parse as valid Python.""" + for pyi_path in _ALL_STUB_FILES: + with self.subTest(file=os.path.basename(pyi_path)): + with open(pyi_path) as f: + source = f.read() + try: + ast.parse(source) + except SyntaxError as exc: + self.fail(f"{pyi_path} has a syntax error: {exc}") + + def test_all_stub_files_exist(self) -> None: + """All expected stub files must exist on disk.""" + for pyi_path in _ALL_STUB_FILES: + with self.subTest(file=os.path.basename(pyi_path)): + self.assertTrue( + os.path.isfile(pyi_path), + f"Stub file does not exist: {pyi_path}", + ) + + def test_no_runtime_py_files_in_stubs(self) -> None: + """The stubs directory must contain only .pyi files, no .py files.""" + py_files = glob.glob(os.path.join(_STUBS_DIR, "*.py")) + self.assertEqual( + py_files, + [], + f"Found .py files in stubs directory (should be .pyi only): {py_files}", + ) + + def test_stubs_directory_is_a_package(self) -> None: + """The browser stubs directory must have an __init__.pyi file.""" + self.assertTrue( + os.path.isfile(_INIT_PYI), + "__init__.pyi must exist for the browser stub package", + ) + + def test_all_stubs_have_future_annotations(self) -> None: + """Each stub file must use ``from __future__ import annotations``.""" + for pyi_path in _ALL_STUB_FILES: + with self.subTest(file=os.path.basename(pyi_path)): + with open(pyi_path) as f: + tree = ast.parse(f.read()) + has_future = False + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.ImportFrom) and node.module == "__future__": + for alias in node.names: + if alias.name == "annotations": + has_future = True + self.assertTrue( + has_future, + f"{os.path.basename(pyi_path)} missing 'from __future__ import annotations'", + ) + + def test_all_stubs_have_docstrings(self) -> None: + """Each stub file must have a module-level docstring.""" + for pyi_path in _ALL_STUB_FILES: + with self.subTest(file=os.path.basename(pyi_path)): + with open(pyi_path) as f: + tree = ast.parse(f.read()) + docstring = ast.get_docstring(tree) + self.assertIsNotNone( + docstring, + f"{os.path.basename(pyi_path)} missing module docstring", + ) + + def test_old_single_file_stub_removed(self) -> None: + """The obsolete single-file browser.pyi must no longer exist.""" + old_stub = os.path.join(_ROOT, "static", "client", "browser.pyi") + self.assertFalse( + os.path.exists(old_stub), + f"Obsolete stub still exists and should be deleted: {old_stub}", + ) + + +# --------------------------------------------------------------------------- +# Tests: Export completeness +# --------------------------------------------------------------------------- + + +class TestExportCompleteness(unittest.TestCase): + """Verify that stub packages export all names used by the codebase.""" + + def test_all_expected_names_exported(self) -> None: + """__init__.pyi must export all 7 module-level names plus DOMNode and ClassList.""" + with open(_INIT_PYI) as f: + tree = ast.parse(f.read()) + + names: set[str] = set() + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name): + names.add(node.target.id) + elif isinstance(node, ast.ImportFrom): + for alias in node.names: + exported = alias.asname if alias.asname else alias.name + names.add(exported) + elif isinstance(node, ast.ClassDef): + names.add(node.name) + + expected = {"document", "window", "html", "svg", "console", "ajax", "aio", "DOMNode", "ClassList"} + missing = expected - names + self.assertFalse(missing, f"Missing exports in __init__.pyi: {missing}") + + def test_reexports_use_explicit_as_syntax(self) -> None: + """Re-exports must use ``X as X`` syntax for implicit_reexport=False compat.""" + with open(_INIT_PYI) as f: + tree = ast.parse(f.read()) + + reexport_names = {"DOMNode", "ClassList", "ajax", "aio"} + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.ImportFrom): + for alias in node.names: + if alias.name in reexport_names: + with self.subTest(name=alias.name): + self.assertEqual( + alias.asname, + alias.name, + f"Re-export '{alias.name}' must use 'as {alias.name}' " + f"for implicit_reexport=False compatibility", + ) + + def test_init_defines_public_classes(self) -> None: + """__init__.pyi must define Document, Window, HTMLFactory, SVGFactory, Console.""" + with open(_INIT_PYI) as f: + tree = ast.parse(f.read()) + + classes = _get_all_class_names(tree) + expected = {"Document", "Window", "HTMLFactory", "SVGFactory", "Console"} + missing = expected - classes + self.assertFalse(missing, f"Missing public classes: {missing}") + + def test_init_defines_helper_classes(self) -> None: + """__init__.pyi must define private helper classes for typed Window sub-objects.""" + with open(_INIT_PYI) as f: + tree = ast.parse(f.read()) + + classes = _get_all_class_names(tree) + expected_helpers = { + "_JSON", + "_LocalStorage", + "_Performance", + "_MathJS", + "_NerdamerExpr", + "_Nerdamer", + "_Date", + "_URL", + } + missing = expected_helpers - classes + self.assertFalse(missing, f"Missing helper classes: {missing}") + + def test_module_level_instances_annotated(self) -> None: + """Module-level singletons must be annotated with their class types.""" + with open(_INIT_PYI) as f: + tree = ast.parse(f.read()) + + annotations: dict[str, str] = {} + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name): + if isinstance(node.annotation, ast.Name): + annotations[node.target.id] = node.annotation.id + + expected = { + "document": "Document", + "window": "Window", + "html": "HTMLFactory", + "svg": "SVGFactory", + "console": "Console", + } + for name, expected_type in expected.items(): + with self.subTest(name=name): + self.assertIn(name, annotations, f"Missing annotation for '{name}'") + self.assertEqual( + annotations[name], + expected_type, + f"'{name}' should be annotated as {expected_type}, got {annotations.get(name)}", + ) + + +# --------------------------------------------------------------------------- +# Tests: DOMNode structure +# --------------------------------------------------------------------------- + + +class TestDOMNodeStructure(unittest.TestCase): + """Verify DOMNode and ClassList class structures in _dom.pyi.""" + + def test_dom_node_has_required_methods(self) -> None: + """DOMNode class must have all expected methods.""" + with open(_DOM_PYI) as f: + tree = ast.parse(f.read()) + + methods = _parse_class_names(tree, "DOMNode") + expected_methods = { + "getBoundingClientRect", + "appendChild", + "removeChild", + "setAttribute", + "getAttribute", + "removeAttribute", + "insertBefore", + "cloneNode", + "bind", + "focus", + "blur", + "click", + "clear", + "remove", + "getContext", + "select", + "select_one", + "__getitem__", + "__setitem__", + "__le__", + "__contains__", + } + missing = expected_methods - methods + self.assertFalse(missing, f"DOMNode missing methods: {missing}") + + def test_dom_node_has_required_properties(self) -> None: + """DOMNode class must have all expected properties.""" + with open(_DOM_PYI) as f: + tree = ast.parse(f.read()) + + attrs = _parse_class_names(tree, "DOMNode") + expected_props = { + "innerHTML", + "text", + "value", + "disabled", + "checked", + "options", + "scrollTop", + "scrollHeight", + "style", + "classList", + "attrs", + "parentNode", + "children", + "firstChild", + "width", + "height", + "onload", + "result", + "responseType", + } + missing = expected_props - attrs + self.assertFalse(missing, f"DOMNode missing properties: {missing}") + + def test_classlist_has_required_methods(self) -> None: + """ClassList class must have add, remove, contains methods.""" + with open(_DOM_PYI) as f: + tree = ast.parse(f.read()) + + methods = _parse_class_names(tree, "ClassList") + expected = {"add", "remove", "contains"} + missing = expected - methods + self.assertFalse(missing, f"ClassList missing methods: {missing}") + + def test_dom_node_classlist_typed_as_classlist(self) -> None: + """DOMNode.classList must be annotated as ClassList type.""" + with open(_DOM_PYI) as f: + tree = ast.parse(f.read()) + + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef) and node.name == "DOMNode": + for item in node.body: + if ( + isinstance(item, ast.AnnAssign) + and isinstance(item.target, ast.Name) + and item.target.id == "classList" + ): + self.assertIsInstance(item.annotation, ast.Name) + assert isinstance(item.annotation, ast.Name) + self.assertEqual(item.annotation.id, "ClassList") + return + self.fail("DOMNode.classList annotation not found") + + def test_dom_node_parent_typed_as_optional(self) -> None: + """DOMNode.parentNode must be typed as DOMNode | None.""" + with open(_DOM_PYI) as f: + source = f.read() + self.assertIn("parentNode: DOMNode | None", source) + + def test_dom_node_firstchild_typed_as_optional(self) -> None: + """DOMNode.firstChild must be typed as DOMNode | None.""" + with open(_DOM_PYI) as f: + source = f.read() + self.assertIn("firstChild: DOMNode | None", source) + + +# --------------------------------------------------------------------------- +# Tests: Ajax module structure +# --------------------------------------------------------------------------- + + +class TestAjaxModuleStructure(unittest.TestCase): + """Verify ajax.pyi class and function structures.""" + + def test_ajax_module_has_required_symbols(self) -> None: + """ajax.pyi must define AjaxRequest, Ajax, ajax, post.""" + with open(_AJAX_PYI) as f: + tree = ast.parse(f.read()) + + names: set[str] = set() + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.ClassDef): + names.add(node.name) + elif isinstance(node, ast.FunctionDef): + names.add(node.name) + + expected = {"AjaxRequest", "Ajax", "ajax", "post"} + missing = expected - names + self.assertFalse(missing, f"ajax.pyi missing symbols: {missing}") + + def test_ajax_request_has_required_members(self) -> None: + """AjaxRequest class must have status, text, response, responseType, bind, open, set_header, send.""" + with open(_AJAX_PYI) as f: + tree = ast.parse(f.read()) + + members = _parse_class_names(tree, "AjaxRequest") + expected = {"status", "text", "response", "responseType", "bind", "open", "set_header", "send"} + missing = expected - members + self.assertFalse(missing, f"AjaxRequest missing members: {missing}") + + def test_ajax_class_has_same_interface_as_request(self) -> None: + """Ajax class must have the same interface as AjaxRequest.""" + with open(_AJAX_PYI) as f: + tree = ast.parse(f.read()) + + request_members = _parse_class_names(tree, "AjaxRequest") + ajax_members = _parse_class_names(tree, "Ajax") + missing = request_members - ajax_members + self.assertFalse( + missing, + f"Ajax class missing members present in AjaxRequest: {missing}", + ) + + def test_ajax_is_separate_class_not_alias(self) -> None: + """Ajax must be a separate ClassDef, not a type alias.""" + with open(_AJAX_PYI) as f: + tree = ast.parse(f.read()) + + ajax_classes = [ + node for node in ast.iter_child_nodes(tree) if isinstance(node, ast.ClassDef) and node.name == "Ajax" + ] + self.assertEqual( + len(ajax_classes), + 1, + "Ajax must be defined as a separate class (not a type alias)", + ) + + +# --------------------------------------------------------------------------- +# Tests: Aio module structure +# --------------------------------------------------------------------------- + + +class TestAioModuleStructure(unittest.TestCase): + """Verify aio.pyi function structures.""" + + def test_aio_module_has_required_symbols(self) -> None: + """aio.pyi must define run and sleep.""" + with open(_AIO_PYI) as f: + tree = ast.parse(f.read()) + + names = _get_all_function_names(tree) + expected = {"run", "sleep"} + missing = expected - names + self.assertFalse(missing, f"aio.pyi missing symbols: {missing}") + + def test_aio_has_no_classes(self) -> None: + """aio.pyi should only have functions, no classes.""" + with open(_AIO_PYI) as f: + tree = ast.parse(f.read()) + + classes = _get_all_class_names(tree) + self.assertEqual(classes, set(), f"aio.pyi should not define classes: {classes}") + + +# --------------------------------------------------------------------------- +# Tests: Window class structure +# --------------------------------------------------------------------------- + + +class TestWindowStructure(unittest.TestCase): + """Verify Window class structure in __init__.pyi.""" + + def test_window_has_typed_sub_objects(self) -> None: + """Window must have typed sub-object annotations.""" + with open(_INIT_PYI) as f: + tree = ast.parse(f.read()) + + members = _parse_class_names(tree, "Window") + expected = { + "JSON", + "localStorage", + "performance", + "math", + "nerdamer", + "Date", + "URL", + } + missing = expected - members + self.assertFalse(missing, f"Window missing typed sub-objects: {missing}") + + def test_window_has_direct_methods(self) -> None: + """Window must have setTimeout, clearTimeout, requestAnimationFrame.""" + with open(_INIT_PYI) as f: + tree = ast.parse(f.read()) + + members = _parse_class_names(tree, "Window") + expected = {"setTimeout", "clearTimeout", "requestAnimationFrame"} + missing = expected - members + self.assertFalse(missing, f"Window missing direct methods: {missing}") + + def test_window_has_constructor_attributes(self) -> None: + """Window must have constructor-like attributes (Audio, Float32Array, etc.).""" + with open(_INIT_PYI) as f: + tree = ast.parse(f.read()) + + members = _parse_class_names(tree, "Window") + expected = {"Audio", "Float32Array", "FileReader", "MouseEvent", "MathJax", "Math"} + missing = expected - members + self.assertFalse(missing, f"Window missing constructor attributes: {missing}") + + def test_window_has_escape_hatches(self) -> None: + """Window must define __getattr__, __setattr__, __getitem__, __setitem__.""" + with open(_INIT_PYI) as f: + tree = ast.parse(f.read()) + + members = _parse_class_names(tree, "Window") + expected = {"__getattr__", "__setattr__", "__getitem__", "__setitem__"} + missing = expected - members + self.assertFalse(missing, f"Window missing escape hatches: {missing}") + + +# --------------------------------------------------------------------------- +# Tests: Document class structure +# --------------------------------------------------------------------------- + + +class TestDocumentStructure(unittest.TestCase): + """Verify Document class structure in __init__.pyi.""" + + def test_document_has_query_methods(self) -> None: + """Document must have getElementById, querySelector, select, select_one.""" + with open(_INIT_PYI) as f: + tree = ast.parse(f.read()) + + members = _parse_class_names(tree, "Document") + expected = {"getElementById", "querySelector", "select", "select_one"} + missing = expected - members + self.assertFalse(missing, f"Document missing query methods: {missing}") + + def test_document_has_operators(self) -> None: + """Document must define __getitem__, __contains__, __le__.""" + with open(_INIT_PYI) as f: + tree = ast.parse(f.read()) + + members = _parse_class_names(tree, "Document") + expected = {"__getitem__", "__contains__", "__le__"} + missing = expected - members + self.assertFalse(missing, f"Document missing operators: {missing}") + + def test_document_has_bind_method(self) -> None: + """Document must define bind for event handling.""" + with open(_INIT_PYI) as f: + tree = ast.parse(f.read()) + + members = _parse_class_names(tree, "Document") + self.assertIn("bind", members) + + def test_document_not_subclass_of_domnode(self) -> None: + """Document must NOT be a subclass of DOMNode to avoid Liskov conflicts.""" + with open(_INIT_PYI) as f: + tree = ast.parse(f.read()) + + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef) and node.name == "Document": + domnode_bases = [b for b in node.bases if isinstance(b, ast.Name) and b.id == "DOMNode"] + self.assertEqual( + len(domnode_bases), + 0, + "Document must not inherit from DOMNode (Liskov conflict)", + ) + return + self.fail("Document class not found") + + +# --------------------------------------------------------------------------- +# Tests: HTMLFactory and SVGFactory structure +# --------------------------------------------------------------------------- + + +class TestFactoryStructure(unittest.TestCase): + """Verify HTMLFactory and SVGFactory class structures.""" + + def test_html_factory_has_element_methods(self) -> None: + """HTMLFactory must define methods for all used HTML elements.""" + with open(_INIT_PYI) as f: + tree = ast.parse(f.read()) + + members = _parse_class_names(tree, "HTMLFactory") + expected = { + "DIV", + "SPAN", + "BUTTON", + "INPUT", + "TEXTAREA", + "LABEL", + "SELECT", + "OPTION", + "H3", + "P", + "IMG", + "CANVAS", + "DETAILS", + "SUMMARY", + } + missing = expected - members + self.assertFalse(missing, f"HTMLFactory missing element methods: {missing}") + + def test_html_factory_has_getattr_fallback(self) -> None: + """HTMLFactory must define __getattr__ for unlisted elements.""" + with open(_INIT_PYI) as f: + tree = ast.parse(f.read()) + + members = _parse_class_names(tree, "HTMLFactory") + self.assertIn("__getattr__", members) + + def test_svg_factory_has_element_methods(self) -> None: + """SVGFactory must define methods for all used SVG elements.""" + with open(_INIT_PYI) as f: + tree = ast.parse(f.read()) + + members = _parse_class_names(tree, "SVGFactory") + expected = {"svg", "g", "line", "path", "circle", "ellipse", "polygon", "text"} + missing = expected - members + self.assertFalse(missing, f"SVGFactory missing element methods: {missing}") + + def test_svg_factory_has_getattr_fallback(self) -> None: + """SVGFactory must define __getattr__ for unlisted elements.""" + with open(_INIT_PYI) as f: + tree = ast.parse(f.read()) + + members = _parse_class_names(tree, "SVGFactory") + self.assertIn("__getattr__", members) + + +# --------------------------------------------------------------------------- +# Tests: Console structure +# --------------------------------------------------------------------------- + + +class TestConsoleStructure(unittest.TestCase): + """Verify Console class structure.""" + + def test_console_has_logging_methods(self) -> None: + """Console must have log, error, warn, groupCollapsed, groupEnd.""" + with open(_INIT_PYI) as f: + tree = ast.parse(f.read()) + + members = _parse_class_names(tree, "Console") + expected = {"log", "error", "warn", "groupCollapsed", "groupEnd"} + missing = expected - members + self.assertFalse(missing, f"Console missing methods: {missing}") + + +# --------------------------------------------------------------------------- +# Tests: MyPy resolution and type-checking +# --------------------------------------------------------------------------- + + +class TestMyPyResolution(unittest.TestCase): + """Verify MyPy can resolve browser imports and type-check operations.""" + + def test_mypy_resolves_browser_imports(self) -> None: + """MyPy must resolve all browser imports without errors.""" + snippet = "from browser import document, window, html, svg, ajax, aio, console\n" + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + self.assertNotIn( + "Cannot find implementation or library stub", + stdout, + f"MyPy could not find browser stubs:\n{stdout}", + ) + + def test_mypy_resolves_domnode_import(self) -> None: + """MyPy must resolve DOMNode and ClassList from browser._dom.""" + snippet = "from browser import DOMNode, ClassList\nreveal_type(DOMNode)\nreveal_type(ClassList)\n" + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_reveal_types(self) -> None: + """reveal_type() must produce expected stub types.""" + snippet = ( + "from browser import document, window, html, svg, console\n" + "reveal_type(document)\n" + "reveal_type(window)\n" + "reveal_type(html)\n" + "reveal_type(svg)\n" + "reveal_type(console)\n" + ) + stdout, _stderr, _exit_code = _run_mypy_snippet(snippet) + self.assertIn("Document", stdout) + self.assertIn("Window", stdout) + self.assertIn("HTMLFactory", stdout) + self.assertIn("SVGFactory", stdout) + self.assertIn("Console", stdout) + + def test_mypy_reveal_ajax_types(self) -> None: + """reveal_type() for ajax objects must show AjaxRequest/Ajax.""" + snippet = "from browser import ajax\nreq = ajax.ajax()\nreveal_type(req)\n" + stdout, _stderr, _exit_code = _run_mypy_snippet(snippet) + self.assertIn("AjaxRequest", stdout) + + def test_stubs_pass_mypy(self) -> None: + """MyPy must accept the stub files themselves without errors.""" + old_mypypath = os.environ.get("MYPYPATH") + try: + os.environ["MYPYPATH"] = _MYPY_PATH + result = api.run( + [ + "--python-version", + "3.11", + "--no-error-summary", + "--ignore-missing-imports", + "--explicit-package-bases", + *_ALL_STUB_FILES, + ] + ) + finally: + if old_mypypath is None: + os.environ.pop("MYPYPATH", None) + else: + os.environ["MYPYPATH"] = old_mypypath + stdout, stderr, exit_code = result + self.assertEqual(exit_code, 0, f"MyPy failed on stubs:\nstdout: {stdout}\nstderr: {stderr}") + + +# --------------------------------------------------------------------------- +# Tests: MyPy DOM operations +# --------------------------------------------------------------------------- + + +class TestMyPyDOMOperations(unittest.TestCase): + """Verify MyPy type-checks common DOM operation patterns.""" + + def test_mypy_basic_dom_operations(self) -> None: + """Basic DOM operations must type-check without errors.""" + snippet = ( + "from browser import document, window, html\n" + "el = document.getElementById('x')\n" + "timer = window.setTimeout(lambda: None, 100)\n" + "div = html.DIV(Class='foo')\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_document_subscript_access(self) -> None: + """document['id'] must type-check for element access by ID.""" + snippet = "from browser import document\nel = document['chat-input']\nel.value = 'hello'\n" + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_document_contains(self) -> None: + """'id' in document must type-check for element existence checks.""" + snippet = "from browser import document\nif 'run-tests-button' in document:\n pass\n" + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_dom_append_operator(self) -> None: + """document <= element (Brython DOM append) must type-check.""" + snippet = "from browser import document, html\ndiv = html.DIV()\ndocument <= div\n" + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_element_append_operator(self) -> None: + """parent <= child (Brython element append) must type-check.""" + snippet = "from browser import html\nparent = html.DIV()\nchild = html.SPAN('text')\nparent <= child\n" + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_dom_node_methods(self) -> None: + """DOMNode methods (appendChild, setAttribute, etc.) must type-check.""" + snippet = ( + "from browser import document, html\n" + "parent = html.DIV()\n" + "child = html.SPAN()\n" + "parent.appendChild(child)\n" + "parent.setAttribute('data-id', '42')\n" + "val = parent.getAttribute('data-id')\n" + "parent.removeAttribute('data-id')\n" + "clone = parent.cloneNode(True)\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_dom_node_properties(self) -> None: + """DOMNode properties (innerHTML, scrollTop, etc.) must type-check.""" + snippet = ( + "from browser import html\n" + "el = html.DIV()\n" + "el.innerHTML = 'bold'\n" + "s: str = el.innerHTML\n" + "el.style.display = 'none'\n" + "top: float = el.scrollTop\n" + "height: float = el.scrollHeight\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_classlist_operations(self) -> None: + """ClassList operations (add, remove, contains) must type-check.""" + snippet = ( + "from browser import html\n" + "el = html.DIV()\n" + "el.classList.add('active')\n" + "el.classList.remove('hidden')\n" + "has: bool = el.classList.contains('active')\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_event_binding(self) -> None: + """Element event binding must type-check.""" + snippet = ( + "from browser import document, html\n" + "el = html.BUTTON('Click me')\n" + "el.bind('click', lambda ev: None)\n" + "document.bind('keydown', lambda ev: None)\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_dom_traversal(self) -> None: + """DOM node traversal (parentNode, firstChild) must type-check.""" + snippet = ( + "from browser import html, DOMNode\n" + "el = html.DIV()\n" + "parent = el.parentNode\n" + "if parent is not None:\n" + " parent.removeChild(el)\n" + "first = el.firstChild\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_select_methods(self) -> None: + """DOMNode.select() and select_one() must type-check.""" + snippet = ( + "from browser import html, DOMNode\n" + "el = html.DIV()\n" + "results: list[DOMNode] = el.select('.child')\n" + "one = el.select_one('.child')\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + +# --------------------------------------------------------------------------- +# Tests: MyPy HTML/SVG factory operations +# --------------------------------------------------------------------------- + + +class TestMyPyFactoryOperations(unittest.TestCase): + """Verify MyPy type-checks HTML and SVG factory patterns.""" + + def test_mypy_html_factory_with_content(self) -> None: + """HTMLFactory methods with positional content must type-check.""" + snippet = ( + "from browser import html\n" + "btn = html.BUTTON('Click me', Class='btn-primary')\n" + "span = html.SPAN('text', Class='highlight')\n" + "p = html.P('paragraph')\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_html_factory_keyword_only(self) -> None: + """HTMLFactory methods with keyword-only args must type-check.""" + snippet = ( + "from browser import html\n" + "inp = html.INPUT(id='user-input')\n" + "sel = html.SELECT(id='model-selector')\n" + "canvas = html.CANVAS(id='main-canvas')\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_html_factory_returns_domnode(self) -> None: + """HTMLFactory methods must return DOMNode-compatible objects.""" + snippet = ( + "from browser import html, DOMNode\n" + "div = html.DIV()\n" + "div.appendChild(html.SPAN())\n" + "div.setAttribute('data-custom', 'val')\n" + "div.bind('click', lambda ev: None)\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_html_factory_getattr_fallback(self) -> None: + """HTMLFactory.__getattr__ must allow arbitrary element creation.""" + snippet = "from browser import html\nel = html.ARTICLE(Class='content')\n" + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_svg_factory_operations(self) -> None: + """SVG factory element creation must type-check.""" + snippet = ( + "from browser import svg\n" + "root = svg.svg(width='100', height='100')\n" + "group = svg.g(id='layer1')\n" + "ln = svg.line(x1='0', y1='0', x2='100', y2='100')\n" + "p = svg.path(d='M0 0 L10 10')\n" + "c = svg.circle(cx='50', cy='50', r='25')\n" + "e = svg.ellipse(cx='50', cy='50', rx='30', ry='20')\n" + "pg = svg.polygon(points='0,0 10,0 5,10')\n" + "t = svg.text(x='10', y='20')\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_svg_factory_getattr_fallback(self) -> None: + """SVGFactory.__getattr__ must allow arbitrary SVG element creation.""" + snippet = "from browser import svg\nrect = svg.rect(x='0', y='0', width='100', height='50')\n" + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + +# --------------------------------------------------------------------------- +# Tests: MyPy Window operations +# --------------------------------------------------------------------------- + + +class TestMyPyWindowOperations(unittest.TestCase): + """Verify MyPy type-checks Window object usage patterns.""" + + def test_mypy_window_json(self) -> None: + """window.JSON.stringify/parse must type-check.""" + snippet = ( + "from browser import window\n" + "from typing import Any\n" + "s: str = window.JSON.stringify({'key': 'value'})\n" + "obj: Any = window.JSON.parse(s)\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_window_localstorage(self) -> None: + """window.localStorage methods must type-check.""" + snippet = ( + "from browser import window\n" + "val = window.localStorage.getItem('key')\n" + "window.localStorage.setItem('key', 'value')\n" + "window.localStorage.removeItem('key')\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_window_performance(self) -> None: + """window.performance.now() must type-check and return float.""" + snippet = "from browser import window\nt: float = window.performance.now()\n" + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_window_timeout(self) -> None: + """setTimeout/clearTimeout must type-check with correct types.""" + snippet = ( + "from browser import window\n" + "timer_id: int = window.setTimeout(lambda: None, 100)\n" + "window.clearTimeout(timer_id)\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_window_request_animation_frame(self) -> None: + """requestAnimationFrame must type-check.""" + snippet = "from browser import window\nframe_id: int = window.requestAnimationFrame(lambda ts: None)\n" + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_window_mathjs(self) -> None: + """window.math (math.js) operations must type-check.""" + snippet = ( + "from browser import window\n" + "from typing import Any\n" + "result: Any = window.math.evaluate('2 + 2')\n" + "formatted: str = window.math.format(result)\n" + "sq: Any = window.math.sqrt(4)\n" + "pw: Any = window.math.pow(2, 3)\n" + "d: Any = window.math.det([[1, 2], [3, 4]])\n" + "t: str = window.math.typeOf(result)\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_window_nerdamer(self) -> None: + """window.nerdamer operations must type-check.""" + snippet = ( + "from browser import window\n" + "expr = window.nerdamer('x^2 + 1')\n" + "text: str = expr.text()\n" + "evald = expr.evaluate()\n" + "subbed = expr.sub('x', 2)\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_window_dynamic_attribute_access(self) -> None: + """Window __getattr__/__setattr__ escape hatch must type-check.""" + snippet = ( + "from browser import window\n" + "from typing import Any\n" + "window.startMatHudTests = lambda: None\n" + "custom: Any = window.VISION_MODELS\n" + "window['custom_key'] = 42\n" + "val = window['custom_key']\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_window_date(self) -> None: + """window.Date.now() must type-check.""" + snippet = "from browser import window\nts: int = window.Date.now()\n" + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_window_url(self) -> None: + """window.URL.createObjectURL/revokeObjectURL must type-check.""" + snippet = ( + "from browser import window\n" + "from typing import Any\n" + "blob: Any = None\n" + "url: str = window.URL.createObjectURL(blob)\n" + "window.URL.revokeObjectURL(url)\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + +# --------------------------------------------------------------------------- +# Tests: MyPy AJAX operations +# --------------------------------------------------------------------------- + + +class TestMyPyAjaxOperations(unittest.TestCase): + """Verify MyPy type-checks AJAX usage patterns.""" + + def test_mypy_ajax_lowercase_constructor(self) -> None: + """ajax.ajax() must type-check and return an AjaxRequest.""" + snippet = ( + "from browser import ajax\n" + "req = ajax.ajax()\n" + "req.bind('complete', lambda r: None)\n" + "req.open('GET', '/api')\n" + "req.send()\n" + "status: int = req.status\n" + "body: str = req.text\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_ajax_uppercase_constructor(self) -> None: + """ajax.Ajax() must type-check as a separate constructor.""" + snippet = ( + "from browser import ajax\n" + "req = ajax.Ajax()\n" + "req.open('POST', '/api')\n" + "req.set_header('Content-Type', 'application/json')\n" + 'req.send(\'{"key": "value"}\')\n' + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_ajax_post_shortcut(self) -> None: + """ajax.post() shortcut must type-check.""" + snippet = "from browser import ajax\najax.post('/api', data='payload', oncomplete=lambda r: None)\n" + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_ajax_with_timeout(self) -> None: + """ajax.ajax(timeout=...) must type-check.""" + snippet = "from browser import ajax\nreq = ajax.ajax(timeout=20000)\nreq.open('GET', '/api')\nreq.send()\n" + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_ajax_both_constructors_type_check(self) -> None: + """Both ajax.ajax() and ajax.Ajax() must produce usable objects.""" + snippet = ( + "from browser import ajax\n" + "req1 = ajax.ajax()\n" + "req2 = ajax.Ajax()\n" + "req1.bind('complete', lambda r: None)\n" + "req2.bind('complete', lambda r: None)\n" + "req1.open('GET', '/api')\n" + "req2.open('POST', '/api')\n" + "req1.send()\n" + "req2.send('data')\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + +# --------------------------------------------------------------------------- +# Tests: MyPy async I/O operations +# --------------------------------------------------------------------------- + + +class TestMyPyAioOperations(unittest.TestCase): + """Verify MyPy type-checks async I/O patterns.""" + + def test_mypy_aio_run_and_sleep(self) -> None: + """aio.run() and aio.sleep() must type-check.""" + snippet = "from browser import aio\nasync def main() -> None:\n await aio.sleep(1.0)\naio.run(main())\n" + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_aio_run_accepts_coroutine(self) -> None: + """aio.run() must accept a coroutine argument.""" + snippet = ( + "from browser import aio\n" + "async def fetch_data() -> str:\n" + " await aio.sleep(0.5)\n" + " return 'data'\n" + "aio.run(fetch_data())\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + +# --------------------------------------------------------------------------- +# Tests: MyPy console operations +# --------------------------------------------------------------------------- + + +class TestMyPyConsoleOperations(unittest.TestCase): + """Verify MyPy type-checks console usage patterns.""" + + def test_mypy_console_logging(self) -> None: + """console.log/error/warn with various args must type-check.""" + snippet = ( + "from browser import console\n" + "console.log('message')\n" + "console.log('key', 42, [1, 2, 3])\n" + "console.error('error message')\n" + "console.warn('warning')\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_console_grouping(self) -> None: + """console.groupCollapsed/groupEnd must type-check.""" + snippet = ( + "from browser import console\n" + "console.groupCollapsed('Debug info')\n" + "console.log('details')\n" + "console.groupEnd()\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + +# --------------------------------------------------------------------------- +# Tests: Integration patterns (real codebase usage) +# --------------------------------------------------------------------------- + + +class TestMyPyIntegrationPatterns(unittest.TestCase): + """Verify MyPy handles real codebase usage patterns from static/client/.""" + + def test_mypy_ajax_bind_open_send_pattern(self) -> None: + """Full AJAX request lifecycle pattern from workspace_manager.py.""" + snippet = ( + "from browser import ajax\n" + "def save_workspace(name: str, data: str) -> None:\n" + " req = ajax.Ajax()\n" + " req.bind('complete', lambda r: None)\n" + " req.open('POST', '/save_workspace')\n" + " req.set_header('Content-Type', 'application/json')\n" + " req.send(data)\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_dom_element_creation_and_append(self) -> None: + """Element creation + append pattern from ai_interface.py.""" + snippet = ( + "from browser import document, html\n" + "def create_ui() -> None:\n" + " container = html.DIV(Class='container')\n" + " button = html.BUTTON('Send', Class='send-btn', id='send-button')\n" + " textarea = html.TEXTAREA(id='chat-input')\n" + " container <= button\n" + " container <= textarea\n" + " document <= container\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_local_storage_pattern(self) -> None: + """LocalStorage get/set pattern from tts_controller.py.""" + snippet = ( + "from browser import window\n" + "STORAGE_KEY = 'mathud.voice'\n" + "val = window.localStorage.getItem(STORAGE_KEY)\n" + "if val is not None:\n" + " current: str = val\n" + "window.localStorage.setItem(STORAGE_KEY, 'default')\n" + "window.localStorage.removeItem(STORAGE_KEY)\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_svg_rendering_pattern(self) -> None: + """SVG rendering pattern from svg_primitive_adapter.py.""" + snippet = ( + "from browser import svg, document\n" + "def create_svg_scene() -> None:\n" + " root = svg.svg(id='canvas-svg', width='800', height='600')\n" + " group = svg.g(id='drawables')\n" + " ln = svg.line(x1='0', y1='0', x2='100', y2='100')\n" + " c = svg.circle(cx='50', cy='50', r='10')\n" + " group <= ln\n" + " group <= c\n" + " root <= group\n" + " frag = document.createDocumentFragment()\n" + " frag <= root\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_element_query_and_manipulation(self) -> None: + """Query + manipulate pattern from canvas_event_handler.py.""" + snippet = ( + "from browser import document\n" + "def setup_handlers() -> None:\n" + " if 'run-tests-button' in document:\n" + " btn = document['run-tests-button']\n" + " btn.bind('click', lambda ev: None)\n" + " el = document.getElementById('main-container')\n" + " if el is not None:\n" + " el.classList.add('active')\n" + " el.classList.remove('hidden')\n" + " has: bool = el.classList.contains('active')\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_math_evaluation_pattern(self) -> None: + """Math evaluation pattern from expression_evaluator.py.""" + snippet = ( + "from browser import window\n" + "from typing import Any\n" + "def evaluate_expr(expr: str) -> str:\n" + " result: Any = window.math.evaluate(expr)\n" + " formatted: str = window.math.format(result)\n" + " type_name: str = window.math.typeOf(result)\n" + " return formatted\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_nerdamer_solve_pattern(self) -> None: + """Nerdamer symbolic math pattern from math_utils.py.""" + snippet = ( + "from browser import window\n" + "def solve_equation(eq: str) -> str:\n" + " expr = window.nerdamer(eq)\n" + " result = expr.evaluate()\n" + " text: str = result.text()\n" + " subbed = expr.sub('x', 5)\n" + " return text\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_async_data_fetch_pattern(self) -> None: + """Async fetch pattern from ai_interface.py.""" + snippet = ( + "from browser import aio, ajax\n" + "async def fetch_response() -> None:\n" + " await aio.sleep(0.1)\n" + " req = ajax.ajax()\n" + " req.bind('complete', lambda r: None)\n" + " req.open('GET', '/api/status')\n" + " req.send()\n" + "aio.run(fetch_response())\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_canvas_getcontext_pattern(self) -> None: + """Canvas getContext pattern from canvas2d_renderer.py.""" + snippet = ( + "from browser import html\n" + "from typing import Any\n" + "canvas = html.CANVAS(id='main-canvas')\n" + "canvas.attrs['width'] = '800'\n" + "canvas.attrs['height'] = '600'\n" + "ctx: Any = canvas.getContext('2d')\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_json_stringify_parse_round_trip(self) -> None: + """JSON round-trip pattern from linear_algebra_utils.py.""" + snippet = ( + "from browser import window\n" + "from typing import Any\n" + "data: dict[str, int] = {'x': 1, 'y': 2}\n" + "json_str: str = window.JSON.stringify(data)\n" + "parsed: Any = window.JSON.parse(json_str)\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_performance_timing_pattern(self) -> None: + """Performance timing pattern from rendering code.""" + snippet = ( + "from browser import window\n" + "start: float = window.performance.now()\n" + "# ... do work ...\n" + "end: float = window.performance.now()\n" + "elapsed: float = end - start\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_window_function_registration_pattern(self) -> None: + """Dynamic window function registration pattern from main.py.""" + snippet = ( + "from browser import window\n" + "from typing import Any\n" + "def start_tests() -> dict[str, str]:\n" + " return {'status': 'started'}\n" + "def get_results() -> dict[str, Any]:\n" + " return {'tests_run': 0}\n" + "window.startMatHudTests = start_tests\n" + "window.getMatHudTestResults = get_results\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_document_exec_command(self) -> None: + """document.execCommand pattern from ai_interface.py.""" + snippet = "from browser import document\nresult: bool = document.execCommand('copy')\n" + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + def test_mypy_insert_before_pattern(self) -> None: + """insertBefore pattern from svg_primitive_adapter.py.""" + snippet = ( + "from browser import svg\n" + "surface = svg.svg()\n" + "group = svg.g()\n" + "ref = svg.g()\n" + "surface.insertBefore(group, ref)\n" + "surface.insertBefore(group, None)\n" + ) + stdout, stderr, exit_code = _run_mypy_snippet(snippet) + self.assertEqual(exit_code, 0, f"MyPy failed:\nstdout: {stdout}\nstderr: {stderr}") + + +# --------------------------------------------------------------------------- +# Tests: MyPy configuration integration +# --------------------------------------------------------------------------- + + +class TestMyPyConfiguration(unittest.TestCase): + """Verify mypy.ini is correctly configured for stub discovery.""" + + def test_mypy_ini_has_mypy_path(self) -> None: + """mypy.ini must include mypy_path pointing to the stubs directory.""" + mypy_ini_path = os.path.join(_ROOT, "mypy.ini") + with open(mypy_ini_path) as f: + content = f.read() + self.assertIn( + "mypy_path", + content, + "mypy.ini must contain 'mypy_path' setting", + ) + self.assertIn( + "static/client/typing", + content, + "mypy_path must reference 'static/client/typing'", + ) + + def test_mypy_ini_has_test_file(self) -> None: + """mypy.ini files list must include this test file.""" + mypy_ini_path = os.path.join(_ROOT, "mypy.ini") + with open(mypy_ini_path) as f: + content = f.read() + self.assertIn( + "server_tests/test_browser_typing_stubs.py", + content, + "mypy.ini files list must include test_browser_typing_stubs.py", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/server_tests/test_canvas_state_summarizer.py b/server_tests/test_canvas_state_summarizer.py index 2cdb55cb..101f8627 100644 --- a/server_tests/test_canvas_state_summarizer.py +++ b/server_tests/test_canvas_state_summarizer.py @@ -93,7 +93,7 @@ def test_large_scene_reduces_size(self) -> None: state["Segments"].append( { "name": f"s{i}", - "args": {"p1": f"P{i}", "p2": f"P{i+1}", "label": {"text": "", "visible": False}}, + "args": {"p1": f"P{i}", "p2": f"P{i + 1}", "label": {"text": "", "visible": False}}, "_p1_coords": [i, i], "_p2_coords": [i + 1, i + 1], } @@ -110,4 +110,3 @@ def test_large_scene_reduces_size(self) -> None: if __name__ == "__main__": unittest.main() - diff --git a/server_tests/test_cli/test_config.py b/server_tests/test_cli/test_config.py index a8a80da4..6e1e24c6 100644 --- a/server_tests/test_cli/test_config.py +++ b/server_tests/test_cli/test_config.py @@ -6,78 +6,90 @@ from pathlib import Path - class TestConfigConstants: """Test configuration constants.""" def test_project_root_exists(self) -> None: """PROJECT_ROOT should point to an existing directory.""" from cli.config import PROJECT_ROOT + assert PROJECT_ROOT.exists() assert PROJECT_ROOT.is_dir() def test_project_root_contains_app_py(self) -> None: """PROJECT_ROOT should contain app.py.""" from cli.config import PROJECT_ROOT + assert (PROJECT_ROOT / "app.py").exists() def test_default_port_is_valid(self) -> None: """DEFAULT_PORT should be a valid port number.""" from cli.config import DEFAULT_PORT + assert isinstance(DEFAULT_PORT, int) assert 1 <= DEFAULT_PORT <= 65535 def test_default_host_is_localhost(self) -> None: """DEFAULT_HOST should be localhost.""" from cli.config import DEFAULT_HOST + assert DEFAULT_HOST == "127.0.0.1" def test_pid_file_path(self) -> None: """PID_FILE should be in project root.""" from cli.config import PID_FILE, PROJECT_ROOT + assert PID_FILE.parent == PROJECT_ROOT assert PID_FILE.name == ".mathud_server.pid" def test_health_check_timeout_positive(self) -> None: """HEALTH_CHECK_TIMEOUT should be positive.""" from cli.config import HEALTH_CHECK_TIMEOUT + assert HEALTH_CHECK_TIMEOUT > 0 def test_health_check_retries_positive(self) -> None: """HEALTH_CHECK_RETRIES should be positive.""" from cli.config import HEALTH_CHECK_RETRIES + assert HEALTH_CHECK_RETRIES > 0 def test_browser_wait_timeout_positive(self) -> None: """BROWSER_WAIT_TIMEOUT should be positive.""" from cli.config import BROWSER_WAIT_TIMEOUT + assert BROWSER_WAIT_TIMEOUT > 0 def test_app_ready_timeout_positive(self) -> None: """APP_READY_TIMEOUT should be positive.""" from cli.config import APP_READY_TIMEOUT + assert APP_READY_TIMEOUT > 0 def test_client_test_timeout_positive(self) -> None: """CLIENT_TEST_TIMEOUT should be positive.""" from cli.config import CLIENT_TEST_TIMEOUT + assert CLIENT_TEST_TIMEOUT > 0 def test_viewport_dimensions_positive(self) -> None: """Viewport dimensions should be positive.""" from cli.config import DEFAULT_VIEWPORT_WIDTH, DEFAULT_VIEWPORT_HEIGHT + assert DEFAULT_VIEWPORT_WIDTH > 0 assert DEFAULT_VIEWPORT_HEIGHT > 0 def test_cli_output_dir_exists(self) -> None: """CLI_OUTPUT_DIR should exist.""" from cli.config import CLI_OUTPUT_DIR + assert CLI_OUTPUT_DIR.exists() assert CLI_OUTPUT_DIR.is_dir() def test_cli_output_dir_is_in_cli_module(self) -> None: """CLI_OUTPUT_DIR should be inside cli/ directory.""" from cli.config import CLI_OUTPUT_DIR + assert CLI_OUTPUT_DIR.name == "output" assert CLI_OUTPUT_DIR.parent.name == "cli" @@ -88,24 +100,28 @@ class TestGetPythonPath: def test_returns_path_object(self) -> None: """get_python_path should return a Path.""" from cli.config import get_python_path + result = get_python_path() assert isinstance(result, Path) def test_path_ends_with_python(self) -> None: """Path should end with python executable.""" from cli.config import get_python_path + result = get_python_path() assert "python" in result.name.lower() def test_path_is_in_venv(self) -> None: """Path should be in venv directory.""" from cli.config import get_python_path + result = get_python_path() assert "venv" in str(result) def test_windows_path_structure(self) -> None: """On Windows, path should use Scripts directory.""" from cli.config import get_python_path + if os.name == "nt": result = get_python_path() assert "Scripts" in str(result) @@ -113,6 +129,7 @@ def test_windows_path_structure(self) -> None: def test_unix_path_structure(self) -> None: """On Unix, path should use bin directory.""" from cli.config import get_python_path + if os.name != "nt": result = get_python_path() assert "bin" in str(result) @@ -124,11 +141,13 @@ class TestAppModule: def test_app_module_exists(self) -> None: """APP_MODULE should point to existing file.""" from cli.config import APP_MODULE + assert APP_MODULE.exists() assert APP_MODULE.is_file() def test_app_module_is_python(self) -> None: """APP_MODULE should be a Python file.""" from cli.config import APP_MODULE + assert APP_MODULE.suffix == ".py" assert APP_MODULE.name == "app.py" diff --git a/server_tests/test_cli/test_server.py b/server_tests/test_cli/test_server.py index e9b2f321..2d149efd 100644 --- a/server_tests/test_cli/test_server.py +++ b/server_tests/test_cli/test_server.py @@ -48,6 +48,7 @@ def test_server_running_returns_true(self) -> None: def test_server_not_running_returns_false(self) -> None: """is_server_running returns False when request fails.""" import requests as req_module + manager = ServerManager() with patch("cli.server.requests.get", side_effect=req_module.RequestException("Connection refused")): @@ -119,6 +120,7 @@ def test_legacy_pid_file_process_cmdline_unreadable_is_cleaned_up(self) -> None: mock_pid_file.read_text.return_value = "12345" import psutil + with patch("cli.server.PID_FILE", mock_pid_file): with patch("cli.server.psutil.pid_exists", return_value=True): with patch("cli.server.psutil.Process", side_effect=psutil.AccessDenied(pid=12345)): @@ -132,9 +134,7 @@ def test_valid_pid_file_json_with_matching_create_time_returns_pid(self) -> None mock_pid_file = MagicMock(spec=Path) mock_pid_file.exists.return_value = True - mock_pid_file.read_text.return_value = json.dumps( - {"pid": 12345, "create_time": 1000.0, "port": 5000} - ) + mock_pid_file.read_text.return_value = json.dumps({"pid": 12345, "create_time": 1000.0, "port": 5000}) mock_process = MagicMock() mock_process.create_time.return_value = 1000.2 @@ -164,9 +164,7 @@ def test_pid_file_with_reused_pid_is_cleaned_up(self) -> None: mock_pid_file = MagicMock(spec=Path) mock_pid_file.exists.return_value = True - mock_pid_file.read_text.return_value = json.dumps( - {"pid": 12345, "create_time": 1000.0, "port": 5000} - ) + mock_pid_file.read_text.return_value = json.dumps({"pid": 12345, "create_time": 1000.0, "port": 5000}) mock_process = MagicMock() mock_process.create_time.return_value = 2000.0 @@ -184,9 +182,7 @@ def test_pid_file_wrong_port_is_ignored(self) -> None: mock_pid_file = MagicMock(spec=Path) mock_pid_file.exists.return_value = True - mock_pid_file.read_text.return_value = json.dumps( - {"pid": 12345, "create_time": 1000.0, "port": 5001} - ) + mock_pid_file.read_text.return_value = json.dumps({"pid": 12345, "create_time": 1000.0, "port": 5001}) with patch("cli.server.PID_FILE", mock_pid_file): assert manager.get_pid() is None diff --git a/server_tests/test_coordinate_mapper.py b/server_tests/test_coordinate_mapper.py index 6cd19bbb..63a6ac52 100644 --- a/server_tests/test_coordinate_mapper.py +++ b/server_tests/test_coordinate_mapper.py @@ -24,7 +24,6 @@ class TestCoordinateMapper(unittest.TestCase): - def setUp(self) -> None: """Set up test fixtures with standard canvas size.""" self.canvas_width = 800 @@ -79,13 +78,7 @@ def test_screen_to_math_basic(self) -> None: def test_coordinate_conversion_roundtrip(self) -> None: """Test that math to screen to math conversion preserves values.""" - test_cases = [ - (0, 0), - (100, 50), - (-100, -50), - (3.14159, -2.71828), - (1000, -500) - ] + test_cases = [(0, 0), (100, 50), (-100, -50), (3.14159, -2.71828), (1000, -500)] for orig_x, orig_y in test_cases: with self.subTest(x=orig_x, y=orig_y): @@ -220,10 +213,10 @@ def test_get_visible_bounds(self) -> None: # With default settings, bounds should cover canvas converted to math coords # Canvas corners: (0,0) and (800,600) # Math corners: (-400,300) and (400,-300) - self.assertAlmostEqual(bounds['left'], -400, places=1) - self.assertAlmostEqual(bounds['right'], 400, places=1) - self.assertAlmostEqual(bounds['top'], 300, places=1) - self.assertAlmostEqual(bounds['bottom'], -300, places=1) + self.assertAlmostEqual(bounds["left"], -400, places=1) + self.assertAlmostEqual(bounds["right"], 400, places=1) + self.assertAlmostEqual(bounds["top"], 300, places=1) + self.assertAlmostEqual(bounds["bottom"], -300, places=1) def test_get_visible_bounds_with_zoom(self) -> None: """Test visible bounds calculation with zoom.""" @@ -232,10 +225,10 @@ def test_get_visible_bounds_with_zoom(self) -> None: bounds = self.mapper.get_visible_bounds() # Bounds should be halved - self.assertAlmostEqual(bounds['left'], -200, places=1) - self.assertAlmostEqual(bounds['right'], 200, places=1) - self.assertAlmostEqual(bounds['top'], 150, places=1) - self.assertAlmostEqual(bounds['bottom'], -150, places=1) + self.assertAlmostEqual(bounds["left"], -200, places=1) + self.assertAlmostEqual(bounds["right"], 200, places=1) + self.assertAlmostEqual(bounds["top"], 150, places=1) + self.assertAlmostEqual(bounds["bottom"], -150, places=1) def test_get_visible_width_height(self) -> None: """Test visible width and height calculation.""" @@ -284,12 +277,12 @@ def test_is_point_visible(self) -> None: """Test screen point visibility checking.""" # Points within canvas should be visible self.assertTrue(self.mapper.is_point_visible(400, 300)) # center - self.assertTrue(self.mapper.is_point_visible(0, 0)) # top-left + self.assertTrue(self.mapper.is_point_visible(0, 0)) # top-left self.assertTrue(self.mapper.is_point_visible(799, 599)) # bottom-right # Points outside canvas should not be visible - self.assertFalse(self.mapper.is_point_visible(-1, 300)) # left edge - self.assertFalse(self.mapper.is_point_visible(400, -1)) # top edge + self.assertFalse(self.mapper.is_point_visible(-1, 300)) # left edge + self.assertFalse(self.mapper.is_point_visible(400, -1)) # top edge self.assertFalse(self.mapper.is_point_visible(801, 300)) # right edge self.assertFalse(self.mapper.is_point_visible(400, 601)) # bottom edge @@ -353,7 +346,7 @@ def test_get_zoom_towards_point_displacement(self) -> None: # When zooming in, objects should move AWAY from zoom point to maintain relative position # Target (400,300) is left/below zoom point (500,200), so displacement should be left/up - self.assertLess(displacement.x, 0) # Move left (away from zoom point) + self.assertLess(displacement.x, 0) # Move left (away from zoom point) self.assertGreater(displacement.y, 0) # Move up (away from zoom point) def test_state_management(self) -> None: @@ -367,8 +360,16 @@ def test_state_management(self) -> None: state = self.mapper.get_state() # Verify state contains expected keys - expected_keys = ['canvas_width', 'canvas_height', 'scale_factor', 'offset', - 'origin', 'zoom_point', 'zoom_direction', 'zoom_step'] + expected_keys = [ + "canvas_width", + "canvas_height", + "scale_factor", + "offset", + "origin", + "zoom_point", + "zoom_direction", + "zoom_step", + ] for key in expected_keys: self.assertIn(key, state) @@ -393,10 +394,10 @@ def test_individual_boundary_methods(self) -> None: # Should match get_visible_bounds() results bounds = self.mapper.get_visible_bounds() - self.assertAlmostEqual(left, bounds['left'], places=6) - self.assertAlmostEqual(right, bounds['right'], places=6) - self.assertAlmostEqual(top, bounds['top'], places=6) - self.assertAlmostEqual(bottom, bounds['bottom'], places=6) + self.assertAlmostEqual(left, bounds["left"], places=6) + self.assertAlmostEqual(right, bounds["right"], places=6) + self.assertAlmostEqual(top, bounds["top"], places=6) + self.assertAlmostEqual(bottom, bounds["bottom"], places=6) # Test with transformations self.mapper.apply_zoom(2.0) @@ -448,12 +449,7 @@ def test_legacy_pattern_support_methods(self) -> None: def test_legacy_methods_consistency(self) -> None: """Test that legacy methods are consistent with core methods.""" - test_cases = [ - (0, 0), - (100, -50), - (-75, 125), - (3.14159, -2.71828) - ] + test_cases = [(0, 0), (100, -50), (-75, 125), (3.14159, -2.71828)] for math_x, math_y in test_cases: with self.subTest(x=math_x, y=math_y): @@ -473,6 +469,7 @@ def test_legacy_methods_consistency(self) -> None: def test_sync_from_canvas_mock(self) -> None: """Test synchronization with a mock Canvas object.""" + # Create a mock canvas object with coordinate properties class MockCanvas: def __init__(self) -> None: @@ -512,6 +509,7 @@ def __init__(self) -> None: def test_sync_from_canvas_minimal(self) -> None: """Test sync with minimal Canvas object (missing some properties).""" + class MinimalCanvas: def __init__(self) -> None: self.width = 600 @@ -537,6 +535,7 @@ def __init__(self) -> None: def test_from_canvas_factory_method(self) -> None: """Test factory method to create CoordinateMapper from Canvas.""" + # Create mock canvas class MockCanvas: def __init__(self) -> None: @@ -577,13 +576,13 @@ def test_coordinate_conversion_with_offset(self) -> None: # Screen center should map to different math coordinates math_x, math_y = self.mapper.screen_to_math(400, 300) - self.assertAlmostEqual(math_x, -50, places=6) # Shifted by offset - self.assertAlmostEqual(math_y, -30, places=6) # Shifted by offset (negative) + self.assertAlmostEqual(math_x, -50, places=6) # Shifted by offset + self.assertAlmostEqual(math_y, -30, places=6) # Shifted by offset (negative) # Test individual conversion methods with offset canvas_x = 400 # Canvas center x math_x = self.mapper.convert_canvas_x_to_math(canvas_x) - self.assertAlmostEqual(math_x, -50, places=6) # (400 - 50 - 400) / 1.0 + self.assertAlmostEqual(math_x, -50, places=6) # (400 - 50 - 400) / 1.0 math_y = 0 canvas_y = self.mapper.convert_math_y_to_canvas(math_y) @@ -601,7 +600,7 @@ def test_from_canvas_with_simple_mock_full_featured(self) -> None: cartesian2axis=cartesian_mock, zoom_point=Position(700, 350), zoom_direction=-1, - zoom_step=0.15 + zoom_step=0.15, ) # Create CoordinateMapper using from_canvas factory method @@ -613,7 +612,7 @@ def test_from_canvas_with_simple_mock_full_featured(self) -> None: self.assertEqual(mapper.scale_factor, 2.0) self.assertEqual(mapper.offset.x, 100) self.assertEqual(mapper.offset.y, -50) - self.assertEqual(mapper.origin.x, 600) # Falls back to width/2 without center + self.assertEqual(mapper.origin.x, 600) # Falls back to width/2 without center self.assertEqual(mapper.origin.y, 400) self.assertEqual(mapper.zoom_point.x, 700) self.assertEqual(mapper.zoom_point.y, 350) @@ -623,11 +622,7 @@ def test_from_canvas_with_simple_mock_full_featured(self) -> None: def test_from_canvas_with_simple_mock_minimal(self) -> None: """Test from_canvas with minimal SimpleMock canvas (missing optional properties).""" # Create minimal mock canvas with only required properties - canvas_mock = SimpleMock( - width=800, - height=600, - scale_factor=1.5 - ) + canvas_mock = SimpleMock(width=800, height=600, scale_factor=1.5) # Should handle missing properties gracefully mapper = CoordinateMapper.from_canvas(canvas_mock) @@ -640,8 +635,8 @@ def test_from_canvas_with_simple_mock_minimal(self) -> None: # Should use defaults for missing properties self.assertEqual(mapper.offset.x, 0) self.assertEqual(mapper.offset.y, 0) - self.assertEqual(mapper.origin.x, 400) # width / 2 - self.assertEqual(mapper.origin.y, 300) # height / 2 + self.assertEqual(mapper.origin.x, 400) # width / 2 + self.assertEqual(mapper.origin.y, 300) # height / 2 self.assertEqual(mapper.zoom_direction, 0) self.assertEqual(mapper.zoom_step, 0.1) @@ -654,7 +649,7 @@ def test_from_canvas_with_center_instead_of_cartesian(self) -> None: scale_factor=1.2, offset=Position(30, 20), center=Position(520, 380), # Using center instead of cartesian2axis - zoom_point=Position(500, 400) + zoom_point=Position(500, 400), ) mapper = CoordinateMapper.from_canvas(canvas_mock) @@ -682,7 +677,7 @@ def test_sync_from_canvas_with_simple_mock_updates_existing(self) -> None: offset=Position(-75, 40), cartesian2axis=cartesian_mock, zoom_direction=1, - zoom_step=0.2 + zoom_step=0.2, ) # Sync with canvas @@ -741,7 +736,7 @@ def test_sync_from_canvas_partial_properties(self) -> None: canvas_mock = SimpleMock( width=900, height=700, - scale_factor=2.5 + scale_factor=2.5, # Missing: offset, cartesian2axis, center, zoom properties ) @@ -767,11 +762,7 @@ def test_sync_from_canvas_coordinate_transformations_work(self) -> None: # cartesian2axis.origin should be ignored when center is absent. cartesian_mock = SimpleMock(origin=Position(600, 400)) canvas_mock = SimpleMock( - width=1000, - height=600, - scale_factor=2.0, - offset=Position(100, -50), - cartesian2axis=cartesian_mock + width=1000, height=600, scale_factor=2.0, offset=Position(100, -50), cartesian2axis=cartesian_mock ) mapper = CoordinateMapper(800, 600) # Different initial dimensions @@ -798,7 +789,7 @@ def test_sync_from_canvas_visible_bounds_consistency(self) -> None: height=600, scale_factor=1.5, offset=Position(50, 25), - center=Position(400, 300) # Using center instead of cartesian2axis + center=Position(400, 300), # Using center instead of cartesian2axis ) mapper = CoordinateMapper(1000, 800) # Different initial size @@ -812,23 +803,23 @@ def test_sync_from_canvas_visible_bounds_consistency(self) -> None: bottom = mapper.get_visible_bottom_bound() # Individual methods should match bounds dictionary - self.assertAlmostEqual(left, bounds['left'], places=6) - self.assertAlmostEqual(right, bounds['right'], places=6) - self.assertAlmostEqual(top, bounds['top'], places=6) - self.assertAlmostEqual(bottom, bounds['bottom'], places=6) + self.assertAlmostEqual(left, bounds["left"], places=6) + self.assertAlmostEqual(right, bounds["right"], places=6) + self.assertAlmostEqual(top, bounds["top"], places=6) + self.assertAlmostEqual(bottom, bounds["bottom"], places=6) # Width and height should be consistent width = mapper.get_visible_width() height = mapper.get_visible_height() - self.assertAlmostEqual(width, bounds['right'] - bounds['left'], places=6) - self.assertAlmostEqual(height, bounds['top'] - bounds['bottom'], places=6) + self.assertAlmostEqual(width, bounds["right"] - bounds["left"], places=6) + self.assertAlmostEqual(height, bounds["top"] - bounds["bottom"], places=6) def test_canvas_mock_attribute_error_handling(self) -> None: """Test handling of missing attributes in canvas mock gracefully.""" # Mock with minimal attributes canvas_mock = SimpleMock( width=640, - height=480 + height=480, # Missing scale_factor and other properties ) @@ -856,7 +847,7 @@ def test_factory_method_vs_sync_consistency(self) -> None: cartesian2axis=cartesian_mock, zoom_point=Position(600, 400), zoom_direction=-1, - zoom_step=0.12 + zoom_step=0.12, ) # Method 1: Use factory method @@ -888,5 +879,5 @@ def test_factory_method_vs_sync_consistency(self) -> None: self.assertAlmostEqual(screen1_y, screen2_y, places=6) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/server_tests/test_coordinate_system_manager.py b/server_tests/test_coordinate_system_manager.py index e633fce2..e4475703 100644 --- a/server_tests/test_coordinate_system_manager.py +++ b/server_tests/test_coordinate_system_manager.py @@ -24,6 +24,7 @@ class Position: """Simple Position class for testing.""" + def __init__(self, x: float = 0, y: float = 0): self.x = x self.y = y @@ -31,6 +32,7 @@ def __init__(self, x: float = 0, y: float = 0): class MockPolarGrid: """Mock PolarGrid for testing.""" + def __init__(self, coordinate_mapper: Any): self.coordinate_mapper = coordinate_mapper self.class_name = "PolarGrid" @@ -435,5 +437,5 @@ def test_mode_switch_preserves_grid_instances(self) -> None: self.assertEqual(id(manager.polar_grid), polar_id) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/server_tests/test_function_renderable_paths.py b/server_tests/test_function_renderable_paths.py index af7f976e..cb41a0fa 100644 --- a/server_tests/test_function_renderable_paths.py +++ b/server_tests/test_function_renderable_paths.py @@ -8,8 +8,8 @@ import unittest from typing import List, Tuple, Optional, Callable -sys.path.insert(0, 'static/client') -sys.path.insert(0, 'static/client/rendering/renderables') +sys.path.insert(0, "static/client") +sys.path.insert(0, "static/client/rendering/renderables") class MockFunction: @@ -134,7 +134,7 @@ def setUp(self): def one_over_x(x): if x == 0: - return float('inf') + return float("inf") return 1.0 / x self.func = MockFunction( @@ -208,6 +208,5 @@ def clamp_screen_y(sy: float, height: float) -> float: return sy -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() - diff --git a/server_tests/test_image_attachment.py b/server_tests/test_image_attachment.py index 71db8c2d..887bbd8f 100644 --- a/server_tests/test_image_attachment.py +++ b/server_tests/test_image_attachment.py @@ -52,21 +52,13 @@ def test_extract_renderer_mode(self) -> None: def test_extract_canvas_image_from_vision_snapshot(self) -> None: """Test extracting canvas_image from nested vision_snapshot.""" - payload = { - "vision_snapshot": { - "canvas_image": TEST_IMAGE_DATA_URL - } - } + payload = {"vision_snapshot": {"canvas_image": TEST_IMAGE_DATA_URL}} _, canvas_image, _, _ = extract_vision_payload(payload) self.assertEqual(canvas_image, TEST_IMAGE_DATA_URL) def test_extract_svg_state_from_vision_snapshot(self) -> None: """Test extracting svg_state from nested vision_snapshot.""" - payload = { - "vision_snapshot": { - "svg_state": {"elements": ["rect", "circle"]} - } - } + payload = {"vision_snapshot": {"svg_state": {"elements": ["rect", "circle"]}}} svg_state, _, _, _ = extract_vision_payload(payload) self.assertEqual(svg_state, {"elements": ["rect", "circle"]}) @@ -75,10 +67,7 @@ def test_vision_snapshot_overrides_top_level(self) -> None: payload = { "svg_state": {"elements": ["old"]}, "renderer_mode": "canvas2d", - "vision_snapshot": { - "svg_state": {"elements": ["new"]}, - "renderer_mode": "webgl" - } + "vision_snapshot": {"svg_state": {"elements": ["new"]}, "renderer_mode": "webgl"}, } svg_state, _, renderer_mode, _ = extract_vision_payload(payload) self.assertEqual(svg_state, {"elements": ["new"]}) @@ -96,12 +85,7 @@ def test_invalid_types_ignored(self) -> None: def test_extract_attached_images(self) -> None: """Test extracting attached_images from payload.""" - payload = { - "attached_images": [ - "data:image/png;base64,img1", - "data:image/jpeg;base64,img2" - ] - } + payload = {"attached_images": ["data:image/png;base64,img1", "data:image/jpeg;base64,img2"]} _, _, _, attached_images = extract_vision_payload(payload) self.assertEqual(len(attached_images), 2) self.assertIn("data:image/png;base64,img1", attached_images) @@ -109,31 +93,20 @@ def test_extract_attached_images(self) -> None: def test_extract_attached_images_filters_non_strings(self) -> None: """Test that non-string items are filtered from attached_images.""" - payload = { - "attached_images": [ - "data:image/png;base64,valid", - 123, - None, - {"invalid": "object"} - ] - } + payload = {"attached_images": ["data:image/png;base64,valid", 123, None, {"invalid": "object"}]} _, _, _, attached_images = extract_vision_payload(payload) self.assertEqual(len(attached_images), 1) self.assertEqual(attached_images[0], "data:image/png;base64,valid") def test_extract_attached_images_empty_list(self) -> None: """Test extracting empty attached_images list.""" - payload = { - "attached_images": [] - } + payload = {"attached_images": []} _, _, _, attached_images = extract_vision_payload(payload) self.assertEqual(attached_images, []) def test_extract_attached_images_invalid_type(self) -> None: """Test that non-list attached_images is ignored.""" - payload = { - "attached_images": "not a list" - } + payload = {"attached_images": "not a list"} _, _, _, attached_images = extract_vision_payload(payload) self.assertIsNone(attached_images) @@ -143,24 +116,24 @@ class TestPrepareMessageContent(unittest.TestCase): def setUp(self) -> None: """Set up test fixtures.""" - self.original_api_key = os.environ.get('OPENAI_API_KEY') - os.environ['OPENAI_API_KEY'] = 'test-api-key' + self.original_api_key = os.environ.get("OPENAI_API_KEY") + os.environ["OPENAI_API_KEY"] = "test-api-key" def tearDown(self) -> None: """Clean up after tests.""" if self.original_api_key: - os.environ['OPENAI_API_KEY'] = self.original_api_key + os.environ["OPENAI_API_KEY"] = self.original_api_key else: - os.environ.pop('OPENAI_API_KEY', None) + os.environ.pop("OPENAI_API_KEY", None) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_plain_text_prompt_returns_original(self, mock_openai: Mock) -> None: """Test plain text prompt returns unchanged.""" api = OpenAIAPIBase() result = api._prepare_message_content("plain text message") self.assertEqual(result, "plain text message") - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_no_vision_no_images_returns_original(self, mock_openai: Mock) -> None: """Test JSON prompt with no vision and no images returns original.""" api = OpenAIAPIBase() @@ -168,15 +141,13 @@ def test_no_vision_no_images_returns_original(self, mock_openai: Mock) -> None: result = api._prepare_message_content(prompt) self.assertEqual(result, prompt) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_attached_images_without_vision(self, mock_openai: Mock) -> None: """Test that attached images are processed even without vision toggle.""" api = OpenAIAPIBase() - prompt = json.dumps({ - "user_message": "What is this image?", - "use_vision": False, - "attached_images": [TEST_IMAGE_DATA_URL] - }) + prompt = json.dumps( + {"user_message": "What is this image?", "use_vision": False, "attached_images": [TEST_IMAGE_DATA_URL]} + ) result = api._prepare_message_content(prompt) # Should return a list (enhanced prompt) even without vision @@ -187,20 +158,12 @@ def test_attached_images_without_vision(self, mock_openai: Mock) -> None: self.assertEqual(result[1]["type"], "image_url") self.assertEqual(result[1]["image_url"]["url"], TEST_IMAGE_DATA_URL) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_multiple_attached_images(self, mock_openai: Mock) -> None: """Test multiple attached images are all included.""" api = OpenAIAPIBase() - images = [ - "data:image/png;base64,img1", - "data:image/jpeg;base64,img2", - "data:image/png;base64,img3" - ] - prompt = json.dumps({ - "user_message": "Compare these images", - "use_vision": False, - "attached_images": images - }) + images = ["data:image/png;base64,img1", "data:image/jpeg;base64,img2", "data:image/png;base64,img3"] + prompt = json.dumps({"user_message": "Compare these images", "use_vision": False, "attached_images": images}) result = api._prepare_message_content(prompt) self.assertIsInstance(result, list) @@ -212,20 +175,22 @@ def test_multiple_attached_images(self, mock_openai: Mock) -> None: for img in images: self.assertIn(img, image_urls) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_invalid_attached_images_filtered(self, mock_openai: Mock) -> None: """Test that invalid image data URLs are filtered out.""" api = OpenAIAPIBase() - prompt = json.dumps({ - "user_message": "test", - "use_vision": False, - "attached_images": [ - "data:image/png;base64,valid", - "not-a-data-url", - 123, # Not a string - "data:text/plain;base64,notimage", # Wrong MIME type - ] - }) + prompt = json.dumps( + { + "user_message": "test", + "use_vision": False, + "attached_images": [ + "data:image/png;base64,valid", + "not-a-data-url", + 123, # Not a string + "data:text/plain;base64,notimage", # Wrong MIME type + ], + } + ) result = api._prepare_message_content(prompt) self.assertIsInstance(result, list) @@ -234,15 +199,11 @@ def test_invalid_attached_images_filtered(self, mock_openai: Mock) -> None: self.assertEqual(len(image_parts), 1) self.assertEqual(image_parts[0]["image_url"]["url"], "data:image/png;base64,valid") - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_empty_attached_images_returns_original(self, mock_openai: Mock) -> None: """Test empty attached_images array without vision returns original.""" api = OpenAIAPIBase() - prompt = json.dumps({ - "user_message": "test", - "use_vision": False, - "attached_images": [] - }) + prompt = json.dumps({"user_message": "test", "use_vision": False, "attached_images": []}) result = api._prepare_message_content(prompt) # Empty images array with no vision should return original self.assertEqual(result, prompt) @@ -253,25 +214,23 @@ class TestCreateEnhancedPromptWithImage(unittest.TestCase): def setUp(self) -> None: """Set up test fixtures.""" - self.original_api_key = os.environ.get('OPENAI_API_KEY') - os.environ['OPENAI_API_KEY'] = 'test-api-key' + self.original_api_key = os.environ.get("OPENAI_API_KEY") + os.environ["OPENAI_API_KEY"] = "test-api-key" def tearDown(self) -> None: """Clean up after tests.""" if self.original_api_key: - os.environ['OPENAI_API_KEY'] = self.original_api_key + os.environ["OPENAI_API_KEY"] = self.original_api_key else: - os.environ.pop('OPENAI_API_KEY', None) + os.environ.pop("OPENAI_API_KEY", None) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_attached_images_only(self, mock_openai: Mock) -> None: """Test creating enhanced prompt with only attached images.""" api = OpenAIAPIBase() images = [TEST_IMAGE_DATA_URL] result = api._create_enhanced_prompt_with_image( - user_message="Describe this", - attached_images=images, - include_canvas_snapshot=False + user_message="Describe this", attached_images=images, include_canvas_snapshot=False ) self.assertIsNotNone(result) @@ -280,25 +239,21 @@ def test_attached_images_only(self, mock_openai: Mock) -> None: self.assertEqual(result[0]["text"], "Describe this") self.assertEqual(result[1]["type"], "image_url") - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_no_images_returns_none(self, mock_openai: Mock) -> None: """Test returns None when no images available.""" api = OpenAIAPIBase() result = api._create_enhanced_prompt_with_image( - user_message="Hello", - attached_images=None, - include_canvas_snapshot=False + user_message="Hello", attached_images=None, include_canvas_snapshot=False ) self.assertIsNone(result) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_empty_images_returns_none(self, mock_openai: Mock) -> None: """Test returns None with empty images list.""" api = OpenAIAPIBase() result = api._create_enhanced_prompt_with_image( - user_message="Hello", - attached_images=[], - include_canvas_snapshot=False + user_message="Hello", attached_images=[], include_canvas_snapshot=False ) self.assertIsNone(result) @@ -308,24 +263,24 @@ class TestConvertContentForResponsesAPI(unittest.TestCase): def setUp(self) -> None: """Set up test fixtures.""" - self.original_api_key = os.environ.get('OPENAI_API_KEY') - os.environ['OPENAI_API_KEY'] = 'test-api-key' + self.original_api_key = os.environ.get("OPENAI_API_KEY") + os.environ["OPENAI_API_KEY"] = "test-api-key" def tearDown(self) -> None: """Clean up after tests.""" if self.original_api_key: - os.environ['OPENAI_API_KEY'] = self.original_api_key + os.environ["OPENAI_API_KEY"] = self.original_api_key else: - os.environ.pop('OPENAI_API_KEY', None) + os.environ.pop("OPENAI_API_KEY", None) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_string_content_unchanged(self, mock_openai: Mock) -> None: """Test string content passes through unchanged.""" api = OpenAIResponsesAPI() result = api._convert_content_for_responses_api("plain text") self.assertEqual(result, "plain text") - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_convert_text_type(self, mock_openai: Mock) -> None: """Test 'text' type converted to 'input_text'.""" api = OpenAIResponsesAPI() @@ -336,21 +291,18 @@ def test_convert_text_type(self, mock_openai: Mock) -> None: self.assertEqual(result[0]["type"], "input_text") self.assertEqual(result[0]["text"], "Hello") - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_convert_image_url_type(self, mock_openai: Mock) -> None: """Test 'image_url' type converted to 'input_image'.""" api = OpenAIResponsesAPI() - content = [{ - "type": "image_url", - "image_url": {"url": TEST_IMAGE_DATA_URL} - }] + content = [{"type": "image_url", "image_url": {"url": TEST_IMAGE_DATA_URL}}] result = api._convert_content_for_responses_api(content) self.assertEqual(len(result), 1) self.assertEqual(result[0]["type"], "input_image") self.assertEqual(result[0]["image_url"], TEST_IMAGE_DATA_URL) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_convert_mixed_content(self, mock_openai: Mock) -> None: """Test converting mixed text and image content.""" api = OpenAIResponsesAPI() @@ -366,7 +318,7 @@ def test_convert_mixed_content(self, mock_openai: Mock) -> None: self.assertEqual(result[1], {"type": "input_image", "image_url": "data:image/png;base64,abc"}) self.assertEqual(result[2], {"type": "input_text", "text": "What do you see?"}) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_unknown_type_preserved(self, mock_openai: Mock) -> None: """Test unknown content types are preserved.""" api = OpenAIResponsesAPI() @@ -375,7 +327,7 @@ def test_unknown_type_preserved(self, mock_openai: Mock) -> None: self.assertEqual(result, content) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_non_dict_items_preserved(self, mock_openai: Mock) -> None: """Test non-dict items in list are preserved.""" api = OpenAIResponsesAPI() @@ -390,41 +342,40 @@ class TestRemoveImagesFromUserMessages(unittest.TestCase): def setUp(self) -> None: """Set up test fixtures.""" - self.original_api_key = os.environ.get('OPENAI_API_KEY') - os.environ['OPENAI_API_KEY'] = 'test-api-key' + self.original_api_key = os.environ.get("OPENAI_API_KEY") + os.environ["OPENAI_API_KEY"] = "test-api-key" def tearDown(self) -> None: """Clean up after tests.""" if self.original_api_key: - os.environ['OPENAI_API_KEY'] = self.original_api_key + os.environ["OPENAI_API_KEY"] = self.original_api_key else: - os.environ.pop('OPENAI_API_KEY', None) + os.environ.pop("OPENAI_API_KEY", None) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_removes_image_content(self, mock_openai: Mock) -> None: """Test images are removed from user messages.""" api = OpenAIAPIBase() - api.messages.append({ - "role": "user", - "content": [ - {"type": "text", "text": "test message"}, - {"type": "image_url", "image_url": {"url": TEST_IMAGE_DATA_URL}} - ] - }) + api.messages.append( + { + "role": "user", + "content": [ + {"type": "text", "text": "test message"}, + {"type": "image_url", "image_url": {"url": TEST_IMAGE_DATA_URL}}, + ], + } + ) api._remove_images_from_user_messages() # Content should now be just the text string self.assertEqual(api.messages[-1]["content"], "test message") - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_preserves_string_content(self, mock_openai: Mock) -> None: """Test string content is preserved unchanged.""" api = OpenAIAPIBase() - api.messages.append({ - "role": "user", - "content": "simple text" - }) + api.messages.append({"role": "user", "content": "simple text"}) api._remove_images_from_user_messages() @@ -436,23 +387,23 @@ class TestPreviousResponseId(unittest.TestCase): def setUp(self) -> None: """Set up test fixtures.""" - self.original_api_key = os.environ.get('OPENAI_API_KEY') - os.environ['OPENAI_API_KEY'] = 'test-api-key' + self.original_api_key = os.environ.get("OPENAI_API_KEY") + os.environ["OPENAI_API_KEY"] = "test-api-key" def tearDown(self) -> None: """Clean up after tests.""" if self.original_api_key: - os.environ['OPENAI_API_KEY'] = self.original_api_key + os.environ["OPENAI_API_KEY"] = self.original_api_key else: - os.environ.pop('OPENAI_API_KEY', None) + os.environ.pop("OPENAI_API_KEY", None) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_initial_previous_response_id_is_none(self, mock_openai: Mock) -> None: """Test that _previous_response_id starts as None.""" api = OpenAIResponsesAPI() self.assertIsNone(api._previous_response_id) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_reset_conversation_clears_previous_response_id(self, mock_openai: Mock) -> None: """Test that reset_conversation clears the previous_response_id.""" api = OpenAIResponsesAPI() @@ -462,7 +413,7 @@ def test_reset_conversation_clears_previous_response_id(self, mock_openai: Mock) self.assertIsNone(api._previous_response_id) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_is_regular_message_turn_with_user_message(self, mock_openai: Mock) -> None: """Test _is_regular_message_turn returns True for user messages.""" api = OpenAIResponsesAPI() @@ -470,19 +421,15 @@ def test_is_regular_message_turn_with_user_message(self, mock_openai: Mock) -> N self.assertTrue(api._is_regular_message_turn()) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_is_regular_message_turn_with_tool_result(self, mock_openai: Mock) -> None: """Test _is_regular_message_turn returns False for tool results.""" api = OpenAIResponsesAPI() - api.messages.append({ - "role": "tool", - "tool_call_id": "call_123", - "content": "result" - }) + api.messages.append({"role": "tool", "tool_call_id": "call_123", "content": "result"}) self.assertFalse(api._is_regular_message_turn()) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_is_regular_message_turn_empty_messages(self, mock_openai: Mock) -> None: """Test _is_regular_message_turn returns False for empty messages.""" api = OpenAIResponsesAPI() @@ -490,7 +437,7 @@ def test_is_regular_message_turn_empty_messages(self, mock_openai: Mock) -> None self.assertFalse(api._is_regular_message_turn()) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_get_latest_user_message_for_input(self, mock_openai: Mock) -> None: """Test _get_latest_user_message_for_input extracts last user message.""" api = OpenAIResponsesAPI() @@ -505,17 +452,19 @@ def test_get_latest_user_message_for_input(self, mock_openai: Mock) -> None: self.assertEqual(result[0]["role"], "user") self.assertEqual(result[0]["content"], "second question") - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_get_latest_user_message_converts_content(self, mock_openai: Mock) -> None: """Test _get_latest_user_message_for_input converts content format.""" api = OpenAIResponsesAPI() - api.messages.append({ - "role": "user", - "content": [ - {"type": "text", "text": "Look at this"}, - {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}} - ] - }) + api.messages.append( + { + "role": "user", + "content": [ + {"type": "text", "text": "Look at this"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}}, + ], + } + ) result = api._get_latest_user_message_for_input() @@ -525,7 +474,7 @@ def test_get_latest_user_message_converts_content(self, mock_openai: Mock) -> No self.assertEqual(content[0]["type"], "input_text") self.assertEqual(content[1]["type"], "input_image") - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_get_latest_user_message_empty_returns_empty(self, mock_openai: Mock) -> None: """Test _get_latest_user_message_for_input returns empty for no messages.""" api = OpenAIResponsesAPI() @@ -536,5 +485,5 @@ def test_get_latest_user_message_empty_returns_empty(self, mock_openai: Mock) -> self.assertEqual(result, []) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/server_tests/test_local_provider_base.py b/server_tests/test_local_provider_base.py index d0399adb..aa48f5ea 100644 --- a/server_tests/test_local_provider_base.py +++ b/server_tests/test_local_provider_base.py @@ -5,7 +5,6 @@ tool capability detection and model name normalization. """ - from static.providers.local import ( TOOL_CAPABLE_MODEL_FAMILIES, LocalProviderRegistry, @@ -165,6 +164,7 @@ def test_get_registered_providers_empty_initially(self) -> None: def test_get_provider_class_ollama(self) -> None: """Can retrieve Ollama provider class.""" from static.providers.local.ollama_api import OllamaAPI + provider_class = LocalProviderRegistry.get_provider_class("ollama") assert provider_class is OllamaAPI diff --git a/server_tests/test_markdown_parser.py b/server_tests/test_markdown_parser.py index 0b33d747..cd428f6e 100644 --- a/server_tests/test_markdown_parser.py +++ b/server_tests/test_markdown_parser.py @@ -257,7 +257,7 @@ def test_multiple_math_expressions(self) -> None: math_block_count = result.count('class="math-block"') self.assertEqual(math_inline_count, 2) # Two inline expressions - self.assertEqual(math_block_count, 1) # One block expression + self.assertEqual(math_block_count, 1) # One block expression def test_table_with_inline_formatting(self) -> None: """Test tables with inline formatting in cells.""" @@ -468,7 +468,7 @@ def test_mathematical_expressions_not_tables(self) -> None: "Vector notation: |v| represents the magnitude of vector v.", "Set builder notation: {x ∈ ℝ | x² > 4}.", "Conditional probability: P(A | B) = P(A ∩ B) / P(B).", - "Matrix determinant: |A| for matrix A." + "Matrix determinant: |A| for matrix A.", ] for example in math_examples: @@ -509,9 +509,9 @@ def test_table_alignment(self) -> None: result = self.parser.parse(aligned_table) self.assertIn("", result) - self.assertIn('text-align: left;', result) - self.assertIn('text-align: center;', result) - self.assertIn('text-align: right;', result) + self.assertIn("text-align: left;", result) + self.assertIn("text-align: center;", result) + self.assertIn("text-align: right;", result) def test_table_without_separator_not_table(self) -> None: """Test that lines with pipes but no separator row are NOT tables.""" @@ -523,7 +523,7 @@ def test_table_without_separator_not_table(self) -> None: "Just text | with pipes | scattered around", "| Start pipe only", "End pipe only |", - "Middle | pipe | here" + "Middle | pipe | here", ] for example in not_tables: @@ -537,15 +537,12 @@ def test_malformed_table_separators(self) -> None: """| Header | | Not a separator | | Data |""", # Second line is not a valid separator - """| Header | |====| | Data |""", # Equals instead of hyphens - """| Header | | abc | | Data |""", # Letters instead of hyphens - """| Header | | | | Data |""", # Empty separator @@ -562,19 +559,15 @@ def test_valid_table_separators(self) -> None: """| Header | |--------| | Data |""", # Basic separator - """| Left | Right | |:-----|------:| | A | B |""", # With alignment - """| A | B | C | |---|:-:|--:| | 1 | 2 | 3 |""", # Mixed alignment - """| Header | | --- | | Data |""", # Minimal separator - """|Header| |---| |Data|""", # No spaces around pipes @@ -652,16 +645,13 @@ def test_lines_not_starting_with_pipe_not_tables(self) -> None: "Math expression |x| in sentence", "Set notation {x | x > 0} explained", "Conditional probability P(A | B) formula", - # Even with what looks like separator lines after """Text with | pipes | in middle |-----|-----| More text | here | too""", - """Math |x| and |y| values |---|---| Not a table | still | not""", - # Mixed content """Regular text with | pipes | scattered | Header | Column | @@ -677,51 +667,52 @@ def test_lines_not_starting_with_pipe_not_tables(self) -> None: table_count = result.count("
") # For single line examples, should be 0 - if '\n' not in example: - self.assertEqual(table_count, 0, - f"Single line not starting with pipe was parsed as table: '{example}'") + if "\n" not in example: + self.assertEqual( + table_count, 0, f"Single line not starting with pipe was parsed as table: '{example}'" + ) # For multi-line examples, check that lines not starting with pipes don't create tables else: - lines = example.split('\n') - [line for line in lines if line.strip().startswith('|')] - non_table_lines = [line for line in lines if not line.strip().startswith('|')] + lines = example.split("\n") + [line for line in lines if line.strip().startswith("|")] + non_table_lines = [line for line in lines if not line.strip().startswith("|")] # Non-table lines should not be in table HTML for non_table_line in non_table_lines: - if '|' in non_table_line: + if "|" in non_table_line: # Make sure this text appears outside of table tags - pipe_content = non_table_line.split('|')[0].strip() + pipe_content = non_table_line.split("|")[0].strip() if pipe_content: # This content should appear in the result but not inside table tags - self.assertIn(pipe_content, result, - f"Content '{pipe_content}' should appear in result") + self.assertIn(pipe_content, result, f"Content '{pipe_content}' should appear in result") def test_edge_cases_pipes_and_tables(self) -> None: """Test edge cases with pipes and potential table confusion.""" edge_cases = [ # Single pipe in text ("Just a | pipe", False), - # Pipes at start/end ("|Starting pipe", False), ("Ending pipe|", False), - # Multiple pipes but no valid table structure ("| A | B | C |", False), # No separator - # Valid minimal table - ("""| A | + ( + """| A | |---| -| 1 |""", True), - +| 1 |""", + True, + ), # Mathematical expressions ("Function f(x) = |x - 1| + |x + 1|", False), ("Probability P(A|B) = 0.5", False), - # Empty table cells - ("""| A | B | + ( + """| A | B | |---|---| -| | |""", True), +| | |""", + True, + ), ] for content, should_be_table in edge_cases: @@ -735,5 +726,5 @@ def test_edge_cases_pipes_and_tables(self) -> None: self.assertNotIn("
", result, f"Should NOT be table: {content}") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/server_tests/test_mocks.py b/server_tests/test_mocks.py index ef95a2a5..f79509c7 100644 --- a/server_tests/test_mocks.py +++ b/server_tests/test_mocks.py @@ -91,17 +91,20 @@ def __init__(self, width: float, height: float, draw_enabled: bool = True) -> No def get_drawables(self) -> List[Dict[str, Any]]: """Get all drawable objects on the canvas.""" - return cast(List[Dict[str, Any]], ( - self.points + - self.segments + - self.circles + - self.rectangles + - self.triangles + - self.ellipses + - self.functions + - self.vectors + - self.polygons - )) + return cast( + List[Dict[str, Any]], + ( + self.points + + self.segments + + self.circles + + self.rectangles + + self.triangles + + self.ellipses + + self.functions + + self.vectors + + self.polygons + ), + ) def get_drawables_by_class_name(self, class_name: str) -> List[Dict[str, Any]]: """Get drawables of a specific class.""" @@ -113,7 +116,7 @@ def get_drawables_by_class_name(self, class_name: str) -> List[Dict[str, Any]]: "Triangle": self.triangles, "Ellipse": self.ellipses, "Function": self.functions, - "Vector": self.vectors + "Vector": self.vectors, } return cast(List[Dict[str, Any]], class_map.get(class_name, [])) @@ -129,7 +132,7 @@ def get_canvas_state(self) -> CanvasStateDict: "Ellipses": self.ellipses, "Functions": self.functions, "Vectors": self.vectors, - "computations": self.computations + "computations": self.computations, } def clear(self) -> None: @@ -153,21 +156,15 @@ def create_point(self, x: float, y: float, name: str = "") -> PointDict: def create_segment(self, x1: float, y1: float, x2: float, y2: float, name: str = "") -> SegmentDict: """Create a segment on the canvas.""" - segment: SegmentDict = cast(SegmentDict, { - "point1": {"x": x1, "y": y1}, - "point2": {"x": x2, "y": y2}, - "name": name - }) + segment: SegmentDict = cast( + SegmentDict, {"point1": {"x": x1, "y": y1}, "point2": {"x": x2, "y": y2}, "name": name} + ) self.segments.append(segment) return segment def create_circle(self, x: float, y: float, radius: float, name: str = "") -> CircleDict: """Create a circle on the canvas.""" - circle: CircleDict = cast(CircleDict, { - "center": {"x": x, "y": y}, - "radius": radius, - "name": name - }) + circle: CircleDict = cast(CircleDict, {"center": {"x": x, "y": y}, "radius": radius, "name": name}) self.circles.append(circle) return circle @@ -198,24 +195,27 @@ def create_polygon( self.polygons.append(polygon) if polygon_type == "rectangle" and len(vertex_pairs) >= 3: - rectangle: RectangleDict = cast(RectangleDict, { - "point1": {"x": vertex_pairs[0][0], "y": vertex_pairs[0][1]}, - "point3": {"x": vertex_pairs[2][0], "y": vertex_pairs[2][1]}, - "name": name - }) + rectangle: RectangleDict = cast( + RectangleDict, + { + "point1": {"x": vertex_pairs[0][0], "y": vertex_pairs[0][1]}, + "point3": {"x": vertex_pairs[2][0], "y": vertex_pairs[2][1]}, + "name": name, + }, + ) self.rectangles.append(rectangle) return cast(Dict[str, Any], rectangle) return polygon - def create_triangle(self, x1: float, y1: float, x2: float, y2: float, x3: float, y3: float, name: str = "") -> TriangleDict: + def create_triangle( + self, x1: float, y1: float, x2: float, y2: float, x3: float, y3: float, name: str = "" + ) -> TriangleDict: """Create a triangle on the canvas.""" - triangle: TriangleDict = cast(TriangleDict, { - "point1": {"x": x1, "y": y1}, - "point2": {"x": x2, "y": y2}, - "point3": {"x": x3, "y": y3}, - "name": name - }) + triangle: TriangleDict = cast( + TriangleDict, + {"point1": {"x": x1, "y": y1}, "point2": {"x": x2, "y": y2}, "point3": {"x": x3, "y": y3}, "name": name}, + ) self.triangles.append(triangle) self.create_polygon( [(x1, y1), (x2, y2), (x3, y3)], @@ -224,44 +224,46 @@ def create_triangle(self, x1: float, y1: float, x2: float, y2: float, x3: float, ) return triangle - def create_ellipse(self, x: float, y: float, radius_x: float, radius_y: float, rotation_angle: float = 0, name: str = "") -> EllipseDict: + def create_ellipse( + self, x: float, y: float, radius_x: float, radius_y: float, rotation_angle: float = 0, name: str = "" + ) -> EllipseDict: """Create an ellipse on the canvas.""" - ellipse: EllipseDict = cast(EllipseDict, { - "center": {"x": x, "y": y}, - "radius_x": radius_x, - "radius_y": radius_y, - "rotation_angle": rotation_angle, - "name": name - }) + ellipse: EllipseDict = cast( + EllipseDict, + { + "center": {"x": x, "y": y}, + "radius_x": radius_x, + "radius_y": radius_y, + "rotation_angle": rotation_angle, + "name": name, + }, + ) self.ellipses.append(ellipse) return ellipse - def draw_function(self, function_string: str, name: str = "", left_bound: Optional[float] = None, right_bound: Optional[float] = None) -> FunctionDict: + def draw_function( + self, + function_string: str, + name: str = "", + left_bound: Optional[float] = None, + right_bound: Optional[float] = None, + ) -> FunctionDict: """Add a function to the canvas.""" - function: FunctionDict = cast(FunctionDict, { - "function_string": function_string, - "name": name, - "left_bound": left_bound, - "right_bound": right_bound - }) + function: FunctionDict = cast( + FunctionDict, + {"function_string": function_string, "name": name, "left_bound": left_bound, "right_bound": right_bound}, + ) self.functions.append(function) return function def create_vector(self, x1: float, y1: float, x2: float, y2: float, name: str = "") -> VectorDict: """Create a vector on the canvas.""" - vector: VectorDict = cast(VectorDict, { - "origin": {"x": x1, "y": y1}, - "tip": {"x": x2, "y": y2}, - "name": name - }) + vector: VectorDict = cast(VectorDict, {"origin": {"x": x1, "y": y1}, "tip": {"x": x2, "y": y2}, "name": name}) self.vectors.append(vector) return vector def add_computation(self, expression: str, result: Any) -> ComputationDict: """Add a computation to the canvas.""" - computation: ComputationDict = cast(ComputationDict, { - "expression": expression, - "result": result - }) + computation: ComputationDict = cast(ComputationDict, {"expression": expression, "result": result}) self.computations.append(computation) return computation diff --git a/server_tests/test_ollama_api.py b/server_tests/test_ollama_api.py index 0760449c..385cfb14 100644 --- a/server_tests/test_ollama_api.py +++ b/server_tests/test_ollama_api.py @@ -145,6 +145,7 @@ def test_default_base_url(self) -> None: with patch.dict("os.environ", {}, clear=True): # Remove the env var if it exists import os + os.environ.pop("OLLAMA_BASE_URL", None) instance = object.__new__(OllamaAPI) diff --git a/server_tests/test_ollama_integration.py b/server_tests/test_ollama_integration.py index 6351ba13..2d5d25dd 100644 --- a/server_tests/test_ollama_integration.py +++ b/server_tests/test_ollama_integration.py @@ -15,10 +15,7 @@ # Skip all tests in this module if Ollama is not running -pytestmark = pytest.mark.skipif( - not OllamaAPI.is_server_running(), - reason="Ollama server is not running" -) +pytestmark = pytest.mark.skipif(not OllamaAPI.is_server_running(), reason="Ollama server is not running") class TestOllamaServerIntegration: @@ -105,10 +102,12 @@ def ollama_api(self, tool_capable_model: str) -> OllamaAPI: def test_simple_chat_completion(self, ollama_api: OllamaAPI) -> None: """Can get a simple chat completion without tools.""" # Create a simple prompt - prompt = json.dumps({ - "user_message": "Say 'hello' and nothing else.", - "canvas_state": {}, - }) + prompt = json.dumps( + { + "user_message": "Say 'hello' and nothing else.", + "canvas_state": {}, + } + ) response = ollama_api.create_chat_completion(prompt) @@ -121,6 +120,7 @@ def test_tool_call_request(self, ollama_api: OllamaAPI) -> None: """Model can request tool calls.""" # Limit tools to just create_point for simpler testing from static.functions_definitions import FUNCTIONS + create_point_tool = None for tool in FUNCTIONS: if tool.get("function", {}).get("name") == "create_point": @@ -131,10 +131,12 @@ def test_tool_call_request(self, ollama_api: OllamaAPI) -> None: ollama_api.tools = [create_point_tool] # Create a prompt that should trigger a tool call - prompt = json.dumps({ - "user_message": "Create a point at coordinates (50, 100) named 'TestPoint'", - "canvas_state": {"points": [], "segments": []}, - }) + prompt = json.dumps( + { + "user_message": "Create a point at coordinates (50, 100) named 'TestPoint'", + "canvas_state": {"points": [], "segments": []}, + } + ) response = ollama_api.create_chat_completion(prompt) @@ -154,10 +156,12 @@ def test_conversation_history_maintained(self, ollama_api: OllamaAPI) -> None: The model's ability to utilize context varies by model quality. """ # First message - prompt1 = json.dumps({ - "user_message": "My favorite color is blue. Remember this.", - "canvas_state": {}, - }) + prompt1 = json.dumps( + { + "user_message": "My favorite color is blue. Remember this.", + "canvas_state": {}, + } + ) ollama_api.create_chat_completion(prompt1) # Check history structure: system + user + assistant @@ -171,10 +175,12 @@ def test_conversation_history_maintained(self, ollama_api: OllamaAPI) -> None: assert "blue" in first_user_content.lower() # Second message referencing the first - prompt2 = json.dumps({ - "user_message": "What is my favorite color?", - "canvas_state": {}, - }) + prompt2 = json.dumps( + { + "user_message": "What is my favorite color?", + "canvas_state": {}, + } + ) response = ollama_api.create_chat_completion(prompt2) # Verify history grew correctly: system + user + assistant + user + assistant @@ -193,10 +199,12 @@ def test_conversation_history_maintained(self, ollama_api: OllamaAPI) -> None: def test_reset_conversation(self, ollama_api: OllamaAPI) -> None: """Can reset conversation history.""" # Add some messages - prompt = json.dumps({ - "user_message": "Hello", - "canvas_state": {}, - }) + prompt = json.dumps( + { + "user_message": "Hello", + "canvas_state": {}, + } + ) ollama_api.create_chat_completion(prompt) initial_count = len(ollama_api.messages) @@ -232,10 +240,12 @@ def ollama_api(self, tool_capable_model: str) -> OllamaAPI: def test_streaming_response(self, ollama_api: OllamaAPI) -> None: """Can stream responses.""" - prompt = json.dumps({ - "user_message": "Count from 1 to 5.", - "canvas_state": {}, - }) + prompt = json.dumps( + { + "user_message": "Count from 1 to 5.", + "canvas_state": {}, + } + ) tokens: List[str] = [] final_event: Dict[str, Any] = {} @@ -272,10 +282,12 @@ def test_invalid_model_handling(self) -> None: model = AIModel.from_identifier("nonexistent-model:latest") api = OllamaAPI(model=model) - prompt = json.dumps({ - "user_message": "Hello", - "canvas_state": {}, - }) + prompt = json.dumps( + { + "user_message": "Hello", + "canvas_state": {}, + } + ) # Should handle error gracefully response = api.create_chat_completion(prompt) @@ -321,10 +333,12 @@ def test_tool_call_and_result_flow(self, ollama_api: OllamaAPI) -> None: ollama_api.tools = [simple_tool] # First request - might trigger tool call - prompt1 = json.dumps({ - "user_message": "What's on the canvas? Use get_current_canvas_state to check.", - "canvas_state": {"points": [{"name": "A", "x": 0, "y": 0}]}, - }) + prompt1 = json.dumps( + { + "user_message": "What's on the canvas? Use get_current_canvas_state to check.", + "canvas_state": {"points": [{"name": "A", "x": 0, "y": 0}]}, + } + ) response1 = ollama_api.create_chat_completion(prompt1) @@ -347,11 +361,13 @@ def test_tool_call_and_result_flow(self, ollama_api: OllamaAPI) -> None: tool_id = tool_call.id # Send tool result - prompt2 = json.dumps({ - "tool_call_results": json.dumps({ - tool_id: {"points": [{"name": "A", "x": 0, "y": 0}], "segments": []} - }), - }) + prompt2 = json.dumps( + { + "tool_call_results": json.dumps( + {tool_id: {"points": [{"name": "A", "x": 0, "y": 0}], "segments": []}} + ), + } + ) response2 = ollama_api.create_chat_completion(prompt2) diff --git a/server_tests/test_openai_api_base.py b/server_tests/test_openai_api_base.py index 2ad01349..8bb9d750 100644 --- a/server_tests/test_openai_api_base.py +++ b/server_tests/test_openai_api_base.py @@ -21,35 +21,35 @@ class TestOpenAIAPIBase(unittest.TestCase): def setUp(self) -> None: """Set up test fixtures.""" # Ensure OPENAI_API_KEY is set for tests - self.original_api_key = os.environ.get('OPENAI_API_KEY') - os.environ['OPENAI_API_KEY'] = 'test-api-key' - self.original_summary_mode = os.environ.get('AI_CANVAS_SUMMARY_MODE') - self.original_hybrid_max = os.environ.get('AI_CANVAS_HYBRID_FULL_MAX_BYTES') - self.original_summary_telemetry = os.environ.get('AI_CANVAS_SUMMARY_TELEMETRY') - os.environ.pop('AI_CANVAS_SUMMARY_MODE', None) - os.environ.pop('AI_CANVAS_HYBRID_FULL_MAX_BYTES', None) - os.environ.pop('AI_CANVAS_SUMMARY_TELEMETRY', None) + self.original_api_key = os.environ.get("OPENAI_API_KEY") + os.environ["OPENAI_API_KEY"] = "test-api-key" + self.original_summary_mode = os.environ.get("AI_CANVAS_SUMMARY_MODE") + self.original_hybrid_max = os.environ.get("AI_CANVAS_HYBRID_FULL_MAX_BYTES") + self.original_summary_telemetry = os.environ.get("AI_CANVAS_SUMMARY_TELEMETRY") + os.environ.pop("AI_CANVAS_SUMMARY_MODE", None) + os.environ.pop("AI_CANVAS_HYBRID_FULL_MAX_BYTES", None) + os.environ.pop("AI_CANVAS_SUMMARY_TELEMETRY", None) def tearDown(self) -> None: """Clean up after tests.""" if self.original_api_key: - os.environ['OPENAI_API_KEY'] = self.original_api_key + os.environ["OPENAI_API_KEY"] = self.original_api_key else: - os.environ.pop('OPENAI_API_KEY', None) + os.environ.pop("OPENAI_API_KEY", None) if self.original_summary_mode is None: - os.environ.pop('AI_CANVAS_SUMMARY_MODE', None) + os.environ.pop("AI_CANVAS_SUMMARY_MODE", None) else: - os.environ['AI_CANVAS_SUMMARY_MODE'] = self.original_summary_mode + os.environ["AI_CANVAS_SUMMARY_MODE"] = self.original_summary_mode if self.original_hybrid_max is None: - os.environ.pop('AI_CANVAS_HYBRID_FULL_MAX_BYTES', None) + os.environ.pop("AI_CANVAS_HYBRID_FULL_MAX_BYTES", None) else: - os.environ['AI_CANVAS_HYBRID_FULL_MAX_BYTES'] = self.original_hybrid_max + os.environ["AI_CANVAS_HYBRID_FULL_MAX_BYTES"] = self.original_hybrid_max if self.original_summary_telemetry is None: - os.environ.pop('AI_CANVAS_SUMMARY_TELEMETRY', None) + os.environ.pop("AI_CANVAS_SUMMARY_TELEMETRY", None) else: - os.environ['AI_CANVAS_SUMMARY_TELEMETRY'] = self.original_summary_telemetry + os.environ["AI_CANVAS_SUMMARY_TELEMETRY"] = self.original_summary_telemetry - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_initialization_default_model(self, mock_openai: Mock) -> None: """Test API initializes with default model.""" api = OpenAIAPIBase() @@ -57,14 +57,14 @@ def test_initialization_default_model(self, mock_openai: Mock) -> None: self.assertEqual(api.temperature, 0.2) self.assertEqual(api.max_tokens, 16000) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_initialization_custom_model(self, mock_openai: Mock) -> None: """Test API initializes with custom model.""" custom_model = AIModel.from_identifier("gpt-4o") api = OpenAIAPIBase(model=custom_model) self.assertEqual(api.model.id, "gpt-4o") - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_initialization_messages(self, mock_openai: Mock) -> None: """Test API initializes with developer message.""" api = OpenAIAPIBase() @@ -72,14 +72,14 @@ def test_initialization_messages(self, mock_openai: Mock) -> None: self.assertEqual(api.messages[0]["role"], "developer") self.assertIn("educational graphing calculator", api.messages[0]["content"]) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_get_model(self, mock_openai: Mock) -> None: """Test get_model returns current model.""" api = OpenAIAPIBase() model = api.get_model() self.assertEqual(model.id, AIModel.DEFAULT_MODEL) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_reset_conversation(self, mock_openai: Mock) -> None: """Test reset_conversation clears messages except developer message.""" api = OpenAIAPIBase() @@ -93,7 +93,7 @@ def test_reset_conversation(self, mock_openai: Mock) -> None: self.assertEqual(len(api.messages), 1) self.assertEqual(api.messages[0]["role"], "developer") - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_set_model(self, mock_openai: Mock) -> None: """Test set_model changes the model.""" api = OpenAIAPIBase() @@ -101,7 +101,7 @@ def test_set_model(self, mock_openai: Mock) -> None: self.assertEqual(api.model.id, "gpt-4o") self.assertFalse(api.model.is_reasoning_model) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_set_model_to_reasoning(self, mock_openai: Mock) -> None: """Test set_model to a reasoning model.""" api = OpenAIAPIBase() @@ -109,7 +109,7 @@ def test_set_model_to_reasoning(self, mock_openai: Mock) -> None: self.assertEqual(api.model.id, "o3") self.assertTrue(api.model.is_reasoning_model) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_set_model_same_model_no_change(self, mock_openai: Mock) -> None: """Test set_model with same model doesn't change.""" api = OpenAIAPIBase() @@ -118,7 +118,7 @@ def test_set_model_same_model_no_change(self, mock_openai: Mock) -> None: # Model should be the same instance (no change) self.assertEqual(api.model.id, original_model.id) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_create_tool_message(self, mock_openai: Mock) -> None: """Test _create_tool_message creates correct format.""" api = OpenAIAPIBase() @@ -127,7 +127,7 @@ def test_create_tool_message(self, mock_openai: Mock) -> None: self.assertEqual(tool_msg["tool_call_id"], "call_123") self.assertEqual(tool_msg["content"], "result content") - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_append_tool_messages(self, mock_openai: Mock) -> None: """Test _append_tool_messages adds placeholder messages.""" api = OpenAIAPIBase() @@ -145,7 +145,7 @@ def test_append_tool_messages(self, mock_openai: Mock) -> None: self.assertEqual(api.messages[-2]["tool_call_id"], "call_1") self.assertEqual(api.messages[-1]["tool_call_id"], "call_2") - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_append_tool_messages_none(self, mock_openai: Mock) -> None: """Test _append_tool_messages with None does nothing.""" api = OpenAIAPIBase() @@ -153,23 +153,19 @@ def test_append_tool_messages_none(self, mock_openai: Mock) -> None: api._append_tool_messages(None) self.assertEqual(len(api.messages), initial_count) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_update_tool_messages_with_results(self, mock_openai: Mock) -> None: """Test _update_tool_messages_with_results updates last tool message.""" api = OpenAIAPIBase() # Add a tool message - api.messages.append({ - "role": "tool", - "tool_call_id": "call_123", - "content": "Awaiting result..." - }) + api.messages.append({"role": "tool", "tool_call_id": "call_123", "content": "Awaiting result..."}) results = {"status": "success", "data": "test"} api._update_tool_messages_with_results(json.dumps(results)) self.assertEqual(api.messages[-1]["content"], json.dumps(results)) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_parse_prompt_json_valid(self, mock_openai: Mock) -> None: """Test _parse_prompt_json with valid JSON.""" api = OpenAIAPIBase() @@ -178,33 +174,33 @@ def test_parse_prompt_json_valid(self, mock_openai: Mock) -> None: self.assertIsNotNone(result) self.assertEqual(result["user_message"], "test") - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_parse_prompt_json_invalid(self, mock_openai: Mock) -> None: """Test _parse_prompt_json with invalid JSON returns None.""" api = OpenAIAPIBase() result = api._parse_prompt_json("not valid json") self.assertIsNone(result) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_parse_prompt_json_non_dict(self, mock_openai: Mock) -> None: """Test _parse_prompt_json with non-dict JSON returns None.""" api = OpenAIAPIBase() result = api._parse_prompt_json(json.dumps(["list", "not", "dict"])) self.assertIsNone(result) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_prepare_message_content_no_vision(self, mock_openai: Mock) -> None: """Test _prepare_message_content without vision returns original prompt.""" api = OpenAIAPIBase() - os.environ['AI_CANVAS_SUMMARY_MODE'] = 'off' + os.environ["AI_CANVAS_SUMMARY_MODE"] = "off" prompt = json.dumps({"user_message": "test", "use_vision": False}) result = api._prepare_message_content(prompt) self.assertEqual(result, prompt) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_prepare_message_content_summary_only_removes_full_canvas_state(self, mock_openai: Mock) -> None: api = OpenAIAPIBase() - os.environ['AI_CANVAS_SUMMARY_MODE'] = 'summary_only' + os.environ["AI_CANVAS_SUMMARY_MODE"] = "summary_only" prompt = json.dumps( { "user_message": "test", @@ -223,11 +219,11 @@ def test_prepare_message_content_summary_only_removes_full_canvas_state(self, mo self.assertIn("state", summary) self.assertIn("metrics", summary) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_prepare_message_content_hybrid_keeps_small_full_state(self, mock_openai: Mock) -> None: api = OpenAIAPIBase() - os.environ['AI_CANVAS_SUMMARY_MODE'] = 'hybrid' - os.environ['AI_CANVAS_HYBRID_FULL_MAX_BYTES'] = '999999' + os.environ["AI_CANVAS_SUMMARY_MODE"] = "hybrid" + os.environ["AI_CANVAS_HYBRID_FULL_MAX_BYTES"] = "999999" prompt = json.dumps( { "user_message": "test", @@ -242,11 +238,11 @@ def test_prepare_message_content_hybrid_keeps_small_full_state(self, mock_openai self.assertNotIn("canvas_state_summary", parsed) self.assertEqual(parsed["canvas_state"], {"Points": [{"name": "A", "args": {"position": {"x": 1, "y": 2}}}]}) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_prepare_message_content_hybrid_drops_large_full_state(self, mock_openai: Mock) -> None: api = OpenAIAPIBase() - os.environ['AI_CANVAS_SUMMARY_MODE'] = 'hybrid' - os.environ['AI_CANVAS_HYBRID_FULL_MAX_BYTES'] = '10' + os.environ["AI_CANVAS_SUMMARY_MODE"] = "hybrid" + os.environ["AI_CANVAS_HYBRID_FULL_MAX_BYTES"] = "10" prompt = json.dumps( { "user_message": "test", @@ -262,10 +258,10 @@ def test_prepare_message_content_hybrid_drops_large_full_state(self, mock_openai self.assertFalse(parsed["canvas_state_summary"]["includes_full_state"]) self.assertIn("state", parsed["canvas_state_summary"]) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_prepare_message_content_default_mode_is_hybrid(self, mock_openai: Mock) -> None: api = OpenAIAPIBase() - os.environ['AI_CANVAS_HYBRID_FULL_MAX_BYTES'] = '10' + os.environ["AI_CANVAS_HYBRID_FULL_MAX_BYTES"] = "10" prompt = json.dumps( { "user_message": "test", @@ -279,11 +275,11 @@ def test_prepare_message_content_default_mode_is_hybrid(self, mock_openai: Mock) self.assertEqual(parsed["canvas_state_summary"]["mode"], "hybrid") self.assertNotIn("canvas_state", parsed) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_prepare_message_content_invalid_mode_falls_back_to_hybrid(self, mock_openai: Mock) -> None: api = OpenAIAPIBase() - os.environ['AI_CANVAS_SUMMARY_MODE'] = 'garbage' - os.environ['AI_CANVAS_HYBRID_FULL_MAX_BYTES'] = '10' + os.environ["AI_CANVAS_SUMMARY_MODE"] = "garbage" + os.environ["AI_CANVAS_HYBRID_FULL_MAX_BYTES"] = "10" prompt = json.dumps( { "user_message": "test", @@ -297,7 +293,7 @@ def test_prepare_message_content_invalid_mode_falls_back_to_hybrid(self, mock_op self.assertEqual(parsed["canvas_state_summary"]["mode"], "hybrid") self.assertNotIn("canvas_state", parsed) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_prepare_message_content_invalid_json(self, mock_openai: Mock) -> None: """Test _prepare_message_content with invalid JSON returns original.""" api = OpenAIAPIBase() @@ -305,12 +301,12 @@ def test_prepare_message_content_invalid_json(self, mock_openai: Mock) -> None: result = api._prepare_message_content(prompt) self.assertEqual(result, prompt) - @patch('static.openai_api_base.OpenAI') - @patch('static.openai_api_base._logger') + @patch("static.openai_api_base.OpenAI") + @patch("static.openai_api_base._logger") def test_prepare_message_content_emits_telemetry_when_enabled(self, mock_logger: Mock, mock_openai: Mock) -> None: api = OpenAIAPIBase() - os.environ['AI_CANVAS_SUMMARY_MODE'] = 'summary_only' - os.environ['AI_CANVAS_SUMMARY_TELEMETRY'] = '1' + os.environ["AI_CANVAS_SUMMARY_MODE"] = "summary_only" + os.environ["AI_CANVAS_SUMMARY_TELEMETRY"] = "1" prompt = json.dumps( { "user_message": "test", @@ -334,12 +330,12 @@ def test_prepare_message_content_emits_telemetry_when_enabled(self, mock_logger: self.assertIn("normalized_prompt_bytes", payload) self.assertIn("output_payload_bytes", payload) - @patch('static.openai_api_base.OpenAI') - @patch('static.openai_api_base._logger') + @patch("static.openai_api_base.OpenAI") + @patch("static.openai_api_base._logger") def test_prepare_message_content_skips_telemetry_when_disabled(self, mock_logger: Mock, mock_openai: Mock) -> None: api = OpenAIAPIBase() - os.environ['AI_CANVAS_SUMMARY_MODE'] = 'summary_only' - os.environ['AI_CANVAS_SUMMARY_TELEMETRY'] = '0' + os.environ["AI_CANVAS_SUMMARY_MODE"] = "summary_only" + os.environ["AI_CANVAS_SUMMARY_TELEMETRY"] = "0" prompt = json.dumps( { "user_message": "test", @@ -351,12 +347,14 @@ def test_prepare_message_content_skips_telemetry_when_disabled(self, mock_logger mock_logger.info.assert_not_called() - @patch('static.openai_api_base.OpenAI') - @patch('static.openai_api_base._logger') - def test_prepare_message_content_telemetry_multimodal_reports_output_payload_size(self, mock_logger: Mock, mock_openai: Mock) -> None: + @patch("static.openai_api_base.OpenAI") + @patch("static.openai_api_base._logger") + def test_prepare_message_content_telemetry_multimodal_reports_output_payload_size( + self, mock_logger: Mock, mock_openai: Mock + ) -> None: api = OpenAIAPIBase() - os.environ['AI_CANVAS_SUMMARY_MODE'] = 'summary_only' - os.environ['AI_CANVAS_SUMMARY_TELEMETRY'] = '1' + os.environ["AI_CANVAS_SUMMARY_MODE"] = "summary_only" + os.environ["AI_CANVAS_SUMMARY_TELEMETRY"] = "1" prompt = json.dumps( { "user_message": "test", @@ -376,13 +374,15 @@ def test_prepare_message_content_telemetry_multimodal_reports_output_payload_siz self.assertIn("normalized_prompt_bytes", payload) self.assertIn("output_payload_bytes", payload) - @patch('static.openai_api_base.OpenAI') - @patch('static.openai_api_base._logger') - def test_prepare_message_content_telemetry_hybrid_small_scene_infers_full_state(self, mock_logger: Mock, mock_openai: Mock) -> None: + @patch("static.openai_api_base.OpenAI") + @patch("static.openai_api_base._logger") + def test_prepare_message_content_telemetry_hybrid_small_scene_infers_full_state( + self, mock_logger: Mock, mock_openai: Mock + ) -> None: api = OpenAIAPIBase() - os.environ['AI_CANVAS_SUMMARY_MODE'] = 'hybrid' - os.environ['AI_CANVAS_HYBRID_FULL_MAX_BYTES'] = '999999' - os.environ['AI_CANVAS_SUMMARY_TELEMETRY'] = '1' + os.environ["AI_CANVAS_SUMMARY_MODE"] = "hybrid" + os.environ["AI_CANVAS_HYBRID_FULL_MAX_BYTES"] = "999999" + os.environ["AI_CANVAS_SUMMARY_TELEMETRY"] = "1" prompt = json.dumps( { "user_message": "test", @@ -401,34 +401,40 @@ def test_prepare_message_content_telemetry_hybrid_small_scene_infers_full_state( self.assertEqual(payload["mode"], "hybrid") self.assertTrue(payload["includes_full_state"]) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_create_error_response(self, mock_openai: Mock) -> None: """Test _create_error_response creates proper error structure.""" api = OpenAIAPIBase() error_resp = api._create_error_response() - self.assertEqual(error_resp.message.content, "I encountered an error processing your request. Please try again.") + self.assertEqual( + error_resp.message.content, "I encountered an error processing your request. Please try again." + ) self.assertEqual(error_resp.message.tool_calls, []) self.assertEqual(error_resp.finish_reason, "error") - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_create_error_response_custom_message(self, mock_openai: Mock) -> None: """Test _create_error_response with custom message.""" api = OpenAIAPIBase() error_resp = api._create_error_response("Custom error message") self.assertEqual(error_resp.message.content, "Custom error message") - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_remove_canvas_state_from_user_messages(self, mock_openai: Mock) -> None: """Test _remove_canvas_state_from_user_messages removes state payloads.""" api = OpenAIAPIBase() - api.messages.append({ - "role": "user", - "content": json.dumps({ - "canvas_state": {"shapes": []}, - "canvas_state_summary": {"state": {"Points": []}}, - "user_message": "test" - }) - }) + api.messages.append( + { + "role": "user", + "content": json.dumps( + { + "canvas_state": {"shapes": []}, + "canvas_state_summary": {"state": {"Points": []}}, + "user_message": "test", + } + ), + } + ) api._remove_canvas_state_from_user_messages() @@ -437,17 +443,19 @@ def test_remove_canvas_state_from_user_messages(self, mock_openai: Mock) -> None self.assertNotIn("canvas_state_summary", content) self.assertEqual(content["user_message"], "test") - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_remove_images_from_user_messages(self, mock_openai: Mock) -> None: """Test _remove_images_from_user_messages removes image content.""" api = OpenAIAPIBase() - api.messages.append({ - "role": "user", - "content": [ - {"type": "text", "text": "test message"}, - {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}} - ] - }) + api.messages.append( + { + "role": "user", + "content": [ + {"type": "text", "text": "test message"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}, + ], + } + ) api._remove_images_from_user_messages() @@ -460,28 +468,28 @@ class TestOpenAIAPIBaseInitialization(unittest.TestCase): def test_api_key_from_environment(self) -> None: """Test API key is read from environment variable.""" - os.environ['OPENAI_API_KEY'] = 'test-env-key' + os.environ["OPENAI_API_KEY"] = "test-env-key" try: api_key = OpenAIAPIBase._initialize_api_key() - self.assertEqual(api_key, 'test-env-key') + self.assertEqual(api_key, "test-env-key") finally: - os.environ.pop('OPENAI_API_KEY', None) + os.environ.pop("OPENAI_API_KEY", None) - @patch('static.openai_api_base.load_dotenv') - @patch('static.openai_api_base.os.path.exists') + @patch("static.openai_api_base.load_dotenv") + @patch("static.openai_api_base.os.path.exists") def test_api_key_missing_returns_placeholder(self, mock_exists: Mock, mock_load_dotenv: Mock) -> None: """Test missing API key returns placeholder instead of crashing.""" # Mock .env file doesn't exist mock_exists.return_value = False # Remove API key from environment - original = os.environ.pop('OPENAI_API_KEY', None) + original = os.environ.pop("OPENAI_API_KEY", None) try: result = OpenAIAPIBase._initialize_api_key() self.assertEqual(result, "not-configured") finally: if original: - os.environ['OPENAI_API_KEY'] = original + os.environ["OPENAI_API_KEY"] = original -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/server_tests/test_openai_completions_api.py b/server_tests/test_openai_completions_api.py index 4f22cd9e..f4a189b1 100644 --- a/server_tests/test_openai_completions_api.py +++ b/server_tests/test_openai_completions_api.py @@ -22,32 +22,29 @@ class TestOpenAIChatCompletionsAPI(unittest.TestCase): def setUp(self) -> None: """Set up test fixtures.""" - self.original_api_key = os.environ.get('OPENAI_API_KEY') - os.environ['OPENAI_API_KEY'] = 'test-api-key' + self.original_api_key = os.environ.get("OPENAI_API_KEY") + os.environ["OPENAI_API_KEY"] = "test-api-key" def tearDown(self) -> None: """Clean up after tests.""" if self.original_api_key: - os.environ['OPENAI_API_KEY'] = self.original_api_key + os.environ["OPENAI_API_KEY"] = self.original_api_key else: - os.environ.pop('OPENAI_API_KEY', None) + os.environ.pop("OPENAI_API_KEY", None) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_initialization(self, mock_openai: Mock) -> None: """Test API initializes correctly.""" api = OpenAIChatCompletionsAPI() self.assertIsNotNone(api.client) self.assertIsNotNone(api.model) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_create_assistant_message_simple(self, mock_openai: Mock) -> None: """Test _create_assistant_message with simple response.""" api = OpenAIChatCompletionsAPI() - response_message = SimpleNamespace( - content="Hello, world!", - tool_calls=None - ) + response_message = SimpleNamespace(content="Hello, world!", tool_calls=None) result = api._create_assistant_message(response_message) @@ -55,22 +52,15 @@ def test_create_assistant_message_simple(self, mock_openai: Mock) -> None: self.assertEqual(result["content"], "Hello, world!") self.assertNotIn("tool_calls", result) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_create_assistant_message_with_tool_calls(self, mock_openai: Mock) -> None: """Test _create_assistant_message with tool calls.""" api = OpenAIChatCompletionsAPI() tool_call = SimpleNamespace( - id="call_123", - function=SimpleNamespace( - name="create_point", - arguments='{"x": 1, "y": 2}' - ) - ) - response_message = SimpleNamespace( - content="Creating a point...", - tool_calls=[tool_call] + id="call_123", function=SimpleNamespace(name="create_point", arguments='{"x": 1, "y": 2}') ) + response_message = SimpleNamespace(content="Creating a point...", tool_calls=[tool_call]) result = api._create_assistant_message(response_message) @@ -80,7 +70,7 @@ def test_create_assistant_message_with_tool_calls(self, mock_openai: Mock) -> No self.assertEqual(result["tool_calls"][0]["type"], "function") self.assertEqual(result["tool_calls"][0]["function"]["name"], "create_point") - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_accumulate_tool_calls(self, mock_openai: Mock) -> None: """Test _accumulate_tool_calls accumulates streaming deltas.""" api = OpenAIChatCompletionsAPI() @@ -88,25 +78,19 @@ def test_accumulate_tool_calls(self, mock_openai: Mock) -> None: # First delta with id and function name delta1 = SimpleNamespace( - index=0, - id="call_123", - function=SimpleNamespace(name="create_point", arguments='{"x":') + index=0, id="call_123", function=SimpleNamespace(name="create_point", arguments='{"x":') ) api._accumulate_tool_calls([delta1], accumulator) # Second delta with more arguments - delta2 = SimpleNamespace( - index=0, - id=None, - function=SimpleNamespace(name=None, arguments=' 1}') - ) + delta2 = SimpleNamespace(index=0, id=None, function=SimpleNamespace(name=None, arguments=" 1}")) api._accumulate_tool_calls([delta2], accumulator) self.assertEqual(accumulator[0]["id"], "call_123") self.assertEqual(accumulator[0]["function"]["name"], "create_point") self.assertEqual(accumulator[0]["function"]["arguments"], '{"x": 1}') - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_accumulate_tool_calls_supports_dict_shape_and_invalid_entries(self, mock_openai: Mock) -> None: """Tool call accumulation supports dict deltas and ignores malformed entries.""" api = OpenAIChatCompletionsAPI() @@ -131,13 +115,13 @@ def test_accumulate_tool_calls_supports_dict_shape_and_invalid_entries(self, moc self.assertEqual(accumulator[1]["function"]["name"], "draw_line") self.assertEqual(accumulator[1]["function"]["arguments"], '{"x1": 0}') - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_normalize_tool_calls(self, mock_openai: Mock) -> None: """Test _normalize_tool_calls converts accumulator to list.""" api = OpenAIChatCompletionsAPI() accumulator = { 0: {"id": "call_1", "function": {"name": "func1", "arguments": "{}"}}, - 1: {"id": "call_2", "function": {"name": "func2", "arguments": "{}"}} + 1: {"id": "call_2", "function": {"name": "func2", "arguments": "{}"}}, } result = api._normalize_tool_calls(accumulator) @@ -146,13 +130,11 @@ def test_normalize_tool_calls(self, mock_openai: Mock) -> None: self.assertEqual(result[0]["id"], "call_1") self.assertEqual(result[1]["id"], "call_2") - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_prepare_tool_calls_for_response(self, mock_openai: Mock) -> None: """Test _prepare_tool_calls_for_response formats for JSON response.""" api = OpenAIChatCompletionsAPI() - tool_calls = [ - {"id": "call_1", "function": {"name": "create_point", "arguments": '{"x": 1, "y": 2}'}} - ] + tool_calls = [{"id": "call_1", "function": {"name": "create_point", "arguments": '{"x": 1, "y": 2}'}}] result = api._prepare_tool_calls_for_response(tool_calls) @@ -160,35 +142,31 @@ def test_prepare_tool_calls_for_response(self, mock_openai: Mock) -> None: self.assertEqual(result[0]["function_name"], "create_point") self.assertEqual(result[0]["arguments"], {"x": 1, "y": 2}) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_prepare_tool_calls_invalid_json(self, mock_openai: Mock) -> None: """Test _prepare_tool_calls_for_response handles invalid JSON arguments.""" api = OpenAIChatCompletionsAPI() - tool_calls = [ - {"id": "call_1", "function": {"name": "test", "arguments": "invalid json"}} - ] + tool_calls = [{"id": "call_1", "function": {"name": "test", "arguments": "invalid json"}}] result = api._prepare_tool_calls_for_response(tool_calls) self.assertEqual(result[0]["function_name"], "test") self.assertEqual(result[0]["arguments"], {}) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_prepare_messages_for_request_with_tool_results(self, mock_openai: Mock) -> None: """When tool_call_results are present, they should be applied instead of adding a user message.""" api = OpenAIChatCompletionsAPI() api._update_tool_messages_with_results = MagicMock() initial_count = len(api.messages) - prompt = json.dumps( - {"user_message": "ignored", "tool_call_results": '{"create_point":{"x":1}}'} - ) + prompt = json.dumps({"user_message": "ignored", "tool_call_results": '{"create_point":{"x":1}}'}) api._prepare_messages_for_request(prompt) api._update_tool_messages_with_results.assert_called_once() self.assertEqual(len(api.messages), initial_count) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_prepare_messages_for_request_adds_user_message_without_tool_results(self, mock_openai: Mock) -> None: """Without tool_call_results, request prep should append user message content.""" api = OpenAIChatCompletionsAPI() @@ -201,21 +179,15 @@ def test_prepare_messages_for_request_adds_user_message_without_tool_results(sel self.assertEqual(api.messages[-1]["role"], "user") self.assertIn("Hello", str(api.messages[-1]["content"])) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_create_chat_completion_success(self, mock_openai: Mock) -> None: """Test create_chat_completion with successful response.""" mock_client = MagicMock() mock_openai.return_value = mock_client # Create mock response - mock_message = SimpleNamespace( - content="Test response", - tool_calls=None - ) - mock_choice = SimpleNamespace( - message=mock_message, - finish_reason="stop" - ) + mock_message = SimpleNamespace(content="Test response", tool_calls=None) + mock_choice = SimpleNamespace(message=mock_message, finish_reason="stop") mock_response = SimpleNamespace(choices=[mock_choice]) mock_client.chat.completions.create.return_value = mock_response @@ -226,7 +198,7 @@ def test_create_chat_completion_success(self, mock_openai: Mock) -> None: self.assertEqual(result.message.content, "Test response") self.assertEqual(result.finish_reason, "stop") - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_create_chat_completion_error(self, mock_openai: Mock) -> None: """Test create_chat_completion handles API errors.""" mock_client = MagicMock() @@ -240,7 +212,7 @@ def test_create_chat_completion_error(self, mock_openai: Mock) -> None: self.assertIn("error", result.message.content.lower()) self.assertEqual(result.finish_reason, "error") - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_create_chat_completion_stream_tokens(self, mock_openai: Mock) -> None: """Test create_chat_completion_stream yields token events.""" mock_client = MagicMock() @@ -248,18 +220,15 @@ def test_create_chat_completion_stream_tokens(self, mock_openai: Mock) -> None: # Create mock stream chunks chunks = [ - SimpleNamespace(choices=[SimpleNamespace( - delta=SimpleNamespace(content="Hello", tool_calls=None), - finish_reason=None - )]), - SimpleNamespace(choices=[SimpleNamespace( - delta=SimpleNamespace(content=" world", tool_calls=None), - finish_reason=None - )]), - SimpleNamespace(choices=[SimpleNamespace( - delta=SimpleNamespace(content="!", tool_calls=None), - finish_reason="stop" - )]), + SimpleNamespace( + choices=[SimpleNamespace(delta=SimpleNamespace(content="Hello", tool_calls=None), finish_reason=None)] + ), + SimpleNamespace( + choices=[SimpleNamespace(delta=SimpleNamespace(content=" world", tool_calls=None), finish_reason=None)] + ), + SimpleNamespace( + choices=[SimpleNamespace(delta=SimpleNamespace(content="!", tool_calls=None), finish_reason="stop")] + ), ] mock_client.chat.completions.create.return_value = iter(chunks) @@ -281,7 +250,7 @@ def test_create_chat_completion_stream_tokens(self, mock_openai: Mock) -> None: self.assertEqual(final_events[0]["ai_message"], "Hello world!") self.assertEqual(final_events[0]["finish_reason"], "stop") - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_create_chat_completion_stream_error(self, mock_openai: Mock) -> None: """Test create_chat_completion_stream handles errors.""" mock_client = MagicMock() @@ -297,14 +266,14 @@ def test_create_chat_completion_stream_error(self, mock_openai: Mock) -> None: self.assertEqual(len(final_events), 1) self.assertEqual(final_events[0]["finish_reason"], "error") - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_extract_choice_from_chunk_handles_missing_choices(self, mock_openai: Mock) -> None: """_extract_choice_from_chunk should return None for malformed chunks.""" api = OpenAIChatCompletionsAPI() self.assertIsNone(api._extract_choice_from_chunk(SimpleNamespace())) self.assertIsNone(api._extract_choice_from_chunk(SimpleNamespace(choices=[]))) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_extract_content_piece_returns_empty_for_non_string(self, mock_openai: Mock) -> None: """_extract_content_piece should only return string content.""" api = OpenAIChatCompletionsAPI() @@ -313,7 +282,7 @@ def test_extract_content_piece_returns_empty_for_non_string(self, mock_openai: M self.assertEqual(api._extract_content_piece(SimpleNamespace(content=123)), "") self.assertEqual(api._extract_content_piece(SimpleNamespace(content="ok")), "ok") - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_extract_tool_call_delta_index_defaults_to_zero(self, mock_openai: Mock) -> None: """_extract_tool_call_delta_index should normalize missing/invalid indexes.""" api = OpenAIChatCompletionsAPI() @@ -333,8 +302,8 @@ class TestOpenAIChatCompletionsAPIIntegration(unittest.TestCase): @classmethod def setUpClass(cls) -> None: """Check if API key is available for integration tests.""" - cls.api_key = os.environ.get('OPENAI_API_KEY') - if not cls.api_key or cls.api_key == 'test-api-key': + cls.api_key = os.environ.get("OPENAI_API_KEY") + if not cls.api_key or cls.api_key == "test-api-key": cls.skip_integration = True else: cls.skip_integration = False @@ -351,10 +320,7 @@ def test_integration_simple_completion(self) -> None: api.set_model("gpt-4o-mini") api.max_tokens = 10 - prompt = json.dumps({ - "user_message": "Say: OK", - "use_vision": False - }) + prompt = json.dumps({"user_message": "Say: OK", "use_vision": False}) result = api.create_chat_completion(prompt) @@ -367,10 +333,7 @@ def test_integration_stream_completion(self) -> None: api.set_model("gpt-4o-mini") api.max_tokens = 10 - prompt = json.dumps({ - "user_message": "Say: HI", - "use_vision": False - }) + prompt = json.dumps({"user_message": "Say: HI", "use_vision": False}) events = list(api.create_chat_completion_stream(prompt)) @@ -387,10 +350,7 @@ def test_integration_response_format(self) -> None: api.set_model("gpt-4o-mini") api.max_tokens = 5 - prompt = json.dumps({ - "user_message": "1", - "use_vision": False - }) + prompt = json.dumps({"user_message": "1", "use_vision": False}) events = list(api.create_chat_completion_stream(prompt)) final_event = [e for e in events if e.get("type") == "final"][0] @@ -405,5 +365,5 @@ def test_integration_response_format(self) -> None: self.assertIsInstance(final_event["ai_tool_calls"], list) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/server_tests/test_openai_responses_api.py b/server_tests/test_openai_responses_api.py index 57313037..8eaabc78 100644 --- a/server_tests/test_openai_responses_api.py +++ b/server_tests/test_openai_responses_api.py @@ -22,24 +22,24 @@ class TestOpenAIResponsesAPI(unittest.TestCase): def setUp(self) -> None: """Set up test fixtures.""" - self.original_api_key = os.environ.get('OPENAI_API_KEY') - os.environ['OPENAI_API_KEY'] = 'test-api-key' + self.original_api_key = os.environ.get("OPENAI_API_KEY") + os.environ["OPENAI_API_KEY"] = "test-api-key" def tearDown(self) -> None: """Clean up after tests.""" if self.original_api_key: - os.environ['OPENAI_API_KEY'] = self.original_api_key + os.environ["OPENAI_API_KEY"] = self.original_api_key else: - os.environ.pop('OPENAI_API_KEY', None) + os.environ.pop("OPENAI_API_KEY", None) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_initialization(self, mock_openai: Mock) -> None: """Test API initializes correctly.""" api = OpenAIResponsesAPI() self.assertIsNotNone(api.client) self.assertIsNotNone(api.model) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_convert_messages_to_input_developer_to_system(self, mock_openai: Mock) -> None: """Test _convert_messages_to_input converts developer role to system.""" api = OpenAIResponsesAPI() @@ -50,20 +50,18 @@ def test_convert_messages_to_input_developer_to_system(self, mock_openai: Mock) self.assertEqual(len(result), 1) self.assertEqual(result[0]["role"], "system") # developer -> system - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_convert_messages_to_input_converts_tool_calls_to_text(self, mock_openai: Mock) -> None: """Test _convert_messages_to_input converts tool call/result pairs to text messages.""" api = OpenAIResponsesAPI() - api.messages.append({ - "role": "assistant", - "content": "", - "tool_calls": [{"id": "call_1", "function": {"name": "create_point", "arguments": '{"x": 5}'}}] - }) - api.messages.append({ - "role": "tool", - "tool_call_id": "call_1", - "content": "Point created at (5, 0)" - }) + api.messages.append( + { + "role": "assistant", + "content": "", + "tool_calls": [{"id": "call_1", "function": {"name": "create_point", "arguments": '{"x": 5}'}}], + } + ) + api.messages.append({"role": "tool", "tool_call_id": "call_1", "content": "Point created at (5, 0)"}) result = api._convert_messages_to_input() @@ -80,15 +78,17 @@ def test_convert_messages_to_input_converts_tool_calls_to_text(self, mock_openai self.assertEqual(len(user_msgs), 1) self.assertIn("Point created", user_msgs[0]["content"]) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_convert_messages_skips_pending_tool_calls(self, mock_openai: Mock) -> None: """Test _convert_messages_to_input skips assistant messages with pending tool calls.""" api = OpenAIResponsesAPI() - api.messages.append({ - "role": "assistant", - "content": "", - "tool_calls": [{"id": "call_1", "function": {"name": "test", "arguments": "{}"}}] - }) + api.messages.append( + { + "role": "assistant", + "content": "", + "tool_calls": [{"id": "call_1", "function": {"name": "test", "arguments": "{}"}}], + } + ) # No tool result message - pending result = api._convert_messages_to_input() @@ -97,7 +97,7 @@ def test_convert_messages_skips_pending_tool_calls(self, mock_openai: Mock) -> N self.assertEqual(len(result), 1) self.assertEqual(result[0]["role"], "system") - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_convert_tools_for_responses_api(self, mock_openai: Mock) -> None: """Test _convert_tools_for_responses_api flattens tool format.""" api = OpenAIResponsesAPI() @@ -108,8 +108,8 @@ def test_convert_tools_for_responses_api(self, mock_openai: Mock) -> None: "function": { "name": "create_point", "description": "Create a point", - "parameters": {"type": "object", "properties": {"x": {"type": "number"}}} - } + "parameters": {"type": "object", "properties": {"x": {"type": "number"}}}, + }, } ] @@ -124,35 +124,25 @@ def test_convert_tools_for_responses_api(self, mock_openai: Mock) -> None: # Should NOT have nested "function" key self.assertNotIn("function", result[0]) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_handle_function_call_delta(self, mock_openai: Mock) -> None: """Test _handle_function_call_delta accumulates function call data.""" api = OpenAIResponsesAPI() accumulator: Dict[int, Dict[str, Any]] = {} # First event with call_id and name - event1 = SimpleNamespace( - output_index=0, - call_id="call_123", - name="create_point", - delta='{"x":' - ) + event1 = SimpleNamespace(output_index=0, call_id="call_123", name="create_point", delta='{"x":') api._handle_function_call_delta(event1, accumulator) # Second event with more arguments - event2 = SimpleNamespace( - output_index=0, - call_id=None, - name=None, - delta=' 5}' - ) + event2 = SimpleNamespace(output_index=0, call_id=None, name=None, delta=" 5}") api._handle_function_call_delta(event2, accumulator) self.assertEqual(accumulator[0]["id"], "call_123") self.assertEqual(accumulator[0]["function"]["name"], "create_point") self.assertEqual(accumulator[0]["function"]["arguments"], '{"x": 5}') - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_upsert_tool_call_entry_merges_chunks(self, mock_openai: Mock) -> None: """_upsert_tool_call_entry should merge ids/names/argument chunks.""" api = OpenAIResponsesAPI() @@ -170,7 +160,7 @@ def test_upsert_tool_call_entry_merges_chunks(self, mock_openai: Mock) -> None: index=2, call_id=None, name=None, - args_delta=' 1}', + args_delta=" 1}", ) self.assertIn(2, accumulator) @@ -178,7 +168,7 @@ def test_upsert_tool_call_entry_merges_chunks(self, mock_openai: Mock) -> None: self.assertEqual(accumulator[2]["function"]["name"], "create_point") self.assertEqual(accumulator[2]["function"]["arguments"], '{"x": 1}') - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_extract_tool_calls_preserves_existing_accumulator_entries(self, mock_openai: Mock) -> None: """_extract_tool_calls should not overwrite/append to existing accumulator entries.""" api = OpenAIResponsesAPI() @@ -201,7 +191,7 @@ def test_extract_tool_calls_preserves_existing_accumulator_entries(self, mock_op self.assertEqual(accumulator[0]["id"], "call_existing") self.assertEqual(accumulator[0]["function"]["arguments"], '{"x":') - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_create_tool_call_entry_if_missing_is_idempotent(self, mock_openai: Mock) -> None: """_create_tool_call_entry_if_missing should not replace existing entry.""" api = OpenAIResponsesAPI() @@ -226,7 +216,7 @@ def test_create_tool_call_entry_if_missing_is_idempotent(self, mock_openai: Mock self.assertEqual(accumulator[3]["function"]["name"], "draw") self.assertEqual(accumulator[3]["function"]["arguments"], "{}") - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_extract_tool_calls(self, mock_openai: Mock) -> None: """Test _extract_tool_calls extracts function calls from response.""" api = OpenAIResponsesAPI() @@ -236,10 +226,7 @@ def test_extract_tool_calls(self, mock_openai: Mock) -> None: response = SimpleNamespace( output=[ SimpleNamespace( - type="function_call", - call_id="call_456", - name="create_circle", - arguments='{"radius": 10}' + type="function_call", call_id="call_456", name="create_circle", arguments='{"radius": 10}' ) ] ) @@ -250,13 +237,13 @@ def test_extract_tool_calls(self, mock_openai: Mock) -> None: self.assertEqual(accumulator[0]["function"]["name"], "create_circle") self.assertEqual(accumulator[0]["function"]["arguments"], '{"radius": 10}') - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_normalize_tool_calls(self, mock_openai: Mock) -> None: """Test _normalize_tool_calls converts accumulator to sorted list.""" api = OpenAIResponsesAPI() accumulator = { 1: {"id": "call_2", "function": {"name": "func2", "arguments": "{}"}}, - 0: {"id": "call_1", "function": {"name": "func1", "arguments": "{}"}} + 0: {"id": "call_1", "function": {"name": "func1", "arguments": "{}"}}, } result = api._normalize_tool_calls(accumulator) @@ -265,22 +252,15 @@ def test_normalize_tool_calls(self, mock_openai: Mock) -> None: self.assertEqual(result[0]["id"], "call_1") # Index 0 first self.assertEqual(result[1]["id"], "call_2") # Index 1 second - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_create_assistant_message(self, mock_openai: Mock) -> None: """Test _create_assistant_message creates correct format.""" api = OpenAIResponsesAPI() tool_call = SimpleNamespace( - id="call_123", - function=SimpleNamespace( - name="test_func", - arguments='{"arg": "value"}' - ) - ) - response_message = SimpleNamespace( - content="Test content", - tool_calls=[tool_call] + id="call_123", function=SimpleNamespace(name="test_func", arguments='{"arg": "value"}') ) + response_message = SimpleNamespace(content="Test content", tool_calls=[tool_call]) result = api._create_assistant_message(response_message) @@ -290,13 +270,11 @@ def test_create_assistant_message(self, mock_openai: Mock) -> None: self.assertEqual(result["tool_calls"][0]["id"], "call_123") self.assertEqual(result["tool_calls"][0]["type"], "function") - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_prepare_tool_calls_for_response(self, mock_openai: Mock) -> None: """Test _prepare_tool_calls_for_response formats for JSON response.""" api = OpenAIResponsesAPI() - tool_calls = [ - {"id": "call_1", "function": {"name": "create_point", "arguments": '{"x": 1, "y": 2}'}} - ] + tool_calls = [{"id": "call_1", "function": {"name": "create_point", "arguments": '{"x": 1, "y": 2}'}}] result = api._prepare_tool_calls_for_response(tool_calls) @@ -304,20 +282,18 @@ def test_prepare_tool_calls_for_response(self, mock_openai: Mock) -> None: self.assertEqual(result[0]["function_name"], "create_point") self.assertEqual(result[0]["arguments"], {"x": 1, "y": 2}) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_prepare_tool_calls_invalid_json(self, mock_openai: Mock) -> None: """Test _prepare_tool_calls_for_response handles invalid JSON.""" api = OpenAIResponsesAPI() - tool_calls = [ - {"id": "call_1", "function": {"name": "test", "arguments": "not json"}} - ] + tool_calls = [{"id": "call_1", "function": {"name": "test", "arguments": "not json"}}] result = api._prepare_tool_calls_for_response(tool_calls) self.assertEqual(result[0]["function_name"], "test") self.assertEqual(result[0]["arguments"], {}) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_create_response_stream_reasoning_events(self, mock_openai: Mock) -> None: """Test create_response_stream yields reasoning events.""" mock_client = MagicMock() @@ -328,7 +304,7 @@ def test_create_response_stream_reasoning_events(self, mock_openai: Mock) -> Non SimpleNamespace(type="response.reasoning_text.delta", delta="Thinking..."), SimpleNamespace(type="response.reasoning_text.delta", delta=" about this"), SimpleNamespace(type="response.output_text.delta", delta="Here's the answer"), - SimpleNamespace(type="response.completed", response=SimpleNamespace(status="stop", output=[])) + SimpleNamespace(type="response.completed", response=SimpleNamespace(status="stop", output=[])), ] mock_client.responses.create.return_value = iter(events) @@ -352,7 +328,7 @@ def test_create_response_stream_reasoning_events(self, mock_openai: Mock) -> Non self.assertEqual(len(final_events), 1) self.assertEqual(final_events[0]["ai_message"], "Here's the answer") - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_create_response_stream_error(self, mock_openai: Mock) -> None: """Test create_response_stream handles errors gracefully.""" mock_client = MagicMock() @@ -368,7 +344,7 @@ def test_create_response_stream_error(self, mock_openai: Mock) -> None: self.assertEqual(len(final_events), 1) self.assertEqual(final_events[0]["finish_reason"], "error") - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_create_response_stream_normalizes_completed_status(self, mock_openai: Mock) -> None: """Test that 'completed' status is normalized to 'stop' finish_reason.""" mock_client = MagicMock() @@ -376,7 +352,7 @@ def test_create_response_stream_normalizes_completed_status(self, mock_openai: M events = [ SimpleNamespace(type="response.output_text.delta", delta="Done"), - SimpleNamespace(type="response.completed", response=SimpleNamespace(status="completed", output=[])) + SimpleNamespace(type="response.completed", response=SimpleNamespace(status="completed", output=[])), ] mock_client.responses.create.return_value = iter(events) @@ -388,7 +364,7 @@ def test_create_response_stream_normalizes_completed_status(self, mock_openai: M final_event = [e for e in result_events if e.get("type") == "final"][0] self.assertEqual(final_event["finish_reason"], "stop") - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_create_response_stream_sets_tool_calls_finish_reason(self, mock_openai: Mock) -> None: """Test that finish_reason is 'tool_calls' when there are tool calls.""" mock_client = MagicMock() @@ -400,7 +376,7 @@ def test_create_response_stream_sets_tool_calls_finish_reason(self, mock_openai: output_index=0, call_id="call_123", name="create_point", - delta='{"x": 5}' + delta='{"x": 5}', ), SimpleNamespace( type="response.completed", @@ -408,14 +384,11 @@ def test_create_response_stream_sets_tool_calls_finish_reason(self, mock_openai: status="requires_action", output=[ SimpleNamespace( - type="function_call", - call_id="call_123", - name="create_point", - arguments='{"x": 5}' + type="function_call", call_id="call_123", name="create_point", arguments='{"x": 5}' ) - ] - ) - ) + ], + ), + ), ] mock_client.responses.create.return_value = iter(events) @@ -428,7 +401,7 @@ def test_create_response_stream_sets_tool_calls_finish_reason(self, mock_openai: self.assertEqual(final_event["finish_reason"], "tool_calls") self.assertEqual(len(final_event["ai_tool_calls"]), 1) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_create_response_stream_reasoning_placeholder_sent_once(self, mock_openai: Mock) -> None: """Test that reasoning placeholder is only sent once per stream.""" mock_client = MagicMock() @@ -437,22 +410,16 @@ def test_create_response_stream_reasoning_placeholder_sent_once(self, mock_opena # Multiple reasoning items without summaries events = [ SimpleNamespace( - type="response.output_item.added", - output_index=0, - item=SimpleNamespace(type="reasoning", summary=None) + type="response.output_item.added", output_index=0, item=SimpleNamespace(type="reasoning", summary=None) ), SimpleNamespace( - type="response.output_item.added", - output_index=1, - item=SimpleNamespace(type="reasoning", summary=None) + type="response.output_item.added", output_index=1, item=SimpleNamespace(type="reasoning", summary=None) ), SimpleNamespace( - type="response.output_item.added", - output_index=2, - item=SimpleNamespace(type="reasoning", summary=None) + type="response.output_item.added", output_index=2, item=SimpleNamespace(type="reasoning", summary=None) ), SimpleNamespace(type="response.output_text.delta", delta="Answer"), - SimpleNamespace(type="response.completed", response=SimpleNamespace(status="completed", output=[])) + SimpleNamespace(type="response.completed", response=SimpleNamespace(status="completed", output=[])), ] mock_client.responses.create.return_value = iter(events) @@ -463,14 +430,13 @@ def test_create_response_stream_reasoning_placeholder_sent_once(self, mock_opena # Count reasoning events with placeholder text placeholder_events = [ - e for e in result_events - if e.get("type") == "reasoning" and "Reasoning in progress" in e.get("text", "") + e for e in result_events if e.get("type") == "reasoning" and "Reasoning in progress" in e.get("text", "") ] # Should only have ONE placeholder despite multiple reasoning items self.assertEqual(len(placeholder_events), 1) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_create_response_stream_fallback_without_reasoning_summary(self, mock_openai: Mock) -> None: """Test that API falls back gracefully when reasoning summary not supported.""" mock_client = MagicMock() @@ -480,10 +446,12 @@ def test_create_response_stream_fallback_without_reasoning_summary(self, mock_op def create_side_effect(*args, **kwargs): if kwargs.get("reasoning"): raise Exception("reasoning.summary is not supported") - return iter([ - SimpleNamespace(type="response.output_text.delta", delta="Hello"), - SimpleNamespace(type="response.completed", response=SimpleNamespace(status="completed", output=[])) - ]) + return iter( + [ + SimpleNamespace(type="response.output_text.delta", delta="Hello"), + SimpleNamespace(type="response.completed", response=SimpleNamespace(status="completed", output=[])), + ] + ) mock_client.responses.create.side_effect = create_side_effect @@ -498,7 +466,7 @@ def create_side_effect(*args, **kwargs): self.assertEqual(final_events[0]["ai_message"], "Hello") self.assertNotEqual(final_events[0]["finish_reason"], "error") - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_create_response_stream_yields_reasoning_summaries(self, mock_openai: Mock) -> None: """Test that reasoning summaries are properly yielded when available.""" mock_client = MagicMock() @@ -510,10 +478,10 @@ def test_create_response_stream_yields_reasoning_summaries(self, mock_openai: Mo SimpleNamespace( type="response.output_item.added", output_index=0, - item=SimpleNamespace(type="reasoning", summary=[summary_item]) + item=SimpleNamespace(type="reasoning", summary=[summary_item]), ), SimpleNamespace(type="response.output_text.delta", delta="The answer is 42"), - SimpleNamespace(type="response.completed", response=SimpleNamespace(status="completed", output=[])) + SimpleNamespace(type="response.completed", response=SimpleNamespace(status="completed", output=[])), ] mock_client.responses.create.return_value = iter(events) @@ -527,7 +495,7 @@ def test_create_response_stream_yields_reasoning_summaries(self, mock_openai: Mo self.assertEqual(len(reasoning_events), 1) self.assertIn("analyze this problem", reasoning_events[0]["text"]) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_create_response_stream_multiple_tool_calls(self, mock_openai: Mock) -> None: """Test handling multiple tool calls in a single response.""" mock_client = MagicMock() @@ -539,25 +507,29 @@ def test_create_response_stream_multiple_tool_calls(self, mock_openai: Mock) -> output_index=0, call_id="call_1", name="create_point", - delta='{"x": 1}' + delta='{"x": 1}', ), SimpleNamespace( type="response.function_call_arguments.delta", output_index=1, call_id="call_2", name="create_circle", - delta='{"r": 5}' + delta='{"r": 5}', ), SimpleNamespace( type="response.completed", response=SimpleNamespace( status="requires_action", output=[ - SimpleNamespace(type="function_call", call_id="call_1", name="create_point", arguments='{"x": 1}'), - SimpleNamespace(type="function_call", call_id="call_2", name="create_circle", arguments='{"r": 5}') - ] - ) - ) + SimpleNamespace( + type="function_call", call_id="call_1", name="create_point", arguments='{"x": 1}' + ), + SimpleNamespace( + type="function_call", call_id="call_2", name="create_circle", arguments='{"r": 5}' + ), + ], + ), + ), ] mock_client.responses.create.return_value = iter(events) @@ -571,7 +543,7 @@ def test_create_response_stream_multiple_tool_calls(self, mock_openai: Mock) -> self.assertEqual(final_event["ai_tool_calls"][0]["function_name"], "create_point") self.assertEqual(final_event["ai_tool_calls"][1]["function_name"], "create_circle") - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_create_response_stream_with_tool_calls(self, mock_openai: Mock) -> None: """Test create_response_stream handles function calls.""" mock_client = MagicMock() @@ -585,7 +557,7 @@ def test_create_response_stream_with_tool_calls(self, mock_openai: Mock) -> None output_index=0, call_id="call_123", name="create_point", - delta='{"x": 5, "y": 10}' + delta='{"x": 5, "y": 10}', ), SimpleNamespace( type="response.completed", @@ -593,14 +565,11 @@ def test_create_response_stream_with_tool_calls(self, mock_openai: Mock) -> None status="tool_calls", output=[ SimpleNamespace( - type="function_call", - call_id="call_123", - name="create_point", - arguments='{"x": 5, "y": 10}' + type="function_call", call_id="call_123", name="create_point", arguments='{"x": 5, "y": 10}' ) - ] - ) - ) + ], + ), + ), ] mock_client.responses.create.return_value = iter(events) @@ -627,8 +596,8 @@ class TestOpenAIResponsesAPIIntegration(unittest.TestCase): @classmethod def setUpClass(cls) -> None: """Check if API key is available for integration tests.""" - cls.api_key = os.environ.get('OPENAI_API_KEY') - if not cls.api_key or cls.api_key == 'test-api-key': + cls.api_key = os.environ.get("OPENAI_API_KEY") + if not cls.api_key or cls.api_key == "test-api-key": cls.skip_integration = True else: cls.skip_integration = False @@ -645,10 +614,7 @@ def test_integration_response_stream_format(self) -> None: api.set_model("o4-mini") api.max_tokens = 20 - prompt = json.dumps({ - "user_message": "Say: OK", - "use_vision": False - }) + prompt = json.dumps({"user_message": "Say: OK", "use_vision": False}) events = list(api.create_response_stream(prompt)) @@ -671,10 +637,7 @@ def test_integration_reasoning_tokens(self) -> None: api.set_model("o4-mini") api.max_tokens = 30 - prompt = json.dumps({ - "user_message": "2+2=?", - "use_vision": False - }) + prompt = json.dumps({"user_message": "2+2=?", "use_vision": False}) events = list(api.create_response_stream(prompt)) @@ -688,10 +651,7 @@ def test_integration_reasoning_tokens(self) -> None: # Some org/account states may return an incomplete response with no text # (for example, when reasoning summaries are not available yet). # In those cases, we still validate a well-formed final event. - total_content = "".join( - [e.get("text", "") for e in token_events] + - [final_events[0].get("ai_message", "")] - ) + total_content = "".join([e.get("text", "") for e in token_events] + [final_events[0].get("ai_message", "")]) finish_reason = final_events[0].get("finish_reason", "") self.assertIsInstance(finish_reason, str) if finish_reason == "error": @@ -703,10 +663,7 @@ def test_integration_event_types_are_valid(self) -> None: api.set_model("o4-mini") api.max_tokens = 10 - prompt = json.dumps({ - "user_message": "1", - "use_vision": False - }) + prompt = json.dumps({"user_message": "1", "use_vision": False}) events = list(api.create_response_stream(prompt)) @@ -721,17 +678,17 @@ class TestOpenAIResponsesAPIModelRouting(unittest.TestCase): def setUp(self) -> None: """Set up test fixtures.""" - self.original_api_key = os.environ.get('OPENAI_API_KEY') - os.environ['OPENAI_API_KEY'] = 'test-api-key' + self.original_api_key = os.environ.get("OPENAI_API_KEY") + os.environ["OPENAI_API_KEY"] = "test-api-key" def tearDown(self) -> None: """Clean up after tests.""" if self.original_api_key: - os.environ['OPENAI_API_KEY'] = self.original_api_key + os.environ["OPENAI_API_KEY"] = self.original_api_key else: - os.environ.pop('OPENAI_API_KEY', None) + os.environ.pop("OPENAI_API_KEY", None) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_set_model_to_reasoning(self, mock_openai: Mock) -> None: """Test setting model to a reasoning model.""" api = OpenAIResponsesAPI() @@ -740,7 +697,7 @@ def test_set_model_to_reasoning(self, mock_openai: Mock) -> None: self.assertEqual(api.model.id, "o3") self.assertTrue(api.model.is_reasoning_model) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_set_model_to_o4_mini(self, mock_openai: Mock) -> None: """Test setting model to o4-mini.""" api = OpenAIResponsesAPI() @@ -750,7 +707,7 @@ def test_set_model_to_o4_mini(self, mock_openai: Mock) -> None: self.assertTrue(api.model.is_reasoning_model) self.assertTrue(api.model.has_vision) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_set_model_to_gpt5(self, mock_openai: Mock) -> None: """Test setting model to GPT-5-chat-latest.""" api = OpenAIResponsesAPI() @@ -760,7 +717,7 @@ def test_set_model_to_gpt5(self, mock_openai: Mock) -> None: self.assertTrue(api.model.is_reasoning_model) self.assertTrue(api.model.has_vision) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_set_model_to_gpt52_chat_latest(self, mock_openai: Mock) -> None: """Test setting model to GPT-5.2-chat-latest.""" api = OpenAIResponsesAPI() @@ -770,7 +727,7 @@ def test_set_model_to_gpt52_chat_latest(self, mock_openai: Mock) -> None: self.assertTrue(api.model.is_reasoning_model) self.assertTrue(api.model.has_vision) - @patch('static.openai_api_base.OpenAI') + @patch("static.openai_api_base.OpenAI") def test_set_model_to_gpt52_medium_reasoning(self, mock_openai: Mock) -> None: """Test setting model to GPT-5.2 with medium reasoning effort.""" api = OpenAIResponsesAPI() @@ -782,5 +739,5 @@ def test_set_model_to_gpt52_medium_reasoning(self, mock_openai: Mock) -> None: self.assertEqual(api.model.reasoning_effort, "medium") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/server_tests/test_plot_tool_schemas.py b/server_tests/test_plot_tool_schemas.py index a704d805..52339d84 100644 --- a/server_tests/test_plot_tool_schemas.py +++ b/server_tests/test_plot_tool_schemas.py @@ -102,16 +102,24 @@ def test_plot_distribution_schema(self) -> None: self.assertEqual(plot_bounds.get("additionalProperties"), False) self.assertEqual(plot_bounds.get("required"), ["left_bound", "right_bound"]) pb_props = _require_dict(plot_bounds.get("properties"), "plot_bounds.properties") - self.assertEqual(_require_dict(pb_props.get("left_bound"), "plot_bounds.left_bound").get("type"), ["number", "null"]) - self.assertEqual(_require_dict(pb_props.get("right_bound"), "plot_bounds.right_bound").get("type"), ["number", "null"]) + self.assertEqual( + _require_dict(pb_props.get("left_bound"), "plot_bounds.left_bound").get("type"), ["number", "null"] + ) + self.assertEqual( + _require_dict(pb_props.get("right_bound"), "plot_bounds.right_bound").get("type"), ["number", "null"] + ) shade_bounds = _require_dict(props.get("shade_bounds"), "shade_bounds") self.assertEqual(shade_bounds.get("type"), ["object", "null"]) self.assertEqual(shade_bounds.get("additionalProperties"), False) self.assertEqual(shade_bounds.get("required"), ["left_bound", "right_bound"]) sb_props = _require_dict(shade_bounds.get("properties"), "shade_bounds.properties") - self.assertEqual(_require_dict(sb_props.get("left_bound"), "shade_bounds.left_bound").get("type"), ["number", "null"]) - self.assertEqual(_require_dict(sb_props.get("right_bound"), "shade_bounds.right_bound").get("type"), ["number", "null"]) + self.assertEqual( + _require_dict(sb_props.get("left_bound"), "shade_bounds.left_bound").get("type"), ["number", "null"] + ) + self.assertEqual( + _require_dict(sb_props.get("right_bound"), "shade_bounds.right_bound").get("type"), ["number", "null"] + ) def test_plot_bars_schema(self) -> None: tool = _find_tool("plot_bars") @@ -177,5 +185,3 @@ def test_delete_plot_schema(self) -> None: name = _require_dict(props.get("name"), "delete_plot.name") self.assertEqual(name.get("type"), "string") - - diff --git a/server_tests/test_polar_conversion.py b/server_tests/test_polar_conversion.py index a4d5246e..c458d571 100644 --- a/server_tests/test_polar_conversion.py +++ b/server_tests/test_polar_conversion.py @@ -274,5 +274,5 @@ def test_floating_point_precision(self) -> None: self.assertAlmostEqual(y, y_back, places=10) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/server_tests/test_polar_grid.py b/server_tests/test_polar_grid.py index 1e74508e..73eb1e9a 100644 --- a/server_tests/test_polar_grid.py +++ b/server_tests/test_polar_grid.py @@ -26,6 +26,7 @@ class Position: """Simple Position class for testing.""" + def __init__(self, x: float = 0, y: float = 0): self.x = x self.y = y @@ -76,13 +77,10 @@ def max_radius_screen(self) -> float: if self.width is None or self.height is None: return 0 ox, oy = self.origin_screen - corners = [ - (0, 0), (self.width, 0), - (0, self.height), (self.width, self.height) - ] + corners = [(0, 0), (self.width, 0), (0, self.height), (self.width, self.height)] max_dist = 0 for cx, cy in corners: - dist = math.sqrt((cx - ox)**2 + (cy - oy)**2) + dist = math.sqrt((cx - ox) ** 2 + (cy - oy) ** 2) max_dist = max(max_dist, dist) return max_dist * 1.1 @@ -240,7 +238,7 @@ def test_max_radius_screen_calculation(self) -> None: max_radius = grid.max_radius_screen # Should be at least the diagonal from center to corner - expected_min = math.sqrt((400)**2 + (300)**2) + expected_min = math.sqrt((400) ** 2 + (300) ** 2) self.assertGreaterEqual(max_radius, expected_min) def test_max_radius_math_calculation(self) -> None: @@ -350,7 +348,7 @@ def test_get_radial_circles(self) -> None: if len(circles) >= 2: spacing = circles[1] - circles[0] for i in range(2, len(circles)): - self.assertAlmostEqual(circles[i] - circles[i-1], spacing, places=6) + self.assertAlmostEqual(circles[i] - circles[i - 1], spacing, places=6) class TestPolarGridOrigin(unittest.TestCase): @@ -371,10 +369,7 @@ def test_origin_screen_default(self) -> None: def test_origin_screen_with_offset(self) -> None: """Test origin screen position with pan offset.""" - mapper = create_mock_coordinate_mapper( - canvas_width=800, canvas_height=600, - offset_x=50, offset_y=-30 - ) + mapper = create_mock_coordinate_mapper(canvas_width=800, canvas_height=600, offset_x=50, offset_y=-30) grid = PolarGrid(mapper) grid.width = 800 grid.height = 600 @@ -392,8 +387,9 @@ class TestPolarGridState(unittest.TestCase): def test_get_state(self) -> None: """Test state serialization.""" mapper = create_mock_coordinate_mapper() - grid = PolarGrid(mapper, angular_divisions=8, radial_spacing=2.0, - show_angle_labels=False, show_radius_labels=True) + grid = PolarGrid( + mapper, angular_divisions=8, radial_spacing=2.0, show_angle_labels=False, show_radius_labels=True + ) state = grid.get_state() @@ -528,8 +524,8 @@ def test_very_large_dimensions(self) -> None: max_radius = grid.max_radius_screen self.assertGreater(max_radius, 0) - self.assertLess(max_radius, float('inf')) + self.assertLess(max_radius, float("inf")) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/server_tests/test_position.py b/server_tests/test_position.py index c71728c1..4f8fa984 100644 --- a/server_tests/test_position.py +++ b/server_tests/test_position.py @@ -15,7 +15,7 @@ def test_init(self) -> None: def test_str(self) -> None: pos = Position(1, 2) - self.assertEqual(str(pos), 'Position: 1, 2') + self.assertEqual(str(pos), "Position: 1, 2") def test_get_state(self) -> None: pos = Position(1, 2) diff --git a/server_tests/test_provider_connections.py b/server_tests/test_provider_connections.py index 219d0700..6c82f952 100644 --- a/server_tests/test_provider_connections.py +++ b/server_tests/test_provider_connections.py @@ -43,16 +43,19 @@ def has_openrouter_key() -> bool: def is_error_response(content: str) -> bool: """Check if the response indicates an error.""" lower = content.lower() - return any(phrase in lower for phrase in [ - "encountered an error", - "credit balance", - "billing", - "insufficient", - "purchase credits", - "quota exceeded", - "rate limit", - "try again", - ]) + return any( + phrase in lower + for phrase in [ + "encountered an error", + "credit balance", + "billing", + "insufficient", + "purchase credits", + "quota exceeded", + "rate limit", + "try again", + ] + ) class TestOpenAIConnection(unittest.TestCase): diff --git a/server_tests/test_regression_pure.py b/server_tests/test_regression_pure.py index 2eef371b..c3b39cbe 100644 --- a/server_tests/test_regression_pure.py +++ b/server_tests/test_regression_pure.py @@ -404,7 +404,7 @@ def test_sinusoidal_fit(self) -> None: def test_sinusoidal_with_phase_shift(self) -> None: # y = sin(x + pi/4) x = [i * 0.5 for i in range(20)] - y = [math.sin(xi + math.pi/4) for xi in x] + y = [math.sin(xi + math.pi / 4) for xi in x] result = fit_sinusoidal(x, y) self.assertGreater(result["r_squared"], 0.9) @@ -421,8 +421,7 @@ class TestFitRegressionDispatcher(unittest.TestCase): """Tests for the main fit_regression dispatcher function.""" def test_supported_model_types(self) -> None: - expected = ("linear", "polynomial", "exponential", "logarithmic", - "power", "logistic", "sinusoidal") + expected = ("linear", "polynomial", "exponential", "logarithmic", "power", "logistic", "sinusoidal") self.assertEqual(SUPPORTED_MODEL_TYPES, expected) def test_dispatch_linear(self) -> None: diff --git a/server_tests/test_routes.py b/server_tests/test_routes.py index 7f281a39..e765f640 100644 --- a/server_tests/test_routes.py +++ b/server_tests/test_routes.py @@ -13,20 +13,17 @@ class TestRoutes(unittest.TestCase): - SAMPLE_PNG_BASE64 = ( - "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8" - "/w8AAwMB/aqVw0sAAAAASUVORK5CYII=" - ) + SAMPLE_PNG_BASE64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/w8AAwMB/aqVw0sAAAAASUVORK5CYII=" def setUp(self) -> None: """Set up test client before each test.""" # Set test environment variables to disable authentication - self.original_require_auth: Optional[str] = os.environ.get('REQUIRE_AUTH') - os.environ['REQUIRE_AUTH'] = 'false' + self.original_require_auth: Optional[str] = os.environ.get("REQUIRE_AUTH") + os.environ["REQUIRE_AUTH"] = "false" self.app: MatHudFlask = AppManager.create_app() self.client = self.app.test_client() - self.app.config['TESTING'] = True + self.app.config["TESTING"] = True # Ensure webdriver_manager is None at start self.app.webdriver_manager = None self._remove_canvas_snapshot() @@ -34,15 +31,15 @@ def setUp(self) -> None: def tearDown(self) -> None: """Clean up after each test.""" # Clean up webdriver if it exists - if hasattr(self.app, 'webdriver_manager') and self.app.webdriver_manager: + if hasattr(self.app, "webdriver_manager") and self.app.webdriver_manager: self.app.webdriver_manager = None self._remove_canvas_snapshot() # Restore original REQUIRE_AUTH environment variable if self.original_require_auth is not None: - os.environ['REQUIRE_AUTH'] = self.original_require_auth + os.environ["REQUIRE_AUTH"] = self.original_require_auth else: - os.environ.pop('REQUIRE_AUTH', None) + os.environ.pop("REQUIRE_AUTH", None) def _remove_canvas_snapshot(self) -> None: if os.path.exists(CANVAS_SNAPSHOT_PATH): @@ -51,66 +48,66 @@ def _remove_canvas_snapshot(self) -> None: except OSError: pass - @patch('static.webdriver_manager.WebDriverManager') + @patch("static.webdriver_manager.WebDriverManager") def test_init_webdriver_route(self, mock_webdriver_class: Mock) -> None: """Test the webdriver initialization route.""" # Create a mock instance mock_instance = Mock() mock_webdriver_class.return_value = mock_instance - response = self.client.get('/init_webdriver') + response = self.client.get("/init_webdriver") data = json.loads(response.data) self.assertEqual(response.status_code, 200) - self.assertEqual(data['status'], 'success') - self.assertEqual(data['message'], 'WebDriver initialization successful') + self.assertEqual(data["status"], "success") + self.assertEqual(data["message"], "WebDriver initialization successful") mock_webdriver_class.assert_called_once() self.assertEqual(self.app.webdriver_manager, mock_instance) def test_index_route(self) -> None: """Test the index route returns HTML.""" - response = self.client.get('/') + response = self.client.get("/") self.assertEqual(response.status_code, 200) - self.assertIn('text/html', response.content_type) + self.assertIn("text/html", response.content_type) def test_workspace_operations(self) -> None: """Test workspace CRUD operations.""" # Test creating a workspace - test_state = {'test': 'data'} - response = self.client.post('/save_workspace', - json={'state': test_state, 'name': 'test_workspace'}) + test_state = {"test": "data"} + response = self.client.post("/save_workspace", json={"state": test_state, "name": "test_workspace"}) data = json.loads(response.data) self.assertEqual(response.status_code, 200) - self.assertEqual(data['status'], 'success') + self.assertEqual(data["status"], "success") # Test listing workspaces - response = self.client.get('/list_workspaces') + response = self.client.get("/list_workspaces") data = json.loads(response.data) self.assertEqual(response.status_code, 200) - self.assertEqual(data['status'], 'success') - self.assertIn('test_workspace', data['data']) + self.assertEqual(data["status"], "success") + self.assertIn("test_workspace", data["data"]) # Test loading workspace - response = self.client.get('/load_workspace?name=test_workspace') + response = self.client.get("/load_workspace?name=test_workspace") data = json.loads(response.data) self.assertEqual(response.status_code, 200) - self.assertEqual(data['status'], 'success') - self.assertEqual(data['data']['state'], test_state) + self.assertEqual(data["status"], "success") + self.assertEqual(data["data"]["state"], test_state) # Test deleting workspace - response = self.client.get('/delete_workspace?name=test_workspace') + response = self.client.get("/delete_workspace?name=test_workspace") data = json.loads(response.data) self.assertEqual(response.status_code, 200) - self.assertEqual(data['status'], 'success') + self.assertEqual(data["status"], "success") # Verify workspace is deleted - response = self.client.get('/list_workspaces') + response = self.client.get("/list_workspaces") data = json.loads(response.data) - self.assertNotIn('test_workspace', data['data']) + self.assertNotIn("test_workspace", data["data"]) - @patch('static.openai_completions_api.OpenAIChatCompletionsAPI.create_chat_completion') + @patch("static.openai_completions_api.OpenAIChatCompletionsAPI.create_chat_completion") def test_send_message(self, mock_chat: Mock) -> None: """Test the send_message route.""" + # Configure mock response class MockMessage: content = "Test response" @@ -122,19 +119,13 @@ class MockResponse: mock_chat.return_value = MockResponse() - test_message = { - 'message': json.dumps({ - 'user_message': 'test message', - 'use_vision': False - }), - 'svg_state': None - } - response = self.client.post('/send_message', json=test_message) + test_message = {"message": json.dumps({"user_message": "test message", "use_vision": False}), "svg_state": None} + response = self.client.post("/send_message", json=test_message) data = json.loads(response.data) self.assertEqual(response.status_code, 200) - self.assertEqual(data['status'], 'success') - self.assertIn('ai_message', data['data']) - self.assertIn('ai_tool_calls', data['data']) + self.assertEqual(data["status"], "success") + self.assertIn("ai_message", data["data"]) + self.assertIn("ai_tool_calls", data["data"]) def test_save_canvas_snapshot_helper(self) -> None: data_url = f"data:image/png;base64,{self.SAMPLE_PNG_BASE64}" @@ -143,7 +134,7 @@ def test_save_canvas_snapshot_helper(self) -> None: self.assertTrue(os.path.exists(CANVAS_SNAPSHOT_PATH)) self.assertGreater(os.path.getsize(CANVAS_SNAPSHOT_PATH), 0) - @patch('static.openai_completions_api.OpenAIChatCompletionsAPI.create_chat_completion') + @patch("static.openai_completions_api.OpenAIChatCompletionsAPI.create_chat_completion") def test_send_message_uses_canvas_snapshot(self, mock_chat: Mock) -> None: class MockMessage: content = "Test response with vision" @@ -157,21 +148,15 @@ class MockResponse: canvas_data_url = f"data:image/png;base64,{self.SAMPLE_PNG_BASE64}" payload = { - 'message': json.dumps({ - 'user_message': 'vision request', - 'use_vision': True - }), - 'vision_snapshot': { - 'renderer_mode': 'canvas2d', - 'canvas_image': canvas_data_url - } + "message": json.dumps({"user_message": "vision request", "use_vision": True}), + "vision_snapshot": {"renderer_mode": "canvas2d", "canvas_image": canvas_data_url}, } - response = self.client.post('/send_message', json=payload) + response = self.client.post("/send_message", json=payload) data = json.loads(response.data) self.assertEqual(response.status_code, 200) - self.assertEqual(data['status'], 'success') + self.assertEqual(data["status"], "success") self.assertTrue(os.path.exists(CANVAS_SNAPSHOT_PATH)) self.assertGreater(os.path.getsize(CANVAS_SNAPSHOT_PATH), 0) self.assertIsNone(self.app.webdriver_manager) @@ -189,13 +174,13 @@ def test_new_conversation_route(self) -> None: self.assertGreater(len(self.app.responses_api.messages), 1) # Now, call the new_conversation route - response = self.client.post('/new_conversation') + response = self.client.post("/new_conversation") data = json.loads(response.data) # Check that the response is successful self.assertEqual(response.status_code, 200) - self.assertEqual(data['status'], 'success') - self.assertEqual(data['message'], 'New conversation started.') + self.assertEqual(data["status"], "success") + self.assertEqual(data["message"], "New conversation started.") # Check that the conversation history has been reset self.assertEqual(len(self.app.ai_api.messages), 1) @@ -221,7 +206,7 @@ def test_debug_canvas_state_comparison_returns_metrics(self) -> None: } } - response = self.client.post('/api/debug/canvas-state-comparison', json=payload) + response = self.client.post("/api/debug/canvas-state-comparison", json=payload) data = json.loads(response.data) self.assertEqual(response.status_code, 200) @@ -235,11 +220,12 @@ def test_debug_canvas_state_comparison_returns_metrics(self) -> None: self.assertNotIn("_p1_coords", summary_segments[0]) def test_debug_canvas_state_comparison_is_disabled_when_deployed(self) -> None: - with patch("static.routes.AppManager.requires_auth", return_value=False), patch( - "static.routes.AppManager.is_deployed", return_value=True + with ( + patch("static.routes.AppManager.requires_auth", return_value=False), + patch("static.routes.AppManager.is_deployed", return_value=True), ): response = self.client.post( - '/api/debug/canvas-state-comparison', + "/api/debug/canvas-state-comparison", json={"canvas_state": {"Points": []}}, ) @@ -250,23 +236,23 @@ def test_debug_canvas_state_comparison_is_disabled_when_deployed(self) -> None: def test_error_handling(self) -> None: """Test error handling in routes.""" # Test invalid workspace name - response = self.client.get('/load_workspace?name=nonexistent') + response = self.client.get("/load_workspace?name=nonexistent") data = json.loads(response.data) self.assertEqual(response.status_code, 404) - self.assertEqual(data['status'], 'error') + self.assertEqual(data["status"], "error") # Test missing workspace name - response = self.client.get('/delete_workspace') + response = self.client.get("/delete_workspace") data = json.loads(response.data) self.assertEqual(response.status_code, 400) - self.assertEqual(data['status'], 'error') + self.assertEqual(data["status"], "error") # Test missing message - response = self.client.post('/send_message', json={'invalid': 'format'}) + response = self.client.post("/send_message", json={"invalid": "format"}) data = json.loads(response.data) self.assertEqual(response.status_code, 400) # Changed from 500 to 400 for invalid request - self.assertEqual(data['status'], 'error') - self.assertIn('message', data['message'].lower()) # Error message should mention 'message' + self.assertEqual(data["status"], "error") + self.assertIn("message", data["message"].lower()) # Error message should mention 'message' class TestAPIRouting(unittest.TestCase): @@ -274,19 +260,19 @@ class TestAPIRouting(unittest.TestCase): def setUp(self) -> None: """Set up test client before each test.""" - self.original_require_auth: Optional[str] = os.environ.get('REQUIRE_AUTH') - os.environ['REQUIRE_AUTH'] = 'false' + self.original_require_auth: Optional[str] = os.environ.get("REQUIRE_AUTH") + os.environ["REQUIRE_AUTH"] = "false" self.app: MatHudFlask = AppManager.create_app() self.client = self.app.test_client() - self.app.config['TESTING'] = True + self.app.config["TESTING"] = True def tearDown(self) -> None: """Clean up after each test.""" if self.original_require_auth is not None: - os.environ['REQUIRE_AUTH'] = self.original_require_auth + os.environ["REQUIRE_AUTH"] = self.original_require_auth else: - os.environ.pop('REQUIRE_AUTH', None) + os.environ.pop("REQUIRE_AUTH", None) def test_app_has_both_apis(self) -> None: """Test that app manager initializes both API instances.""" @@ -320,53 +306,61 @@ def test_reasoning_model_identified(self) -> None: self.assertEqual(self.app.responses_api.get_model().id, "o3") self.assertTrue(self.app.ai_api.get_model().is_reasoning_model) - @patch.object(OpenAIChatCompletionsAPI, 'create_chat_completion_stream') + @patch.object(OpenAIChatCompletionsAPI, "create_chat_completion_stream") def test_stream_route_uses_completions_for_standard_model(self, mock_stream: Mock) -> None: """Test /send_message_stream uses Chat Completions for standard models.""" # Configure mock to return an iterator - mock_stream.return_value = iter([ - {"type": "token", "text": "Hello"}, - {"type": "final", "ai_message": "Hello", "ai_tool_calls": [], "finish_reason": "stop"} - ]) + mock_stream.return_value = iter( + [ + {"type": "token", "text": "Hello"}, + {"type": "final", "ai_message": "Hello", "ai_tool_calls": [], "finish_reason": "stop"}, + ] + ) test_message = { - 'message': json.dumps({ - 'user_message': 'test', - 'use_vision': False, - 'ai_model': 'gpt-4o-mini' # Standard model - }), - 'svg_state': None + "message": json.dumps( + { + "user_message": "test", + "use_vision": False, + "ai_model": "gpt-4o-mini", # Standard model + } + ), + "svg_state": None, } - self.client.post('/send_message_stream', json=test_message) + self.client.post("/send_message_stream", json=test_message) # Check that Chat Completions API was called (not Responses API) mock_stream.assert_called_once() - @patch.object(OpenAIResponsesAPI, 'create_response_stream') - @patch.object(OpenAIChatCompletionsAPI, 'set_model') + @patch.object(OpenAIResponsesAPI, "create_response_stream") + @patch.object(OpenAIChatCompletionsAPI, "set_model") def test_stream_route_uses_responses_for_reasoning_model(self, mock_set_model: Mock, mock_stream: Mock) -> None: """Test /send_message_stream uses Responses API for reasoning models.""" # Configure mock to return an iterator - mock_stream.return_value = iter([ - {"type": "reasoning", "text": "Thinking..."}, - {"type": "token", "text": "Answer"}, - {"type": "final", "ai_message": "Answer", "ai_tool_calls": [], "finish_reason": "stop"} - ]) + mock_stream.return_value = iter( + [ + {"type": "reasoning", "text": "Thinking..."}, + {"type": "token", "text": "Answer"}, + {"type": "final", "ai_message": "Answer", "ai_tool_calls": [], "finish_reason": "stop"}, + ] + ) # Pre-set the model to a reasoning model self.app.ai_api.model = self.app.ai_api.model.from_identifier("o3") test_message = { - 'message': json.dumps({ - 'user_message': 'test', - 'use_vision': False, - 'ai_model': 'o3' # Reasoning model - }), - 'svg_state': None + "message": json.dumps( + { + "user_message": "test", + "use_vision": False, + "ai_model": "o3", # Reasoning model + } + ), + "svg_state": None, } - self.client.post('/send_message_stream', json=test_message) + self.client.post("/send_message_stream", json=test_message) # Check that Responses API was called mock_stream.assert_called_once() @@ -392,44 +386,42 @@ class TestStreamingResponseFormat(unittest.TestCase): def setUp(self) -> None: """Set up test client before each test.""" - self.original_require_auth: Optional[str] = os.environ.get('REQUIRE_AUTH') - os.environ['REQUIRE_AUTH'] = 'false' + self.original_require_auth: Optional[str] = os.environ.get("REQUIRE_AUTH") + os.environ["REQUIRE_AUTH"] = "false" self.app: MatHudFlask = AppManager.create_app() self.client = self.app.test_client() - self.app.config['TESTING'] = True + self.app.config["TESTING"] = True def tearDown(self) -> None: """Clean up after each test.""" if self.original_require_auth is not None: - os.environ['REQUIRE_AUTH'] = self.original_require_auth + os.environ["REQUIRE_AUTH"] = self.original_require_auth else: - os.environ.pop('REQUIRE_AUTH', None) + os.environ.pop("REQUIRE_AUTH", None) - @patch.object(OpenAIChatCompletionsAPI, 'create_chat_completion_stream') + @patch.object(OpenAIChatCompletionsAPI, "create_chat_completion_stream") def test_stream_response_ndjson_format(self, mock_stream: Mock) -> None: """Test streaming response is in NDJSON format.""" - mock_stream.return_value = iter([ - {"type": "token", "text": "Hello"}, - {"type": "final", "ai_message": "Hello", "ai_tool_calls": [], "finish_reason": "stop"} - ]) + mock_stream.return_value = iter( + [ + {"type": "token", "text": "Hello"}, + {"type": "final", "ai_message": "Hello", "ai_tool_calls": [], "finish_reason": "stop"}, + ] + ) test_message = { - 'message': json.dumps({ - 'user_message': 'test', - 'use_vision': False, - 'ai_model': 'gpt-4o-mini' - }), - 'svg_state': None + "message": json.dumps({"user_message": "test", "use_vision": False, "ai_model": "gpt-4o-mini"}), + "svg_state": None, } - response = self.client.post('/send_message_stream', json=test_message) + response = self.client.post("/send_message_stream", json=test_message) # Check content type - self.assertEqual(response.content_type, 'application/x-ndjson') + self.assertEqual(response.content_type, "application/x-ndjson") # Parse NDJSON response - lines = response.data.decode('utf-8').strip().split('\n') + lines = response.data.decode("utf-8").strip().split("\n") events = [json.loads(line) for line in lines if line.strip()] # Should have token and final events @@ -441,32 +433,30 @@ def test_stream_response_ndjson_format(self, mock_stream: Mock) -> None: self.assertIn("ai_tool_calls", events[-1]) self.assertIn("finish_reason", events[-1]) - @patch.object(OpenAIResponsesAPI, 'create_response_stream') - @patch.object(OpenAIChatCompletionsAPI, 'set_model') + @patch.object(OpenAIResponsesAPI, "create_response_stream") + @patch.object(OpenAIChatCompletionsAPI, "set_model") def test_reasoning_stream_includes_reasoning_events(self, mock_set_model: Mock, mock_stream: Mock) -> None: """Test reasoning model stream includes reasoning events.""" - mock_stream.return_value = iter([ - {"type": "reasoning", "text": "Let me think..."}, - {"type": "token", "text": "The answer is"}, - {"type": "final", "ai_message": "The answer is", "ai_tool_calls": [], "finish_reason": "stop"} - ]) + mock_stream.return_value = iter( + [ + {"type": "reasoning", "text": "Let me think..."}, + {"type": "token", "text": "The answer is"}, + {"type": "final", "ai_message": "The answer is", "ai_tool_calls": [], "finish_reason": "stop"}, + ] + ) # Pre-set the model to a reasoning model self.app.ai_api.model = self.app.ai_api.model.from_identifier("o3") test_message = { - 'message': json.dumps({ - 'user_message': 'test', - 'use_vision': False, - 'ai_model': 'o3' - }), - 'svg_state': None + "message": json.dumps({"user_message": "test", "use_vision": False, "ai_model": "o3"}), + "svg_state": None, } - response = self.client.post('/send_message_stream', json=test_message) + response = self.client.post("/send_message_stream", json=test_message) # Parse NDJSON response - lines = response.data.decode('utf-8').strip().split('\n') + lines = response.data.decode("utf-8").strip().split("\n") events = [json.loads(line) for line in lines if line.strip()] # Should have reasoning event @@ -480,26 +470,26 @@ class TestInterceptSearchTools(unittest.TestCase): def setUp(self) -> None: """Set up test client before each test.""" - self.original_require_auth: Optional[str] = os.environ.get('REQUIRE_AUTH') - os.environ['REQUIRE_AUTH'] = 'false' + self.original_require_auth: Optional[str] = os.environ.get("REQUIRE_AUTH") + os.environ["REQUIRE_AUTH"] = "false" self.app: MatHudFlask = AppManager.create_app() - self.app.config['TESTING'] = True + self.app.config["TESTING"] = True def tearDown(self) -> None: """Clean up after each test.""" if self.original_require_auth is not None: - os.environ['REQUIRE_AUTH'] = self.original_require_auth + os.environ["REQUIRE_AUTH"] = self.original_require_auth else: - os.environ.pop('REQUIRE_AUTH', None) + os.environ.pop("REQUIRE_AUTH", None) def test_no_search_tools_returns_unchanged(self) -> None: """When no search_tools call, tool_calls should be returned unchanged.""" from static.routes import _intercept_search_tools tool_calls = [ - {'function_name': 'create_circle', 'arguments': {'x': 0, 'y': 0}}, - {'function_name': 'create_point', 'arguments': {'x': 10, 'y': 20}}, + {"function_name": "create_circle", "arguments": {"x": 0, "y": 0}}, + {"function_name": "create_point", "arguments": {"x": 10, "y": 20}}, ] result = _intercept_search_tools(self.app, tool_calls) @@ -515,7 +505,7 @@ def test_empty_tool_calls_returns_empty(self) -> None: self.assertEqual(result, []) - @patch('static.tool_search_service.ToolSearchService') + @patch("static.tool_search_service.ToolSearchService") def test_filters_disallowed_tools(self, mock_service_class: Mock) -> None: """Tools not in search results should be filtered out.""" from static.routes import _intercept_search_tools @@ -523,26 +513,26 @@ def test_filters_disallowed_tools(self, mock_service_class: Mock) -> None: # Mock search_tools to return only create_circle mock_service = Mock() mock_service.search_tools.return_value = [ - {'function': {'name': 'create_circle'}}, + {"function": {"name": "create_circle"}}, ] mock_service_class.return_value = mock_service tool_calls = [ - {'function_name': 'search_tools', 'arguments': {'query': 'draw circle'}}, - {'function_name': 'create_circle', 'arguments': {'x': 0, 'y': 0}}, - {'function_name': 'delete_all', 'arguments': {}}, # Not in search results + {"function_name": "search_tools", "arguments": {"query": "draw circle"}}, + {"function_name": "create_circle", "arguments": {"x": 0, "y": 0}}, + {"function_name": "delete_all", "arguments": {}}, # Not in search results ] result = _intercept_search_tools(self.app, tool_calls) # search_tools and create_circle should be allowed (essentials include search_tools) # delete_all should be filtered out - names = [c.get('function_name') for c in result] - self.assertIn('search_tools', names) - self.assertIn('create_circle', names) - self.assertNotIn('delete_all', names) + names = [c.get("function_name") for c in result] + self.assertIn("search_tools", names) + self.assertIn("create_circle", names) + self.assertNotIn("delete_all", names) - @patch('static.tool_search_service.ToolSearchService') + @patch("static.tool_search_service.ToolSearchService") def test_essential_tools_always_allowed(self, mock_service_class: Mock) -> None: """Essential tools should always be allowed even if not in search results.""" from static.routes import _intercept_search_tools @@ -550,29 +540,29 @@ def test_essential_tools_always_allowed(self, mock_service_class: Mock) -> None: # Mock search_tools to return only create_circle (no essentials) mock_service = Mock() mock_service.search_tools.return_value = [ - {'function': {'name': 'create_circle'}}, + {"function": {"name": "create_circle"}}, ] mock_service_class.return_value = mock_service tool_calls = [ - {'function_name': 'search_tools', 'arguments': {'query': 'draw'}}, - {'function_name': 'undo', 'arguments': {}}, # Essential tool - {'function_name': 'create_circle', 'arguments': {}}, + {"function_name": "search_tools", "arguments": {"query": "draw"}}, + {"function_name": "undo", "arguments": {}}, # Essential tool + {"function_name": "create_circle", "arguments": {}}, ] result = _intercept_search_tools(self.app, tool_calls) - names = [c.get('function_name') for c in result] - self.assertIn('undo', names) # Essential should be allowed - self.assertIn('create_circle', names) + names = [c.get("function_name") for c in result] + self.assertIn("undo", names) # Essential should be allowed + self.assertIn("create_circle", names) def test_empty_query_returns_unchanged(self) -> None: """search_tools with empty query should return tool_calls unchanged.""" from static.routes import _intercept_search_tools tool_calls = [ - {'function_name': 'search_tools', 'arguments': {'query': ''}}, - {'function_name': 'create_circle', 'arguments': {}}, + {"function_name": "search_tools", "arguments": {"query": ""}}, + {"function_name": "create_circle", "arguments": {}}, ] result = _intercept_search_tools(self.app, tool_calls) @@ -580,7 +570,7 @@ def test_empty_query_returns_unchanged(self) -> None: # Should return unchanged because query is empty self.assertEqual(result, tool_calls) - @patch('static.tool_search_service.ToolSearchService') + @patch("static.tool_search_service.ToolSearchService") def test_service_error_returns_unchanged(self, mock_service_class: Mock) -> None: """On ToolSearchService error, original tool_calls should be returned.""" from static.routes import _intercept_search_tools @@ -590,9 +580,9 @@ def test_service_error_returns_unchanged(self, mock_service_class: Mock) -> None mock_service_class.return_value = mock_service tool_calls = [ - {'function_name': 'search_tools', 'arguments': {'query': 'draw'}}, - {'function_name': 'create_circle', 'arguments': {}}, - {'function_name': 'delete_all', 'arguments': {}}, + {"function_name": "search_tools", "arguments": {"query": "draw"}}, + {"function_name": "create_circle", "arguments": {}}, + {"function_name": "delete_all", "arguments": {}}, ] result = _intercept_search_tools(self.app, tool_calls) @@ -600,76 +590,74 @@ def test_service_error_returns_unchanged(self, mock_service_class: Mock) -> None # On error, should return original calls unchanged self.assertEqual(result, tool_calls) - @patch('static.tool_search_service.ToolSearchService') + @patch("static.tool_search_service.ToolSearchService") def test_handles_json_string_arguments(self, mock_service_class: Mock) -> None: """Should handle arguments as JSON string (from some API responses).""" from static.routes import _intercept_search_tools mock_service = Mock() mock_service.search_tools.return_value = [ - {'function': {'name': 'create_point'}}, + {"function": {"name": "create_point"}}, ] mock_service_class.return_value = mock_service tool_calls: list[dict[str, Any]] = [ { - 'function_name': 'search_tools', - 'arguments': '{"query": "point", "max_results": 5}' # JSON string + "function_name": "search_tools", + "arguments": '{"query": "point", "max_results": 5}', # JSON string }, - {'function_name': 'create_point', 'arguments': {}}, + {"function_name": "create_point", "arguments": {}}, ] result = _intercept_search_tools(self.app, tool_calls) - names = [c.get('function_name') for c in result] - self.assertIn('search_tools', names) - self.assertIn('create_point', names) + names = [c.get("function_name") for c in result] + self.assertIn("search_tools", names) + self.assertIn("create_point", names) - @patch('static.tool_search_service.ToolSearchService') + @patch("static.tool_search_service.ToolSearchService") def test_handles_alternative_function_key(self, mock_service_class: Mock) -> None: """Should handle tool calls with 'function' key instead of 'function_name'.""" from static.routes import _intercept_search_tools mock_service = Mock() mock_service.search_tools.return_value = [ - {'function': {'name': 'create_circle'}}, + {"function": {"name": "create_circle"}}, ] mock_service_class.return_value = mock_service # Some API responses use 'function' key with nested 'name' tool_calls: List[Dict[str, Any]] = [ - {'function': {'name': 'search_tools'}, 'arguments': {'query': 'circle'}}, - {'function': {'name': 'create_circle'}, 'arguments': {}}, - {'function': {'name': 'delete_all'}, 'arguments': {}}, + {"function": {"name": "search_tools"}, "arguments": {"query": "circle"}}, + {"function": {"name": "create_circle"}, "arguments": {}}, + {"function": {"name": "delete_all"}, "arguments": {}}, ] result = _intercept_search_tools(self.app, tool_calls) # Extract names using the same logic as the function - names = [ - c.get('function_name') or c.get('function', {}).get('name') - for c in result - ] - self.assertIn('search_tools', names) - self.assertIn('create_circle', names) - self.assertNotIn('delete_all', names) + names = [c.get("function_name") or c.get("function", {}).get("name") for c in result] + self.assertIn("search_tools", names) + self.assertIn("create_circle", names) + self.assertNotIn("delete_all", names) - @patch('static.tool_search_service.ToolSearchService') + @patch("static.tool_search_service.ToolSearchService") def test_injects_tools_into_both_apis(self, mock_service_class: Mock) -> None: """Should inject tools into both ai_api and responses_api.""" from static.routes import _intercept_search_tools mock_service = Mock() - returned_tools = [{'function': {'name': 'create_circle'}}] + returned_tools = [{"function": {"name": "create_circle"}}] mock_service.search_tools.return_value = returned_tools mock_service_class.return_value = mock_service # Spy on inject_tools - with patch.object(self.app.ai_api, 'inject_tools') as mock_ai_inject, \ - patch.object(self.app.responses_api, 'inject_tools') as mock_resp_inject: - + with ( + patch.object(self.app.ai_api, "inject_tools") as mock_ai_inject, + patch.object(self.app.responses_api, "inject_tools") as mock_resp_inject, + ): tool_calls = [ - {'function_name': 'search_tools', 'arguments': {'query': 'circle'}}, + {"function_name": "search_tools", "arguments": {"query": "circle"}}, ] _intercept_search_tools(self.app, tool_calls) @@ -677,13 +665,13 @@ def test_injects_tools_into_both_apis(self, mock_service_class: Mock) -> None: mock_ai_inject.assert_called_once_with(returned_tools, include_essentials=True) mock_resp_inject.assert_called_once_with(returned_tools, include_essentials=True) - @patch('static.tool_search_service.ToolSearchService') + @patch("static.tool_search_service.ToolSearchService") def test_injects_tools_into_active_provider_when_distinct(self, mock_service_class: Mock) -> None: """Should inject into the active non-OpenAI provider as well.""" from static.routes import _intercept_search_tools mock_service = Mock() - returned_tools = [{'function': {'name': 'create_circle'}}] + returned_tools = [{"function": {"name": "create_circle"}}] mock_service.search_tools.return_value = returned_tools mock_service_class.return_value = mock_service @@ -691,10 +679,12 @@ def test_injects_tools_into_active_provider_when_distinct(self, mock_service_cla provider.client = Mock() provider.get_model.return_value = self.app.ai_api.get_model() - with patch.object(self.app.ai_api, 'inject_tools') as mock_ai_inject, \ - patch.object(self.app.responses_api, 'inject_tools') as mock_resp_inject: + with ( + patch.object(self.app.ai_api, "inject_tools") as mock_ai_inject, + patch.object(self.app.responses_api, "inject_tools") as mock_resp_inject, + ): tool_calls = [ - {'function_name': 'search_tools', 'arguments': {'query': 'circle'}}, + {"function_name": "search_tools", "arguments": {"query": "circle"}}, ] _intercept_search_tools(self.app, tool_calls, provider=provider) @@ -747,5 +737,5 @@ def test_extract_injectable_tools_ignores_non_search_payload(self) -> None: self.assertIsNone(tools) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/server_tests/test_search_tool_wiring_smoke.py b/server_tests/test_search_tool_wiring_smoke.py index 0f0692cc..c6dd1719 100644 --- a/server_tests/test_search_tool_wiring_smoke.py +++ b/server_tests/test_search_tool_wiring_smoke.py @@ -127,8 +127,8 @@ def __init__(self, name: str, arguments: str) -> None: class MockMessage: content = "I will draw that." tool_calls = [ - MockToolCall("search_tools", '{\"query\": \"plot x^2\"}'), - MockToolCall("draw_function", '{\"expression\": \"x^2\"}'), + MockToolCall("search_tools", '{"query": "plot x^2"}'), + MockToolCall("draw_function", '{"expression": "x^2"}'), ] class MockChoice: diff --git a/server_tests/test_statistics_pure.py b/server_tests/test_statistics_pure.py index 5befd00d..5e23f7ab 100644 --- a/server_tests/test_statistics_pure.py +++ b/server_tests/test_statistics_pure.py @@ -235,5 +235,3 @@ def test_distributions_pdf_matches_expected_peak(self) -> None: expr = normal_pdf_expression(mean, sigma) self.assertTrue(expr) self.assertAlmostEqual(expected, 0.19947114020071635) - - diff --git a/server_tests/test_tool_argument_validator.py b/server_tests/test_tool_argument_validator.py index 89774dc9..f69d2553 100644 --- a/server_tests/test_tool_argument_validator.py +++ b/server_tests/test_tool_argument_validator.py @@ -100,9 +100,7 @@ class TestSuccessfulValidation(unittest.TestCase): def test_create_point_valid(self) -> None: """create_point with correct types passes validation.""" - result = ToolArgumentValidator.validate( - "create_point", {"x": 5, "y": 10, "color": None, "name": None} - ) + result = ToolArgumentValidator.validate("create_point", {"x": 5, "y": 10, "color": None, "name": None}) self.assertTrue(result["valid"]) self.assertEqual(result["errors"], []) self.assertEqual(result["arguments"]["x"], 5) @@ -110,9 +108,7 @@ def test_create_point_valid(self) -> None: def test_create_point_with_string_values(self) -> None: """create_point with non-null optional string fields.""" - result = ToolArgumentValidator.validate( - "create_point", {"x": 3.5, "y": -2.0, "color": "red", "name": "A"} - ) + result = ToolArgumentValidator.validate("create_point", {"x": 3.5, "y": -2.0, "color": "red", "name": "A"}) self.assertTrue(result["valid"]) self.assertEqual(result["arguments"]["color"], "red") self.assertEqual(result["arguments"]["name"], "A") @@ -379,23 +375,17 @@ def test_draw_piecewise_function_valid(self) -> None: def test_set_coordinate_system_valid(self) -> None: """set_coordinate_system with valid enum.""" - result = ToolArgumentValidator.validate( - "set_coordinate_system", {"mode": "polar"} - ) + result = ToolArgumentValidator.validate("set_coordinate_system", {"mode": "polar"}) self.assertTrue(result["valid"]) def test_set_grid_visible_valid(self) -> None: """set_grid_visible with boolean value.""" - result = ToolArgumentValidator.validate( - "set_grid_visible", {"visible": True} - ) + result = ToolArgumentValidator.validate("set_grid_visible", {"visible": True}) self.assertTrue(result["valid"]) def test_int_accepted_for_number_field(self) -> None: """Python int should be accepted for JSON Schema 'number' type.""" - result = ToolArgumentValidator.validate( - "create_point", {"x": 5, "y": 10, "color": None, "name": None} - ) + result = ToolArgumentValidator.validate("create_point", {"x": 5, "y": 10, "color": None, "name": None}) self.assertTrue(result["valid"]) # int should pass through as-is (not converted to float) self.assertIs(type(result["arguments"]["x"]), int) @@ -458,9 +448,7 @@ def test_number_instead_of_string(self) -> None: def test_string_instead_of_boolean(self) -> None: """String for a boolean field should fail.""" - result = ToolArgumentValidator.validate( - "set_grid_visible", {"visible": "yes"} - ) + result = ToolArgumentValidator.validate("set_grid_visible", {"visible": "yes"}) self.assertFalse(result["valid"]) self.assertTrue(any("'visible'" in e for e in result["errors"])) @@ -497,9 +485,7 @@ def test_string_instead_of_object(self) -> None: }, ) self.assertFalse(result["valid"]) - self.assertTrue( - any("'distribution_params'" in e for e in result["errors"]) - ) + self.assertTrue(any("'distribution_params'" in e for e in result["errors"])) def test_float_instead_of_integer(self) -> None: """Float for an integer field should fail.""" @@ -622,9 +608,7 @@ def test_extra_key_in_nested_object(self) -> None: }, ) self.assertFalse(result["valid"]) - self.assertTrue( - any("'z'" in e or "'vertices[0].z'" in e for e in result["errors"]) - ) + self.assertTrue(any("'z'" in e or "'vertices[0].z'" in e for e in result["errors"])) def test_allowed_keys_listed_in_error(self) -> None: """Error message for unknown key should list allowed keys.""" @@ -657,19 +641,13 @@ def test_invalid_enum_zoom_range_axis(self) -> None: }, ) self.assertFalse(result["valid"]) - self.assertTrue( - any("'range_axis'" in e and "'z'" in e for e in result["errors"]) - ) + self.assertTrue(any("'range_axis'" in e and "'z'" in e for e in result["errors"])) def test_invalid_enum_coordinate_system(self) -> None: """Invalid enum value for set_coordinate_system should fail.""" - result = ToolArgumentValidator.validate( - "set_coordinate_system", {"mode": "spherical"} - ) + result = ToolArgumentValidator.validate("set_coordinate_system", {"mode": "spherical"}) self.assertFalse(result["valid"]) - self.assertTrue( - any("'mode'" in e and "'spherical'" in e for e in result["errors"]) - ) + self.assertTrue(any("'mode'" in e and "'spherical'" in e for e in result["errors"])) def test_valid_enum_value(self) -> None: """Valid enum value should pass.""" @@ -701,9 +679,7 @@ def test_invalid_polygon_type_enum(self) -> None: }, ) self.assertFalse(result["valid"]) - self.assertTrue( - any("'polygon_type'" in e and "'circle'" in e for e in result["errors"]) - ) + self.assertTrue(any("'polygon_type'" in e and "'circle'" in e for e in result["errors"])) def test_null_in_nullable_enum(self) -> None: """None is a valid enum value for nullable enum fields.""" @@ -736,9 +712,7 @@ def test_invalid_numeric_integrate_method(self) -> None: }, ) self.assertFalse(result["valid"]) - self.assertTrue( - any("'method'" in e and "'euler'" in e for e in result["errors"]) - ) + self.assertTrue(any("'method'" in e and "'euler'" in e for e in result["errors"])) # =================================================================== @@ -943,9 +917,7 @@ def test_generate_graph_invalid_vertex_type(self) -> None: }, ) self.assertFalse(result["valid"]) - self.assertTrue( - any("vertices[0]" in e for e in result["errors"]) - ) + self.assertTrue(any("vertices[0]" in e for e in result["errors"])) def test_plot_distribution_invalid_nested_type(self) -> None: """Invalid type in distribution_params.sigma should fail.""" @@ -965,9 +937,7 @@ def test_plot_distribution_invalid_nested_type(self) -> None: }, ) self.assertFalse(result["valid"]) - self.assertTrue( - any("sigma" in e for e in result["errors"]) - ) + self.assertTrue(any("sigma" in e for e in result["errors"])) def test_draw_piecewise_deep_nesting(self) -> None: """Invalid type in piecewise function piece should fail.""" @@ -1056,9 +1026,7 @@ def test_anyof_invalid_value(self) -> None: }, ) self.assertFalse(result["valid"]) - self.assertTrue( - any("did not match any" in e for e in result["errors"]) - ) + self.assertTrue(any("did not match any" in e for e in result["errors"])) def test_anyof_invalid_array_element(self) -> None: """anyOf value with array containing strings should fail.""" @@ -1088,9 +1056,7 @@ def test_placement_box_missing_required(self) -> None: }, ) self.assertFalse(result["valid"]) - self.assertTrue( - any("height" in e for e in result["errors"]) - ) + self.assertTrue(any("height" in e for e in result["errors"])) def test_nested_object_unknown_key(self) -> None: """Extra key in nested placement_box should be reported.""" @@ -1139,9 +1105,7 @@ def test_min_items_violation_polygon(self) -> None: }, ) self.assertFalse(result["valid"]) - self.assertTrue( - any("at least 3" in e and "got 2" in e for e in result["errors"]) - ) + self.assertTrue(any("at least 3" in e and "got 2" in e for e in result["errors"])) def test_min_items_exactly_met(self) -> None: """create_polygon with exactly 3 vertices should pass.""" @@ -1168,9 +1132,7 @@ def test_min_items_piecewise_empty_array(self) -> None: {"pieces": [], "name": None, "color": None}, ) self.assertFalse(result["valid"]) - self.assertTrue( - any("at least 1" in e for e in result["errors"]) - ) + self.assertTrue(any("at least 1" in e for e in result["errors"])) def test_min_items_linear_algebra_empty_objects(self) -> None: """evaluate_linear_algebra_expression with empty objects should fail.""" @@ -1179,9 +1141,7 @@ def test_min_items_linear_algebra_empty_objects(self) -> None: {"objects": [], "expression": "A"}, ) self.assertFalse(result["valid"]) - self.assertTrue( - any("at least 1" in e for e in result["errors"]) - ) + self.assertTrue(any("at least 1" in e for e in result["errors"])) def test_max_length_violation(self) -> None: """create_label with text exceeding maxLength should fail.""" @@ -1198,9 +1158,7 @@ def test_max_length_violation(self) -> None: }, ) self.assertFalse(result["valid"]) - self.assertTrue( - any("at most 160" in e and "got 200" in e for e in result["errors"]) - ) + self.assertTrue(any("at most 160" in e and "got 200" in e for e in result["errors"])) def test_max_length_at_limit(self) -> None: """create_label with text exactly at maxLength should pass.""" @@ -1251,9 +1209,7 @@ def test_none_arguments(self) -> None: def test_unknown_function_passes_through(self) -> None: """Unknown function name should pass with valid=True and log warning.""" with self.assertLogs("static.tool_argument_validator", level="WARNING") as cm: - result = ToolArgumentValidator.validate( - "totally_unknown_function", {"foo": "bar"} - ) + result = ToolArgumentValidator.validate("totally_unknown_function", {"foo": "bar"}) self.assertTrue(result["valid"]) self.assertEqual(result["errors"], []) self.assertTrue(any("no schema found" in msg for msg in cm.output)) @@ -1287,10 +1243,7 @@ def test_large_argument_dict(self) -> None: def test_deeply_nested_valid_graph(self) -> None: """generate_graph with many vertices and edges should work.""" - vertices = [ - {"name": f"V{i}", "x": i * 10, "y": i * 5, "color": None, "label": None} - for i in range(20) - ] + vertices = [{"name": f"V{i}", "x": i * 10, "y": i * 5, "color": None, "label": None} for i in range(20)] edges = [ { "source": i, @@ -1321,9 +1274,7 @@ def test_deeply_nested_valid_graph(self) -> None: def test_error_value_truncation(self) -> None: """Long string values in errors should be truncated.""" long_string = "x" * 200 - result = ToolArgumentValidator.validate( - "set_coordinate_system", {"mode": long_string} - ) + result = ToolArgumentValidator.validate("set_coordinate_system", {"mode": long_string}) self.assertFalse(result["valid"]) # Error message should contain truncated value, not the full 200 chars error_text = result["errors"][0] @@ -1331,9 +1282,7 @@ def test_error_value_truncation(self) -> None: def test_validation_result_structure(self) -> None: """ValidationResult should contain valid, arguments, and errors keys.""" - result = ToolArgumentValidator.validate( - "create_point", {"x": 5, "y": 10, "color": None, "name": None} - ) + result = ToolArgumentValidator.validate("create_point", {"x": 5, "y": 10, "color": None, "name": None}) self.assertIn("valid", result) self.assertIn("arguments", result) self.assertIn("errors", result) @@ -1351,9 +1300,7 @@ def test_invalid_returns_original_args(self) -> None: def test_valid_returns_canonical_args(self) -> None: """When validation passes, canonical arguments should be returned.""" - result = ToolArgumentValidator.validate( - "create_point", {"x": "5", "y": 10, "color": None, "name": None} - ) + result = ToolArgumentValidator.validate("create_point", {"x": "5", "y": 10, "color": None, "name": None}) self.assertTrue(result["valid"]) self.assertEqual(result["arguments"]["x"], 5.0) @@ -1367,9 +1314,7 @@ def test_negative_numbers_valid(self) -> None: def test_zero_valid(self) -> None: """Zero should be valid for number fields.""" - result = ToolArgumentValidator.validate( - "create_point", {"x": 0, "y": 0, "color": None, "name": None} - ) + result = ToolArgumentValidator.validate("create_point", {"x": 0, "y": 0, "color": None, "name": None}) self.assertTrue(result["valid"]) diff --git a/server_tests/test_tool_discovery_live.py b/server_tests/test_tool_discovery_live.py index 554b60a7..29c096f5 100644 --- a/server_tests/test_tool_discovery_live.py +++ b/server_tests/test_tool_discovery_live.py @@ -44,11 +44,7 @@ def _load_dataset() -> Dict[str, Any]: def _tool_name_set() -> set[str]: - return { - f.get("function", {}).get("name", "") - for f in FUNCTIONS - if f.get("function", {}).get("name") - } + return {f.get("function", {}).get("name", "") for f in FUNCTIONS if f.get("function", {}).get("name")} def _tool_hash(tool_names: set[str]) -> str: @@ -224,12 +220,10 @@ def test_live_tool_discovery_benchmark() -> None: actual_hash = _tool_hash(all_tools) assert len(all_tools) == expected_count, ( - "Tool count mismatch; refresh dataset. " - f"expected={expected_count}, actual={len(all_tools)}" + f"Tool count mismatch; refresh dataset. expected={expected_count}, actual={len(all_tools)}" ) assert actual_hash == expected_hash, ( - "Tool hash mismatch; refresh dataset. " - f"expected={expected_hash}, actual={actual_hash}" + f"Tool hash mismatch; refresh dataset. expected={expected_hash}, actual={actual_hash}" ) model = _resolve_model() @@ -294,11 +288,7 @@ def test_live_tool_discovery_benchmark() -> None: if csv_path is not None and resume: completed_case_ids = _load_existing_case_ids(csv_path) if completed_case_ids: - selected_cases = [ - c - for c in selected_cases - if str(c.get("id", "")) not in completed_case_ids - ] + selected_cases = [c for c in selected_cases if str(c.get("id", "")) not in completed_case_ids] if tool_limit > 0: selected_cases = selected_cases[:tool_limit] @@ -420,8 +410,7 @@ def test_live_tool_discovery_benchmark() -> None: _write_csv(Path("/tmp/tool_discovery_results.csv"), rows) assert positive_evaluated > 0, ( - "No positive cases were evaluated (all may have been blocked). " - "Inspect TOOL_DISCOVERY_CSV output for details." + "No positive cases were evaluated (all may have been blocked). Inspect TOOL_DISCOVERY_CSV output for details." ) assert blocked_rate <= blocked_max, ( f"Blocked rate too high: {blocked_rate:.3f} > {blocked_max:.3f}. " @@ -440,6 +429,5 @@ def test_live_tool_discovery_benchmark() -> None: f"Sample failing cases: {failed_case_ids[:12]}" ) assert confusion_hard_miss_rate <= confusion_hard_miss_max, ( - "Confusion hard-miss rate too high: " - f"{confusion_hard_miss_rate:.3f} > {confusion_hard_miss_max:.3f}" + f"Confusion hard-miss rate too high: {confusion_hard_miss_rate:.3f} > {confusion_hard_miss_max:.3f}" ) diff --git a/server_tests/test_tool_search_service.py b/server_tests/test_tool_search_service.py index aa68b283..f62ffb01 100644 --- a/server_tests/test_tool_search_service.py +++ b/server_tests/test_tool_search_service.py @@ -149,9 +149,7 @@ def service(self, mock_client: MagicMock) -> ToolSearchService: """Create a ToolSearchService with mocked client.""" return ToolSearchService(client=mock_client) - def _setup_mock_response( - self, mock_client: MagicMock, content: str - ) -> None: + def _setup_mock_response(self, mock_client: MagicMock, content: str) -> None: """Configure mock client to return specific content.""" mock_message = MagicMock() mock_message.content = content @@ -161,13 +159,9 @@ def _setup_mock_response( mock_response.choices = [mock_choice] mock_client.chat.completions.create.return_value = mock_response - def test_search_returns_tools( - self, service: ToolSearchService, mock_client: MagicMock - ) -> None: + def test_search_returns_tools(self, service: ToolSearchService, mock_client: MagicMock) -> None: """search_tools should return matching tool definitions.""" - self._setup_mock_response( - mock_client, '["create_circle", "create_point"]' - ) + self._setup_mock_response(mock_client, '["create_circle", "create_point"]') result = service.search_tools("draw a circle") @@ -176,9 +170,7 @@ def test_search_returns_tools( assert "create_circle" in names assert "create_point" in names - def test_search_respects_max_results( - self, service: ToolSearchService, mock_client: MagicMock - ) -> None: + def test_search_respects_max_results(self, service: ToolSearchService, mock_client: MagicMock) -> None: """search_tools should limit results to max_results.""" self._setup_mock_response( mock_client, @@ -189,26 +181,20 @@ def test_search_respects_max_results( assert len(result) == 2 - def test_search_empty_query_returns_empty( - self, service: ToolSearchService, mock_client: MagicMock - ) -> None: + def test_search_empty_query_returns_empty(self, service: ToolSearchService, mock_client: MagicMock) -> None: """search_tools should return empty list for empty query.""" result = service.search_tools("") assert result == [] # Should not call the API mock_client.chat.completions.create.assert_not_called() - def test_search_whitespace_query_returns_empty( - self, service: ToolSearchService, mock_client: MagicMock - ) -> None: + def test_search_whitespace_query_returns_empty(self, service: ToolSearchService, mock_client: MagicMock) -> None: """search_tools should return empty list for whitespace-only query.""" result = service.search_tools(" ") assert result == [] mock_client.chat.completions.create.assert_not_called() - def test_search_clamps_max_results_low( - self, service: ToolSearchService, mock_client: MagicMock - ) -> None: + def test_search_clamps_max_results_low(self, service: ToolSearchService, mock_client: MagicMock) -> None: """search_tools should clamp max_results to minimum of 1.""" self._setup_mock_response(mock_client, '["create_circle"]') @@ -217,13 +203,9 @@ def test_search_clamps_max_results_low( # Should still work with at least 1 result assert len(result) <= 1 - def test_search_clamps_max_results_high( - self, service: ToolSearchService, mock_client: MagicMock - ) -> None: + def test_search_clamps_max_results_high(self, service: ToolSearchService, mock_client: MagicMock) -> None: """search_tools should clamp max_results to maximum of 20.""" - self._setup_mock_response( - mock_client, '["create_circle", "create_point"]' - ) + self._setup_mock_response(mock_client, '["create_circle", "create_point"]') # Passing 100 should be clamped to 20 service.search_tools("draw", max_results=100) @@ -234,9 +216,7 @@ def test_search_clamps_max_results_high( prompt = messages[0]["content"] if messages else "" assert "up to 20 tool names" in prompt - def test_search_filters_unknown_tools( - self, service: ToolSearchService, mock_client: MagicMock - ) -> None: + def test_search_filters_unknown_tools(self, service: ToolSearchService, mock_client: MagicMock) -> None: """search_tools should filter out unknown tool names from response.""" self._setup_mock_response( mock_client, @@ -250,9 +230,7 @@ def test_search_filters_unknown_tools( names = [t["function"]["name"] for t in result] assert "nonexistent_tool" not in names - def test_search_handles_api_error( - self, service: ToolSearchService, mock_client: MagicMock - ) -> None: + def test_search_handles_api_error(self, service: ToolSearchService, mock_client: MagicMock) -> None: """search_tools should use fallback ranking on API error.""" mock_client.chat.completions.create.side_effect = Exception("API Error") @@ -262,9 +240,7 @@ def test_search_handles_api_error( names = [t["function"]["name"] for t in result] assert "create_circle" in names - def test_search_handles_empty_response( - self, service: ToolSearchService, mock_client: MagicMock - ) -> None: + def test_search_handles_empty_response(self, service: ToolSearchService, mock_client: MagicMock) -> None: """search_tools should handle empty API response.""" self._setup_mock_response(mock_client, "") @@ -272,9 +248,7 @@ def test_search_handles_empty_response( assert result == [] - def test_search_uses_correct_model( - self, service: ToolSearchService, mock_client: MagicMock - ) -> None: + def test_search_uses_correct_model(self, service: ToolSearchService, mock_client: MagicMock) -> None: """search_tools should use the specified model.""" self._setup_mock_response(mock_client, '["create_circle"]') model = AIModel.from_identifier("gpt-4.1") @@ -284,9 +258,7 @@ def test_search_uses_correct_model( call_args = mock_client.chat.completions.create.call_args assert call_args.kwargs.get("model") == "gpt-4.1" - def test_search_uses_default_model_when_none( - self, service: ToolSearchService, mock_client: MagicMock - ) -> None: + def test_search_uses_default_model_when_none(self, service: ToolSearchService, mock_client: MagicMock) -> None: """search_tools should use gpt-4.1-mini when no model specified.""" self._setup_mock_response(mock_client, '["create_circle"]') @@ -310,9 +282,7 @@ def service(self, mock_client: MagicMock) -> ToolSearchService: """Create a ToolSearchService with mocked client.""" return ToolSearchService(client=mock_client) - def _setup_mock_response( - self, mock_client: MagicMock, content: str - ) -> None: + def _setup_mock_response(self, mock_client: MagicMock, content: str) -> None: """Configure mock client to return specific content.""" mock_message = MagicMock() mock_message.content = content @@ -322,9 +292,7 @@ def _setup_mock_response( mock_response.choices = [mock_choice] mock_client.chat.completions.create.return_value = mock_response - def test_formatted_returns_dict( - self, service: ToolSearchService, mock_client: MagicMock - ) -> None: + def test_formatted_returns_dict(self, service: ToolSearchService, mock_client: MagicMock) -> None: """search_tools_formatted should return a dict.""" self._setup_mock_response(mock_client, '["create_circle"]') @@ -332,9 +300,7 @@ def test_formatted_returns_dict( assert isinstance(result, dict) - def test_formatted_contains_required_keys( - self, service: ToolSearchService, mock_client: MagicMock - ) -> None: + def test_formatted_contains_required_keys(self, service: ToolSearchService, mock_client: MagicMock) -> None: """search_tools_formatted result should have tools, count, and query.""" self._setup_mock_response(mock_client, '["create_circle"]') @@ -344,22 +310,16 @@ def test_formatted_contains_required_keys( assert "count" in result assert "query" in result - def test_formatted_count_matches_tools( - self, service: ToolSearchService, mock_client: MagicMock - ) -> None: + def test_formatted_count_matches_tools(self, service: ToolSearchService, mock_client: MagicMock) -> None: """count should match the number of tools returned.""" - self._setup_mock_response( - mock_client, '["create_circle", "create_point"]' - ) + self._setup_mock_response(mock_client, '["create_circle", "create_point"]') result = service.search_tools_formatted("draw shapes") assert result["count"] == len(result["tools"]) assert result["count"] == 2 - def test_formatted_preserves_query( - self, service: ToolSearchService, mock_client: MagicMock - ) -> None: + def test_formatted_preserves_query(self, service: ToolSearchService, mock_client: MagicMock) -> None: """query field should contain the original query.""" self._setup_mock_response(mock_client, '["create_circle"]') query = "draw a beautiful circle" diff --git a/server_tests/test_tts_manager.py b/server_tests/test_tts_manager.py index cbe4aabb..98155e5b 100644 --- a/server_tests/test_tts_manager.py +++ b/server_tests/test_tts_manager.py @@ -62,7 +62,7 @@ def test_generate_speech_invalid_voice_uses_default(self) -> None: manager = TTSManager() # Mock the pipeline to avoid actual TTS - with patch.object(manager, '_get_pipeline') as mock_pipeline: + with patch.object(manager, "_get_pipeline") as mock_pipeline: mock_pipeline.return_value = (False, "Test: Kokoro not installed") success, result = manager.generate_speech("test", voice="invalid_voice") @@ -109,6 +109,7 @@ def setUpClass(cls) -> None: """Check if Kokoro is available.""" try: import kokoro # noqa: F401 + cls.kokoro_available = True except ImportError: cls.kokoro_available = False @@ -147,5 +148,5 @@ def test_generate_speech_produces_wav(self) -> None: self.assertEqual(result[:4], b"RIFF") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/server_tests/test_workspace_management.py b/server_tests/test_workspace_management.py index 10a40f51..77466eee 100644 --- a/server_tests/test_workspace_management.py +++ b/server_tests/test_workspace_management.py @@ -50,13 +50,15 @@ def test_save_workspace_without_name(self) -> None: self.canvas.create_segment(100, 100, 200, 200, "AB") test_current = "test_current_workspace" - success = self.workspace_manager.save_workspace(cast(WorkspaceState, self.canvas.get_canvas_state()), test_current, TEST_DIR) + success = self.workspace_manager.save_workspace( + cast(WorkspaceState, self.canvas.get_canvas_state()), test_current, TEST_DIR + ) self.assertTrue(success, "Save workspace should return True on success") workspace_path = os.path.join(WORKSPACES_DIR, TEST_DIR, f"{test_current}.json") self.assertTrue(os.path.exists(workspace_path)) - with open(workspace_path, 'r') as f: + with open(workspace_path, "r") as f: data = json.load(f) self.assertEqual(data["metadata"]["name"], test_current) self.assertIn("Points", data["state"]) @@ -69,13 +71,15 @@ def test_save_workspace_with_name(self) -> None: self.canvas.create_circle(0, 0, 100, "C1") workspace_name = "test_circle_workspace" - success = self.workspace_manager.save_workspace(cast(WorkspaceState, self.canvas.get_canvas_state()), workspace_name, TEST_DIR) + success = self.workspace_manager.save_workspace( + cast(WorkspaceState, self.canvas.get_canvas_state()), workspace_name, TEST_DIR + ) self.assertTrue(success, "Save workspace should return True on success") workspace_path = os.path.join(WORKSPACES_DIR, TEST_DIR, f"{workspace_name}.json") self.assertTrue(os.path.exists(workspace_path)) - with open(workspace_path, 'r') as f: + with open(workspace_path, "r") as f: data = json.load(f) self.assertEqual(data["metadata"]["name"], workspace_name) self.assertEqual(data["metadata"]["schema_version"], CURRENT_WORKSPACE_SCHEMA_VERSION) @@ -87,15 +91,19 @@ def test_save_workspace_failure(self) -> None: success = self.workspace_manager.save_workspace(None, test_dir=TEST_DIR) self.assertFalse(success, "Save workspace should return False when state is None") - success = self.workspace_manager.save_workspace(cast(WorkspaceState, self.canvas.get_canvas_state()), "test/invalid/name", TEST_DIR) + success = self.workspace_manager.save_workspace( + cast(WorkspaceState, self.canvas.get_canvas_state()), "test/invalid/name", TEST_DIR + ) self.assertFalse(success, "Save workspace should return False with invalid name") - if os.name != 'nt' and os.getuid() != 0: # Skip on Windows and root + if os.name != "nt" and os.getuid() != 0: # Skip on Windows and root test_dir = os.path.join(WORKSPACES_DIR, TEST_DIR) original_mode = os.stat(test_dir).st_mode try: os.chmod(test_dir, 0o444) # Read-only - success = self.workspace_manager.save_workspace(cast(WorkspaceState, self.canvas.get_canvas_state()), "test_readonly", TEST_DIR) + success = self.workspace_manager.save_workspace( + cast(WorkspaceState, self.canvas.get_canvas_state()), "test_readonly", TEST_DIR + ) self.assertFalse(success, "Save workspace should return False with read-only directory") finally: os.chmod(test_dir, original_mode) @@ -108,7 +116,9 @@ def test_load_workspace(self) -> None: name="ABC", ) workspace_name = "test_triangle_workspace" - success = self.workspace_manager.save_workspace(cast(WorkspaceState, self.canvas.get_canvas_state()), workspace_name, TEST_DIR) + success = self.workspace_manager.save_workspace( + cast(WorkspaceState, self.canvas.get_canvas_state()), workspace_name, TEST_DIR + ) self.assertTrue(success, "Initial save should succeed") self.canvas.clear() @@ -118,10 +128,7 @@ def test_load_workspace(self) -> None: state_dict = cast(Dict[str, Any], state) self.assertIn("Polygons", state_dict) - triangles = [ - polygon for polygon in state_dict["Polygons"] - if polygon.get("polygon_type") == "triangle" - ] + triangles = [polygon for polygon in state_dict["Polygons"] if polygon.get("polygon_type") == "triangle"] self.assertEqual(len(triangles), 1) triangle = triangles[0] self.assertEqual(triangle.get("name"), "ABC") @@ -136,8 +143,8 @@ def test_load_workspace_invalid_json(self) -> None: workspace_name = "test_invalid_json_workspace" workspace_path = os.path.join(WORKSPACES_DIR, TEST_DIR, f"{workspace_name}.json") - with open(workspace_path, 'w') as f: - f.write("{\"name\": \"test\", \"state\": {this_is_not_valid_json}") + with open(workspace_path, "w") as f: + f.write('{"name": "test", "state": {this_is_not_valid_json}') with self.assertRaises((json.JSONDecodeError, ValueError)): self.workspace_manager.load_workspace(workspace_name, TEST_DIR) @@ -150,9 +157,9 @@ def test_load_workspace_incorrect_schema(self) -> None: # Case 1: Valid JSON, but missing the top-level 'state' key. malformed_data = { "metadata": {"name": workspace_name, "timestamp": "sometime"}, - "unexpected_top_level_key": {"Points": []} + "unexpected_top_level_key": {"Points": []}, } - with open(workspace_path, 'w') as f: + with open(workspace_path, "w") as f: json.dump(malformed_data, f) with self.assertRaisesRegex(ValueError, "Error loading workspace: .*state"): @@ -161,18 +168,17 @@ def test_load_workspace_incorrect_schema(self) -> None: # Case 2: 'state' key exists, but 'Points' (a drawable type) is not a list. workspace_name_2 = "test_points_not_list" workspace_path_2 = os.path.join(WORKSPACES_DIR, TEST_DIR, f"{workspace_name_2}.json") - malformed_data_2 = { - "metadata": {"name": workspace_name_2}, - "state": {"Points": "this should be a list"} - } - with open(workspace_path_2, 'w') as f: + malformed_data_2 = {"metadata": {"name": workspace_name_2}, "state": {"Points": "this should be a list"}} + with open(workspace_path_2, "w") as f: json.dump(malformed_data_2, f) loaded_state_2 = self.workspace_manager.load_workspace(workspace_name_2, TEST_DIR) loaded_state_2_dict = cast(Dict[str, Any], loaded_state_2) self.assertIsInstance(loaded_state_2_dict, dict, "Loaded state should be a dictionary.") self.assertIn("Points", loaded_state_2_dict) - self.assertNotIsInstance(loaded_state_2_dict["Points"], list, "Points should not be a list in this malformed case.") + self.assertNotIsInstance( + loaded_state_2_dict["Points"], list, "Points should not be a list in this malformed case." + ) self.assertEqual(loaded_state_2_dict["Points"], "this should be a list") def test_load_workspace_legacy_state_only_payload(self) -> None: @@ -180,7 +186,7 @@ def test_load_workspace_legacy_state_only_payload(self) -> None: workspace_name = "test_legacy_state_only_workspace" workspace_path = os.path.join(WORKSPACES_DIR, TEST_DIR, f"{workspace_name}.json") legacy_state = {"Points": [{"x": 1, "y": 2, "name": "A"}], "Segments": []} - with open(workspace_path, 'w') as f: + with open(workspace_path, "w") as f: json.dump(legacy_state, f) loaded_state = self.workspace_manager.load_workspace(workspace_name, TEST_DIR) @@ -200,7 +206,7 @@ def test_load_workspace_rejects_future_schema_version(self) -> None: }, "state": {"Points": []}, } - with open(workspace_path, 'w') as f: + with open(workspace_path, "w") as f: json.dump(future_data, f) with self.assertRaisesRegex(ValueError, "Unsupported workspace schema_version"): @@ -218,7 +224,7 @@ def test_load_workspace_accepts_string_schema_version(self) -> None: }, "state": {"Points": [{"x": 0, "y": 0, "name": "A"}]}, } - with open(workspace_path, 'w') as f: + with open(workspace_path, "w") as f: json.dump(data, f) loaded_state = self.workspace_manager.load_workspace(workspace_name, TEST_DIR) @@ -230,7 +236,9 @@ def test_list_workspaces(self) -> None: """Test listing all workspaces.""" workspace_names = ["test_ws1", "test_ws2", "test_ws3"] for name in workspace_names: - success = self.workspace_manager.save_workspace(cast(WorkspaceState, self.canvas.get_canvas_state()), name, TEST_DIR) + success = self.workspace_manager.save_workspace( + cast(WorkspaceState, self.canvas.get_canvas_state()), name, TEST_DIR + ) self.assertTrue(success, f"Failed to save workspace {name}") workspaces = self.workspace_manager.list_workspaces(TEST_DIR) @@ -240,16 +248,18 @@ def test_list_workspaces(self) -> None: def test_list_workspaces_with_non_workspace_files(self) -> None: """Test listing workspaces when non-workspace files are present.""" - self.workspace_manager.save_workspace(cast(WorkspaceState, self.canvas.get_canvas_state()), "valid_ws_for_list_test", TEST_DIR) + self.workspace_manager.save_workspace( + cast(WorkspaceState, self.canvas.get_canvas_state()), "valid_ws_for_list_test", TEST_DIR + ) test_dir_path = os.path.join(WORKSPACES_DIR, TEST_DIR) - with open(os.path.join(test_dir_path, "notes.txt"), 'w') as f: + with open(os.path.join(test_dir_path, "notes.txt"), "w") as f: f.write("some notes") - with open(os.path.join(test_dir_path, "image.jpg"), 'w') as f: + with open(os.path.join(test_dir_path, "image.jpg"), "w") as f: f.write("fake image data") - with open(os.path.join(test_dir_path, "invalid_structure.json"), 'w') as f: - json.dump({"metadata": {"name": "invalid_structure"}}, f) # Missing 'state' key - with open(os.path.join(test_dir_path, ".hiddenfile"), 'w') as f: + with open(os.path.join(test_dir_path, "invalid_structure.json"), "w") as f: + json.dump({"metadata": {"name": "invalid_structure"}}, f) # Missing 'state' key + with open(os.path.join(test_dir_path, ".hiddenfile"), "w") as f: f.write("hidden") workspaces = self.workspace_manager.list_workspaces(TEST_DIR) @@ -273,7 +283,9 @@ def test_workspace_with_computations(self) -> None: self.canvas.add_computation("sin(pi/2)", 1) workspace_name = "test_computations" - success = self.workspace_manager.save_workspace(cast(WorkspaceState, self.canvas.get_canvas_state()), workspace_name, TEST_DIR) + success = self.workspace_manager.save_workspace( + cast(WorkspaceState, self.canvas.get_canvas_state()), workspace_name, TEST_DIR + ) self.assertTrue(success, "Failed to save workspace with computations") self.canvas.clear() @@ -305,11 +317,13 @@ def test_save_complex_workspace(self) -> None: self.canvas.add_computation("area", 10000) workspace_name = "test_complex" - success = self.workspace_manager.save_workspace(cast(WorkspaceState, self.canvas.get_canvas_state()), workspace_name, TEST_DIR) + success = self.workspace_manager.save_workspace( + cast(WorkspaceState, self.canvas.get_canvas_state()), workspace_name, TEST_DIR + ) self.assertTrue(success, "Failed to save complex workspace") workspace_path = os.path.join(WORKSPACES_DIR, TEST_DIR, f"{workspace_name}.json") - with open(workspace_path, 'r') as f: + with open(workspace_path, "r") as f: data = json.load(f) state = data["state"] self.assertIn("Points", state) @@ -328,9 +342,9 @@ def test_save_and_load_preserves_state_integrity(self) -> None: # 1. Create a complex canvas state pointA_coords = (10, 20) pointB_coords = (30, 40) - pointC_coords = (50, 50) # For Circle + pointC_coords = (50, 50) # For Circle pointD_coords = (5, 15) # For Vector - pointE_coords = (25, 35) # For Vector + pointE_coords = (25, 35) # For Vector self.canvas.create_point(pointA_coords[0], pointA_coords[1], "A") self.canvas.create_point(pointB_coords[0], pointB_coords[1], "B") @@ -348,15 +362,20 @@ def test_save_and_load_preserves_state_integrity(self) -> None: workspace_name = "test_integrity_workspace_canvas_methods" # 2. Save the workspace - save_success = self.workspace_manager.save_workspace(cast(WorkspaceState, original_state), workspace_name, TEST_DIR) + save_success = self.workspace_manager.save_workspace( + cast(WorkspaceState, original_state), workspace_name, TEST_DIR + ) self.assertTrue(save_success, "Saving the integrity test workspace should succeed.") # 3. Load the workspace loaded_state = self.workspace_manager.load_workspace(workspace_name, TEST_DIR) # 4. Deeply compare the loaded state with the original state. - self._assert_states_equal_after_sorting(original_state, loaded_state, - "Loaded workspace state (from canvas methods) does not match the original state.") + self._assert_states_equal_after_sorting( + original_state, + loaded_state, + "Loaded workspace state (from canvas methods) does not match the original state.", + ) def _assert_states_equal_after_sorting(self, state1: CanvasStateDict, state2: WorkspaceState, message: str) -> None: """Helper to compare two canvas states after deep copying and sorting their lists.""" @@ -415,28 +434,24 @@ def test_save_and_load_mock_state_preserves_integrity_with_vectors(self) -> None {"name": "M_B", "args": {"position": {"x": 130, "y": 140}}}, {"name": "M_C", "args": {"position": {"x": 150, "y": 160}}}, {"name": "M_D", "args": {"position": {"x": 105, "y": 115}}}, - {"name": "M_E", "args": {"position": {"x": 125, "y": 135}}} - ], - "Segments": [ - {"name": "M_AB", "args": {"p1": "M_A", "p2": "M_B"}} - ], - "Circles": [ - {"name": "M_CircleC", "args": {"center": "M_C", "radius": 20}} - ], - "Vectors": [ - {"name": "M_DE", "args": {"origin": "M_D", "tip": "M_E"}} + {"name": "M_E", "args": {"position": {"x": 125, "y": 135}}}, ], + "Segments": [{"name": "M_AB", "args": {"p1": "M_A", "p2": "M_B"}}], + "Circles": [{"name": "M_CircleC", "args": {"center": "M_C", "radius": 20}}], + "Vectors": [{"name": "M_DE", "args": {"origin": "M_D", "tip": "M_E"}}], "Functions": [ {"name": "M_Func1", "args": {"function_string": "sin(x)", "left_bound": -10, "right_bound": 10}} ], "computations": [ {"expression": "10+20", "result": 30}, - {"expression": "cos(0)", "result": 1.0} # MockCanvas computation results are floats - ] + {"expression": "cos(0)", "result": 1.0}, # MockCanvas computation results are floats + ], } workspace_name = "test_integrity_mock_state_vectors" - save_success = self.workspace_manager.save_workspace(cast(WorkspaceState, mock_canvas_state), workspace_name, TEST_DIR) + save_success = self.workspace_manager.save_workspace( + cast(WorkspaceState, mock_canvas_state), workspace_name, TEST_DIR + ) self.assertTrue(save_success, "Saving the mock state integrity test workspace should succeed.") loaded_state = self.workspace_manager.load_workspace(workspace_name, TEST_DIR) @@ -450,18 +465,21 @@ def test_save_and_load_mock_state_preserves_integrity_with_vectors(self) -> None # Quick checks for key elements loaded_points_list = cast(list, loaded_state_dict.get("Points", [])) - loaded_points_dict = {cast(Dict[str, Any], p)['name']: p for p in loaded_points_list} + loaded_points_dict = {cast(Dict[str, Any], p)["name"]: p for p in loaded_points_list} self.assertIn("M_E", loaded_points_dict) self.assertEqual(cast(Dict[str, Any], loaded_points_dict["M_E"])["args"]["position"]["x"], 125) loaded_vectors_list = cast(list, loaded_state_dict.get("Vectors", [])) - loaded_vectors_dict = {cast(Dict[str, Any], v)['name']: v for v in loaded_vectors_list} + loaded_vectors_dict = {cast(Dict[str, Any], v)["name"]: v for v in loaded_vectors_list} self.assertIn("M_DE", loaded_vectors_dict) self.assertEqual(cast(Dict[str, Any], loaded_vectors_dict["M_DE"])["args"]["origin"], "M_D") # Use the helper method for full comparison - self._assert_states_equal_after_sorting(cast(CanvasStateDict, mock_canvas_state), loaded_state_dict, - "Loaded mock workspace state with vectors does not match the original mock state.") + self._assert_states_equal_after_sorting( + cast(CanvasStateDict, mock_canvas_state), + loaded_state_dict, + "Loaded mock workspace state with vectors does not match the original mock state.", + ) def test_save_and_load_empty_workspace(self) -> None: """Test saving and loading an empty workspace.""" @@ -473,7 +491,9 @@ def test_save_and_load_empty_workspace(self) -> None: workspace_name = "test_empty_workspace" # 2. Save the empty workspace - save_success = self.workspace_manager.save_workspace(cast(WorkspaceState, original_empty_state), workspace_name, TEST_DIR) + save_success = self.workspace_manager.save_workspace( + cast(WorkspaceState, original_empty_state), workspace_name, TEST_DIR + ) self.assertTrue(save_success, "Saving an empty workspace should succeed.") # 3. Load the empty workspace @@ -485,14 +505,19 @@ def test_save_and_load_empty_workspace(self) -> None: del loaded_state_dict["metadata"] # Note: Relies on MockCanvas.get_canvas_state() returning a consistent empty structure. - self.assertEqual(original_empty_state, loaded_state_dict, - "Loaded empty workspace state does not match the original empty state.") + self.assertEqual( + original_empty_state, + loaded_state_dict, + "Loaded empty workspace state does not match the original empty state.", + ) def test_delete_workspace(self) -> None: """Test deleting a workspace.""" self.canvas.create_point(100, 100, "A") workspace_name = "test_delete_ws" - success = self.workspace_manager.save_workspace(cast(WorkspaceState, self.canvas.get_canvas_state()), workspace_name, TEST_DIR) + success = self.workspace_manager.save_workspace( + cast(WorkspaceState, self.canvas.get_canvas_state()), workspace_name, TEST_DIR + ) self.assertTrue(success, "Initial save should succeed") workspace_path = os.path.join(WORKSPACES_DIR, TEST_DIR, f"{workspace_name}.json") @@ -519,5 +544,6 @@ def test_delete_workspace_with_invalid_name(self) -> None: success = self.workspace_manager.delete_workspace("test/invalid/name", TEST_DIR) self.assertFalse(success, "Delete workspace should return False for invalid name") -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/static/ai_model.py b/static/ai_model.py index 7da2c912..b3b3cf7d 100644 --- a/static/ai_model.py +++ b/static/ai_model.py @@ -437,9 +437,7 @@ def _format_display_name(model_name: str) -> str: if i > 0: prev_char = base[i - 1] # Add space between letter and digit - if (prev_char.isalpha() and char.isdigit()) or ( - prev_char.isdigit() and char.isalpha() - ): + if (prev_char.isalpha() and char.isdigit()) or (prev_char.isdigit() and char.isalpha()): # But not for decimal points in version numbers if not (prev_char.isdigit() and char == "."): formatted += " " diff --git a/static/app_manager.py b/static/app_manager.py index e704ff67..47acfda7 100644 --- a/static/app_manager.py +++ b/static/app_manager.py @@ -86,7 +86,7 @@ def is_deployed() -> bool: Returns: bool: True if deployed (PORT environment variable is set), False for local development """ - return os.environ.get('PORT') is not None + return os.environ.get("PORT") is not None @staticmethod def _load_env() -> None: @@ -105,7 +105,7 @@ def requires_auth() -> bool: """ AppManager._load_env() # Require auth if deployed OR if explicitly enabled via REQUIRE_AUTH - return AppManager.is_deployed() or os.getenv('REQUIRE_AUTH', '').lower() in ('true', '1', 'yes') + return AppManager.is_deployed() or os.getenv("REQUIRE_AUTH", "").lower() in ("true", "1", "yes") @staticmethod def get_auth_pin() -> Optional[str]: @@ -121,7 +121,7 @@ def get_auth_pin() -> Optional[str]: def make_response( data: JsonValue | None = None, message: Optional[str] = None, - status: str = 'success', + status: str = "success", code: int = 200, ) -> Tuple[Response, int]: """Create a consistent JSON response format. @@ -136,9 +136,9 @@ def make_response( tuple: (Flask JSON response, HTTP status code) """ response: ApiResponseDict = { - 'status': status, - 'message': message, - 'data': data, + "status": status, + "message": message, + "data": data, } return jsonify(response), code @@ -154,29 +154,29 @@ def create_app() -> MatHudFlask: Returns: Flask: Configured Flask application instance """ - app = MatHudFlask(__name__, template_folder='../templates', static_folder='../static') + app = MatHudFlask(__name__, template_folder="../templates", static_folder="../static") # Load environment variables from project .env and parent .env (API keys) AppManager._load_env() # Configure session management for authentication using modern CacheLib backend - app.secret_key = os.getenv('SECRET_KEY', secrets.token_hex(32)) + app.secret_key = os.getenv("SECRET_KEY", secrets.token_hex(32)) # Create session directory if it doesn't exist - session_dir = os.path.join(os.getcwd(), 'flask_session') + session_dir = os.path.join(os.getcwd(), "flask_session") os.makedirs(session_dir, exist_ok=True) # Modern Flask-Session configuration using CacheLib - app.config['SESSION_TYPE'] = 'cachelib' - app.config['SESSION_CACHELIB'] = FileSystemCache(cache_dir=session_dir) - app.config['SESSION_PERMANENT'] = False - app.config['SESSION_KEY_PREFIX'] = 'mathud:' + app.config["SESSION_TYPE"] = "cachelib" + app.config["SESSION_CACHELIB"] = FileSystemCache(cache_dir=session_dir) + app.config["SESSION_PERMANENT"] = False + app.config["SESSION_KEY_PREFIX"] = "mathud:" # Security settings for deployed environments if AppManager.is_deployed(): - app.config['SESSION_COOKIE_SECURE'] = True - app.config['SESSION_COOKIE_HTTPONLY'] = True - app.config['SESSION_COOKIE_SAMESITE'] = 'Lax' + app.config["SESSION_COOKIE_SECURE"] = True + app.config["SESSION_COOKIE_HTTPONLY"] = True + app.config["SESSION_COOKIE_SAMESITE"] = "Lax" # Initialize Flask-Session FlaskSession(app) @@ -204,6 +204,7 @@ def create_app() -> MatHudFlask: # Import and register routes from static.routes import register_routes + register_routes(app) return app @@ -213,6 +214,7 @@ def _initialize_tts() -> None: """Initialize TTS manager and log availability status.""" try: from static.tts_manager import get_tts_manager + manager = get_tts_manager() if manager.is_available(): print("TTS: Kokoro initialized successfully") diff --git a/static/client/ai_interface.py b/static/client/ai_interface.py index 286c1a3d..02547618 100644 --- a/static/client/ai_interface.py +++ b/static/client/ai_interface.py @@ -90,7 +90,9 @@ def __init__(self, canvas: "Canvas") -> None: self._stop_requested = False self._tests_running = False self._stop_tests_requested = False - self.available_functions: Dict[str, Any] = FunctionRegistry.get_available_functions(canvas, self.workspace_manager, self) + self.available_functions: Dict[str, Any] = FunctionRegistry.get_available_functions( + canvas, self.workspace_manager, self + ) self.undoable_functions: tuple[str, ...] = FunctionRegistry.get_undoable_functions() self.markdown_parser: MarkdownParser = MarkdownParser() # Slash command handler for local commands @@ -111,9 +113,9 @@ def __init__(self, canvas: "Canvas") -> None: self._needs_continuation_separator: bool = False # Add newline before next text after tool calls # Tool call log state self._tool_call_log_entries: list[dict[str, Any]] = [] - self._tool_call_log_element: Optional[Any] = None #
element - self._tool_call_log_summary: Optional[Any] = None # element - self._tool_call_log_content: Optional[Any] = None # content container div + self._tool_call_log_element: Optional[Any] = None #
element + self._tool_call_log_summary: Optional[Any] = None # element + self._tool_call_log_content: Optional[Any] = None # content container div # Timeout state self._response_timeout_id: Optional[int] = None # Chat message menu state @@ -140,16 +142,14 @@ def _safe_json_to_js(data: Any) -> Any: except Exception as exc: return window.JSON.parse(json.dumps({"error": str(exc)})) - window.getActionTraces = lambda: _safe_json_to_js( - self._trace_collector.export_traces_json() - ) - window.getLastActionTrace = lambda: _safe_json_to_js( - self._trace_collector.get_last_trace_json() - ) + window.getActionTraces = lambda: _safe_json_to_js(self._trace_collector.export_traces_json()) + window.getLastActionTrace = lambda: _safe_json_to_js(self._trace_collector.get_last_trace_json()) window.clearActionTraces = lambda: self._trace_collector.clear() window.replayLastTrace = lambda: _safe_json_to_js( self._trace_collector.replay_last_trace( - self.available_functions, self.undoable_functions, self.canvas, + self.available_functions, + self.undoable_functions, + self.canvas, ) ) @@ -157,6 +157,7 @@ def run_tests(self) -> Dict[str, Any]: """Run unit tests for the AIInterface class and return results to the AI as the function result.""" try: from test_runner import TestRunner + test_runner = TestRunner(self.canvas, self.available_functions, self.undoable_functions) # Run tests and get formatted results in one step @@ -164,23 +165,26 @@ def run_tests(self) -> Dict[str, Any]: return cast(Dict[str, Any], test_runner.format_results_for_ai(results)) except ImportError as e: print(f"Test runner not available: {e}") - return cast(Dict[str, Any], { - "tests_run": 0, - "failures": 0, - "errors": 1, - "failing_tests": [], - "error_tests": [{"test": "Test Runner Import", "error": f"Could not import test runner: {e}"}] - }) + return cast( + Dict[str, Any], + { + "tests_run": 0, + "failures": 0, + "errors": 1, + "failing_tests": [], + "error_tests": [{"test": "Test Runner Import", "error": f"Could not import test runner: {e}"}], + }, + ) def compare_canvas_state(self) -> None: """Send current canvas state to debug endpoint and log full-vs-summary output.""" try: payload = json.dumps({"canvas_state": self.canvas.get_canvas_state()}) req = ajax.ajax(timeout=20000) - req.bind('complete', self._on_compare_canvas_state_complete) - req.bind('error', self._on_compare_canvas_state_error) - req.open('POST', '/api/debug/canvas-state-comparison', True) - req.set_header('content-type', 'application/json') + req.bind("complete", self._on_compare_canvas_state_complete) + req.bind("error", self._on_compare_canvas_state_error) + req.open("POST", "/api/debug/canvas-state-comparison", True) + req.set_header("content-type", "application/json") req.send(payload) console.log("[MatHud] Requested canvas-state comparison...") except Exception as e: @@ -202,11 +206,11 @@ def _on_compare_canvas_state_complete(self, req: Any) -> None: console.log("=== Canvas State Comparison ===") console.log( "Full state:", - f"{metrics.get('full_bytes', 0)} bytes (~{metrics.get('full_estimated_tokens', 0)} tokens)" + f"{metrics.get('full_bytes', 0)} bytes (~{metrics.get('full_estimated_tokens', 0)} tokens)", ) console.log( "Summary:", - f"{metrics.get('summary_bytes', 0)} bytes (~{metrics.get('summary_estimated_tokens', 0)} tokens)" + f"{metrics.get('summary_bytes', 0)} bytes (~{metrics.get('summary_estimated_tokens', 0)} tokens)", ) console.log("Reduction:", f"{metrics.get('reduction_pct', 0.0)}%") console.log("Full state object:") @@ -236,6 +240,7 @@ async def run_tests_async( """ try: from test_runner import TestRunner + test_runner = TestRunner(self.canvas, self.available_functions, self.undoable_functions) # Run tests asynchronously and get formatted results @@ -243,13 +248,16 @@ async def run_tests_async( return cast(Dict[str, Any], test_runner.format_results_for_ai(results)) except ImportError as e: print(f"Test runner not available: {e}") - return cast(Dict[str, Any], { - "tests_run": 0, - "failures": 0, - "errors": 1, - "failing_tests": [], - "error_tests": [{"test": "Test Runner Import", "error": f"Could not import test runner: {e}"}] - }) + return cast( + Dict[str, Any], + { + "tests_run": 0, + "failures": 0, + "errors": 1, + "failing_tests": [], + "error_tests": [{"test": "Test Runner Import", "error": f"Could not import test runner: {e}"}], + }, + ) def initialize_autocomplete(self) -> None: """Initialize the command autocomplete popup. @@ -401,6 +409,7 @@ def make_remove_handler(index: int) -> Any: def handler(event: Any) -> None: event.stopPropagation() self._remove_attached_image(index) + return handler remove_btn.bind("click", make_remove_handler(idx)) @@ -457,10 +466,12 @@ def _store_results_in_canvas_state(self, call_results: Dict[str, Any]) -> None: for key, value in call_results.items(): # Skip storing workspace management functions and test results in computations - if key.startswith("list_workspaces") or \ - key.startswith("save_workspace") or \ - key.startswith("load_workspace") or \ - key.startswith("run_tests"): + if ( + key.startswith("list_workspaces") + or key.startswith("save_workspace") + or key.startswith("load_workspace") + or key.startswith("run_tests") + ): continue if not ProcessFunctionCalls.is_successful_result(value): @@ -486,7 +497,7 @@ def _render_math(self) -> None: """Trigger MathJax rendering for newly added content.""" try: # Check if MathJax is available - if hasattr(window, 'MathJax') and hasattr(window.MathJax, 'typesetPromise'): + if hasattr(window, "MathJax") and hasattr(window.MathJax, "typesetPromise"): # Re-render math in the chat history window.MathJax.typesetPromise([document["chat-history"]]) except Exception: @@ -795,29 +806,29 @@ def _strip_markdown_for_tts(self, text: str) -> str: result = text # Remove code blocks - result = re.sub(r'```[\s\S]*?```', '', result) - result = re.sub(r'`[^`]+`', '', result) + result = re.sub(r"```[\s\S]*?```", "", result) + result = re.sub(r"`[^`]+`", "", result) # Remove headers - result = re.sub(r'^#{1,6}\s+', '', result, flags=re.MULTILINE) + result = re.sub(r"^#{1,6}\s+", "", result, flags=re.MULTILINE) # Remove bold/italic - result = re.sub(r'\*\*([^*]+)\*\*', r'\1', result) - result = re.sub(r'\*([^*]+)\*', r'\1', result) - result = re.sub(r'__([^_]+)__', r'\1', result) - result = re.sub(r'_([^_]+)_', r'\1', result) + result = re.sub(r"\*\*([^*]+)\*\*", r"\1", result) + result = re.sub(r"\*([^*]+)\*", r"\1", result) + result = re.sub(r"__([^_]+)__", r"\1", result) + result = re.sub(r"_([^_]+)_", r"\1", result) # Remove links, keep text - result = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', result) + result = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", result) # Remove images - result = re.sub(r'!\[[^\]]*\]\([^)]+\)', '', result) + result = re.sub(r"!\[[^\]]*\]\([^)]+\)", "", result) # Remove horizontal rules - result = re.sub(r'^[-*_]{3,}$', '', result, flags=re.MULTILINE) + result = re.sub(r"^[-*_]{3,}$", "", result, flags=re.MULTILINE) # Clean up extra whitespace - result = re.sub(r'\n{3,}', '\n\n', result) + result = re.sub(r"\n{3,}", "\n\n", result) result = result.strip() return result @@ -969,6 +980,7 @@ def _create_message_element( def make_image_click_handler(url: str) -> Any: def handler(event: Any) -> None: self._show_image_modal(url) + return handler img.bind("click", make_image_click_handler(data_url)) @@ -985,10 +997,10 @@ def handler(event: Any) -> None: print(f"Error creating message element: {e}") # Fall back to simple paragraph if sender == "AI": - content = message.replace('\n', '
') - return html.P(f'{sender}: {content}', innerHTML=True) + content = message.replace("\n", "
") + return html.P(f"{sender}: {content}", innerHTML=True) else: - return html.P(f'{sender}: {message}') + return html.P(f"{sender}: {message}") def _print_ai_message_in_chat(self, ai_message: str) -> None: """Print an AI message to the chat history with markdown support and scroll to bottom.""" @@ -1193,9 +1205,7 @@ def _add_tool_call_entries(self, tool_calls: list[dict[str, Any]], call_results: error_message = result_value if is_error else "" # Full untruncated args for the expanded view - args_full = ", ".join( - f"{k}: {v}" for k, v in args.items() if k != "canvas" - ) + args_full = ", ".join(f"{k}: {v}" for k, v in args.items() if k != "canvas") # Format result for display (truncate if too long) result_display = "" @@ -1346,6 +1356,7 @@ def _finalize_stream_message(self, final_message: Optional[str] = None) -> None: if self._reasoning_summary is not None and self._request_start_time is not None: try: from browser import window + elapsed_ms = window.Date.now() - self._request_start_time elapsed_seconds = int(elapsed_ms / 1000) self._reasoning_summary.text = f"Thought for {elapsed_seconds} seconds" @@ -1428,7 +1439,12 @@ def _remove_empty_response_container(self) -> None: has_tool_call_log = bool(self._tool_call_log_entries) # Only remove if there's NO actual text content anywhere and no tool call log - if self._stream_message_container is not None and not has_buffer_text and not has_element_text and not has_tool_call_log: + if ( + self._stream_message_container is not None + and not has_buffer_text + and not has_element_text + and not has_tool_call_log + ): history = document["chat-history"] try: history.removeChild(self._stream_message_container) @@ -1452,10 +1468,10 @@ def _on_stream_final(self, event_obj: Any) -> None: try: event = self._normalize_stream_event(event_obj) - finish_reason = event.get('finish_reason', 'stop') - ai_tool_calls = event.get('ai_tool_calls', []) - ai_message = event.get('ai_message', '') - error_details = event.get('error_details', '') + finish_reason = event.get("finish_reason", "stop") + ai_tool_calls = event.get("ai_tool_calls", []) + ai_message = event.get("ai_message", "") + error_details = event.get("error_details", "") # Log error details to console for debugging if finish_reason == "error": @@ -1482,7 +1498,10 @@ def _on_stream_final(self, event_obj: Any) -> None: traced_calls: list[Dict[str, Any]] = [] try: call_results, traced_calls = ProcessFunctionCalls.get_results_traced( - ai_tool_calls, self.available_functions, self.undoable_functions, self.canvas, + ai_tool_calls, + self.available_functions, + self.undoable_functions, + self.canvas, ) self._store_results_in_canvas_state(call_results) self._add_tool_call_entries(ai_tool_calls, call_results) @@ -1493,7 +1512,10 @@ def _on_stream_final(self, event_obj: Any) -> None: state_after = self.canvas.get_canvas_state() total_ms = window.performance.now() - t0 trace = self._trace_collector.build_trace( - state_before, state_after, traced_calls, total_ms, + state_before, + state_after, + traced_calls, + total_ms, ) self._trace_collector.store(trace) except Exception: @@ -1506,7 +1528,10 @@ def _on_stream_final(self, event_obj: Any) -> None: state_after = self.canvas.get_canvas_state() total_ms = window.performance.now() - t0 trace = self._trace_collector.build_trace( - state_before, state_after, traced_calls, total_ms, + state_before, + state_after, + traced_calls, + total_ms, ) self._trace_collector.store(trace) trace_summary = self._trace_collector.build_compact_summary(trace) @@ -1517,8 +1542,10 @@ def _on_stream_final(self, event_obj: Any) -> None: if self._stream_buffer.strip(): self._needs_continuation_separator = True self._send_prompt_to_ai( - None, json.dumps(call_results), - canvas_state=state_after, action_trace=trace_summary, + None, + json.dumps(call_results), + canvas_state=state_after, + action_trace=trace_summary, ) except Exception as e: # Always capture trace even on partial failure @@ -1526,7 +1553,10 @@ def _on_stream_final(self, event_obj: Any) -> None: state_after = self.canvas.get_canvas_state() total_ms = window.performance.now() - t0 trace = self._trace_collector.build_trace( - state_before, state_after, traced_calls, total_ms, + state_before, + state_after, + traced_calls, + total_ms, ) self._trace_collector.store(trace) except Exception: @@ -1584,10 +1614,7 @@ def _restore_user_message_on_error(self) -> None: chat_input.value = self._last_user_message # Apply visual error feedback chat_input.classList.add("error-flash") - window.setTimeout( - lambda: chat_input.classList.remove("error-flash"), - 2000 - ) + window.setTimeout(lambda: chat_input.classList.remove("error-flash"), 2000) except Exception as e: print(f"Error restoring user message: {e}") @@ -1603,7 +1630,17 @@ def _normalize_stream_event(self, event_obj: Any) -> Dict[str, Any]: except Exception: pass result = {} - for key in ["type", "text", "ai_message", "ai_tool_calls", "finish_reason", "error_details", "level", "message", "source"]: + for key in [ + "type", + "text", + "ai_message", + "ai_tool_calls", + "finish_reason", + "error_details", + "level", + "message", + "source", + ]: try: result[key] = getattr(event_obj, key) except Exception: @@ -1643,7 +1680,7 @@ def _print_system_message_in_chat(self, message: str) -> None: sender_label = html.SPAN("System: ", Class="chat-sender system") # Check if message is long and needs expandable display - line_count = message.count('\n') + line_count = message.count("\n") is_long_message = len(message) > 800 or line_count > 20 if is_long_message: @@ -1687,11 +1724,11 @@ def _create_expandable_content(self, message: str) -> Any: A DOM element with expandable content """ # Create preview (first ~500 chars or 10 lines) - lines = message.split('\n') + lines = message.split("\n") if len(lines) > 10: - preview_text = '\n'.join(lines[:10]) + '\n...' + preview_text = "\n".join(lines[:10]) + "\n..." elif len(message) > 500: - preview_text = message[:500] + '...' + preview_text = message[:500] + "..." else: preview_text = message @@ -1739,12 +1776,13 @@ def _escape_html(self, text: str) -> str: Returns: Escaped text safe for HTML """ - return (text - .replace("&", "&") + return ( + text.replace("&", "&") .replace("<", "<") .replace(">", ">") .replace('"', """) - .replace("'", "'")) + .replace("'", "'") + ) def _debug_log_ai_response(self, ai_message: str, ai_function_calls: Any, finish_reason: str) -> None: """Log debug information about the AI response.""" @@ -1791,10 +1829,7 @@ def _start_response_timeout(self, use_reasoning_timeout: bool = False) -> None: # Cancel any existing timeout first self._cancel_response_timeout() timeout_ms = self.REASONING_TIMEOUT_MS if use_reasoning_timeout else self.AI_RESPONSE_TIMEOUT_MS - self._response_timeout_id = window.setTimeout( - self._on_response_timeout, - timeout_ms - ) + self._response_timeout_id = window.setTimeout(self._on_response_timeout, timeout_ms) except Exception as e: print(f"Error starting response timeout: {e}") @@ -1824,7 +1859,7 @@ def _on_response_timeout(self) -> None: def _abort_current_stream(self) -> None: """Abort the current streaming connection if one is active.""" try: - if hasattr(window, 'abortCurrentStream'): + if hasattr(window, "abortCurrentStream"): window.abortCurrentStream() except Exception as e: print(f"Error aborting stream: {e}") @@ -1861,34 +1896,45 @@ def _process_ai_response(self, ai_message: str, tool_calls: Any, finish_reason: if finish_reason == "stop" or finish_reason == "error": self._print_ai_message_in_chat(ai_message) self._enable_send_controls() - else: # finish_reason == "tool_calls" or "function_call" + else: # finish_reason == "tool_calls" or "function_call" state_before = self.canvas.get_canvas_state() t0 = window.performance.now() traced_calls: list[Dict[str, Any]] = [] try: call_results, traced_calls = ProcessFunctionCalls.get_results_traced( - tool_calls, self.available_functions, self.undoable_functions, self.canvas, + tool_calls, + self.available_functions, + self.undoable_functions, + self.canvas, ) self._store_results_in_canvas_state(call_results) state_after = self.canvas.get_canvas_state() total_ms = window.performance.now() - t0 trace = self._trace_collector.build_trace( - state_before, state_after, traced_calls, total_ms, + state_before, + state_after, + traced_calls, + total_ms, ) self._trace_collector.store(trace) trace_summary = self._trace_collector.build_compact_summary(trace) self._send_prompt_to_ai( - None, json.dumps(call_results), - canvas_state=state_after, action_trace=trace_summary, + None, + json.dumps(call_results), + canvas_state=state_after, + action_trace=trace_summary, ) except Exception as e: try: state_after = self.canvas.get_canvas_state() total_ms = window.performance.now() - t0 trace = self._trace_collector.build_trace( - state_before, state_after, traced_calls, total_ms, + state_before, + state_after, + traced_calls, + total_ms, ) self._trace_collector.store(trace) except Exception: @@ -1907,17 +1953,17 @@ def _on_complete(self, request: Any) -> None: try: if request.status == 200 or request.status == 0: # Extract data from the proper response structure - response_data = request.json.get('data') + response_data = request.json.get("data") if not response_data: - error_msg = request.json.get('message', 'Invalid response format') + error_msg = request.json.get("message", "Invalid response format") print(f"Error: {error_msg}") document["ai-response"].text = error_msg self._enable_send_controls() return - ai_message = response_data.get('ai_message') - ai_function_calls = response_data.get('ai_tool_calls') - finish_reason = response_data.get('finish_reason') + ai_message = response_data.get("ai_message") + ai_function_calls = response_data.get("ai_tool_calls") + finish_reason = response_data.get("finish_reason") # Parse the AI's response and create / delete drawables as needed self._process_ai_response(ai_message, ai_function_calls, finish_reason) @@ -1935,13 +1981,13 @@ def _create_request_payload( action_trace: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """Create the JSON payload for the request, optionally including SVG and Canvas2D state.""" - payload: Dict[str, Any] = {'message': prompt} + payload: Dict[str, Any] = {"message": prompt} if action_trace is not None: - payload['action_trace'] = action_trace + payload["action_trace"] = action_trace vision_enabled = self._is_vision_enabled(prompt) renderer_mode = getattr(self.canvas, "renderer_mode", None) if isinstance(renderer_mode, str): - payload['renderer_mode'] = renderer_mode + payload["renderer_mode"] = renderer_mode svg_state_payload: Optional[Dict[str, Any]] = None if include_svg: @@ -1951,15 +1997,12 @@ def _create_request_payload( container = document["math-container"] rect = container.getBoundingClientRect() svg_state_payload = { - 'content': svg_content, - 'dimensions': { - 'width': rect.width, - 'height': rect.height - }, - 'viewBox': svg_element.getAttribute("viewBox"), - 'transform': svg_element.getAttribute("transform") + "content": svg_content, + "dimensions": {"width": rect.width, "height": rect.height}, + "viewBox": svg_element.getAttribute("viewBox"), + "transform": svg_element.getAttribute("transform"), } - payload['svg_state'] = svg_state_payload + payload["svg_state"] = svg_state_payload except Exception as exc: print(f"Failed to collect SVG state: {exc}") @@ -1968,17 +2011,17 @@ def _create_request_payload( snapshot: Dict[str, Any] = {} if isinstance(renderer_mode, str): - snapshot['renderer_mode'] = renderer_mode + snapshot["renderer_mode"] = renderer_mode if svg_state_payload: - snapshot['svg_state'] = svg_state_payload + snapshot["svg_state"] = svg_state_payload if renderer_mode == "canvas2d": canvas_image = self._capture_canvas2d_snapshot() if canvas_image: - snapshot['canvas_image'] = canvas_image + snapshot["canvas_image"] = canvas_image if snapshot: - payload['vision_snapshot'] = snapshot + payload["vision_snapshot"] = snapshot return payload @@ -2011,10 +2054,10 @@ def _capture_canvas2d_snapshot(self) -> Optional[str]: def _make_request(self, payload: Dict[str, Any]) -> None: """Send an AJAX request with the given payload.""" req = ajax.ajax() - req.bind('complete', self._on_complete) - req.bind('error', self._on_error) - req.open('POST', '/send_message', True) - req.set_header('content-type', 'application/json') + req.bind("complete", self._on_complete) + req.bind("error", self._on_error) + req.open("POST", "/send_message", True) + req.set_header("content-type", "application/json") req.send(json.dumps(payload)) def _start_streaming_request(self, payload: Dict[str, Any]) -> None: @@ -2031,7 +2074,7 @@ def _start_streaming_request(self, payload: Dict[str, Any]) -> None: self._on_stream_final, self._on_stream_error, self._on_stream_reasoning, - self._on_stream_log + self._on_stream_log, ) except Exception as e: print(f"Falling back to non-streaming request due to error: {e}") @@ -2065,13 +2108,15 @@ def _send_prompt_to_ai_stream( "user_message": user_message, "tool_call_results": tool_call_results, "use_vision": use_vision, - "ai_model": document["ai-model-selector"].value + "ai_model": document["ai-model-selector"].value, } # Include attached images if provided (works independently of vision toggle) if attached_images: prompt_json["attached_images"] = attached_images prompt = json.dumps(prompt_json) - print(f'Prompt for AI (stream): {prompt[:500]}...' if len(prompt) > 500 else f'Prompt for AI (stream): {prompt}') + print( + f"Prompt for AI (stream): {prompt[:500]}..." if len(prompt) > 500 else f"Prompt for AI (stream): {prompt}" + ) # For new user messages, reset all state including containers and buffers # For tool call results, preserve everything to keep intermediary text visible @@ -2116,7 +2161,7 @@ def _send_prompt_to_ai( "user_message": user_message, "tool_call_results": tool_call_results, "use_vision": use_vision, - "ai_model": document["ai-model-selector"].value + "ai_model": document["ai-model-selector"].value, } # Include attached images if provided (works independently of vision toggle) @@ -2215,12 +2260,10 @@ async def _execute_tests_async(self) -> None: self._print_user_message_in_chat("Run tests (direct execution)") # Run tests asynchronously with stop callback - results = await self.run_tests_async( - should_stop=lambda: self._stop_tests_requested - ) + results = await self.run_tests_async(should_stop=lambda: self._stop_tests_requested) # Check if tests were stopped - was_stopped = results.get('stopped', False) + was_stopped = results.get("stopped", False) if was_stopped: summary = ( @@ -2238,14 +2281,14 @@ async def _execute_tests_async(self) -> None: f"- **Errors:** {results.get('errors', 0)}\n" ) - if results.get('failing_tests'): + if results.get("failing_tests"): summary += "\n#### Failures:\n" - for fail in results['failing_tests']: + for fail in results["failing_tests"]: summary += f"- **{fail['test']}**: {fail['error']}\n" - if results.get('error_tests'): + if results.get("error_tests"): summary += "\n#### Errors:\n" - for err in results['error_tests']: + for err in results["error_tests"]: summary += f"- **{err['test']}**: {err['error']}\n" self._print_ai_message_in_chat(summary) @@ -2277,7 +2320,7 @@ def interact_with_ai(self, event: Any) -> None: if user_message or has_images: # Buffer message for recovery on error before clearing self._last_user_message = user_message - document["chat-input"].value = '' + document["chat-input"].value = "" self.send_user_message(user_message) def start_new_conversation(self, event: Any) -> None: @@ -2293,6 +2336,6 @@ def start_new_conversation(self, event: Any) -> None: # 4. Call the backend to reset the AI conversation state req = ajax.ajax() - req.open('POST', '/new_conversation', True) - req.set_header('content-type', 'application/json') + req.open("POST", "/new_conversation", True) + req.set_header("content-type", "application/json") req.send() diff --git a/static/client/browser.pyi b/static/client/browser.pyi deleted file mode 100644 index 73e1f964..00000000 --- a/static/client/browser.pyi +++ /dev/null @@ -1,55 +0,0 @@ -"""Type stubs for Brython browser module.""" - -from typing import Any, Callable, Optional - -class DOMNode: - """Base DOM node type.""" - def getBoundingClientRect(self) -> Any: ... - def appendChild(self, child: Any) -> Any: ... - def removeChild(self, child: Any) -> Any: ... - def setAttribute(self, name: str, value: str) -> None: ... - def getAttribute(self, name: str) -> Optional[str]: ... - def __getitem__(self, key: str) -> Any: ... - def __setitem__(self, key: str, value: Any) -> None: ... - -class Document: - """Document object.""" - def __getitem__(self, key: str) -> DOMNode: ... - -document: Document - -class HTML: - """HTML element factory.""" - def __call__(self, tag: str, **kwargs: Any) -> DOMNode: ... - -html: HTML - -class Ajax: - """AJAX request handler.""" - def __init__(self, url: str, **kwargs: Any) -> None: ... - def send(self, *args: Any) -> None: ... - def bind(self, event: str, handler: Callable[..., None]) -> None: ... - -ajax: Any # Module-level alias - -class Window: - """Browser window object.""" - def __getitem__(self, key: str) -> Any: ... - def __setitem__(self, key: str, value: Any) -> None: ... - -window: Window - -class Console: - """Console object.""" - def log(self, *args: Any) -> None: ... - def error(self, *args: Any) -> None: ... - def warn(self, *args: Any) -> None: ... - -console: Console - -class SVG: - """SVG element factory.""" - def __call__(self, tag: str, **kwargs: Any) -> DOMNode: ... - -svg: SVG - diff --git a/static/client/canvas.py b/static/client/canvas.py index 2b69648c..901c6b4e 100644 --- a/static/client/canvas.py +++ b/static/client/canvas.py @@ -89,7 +89,10 @@ class Canvas: zoom_direction (int): Current zoom direction (-1=in, 1=out, 0=none) zoom_step (float): Zoom increment per step """ - def __init__(self, width: float, height: float, draw_enabled: bool = True, renderer: Optional[RendererProtocol] = None) -> None: + + def __init__( + self, width: float, height: float, draw_enabled: bool = True, renderer: Optional[RendererProtocol] = None + ) -> None: """Initialize the mathematical canvas with specified dimensions. Sets up the coordinate system, managers, and initial state for mathematical visualization. @@ -284,9 +287,7 @@ def _invalidate_drawable_zoom_cache(self, drawable: Any, apply_zoom: bool) -> No if apply_zoom and hasattr(drawable, "_invalidate_cache_on_zoom"): drawable._invalidate_cache_on_zoom() - def _render_drawable_with_renderer( - self, renderer: Optional[RendererProtocol], drawable: Any - ) -> None: + def _render_drawable_with_renderer(self, renderer: Optional[RendererProtocol], drawable: Any) -> None: if renderer is not None: try: renderer.render(drawable, self.coordinate_mapper) @@ -318,9 +319,7 @@ def _is_drawable_visible(self, drawable: "Drawable") -> bool: def _safe_drawable_class_name(self, drawable: Any) -> str: try: return str( - drawable.get_class_name() - if hasattr(drawable, "get_class_name") - else drawable.__class__.__name__ + drawable.get_class_name() if hasattr(drawable, "get_class_name") else drawable.__class__.__name__ ) except Exception: return str(drawable.__class__.__name__) @@ -785,9 +784,7 @@ def _append_plot_prefixes(self, plots_state: Any, prefixes: List[str]) -> None: if isinstance(name, str) and name: prefixes.append(f"{name}_bar_") - def _filter_bars_excluding_prefixes( - self, bars_state: List[Any], prefixes: List[str] - ) -> List[Any]: + def _filter_bars_excluding_prefixes(self, bars_state: List[Any], prefixes: List[str]) -> List[Any]: kept: List[Any] = [] for item in bars_state: if not isinstance(item, dict): @@ -948,7 +945,9 @@ def create_segment( label_visible=label_visible, ) - def delete_segment(self, x1: float, y1: float, x2: float, y2: float, delete_children: bool = True, delete_parents: bool = False) -> bool: + def delete_segment( + self, x1: float, y1: float, x2: float, y2: float, delete_children: bool = True, delete_parents: bool = False + ) -> bool: """Delete a segment by its endpoint coordinates""" return bool(self.drawable_manager.delete_segment(x1, y1, x2, y2, delete_children, delete_parents)) @@ -981,7 +980,9 @@ def any_segment_part_visible_in_canvas_area(self, x1: float, y1: float, x2: floa intersect_left = MathUtils.segments_intersect(x1, y1, x2, y2, 0, self.height, 0, 0) point1_visible: bool = self.is_point_within_canvas_visible_area(x1, y1) point2_visible: bool = self.is_point_within_canvas_visible_area(x2, y2) - return bool(intersect_top or intersect_right or intersect_bottom or intersect_left or point1_visible or point2_visible) + return bool( + intersect_top or intersect_right or intersect_bottom or intersect_left or point1_visible or point2_visible + ) def get_vector(self, x1: float, y1: float, x2: float, y2: float) -> Optional["Drawable"]: """Get a vector by its origin and tip coordinates""" @@ -1040,18 +1041,21 @@ def plot_distribution( fill_opacity: Optional[float] = None, bar_count: Optional[float] = None, ) -> Dict[str, Any]: - return cast(Dict[str, Any], self.drawable_manager.plot_distribution( - name=name, - representation=representation, - distribution_type=distribution_type, - distribution_params=distribution_params, - plot_bounds=plot_bounds, - shade_bounds=shade_bounds, - curve_color=curve_color, - fill_color=fill_color, - fill_opacity=fill_opacity, - bar_count=bar_count, - )) + return cast( + Dict[str, Any], + self.drawable_manager.plot_distribution( + name=name, + representation=representation, + distribution_type=distribution_type, + distribution_params=distribution_params, + plot_bounds=plot_bounds, + shade_bounds=shade_bounds, + curve_color=curve_color, + fill_color=fill_color, + fill_opacity=fill_opacity, + bar_count=bar_count, + ), + ) def plot_bars( self, @@ -1068,19 +1072,22 @@ def plot_bars( x_start: Optional[float] = None, y_base: Optional[float] = None, ) -> Dict[str, Any]: - return cast(Dict[str, Any], self.drawable_manager.plot_bars( - name=name, - values=values or [], - labels_below=labels_below or [], - labels_above=labels_above, - bar_spacing=bar_spacing, - bar_width=bar_width, - stroke_color=stroke_color, - fill_color=fill_color, - fill_opacity=fill_opacity, - x_start=x_start, - y_base=y_base, - )) + return cast( + Dict[str, Any], + self.drawable_manager.plot_bars( + name=name, + values=values or [], + labels_below=labels_below or [], + labels_above=labels_above, + bar_spacing=bar_spacing, + bar_width=bar_width, + stroke_color=stroke_color, + fill_color=fill_color, + fill_opacity=fill_opacity, + x_start=x_start, + y_base=y_base, + ), + ) def delete_plot(self, name: str) -> bool: return bool(self.drawable_manager.delete_plot(name)) @@ -1098,17 +1105,20 @@ def fit_regression( show_points: Optional[bool] = None, point_color: Optional[str] = None, ) -> Dict[str, Any]: - return cast(Dict[str, Any], self.drawable_manager.fit_regression( - name=name, - x_data=x_data if x_data is not None else [], - y_data=y_data if y_data is not None else [], - model_type=model_type, - degree=degree, - plot_bounds=plot_bounds, - curve_color=curve_color, - show_points=show_points, - point_color=point_color, - )) + return cast( + Dict[str, Any], + self.drawable_manager.fit_regression( + name=name, + x_data=x_data if x_data is not None else [], + y_data=y_data if y_data is not None else [], + model_type=model_type, + degree=degree, + plot_bounds=plot_bounds, + curve_color=curve_color, + show_points=show_points, + point_color=point_color, + ), + ) # ------------------- Graph Methods ------------------- def create_graph(self, graph_state: "GraphState") -> "Drawable": @@ -1461,9 +1471,7 @@ def create_tangent_line( Returns: The created Segment drawable """ - return self.drawable_manager.create_tangent_line( - curve_name, parameter, name=name, length=length, color=color - ) + return self.drawable_manager.create_tangent_line(curve_name, parameter, name=name, length=length, color=color) def create_normal_line( self, @@ -1485,9 +1493,7 @@ def create_normal_line( Returns: The created Segment drawable """ - return self.drawable_manager.create_normal_line( - curve_name, parameter, name=name, length=length, color=color - ) + return self.drawable_manager.create_normal_line(curve_name, parameter, name=name, length=length, color=color) # ------------------- Construction Methods ------------------- @@ -1514,9 +1520,7 @@ def create_perpendicular_bisector( color: Optional[str] = None, ) -> "Drawable": """Create the perpendicular bisector of a segment.""" - return self.drawable_manager.create_perpendicular_bisector( - segment_name, length=length, name=name, color=color - ) + return self.drawable_manager.create_perpendicular_bisector(segment_name, length=length, name=name, color=color) def create_perpendicular_from_point( self, @@ -1527,9 +1531,8 @@ def create_perpendicular_from_point( color: Optional[str] = None, ) -> Dict[str, Any]: """Drop a perpendicular from a point to a segment.""" - return self.drawable_manager.create_perpendicular_from_point( - point_name, segment_name, name=name, color=color - ) + result: Dict[str, Any] = self.drawable_manager.create_perpendicular_from_point(point_name, segment_name, name=name, color=color) + return result def create_angle_bisector( self, @@ -1544,8 +1547,7 @@ def create_angle_bisector( ) -> "Drawable": """Create a segment along the bisector of an angle.""" return self.drawable_manager.create_angle_bisector( - vertex_name, p1_name, p2_name, - angle_name=angle_name, length=length, name=name, color=color + vertex_name, p1_name, p2_name, angle_name=angle_name, length=length, name=name, color=color ) def create_parallel_line( @@ -1575,8 +1577,11 @@ def create_circumcircle( """Create the circumscribed circle of a triangle or three points.""" return self.drawable_manager.create_circumcircle( triangle_name=triangle_name, - p1_name=p1_name, p2_name=p2_name, p3_name=p3_name, - name=name, color=color, + p1_name=p1_name, + p2_name=p2_name, + p3_name=p3_name, + name=name, + color=color, ) def create_incircle( @@ -1588,7 +1593,9 @@ def create_incircle( ) -> "Drawable": """Create the inscribed circle of a triangle.""" return self.drawable_manager.create_incircle( - triangle_name, name=name, color=color, + triangle_name, + name=name, + color=color, ) def translate_object(self, name: str, x_offset: float, y_offset: float) -> bool: @@ -1619,13 +1626,16 @@ def reflect_object( segment_name: Optional[str] = None, ) -> bool: """Reflect a drawable across an axis, line, or segment.""" - return bool(self.transformations_manager.reflect_object( - name, axis, - line_a=float(line_a) if line_a is not None else 0, - line_b=float(line_b) if line_b is not None else 0, - line_c=float(line_c) if line_c is not None else 0, - segment_name=str(segment_name) if segment_name else "", - )) + return bool( + self.transformations_manager.reflect_object( + name, + axis, + line_a=float(line_a) if line_a is not None else 0, + line_b=float(line_b) if line_b is not None else 0, + line_c=float(line_c) if line_c is not None else 0, + segment_name=str(segment_name) if segment_name else "", + ) + ) def scale_object( self, @@ -1669,9 +1679,7 @@ def zoom(self, center_x: float, center_y: float, range_val: float, range_axis: s range_val: Half-size for the specified axis range_axis: 'x' or 'y' - which axis the range applies to """ - left, right, top, bottom = self._compute_zoom_bounds( - center_x, center_y, range_val, range_axis - ) + left, right, top, bottom = self._compute_zoom_bounds(center_x, center_y, range_val, range_axis) self.coordinate_mapper.set_visible_bounds(left, right, top, bottom) self._invalidate_cartesian_cache_on_zoom() self.draw(apply_zoom=True) @@ -1709,7 +1717,7 @@ def find_largest_connected_shape(self, shape: "Drawable") -> tuple[Optional["Dra return None, None # If the shape is a rectangle, don't check for parent shapes - if shape.get_class_name() == 'Rectangle': + if shape.get_class_name() == "Rectangle": return None, None rectangles = self.drawable_manager.drawables.Rectangles @@ -1718,7 +1726,7 @@ def find_largest_connected_shape(self, shape: "Drawable") -> tuple[Optional["Dra return rectangle_parent, rectangle_parent.get_class_name() # Only check triangles if no rectangle was found and the shape isn't a triangle - if shape.get_class_name() == 'Triangle': + if shape.get_class_name() == "Triangle": return None, None triangles = self.drawable_manager.drawables.Triangles @@ -1776,9 +1784,19 @@ def _collect_shape_segments(self, shape: "Drawable") -> List["Drawable"]: segments.append(shape.segment4) return segments - def create_colored_area(self, drawable1_name: str, drawable2_name: Optional[str] = None, left_bound: Optional[float] = None, right_bound: Optional[float] = None, color: str = default_area_fill_color, opacity: float = default_area_opacity) -> "Drawable": + def create_colored_area( + self, + drawable1_name: str, + drawable2_name: Optional[str] = None, + left_bound: Optional[float] = None, + right_bound: Optional[float] = None, + color: str = default_area_fill_color, + opacity: float = default_area_opacity, + ) -> "Drawable": """Creates a vertical bounded colored area between two functions, two segments, or a function and a segment""" - return self.drawable_manager.create_colored_area(drawable1_name, drawable2_name, left_bound, right_bound, color, opacity) + return self.drawable_manager.create_colored_area( + drawable1_name, drawable2_name, left_bound, right_bound, color, opacity + ) def create_region_colored_area( self, @@ -1864,16 +1882,33 @@ def name_generator(self) -> Any: # NameGenerator # ------------------- Angle Methods ------------------- - def create_angle(self, vx: float, vy: float, p1x: float, p1y: float, p2x: float, p2y: float, color: Optional[str] = None, angle_name: Optional[str] = None, is_reflex: bool = False, extra_graphics: bool = True) -> Optional["Drawable"]: + def create_angle( + self, + vx: float, + vy: float, + p1x: float, + p1y: float, + p2x: float, + p2y: float, + color: Optional[str] = None, + angle_name: Optional[str] = None, + is_reflex: bool = False, + extra_graphics: bool = True, + ) -> Optional["Drawable"]: """Create an angle defined by three points via AngleManager.""" angle_manager = self._get_angle_manager() if angle_manager: return angle_manager.create_angle( - vx, vy, p1x, p1y, p2x, p2y, + vx, + vy, + p1x, + p1y, + p2x, + p2y, color=color, angle_name=angle_name, is_reflex=is_reflex, - extra_graphics=extra_graphics + extra_graphics=extra_graphics, ) return None @@ -1888,9 +1923,7 @@ def update_angle(self, name: str, new_color: Optional[str] = None) -> bool: """Update editable angle properties via AngleManager.""" angle_manager = self._get_angle_manager() if angle_manager: - return bool(angle_manager.update_angle( - name, new_color=new_color - )) + return bool(angle_manager.update_angle(name, new_color=new_color)) return False def _get_angle_manager(self) -> Any: @@ -2124,9 +2157,7 @@ def get_renderer_mode(self) -> str: def _resolve_renderer_mode(self, renderer: Optional[RendererProtocol]) -> str: if renderer is None: return "none" - mode_from_name = self._resolve_renderer_mode_from_name( - renderer.__class__.__name__.lower() - ) + mode_from_name = self._resolve_renderer_mode_from_name(renderer.__class__.__name__.lower()) if mode_from_name is not None: return mode_from_name module = getattr(renderer, "__module__", "") diff --git a/static/client/canvas_event_handler.py b/static/client/canvas_event_handler.py index 06b57927..1f4fad3e 100644 --- a/static/client/canvas_event_handler.py +++ b/static/client/canvas_event_handler.py @@ -32,12 +32,7 @@ import time from browser import document, window -from constants import ( - double_click_threshold_s, - zoom_in_scale_factor, - zoom_out_scale_factor, - mousemove_throttle_ms -) +from constants import double_click_threshold_s, zoom_in_scale_factor, zoom_out_scale_factor, mousemove_throttle_ms from drawables_aggregator import Position if TYPE_CHECKING: @@ -57,6 +52,7 @@ def throttle(wait_ms: float) -> Callable[[Callable[..., Any]], Callable[..., Any Returns: A decorator function that will throttle the decorated function """ + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: last_call: Optional[float] = None queued: Optional[Any] = None @@ -83,10 +79,7 @@ def throttled(*args: Any, **kwargs: Any) -> None: else: # Schedule to run at next interval remaining_time: float = wait_ms - elapsed - queued = window.setTimeout( - lambda: throttled(*args, **kwargs), - remaining_time - ) + queued = window.setTimeout(lambda: throttled(*args, **kwargs), remaining_time) except Exception as e: print(f"Error in throttle: {str(e)}") @@ -111,6 +104,7 @@ class CanvasEventHandler: initial_pinch_distance (float): Initial distance between two fingers for pinch-to-zoom last_pinch_distance (float): Last recorded distance for pinch gesture """ + def __init__(self, canvas: "Canvas", ai_interface: "AIInterface") -> None: """Initialize event handler with canvas and AI interface integration. @@ -225,7 +219,7 @@ def _settle_callback() -> None: def _update_zoom_point(self, event: Any) -> None: """Update the zoom point based on mouse position.""" try: - svg_canvas: Any = document['math-svg'] + svg_canvas: Any = document["math-svg"] rect: Any = svg_canvas.getBoundingClientRect() # Save the current zoom point and update it to the mouse position self.canvas.zoom_point = Position(event.clientX - rect.left, event.clientY - rect.top) @@ -264,6 +258,7 @@ def _apply_zoom_with_anchor(self, zoom_factor: float) -> None: if zp is None: # Fallback: zoom about canvas center from drawables_aggregator import Position + zp = Position(self.canvas.width / 2, self.canvas.height / 2) self.canvas.zoom_point = zp @@ -311,7 +306,7 @@ def _get_decimal_places_for_fraction(self, value: float) -> int: """Calculate decimal places for fractional values.""" try: decimal_part: str = format(value, ".10f").split(".")[1] - leading_zeros: int = len(decimal_part) - len(decimal_part.lstrip('0')) + leading_zeros: int = len(decimal_part) - len(decimal_part.lstrip("0")) return leading_zeros + 2 except Exception as e: print(f"Error calculating decimal places for fraction: {str(e)}") @@ -347,7 +342,10 @@ def handle_mousedown(self, event: Any) -> None: def _is_double_click(self, current_timestamp: float) -> bool: """Determine if this is a double click based on timing.""" try: - return self.last_click_timestamp is not None and (current_timestamp - self.last_click_timestamp) < double_click_threshold_s + return ( + self.last_click_timestamp is not None + and (current_timestamp - self.last_click_timestamp) < double_click_threshold_s + ) except Exception as e: print(f"Error detecting double click: {str(e)}") return False @@ -369,8 +367,8 @@ def _calculate_click_coordinates(self, event: Any) -> str: scale_factor: float = self.canvas.scale_factor origin: Position = self.canvas.cartesian2axis.origin - x: float = (canvas_x - origin.x) * 1/scale_factor - y: float = (origin.y - canvas_y) * 1/scale_factor + x: float = (canvas_x - origin.x) * 1 / scale_factor + y: float = (origin.y - canvas_y) * 1 / scale_factor decimal_places_x: int = self.get_decimal_places(x) decimal_places_y: int = self.get_decimal_places(y) @@ -546,10 +544,7 @@ def _handle_touch_pan(self, touch: Any) -> None: """Handle single finger panning.""" try: # Create a mock event object similar to mouse event - mock_event: Any = type('obj', (object,), { - 'clientX': touch.clientX, - 'clientY': touch.clientY - }) + mock_event: Any = type("obj", (object,), {"clientX": touch.clientX, "clientY": touch.clientY}) self._update_mouse_position(mock_event) self._update_canvas_position(mock_event) @@ -609,10 +604,7 @@ def _handle_double_tap(self, touch: Any) -> None: """Handle double tap action - capture coordinates.""" try: # Create a mock event object similar to mouse event - mock_event: Any = type('obj', (object,), { - 'clientX': touch.clientX, - 'clientY': touch.clientY - }) + mock_event: Any = type("obj", (object,), {"clientX": touch.clientX, "clientY": touch.clientY}) coordinates: str = self._calculate_click_coordinates(mock_event) self._add_coordinates_to_chat(coordinates) diff --git a/static/client/cartesian_system_2axis.py b/static/client/cartesian_system_2axis.py index 8e9a12f8..90da4ad4 100644 --- a/static/client/cartesian_system_2axis.py +++ b/static/client/cartesian_system_2axis.py @@ -99,7 +99,7 @@ def reset(self) -> None: def get_class_name(self) -> str: """Return the class name 'Cartesian2Axis'.""" - return 'Cartesian2Axis' + return "Cartesian2Axis" @property def origin(self) -> Position: @@ -306,19 +306,15 @@ def set_state(self, state: Dict[str, Any]) -> None: if "visible" in state: self.visible = bool(state["visible"]) if "current_tick_spacing" in state: - self.current_tick_spacing = self._safe_float( - state["current_tick_spacing"], self.default_tick_spacing - ) + self.current_tick_spacing = self._safe_float(state["current_tick_spacing"], self.default_tick_spacing) if "default_tick_spacing" in state: - self.default_tick_spacing = self._safe_float( - state["default_tick_spacing"], 100.0 - ) + self.default_tick_spacing = self._safe_float(state["default_tick_spacing"], 100.0) def _get_axis_origin(self, axis: str) -> float: """Get the origin position for the specified axis""" origin: Position = self.origin - return cast(float, origin.x if axis == 'x' else origin.y) + return cast(float, origin.x if axis == "x" else origin.y) def _get_axis_boundary(self, axis: str) -> float: """Get the boundary (width/height) for the specified axis""" - return self.width if axis == 'x' else self.height + return self.width if axis == "x" else self.height diff --git a/static/client/client_tests/ai_result_formatter.py b/static/client/client_tests/ai_result_formatter.py index 18768fbf..a861885a 100644 --- a/static/client/client_tests/ai_result_formatter.py +++ b/static/client/client_tests/ai_result_formatter.py @@ -49,11 +49,7 @@ def _format_error(self, err: ErrorTuple) -> str: def get_failures_and_errors(self) -> Dict[str, Any]: """Return error details together with summary metadata.""" - output: Any = ( - self.output_stream.get_output() - if hasattr(self.output_stream, "get_output") - else None - ) + output: Any = self.output_stream.get_output() if hasattr(self.output_stream, "get_output") else None return { "failures": self.failures_details, "errors": self.errors_details, diff --git a/static/client/client_tests/renderer_performance_tests.py b/static/client/client_tests/renderer_performance_tests.py index 6d43926f..902ec472 100644 --- a/static/client/client_tests/renderer_performance_tests.py +++ b/static/client/client_tests/renderer_performance_tests.py @@ -269,9 +269,7 @@ def add_phase(name: str, total_key: str, count_key: str) -> None: count_int = summary["count"] avg_ms = summary.get("avg_ms") if avg_ms is None: - console.log( - f"[RendererPerf] {phase_name}: total={total_ms:.2f} ms over {count_int} events" - ) + console.log(f"[RendererPerf] {phase_name}: total={total_ms:.2f} ms over {count_int} events") else: console.log( f"[RendererPerf] {phase_name}: avg={avg_ms:.2f} ms " @@ -333,4 +331,3 @@ def test_renderer_performance(self) -> None: self.assertGreater(plan_apply_count, 0, "Optimized renderer did not apply any plans") self.assertEqual(plan_miss_count, 0, "Plan misses detected in optimized renderer") - diff --git a/static/client/client_tests/simple_mock.py b/static/client/client_tests/simple_mock.py index b5d9527e..3e1e1eeb 100644 --- a/static/client/client_tests/simple_mock.py +++ b/static/client/client_tests/simple_mock.py @@ -82,9 +82,7 @@ def assert_called_once_with(self, *args: Any, **kwargs: Any) -> None: raise AssertionError(f"Expected one call, got {len(self.calls)}") call_args, call_kwargs = self.calls[0] if call_args != args or call_kwargs != kwargs: - raise AssertionError( - f"Expected call with ({args}, {kwargs}), got ({call_args}, {call_kwargs})" - ) + raise AssertionError(f"Expected call with ({args}, {kwargs}), got ({call_args}, {call_kwargs})") def assert_called_once(self) -> None: """Assert the mock was called exactly once (any arguments).""" diff --git a/static/client/client_tests/test_action_trace_collector.py b/static/client/client_tests/test_action_trace_collector.py index 0b76306a..944a878f 100644 --- a/static/client/client_tests/test_action_trace_collector.py +++ b/static/client/client_tests/test_action_trace_collector.py @@ -71,14 +71,16 @@ def setUp(self) -> None: def test_structure(self) -> None: before: Dict[str, Any] = {"Point": {}} after: Dict[str, Any] = {"Point": {"A": {"x": 1, "y": 2}}} - calls: List[Dict[str, Any]] = [{ - "seq": 0, - "function_name": "create_point", - "arguments": {"x": 1, "y": 2}, - "result": "Success", - "is_error": False, - "duration_ms": 1.5, - }] + calls: List[Dict[str, Any]] = [ + { + "seq": 0, + "function_name": "create_point", + "arguments": {"x": 1, "y": 2}, + "result": "Success", + "is_error": False, + "duration_ms": 1.5, + } + ] trace = self.collector.build_trace(before, after, calls, 2.0) self.assertIn("trace_id", trace) @@ -161,14 +163,16 @@ def test_truncated_results(self) -> None: trace: Dict[str, Any] = { "trace_id": "t1", "timestamp": "2024-01-01T00:00:00Z", - "tool_calls": [{ - "seq": 0, - "function_name": "eval", - "arguments": {}, - "result": long_result, - "is_error": False, - "duration_ms": 1.0, - }], + "tool_calls": [ + { + "seq": 0, + "function_name": "eval", + "arguments": {}, + "result": long_result, + "is_error": False, + "duration_ms": 1.0, + } + ], "state_delta": {"added": [], "removed": [], "modified": []}, "total_duration_ms": 1.0, "canvas_state_before": {}, @@ -185,14 +189,16 @@ def test_serializable(self) -> None: trace: Dict[str, Any] = { "trace_id": "t1", "timestamp": "2024-01-01T00:00:00Z", - "tool_calls": [{ - "seq": 0, - "function_name": "create_point", - "arguments": {"x": 1}, - "result": "OK", - "is_error": False, - "duration_ms": 0.5, - }], + "tool_calls": [ + { + "seq": 0, + "function_name": "create_point", + "arguments": {"x": 1}, + "result": "OK", + "is_error": False, + "duration_ms": 0.5, + } + ], "state_delta": {"added": ["A"], "removed": [], "modified": []}, "total_duration_ms": 0.5, "canvas_state_before": {}, @@ -214,10 +220,22 @@ def test_summary_structure(self) -> None: "trace_id": "t1", "timestamp": "2024-01-01T00:00:00Z", "tool_calls": [ - {"seq": 0, "function_name": "f1", "arguments": {}, "result": "ok", - "is_error": False, "duration_ms": 1.0}, - {"seq": 1, "function_name": "f2", "arguments": {}, "result": "Error: bad", - "is_error": True, "duration_ms": 0.5}, + { + "seq": 0, + "function_name": "f1", + "arguments": {}, + "result": "ok", + "is_error": False, + "duration_ms": 1.0, + }, + { + "seq": 1, + "function_name": "f2", + "arguments": {}, + "result": "Error: bad", + "is_error": True, + "duration_ms": 0.5, + }, ], "state_delta": {"added": ["A"], "removed": [], "modified": []}, "total_duration_ms": 1.5, diff --git a/static/client/client_tests/test_angle.py b/static/client/client_tests/test_angle.py index 3ad11bc7..57ca09b9 100644 --- a/static/client/client_tests/test_angle.py +++ b/static/client/client_tests/test_angle.py @@ -16,6 +16,7 @@ from rendering.cached_render_plan import build_plan_for_drawable, _capture_map_state from rendering.style_manager import get_renderer_style + class TestAngle(unittest.TestCase): def setUp(self) -> None: # Create a real CoordinateMapper instance @@ -24,7 +25,11 @@ def setUp(self) -> None: # Setup for Canvas first, as DrawableNameGenerator needs it self.canvas = SimpleMock( # drawable_manager will be set after it's created - create_svg_element = lambda tag_name, attributes, text_content=None: {"tag": tag_name, "attrs": attributes, "text": text_content}, + create_svg_element=lambda tag_name, attributes, text_content=None: { + "tag": tag_name, + "attrs": attributes, + "text": text_content, + }, draw_enabled=True, draw=SimpleMock(return_value=None), # Add minimal coordinate_mapper properties needed by the system @@ -37,7 +42,7 @@ def setUp(self) -> None: zoom_point=Position(1, 1), zoom_direction=1, zoom_step=0.1, - offset=Position(0, 0) + offset=Position(0, 0), ) # Sync canvas state with coordinate mapper @@ -56,9 +61,9 @@ def add_segment_mock(segment: Any) -> None: self.drawable_manager_segments[segment.name] = segment self.drawable_manager = SimpleMock( - get_segment_by_name = get_segment_by_name_mock, - add_segment = add_segment_mock, - name_generator = self.name_generator + get_segment_by_name=get_segment_by_name_mock, + add_segment=add_segment_mock, + name_generator=self.name_generator, ) # Now that drawable_manager is created, assign it to canvas mock @@ -76,8 +81,12 @@ def add_segment_mock(segment: Any) -> None: self.s_AC = SimpleMock(name="AC", point1=self.A, point2=self.C, canvas=self.canvas) self.s_AD = SimpleMock(name="AD", point1=self.A, point2=self.D, canvas=self.canvas) self.s_AE = SimpleMock(name="AE", point1=self.A, point2=self.E, canvas=self.canvas) - self.s_BD = SimpleMock(name="BD", point1=self.B, point2=self.D, canvas=self.canvas) # Used for no common vertex test - self.s_CD = SimpleMock(name="CD", point1=self.C, point2=self.D, canvas=self.canvas) # Used for no common vertex test + self.s_BD = SimpleMock( + name="BD", point1=self.B, point2=self.D, canvas=self.canvas + ) # Used for no common vertex test + self.s_CD = SimpleMock( + name="CD", point1=self.C, point2=self.D, canvas=self.canvas + ) # Used for no common vertex test # Add segments to the mock drawable_manager self.drawable_manager.add_segment(self.s_AB) @@ -88,14 +97,14 @@ def add_segment_mock(segment: Any) -> None: self.drawable_manager.add_segment(self.s_CD) def test_initialization_valid_90_degrees(self) -> None: - angle = Angle(self.s_AB, self.s_AC) # is_reflex=False by default + angle = Angle(self.s_AB, self.s_AC) # is_reflex=False by default self.assertIsNotNone(angle) self.assertFalse(angle.is_reflex) self.assertIs(angle.vertex_point, self.A) self.assertIs(angle.arm1_point, self.B) self.assertIs(angle.arm2_point, self.C) self.assertAlmostEqual(angle.raw_angle_degrees, 90.0, places=5) - self.assertAlmostEqual(angle.angle_degrees, 90.0, places=5) # Small angle + self.assertAlmostEqual(angle.angle_degrees, 90.0, places=5) # Small angle self.assertEqual(angle.name, "angle_BAC") def test_arc_orientation_matches_renderer_flags(self) -> None: @@ -108,29 +117,29 @@ def test_arc_orientation_matches_renderer_flags(self) -> None: params = angle._calculate_arc_parameters(vx, vy, p1x, p1y, p2x, p2y, arc_radius=15) # 90-degree small (counterclockwise math-space) should produce sweep_flag='0' in SVG (y-down), large_arc_flag='0' self.assertIsNotNone(params) - self.assertEqual(params['final_sweep_flag'], '0') - self.assertEqual(params['final_large_arc_flag'], '0') + self.assertEqual(params["final_sweep_flag"], "0") + self.assertEqual(params["final_large_arc_flag"], "0") def test_initialization_valid_45_degrees(self) -> None: - angle = Angle(self.s_AB, self.s_AD) # is_reflex=False by default + angle = Angle(self.s_AB, self.s_AD) # is_reflex=False by default self.assertIsNotNone(angle) self.assertFalse(angle.is_reflex) self.assertIs(angle.vertex_point, self.A) self.assertIs(angle.arm1_point, self.B) self.assertIs(angle.arm2_point, self.D) self.assertAlmostEqual(angle.raw_angle_degrees, 45.0, places=5) - self.assertAlmostEqual(angle.angle_degrees, 45.0, places=5) # Small angle + self.assertAlmostEqual(angle.angle_degrees, 45.0, places=5) # Small angle self.assertEqual(angle.name, "angle_BAD") def test_initialization_valid_180_degrees(self) -> None: - angle = Angle(self.s_AB, self.s_AE) # is_reflex=False by default + angle = Angle(self.s_AB, self.s_AE) # is_reflex=False by default self.assertIsNotNone(angle) self.assertFalse(angle.is_reflex) self.assertIs(angle.vertex_point, self.A) self.assertIs(angle.arm1_point, self.B) self.assertIs(angle.arm2_point, self.E) self.assertAlmostEqual(angle.raw_angle_degrees, 180.0, places=5) - self.assertAlmostEqual(angle.angle_degrees, 180.0, places=5) # Small angle (180 is its own small/reflex) + self.assertAlmostEqual(angle.angle_degrees, 180.0, places=5) # Small angle (180 is its own small/reflex) self.assertEqual(angle.name, "angle_BAE") def test_initialization_invalid_no_common_vertex(self) -> None: @@ -144,7 +153,7 @@ def test_initialization_invalid_collinear_overlapping_same_segment(self) -> None Angle(self.s_AB, s_AB_copy) def test_initialization_invalid_one_segment_is_point_at_vertex(self) -> None: - p_at_vertex = SimpleMock(name="A_vtx_copy", x=self.A.x, y=self.A.y) # Point identical to A + p_at_vertex = SimpleMock(name="A_vtx_copy", x=self.A.x, y=self.A.y) # Point identical to A s_degenerate = SimpleMock(name="S_Deg", point1=self.A, point2=p_at_vertex, canvas=self.canvas) with self.assertRaisesRegex(ValueError, "segments do not form a valid angle"): Angle(self.s_AB, s_degenerate) @@ -152,39 +161,39 @@ def test_initialization_invalid_one_segment_is_point_at_vertex(self) -> None: def test_angle_calculation_270_degrees_or_minus_90(self) -> None: # Angle from s_AC (A-C) to s_AB (A-B) -> Vertex A, Arms C, B # Raw CCW angle from AC to AB is 270 degrees. - angle = Angle(self.s_AC, self.s_AB) # is_reflex=False by default + angle = Angle(self.s_AC, self.s_AB) # is_reflex=False by default self.assertFalse(angle.is_reflex) self.assertAlmostEqual(angle.raw_angle_degrees, 270.0, places=5) - self.assertAlmostEqual(angle.angle_degrees, 90.0, places=5) # Small angle (360 - 270) + self.assertAlmostEqual(angle.angle_degrees, 90.0, places=5) # Small angle (360 - 270) # Name is based on sorted arm points (B, C) around vertex A self.assertEqual(angle.name, "angle_BAC") - def test_initialization_reflex_AB_AC(self) -> None: # raw 90 deg + def test_initialization_reflex_AB_AC(self) -> None: # raw 90 deg angle = Angle(self.s_AB, self.s_AC, is_reflex=True) self.assertTrue(angle.is_reflex) self.assertAlmostEqual(angle.raw_angle_degrees, 90.0, places=5) - self.assertAlmostEqual(angle.angle_degrees, 270.0, places=5) # Reflex of 90 + self.assertAlmostEqual(angle.angle_degrees, 270.0, places=5) # Reflex of 90 self.assertEqual(angle.name, "angle_BAC_reflex") - def test_initialization_reflex_AC_AB(self) -> None: # raw 270 deg + def test_initialization_reflex_AC_AB(self) -> None: # raw 270 deg angle = Angle(self.s_AC, self.s_AB, is_reflex=True) self.assertTrue(angle.is_reflex) self.assertAlmostEqual(angle.raw_angle_degrees, 270.0, places=5) - self.assertAlmostEqual(angle.angle_degrees, 270.0, places=5) # Reflex of 270 is 270 + self.assertAlmostEqual(angle.angle_degrees, 270.0, places=5) # Reflex of 270 is 270 self.assertEqual(angle.name, "angle_BAC_reflex") - def test_initialization_reflex_AB_AD(self) -> None: # raw 45 deg + def test_initialization_reflex_AB_AD(self) -> None: # raw 45 deg angle = Angle(self.s_AB, self.s_AD, is_reflex=True) self.assertTrue(angle.is_reflex) self.assertAlmostEqual(angle.raw_angle_degrees, 45.0, places=5) - self.assertAlmostEqual(angle.angle_degrees, 315.0, places=5) # Reflex of 45 + self.assertAlmostEqual(angle.angle_degrees, 315.0, places=5) # Reflex of 45 self.assertEqual(angle.name, "angle_BAD_reflex") - def test_initialization_reflex_AB_AE(self) -> None: # raw 180 deg + def test_initialization_reflex_AB_AE(self) -> None: # raw 180 deg angle = Angle(self.s_AB, self.s_AE, is_reflex=True) self.assertTrue(angle.is_reflex) self.assertAlmostEqual(angle.raw_angle_degrees, 180.0, places=5) - self.assertAlmostEqual(angle.angle_degrees, 180.0, places=5) # Reflex of 180 is 180 + self.assertAlmostEqual(angle.angle_degrees, 180.0, places=5) # Reflex of 180 is 180 self.assertEqual(angle.name, "angle_BAE_reflex") def test_initialization_zero_angle_non_reflex(self) -> None: @@ -194,7 +203,7 @@ def test_initialization_zero_angle_non_reflex(self) -> None: self.assertFalse(angle.is_reflex) self.assertAlmostEqual(angle.raw_angle_degrees, 0.0, places=5) self.assertAlmostEqual(angle.angle_degrees, 0.0, places=5) - self.assertEqual(angle.name, "angle_BAF") # Name might vary based on B vs F sorting + self.assertEqual(angle.name, "angle_BAF") # Name might vary based on B vs F sorting def test_initialization_zero_angle_reflex(self) -> None: s_AF = SimpleMock(name="AF", point1=self.A, point2=SimpleMock(name="F", x=5, y=0), canvas=self.canvas) @@ -203,13 +212,13 @@ def test_initialization_zero_angle_reflex(self) -> None: self.assertTrue(angle.is_reflex) self.assertAlmostEqual(angle.raw_angle_degrees, 0.0, places=5) self.assertAlmostEqual(angle.angle_degrees, 360.0, places=5) - self.assertEqual(angle.name, "angle_BAF_reflex") # Name might vary based on B vs F sorting + self.assertEqual(angle.name, "angle_BAF_reflex") # Name might vary based on B vs F sorting def test_angle_calculation_zero_length_arm(self) -> None: A_copy = SimpleMock(name="A_copy", x=self.A.x, y=self.A.y) s_zero_arm = SimpleMock(name="S_Zero", point1=self.A, point2=A_copy, canvas=self.canvas) with self.assertRaisesRegex(ValueError, "segments do not form a valid angle"): - Angle(s_zero_arm, self.s_AC) + Angle(s_zero_arm, self.s_AC) def test_initialization_with_none_segment(self) -> None: with self.assertRaisesRegex(ValueError, "segments do not form a valid angle"): @@ -221,7 +230,7 @@ def test_initialization_with_none_segment(self) -> None: def test_get_class_name(self) -> None: angle = Angle(self.s_AB, self.s_AC) - self.assertEqual(angle.get_class_name(), 'Angle') + self.assertEqual(angle.get_class_name(), "Angle") # test_canvas_property removed: Angle is canvas-free; segments manage their own canvas @@ -235,12 +244,7 @@ def test_get_state_and_from_state(self) -> None: expected_state_non_reflex = { "name": "angle_BAD", "type": "angle", - "args": { - "segment1_name": "AB", - "segment2_name": "AD", - "color": "blue", - "is_reflex": False - } + "args": {"segment1_name": "AB", "segment2_name": "AD", "color": "blue", "is_reflex": False}, } self.assertEqual(state_non_reflex, expected_state_non_reflex) @@ -259,7 +263,7 @@ def test_get_state_and_from_state(self) -> None: dependency_manager=dep_mgr, point_manager=SimpleMock(), segment_manager=SimpleMock(), - drawable_manager_proxy=self.drawable_manager + drawable_manager_proxy=self.drawable_manager, ) angle_manager.load_angles([state_non_reflex]) self.assertTrue(len(angle_bucket) > 0) @@ -281,16 +285,11 @@ def test_get_state_and_from_state(self) -> None: expected_state_reflex = { "name": "angle_BAC_reflex", "type": "angle", - "args": { - "segment1_name": "AB", - "segment2_name": "AC", - "color": "blue", - "is_reflex": True - } + "args": {"segment1_name": "AB", "segment2_name": "AC", "color": "blue", "is_reflex": True}, } self.assertEqual(state_reflex, expected_state_reflex) - self.drawable_manager.add_segment(self.s_AC) # Ensure AC is also added + self.drawable_manager.add_segment(self.s_AC) # Ensure AC is also added angle_bucket2 = [] drawables_container2 = SimpleMock(add=lambda a: angle_bucket2.append(a)) @@ -301,7 +300,7 @@ def test_get_state_and_from_state(self) -> None: dependency_manager=dep_mgr, point_manager=SimpleMock(), segment_manager=SimpleMock(), - drawable_manager_proxy=self.drawable_manager + drawable_manager_proxy=self.drawable_manager, ) angle_manager2.load_angles([state_reflex]) self.assertTrue(len(angle_bucket2) > 0) @@ -318,7 +317,7 @@ def test_from_state_segment_not_found(self) -> None: state = { "name": "ghost_angle", "type": "angle", - "args": {"segment1_name": "NonExistentS1", "segment2_name": "AD", "color": "blue"} + "args": {"segment1_name": "NonExistentS1", "segment2_name": "AD", "color": "blue"}, } self.drawable_manager.add_segment(self.s_AD) @@ -331,7 +330,7 @@ def test_from_state_segment_not_found(self) -> None: dependency_manager=SimpleMock(register_dependency=SimpleMock(return_value=None)), point_manager=SimpleMock(), segment_manager=SimpleMock(), - drawable_manager_proxy=self.drawable_manager + drawable_manager_proxy=self.drawable_manager, ) angle_manager3.load_angles([state]) # No angle should have been added @@ -344,10 +343,10 @@ def test_update_points_based_on_segments(self) -> None: original_C_y = self.C.y original_C_x = self.C.x - self.C.x = 10 # C was (0,10), now (10,0) which is B's location + self.C.x = 10 # C was (0,10), now (10,0) which is B's location self.C.y = 0 - if hasattr(angle, 'update_points_based_on_segments'): + if hasattr(angle, "update_points_based_on_segments"): result = angle.update_points_based_on_segments() self.assertFalse(result) self.assertIsNone(angle.angle_degrees) @@ -363,10 +362,10 @@ def test_update_makes_angle_invalid(self) -> None: original_B_x = self.B.x original_B_y = self.B.y - self.B.x = self.A.x # Move B to A's location + self.B.x = self.A.x # Move B to A's location self.B.y = self.A.y - if hasattr(angle, 'update_points_based_on_segments'): + if hasattr(angle, "update_points_based_on_segments"): result = angle.update_points_based_on_segments() self.assertFalse(result) self.assertIsNone(angle.angle_degrees) @@ -446,8 +445,12 @@ def test_angle_deletion_on_point_deletion(self) -> None: # Setup mock managers for dependency tracking dependency_manager_mock = SimpleMock() dependency_manager_mock.dependencies = {} # Simple storage for dependencies - dependency_manager_mock.register_dependency = lambda child, parent: self._add_dependency(dependency_manager_mock.dependencies, child, parent) - dependency_manager_mock.get_children = lambda parent: self._get_dependencies(dependency_manager_mock.dependencies, parent, 'children') + dependency_manager_mock.register_dependency = lambda child, parent: self._add_dependency( + dependency_manager_mock.dependencies, child, parent + ) + dependency_manager_mock.get_children = lambda parent: self._get_dependencies( + dependency_manager_mock.dependencies, parent, "children" + ) dependency_manager_mock.remove_drawable = lambda drawable: None angle_manager_mock = SimpleMock() @@ -461,9 +464,9 @@ def test_angle_deletion_on_point_deletion(self) -> None: # Register dependencies (simulating what AngleManager.create_angle would do) dependency_manager_mock.register_dependency(angle, self.s_AB) # angle depends on segment AB dependency_manager_mock.register_dependency(angle, self.s_AC) # angle depends on segment AC - dependency_manager_mock.register_dependency(angle, self.A) # angle depends on vertex point A - dependency_manager_mock.register_dependency(angle, self.B) # angle depends on arm point B - dependency_manager_mock.register_dependency(angle, self.C) # angle depends on arm point C + dependency_manager_mock.register_dependency(angle, self.A) # angle depends on vertex point A + dependency_manager_mock.register_dependency(angle, self.B) # angle depends on arm point B + dependency_manager_mock.register_dependency(angle, self.C) # angle depends on arm point C # Verify angle exists self.assertEqual(len(angles_list), 1) @@ -478,7 +481,7 @@ def test_angle_deletion_on_point_deletion(self) -> None: # Simulate the deletion process that should happen when point A is deleted for child in point_a_children: - if hasattr(child, 'get_class_name') and child.get_class_name() == 'Angle': + if hasattr(child, "get_class_name") and child.get_class_name() == "Angle": angles_list.remove(child) # Verify angle was deleted @@ -490,8 +493,12 @@ def test_angle_deletion_on_segment_deletion(self) -> None: # Setup mock managers for dependency tracking dependency_manager_mock = SimpleMock() dependency_manager_mock.dependencies = {} # Simple storage for dependencies - dependency_manager_mock.register_dependency = lambda child, parent: self._add_dependency(dependency_manager_mock.dependencies, child, parent) - dependency_manager_mock.get_children = lambda parent: self._get_dependencies(dependency_manager_mock.dependencies, parent, 'children') + dependency_manager_mock.register_dependency = lambda child, parent: self._add_dependency( + dependency_manager_mock.dependencies, child, parent + ) + dependency_manager_mock.get_children = lambda parent: self._get_dependencies( + dependency_manager_mock.dependencies, parent, "children" + ) dependency_manager_mock.remove_drawable = lambda drawable: None angle_manager_mock = SimpleMock() @@ -505,9 +512,9 @@ def test_angle_deletion_on_segment_deletion(self) -> None: # Register dependencies (simulating what AngleManager.create_angle would do) dependency_manager_mock.register_dependency(angle, self.s_AB) # angle depends on segment AB dependency_manager_mock.register_dependency(angle, self.s_AC) # angle depends on segment AC - dependency_manager_mock.register_dependency(angle, self.A) # angle depends on vertex point A - dependency_manager_mock.register_dependency(angle, self.B) # angle depends on arm point B - dependency_manager_mock.register_dependency(angle, self.C) # angle depends on arm point C + dependency_manager_mock.register_dependency(angle, self.A) # angle depends on vertex point A + dependency_manager_mock.register_dependency(angle, self.B) # angle depends on arm point B + dependency_manager_mock.register_dependency(angle, self.C) # angle depends on arm point C # Verify angle exists self.assertEqual(len(angles_list), 1) @@ -522,7 +529,7 @@ def test_angle_deletion_on_segment_deletion(self) -> None: # Simulate the deletion process that should happen when segment AB is deleted for child in segment_ab_children: - if hasattr(child, 'get_class_name') and child.get_class_name() == 'Angle': + if hasattr(child, "get_class_name") and child.get_class_name() == "Angle": angles_list.remove(child) # Verify angle was deleted @@ -630,7 +637,6 @@ def _add_dependency(self, dependencies_dict: Dict[Any, list[Any]], child: Any, p def _get_dependencies(self, dependencies_dict: Dict[Any, list[Any]], parent: Any, dep_type: str) -> list[Any]: """Helper method to get dependencies.""" - if dep_type == 'children': + if dep_type == "children": return dependencies_dict.get(parent, []) return [] - diff --git a/static/client/client_tests/test_angle_manager.py b/static/client/client_tests/test_angle_manager.py index 9274e841..2240601c 100644 --- a/static/client/client_tests/test_angle_manager.py +++ b/static/client/client_tests/test_angle_manager.py @@ -6,6 +6,7 @@ from drawables_aggregator import Position from typing import Any + class TestAngleManager(unittest.TestCase): def setUp(self) -> None: # Create a real CoordinateMapper instance @@ -34,17 +35,17 @@ def setUp(self) -> None: zoom_point=Position(1, 1), zoom_direction=1, zoom_step=0.1, - offset=Position(0, 0) + offset=Position(0, 0), ) # Sync canvas state with coordinate mapper self.coordinate_mapper.sync_from_canvas(self.canvas_mock) self.drawables_container_mock: Any = SimpleMock( name="DrawablesContainerMock", - Angles=[], # Holds created Angle instances - add=MagicMock(side_effect=lambda x: self.drawables_container_mock.Angles.append(x)) + Angles=[], # Holds created Angle instances + add=MagicMock(side_effect=lambda x: self.drawables_container_mock.Angles.append(x)), ) - self.name_generator_mock = SimpleMock(name="NameGeneratorMock") # Basic mock for now + self.name_generator_mock = SimpleMock(name="NameGeneratorMock") # Basic mock for now self.dependency_manager_mock = SimpleMock( name="DependencyManagerMock", @@ -53,12 +54,12 @@ def setUp(self) -> None: ) # Mock Point Class Behavior (used by point_manager) - self.MockPoint = SimpleMock # Use SimpleMock directly + self.MockPoint = SimpleMock # Use SimpleMock directly # We will define the behavior in the create_point mock below self.point_manager_mock = SimpleMock( name="PointManagerMock", - create_point=lambda x, y, name=None, extra_graphics=True, label=None, color=None, size=None, display_type=None, is_visible=True, is_fixed=False: + create_point=lambda x, y, name=None, extra_graphics=True, label=None, color=None, size=None, display_type=None, is_visible=True, is_fixed=False: ( SimpleMock( name=name or f"P({x},{y})", x=x, @@ -68,33 +69,35 @@ def setUp(self) -> None: size=size, display_type=display_type, is_visible=is_visible, - is_fixed=is_fixed + is_fixed=is_fixed, ) + ), ) # Mock Segment Class Behavior (used by segment_manager) - self.MockSegment = SimpleMock # Use SimpleMock directly + self.MockSegment = SimpleMock # Use SimpleMock directly # We will define the behavior in the create_segment mock below self.segment_manager_mock = SimpleMock( name="SegmentManagerMock", - create_segment=lambda x1, y1, x2, y2, name=None, extra_graphics=True, label=None, color=None, thickness=None, is_visible=True, has_direction=False: + create_segment=lambda x1, y1, x2, y2, name=None, extra_graphics=True, label=None, color=None, thickness=None, is_visible=True, has_direction=False: ( SimpleMock( name=name or f"S_({x1},{y1})-({x2},{y2})", - point1=self.point_manager_mock.create_point(x1,y1), - point2=self.point_manager_mock.create_point(x2,y2), + point1=self.point_manager_mock.create_point(x1, y1), + point2=self.point_manager_mock.create_point(x2, y2), # Ensure all attributes that might be accessed are present label=label, color=color, thickness=thickness, is_visible=is_visible, - has_direction=has_direction + has_direction=has_direction, ) + ), ) self.drawable_manager_proxy_mock = SimpleMock( name="DrawableManagerProxyMock", - create_drawables_from_new_connections=MagicMock() + create_drawables_from_new_connections=MagicMock(), # Add get_segment_by_name if AngleManager needs it directly from proxy ) @@ -108,21 +111,20 @@ def setUp(self) -> None: dependency_manager=self.dependency_manager_mock, point_manager=self.point_manager_mock, segment_manager=self.segment_manager_mock, - drawable_manager_proxy=self.drawable_manager_proxy_mock + drawable_manager_proxy=self.drawable_manager_proxy_mock, ) # Helper points for tests self.A = self.point_manager_mock.create_point(0, 0, name="A") self.B = self.point_manager_mock.create_point(10, 0, name="B") self.C = self.point_manager_mock.create_point(0, 10, name="C") - self.D = self.point_manager_mock.create_point(-10, 0, name="D") # For a different angle + self.D = self.point_manager_mock.create_point(-10, 0, name="D") # For a different angle # Helper segments for tests self.seg_AB = self.segment_manager_mock.create_segment(self.A.x, self.A.y, self.B.x, self.B.y, name="AB") self.seg_AC = self.segment_manager_mock.create_segment(self.A.x, self.A.y, self.C.x, self.C.y, name="AC") self.seg_AD = self.segment_manager_mock.create_segment(self.A.x, self.A.y, self.D.x, self.D.y, name="AD") - def test_initialization(self) -> None: self.assertIsNotNone(self.angle_manager.canvas) self.assertIsNotNone(self.angle_manager.drawables) @@ -156,16 +158,16 @@ def test_get_angle_by_segments_found(self) -> None: def test_get_angle_by_segments_not_found(self) -> None: # Segments that don't form a known angle in the container # Create a new point E for the other segment to avoid confusion - E = self.point_manager_mock.create_point(20,20, name="E") + E = self.point_manager_mock.create_point(20, 20, name="E") other_segment = self.segment_manager_mock.create_segment(self.B.x, self.B.y, E.x, E.y, name="BE") found_angle = self.angle_manager.get_angle_by_segments(self.seg_AB, other_segment) self.assertIsNone(found_angle) def test_get_angle_by_points_found(self) -> None: # Setup points that an Angle would have derived - A_obj = SimpleMock(name="A_obj", x=0,y=0) - B_obj = SimpleMock(name="B_obj", x=10,y=0) - C_obj = SimpleMock(name="C_obj", x=0,y=10) + A_obj = SimpleMock(name="A_obj", x=0, y=0) + B_obj = SimpleMock(name="B_obj", x=10, y=0) + C_obj = SimpleMock(name="C_obj", x=0, y=10) mock_angle = SimpleMock( name="AngleByPoints", @@ -175,7 +177,7 @@ def test_get_angle_by_points_found(self) -> None: # segment1 and segment2 would be composed of these points segment1=SimpleMock(point1=A_obj, point2=B_obj), segment2=SimpleMock(point1=A_obj, point2=C_obj), - is_reflex=False + is_reflex=False, ) self.drawables_container_mock.Angles.append(mock_angle) @@ -188,10 +190,10 @@ def test_get_angle_by_points_found(self) -> None: self.assertIs(found_angle_reverse, mock_angle) def test_get_angle_by_points_not_found_wrong_vertex(self) -> None: - A_obj = SimpleMock(name="A_obj", x=0,y=0) - B_obj = SimpleMock(name="B_obj", x=10,y=0) - C_obj = SimpleMock(name="C_obj", x=0,y=10) - wrong_V_obj = SimpleMock(name="WrongV_obj", x=1,y=1) # Changed from wrong_v_obj + A_obj = SimpleMock(name="A_obj", x=0, y=0) + B_obj = SimpleMock(name="B_obj", x=10, y=0) + C_obj = SimpleMock(name="C_obj", x=0, y=10) + wrong_V_obj = SimpleMock(name="WrongV_obj", x=1, y=1) # Changed from wrong_v_obj mock_angle = SimpleMock(name="AngleByPoints", vertex_point=A_obj, arm1_point=B_obj, arm2_point=C_obj) self.drawables_container_mock.Angles.append(mock_angle) @@ -200,10 +202,10 @@ def test_get_angle_by_points_not_found_wrong_vertex(self) -> None: self.assertIsNone(found_angle) def test_get_angle_by_points_not_found_wrong_arm(self) -> None: - A_obj = SimpleMock(name="A_obj", x=0,y=0) - B_obj = SimpleMock(name="B_obj", x=10,y=0) - C_obj = SimpleMock(name="C_obj", x=0,y=10) - wrong_D_obj = SimpleMock(name="WrongD_obj", x=-1,y=-1) # Changed from wrong_a3_obj + A_obj = SimpleMock(name="A_obj", x=0, y=0) + B_obj = SimpleMock(name="B_obj", x=10, y=0) + C_obj = SimpleMock(name="C_obj", x=0, y=10) + wrong_D_obj = SimpleMock(name="WrongD_obj", x=-1, y=-1) # Changed from wrong_a3_obj mock_angle = SimpleMock(name="AngleByPoints", vertex_point=A_obj, arm1_point=B_obj, arm2_point=C_obj) self.drawables_container_mock.Angles.append(mock_angle) @@ -212,8 +214,8 @@ def test_get_angle_by_points_not_found_wrong_arm(self) -> None: self.assertIsNone(found_angle) def test_get_angle_by_points_input_points_none(self) -> None: - A_obj = SimpleMock(name="A_obj", x=0,y=0) # Changed from v_obj - B_obj = SimpleMock(name="B_obj", x=10,y=0) # Changed from a1_obj + A_obj = SimpleMock(name="A_obj", x=0, y=0) # Changed from v_obj + B_obj = SimpleMock(name="B_obj", x=10, y=0) # Changed from a1_obj self.assertIsNone(self.angle_manager.get_angle_by_points(None, B_obj, B_obj)) self.assertIsNone(self.angle_manager.get_angle_by_points(A_obj, None, B_obj)) self.assertIsNone(self.angle_manager.get_angle_by_points(A_obj, B_obj, None)) diff --git a/static/client/client_tests/test_arc_manager.py b/static/client/client_tests/test_arc_manager.py index ed4e577e..a4d416b8 100644 --- a/static/client/client_tests/test_arc_manager.py +++ b/static/client/client_tests/test_arc_manager.py @@ -27,7 +27,9 @@ def setUp(self) -> None: ) self.points_by_name: Dict[str, Point] = {} - def create_point(x: float, y: float, name: Optional[str] = None, extra_graphics: bool = True, **kwargs: Any) -> Point: + def create_point( + x: float, y: float, name: Optional[str] = None, extra_graphics: bool = True, **kwargs: Any + ) -> Point: assigned_name = name or f"P_{len(self.points_by_name)}" point = Point(x, y, assigned_name) self.points_by_name[assigned_name] = point @@ -60,7 +62,9 @@ def get_circle_by_name(name: str) -> Optional[Circle]: self.name_generator = SimpleMock( extract_point_names_from_arc_name=lambda arc_name: (None, None), - generate_arc_name=lambda proposed, p1, p2, major, existing: proposed if proposed else f"{'ArcMaj' if major else 'ArcMin'}_{p1}{p2}", + generate_arc_name=lambda proposed, p1, p2, major, existing: ( + proposed if proposed else f"{'ArcMaj' if major else 'ArcMin'}_{p1}{p2}" + ), ) self.arc_manager = ArcManager( @@ -114,12 +118,8 @@ def test_create_circle_arc_snaps_new_points_to_circle(self) -> None: ) self.assertIsNotNone(arc) - self.assertTrue( - math.isclose(math.hypot(arc.point1.x, arc.point1.y), 5.0, rel_tol=1e-9, abs_tol=1e-9) - ) - self.assertTrue( - math.isclose(math.hypot(arc.point2.x, arc.point2.y), 5.0, rel_tol=1e-9, abs_tol=1e-9) - ) + self.assertTrue(math.isclose(math.hypot(arc.point1.x, arc.point1.y), 5.0, rel_tol=1e-9, abs_tol=1e-9)) + self.assertTrue(math.isclose(math.hypot(arc.point2.x, arc.point2.y), 5.0, rel_tol=1e-9, abs_tol=1e-9)) def test_create_circle_arc_projects_existing_points(self) -> None: existing_point_a = self.point_manager.create_point(1, 0, name="A") @@ -253,7 +253,6 @@ def test_update_circle_arc_requires_editable_property(self) -> None: with self.assertRaises(ValueError): self.arc_manager.update_circle_arc("arc_AB") - def test_handle_circle_removed_deletes_arcs(self) -> None: arc = self.arc_manager.create_circle_arc( point1_x=5, @@ -290,4 +289,3 @@ def test_load_circle_arcs_restores_from_state(self) -> None: self.arc_manager.load_circle_arcs(state) self.assertEqual(len(self.drawables.CircleArcs), 1) self.assertEqual(self.drawables.CircleArcs[0].name, "arc_AB") - diff --git a/static/client/client_tests/test_area_expression_evaluator.py b/static/client/client_tests/test_area_expression_evaluator.py index 07559626..4ba38b01 100644 --- a/static/client/client_tests/test_area_expression_evaluator.py +++ b/static/client/client_tests/test_area_expression_evaluator.py @@ -23,6 +23,7 @@ # Mock Classes # ============================================================================= + class MockPoint: """Mock Point for testing.""" @@ -151,6 +152,7 @@ def get_segments(self) -> List[MockSegment]: def _make_polygon_mock(class_name: str): """Factory to create mock polygon classes with get_segments.""" + class MockPolygon: def __init__(self, vertices: List[tuple], name: str = "") -> None: self._vertices = vertices @@ -204,6 +206,7 @@ def __init__(self) -> None: # Area Calculation Tests # ============================================================================= + class TestAreaCalculation(unittest.TestCase): """Test area calculation with various shapes and operations.""" @@ -1455,6 +1458,7 @@ def _regular_polygon_vertices(self, n: int, radius: float = 5, cx: float = 0, cy # Region Generation Tests # ============================================================================= + class TestRegionGeneration(unittest.TestCase): """Test that all drawable types can be converted to regions.""" diff --git a/static/client/client_tests/test_bar_manager.py b/static/client/client_tests/test_bar_manager.py index d51d1043..986215e5 100644 --- a/static/client/client_tests/test_bar_manager.py +++ b/static/client/client_tests/test_bar_manager.py @@ -272,5 +272,3 @@ def fake_archive() -> None: calls.clear() self.assertTrue(self.bar_manager.delete_bar("ToDelete2", archive=False, redraw=False)) self.assertEqual(calls, []) - - diff --git a/static/client/client_tests/test_bar_renderer.py b/static/client/client_tests/test_bar_renderer.py index 9550788e..c20f7855 100644 --- a/static/client/client_tests/test_bar_renderer.py +++ b/static/client/client_tests/test_bar_renderer.py @@ -15,9 +15,7 @@ def fill_polygon(self, points, fill, stroke=None, **kwargs): self.calls.append(("fill_polygon", (points, fill, stroke), dict(kwargs))) def draw_text(self, text, position, font, color, alignment, style_overrides=None, **kwargs): - self.calls.append( - ("draw_text", (text, position, font, color, alignment, style_overrides), dict(kwargs)) - ) + self.calls.append(("draw_text", (text, position, font, color, alignment, style_overrides), dict(kwargs))) # The remaining primitives are not needed for these tests. def stroke_line(self, *_args, **_kwargs): @@ -161,5 +159,3 @@ def test_guard_clauses_skip_invalid_geometry(self) -> None: bar2 = Bar(name="B2", x_left=0.0, x_right=1.0, y_bottom=2.0, y_top=2.0) shared.render_bar_helper(self.primitives, bar2, self.mapper, self.style) self.assertEqual(self.primitives.calls, []) - - diff --git a/static/client/client_tests/test_canvas.py b/static/client/client_tests/test_canvas.py index 341a929e..8f4bebf1 100644 --- a/static/client/client_tests/test_canvas.py +++ b/static/client/client_tests/test_canvas.py @@ -20,18 +20,39 @@ class TestCanvas(unittest.TestCase): def setUp(self) -> None: self.canvas = Canvas(500, 500, draw_enabled=False) - self.mock_cartesian2axis = SimpleMock(draw=SimpleMock(return_value=None), reset=SimpleMock(return_value=None), - get_state=SimpleMock(return_value={'Cartesian_System_Visibility': 'cartesian_state'}), - origin=Position(0, 0)) + self.mock_cartesian2axis = SimpleMock( + draw=SimpleMock(return_value=None), + reset=SimpleMock(return_value=None), + get_state=SimpleMock(return_value={"Cartesian_System_Visibility": "cartesian_state"}), + origin=Position(0, 0), + ) self.canvas.cartesian2axis = self.mock_cartesian2axis - self.mock_point1 = SimpleMock(canvas=None, get_class_name=SimpleMock(return_value='Point'), reset=SimpleMock(return_value=None), - get_state=SimpleMock(return_value='point1_state'), - x=10, y=10, name='A') - self.mock_point2 = SimpleMock(canvas=None, get_class_name=SimpleMock(return_value='Point'), reset=SimpleMock(return_value=None), - get_state=SimpleMock(return_value='point2_state'), - x=20, y=20, name='B') - self.mock_segment1 = SimpleMock(canvas=None, get_class_name=SimpleMock(return_value='Segment'), reset=SimpleMock(return_value=None), - get_state=SimpleMock(return_value='segment1_state'), point1=self.mock_point1, point2=self.mock_point2) + self.mock_point1 = SimpleMock( + canvas=None, + get_class_name=SimpleMock(return_value="Point"), + reset=SimpleMock(return_value=None), + get_state=SimpleMock(return_value="point1_state"), + x=10, + y=10, + name="A", + ) + self.mock_point2 = SimpleMock( + canvas=None, + get_class_name=SimpleMock(return_value="Point"), + reset=SimpleMock(return_value=None), + get_state=SimpleMock(return_value="point2_state"), + x=20, + y=20, + name="B", + ) + self.mock_segment1 = SimpleMock( + canvas=None, + get_class_name=SimpleMock(return_value="Segment"), + reset=SimpleMock(return_value=None), + get_state=SimpleMock(return_value="segment1_state"), + point1=self.mock_point1, + point2=self.mock_point2, + ) # Add backward-compatible drawables property for tests # This allows existing tests to work with the new DrawablesContainer @@ -45,7 +66,7 @@ def set_drawables(canvas_self: Canvas, value: Dict[str, List[Any]]) -> None: if hasattr(canvas_self.drawable_manager.drawables, "rebuild_renderables"): canvas_self.drawable_manager.drawables.rebuild_renderables() - setattr(Canvas, 'drawables', property(get_drawables, set_drawables)) + setattr(Canvas, "drawables", property(get_drawables, set_drawables)) def _create_rectangle( self, @@ -117,10 +138,10 @@ def test_init(self) -> None: def test_add_drawable(self) -> None: self.canvas.add_drawable(self.mock_point1) self.canvas.add_drawable(self.mock_segment1) - self.assertIn('Point', self.canvas.drawable_manager.drawables._drawables) - self.assertIn('Segment', self.canvas.drawable_manager.drawables._drawables) - self.assertIn(self.mock_point1, self.canvas.drawable_manager.drawables._drawables['Point']) - self.assertIn(self.mock_segment1, self.canvas.drawable_manager.drawables._drawables['Segment']) + self.assertIn("Point", self.canvas.drawable_manager.drawables._drawables) + self.assertIn("Segment", self.canvas.drawable_manager.drawables._drawables) + self.assertIn(self.mock_point1, self.canvas.drawable_manager.drawables._drawables["Point"]) + self.assertIn(self.mock_segment1, self.canvas.drawable_manager.drawables._drawables["Segment"]) def test_get_drawables(self) -> None: self.canvas.add_drawable(self.mock_point1) @@ -138,17 +159,25 @@ def test_get_drawables_state(self) -> None: self.canvas.add_drawable(self.mock_point1) self.canvas.add_drawable(self.mock_segment1) state = self.canvas.get_drawables_state() - self.assertEqual(state, {'Points': ['point1_state'], 'Segments': ['segment1_state']}) + self.assertEqual(state, {"Points": ["point1_state"], "Segments": ["segment1_state"]}) def test_get_cartesian2axis_state(self) -> None: state = self.canvas.get_cartesian2axis_state() - self.assertEqual(state, {'Cartesian_System_Visibility': 'cartesian_state'}) + self.assertEqual(state, {"Cartesian_System_Visibility": "cartesian_state"}) def test_get_canvas_state(self) -> None: self.canvas.add_drawable(self.mock_point1) self.canvas.add_drawable(self.mock_segment1) state = self.canvas.get_canvas_state() - self.assertEqual(state, {'Points': ['point1_state'], 'Segments': ['segment1_state'], 'Cartesian_System_Visibility': 'cartesian_state', 'coordinate_system': {'mode': 'cartesian'}}) + self.assertEqual( + state, + { + "Points": ["point1_state"], + "Segments": ["segment1_state"], + "Cartesian_System_Visibility": "cartesian_state", + "coordinate_system": {"mode": "cartesian"}, + }, + ) def test_get_canvas_state_filtered_no_filters_matches_full_state(self) -> None: self.canvas.add_drawable(self.mock_point1) @@ -228,10 +257,7 @@ def test_undo_redo(self) -> None: # Store initial state properties initial_point_pos = (point.x, point.y) - initial_segment_points = ( - (segment.point1.x, segment.point1.y), - (segment.point2.x, segment.point2.y) - ) + initial_segment_points = ((segment.point1.x, segment.point1.y), (segment.point2.x, segment.point2.y)) # Make some changes self.canvas.delete_point_by_name("A") @@ -250,18 +276,9 @@ def test_undo_redo(self) -> None: self.assertIsNotNone(restored_point) self.assertIsNotNone(restored_segment) - self.assertEqual( - (restored_point.x, restored_point.y), - initial_point_pos - ) - self.assertEqual( - (restored_segment.point1.x, restored_segment.point1.y), - initial_segment_points[0] - ) - self.assertEqual( - (restored_segment.point2.x, restored_segment.point2.y), - initial_segment_points[1] - ) + self.assertEqual((restored_point.x, restored_point.y), initial_point_pos) + self.assertEqual((restored_segment.point1.x, restored_segment.point1.y), initial_segment_points[0]) + self.assertEqual((restored_segment.point2.x, restored_segment.point2.y), initial_segment_points[1]) # Redo changes self.canvas.redo() @@ -274,8 +291,8 @@ def test_get_canvas_state_and_reconstruct_preserves_integrity(self) -> None: pointA = self.canvas.create_point(10, 20, "A") pointB = self.canvas.create_point(30, 40, "B") pointC = self.canvas.create_point(50, 50, "C") - pointD = self.canvas.create_point(5, 15, "D") # New point for vector - pointE = self.canvas.create_point(25, 35, "E") # New point for vector + pointD = self.canvas.create_point(5, 15, "D") # New point for vector + pointE = self.canvas.create_point(25, 35, "E") # New point for vector segmentAB = self.canvas.create_segment(pointA.x, pointA.y, pointB.x, pointB.y) @@ -292,64 +309,61 @@ def test_get_canvas_state_and_reconstruct_preserves_integrity(self) -> None: # Reconstruct Points, Segments, Circles (with else: self.fail) if "Points" in original_state: for point_data in original_state["Points"]: - pos = point_data['args']['position'] - self.canvas.create_point(pos['x'], pos['y'], point_data['name']) + pos = point_data["args"]["position"] + self.canvas.create_point(pos["x"], pos["y"], point_data["name"]) if "Segments" in original_state: for seg_data in original_state["Segments"]: - p1_name = seg_data['args']['p1'] - p2_name = seg_data['args']['p2'] + p1_name = seg_data["args"]["p1"] + p2_name = seg_data["args"]["p2"] p1 = self.canvas.get_point_by_name(p1_name) p2 = self.canvas.get_point_by_name(p2_name) if p1 and p2: - self.canvas.create_segment(p1.x, p1.y, - p2.x, p2.y, - seg_data['name']) + self.canvas.create_segment(p1.x, p1.y, p2.x, p2.y, seg_data["name"]) else: self.fail(f"Could not find points {p1_name} or {p2_name} for segment {seg_data['name']}") if "Circles" in original_state: for circ_data in original_state["Circles"]: - center_name = circ_data['args']['center'] - radius = circ_data['args']['radius'] + center_name = circ_data["args"]["center"] + radius = circ_data["args"]["radius"] center_point = self.canvas.get_point_by_name(center_name) if center_point: - self.canvas.create_circle(center_point.x, center_point.y, - radius, circ_data['name']) + self.canvas.create_circle(center_point.x, center_point.y, radius, circ_data["name"]) else: self.fail(f"Could not find center point {center_name} for circle {circ_data['name']}") # Reconstruct Vectors (with else: self.fail) if "Vectors" in original_state: for vec_data in original_state["Vectors"]: - origin_name = vec_data['args']['origin'] - tip_name = vec_data['args']['tip'] + origin_name = vec_data["args"]["origin"] + tip_name = vec_data["args"]["tip"] origin_point = self.canvas.get_point_by_name(origin_name) tip_point = self.canvas.get_point_by_name(tip_name) if origin_point and tip_point: - self.canvas.create_vector(origin_point.x, origin_point.y, - tip_point.x, tip_point.y, - vec_data['name']) + self.canvas.create_vector( + origin_point.x, origin_point.y, tip_point.x, tip_point.y, vec_data["name"] + ) else: self.fail(f"Could not find points {origin_name} or {tip_name} for vector {vec_data['name']}") # Reconstruct Functions and Computations (with else: self.fail for computations) if "Functions" in original_state: for func_data in original_state["Functions"]: - args = func_data['args'] + args = func_data["args"] self.canvas.draw_function( args["function_string"], name=func_data.get("name", ""), left_bound=args.get("left_bound"), - right_bound=args.get("right_bound") + right_bound=args.get("right_bound"), ) if "computations" in original_state: - for comp_data in original_state["computations"]: + for comp_data in original_state["computations"]: expression = comp_data.get("expression") result = comp_data.get("result") if expression is not None and result is not None: - self.canvas.add_computation(expression, result) + self.canvas.add_computation(expression, result) else: self.fail(f"Computation data missing expression or result: {comp_data}") @@ -359,7 +373,9 @@ def test_get_canvas_state_and_reconstruct_preserves_integrity(self) -> None: self.assertEqual(len(original_state.get("Points", [])), len(reconstructed_state.get("Points", []))) self.assertEqual(len(original_state.get("Segments", [])), len(reconstructed_state.get("Segments", []))) self.assertEqual(len(original_state.get("Circles", [])), len(reconstructed_state.get("Circles", []))) - self.assertEqual(len(original_state.get("Vectors", [])), len(reconstructed_state.get("Vectors", []))) # New assert for vectors + self.assertEqual( + len(original_state.get("Vectors", [])), len(reconstructed_state.get("Vectors", [])) + ) # New assert for vectors self.assertEqual(len(original_state.get("Functions", [])), len(reconstructed_state.get("Functions", []))) self.assertEqual(len(original_state.get("computations", [])), len(reconstructed_state.get("computations", []))) @@ -368,7 +384,7 @@ def test_get_canvas_state_and_reconstruct_preserves_integrity(self) -> None: self.assertIsNotNone(rec_pointA, "Point A should be reconstructed") self.assertEqual(rec_pointA.x, 10) - rec_pointB = self.canvas.get_point_by_name("B") # Get B for segment + rec_pointB = self.canvas.get_point_by_name("B") # Get B for segment self.assertIsNotNone(rec_pointB, "Point B should be reconstructed") rec_segmentAB = self.canvas.get_segment_by_points(rec_pointA, rec_pointB) self.assertIsNotNone(rec_segmentAB, "Segment AB should be reconstructed") @@ -376,17 +392,17 @@ def test_get_canvas_state_and_reconstruct_preserves_integrity(self) -> None: self.assertEqual(rec_segmentAB.point1.name, "A") self.assertEqual(rec_segmentAB.point2.name, "B") - rec_pointC = self.canvas.get_point_by_name("C") # Get C for circle + rec_pointC = self.canvas.get_point_by_name("C") # Get C for circle self.assertIsNotNone(rec_pointC, "Point C should be reconstructed") # Assuming the radius was 25 during creation rec_circleC = self.canvas.get_circle(rec_pointC.x, rec_pointC.y, 25) self.assertIsNotNone(rec_circleC, "Circle C should be reconstructed") - self.assertEqual(rec_circleC.center.name, "C") # Changed from center_point to center + self.assertEqual(rec_circleC.center.name, "C") # Changed from center_point to center self.assertEqual(rec_circleC.radius, 25) - rec_pointD = self.canvas.get_point_by_name("D") # Get D for vector + rec_pointD = self.canvas.get_point_by_name("D") # Get D for vector self.assertIsNotNone(rec_pointD, "Point D should be reconstructed") - rec_pointE = self.canvas.get_point_by_name("E") # Get E for vector + rec_pointE = self.canvas.get_point_by_name("E") # Get E for vector self.assertIsNotNone(rec_pointE, "Point E should be reconstructed") rec_vectorDE = self.canvas.get_vector(rec_pointD.x, rec_pointD.y, rec_pointE.x, rec_pointE.y) self.assertIsNotNone(rec_vectorDE, "Vector DE should be reconstructed") @@ -400,13 +416,17 @@ def test_get_canvas_state_and_reconstruct_preserves_integrity(self) -> None: # Cleaned up: Removed duplicated detailed state comparison for brevity # Detailed comparison of sorted lists for each type for key in ["Points", "Segments", "Circles", "Vectors", "Functions"]: - original_items = sorted(original_state.get(key, []), key=lambda x: str(x['name'])) - reconstructed_items = sorted(reconstructed_state.get(key, []), key=lambda x: str(x['name'])) + original_items = sorted(original_state.get(key, []), key=lambda x: str(x["name"])) + reconstructed_items = sorted(reconstructed_state.get(key, []), key=lambda x: str(x["name"])) self.assertListEqual(original_items, reconstructed_items, f"{key} state mismatch after sorting") - original_computations = sorted(original_state.get("computations", []), key=lambda x: str(x['expression'])) - reconstructed_computations = sorted(reconstructed_state.get("computations", []), key=lambda x: str(x['expression'])) - self.assertListEqual(original_computations, reconstructed_computations, "Computations state mismatch after sorting") + original_computations = sorted(original_state.get("computations", []), key=lambda x: str(x["expression"])) + reconstructed_computations = sorted( + reconstructed_state.get("computations", []), key=lambda x: str(x["expression"]) + ) + self.assertListEqual( + original_computations, reconstructed_computations, "Computations state mismatch after sorting" + ) # Point tests def test_create_point_existing(self) -> None: @@ -419,16 +439,16 @@ def test_create_point_new_unnamed(self) -> None: p2 = self.canvas.create_point(30, 30) self.assertIsNotNone(p1) self.assertIsNotNone(p2) - self.assertEqual(p1.name, 'A') - self.assertEqual(p2.name, 'B') - self.assertIn(p1, self.canvas.drawable_manager.drawables._drawables['Point']) - self.assertIn(p2, self.canvas.drawable_manager.drawables._drawables['Point']) + self.assertEqual(p1.name, "A") + self.assertEqual(p2.name, "B") + self.assertIn(p1, self.canvas.drawable_manager.drawables._drawables["Point"]) + self.assertIn(p2, self.canvas.drawable_manager.drawables._drawables["Point"]) def test_create_point_new_named(self) -> None: - point = self.canvas.create_point(30, 30, 'C') + point = self.canvas.create_point(30, 30, "C") self.assertIsNotNone(point) - self.assertEqual(point.name, 'C') - self.assertIn(point, self.canvas.drawable_manager.drawables._drawables['Point']) + self.assertEqual(point.name, "C") + self.assertIn(point, self.canvas.drawable_manager.drawables._drawables["Point"]) def test_get_point(self) -> None: point_created = self.canvas.create_point(10, 10) @@ -439,9 +459,9 @@ def test_get_point(self) -> None: def test_get_point_by_name(self) -> None: point_created = self.canvas.create_point(10, 10) - point_retrieved = self.canvas.get_point_by_name('A') + point_retrieved = self.canvas.get_point_by_name("A") self.assertEqual(point_created, point_retrieved) - point = self.canvas.get_point_by_name('Z') + point = self.canvas.get_point_by_name("Z") self.assertIsNone(point) def test_delete_point(self) -> None: @@ -451,13 +471,13 @@ def test_delete_point(self) -> None: def test_delete_point_by_name(self) -> None: p1 = self.canvas.create_point(10, 10) - self.canvas.delete_point_by_name('A') + self.canvas.delete_point_by_name("A") self.assertNotIn(p1, self.canvas.get_drawables_by_class_name(p1.get_class_name())) def test_delete_point_by_nonexistent_name(self) -> None: """Test deleting a point with a non-existent name returns False and doesn't cause errors.""" # Test deleting a point that doesn't exist - result = self.canvas.delete_point_by_name('NonexistentPoint') + result = self.canvas.delete_point_by_name("NonexistentPoint") self.assertFalse(result, "delete_point_by_name should return False for non-existent names") # Create a point and verify normal deletion still works @@ -465,7 +485,7 @@ def test_delete_point_by_nonexistent_name(self) -> None: self.assertIsNotNone(self.canvas.get_point_by_name("A")) # Verify the non-existent deletion didn't affect existing points - self.assertIn(p1, self.canvas.get_drawables_by_class_name('Point')) + self.assertIn(p1, self.canvas.get_drawables_by_class_name("Point")) # Delete the point and verify it was removed self.assertTrue(self.canvas.delete_point_by_name("A")) @@ -536,7 +556,7 @@ def test_delete_segment(self) -> None: s = self.canvas.create_segment(10, 10, 20, 20) self.canvas.delete_segment(10, 10, 20, 20) self.assertNotIn(s, self.canvas.get_drawables_by_class_name(s.get_class_name())) - points = self.canvas.get_drawables_by_class_name('Point') + points = self.canvas.get_drawables_by_class_name("Point") self.assertIn(self.canvas.get_point(10, 10), points) self.assertIn(self.canvas.get_point(20, 20), points) @@ -552,14 +572,14 @@ def test_delete_segment_by_name(self) -> None: self.assertNotIn(segment, self.canvas.get_drawables_by_class_name(segment.get_class_name())) # Verify points still exist - points = self.canvas.get_drawables_by_class_name('Point') + points = self.canvas.get_drawables_by_class_name("Point") self.assertIn(self.canvas.get_point(10, 10), points) self.assertIn(self.canvas.get_point(20, 20), points) def test_delete_segment_by_nonexistent_name(self) -> None: """Test deleting a segment with a non-existent name returns False and doesn't cause errors.""" # Test deleting a segment that doesn't exist - result = self.canvas.delete_segment_by_name('NonexistentSegment') + result = self.canvas.delete_segment_by_name("NonexistentSegment") self.assertFalse(result, "delete_segment_by_name should return False for non-existent names") # Create a segment and verify normal deletion still works @@ -567,7 +587,7 @@ def test_delete_segment_by_nonexistent_name(self) -> None: self.assertIsNotNone(self.canvas.get_segment_by_name(segment.name)) # Verify the non-existent deletion didn't affect existing segments - self.assertIn(segment, self.canvas.get_drawables_by_class_name('Segment')) + self.assertIn(segment, self.canvas.get_drawables_by_class_name("Segment")) # Delete the segment and verify it was removed self.assertTrue(self.canvas.delete_segment_by_name(segment.name)) @@ -581,7 +601,9 @@ def test_any_segment_part_visible_in_canvas_area(self) -> None: self.assertTrue(visible) visible = self.canvas.any_segment_part_visible_in_canvas_area(center.x, height + 100, width + 100, center.y) self.assertTrue(visible) - not_visible = self.canvas.any_segment_part_visible_in_canvas_area(width + 10, height + 10, height + 100, width + 100) + not_visible = self.canvas.any_segment_part_visible_in_canvas_area( + width + 10, height + 10, height + 100, width + 100 + ) self.assertFalse(not_visible) def test_delete_segments_depending_on_point(self) -> None: @@ -606,12 +628,24 @@ def test_delete_segment_children(self) -> None: ad = self.canvas.create_segment(0, 0, 30, 0) self.canvas.create_point(10, 0) self.canvas.create_point(20, 0) - self.assertEqual(len(self.canvas.drawable_manager.dependency_manager.get_children(ad)), 5, "Parent segment should have 5 children.") + self.assertEqual( + len(self.canvas.drawable_manager.dependency_manager.get_children(ad)), + 5, + "Parent segment should have 5 children.", + ) # Execute: Delete AD segment's children - self.canvas.drawable_manager.segment_manager._delete_segment_dependencies(0, 0, 30, 0, delete_children=True, delete_parents=False) + self.canvas.drawable_manager.segment_manager._delete_segment_dependencies( + 0, 0, 30, 0, delete_children=True, delete_parents=False + ) # Verify: All contained child segments should be deleted - self.assertEqual(len(self.canvas.drawable_manager.dependency_manager.get_children(ad)), 0, "Parent segment should have no children.") - self.assertEqual(len(self.canvas.get_drawables_by_class_name('Segment')), 1, "There should be one segment left.") + self.assertEqual( + len(self.canvas.drawable_manager.dependency_manager.get_children(ad)), + 0, + "Parent segment should have no children.", + ) + self.assertEqual( + len(self.canvas.get_drawables_by_class_name("Segment")), 1, "There should be one segment left." + ) def test_dependency_graph_stays_consistent_after_point_delete_recreate(self) -> None: segment = self.canvas.create_segment(0, 0, 10, 0, name="AB") @@ -653,17 +687,27 @@ def test_dependency_graph_stays_consistent_after_ellipse_delete_recreate(self) - self._assert_dependency_graph_consistent() def test_delete_segment_parents(self) -> None: - s = self.canvas.create_segment(50, 0, 60, 0) # distinct segment + s = self.canvas.create_segment(50, 0, 60, 0) # distinct segment self.canvas.create_segment(0, 0, 30, 0) # root segment self.canvas.create_point(10, 0) - self.assertEqual(len(self.canvas.get_drawables_by_class_name('Segment')), 4, "There should be 4 segments in total.") + self.assertEqual( + len(self.canvas.get_drawables_by_class_name("Segment")), 4, "There should be 4 segments in total." + ) self.canvas.create_point(20, 0) - self.assertEqual(len(self.canvas.get_drawables_by_class_name('Segment')), 7, "There should be 7 segments in total.") + self.assertEqual( + len(self.canvas.get_drawables_by_class_name("Segment")), 7, "There should be 7 segments in total." + ) # Delete segment's parents - self.canvas.drawable_manager.segment_manager._delete_segment_dependencies(0, 0, 10, 0, delete_children=False, delete_parents=True) + self.canvas.drawable_manager.segment_manager._delete_segment_dependencies( + 0, 0, 10, 0, delete_children=False, delete_parents=True + ) # Verify: Segment parents should be deleted - self.assertEqual(len(self.canvas.get_drawables_by_class_name('Segment')), 1, "There should be one segment left - deleting the parent also deletes its children.") - self.assertIn(s, self.canvas.get_drawables_by_class_name('Segment'), "Distinct segment should not be deleted.") + self.assertEqual( + len(self.canvas.get_drawables_by_class_name("Segment")), + 1, + "There should be one segment left - deleting the parent also deletes its children.", + ) + self.assertIn(s, self.canvas.get_drawables_by_class_name("Segment"), "Distinct segment should not be deleted.") def test_are_points_connected(self) -> None: s1 = self.canvas.create_segment(10, 10, 20, 20) @@ -673,7 +717,9 @@ def test_are_points_connected(self) -> None: p3 = self.canvas.get_point(30, 30) points = [p1.name, p2.name, p3.name] # Check if the points are not connected - are_not_connected = GeometryUtils.is_fully_connected_graph(points, self.canvas.drawable_manager.drawables.Segments) + are_not_connected = GeometryUtils.is_fully_connected_graph( + points, self.canvas.drawable_manager.drawables.Segments + ) self.assertFalse(are_not_connected, "All points should not be connected.") # Create a segment connecting the points s3 = self.canvas.create_segment(30, 30, 10, 10) @@ -713,7 +759,7 @@ def test_delete_vector(self) -> None: self.canvas.delete_vector(10, 10, 20, 20) self.assertNotIn(vector, self.canvas.get_drawables_by_class_name(vector.get_class_name())) # Check points are still present after deleting the vector - points = self.canvas.get_drawables_by_class_name('Point') + points = self.canvas.get_drawables_by_class_name("Point") self.assertIn(self.canvas.get_point(10, 10), points) self.assertIn(self.canvas.get_point(20, 20), points) @@ -728,7 +774,7 @@ def test_delete_vector_by_nonexistent_coordinates(self) -> None: self.assertIsNotNone(self.canvas.get_vector(10, 10, 20, 20)) # Verify the non-existent deletion didn't affect existing vectors - self.assertIn(vector, self.canvas.get_drawables_by_class_name('Vector')) + self.assertIn(vector, self.canvas.get_drawables_by_class_name("Vector")) # Delete the vector and verify it was removed self.assertTrue(self.canvas.delete_vector(10, 10, 20, 20)) @@ -754,7 +800,7 @@ def test_create_triangle_new(self) -> None: polygon_type=PolygonType.TRIANGLE, ) self.assertIsNotNone(new_triangle) - self.assertIn(new_triangle, self.canvas.get_drawables_by_class_name('Triangle')) + self.assertIn(new_triangle, self.canvas.get_drawables_by_class_name("Triangle")) def test_get_triangle(self) -> None: vertices = [(10, 10), (20, 20), (30, 30)] @@ -776,7 +822,7 @@ def test_delete_triangle(self) -> None: self.canvas.delete_polygon(polygon_type=PolygonType.TRIANGLE, vertices=vertices) self.assertNotIn(triangle, self.canvas.get_drawables_by_class_name(triangle.get_class_name())) # Verifying segments and points still exist after deleting the triangle - points = self.canvas.get_drawables_by_class_name('Point') + points = self.canvas.get_drawables_by_class_name("Point") self.assertTrue( all( p in points @@ -800,20 +846,14 @@ def test_delete_triangle_by_nonexistent_coordinates(self) -> None: # Create a triangle and verify normal deletion still works vertices = [(10, 10), (20, 20), (30, 30)] triangle = self.canvas.create_polygon(vertices, polygon_type=PolygonType.TRIANGLE) - self.assertIsNotNone( - self.canvas.get_polygon_by_vertices(vertices, polygon_type=PolygonType.TRIANGLE) - ) + self.assertIsNotNone(self.canvas.get_polygon_by_vertices(vertices, polygon_type=PolygonType.TRIANGLE)) # Verify the non-existent deletion didn't affect existing triangles - self.assertIn(triangle, self.canvas.get_drawables_by_class_name('Triangle')) + self.assertIn(triangle, self.canvas.get_drawables_by_class_name("Triangle")) # Delete the triangle and verify it was removed - self.assertTrue( - self.canvas.delete_polygon(polygon_type=PolygonType.TRIANGLE, vertices=vertices) - ) - self.assertIsNone( - self.canvas.get_polygon_by_vertices(vertices, polygon_type=PolygonType.TRIANGLE) - ) + self.assertTrue(self.canvas.delete_polygon(polygon_type=PolygonType.TRIANGLE, vertices=vertices)) + self.assertIsNone(self.canvas.get_polygon_by_vertices(vertices, polygon_type=PolygonType.TRIANGLE)) def test_create_triangle_from_connected_segments(self) -> None: # Setup: Create segments that form a triangle @@ -821,12 +861,20 @@ def test_create_triangle_from_connected_segments(self) -> None: self.canvas.create_segment(20, 20, 10, 30, name="BC", extra_graphics=False) self.canvas.create_segment(10, 30, 10, 10, name="CA", extra_graphics=False) # Verify initial conditions - self.assertEqual(len(self.canvas.get_drawables_by_class_name("Segment")), 3, "Canvas should initially have 3 segments.") - self.assertEqual(len(self.canvas.get_drawables_by_class_name("Triangle")), 0, "Canvas should initially have no triangles.") + self.assertEqual( + len(self.canvas.get_drawables_by_class_name("Segment")), 3, "Canvas should initially have 3 segments." + ) + self.assertEqual( + len(self.canvas.get_drawables_by_class_name("Triangle")), 0, "Canvas should initially have no triangles." + ) # Execute: Attempt to create new triangles from connected segments self.canvas.drawable_manager.create_new_triangles_from_connected_segments() # Verify: A triangle should be created from the three segments - self.assertEqual(len(self.canvas.get_drawables_by_class_name("Triangle")), 1, "Canvas should have 1 triangle after operation.") + self.assertEqual( + len(self.canvas.get_drawables_by_class_name("Triangle")), + 1, + "Canvas should have 1 triangle after operation.", + ) def test_no_triangle_from_unconnected_segments(self) -> None: # Setup: Create segments that do not form a triangle @@ -836,7 +884,11 @@ def test_no_triangle_from_unconnected_segments(self) -> None: # Execute: Attempt to create new triangles from unconnected segments self.canvas.drawable_manager.create_new_triangles_from_connected_segments() # Verify: No triangle should be created as the segments are not connected - self.assertEqual(len(self.canvas.get_drawables_by_class_name("Triangle")), 0, "Canvas should have no triangles as segments are unconnected.") + self.assertEqual( + len(self.canvas.get_drawables_by_class_name("Triangle")), + 0, + "Canvas should have no triangles as segments are unconnected.", + ) def test_triangle_rendering_uses_primitives_only(self) -> None: class RecordingRenderer: @@ -958,7 +1010,12 @@ def test_update_polygon_color(self) -> None: ) self.canvas.update_polygon(polygon.name, polygon_type=PolygonType.RECTANGLE, new_color="#ff00aa") self.assertEqual(polygon.color, "#ff00aa") - self.assertTrue(all(edge.color == "#ff00aa" for edge in [polygon.segment1, polygon.segment2, polygon.segment3, polygon.segment4])) + self.assertTrue( + all( + edge.color == "#ff00aa" + for edge in [polygon.segment1, polygon.segment2, polygon.segment3, polygon.segment4] + ) + ) def test_delete_polygon_by_vertices(self) -> None: vertices = [ @@ -1066,10 +1123,7 @@ def test_create_polygon_quadrilateral_rhombus_subtype(self) -> None: ) self.assertEqual(polygon.get_class_name(), "Quadrilateral") segments = polygon.get_segments() - lengths = [ - ((s.point2.x - s.point1.x) ** 2 + (s.point2.y - s.point1.y) ** 2) ** 0.5 - for s in segments - ] + lengths = [((s.point2.x - s.point1.x) ** 2 + (s.point2.y - s.point1.y) ** 2) ** 0.5 for s in segments] for length in lengths[1:]: self.assertAlmostEqual(lengths[0], length, places=4) @@ -1077,7 +1131,9 @@ def test_create_polygon_quadrilateral_rhombus_subtype(self) -> None: def test_create_rectangle_new(self) -> None: new_rectangle = self._create_rectangle(10, 10, 40, 40) self.assertIsNotNone(new_rectangle, "Failed to create a new rectangle.") - self.assertIn(new_rectangle, self.canvas.get_drawables_by_class_name('Rectangle'), "New rectangle should be in canvas.") + self.assertIn( + new_rectangle, self.canvas.get_drawables_by_class_name("Rectangle"), "New rectangle should be in canvas." + ) def test_get_rectangle_by_diagonal_points(self) -> None: rectangle_created = self._create_rectangle(10, 10, 30, 30) @@ -1087,19 +1143,27 @@ def test_get_rectangle_by_diagonal_points(self) -> None: ) rectangle_retrieved = self.canvas.get_polygon_by_vertices(target_vertices, polygon_type=PolygonType.RECTANGLE) self.assertIsNotNone(rectangle_retrieved, "Rectangle should exist.") - self.assertEqual(rectangle_created, rectangle_retrieved, "Retrieved rectangle should be the same as the created one.") + self.assertEqual( + rectangle_created, rectangle_retrieved, "Retrieved rectangle should be the same as the created one." + ) missing_vertices = canonicalize_rectangle( [(100, 100), (200, 200)], construction_mode="diagonal", ) - non_existent_rectangle = self.canvas.get_polygon_by_vertices(missing_vertices, polygon_type=PolygonType.RECTANGLE) + non_existent_rectangle = self.canvas.get_polygon_by_vertices( + missing_vertices, polygon_type=PolygonType.RECTANGLE + ) self.assertIsNone(non_existent_rectangle, "Should not retrieve a non-existent rectangle.") def test_get_rectangle_by_name(self) -> None: rectangle_created = self._create_rectangle(10, 10, 20, 20) - rectangle_retrieved = self.canvas.get_polygon_by_name(rectangle_created.name, polygon_type=PolygonType.RECTANGLE) + rectangle_retrieved = self.canvas.get_polygon_by_name( + rectangle_created.name, polygon_type=PolygonType.RECTANGLE + ) self.assertIsNotNone(rectangle_retrieved, "Rectangle should exist.") - self.assertEqual(rectangle_created, rectangle_retrieved, "Retrieved rectangle should be the same as the created one.") + self.assertEqual( + rectangle_created, rectangle_retrieved, "Retrieved rectangle should be the same as the created one." + ) non_existent_rectangle = self.canvas.get_polygon_by_name("NonExistent", polygon_type=PolygonType.RECTANGLE) self.assertIsNone(non_existent_rectangle, "Should not retrieve a non-existent rectangle by name.") @@ -1122,11 +1186,11 @@ def test_delete_rectangle(self) -> None: ) def test_delete_rectangles_by_nonexistent_name(self) -> None: - result = self.canvas.delete_polygon(polygon_type=PolygonType.RECTANGLE, name='NonexistentRectangle') + result = self.canvas.delete_polygon(polygon_type=PolygonType.RECTANGLE, name="NonexistentRectangle") self.assertFalse(result, "delete_polygon should return False for non-existent rectangle names") rectangle = self._create_rectangle(10, 10, 30, 30) self.assertIsNotNone(self.canvas.get_polygon_by_name(rectangle.name, polygon_type=PolygonType.RECTANGLE)) - self.assertIn(rectangle, self.canvas.get_drawables_by_class_name('Rectangle')) + self.assertIn(rectangle, self.canvas.get_drawables_by_class_name("Rectangle")) self.assertTrue(self.canvas.delete_polygon(polygon_type=PolygonType.RECTANGLE, name=rectangle.name)) self.assertIsNone(self.canvas.get_polygon_by_name(rectangle.name, polygon_type=PolygonType.RECTANGLE)) @@ -1143,7 +1207,7 @@ def test_create_circle_new(self) -> None: # Directly test creating a new circle new_circle = self.canvas.create_circle(center_x, center_y, radius) self.assertIsNotNone(new_circle, "Failed to create a new circle.") - self.assertIn(new_circle, self.canvas.get_drawables_by_class_name('Circle'), "New circle should be in canvas.") + self.assertIn(new_circle, self.canvas.get_drawables_by_class_name("Circle"), "New circle should be in canvas.") def test_create_circle_with_color(self) -> None: center_x, center_y, radius = 110, 110, 45 @@ -1236,12 +1300,12 @@ def test_delete_circle(self) -> None: # Test deleting the circle self.canvas.delete_circle(circle.name) self.assertIsNone(self.canvas.get_circle(center_x, center_y, radius), "Circle should be deleted.") - self.assertIsNone(self.canvas.get_circle_by_name('A(35)'), "Circle should be deleted.") + self.assertIsNone(self.canvas.get_circle_by_name("A(35)"), "Circle should be deleted.") def test_delete_circle_by_nonexistent_name(self) -> None: """Test deleting a circle with a non-existent name returns False and doesn't cause errors.""" # Test deleting a circle that doesn't exist - result = self.canvas.delete_circle('NonexistentCircle') + result = self.canvas.delete_circle("NonexistentCircle") self.assertFalse(result, "delete_circle should return False for non-existent names") # Create a circle and verify normal deletion still works @@ -1249,7 +1313,7 @@ def test_delete_circle_by_nonexistent_name(self) -> None: self.assertIsNotNone(self.canvas.get_circle_by_name(circle.name)) # Verify the non-existent deletion didn't affect existing circles - self.assertIn(circle, self.canvas.get_drawables_by_class_name('Circle')) + self.assertIn(circle, self.canvas.get_drawables_by_class_name("Circle")) # Delete the circle and verify it was removed self.assertTrue(self.canvas.delete_circle(circle.name)) @@ -1260,7 +1324,7 @@ def test_delete_circles_depending_on_point(self) -> None: circle = self.canvas.create_circle(center_x, center_y, radius) self.assertIsNotNone(circle) self.assertIn(circle, self.canvas.get_drawables_by_class_name(circle.get_class_name())) - self.canvas.drawable_manager.point_manager._delete_point_dependencies(center_x, center_y) # Keep original args + self.canvas.drawable_manager.point_manager._delete_point_dependencies(center_x, center_y) # Keep original args self.assertNotIn(circle, self.canvas.get_drawables_by_class_name(circle.get_class_name())) # Ellipse tests @@ -1269,24 +1333,28 @@ def test_create_ellipse_new(self) -> None: # Directly test creating a new ellipse new_ellipse = self.canvas.create_ellipse(center_x, center_y, radius_x, radius_y) self.assertIsNotNone(new_ellipse, "Failed to create a new ellipse.") - self.assertIn(new_ellipse, self.canvas.get_drawables_by_class_name(new_ellipse.get_class_name()), "New ellipse should be in canvas.") + self.assertIn( + new_ellipse, + self.canvas.get_drawables_by_class_name(new_ellipse.get_class_name()), + "New ellipse should be in canvas.", + ) def test_create_ellipse_with_rotation(self) -> None: center_x, center_y = 100, 100 radius_x, radius_y = 50, 30 rotation_angle = 45 - new_ellipse = self.canvas.create_ellipse( - center_x, center_y, radius_x, radius_y, - rotation_angle=rotation_angle - ) + new_ellipse = self.canvas.create_ellipse(center_x, center_y, radius_x, radius_y, rotation_angle=rotation_angle) self.assertIsNotNone(new_ellipse, "Failed to create a new rotated ellipse.") - self.assertEqual(new_ellipse.rotation_angle, rotation_angle, - "Ellipse rotation angle does not match specified angle.") - self.assertIn(new_ellipse, - self.canvas.get_drawables_by_class_name(new_ellipse.get_class_name()), - "New rotated ellipse should be in canvas.") + self.assertEqual( + new_ellipse.rotation_angle, rotation_angle, "Ellipse rotation angle does not match specified angle." + ) + self.assertIn( + new_ellipse, + self.canvas.get_drawables_by_class_name(new_ellipse.get_class_name()), + "New rotated ellipse should be in canvas.", + ) def test_create_ellipse_existing(self) -> None: center_x, center_y, radius_x, radius_y = 130, 130, 40, 25 @@ -1334,7 +1402,7 @@ def test_delete_ellipse(self) -> None: def test_delete_ellipse_by_nonexistent_name(self) -> None: """Test deleting an ellipse with a non-existent name returns False and doesn't cause errors.""" # Test deleting an ellipse that doesn't exist - result = self.canvas.delete_ellipse('NonexistentEllipse') + result = self.canvas.delete_ellipse("NonexistentEllipse") self.assertFalse(result, "delete_ellipse should return False for non-existent names") # Create an ellipse and verify normal deletion still works @@ -1342,7 +1410,7 @@ def test_delete_ellipse_by_nonexistent_name(self) -> None: self.assertIsNotNone(self.canvas.get_ellipse_by_name(ellipse.name)) # Verify the non-existent deletion didn't affect existing ellipses - self.assertIn(ellipse, self.canvas.get_drawables_by_class_name('Ellipse')) + self.assertIn(ellipse, self.canvas.get_drawables_by_class_name("Ellipse")) # Delete the ellipse and verify it was removed self.assertTrue(self.canvas.delete_ellipse(ellipse.name)) @@ -1363,7 +1431,9 @@ def test_draw_function_new(self) -> None: # Directly test drawing a new math function f = self.canvas.draw_function(function_string, name, left_bound, right_bound) self.assertIsNotNone(f, "Failed to draw a new math function.") - self.assertIn(f, self.canvas.get_drawables_by_class_name(f.get_class_name()), "New function should be in canvas.") + self.assertIn( + f, self.canvas.get_drawables_by_class_name(f.get_class_name()), "New function should be in canvas." + ) self.assertEqual(f.function_string, "x^2", "Function string should match.") self.assertEqual(f.left_bound, left_bound, "Left bound should match.") self.assertEqual(f.right_bound, right_bound, "Right bound should match.") @@ -1385,7 +1455,11 @@ def test_delete_math_function(self) -> None: f = self.canvas.draw_function(function_string, name) # Ensure the function exists before deletion self.assertIsNotNone(f, "Math function should exist before deletion.") - self.assertIn(f, self.canvas.get_drawables_by_class_name(f.get_class_name()), "Math function should exist before deletion.") + self.assertIn( + f, + self.canvas.get_drawables_by_class_name(f.get_class_name()), + "Math function should exist before deletion.", + ) # Test deleting the math function self.canvas.delete_function(f.name) deleted_function = self.canvas.get_function(f.name) @@ -1394,7 +1468,7 @@ def test_delete_math_function(self) -> None: def test_delete_function_by_nonexistent_name(self) -> None: """Test deleting a function with a non-existent name returns False and doesn't cause errors.""" # Test deleting a function that doesn't exist - result = self.canvas.delete_function('NonexistentFunction') + result = self.canvas.delete_function("NonexistentFunction") self.assertFalse(result, "delete_function should return False for non-existent names") # Create a function and verify normal deletion still works @@ -1402,7 +1476,7 @@ def test_delete_function_by_nonexistent_name(self) -> None: self.assertIsNotNone(self.canvas.get_function(func.name)) # Verify the non-existent deletion didn't affect existing functions - self.assertIn(func, self.canvas.get_drawables_by_class_name('Function')) + self.assertIn(func, self.canvas.get_drawables_by_class_name("Function")) # Delete the function and verify it was removed self.assertTrue(self.canvas.delete_function(func.name)) @@ -1448,8 +1522,8 @@ def test_find_segment_children(self) -> None: self.assertEqual(len(s3_children), 2, "Segment s3 should have 2 children.") def test_get_segment_parents(self) -> None: - original_segment = self.canvas.create_segment(0, 0, 30, 0) # AB: A(0,0), B(30,0) - point_c = self.canvas.create_point(10, 0) # C(10,0) on AB + original_segment = self.canvas.create_segment(0, 0, 30, 0) # AB: A(0,0), B(30,0) + point_c = self.canvas.create_point(10, 0) # C(10,0) on AB dm = self.canvas.drawable_manager.dependency_manager children = list(dm.get_children(original_segment)) @@ -1457,8 +1531,8 @@ def test_get_segment_parents(self) -> None: self.assertEqual(len(children), 2, "Original segment should have 2 children after split.") # Retrieve child segments by their expected coordinates for reliable checking - child_ac = self.canvas.get_segment_by_coordinates(0, 0, 10, 0) # A(0,0) to C(10,0) - child_cb = self.canvas.get_segment_by_coordinates(10, 0, 30, 0) # C(10,0) to B(30,0) + child_ac = self.canvas.get_segment_by_coordinates(0, 0, 10, 0) # A(0,0) to C(10,0) + child_cb = self.canvas.get_segment_by_coordinates(10, 0, 30, 0) # C(10,0) to B(30,0) self.assertIsNotNone(child_ac, "Child segment AC should exist.") self.assertIsNotNone(child_cb, "Child segment CB should exist.") @@ -1492,7 +1566,11 @@ def test_add_segment_to_parents(self) -> None: def test_split_segment(self) -> None: segment = self.canvas.create_segment(0, 0, 30, 0) self.canvas.drawable_manager.segment_manager._split_segments_with_point(20, 0) - self.assertEqual(len(self.canvas.drawable_manager.dependency_manager.get_children(segment)), 2, "Segment should have 2 children.") + self.assertEqual( + len(self.canvas.drawable_manager.dependency_manager.get_children(segment)), + 2, + "Segment should have 2 children.", + ) segments = self.canvas.get_drawables_by_class_name("Segment") points = self.canvas.get_drawables_by_class_name("Point") @@ -1524,7 +1602,11 @@ def test_split_segment_with_specific_coordinates(self) -> None: self.canvas.drawable_manager.segment_manager._split_segments_with_point(point_c_x, point_c_y) # Verify that the original segment has two children - self.assertEqual(len(self.canvas.drawable_manager.dependency_manager.get_children(segment_ab)), 2, "Segment AB should have 2 children after splitting") + self.assertEqual( + len(self.canvas.drawable_manager.dependency_manager.get_children(segment_ab)), + 2, + "Segment AB should have 2 children after splitting", + ) # Get the child segments (AC and CB) ac_segment = self.canvas.get_segment_by_coordinates(x1, y1, point_c_x, point_c_y) @@ -1547,11 +1629,7 @@ def test_split_segment_with_specific_coordinates(self) -> None: b_pos = segment_ab.point2 c_pos = SimpleNamespace(x=point_c_x, y=point_c_y) - is_point_on_segment = MathUtils.is_point_on_segment( - c_pos.x, c_pos.y, - a_pos.x, a_pos.y, - b_pos.x, b_pos.y - ) + is_point_on_segment = MathUtils.is_point_on_segment(c_pos.x, c_pos.y, a_pos.x, a_pos.y, b_pos.x, b_pos.y) self.assertTrue(is_point_on_segment, "Point C should be on segment AB") @@ -1572,10 +1650,14 @@ def test_split_segment_comprehensive(self) -> None: self.canvas.drawable_manager.segment_manager._split_segments_with_point(10, 0) # Original segment AB should have two direct children after splitting - self.assertEqual(len(self.canvas.drawable_manager.dependency_manager.get_children(ab)), 2, "Original segment AB should have 2 children after adding point C") + self.assertEqual( + len(self.canvas.drawable_manager.dependency_manager.get_children(ab)), + 2, + "Original segment AB should have 2 children after adding point C", + ) # Get the child segments (AC and CB) - segments = self.canvas.get_drawables_by_class_name('Segment') + segments = self.canvas.get_drawables_by_class_name("Segment") self.assertEqual(len(segments), 3, "Should have 3 segments total: AB (parent), AC, CB") # Find the child segments @@ -1601,7 +1683,11 @@ def test_split_segment_comprehensive(self) -> None: self.assertIsNotNone(db_segment, "Segment DB should exist") # The original child segment CB should now have two children of its own - self.assertEqual(len(self.canvas.drawable_manager.dependency_manager.get_children(cb_segment)), 2, "Segment CB should have 2 children (CD and DB)") + self.assertEqual( + len(self.canvas.drawable_manager.dependency_manager.get_children(cb_segment)), + 2, + "Segment CB should have 2 children (CD and DB)", + ) # But the original parent AB should still have CD and DB as grandchildren # Check that AB is in the parent chain of both CD and DB @@ -1613,12 +1699,18 @@ def test_split_segment_comprehensive(self) -> None: self.assertTrue(any(p == cb_segment for p in db_parents), "CB should be a parent of DB") # Verify the total number of segments - segments_after = self.canvas.get_drawables_by_class_name('Segment') - self.assertEqual(len(segments_after), 6, "Should have 6 segments total after splitting twice: AB (parent), AC, CB (children), CD, DB (grandchildren), and AD (additional segment between existing points)") + segments_after = self.canvas.get_drawables_by_class_name("Segment") + self.assertEqual( + len(segments_after), + 6, + "Should have 6 segments total after splitting twice: AB (parent), AC, CB (children), CD, DB (grandchildren), and AD (additional segment between existing points)", + ) # Check that the additional segment AD exists ad_segment = self.canvas.get_segment_by_coordinates(0, 0, 25, 0) - self.assertIsNotNone(ad_segment, "Segment AD should also exist as all possible segments between points are created") + self.assertIsNotNone( + ad_segment, "Segment AD should also exist as all possible segments between points are created" + ) # Delete point C and verify consequences c_point = self.canvas.get_point(10, 0) @@ -1626,10 +1718,12 @@ def test_split_segment_comprehensive(self) -> None: # After deleting C, we should lose AC and CD, but AB, AD, and DB should still exist # AB's children should be updated - segments_after_delete = self.canvas.get_drawables_by_class_name('Segment') + segments_after_delete = self.canvas.get_drawables_by_class_name("Segment") # We should have 3 segments left: AB, AD, and DB - self.assertEqual(len(segments_after_delete), 3, "Should have 3 segments left after deleting C: AB (parent), AD, and DB") + self.assertEqual( + len(segments_after_delete), 3, "Should have 3 segments left after deleting C: AB (parent), AD, and DB" + ) # AB should still exist ab_after = self.canvas.get_segment_by_name("AB") @@ -1648,7 +1742,7 @@ def test_segment_preservation_with_one_point(self) -> None: c_point = self.canvas.create_point(25, 0, name="C", extra_graphics=True) # 3. Check total number of segments - all_segments = self.canvas.get_drawables_by_class_name('Segment') + all_segments = self.canvas.get_drawables_by_class_name("Segment") self.assertEqual(len(all_segments), 3, "Should have 3 segments total: AB, AC, CB") # 4. Get child segments AC and CB @@ -1668,7 +1762,7 @@ def test_segment_preservation_with_one_point(self) -> None: self.assertIn(cb_segment, children_ab, "Segment CB should be a child of AB") # 6. Delete point C - self.canvas.delete_point(25, 0) # Deleting point C by coordinates + self.canvas.delete_point(25, 0) # Deleting point C by coordinates # 7. Verify AB still exists ab_after_delete = self.canvas.get_segment_by_name("AB") @@ -1682,19 +1776,23 @@ def test_segment_preservation_with_one_point(self) -> None: self.assertIsNone(cb_after_delete, "Segment CB should be deleted after point C is deleted") # 9. Verify AB has no segment children - children_ab_after_delete = self.canvas.drawable_manager.dependency_manager.get_children(ab_after_delete) # Use ab_after_delete to be safe - self.assertEqual(len(children_ab_after_delete), 0, "Segment AB should have 0 segment children after point C is deleted") + children_ab_after_delete = self.canvas.drawable_manager.dependency_manager.get_children( + ab_after_delete + ) # Use ab_after_delete to be safe + self.assertEqual( + len(children_ab_after_delete), 0, "Segment AB should have 0 segment children after point C is deleted" + ) # 10. Verify points A and B still exist, but C is gone - pointA_after_delete = self.canvas.get_point(0,0) - pointB_after_delete = self.canvas.get_point(100,0) - pointC_after_delete = self.canvas.get_point(25,0) + pointA_after_delete = self.canvas.get_point(0, 0) + pointB_after_delete = self.canvas.get_point(100, 0) + pointC_after_delete = self.canvas.get_point(25, 0) self.assertIsNotNone(pointA_after_delete, "Point A should still exist") self.assertIsNotNone(pointB_after_delete, "Point B should still exist") self.assertIsNone(pointC_after_delete, "Point C should be deleted") # 11. Verify total number of segments is 1 (only AB) - all_segments_after_delete = self.canvas.get_drawables_by_class_name('Segment') + all_segments_after_delete = self.canvas.get_drawables_by_class_name("Segment") self.assertEqual(len(all_segments_after_delete), 1, "Should have 1 segment (AB) left after deleting point C") self.assertIn(ab_after_delete, all_segments_after_delete, "The remaining segment should be AB") @@ -1709,7 +1807,7 @@ def test_segment_preservation_with_multiple_points(self) -> None: # 3. Initial Checks (after adding C and D) # Points: A(0,0), C(25,0), D(50,0), B(100,0) - all_segments_initial = self.canvas.get_drawables_by_class_name('Segment') + all_segments_initial = self.canvas.get_drawables_by_class_name("Segment") self.assertEqual(len(all_segments_initial), 6, "Should have 6 segments total: AB, AC, AD, CD, CB, DB") ac_initial = self.canvas.get_segment_by_coordinates(0, 0, 25, 0) @@ -1727,14 +1825,16 @@ def test_segment_preservation_with_multiple_points(self) -> None: children_ab_initial = self.canvas.drawable_manager.dependency_manager.get_children(ab) self.assertEqual(len(children_ab_initial), 5, "AB should have 5 children initially (AC, AD, CD, CB, DB)") for seg in [ac_initial, ad_initial, cd_initial, cb_initial, db_initial]: - self.assertIn(seg, children_ab_initial, f"{seg.name if seg else 'A segment'} should be a child of AB initially") + self.assertIn( + seg, children_ab_initial, f"{seg.name if seg else 'A segment'} should be a child of AB initially" + ) # 4. Delete point C - self.canvas.delete_point(25, 0) # Deleting point C + self.canvas.delete_point(25, 0) # Deleting point C # 5. Checks after deleting C # Remaining points: A(0,0), D(50,0), B(100,0) - all_segments_after_c_delete = self.canvas.get_drawables_by_class_name('Segment') + all_segments_after_c_delete = self.canvas.get_drawables_by_class_name("Segment") self.assertEqual(len(all_segments_after_c_delete), 3, "Should have 3 segments after deleting C (AB, AD, DB)") ab_after_c = self.canvas.get_segment_by_name("AB") @@ -1756,11 +1856,11 @@ def test_segment_preservation_with_multiple_points(self) -> None: self.assertIn(db_after_c, children_ab_after_c_delete, "DB should be a child of AB after C deletion") # 6. Delete point D - self.canvas.delete_point(50, 0) # Deleting point D + self.canvas.delete_point(50, 0) # Deleting point D # 7. Final Checks (after deleting C and D) # Remaining points: A(0,0), B(100,0) - all_segments_final = self.canvas.get_drawables_by_class_name('Segment') + all_segments_final = self.canvas.get_drawables_by_class_name("Segment") self.assertEqual(len(all_segments_final), 1, "Should have 1 segment (AB) after C and D deletion") ab_final = self.canvas.get_segment_by_name("AB") @@ -1901,7 +2001,7 @@ def test_translate_triangle(self) -> None: original_points = [ (triangle.segment1.point1.x, triangle.segment1.point1.y), (triangle.segment1.point2.x, triangle.segment1.point2.y), - (triangle.segment2.point2.x, triangle.segment2.point2.y) + (triangle.segment2.point2.x, triangle.segment2.point2.y), ] # Translate triangle @@ -1912,25 +2012,31 @@ def test_translate_triangle(self) -> None: translated_points = [ (triangle.segment1.point1.x, triangle.segment1.point1.y), (triangle.segment1.point2.x, triangle.segment1.point2.y), - (triangle.segment2.point2.x, triangle.segment2.point2.y) + (triangle.segment2.point2.x, triangle.segment2.point2.y), ] # Verify each point was translated correctly for i in range(3): - self.assertEqual(translated_points[i][0], original_points[i][0] + x_offset, - f"Point {i+1} x-coordinate not translated correctly") - self.assertEqual(translated_points[i][1], original_points[i][1] + y_offset, - f"Point {i+1} y-coordinate not translated correctly") + self.assertEqual( + translated_points[i][0], + original_points[i][0] + x_offset, + f"Point {i + 1} x-coordinate not translated correctly", + ) + self.assertEqual( + translated_points[i][1], + original_points[i][1] + y_offset, + f"Point {i + 1} y-coordinate not translated correctly", + ) def test_translate_function(self) -> None: def replace_x(func_str: str, x_offset: float) -> str: protected_funcs = sorted(ExpressionValidator.ALLOWED_FUNCTIONS, key=len, reverse=True) - func_pattern = '|'.join(map(re.escape, protected_funcs)) - pattern = rf'\b(x)\b|({func_pattern})' + func_pattern = "|".join(map(re.escape, protected_funcs)) + pattern = rf"\b(x)\b|({func_pattern})" def replace_match(match: re.Match[str]) -> str: if match.group(1): - return f'(x - {x_offset})' + return f"(x - {x_offset})" if match.group(2): return match.group(2) return match.group(0) @@ -1939,14 +2045,14 @@ def replace_match(match: re.Match[str]) -> str: # Test basic function translations test_cases = [ - ("x + 1", 2, 3), # Linear function with positive offset - ("sin(x)", -1, 2), # Trigonometric function with negative x offset - ("x^2", -3, 0), # Quadratic function with negative x offset - ("1/x", -2, -1), # Rational function with negative offsets - ("sqrt(x)", 1, -1), # Root function with negative y offset - ("2*x + 3", -2, 2), # Linear function with negative x offset - ("exp(x)", -1, 0), # Exponential function with negative x offset - ("log(x)", -2, -3), # Logarithmic function with negative offsets + ("x + 1", 2, 3), # Linear function with positive offset + ("sin(x)", -1, 2), # Trigonometric function with negative x offset + ("x^2", -3, 0), # Quadratic function with negative x offset + ("1/x", -2, -1), # Rational function with negative offsets + ("sqrt(x)", 1, -1), # Root function with negative y offset + ("2*x + 3", -2, 2), # Linear function with negative x offset + ("exp(x)", -1, 0), # Exponential function with negative x offset + ("log(x)", -2, -3), # Logarithmic function with negative offsets ] for func_str, x_offset, y_offset in test_cases: @@ -1961,10 +2067,16 @@ def replace_match(match: re.Match[str]) -> str: # Verify bounds were translated correctly for x_offset if x_offset != 0: - self.assertEqual(f.left_bound, original_left + x_offset, - f"Left bound incorrect for {func_str} with x_offset {x_offset}") - self.assertEqual(f.right_bound, original_right + x_offset, - f"Right bound incorrect for {func_str} with x_offset {x_offset}") + self.assertEqual( + f.left_bound, + original_left + x_offset, + f"Left bound incorrect for {func_str} with x_offset {x_offset}", + ) + self.assertEqual( + f.right_bound, + original_right + x_offset, + f"Right bound incorrect for {func_str} with x_offset {x_offset}", + ) # Get expected function string after translation expected = func_str @@ -1980,7 +2092,7 @@ def replace_match(match: re.Match[str]) -> str: ExpressionValidator.fix_math_expression(expected), f"Function string incorrect for {func_str} with offsets ({x_offset}, {y_offset})\n" f"Expected: {ExpressionValidator.fix_math_expression(expected)}\n" - f"Actual: {actual}" + f"Actual: {actual}", ) # Test function evaluation at sample points @@ -1991,11 +2103,11 @@ def replace_match(match: re.Match[str]) -> str: for x in test_points: try: # Ensure x is a valid input for the original function - if func_str.startswith('sqrt') and (x - x_offset) < 0: + if func_str.startswith("sqrt") and (x - x_offset) < 0: continue # Skip negative inputs for sqrt - if func_str.startswith('log') and (x - x_offset) <= 0: + if func_str.startswith("log") and (x - x_offset) <= 0: continue # Skip non-positive inputs for log - if '1/x' in func_str and (x - x_offset) == 0: + if "1/x" in func_str and (x - x_offset) == 0: continue # Skip zero for division # Calculate expected result @@ -2010,14 +2122,16 @@ def replace_match(match: re.Match[str]) -> str: expected_result, places=10, msg=f"Failed for function {func_str} at x={x} with offsets ({x_offset}, {y_offset})\n" - f"Expected: {expected_result}\n" - f"Actual: {actual_result}" + f"Expected: {expected_result}\n" + f"Actual: {actual_result}", ) except (ValueError, ZeroDivisionError): # Skip points where function is undefined continue except Exception as e: - self.fail(f"Error evaluating {func_str} at x={x} with offsets ({x_offset}, {y_offset}): {str(e)}") + self.fail( + f"Error evaluating {func_str} at x={x} with offsets ({x_offset}, {y_offset}): {str(e)}" + ) except Exception as e: self.fail(f"Error setting up function evaluation for {func_str}: {str(e)}") @@ -2061,17 +2175,11 @@ def test_segment_rotation(self) -> None: def test_segment_multiple_rotations(self) -> None: """Test multiple rotations on a segment""" segment = self.canvas.create_segment(0, 0, 30, 0, name="AB") - initial_points = { - (round(p.x, 6), round(p.y, 6)) - for p in [segment.point1, segment.point2] - } + initial_points = {(round(p.x, 6), round(p.y, 6)) for p in [segment.point1, segment.point2]} # Full rotation should return to original position segment.rotate(360) - final_points = { - (round(p.x, 6), round(p.y, 6)) - for p in [segment.point1, segment.point2] - } + final_points = {(round(p.x, 6), round(p.y, 6)) for p in [segment.point1, segment.point2]} self.assertEqual(initial_points, final_points) def test_triangle_rotation(self) -> None: @@ -2081,15 +2189,13 @@ def test_triangle_rotation(self) -> None: polygon_type=PolygonType.TRIANGLE, ) initial_points = { - (p.x, p.y) - for p in [triangle.segment1.point1, triangle.segment1.point2, triangle.segment2.point2] + (p.x, p.y) for p in [triangle.segment1.point1, triangle.segment1.point2, triangle.segment2.point2] } triangle.rotate(120) rotated_points = { - (p.x, p.y) - for p in [triangle.segment1.point1, triangle.segment1.point2, triangle.segment2.point2] + (p.x, p.y) for p in [triangle.segment1.point1, triangle.segment1.point2, triangle.segment2.point2] } self.assertNotEqual(initial_points, rotated_points) @@ -2117,16 +2223,24 @@ def test_rectangle_rotation(self) -> None: rectangle = self._create_rectangle(0, 0, 30, 20) initial_points = { (p.x, p.y) - for p in [rectangle.segment1.point1, rectangle.segment1.point2, - rectangle.segment2.point2, rectangle.segment3.point2] + for p in [ + rectangle.segment1.point1, + rectangle.segment1.point2, + rectangle.segment2.point2, + rectangle.segment3.point2, + ] } rectangle.rotate(45) rotated_points = { (p.x, p.y) - for p in [rectangle.segment1.point1, rectangle.segment1.point2, - rectangle.segment2.point2, rectangle.segment3.point2] + for p in [ + rectangle.segment1.point1, + rectangle.segment1.point2, + rectangle.segment2.point2, + rectangle.segment3.point2, + ] } self.assertNotEqual(initial_points, rotated_points) @@ -2141,31 +2255,16 @@ def test_rectangle_multiple_rotations(self) -> None: p4 = rectangle.segment3.point2 # Store original positions - original_points = [ - (p1.x, p1.y), - (p2.x, p2.y), - (p3.x, p3.y), - (p4.x, p4.y) - ] + original_points = [(p1.x, p1.y), (p2.x, p2.y), (p3.x, p3.y), (p4.x, p4.y)] # Test cumulative rotations rectangle.rotate(45) # Store new positions after first rotation - new_points_after_first_rotation = [ - (p1.x, p1.y), - (p2.x, p2.y), - (p3.x, p3.y), - (p4.x, p4.y) - ] + new_points_after_first_rotation = [(p1.x, p1.y), (p2.x, p2.y), (p3.x, p3.y), (p4.x, p4.y)] rectangle.rotate(45) # Store new positions after second rotation - new_points_after_second_rotation = [ - (p1.x, p1.y), - (p2.x, p2.y), - (p3.x, p3.y), - (p4.x, p4.y) - ] + new_points_after_second_rotation = [(p1.x, p1.y), (p2.x, p2.y), (p3.x, p3.y), (p4.x, p4.y)] # Verify that the points have changed after rotations self.assertNotEqual(original_points, new_points_after_first_rotation) @@ -2178,8 +2277,7 @@ def test_ellipse_rotation(self) -> None: ellipse.rotate(60) self.assertEqual(ellipse.rotation_angle, 60) - self.assertEqual(ellipse.ellipse_formula, - MathUtils.get_ellipse_formula(100, 100, 50, 30, 60)) + self.assertEqual(ellipse.ellipse_formula, MathUtils.get_ellipse_formula(100, 100, 50, 30, 60)) def test_ellipse_formula_after_rotation(self) -> None: """Test that ellipse formula updates correctly after rotation""" @@ -2189,8 +2287,7 @@ def test_ellipse_formula_after_rotation(self) -> None: ellipse.rotate(45) self.assertNotEqual(ellipse.ellipse_formula, initial_formula) - self.assertEqual(ellipse.ellipse_formula, - MathUtils.get_ellipse_formula(100, 100, 50, 30, 45)) + self.assertEqual(ellipse.ellipse_formula, MathUtils.get_ellipse_formula(100, 100, 50, 30, 45)) def test_non_rotatable_objects(self) -> None: """Test that appropriate objects cannot be rotated""" @@ -2238,29 +2335,29 @@ def test_computations(self) -> None: # Test adding computations self.canvas.add_computation( expression="|sqrt((100-0)^2 + (0-0)^2)| + |sqrt((50-100)^2 + (86-0)^2)| + |sqrt((0-50)^2 + (0-86)^2)|", - result=300 + result=300, ) # Test computation appears in canvas state state = self.canvas.get_canvas_state() - self.assertIn('computations', state) - self.assertEqual(len(state['computations']), 1) - self.assertEqual(state['computations'][0]['expression'], "|sqrt((100-0)^2 + (0-0)^2)| + |sqrt((50-100)^2 + (86-0)^2)| + |sqrt((0-50)^2 + (0-86)^2)|") - self.assertEqual(state['computations'][0]['result'], 300) + self.assertIn("computations", state) + self.assertEqual(len(state["computations"]), 1) + self.assertEqual( + state["computations"][0]["expression"], + "|sqrt((100-0)^2 + (0-0)^2)| + |sqrt((50-100)^2 + (86-0)^2)| + |sqrt((0-50)^2 + (0-86)^2)|", + ) + self.assertEqual(state["computations"][0]["result"], 300) # Test duplicate prevention - try to add the same computation again self.canvas.add_computation( expression="|sqrt((100-0)^2 + (0-0)^2)| + |sqrt((50-100)^2 + (86-0)^2)| + |sqrt((0-50)^2 + (0-86)^2)|", - result=300 + result=300, ) # Verify the duplicate was not added self.assertEqual(len(self.canvas.computations), 1) # Add a different computation - self.canvas.add_computation( - expression="300/(2*pi)", - result=47.75 - ) + self.canvas.add_computation(expression="300/(2*pi)", result=47.75) # Verify both unique computations exist self.assertEqual(len(self.canvas.computations), 2) @@ -2271,14 +2368,20 @@ def test_computations(self) -> None: # Verify computations are preserved after undo self.assertEqual(len(self.canvas.computations), 2) - self.assertEqual(self.canvas.computations[0]['expression'], "|sqrt((100-0)^2 + (0-0)^2)| + |sqrt((50-100)^2 + (86-0)^2)| + |sqrt((0-50)^2 + (0-86)^2)|") - self.assertEqual(self.canvas.computations[1]['expression'], "300/(2*pi)") + self.assertEqual( + self.canvas.computations[0]["expression"], + "|sqrt((100-0)^2 + (0-0)^2)| + |sqrt((50-100)^2 + (86-0)^2)| + |sqrt((0-50)^2 + (0-86)^2)|", + ) + self.assertEqual(self.canvas.computations[1]["expression"], "300/(2*pi)") # Test redo also preserves computations self.canvas.redo() # Redo the point creation self.assertEqual(len(self.canvas.computations), 2) - self.assertEqual(self.canvas.computations[0]['expression'], "|sqrt((100-0)^2 + (0-0)^2)| + |sqrt((50-100)^2 + (86-0)^2)| + |sqrt((0-50)^2 + (0-86)^2)|") - self.assertEqual(self.canvas.computations[1]['expression'], "300/(2*pi)") + self.assertEqual( + self.canvas.computations[0]["expression"], + "|sqrt((100-0)^2 + (0-0)^2)| + |sqrt((50-100)^2 + (86-0)^2)| + |sqrt((0-50)^2 + (0-86)^2)|", + ) + self.assertEqual(self.canvas.computations[1]["expression"], "300/(2*pi)") def test_delete_point_on_segment_preserves_parents(self) -> None: """Test that deleting a point on a segment preserves parent segments.""" @@ -2292,7 +2395,7 @@ def test_delete_point_on_segment_preserves_parents(self) -> None: e = self.canvas.create_point(20, 0, name="E", extra_graphics=True) # At this point we should have segments: AB, AD, DB, AE, EB - segments = self.canvas.get_drawables_by_class_name('Segment') + segments = self.canvas.get_drawables_by_class_name("Segment") self.assertEqual(len(segments), 6, "Should have 6 segments: AB, AD, DB, AE, EB, DE") # Get the segment AB to check later @@ -2303,7 +2406,7 @@ def test_delete_point_on_segment_preserves_parents(self) -> None: self.canvas.delete_point_by_name("D") # After deletion, AB should still exist - segments_after_deletion = self.canvas.get_drawables_by_class_name('Segment') + segments_after_deletion = self.canvas.get_drawables_by_class_name("Segment") self.assertIn(ab_after_points, segments_after_deletion, "Segment AB should still exist after deleting point D") # Segments AD, DB, DE should be deleted @@ -2332,7 +2435,7 @@ def test_delete_point_on_triangle_segment_preserves_parents(self) -> None: ) # Get the segments of the triangle - segments_before = self.canvas.get_drawables_by_class_name('Segment') + segments_before = self.canvas.get_drawables_by_class_name("Segment") segment_ab = self.canvas.get_segment_by_coordinates(0, 0, 30, 0) segment_bc = self.canvas.get_segment_by_coordinates(30, 0, 15, 20) @@ -2373,7 +2476,7 @@ def test_delete_point_on_triangle_segment_preserves_parents(self) -> None: ca_segment_after = self.canvas.get_segment_by_name("CA") # Verify the triangle itself is still there - triangles_after = self.canvas.get_drawables_by_class_name('Triangle') + triangles_after = self.canvas.get_drawables_by_class_name("Triangle") self.assertEqual(len(triangles_after), 1, "Triangle should still exist after deleting point D") self.assertEqual(triangles_after[0].name, "ABC", "Triangle ABC should still exist after deleting D") @@ -2526,7 +2629,7 @@ def test_delete_colored_area_by_existing_name(self) -> None: def test_delete_colored_area_by_nonexistent_name(self) -> None: """Test deleting a colored area with a non-existent name returns False and doesn't cause errors.""" # Test deleting a colored area that doesn't exist - result = self.canvas.delete_colored_area('NonexistentColoredArea') + result = self.canvas.delete_colored_area("NonexistentColoredArea") self.assertFalse(result, "delete_colored_area should return False for non-existent names") # Create a colored area and verify normal deletion still works @@ -2695,7 +2798,7 @@ def test_angle_deletion_on_point_deletion(self) -> None: # Verify angle was created self.assertIsNotNone(angle) - angles = self.canvas.get_drawables_by_class_name('Angle') + angles = self.canvas.get_drawables_by_class_name("Angle") self.assertEqual(len(angles), 1) self.assertIn(angle, angles) @@ -2703,7 +2806,7 @@ def test_angle_deletion_on_point_deletion(self) -> None: self.canvas.delete_point(0, 0) # Verify angle was automatically deleted - angles_after = self.canvas.get_drawables_by_class_name('Angle') + angles_after = self.canvas.get_drawables_by_class_name("Angle") self.assertEqual(len(angles_after), 0) self.assertNotIn(angle, angles_after) @@ -2712,7 +2815,7 @@ def test_vector_creation_registers_dependency(self) -> None: vector = self.canvas.create_vector(10, 10, 20, 20) self.assertIsNotNone(vector) parents = self.canvas.dependency_manager.get_parents(vector) - segment_parents = [p for p in parents if hasattr(p, 'get_class_name') and p.get_class_name() == 'Segment'] + segment_parents = [p for p in parents if hasattr(p, "get_class_name") and p.get_class_name() == "Segment"] self.assertEqual(len(segment_parents), 1, "Vector should have exactly 1 segment parent") self.assertIs(segment_parents[0], vector.segment) @@ -2727,13 +2830,13 @@ def test_point_deletion_cascades_to_pentagon(self) -> None: ] pentagon = self.canvas.create_polygon(vertices, polygon_type=PolygonType.PENTAGON) self.assertIsNotNone(pentagon) - self.assertIn(pentagon, self.canvas.get_drawables_by_class_name('Pentagon')) + self.assertIn(pentagon, self.canvas.get_drawables_by_class_name("Pentagon")) # Delete the first vertex self.canvas.delete_point(0.0, 0.0) # Pentagon should be removed - self.assertNotIn(pentagon, self.canvas.get_drawables_by_class_name('Pentagon')) + self.assertNotIn(pentagon, self.canvas.get_drawables_by_class_name("Pentagon")) def test_segment_deletion_cascades_to_pentagon(self) -> None: """Deleting an edge segment of a pentagon should remove the pentagon.""" @@ -2754,7 +2857,7 @@ def test_segment_deletion_cascades_to_pentagon(self) -> None: self.canvas.delete_segment(seg.point1.x, seg.point1.y, seg.point2.x, seg.point2.y) # Pentagon should be removed - self.assertNotIn(pentagon, self.canvas.get_drawables_by_class_name('Pentagon')) + self.assertNotIn(pentagon, self.canvas.get_drawables_by_class_name("Pentagon")) def test_point_deletion_cascades_to_hexagon(self) -> None: """Deleting a vertex point of a hexagon should remove the hexagon.""" @@ -2768,11 +2871,11 @@ def test_point_deletion_cascades_to_hexagon(self) -> None: ] hexagon = self.canvas.create_polygon(vertices, polygon_type=PolygonType.HEXAGON) self.assertIsNotNone(hexagon) - self.assertIn(hexagon, self.canvas.get_drawables_by_class_name('Hexagon')) + self.assertIn(hexagon, self.canvas.get_drawables_by_class_name("Hexagon")) self.canvas.delete_point(0.0, 0.0) - self.assertNotIn(hexagon, self.canvas.get_drawables_by_class_name('Hexagon')) + self.assertNotIn(hexagon, self.canvas.get_drawables_by_class_name("Hexagon")) def test_segment_deletion_cascades_to_hexagon(self) -> None: """Deleting an edge segment of a hexagon should remove the hexagon.""" @@ -2792,7 +2895,7 @@ def test_segment_deletion_cascades_to_hexagon(self) -> None: seg = edge_segments[0] self.canvas.delete_segment(seg.point1.x, seg.point1.y, seg.point2.x, seg.point2.y) - self.assertNotIn(hexagon, self.canvas.get_drawables_by_class_name('Hexagon')) + self.assertNotIn(hexagon, self.canvas.get_drawables_by_class_name("Hexagon")) def test_point_deletion_cascades_to_quadrilateral(self) -> None: """Deleting a vertex of a quadrilateral should remove it (exercises segment1..4 attrs).""" @@ -2804,11 +2907,11 @@ def test_point_deletion_cascades_to_quadrilateral(self) -> None: ] quad = self.canvas.create_polygon(vertices, polygon_type=PolygonType.QUADRILATERAL) self.assertIsNotNone(quad) - self.assertIn(quad, self.canvas.get_drawables_by_class_name('Quadrilateral')) + self.assertIn(quad, self.canvas.get_drawables_by_class_name("Quadrilateral")) self.canvas.delete_point(0.0, 0.0) - self.assertNotIn(quad, self.canvas.get_drawables_by_class_name('Quadrilateral')) + self.assertNotIn(quad, self.canvas.get_drawables_by_class_name("Quadrilateral")) def test_segment_deletion_cascades_to_quadrilateral(self) -> None: """Deleting an edge of a quadrilateral should remove it (exercises segment1..4 attrs).""" @@ -2826,7 +2929,7 @@ def test_segment_deletion_cascades_to_quadrilateral(self) -> None: seg = edge_segments[0] self.canvas.delete_segment(seg.point1.x, seg.point1.y, seg.point2.x, seg.point2.y) - self.assertNotIn(quad, self.canvas.get_drawables_by_class_name('Quadrilateral')) + self.assertNotIn(quad, self.canvas.get_drawables_by_class_name("Quadrilateral")) class TestCanvasHelperMethods(unittest.TestCase): @@ -2892,7 +2995,7 @@ def test_angle_deletion_on_segment_deletion(self) -> None: # Verify angle was created self.assertIsNotNone(angle) - angles = self.canvas.get_drawables_by_class_name('Angle') + angles = self.canvas.get_drawables_by_class_name("Angle") self.assertEqual(len(angles), 1) self.assertIn(angle, angles) @@ -2900,6 +3003,6 @@ def test_angle_deletion_on_segment_deletion(self) -> None: self.canvas.delete_segment(0, 0, 10, 0) # Delete segment AB # Verify angle was automatically deleted - angles_after = self.canvas.get_drawables_by_class_name('Angle') + angles_after = self.canvas.get_drawables_by_class_name("Angle") self.assertEqual(len(angles_after), 0) self.assertNotIn(angle, angles_after) diff --git a/static/client/client_tests/test_cartesian.py b/static/client/client_tests/test_cartesian.py index 40bdf10c..36a3a243 100644 --- a/static/client/client_tests/test_cartesian.py +++ b/static/client/client_tests/test_cartesian.py @@ -23,7 +23,7 @@ def setUp(self) -> None: zoom_direction=0, offset=Position(0, 0), # Set to (0,0) for simpler tests zoom_point=Position(0, 0), - zoom_step=0.1 + zoom_step=0.1, ) # Sync canvas state with coordinate mapper @@ -50,7 +50,7 @@ def test_get_visible_bounds(self) -> None: self.assertEqual(left_bound, -400.0) # -self.origin.x / self.canvas.scale_factor self.assertEqual(right_bound, 400.0) # (self.width - self.origin.x) / self.canvas.scale_factor - self.assertEqual(top_bound, 300.0) # self.origin.y / self.canvas.scale_factor + self.assertEqual(top_bound, 300.0) # self.origin.y / self.canvas.scale_factor self.assertEqual(bottom_bound, -300.0) # (self.origin.y - self.height) / self.canvas.scale_factor # Test bounds with different scale factor - bounds are now dynamic via CoordinateMapper @@ -168,11 +168,11 @@ def test_state_retrieval(self) -> None: def test_get_axis_helpers(self) -> None: # Test the axis helper methods we added during refactoring - self.assertEqual(self.cartesian_system._get_axis_origin('x'), self.cartesian_system.origin.x) - self.assertEqual(self.cartesian_system._get_axis_origin('y'), self.cartesian_system.origin.y) + self.assertEqual(self.cartesian_system._get_axis_origin("x"), self.cartesian_system.origin.x) + self.assertEqual(self.cartesian_system._get_axis_origin("y"), self.cartesian_system.origin.y) - self.assertEqual(self.cartesian_system._get_axis_boundary('x'), self.cartesian_system.width) - self.assertEqual(self.cartesian_system._get_axis_boundary('y'), self.cartesian_system.height) + self.assertEqual(self.cartesian_system._get_axis_boundary("x"), self.cartesian_system.width) + self.assertEqual(self.cartesian_system._get_axis_boundary("y"), self.cartesian_system.height) def test_should_continue_drawing(self) -> None: # Test the boundary condition method for drawing @@ -356,6 +356,7 @@ def test_relative_dimensions(self) -> None: def _count_grid_lines(self, origin: float, dimension_px: float, display_tick: float) -> int: import math + if display_tick <= 0: return 0 start_n = int(math.ceil(-origin / display_tick)) @@ -369,6 +370,7 @@ def _count_grid_lines(self, origin: float, dimension_px: float, display_tick: fl def _calculate_adaptive_tick_spacing(self, width: float, scale_factor: float, max_ticks: int = 10) -> float: import math + relative_width = width / scale_factor ideal_spacing = relative_width / max_ticks if ideal_spacing <= 0: @@ -405,12 +407,16 @@ def test_grid_line_count_constant_at_various_origins(self) -> None: count_y = self._count_grid_lines(oy, height_px, display_tick) self.assertAlmostEqual( - count_x, base_count_x, delta=2, - msg=f"X grid line count {count_x} differs from base {base_count_x} at ox={ox}" + count_x, + base_count_x, + delta=2, + msg=f"X grid line count {count_x} differs from base {base_count_x} at ox={ox}", ) self.assertAlmostEqual( - count_y, base_count_y, delta=2, - msg=f"Y grid line count {count_y} differs from base {base_count_y} at oy={oy}" + count_y, + base_count_y, + delta=2, + msg=f"Y grid line count {count_y} differs from base {base_count_y} at oy={oy}", ) def test_grid_line_count_bounded_across_zoom_levels(self) -> None: @@ -428,22 +434,10 @@ def test_grid_line_count_bounded_across_zoom_levels(self) -> None: count_x = self._count_grid_lines(ox, width_px, display_tick) count_y = self._count_grid_lines(oy, height_px, display_tick) - self.assertGreaterEqual( - count_x, 4, - msg=f"Too few X grid lines ({count_x}) at scale {scale_factor}" - ) - self.assertLessEqual( - count_x, 25, - msg=f"Too many X grid lines ({count_x}) at scale {scale_factor}" - ) - self.assertGreaterEqual( - count_y, 3, - msg=f"Too few Y grid lines ({count_y}) at scale {scale_factor}" - ) - self.assertLessEqual( - count_y, 20, - msg=f"Too many Y grid lines ({count_y}) at scale {scale_factor}" - ) + self.assertGreaterEqual(count_x, 4, msg=f"Too few X grid lines ({count_x}) at scale {scale_factor}") + self.assertLessEqual(count_x, 25, msg=f"Too many X grid lines ({count_x}) at scale {scale_factor}") + self.assertGreaterEqual(count_y, 3, msg=f"Too few Y grid lines ({count_y}) at scale {scale_factor}") + self.assertLessEqual(count_y, 20, msg=f"Too many Y grid lines ({count_y}) at scale {scale_factor}") def test_grid_line_count_constant_at_extreme_distances(self) -> None: width_px = 800 @@ -461,9 +455,12 @@ def test_grid_line_count_constant_at_extreme_distances(self) -> None: base_count_y = self._count_grid_lines(base_oy, height_px, display_tick) extreme_offsets = [ - 1e6, -1e6, - 1e9, -1e9, - 1e12, -1e12, + 1e6, + -1e6, + 1e9, + -1e9, + 1e12, + -1e12, ] for offset in extreme_offsets: @@ -473,12 +470,16 @@ def test_grid_line_count_constant_at_extreme_distances(self) -> None: count_y = self._count_grid_lines(oy, height_px, display_tick) self.assertAlmostEqual( - count_x, base_count_x, delta=2, - msg=f"scale={scale_factor}, offset={offset}: X count {count_x} vs base {base_count_x}" + count_x, + base_count_x, + delta=2, + msg=f"scale={scale_factor}, offset={offset}: X count {count_x} vs base {base_count_x}", ) self.assertAlmostEqual( - count_y, base_count_y, delta=2, - msg=f"scale={scale_factor}, offset={offset}: Y count {count_y} vs base {base_count_y}" + count_y, + base_count_y, + delta=2, + msg=f"scale={scale_factor}, offset={offset}: Y count {count_y} vs base {base_count_y}", ) def test_grid_line_count_combined_zoom_and_distance(self) -> None: @@ -521,11 +522,5 @@ def test_grid_line_count_combined_zoom_and_distance(self) -> None: x_range = max(x_counts) - min(x_counts) y_range = max(y_counts) - min(y_counts) - self.assertLessEqual( - x_range, 2, - msg=f"At scale {scale_factor}, X counts vary too much: {counts}" - ) - self.assertLessEqual( - y_range, 2, - msg=f"At scale {scale_factor}, Y counts vary too much: {counts}" - ) + self.assertLessEqual(x_range, 2, msg=f"At scale {scale_factor}, X counts vary too much: {counts}") + self.assertLessEqual(y_range, 2, msg=f"At scale {scale_factor}, Y counts vary too much: {counts}") diff --git a/static/client/client_tests/test_chat_message_menu.py b/static/client/client_tests/test_chat_message_menu.py index 5be252bd..3c2f2172 100644 --- a/static/client/client_tests/test_chat_message_menu.py +++ b/static/client/client_tests/test_chat_message_menu.py @@ -102,5 +102,3 @@ def test_copy_message_text_uses_raw_source(self) -> None: pass copy_mock.assert_called_once_with(raw_text) - - diff --git a/static/client/client_tests/test_circle.py b/static/client/client_tests/test_circle.py index 66a4e2f6..3063bfb1 100644 --- a/static/client/client_tests/test_circle.py +++ b/static/client/client_tests/test_circle.py @@ -22,7 +22,7 @@ def setUp(self) -> None: zoom_point=Position(1, 1), zoom_direction=1, zoom_step=0.1, - offset=Position(0, 0) # Set to (0,0) for simpler tests + offset=Position(0, 0), # Set to (0,0) for simpler tests ) # Sync canvas state with coordinate mapper @@ -38,7 +38,7 @@ def test_init(self) -> None: self.assertEqual(self.circle.color, "blue") def test_get_class_name(self) -> None: - self.assertEqual(self.circle.get_class_name(), 'Circle') + self.assertEqual(self.circle.get_class_name(), "Circle") def test_calculate_circle_algebraic_formula(self) -> None: formula = self.circle._calculate_circle_algebraic_formula() @@ -46,7 +46,10 @@ def test_calculate_circle_algebraic_formula(self) -> None: def test_get_state(self) -> None: state = self.circle.get_state() - expected_state = {"name": self.circle.name, "args": {"center": self.center.name, "radius": self.radius, "circle_formula": self.circle.circle_formula}} + expected_state = { + "name": self.circle.name, + "args": {"center": self.center.name, "radius": self.radius, "circle_formula": self.circle.circle_formula}, + } self.assertEqual(state, expected_state) def test_deepcopy(self) -> None: diff --git a/static/client/client_tests/test_circle_arc.py b/static/client/client_tests/test_circle_arc.py index 1fe100e2..966d4c80 100644 --- a/static/client/client_tests/test_circle_arc.py +++ b/static/client/client_tests/test_circle_arc.py @@ -96,4 +96,3 @@ def test_set_use_major_arc_switches_flag(self) -> None: self.assertTrue(arc.use_major_arc) arc.set_use_major_arc(False) self.assertFalse(arc.use_major_arc) - diff --git a/static/client/client_tests/test_circle_manager.py b/static/client/client_tests/test_circle_manager.py index 221b47de..acb61ead 100644 --- a/static/client/client_tests/test_circle_manager.py +++ b/static/client/client_tests/test_circle_manager.py @@ -86,9 +86,7 @@ def test_update_circle_requires_complete_center_pair(self) -> None: def test_update_circle_rejects_non_solitary_center(self) -> None: other_parent = object() circle = self._add_circle() - self.dependency_manager.get_parents = ( - lambda obj: {circle, other_parent} if obj is circle.center else set() - ) + self.dependency_manager.get_parents = lambda obj: {circle, other_parent} if obj is circle.center else set() self.dependency_manager.get_children = lambda obj: set() with self.assertRaises(ValueError): @@ -97,12 +95,8 @@ def test_update_circle_rejects_non_solitary_center(self) -> None: def test_update_circle_rejects_center_with_other_child(self) -> None: other_child = object() circle = self._add_circle() - self.dependency_manager.get_parents = ( - lambda obj: {circle} if obj is circle.center else set() - ) - self.dependency_manager.get_children = ( - lambda obj: {circle, other_child} if obj is circle.center else set() - ) + self.dependency_manager.get_parents = lambda obj: {circle} if obj is circle.center else set() + self.dependency_manager.get_children = lambda obj: {circle, other_child} if obj is circle.center else set() with self.assertRaises(ValueError): self.circle_manager.update_circle("CircleA", new_center_x=5.0, new_center_y=6.0) diff --git a/static/client/client_tests/test_closed_shape_colored_area.py b/static/client/client_tests/test_closed_shape_colored_area.py index 454fe65b..759f5bd2 100644 --- a/static/client/client_tests/test_closed_shape_colored_area.py +++ b/static/client/client_tests/test_closed_shape_colored_area.py @@ -173,5 +173,3 @@ def test_deepcopy_creates_independent_copy(self) -> None: def test_invalid_shape_type_raises(self) -> None: with self.assertRaises(ValueError): ClosedShapeColoredArea(shape_type="unsupported") - - diff --git a/static/client/client_tests/test_colored_area_helpers.py b/static/client/client_tests/test_colored_area_helpers.py index 7110241a..90f36053 100644 --- a/static/client/client_tests/test_colored_area_helpers.py +++ b/static/client/client_tests/test_colored_area_helpers.py @@ -173,4 +173,3 @@ def test_invalid_opacity_uses_default(self) -> None: "TestFilterValidPoints", "TestRenderColoredAreaHelper", ] - diff --git a/static/client/client_tests/test_coordinate_system_manager.py b/static/client/client_tests/test_coordinate_system_manager.py index d9108cf4..4840ecf6 100644 --- a/static/client/client_tests/test_coordinate_system_manager.py +++ b/static/client/client_tests/test_coordinate_system_manager.py @@ -14,6 +14,7 @@ def setUp(self) -> None: self.cartesian_grid = Cartesian2Axis(coordinate_mapper=self.coordinate_mapper) draw_called = [False] + def mock_draw(): draw_called[0] = True @@ -29,7 +30,7 @@ def mock_draw(): zoom_point=Position(0, 0), zoom_step=0.1, draw=mock_draw, - _draw_called=draw_called + _draw_called=draw_called, ) self.coordinate_mapper.sync_from_canvas(self.canvas) @@ -168,6 +169,7 @@ def setUp(self) -> None: self.cartesian_grid = Cartesian2Axis(coordinate_mapper=self.coordinate_mapper) draw_called = [False] + def mock_draw(): draw_called[0] = True @@ -183,7 +185,7 @@ def mock_draw(): zoom_point=Position(0, 0), zoom_step=0.1, draw=mock_draw, - _draw_called=draw_called + _draw_called=draw_called, ) self.coordinate_mapper.sync_from_canvas(self.canvas) diff --git a/static/client/client_tests/test_custom_drawable_names.py b/static/client/client_tests/test_custom_drawable_names.py index f6b82e1d..f5275a85 100644 --- a/static/client/client_tests/test_custom_drawable_names.py +++ b/static/client/client_tests/test_custom_drawable_names.py @@ -13,9 +13,12 @@ class TestCustomDrawableNames(unittest.TestCase): def setUp(self) -> None: self.canvas = Canvas(500, 500, draw_enabled=False) - self.mock_cartesian2axis = SimpleMock(draw=SimpleMock(return_value=None), reset=SimpleMock(return_value=None), - get_state=SimpleMock(return_value={'Cartesian_System_Visibility': 'cartesian_state'}), - origin=Position(0, 0)) + self.mock_cartesian2axis = SimpleMock( + draw=SimpleMock(return_value=None), + reset=SimpleMock(return_value=None), + get_state=SimpleMock(return_value={"Cartesian_System_Visibility": "cartesian_state"}), + origin=Position(0, 0), + ) self.canvas.cartesian2axis = self.mock_cartesian2axis def tearDown(self) -> None: @@ -108,7 +111,7 @@ def test_rectangle_basic_naming(self) -> None: rectangle.segment1.point1.name, rectangle.segment1.point2.name, rectangle.segment2.point2.name, - rectangle.segment3.point2.name + rectangle.segment3.point2.name, ] # Check that the points use the first four letters of "Rectangle" self.assertIn("R", points) @@ -125,7 +128,7 @@ def test_rectangle_basic_naming(self) -> None: rectangle2.segment1.point1.name, rectangle2.segment1.point2.name, rectangle2.segment2.point2.name, - rectangle2.segment3.point2.name + rectangle2.segment3.point2.name, ] # Check that the points use the next available letters self.assertIn("A", points2) @@ -144,7 +147,7 @@ def test_rectangle_apostrophe_naming(self) -> None: rectangle.segment1.point1.name, rectangle.segment1.point2.name, rectangle.segment2.point2.name, - rectangle.segment3.point2.name + rectangle.segment3.point2.name, ] # Check that the points use the letters with their apostrophes self.assertIn("W'", points) @@ -207,12 +210,12 @@ def test_vector_apostrophe_naming(self) -> None: def test_name_fallback_sequence(self) -> None: # Create 26 points to use up all letters for i in range(26): - point = self.canvas.create_point(i*10, i*10) - expected_letter = chr(ord('A') + i) # A, B, C, ... + point = self.canvas.create_point(i * 10, i * 10) + expected_letter = chr(ord("A") + i) # A, B, C, ... self.assertEqual(point.name, expected_letter) # Check we have all letters A-Z - expected_names = [letter for letter in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'] - actual_names = sorted(self.canvas.name_generator.get_drawable_names('Point')) + expected_names = [letter for letter in "ABCDEFGHIJKLMNOPQRSTUVWXYZ"] + actual_names = sorted(self.canvas.name_generator.get_drawable_names("Point")) self.assertEqual(actual_names, expected_names) # Try to create a point with custom name - should use first letter with apostrophe @@ -220,21 +223,21 @@ def test_name_fallback_sequence(self) -> None: self.assertEqual(point.name, "C'") expected_names.append("C'") expected_names.sort() - actual_names = sorted(self.canvas.name_generator.get_drawable_names('Point')) + actual_names = sorted(self.canvas.name_generator.get_drawable_names("Point")) self.assertEqual(actual_names, expected_names) # Create 25 more points without names - should get A'-Z' (except C' which is already used) for i in range(25): # 25 because C' is already used - point = self.canvas.create_point(i*10 + 400, i*10 + 400) + point = self.canvas.create_point(i * 10 + 400, i * 10 + 400) # Check we have all letters A'-Z' - expected_names = expected_names + [letter + "'" for letter in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' if letter != 'C'] + expected_names = expected_names + [letter + "'" for letter in "ABCDEFGHIJKLMNOPQRSTUVWXYZ" if letter != "C"] expected_names.sort() - actual_names = sorted(self.canvas.name_generator.get_drawable_names('Point')) + actual_names = sorted(self.canvas.name_generator.get_drawable_names("Point")) self.assertEqual(actual_names, expected_names) # Try to create another point - should use first letter with two apostrophes point2 = self.canvas.create_point(310, 310) # No custom name self.assertEqual(point2.name, "A''") expected_names.append("A''") - self.assertEqual(sorted(self.canvas.name_generator.get_drawable_names('Point')), sorted(expected_names)) + self.assertEqual(sorted(self.canvas.name_generator.get_drawable_names("Point")), sorted(expected_names)) diff --git a/static/client/client_tests/test_decagon.py b/static/client/client_tests/test_decagon.py index 8d075df2..d9f2a69e 100644 --- a/static/client/client_tests/test_decagon.py +++ b/static/client/client_tests/test_decagon.py @@ -130,4 +130,3 @@ def test_is_irregular_helper(self) -> None: if __name__ == "__main__": unittest.main() - diff --git a/static/client/client_tests/test_drawable_dependency_manager.py b/static/client/client_tests/test_drawable_dependency_manager.py index 9c0c69e7..2a18a391 100644 --- a/static/client/client_tests/test_drawable_dependency_manager.py +++ b/static/client/client_tests/test_drawable_dependency_manager.py @@ -20,9 +20,9 @@ def _create_mock_point(self, name: str, x: float = 0, y: float = 0) -> SimpleMoc x=x, y=y, canvas=None, - get_class_name=SimpleMock(return_value='Point'), + get_class_name=SimpleMock(return_value="Point"), __str__=SimpleMock(return_value=f"Point({name})"), - __repr__=SimpleMock(return_value=f"Point({name})") + __repr__=SimpleMock(return_value=f"Point({name})"), ) def _create_mock_segment(self, name: str, point1: SimpleMock, point2: SimpleMock) -> SimpleMock: @@ -32,9 +32,9 @@ def _create_mock_segment(self, name: str, point1: SimpleMock, point2: SimpleMock point1=point1, point2=point2, canvas=None, - get_class_name=SimpleMock(return_value='Segment'), + get_class_name=SimpleMock(return_value="Segment"), __str__=SimpleMock(return_value=f"Segment({name})"), - __repr__=SimpleMock(return_value=f"Segment({name})") + __repr__=SimpleMock(return_value=f"Segment({name})"), ) def _create_mock_drawable(self, name: str, class_name: str = "MockDrawable") -> SimpleMock: @@ -45,17 +45,13 @@ def _create_mock_drawable(self, name: str, class_name: str = "MockDrawable") -> canvas=None, get_class_name=SimpleMock(return_value=class_name), __str__=SimpleMock(return_value=f"{class_name}({name})"), - __repr__=SimpleMock(return_value=f"{class_name}({name})") + __repr__=SimpleMock(return_value=f"{class_name}({name})"), ) def setUp(self) -> None: """Set up test environment before each test""" # Create a mock drawable manager - self.mock_drawable_manager = SimpleMock( - drawables=SimpleMock( - Segments=[] - ) - ) + self.mock_drawable_manager = SimpleMock(drawables=SimpleMock(Segments=[])) self.manager = DrawableDependencyManager(drawable_manager_proxy=self.mock_drawable_manager) # Create mock drawables using private factory methods @@ -81,9 +77,7 @@ def _assert_internal_graph_invariants(self) -> None: self.assertIn(parent_id, self.manager._parents.get(child_id, set())) for drawable_id in self.manager._object_lookup: - self.assertTrue( - drawable_id in self.manager._parents or drawable_id in self.manager._children - ) + self.assertTrue(drawable_id in self.manager._parents or drawable_id in self.manager._children) def test_register_dependency(self) -> None: """Test registering dependencies between drawables""" @@ -182,8 +176,9 @@ def test_resolve_dependency_order(self) -> None: for segment in [self.segment1, self.segment2, self.segment3]: if point in self.manager.get_parents(segment): segment_index = ordered.index(segment) - self.assertLess(point_index, segment_index, - f"Point {point.name} should come before its segment {segment.name}") + self.assertLess( + point_index, segment_index, f"Point {point.name} should come before its segment {segment.name}" + ) def test_circular_dependencies(self) -> None: """Test handling of circular dependencies""" @@ -292,7 +287,9 @@ def test_analyze_drawable_for_dependencies(self) -> None: func_seg_area.func = function func_seg_area.segment = self.segment1 func_seg_area_dependencies = self.manager.analyze_drawable_for_dependencies(func_seg_area) - self.assertEqual(len(func_seg_area_dependencies), 2, "FunctionSegmentBoundedColoredArea should have 2 dependencies") + self.assertEqual( + len(func_seg_area_dependencies), 2, "FunctionSegmentBoundedColoredArea should have 2 dependencies" + ) self.assertIn(function, func_seg_area_dependencies, "function should be a dependency of area") self.assertIn(self.segment1, func_seg_area_dependencies, "segment should be a dependency of area") @@ -311,7 +308,6 @@ def test_analyze_drawable_for_dependencies(self) -> None: dependencies = self.manager.analyze_drawable_for_dependencies(obj_without_method) self.assertEqual(len(dependencies), 0, "Object without get_class_name should return empty dependencies list") - def test_drawable_types_completeness(self) -> None: """Test that analyze_drawable_for_dependencies has cases for all drawable types""" # Use _type_hierarchy as the source of truth for drawable types @@ -323,32 +319,35 @@ def test_drawable_types_completeness(self) -> None: # Use direct functional testing since inspect.getsource() in Brython # only returns method signatures, not the full method body test_cases = [ - ('Point', self._create_mock_drawable("TestPoint", "Point")), - ('Segment', self._create_mock_drawable("TestSegment", "Segment")), - ('Vector', self._create_mock_drawable("TestVector", "Vector")), - ('Triangle', self._create_mock_drawable("TestTriangle", "Triangle")), - ('Rectangle', self._create_mock_drawable("TestRectangle", "Rectangle")), - ('Quadrilateral', self._create_mock_drawable("TestQuadrilateral", "Quadrilateral")), - ('Pentagon', self._create_mock_drawable("TestPentagon", "Pentagon")), - ('Hexagon', self._create_mock_drawable("TestHexagon", "Hexagon")), - ('Heptagon', self._create_mock_drawable("TestHeptagon", "Heptagon")), - ('Octagon', self._create_mock_drawable("TestOctagon", "Octagon")), - ('Nonagon', self._create_mock_drawable("TestNonagon", "Nonagon")), - ('Decagon', self._create_mock_drawable("TestDecagon", "Decagon")), - ('GenericPolygon', self._create_mock_drawable("TestGenericPolygon", "GenericPolygon")), - ('Circle', self._create_mock_drawable("TestCircle", "Circle")), - ('CircleArc', self._create_mock_drawable("TestCircleArc", "CircleArc")), - ('Ellipse', self._create_mock_drawable("TestEllipse", "Ellipse")), - ('Function', self._create_mock_drawable("TestFunction", "Function")), - ('Angle', self._create_mock_drawable("TestAngle", "Angle")), - ('ColoredArea', self._create_mock_drawable("TestColoredArea", "ColoredArea")), - ('SegmentsBoundedColoredArea', self._create_mock_drawable("TestSBCA", "SegmentsBoundedColoredArea")), - ('FunctionSegmentBoundedColoredArea', self._create_mock_drawable("TestFSBCA", "FunctionSegmentBoundedColoredArea")), - ('FunctionsBoundedColoredArea', self._create_mock_drawable("TestFBCA", "FunctionsBoundedColoredArea")), - ('Graph', self._create_mock_drawable("TestGraph", "Graph")), - ('DirectedGraph', self._create_mock_drawable("TestDirectedGraph", "DirectedGraph")), - ('UndirectedGraph', self._create_mock_drawable("TestUndirectedGraph", "UndirectedGraph")), - ('Tree', self._create_mock_drawable("TestTree", "Tree")), + ("Point", self._create_mock_drawable("TestPoint", "Point")), + ("Segment", self._create_mock_drawable("TestSegment", "Segment")), + ("Vector", self._create_mock_drawable("TestVector", "Vector")), + ("Triangle", self._create_mock_drawable("TestTriangle", "Triangle")), + ("Rectangle", self._create_mock_drawable("TestRectangle", "Rectangle")), + ("Quadrilateral", self._create_mock_drawable("TestQuadrilateral", "Quadrilateral")), + ("Pentagon", self._create_mock_drawable("TestPentagon", "Pentagon")), + ("Hexagon", self._create_mock_drawable("TestHexagon", "Hexagon")), + ("Heptagon", self._create_mock_drawable("TestHeptagon", "Heptagon")), + ("Octagon", self._create_mock_drawable("TestOctagon", "Octagon")), + ("Nonagon", self._create_mock_drawable("TestNonagon", "Nonagon")), + ("Decagon", self._create_mock_drawable("TestDecagon", "Decagon")), + ("GenericPolygon", self._create_mock_drawable("TestGenericPolygon", "GenericPolygon")), + ("Circle", self._create_mock_drawable("TestCircle", "Circle")), + ("CircleArc", self._create_mock_drawable("TestCircleArc", "CircleArc")), + ("Ellipse", self._create_mock_drawable("TestEllipse", "Ellipse")), + ("Function", self._create_mock_drawable("TestFunction", "Function")), + ("Angle", self._create_mock_drawable("TestAngle", "Angle")), + ("ColoredArea", self._create_mock_drawable("TestColoredArea", "ColoredArea")), + ("SegmentsBoundedColoredArea", self._create_mock_drawable("TestSBCA", "SegmentsBoundedColoredArea")), + ( + "FunctionSegmentBoundedColoredArea", + self._create_mock_drawable("TestFSBCA", "FunctionSegmentBoundedColoredArea"), + ), + ("FunctionsBoundedColoredArea", self._create_mock_drawable("TestFBCA", "FunctionsBoundedColoredArea")), + ("Graph", self._create_mock_drawable("TestGraph", "Graph")), + ("DirectedGraph", self._create_mock_drawable("TestDirectedGraph", "DirectedGraph")), + ("UndirectedGraph", self._create_mock_drawable("TestUndirectedGraph", "UndirectedGraph")), + ("Tree", self._create_mock_drawable("TestTree", "Tree")), ] # Test each drawable type by calling the method and checking it doesn't raise an exception @@ -364,7 +363,7 @@ def test_drawable_types_completeness(self) -> None: # Also check for any ColoredArea types that might be handled by the endswith logic for class_name in drawable_types: - if class_name.endswith('ColoredArea') and class_name not in handled_classes: + if class_name.endswith("ColoredArea") and class_name not in handled_classes: # Test this ColoredArea type try: test_obj = self._create_mock_drawable(f"Test{class_name}", class_name) @@ -374,8 +373,6 @@ def test_drawable_types_completeness(self) -> None: except Exception as test_e: print(f"DEBUG: {class_name} (ColoredArea type) failed: {test_e}") - - # Check for missing implementations missing_implementations = drawable_types - handled_classes @@ -385,8 +382,11 @@ def test_drawable_types_completeness(self) -> None: print(f"Missing implementations: {sorted(missing_implementations)}") # Assert that all types are handled - self.assertEqual(len(missing_implementations), 0, - f"Missing analyze_drawable_for_dependencies cases for: {', '.join(missing_implementations)}") + self.assertEqual( + len(missing_implementations), + 0, + f"Missing analyze_drawable_for_dependencies cases for: {', '.join(missing_implementations)}", + ) def test_error_handling_none_values(self) -> None: """Test handling of None values in various methods""" @@ -401,7 +401,9 @@ def test_error_handling_none_values(self) -> None: # Test get_all_parents and get_all_children with None self.assertEqual(len(self.manager.get_all_parents(None)), 0, "get_all_parents should return empty set for None") - self.assertEqual(len(self.manager.get_all_children(None)), 0, "get_all_children should return empty set for None") + self.assertEqual( + len(self.manager.get_all_children(None)), 0, "get_all_children should return empty set for None" + ) # Test remove_drawable with None try: @@ -432,10 +434,16 @@ def test_edge_cases(self) -> None: self.fail(f"remove_drawable failed with non-existent drawable: {e}") # Test get_parents and get_children for non-existent drawable - self.assertEqual(len(self.manager.get_parents(non_existent)), 0, - "get_parents should return empty set for non-existent drawable") - self.assertEqual(len(self.manager.get_children(non_existent)), 0, - "get_children should return empty set for non-existent drawable") + self.assertEqual( + len(self.manager.get_parents(non_existent)), + 0, + "get_parents should return empty set for non-existent drawable", + ) + self.assertEqual( + len(self.manager.get_children(non_existent)), + 0, + "get_children should return empty set for non-existent drawable", + ) def test_verify_get_class_name_method(self) -> None: """Test verification of get_class_name method""" @@ -448,6 +456,7 @@ def test_verify_get_class_name_method(self) -> None: class NoMethodDrawable: def __init__(self) -> None: self.name = "NoMethod" + no_method = NoMethodDrawable() self.manager._verify_get_class_name_method(no_method, "Test") # Should log warning but not raise error @@ -461,6 +470,7 @@ class BadMethodDrawable: def __init__(self) -> None: self.name = "BadMethod" self.get_class_name = "not a method" + bad_method = BadMethodDrawable() self.manager._verify_get_class_name_method(bad_method, "Test") # Should log warning but not raise error @@ -569,10 +579,8 @@ def test_debug_logging_mode_preserves_dependency_behavior(self) -> None: def test_get_parents_and_children(self) -> None: """Test getting direct parents and children""" # Test empty sets - self.assertEqual(len(self.manager.get_parents(self.segment1)), 0, - "New segment should have no parents") - self.assertEqual(len(self.manager.get_children(self.point1)), 0, - "New point should have no children") + self.assertEqual(len(self.manager.get_parents(self.segment1)), 0, "New segment should have no parents") + self.assertEqual(len(self.manager.get_children(self.point1)), 0, "New point should have no children") # Set up multiple dependencies self.manager.register_dependency(child=self.segment1, parent=self.point1) @@ -592,10 +600,8 @@ def test_get_parents_and_children(self) -> None: self.assertIn(self.segment2, children, "Segment2 should be a child of Point1") # Test with None values - self.assertEqual(len(self.manager.get_parents(None)), 0, - "None should return empty parents set") - self.assertEqual(len(self.manager.get_children(None)), 0, - "None should return empty children set") + self.assertEqual(len(self.manager.get_parents(None)), 0, "None should return empty parents set") + self.assertEqual(len(self.manager.get_children(None)), 0, "None should return empty children set") def test_should_skip_point_point_dependency(self) -> None: """Test the point-point dependency skip logic""" @@ -730,6 +736,7 @@ def test_analyze_drawable_for_dependencies_comprehensive(self) -> None: class BadDrawable: def __init__(self) -> None: self.name = "Bad" + bad = BadDrawable() dependencies = self.manager.analyze_drawable_for_dependencies(bad) self.assertEqual(len(dependencies), 0, "Drawable without get_class_name should have no dependencies") @@ -824,7 +831,9 @@ def test_analyze_pentagon_dependencies(self) -> None: def test_analyze_generic_polygon_dependencies(self) -> None: """GenericPolygon dependency analysis uses _segments iterable.""" segments = [ - self._create_mock_segment(f"GS{i}", self._create_mock_point(f"GP{i}"), self._create_mock_point(f"GP{i+1}")) + self._create_mock_segment( + f"GS{i}", self._create_mock_point(f"GP{i}"), self._create_mock_point(f"GP{i + 1}") + ) for i in range(7) ] generic = self._create_mock_drawable("GP1", "GenericPolygon") @@ -849,8 +858,16 @@ def test_analyze_vector_registers_segment_dependency(self) -> None: def test_type_hierarchy_includes_all_polygon_types(self) -> None: """All polygon class names must appear in _type_hierarchy.""" required = { - "Triangle", "Rectangle", "Quadrilateral", "Pentagon", "Hexagon", - "Heptagon", "Octagon", "Nonagon", "Decagon", "GenericPolygon", + "Triangle", + "Rectangle", + "Quadrilateral", + "Pentagon", + "Hexagon", + "Heptagon", + "Octagon", + "Nonagon", + "Decagon", + "GenericPolygon", } hierarchy_keys = set(self.manager._type_hierarchy.keys()) missing = required - hierarchy_keys diff --git a/static/client/client_tests/test_drawable_name_generator.py b/static/client/client_tests/test_drawable_name_generator.py index a0b1b8f1..55cad40f 100644 --- a/static/client/client_tests/test_drawable_name_generator.py +++ b/static/client/client_tests/test_drawable_name_generator.py @@ -17,12 +17,10 @@ def setUp(self) -> None: def test_get_drawable_names(self) -> None: # Here, get_drawables_by_class_name is expected to be a callable that returns a list of mocks when called - set_drawables = SimpleMock( - return_value=[SimpleMock(name='Point1'), SimpleMock(name='Point2')] - ) + set_drawables = SimpleMock(return_value=[SimpleMock(name="Point1"), SimpleMock(name="Point2")]) setattr(self.canvas, "get_drawables_by_class_name", set_drawables) - result = self.generator.get_drawable_names('Point') - self.assertEqual(result, ['Point1', 'Point2']) + result = self.generator.get_drawable_names("Point") + self.assertEqual(result, ["Point1", "Point2"]) def test_filter_string(self) -> None: # Test with a string that contains letters, digits, apostrophes, and parentheses @@ -48,13 +46,13 @@ def test_print_names(self) -> None: # Define predictable returns for each class name mock_returns = { - 'Point': [SimpleMock(name='Point1'), SimpleMock(name='Point2')], - 'Segment': [SimpleMock(name='Segment1'), SimpleMock(name='Segment2')], - 'Triangle': [SimpleMock(name='Triangle1'), SimpleMock(name='Triangle2')], - 'Rectangle': [SimpleMock(name='Rectangle1'), SimpleMock(name='Rectangle2')], - 'Circle': [SimpleMock(name='Circle1'), SimpleMock(name='Circle2')], - 'Ellipse': [SimpleMock(name='Ellipse1'), SimpleMock(name='Ellipse2')], - 'Function': [SimpleMock(name='Function1'), SimpleMock(name='Function2')] + "Point": [SimpleMock(name="Point1"), SimpleMock(name="Point2")], + "Segment": [SimpleMock(name="Segment1"), SimpleMock(name="Segment2")], + "Triangle": [SimpleMock(name="Triangle1"), SimpleMock(name="Triangle2")], + "Rectangle": [SimpleMock(name="Rectangle1"), SimpleMock(name="Rectangle2")], + "Circle": [SimpleMock(name="Circle1"), SimpleMock(name="Circle2")], + "Ellipse": [SimpleMock(name="Ellipse1"), SimpleMock(name="Ellipse2")], + "Function": [SimpleMock(name="Function1"), SimpleMock(name="Function2")], } # Set up the canvas.get_drawables_by_class_name to return the appropriate mock objects @@ -71,7 +69,7 @@ def mock_get_drawables(class_name: str) -> List[SimpleMock]: printed_lines = [] def mock_print(*args: object, **kwargs: object) -> None: - line = ' '.join(str(arg) for arg in args) + line = " ".join(str(arg) for arg in args) printed_lines.append(line) # Replace print with our mock @@ -92,14 +90,17 @@ def mock_print(*args: object, **kwargs: object) -> None: "Rectangle names: ['Rectangle1', 'Rectangle2']", "Circle names: ['Circle1', 'Circle2']", "Ellipse names: ['Ellipse1', 'Ellipse2']", - "Function names: ['Function1', 'Function2']" + "Function names: ['Function1', 'Function2']", ] # Compare line by line for easier debugging - self.assertEqual(len(printed_lines), len(expected_lines), - f"Expected {len(expected_lines)} lines but got {len(printed_lines)}") + self.assertEqual( + len(printed_lines), + len(expected_lines), + f"Expected {len(expected_lines)} lines but got {len(printed_lines)}", + ) for i, (actual, expected) in enumerate(zip(printed_lines, expected_lines)): - self.assertEqual(actual, expected, f"Line {i+1} doesn't match: {actual} != {expected}") + self.assertEqual(actual, expected, f"Line {i + 1} doesn't match: {actual} != {expected}") def test_split_point_names_basic(self) -> None: result = self.generator.split_point_names("A'B'CD", 4) @@ -149,28 +150,28 @@ def test_generate_point_name(self) -> None: setattr( self.canvas, "get_drawables_by_class_name", - SimpleMock(return_value=[SimpleMock(name='A')]), + SimpleMock(return_value=[SimpleMock(name="A")]), ) result = self.generator.generate_point_name(None) - self.assertEqual(result, 'B') + self.assertEqual(result, "B") def test_generate_point_name_with_preferred_name(self) -> None: setattr( self.canvas, "get_drawables_by_class_name", - SimpleMock(return_value=[SimpleMock(name='A')]), + SimpleMock(return_value=[SimpleMock(name="A")]), ) - result = self.generator.generate_point_name('B') - self.assertEqual(result, 'B') + result = self.generator.generate_point_name("B") + self.assertEqual(result, "B") def test_generate_point_name_with_used_preferred_name(self) -> None: setattr( self.canvas, "get_drawables_by_class_name", - SimpleMock(return_value=[SimpleMock(name='A'), SimpleMock(name='B')]), + SimpleMock(return_value=[SimpleMock(name="A"), SimpleMock(name="B")]), ) - result = self.generator.generate_point_name('B') - self.assertNotEqual(result, 'B') + result = self.generator.generate_point_name("B") + self.assertNotEqual(result, "B") def test_generate_point_name_with_complex_preferred_name(self) -> None: # When we pass "AB'C" as preferred_name, and 'A' is already used, @@ -179,7 +180,7 @@ def test_generate_point_name_with_complex_preferred_name(self) -> None: setattr( self.canvas, "get_drawables_by_class_name", - SimpleMock(return_value=[SimpleMock(name='A'), SimpleMock(name="B'")]), + SimpleMock(return_value=[SimpleMock(name="A"), SimpleMock(name="B'")]), ) # Reset the dictionary for a clean test @@ -192,20 +193,20 @@ def test_generate_point_name_with_complex_preferred_name(self) -> None: def test_increment_function_name(self) -> None: # Test with a function name that ends with a number - result = self.generator._increment_function_name('f4') - self.assertEqual(result, 'f5') + result = self.generator._increment_function_name("f4") + self.assertEqual(result, "f5") # Test with a function name that does not end with a number - result = self.generator._increment_function_name('f') - self.assertEqual(result, 'f1') + result = self.generator._increment_function_name("f") + self.assertEqual(result, "f1") # Test with a function name that ends with a large number - result = self.generator._increment_function_name('f99') - self.assertEqual(result, 'f100') + result = self.generator._increment_function_name("f99") + self.assertEqual(result, "f100") # Test with a function name that ends with a number and has other numbers in it - result = self.generator._increment_function_name('f4f4') - self.assertEqual(result, 'f4f5') + result = self.generator._increment_function_name("f4f4") + self.assertEqual(result, "f4f5") # Test with a function name that does not end with a number and has other numbers in it - result = self.generator._increment_function_name('f4f') - self.assertEqual(result, 'f4f1') + result = self.generator._increment_function_name("f4f") + self.assertEqual(result, "f4f1") def test_generate_unique_function_name(self) -> None: self.canvas.get_drawables_by_class_name = SimpleMock(return_value=[]) @@ -220,46 +221,40 @@ def test_generate_unique_function_name(self) -> None: self.assertEqual(result, "v1") def test_generate_function_name(self) -> None: - self.canvas.get_drawables_by_class_name = SimpleMock( - return_value=[SimpleMock(name='f'), SimpleMock(name='f1')] - ) + self.canvas.get_drawables_by_class_name = SimpleMock(return_value=[SimpleMock(name="f"), SimpleMock(name="f1")]) result = self.generator.generate_function_name(None) - self.assertEqual(result, 'g') + self.assertEqual(result, "g") def test_generate_function_name_with_preferred_name(self) -> None: - self.canvas.get_drawables_by_class_name = SimpleMock( - return_value=[SimpleMock(name='f1')] - ) - result = self.generator.generate_function_name('f2') - self.assertEqual(result, 'f2') + self.canvas.get_drawables_by_class_name = SimpleMock(return_value=[SimpleMock(name="f1")]) + result = self.generator.generate_function_name("f2") + self.assertEqual(result, "f2") def test_generate_function_name_with_used_preferred_name(self) -> None: self.canvas.get_drawables_by_class_name = SimpleMock( - return_value=[SimpleMock(name='f1'), SimpleMock(name='f2')] + return_value=[SimpleMock(name="f1"), SimpleMock(name="f2")] ) - result = self.generator.generate_function_name('f2') - self.assertEqual(result, 'f3') + result = self.generator.generate_function_name("f2") + self.assertEqual(result, "f3") def test_generate_function_name_with_preferred_name_and_parentheses(self) -> None: - self.canvas.get_drawables_by_class_name = SimpleMock( - return_value=[SimpleMock(name='f1')] - ) - result = self.generator.generate_function_name('f2(x)') - self.assertEqual(result, 'f2') + self.canvas.get_drawables_by_class_name = SimpleMock(return_value=[SimpleMock(name="f1")]) + result = self.generator.generate_function_name("f2(x)") + self.assertEqual(result, "f2") def test_generate_function_name_with_used_preferred_name_and_parentheses(self) -> None: self.canvas.get_drawables_by_class_name = SimpleMock( - return_value=[SimpleMock(name='f1'), SimpleMock(name='f2')] + return_value=[SimpleMock(name="f1"), SimpleMock(name="f2")] ) - result = self.generator.generate_function_name('f2(x)') - self.assertEqual(result, 'f3') + result = self.generator.generate_function_name("f2(x)") + self.assertEqual(result, "f3") def test_generate_function_name_with_complex_expression(self) -> None: self.canvas.get_drawables_by_class_name = SimpleMock( - return_value=[SimpleMock(name='f1'), SimpleMock(name='f2')] + return_value=[SimpleMock(name="f1"), SimpleMock(name="f2")] ) - result = self.generator.generate_function_name('g(x) = sin(x)') - self.assertEqual(result, 'g') + result = self.generator.generate_function_name("g(x) = sin(x)") + self.assertEqual(result, "g") def test_generate_angle_name_from_segments_valid(self) -> None: # Standard case: AB, AC -> angle_BAC (assuming B, C sorted) @@ -306,7 +301,7 @@ def test_generate_angle_name_invalid_no_common_vertex(self) -> None: def test_generate_angle_name_collinear_same_segment_implicitly(self) -> None: # e.g. AB, BA - this implies 2 points, not 3 unique points for an angle name = self.generator.generate_angle_name_from_segments("AB", "BA") - self.assertIsNone(name) # Should result in 2 unique points, not 3 + self.assertIsNone(name) # Should result in 2 unique points, not 3 def test_generate_angle_name_identical_segments(self) -> None: result = self.generator.generate_angle_name_from_segments("AB", "AB") @@ -319,17 +314,17 @@ def test_generate_angle_name_malformed_segment_names(self) -> None: # Case 1: First segment name is a single letter "A". # split_point_names("A") -> ["A", ""], which is invalid. - result = self.generator.generate_angle_name_from_segments("A", "BC") # BC is valid + result = self.generator.generate_angle_name_from_segments("A", "BC") # BC is valid self.assertIsNone(result, "Should be None for segment 'A'.") # Case 2: Second segment name is a single letter "A". - result = self.generator.generate_angle_name_from_segments("BC", "A") # BC is valid - self.assertIsNone(result, "Should be None for segment 'A' (second arg)." ) + result = self.generator.generate_angle_name_from_segments("BC", "A") # BC is valid + self.assertIsNone(result, "Should be None for segment 'A' (second arg).") # Case 3: Segment name is a single letter with an apostrophe "A'". # split_point_names("A'") -> ["A'", ""], which is invalid. result = self.generator.generate_angle_name_from_segments("A'", "BC") - self.assertIsNone(result, "Should be None for segment 'A\''.") + self.assertIsNone(result, "Should be None for segment 'A''.") def test_generate_angle_name_empty_or_none_segment_names(self) -> None: name = self.generator.generate_angle_name_from_segments("", "AB") diff --git a/static/client/client_tests/test_drawable_renderers.py b/static/client/client_tests/test_drawable_renderers.py index 2b01f230..eacc4071 100644 --- a/static/client/client_tests/test_drawable_renderers.py +++ b/static/client/client_tests/test_drawable_renderers.py @@ -51,8 +51,12 @@ def fill_polygon(self, points, fill, stroke=None, **kwargs): def fill_joined_area(self, forward, reverse, fill): self._record("fill_joined_area", forward, reverse, fill) - def stroke_arc(self, center, radius, start_angle_rad, end_angle_rad, sweep_clockwise, stroke, css_class=None, **kwargs): - self._record("stroke_arc", center, radius, start_angle_rad, end_angle_rad, sweep_clockwise, stroke, css_class, **kwargs) + def stroke_arc( + self, center, radius, start_angle_rad, end_angle_rad, sweep_clockwise, stroke, css_class=None, **kwargs + ): + self._record( + "stroke_arc", center, radius, start_angle_rad, end_angle_rad, sweep_clockwise, stroke, css_class, **kwargs + ) def draw_text(self, text, position, font, color, alignment, style_overrides=None, **kwargs): self._record("draw_text", text, position, font, color, alignment, style_overrides, **kwargs) @@ -222,6 +226,8 @@ def test_circle_arc_major_sweep_goes_clockwise(self) -> None: self.assertTrue(sweep_clockwise) delta = abs(end_angle - start_angle) self.assertTrue(math.isclose(delta, 3 * math.pi / 2, rel_tol=1e-6)) + + class TestEllipseRenderer(unittest.TestCase): def setUp(self) -> None: self.mapper = CoordinateMapper(640, 480) @@ -308,6 +314,7 @@ def test_label_respects_font_size(self) -> None: def test_label_zoom_adjusted_font_size(self) -> None: from drawables.position import Position + self.mapper.apply_zoom(0.5, Position(320, 240)) label = Label(2, 3, "Zoom Test", font_size=16, reference_scale_factor=1.0) @@ -458,6 +465,7 @@ def test_segment_with_same_endpoints_does_not_crash(self) -> None: def test_zero_radius_circle_does_not_crash(self) -> None: from drawables.circle import Circle + center = Point(0, 0, name="O") circle = Circle(center, radius=0) @@ -493,4 +501,3 @@ def test_label_with_empty_text_does_not_crash(self) -> None: "TestSegmentLabelRenderer", "TestRendererEdgeCases", ] - diff --git a/static/client/client_tests/test_drawables_container.py b/static/client/client_tests/test_drawables_container.py index b7e6abe0..2513b246 100644 --- a/static/client/client_tests/test_drawables_container.py +++ b/static/client/client_tests/test_drawables_container.py @@ -10,12 +10,18 @@ def setUp(self) -> None: """Set up test fixtures before each test.""" self.container = DrawablesContainer() # Create mock drawables of different types - self.point = SimpleMock(get_class_name=SimpleMock(return_value="Point"), - get_state=SimpleMock(return_value={"name": "A", "coords": [1, 2]})) - self.segment = SimpleMock(get_class_name=SimpleMock(return_value="Segment"), - get_state=SimpleMock(return_value={"name": "AB", "points": ["A", "B"]})) - self.circle = SimpleMock(get_class_name=SimpleMock(return_value="Circle"), - get_state=SimpleMock(return_value={"name": "c1", "center": "A", "radius": 5})) + self.point = SimpleMock( + get_class_name=SimpleMock(return_value="Point"), + get_state=SimpleMock(return_value={"name": "A", "coords": [1, 2]}), + ) + self.segment = SimpleMock( + get_class_name=SimpleMock(return_value="Segment"), + get_state=SimpleMock(return_value={"name": "AB", "points": ["A", "B"]}), + ) + self.circle = SimpleMock( + get_class_name=SimpleMock(return_value="Circle"), + get_state=SimpleMock(return_value={"name": "c1", "center": "A", "radius": 5}), + ) def test_init(self) -> None: """Test initialization of the container.""" @@ -31,8 +37,10 @@ def test_add(self) -> None: self.assertEqual(len(self.container.get_all()), 2, "Container should have 2 drawables") # Test adding multiple drawables of the same type - point2 = SimpleMock(get_class_name=SimpleMock(return_value="Point"), - get_state=SimpleMock(return_value={"name": "B", "coords": [3, 4]})) + point2 = SimpleMock( + get_class_name=SimpleMock(return_value="Point"), + get_state=SimpleMock(return_value={"name": "B", "coords": [3, 4]}), + ) self.container.add(point2) self.assertEqual(len(self.container.get_all()), 3, "Container should have 3 drawables") self.assertEqual(len(self.container.get_by_class_name("Point")), 2, "Container should have 2 Points") diff --git a/static/client/client_tests/test_ellipse.py b/static/client/client_tests/test_ellipse.py index 2998a7d1..05d2bb5f 100644 --- a/static/client/client_tests/test_ellipse.py +++ b/static/client/client_tests/test_ellipse.py @@ -22,7 +22,7 @@ def setUp(self) -> None: zoom_point=Position(1, 1), zoom_direction=1, zoom_step=0.1, - offset=Position(0, 0) # Set to (0,0) for simpler tests + offset=Position(0, 0), # Set to (0,0) for simpler tests ) # Sync canvas state with coordinate mapper @@ -32,7 +32,9 @@ def setUp(self) -> None: self.radius_x = 5 self.radius_y = 3 self.rotation_angle = 45 - self.ellipse = Ellipse(self.center, self.radius_x, self.radius_y, color="red", rotation_angle=self.rotation_angle) + self.ellipse = Ellipse( + self.center, self.radius_x, self.radius_y, color="red", rotation_angle=self.rotation_angle + ) def test_init(self) -> None: self.assertEqual(self.ellipse.center, self.center) @@ -42,7 +44,7 @@ def test_init(self) -> None: self.assertEqual(self.ellipse.color, "red") def test_get_class_name(self) -> None: - self.assertEqual(self.ellipse.get_class_name(), 'Ellipse') + self.assertEqual(self.ellipse.get_class_name(), "Ellipse") def test_calculate_ellipse_algebraic_formula(self) -> None: formula = self.ellipse._calculate_ellipse_algebraic_formula() @@ -57,8 +59,8 @@ def test_get_state(self) -> None: "radius_x": self.radius_x, "radius_y": self.radius_y, "rotation_angle": self.rotation_angle, - "ellipse_formula": self.ellipse.ellipse_formula - } + "ellipse_formula": self.ellipse.ellipse_formula, + }, } self.assertEqual(state, expected_state) @@ -97,4 +99,3 @@ def test_translate_ellipse_in_math_space(self) -> None: x, y = self.coordinate_mapper.math_to_screen(self.ellipse.center.x, self.ellipse.center.y) self.assertEqual((x, y), (255, 247)) self.assertNotEqual(original_formula, self.ellipse.ellipse_formula) - diff --git a/static/client/client_tests/test_ellipse_manager.py b/static/client/client_tests/test_ellipse_manager.py index 0ce282b2..1a9db169 100644 --- a/static/client/client_tests/test_ellipse_manager.py +++ b/static/client/client_tests/test_ellipse_manager.py @@ -106,9 +106,7 @@ def test_update_ellipse_rejects_center_with_other_parent(self) -> None: ellipse = self._add_ellipse() other_parent = object() - self.dependency_manager.get_parents = ( - lambda obj: {ellipse, other_parent} if obj is ellipse.center else set() - ) + self.dependency_manager.get_parents = lambda obj: {ellipse, other_parent} if obj is ellipse.center else set() self.dependency_manager.get_children = lambda obj: set() with self.assertRaises(ValueError): @@ -118,9 +116,7 @@ def test_update_ellipse_rejects_when_not_solitary(self) -> None: ellipse = self._add_ellipse() other_parent = object() - self.dependency_manager.get_parents = ( - lambda obj: {other_parent} if obj is ellipse else set() - ) + self.dependency_manager.get_parents = lambda obj: {other_parent} if obj is ellipse else set() self.dependency_manager.get_children = lambda obj: set() with self.assertRaises(ValueError): diff --git a/static/client/client_tests/test_error_recovery.py b/static/client/client_tests/test_error_recovery.py index e1a11845..d12f0138 100644 --- a/static/client/client_tests/test_error_recovery.py +++ b/static/client/client_tests/test_error_recovery.py @@ -3,6 +3,7 @@ When an AI request fails (e.g., TEST_ERROR_TRIGGER_12345), the user's message should be restored to the input field so they can edit and retry. """ + from __future__ import annotations import unittest @@ -17,6 +18,7 @@ class TestErrorRecovery(unittest.TestCase): def _create_ai_interface(self) -> Any: """Create an AIInterface instance without running __init__.""" from ai_interface import AIInterface + ai = AIInterface.__new__(AIInterface) # Initialize minimal state needed for error recovery ai._last_user_message = "" @@ -172,5 +174,3 @@ def mock_restore() -> None: self.assertTrue(restore_called[0]) # Verify buffer was NOT cleared (restore should happen before any clearing) self.assertEqual(ai._last_user_message, original_message) - - diff --git a/static/client/client_tests/test_event_handler.py b/static/client/client_tests/test_event_handler.py index 5846aa8c..616cf1d7 100644 --- a/static/client/client_tests/test_event_handler.py +++ b/static/client/client_tests/test_event_handler.py @@ -49,11 +49,12 @@ def setUp(self) -> None: "math-svg": self.mock_svg_element, "chat-input": self.mock_chat_input, "send-button": self.mock_send_button, - "new-conversation-button": self.mock_new_conversation_button + "new-conversation-button": self.mock_new_conversation_button, } # Replace the actual document with our mock import canvas_event_handler + self.original_document = canvas_event_handler.document canvas_event_handler.document = self.mock_document @@ -63,6 +64,7 @@ def setUp(self) -> None: def tearDown(self) -> None: """Restore original document.""" import canvas_event_handler + canvas_event_handler.document = self.original_document def _create_mock_touch(self, client_x: float, client_y: float) -> SimpleMock: diff --git a/static/client/client_tests/test_expression_validator.py b/static/client/client_tests/test_expression_validator.py index 7516c7d9..5d110596 100644 --- a/static/client/client_tests/test_expression_validator.py +++ b/static/client/client_tests/test_expression_validator.py @@ -43,7 +43,7 @@ def test_validate_expression_tree_valid(self) -> None: "stdev([1, 2, 3, x])", "variance([1, 2, 3, x])", "random()", - "randint(1, 10)" + "randint(1, 10)", ] for expr in valid_expressions: with self.subTest(expr=expr): @@ -54,7 +54,10 @@ def test_validate_expression_tree_valid(self) -> None: def test_validate_expression_tree_invalid(self) -> None: invalid_expressions = [ - "", " ", "\t", "\n", # Empty or whitespace-only strings + "", + " ", + "\t", + "\n", # Empty or whitespace-only strings "import os", # Import statements "from os import system", # From...import statements "__import__('os')", # Disallowed function call @@ -75,7 +78,7 @@ def test_validate_expression_tree_invalid(self) -> None: "sinhh(x)", "cossh(x)", "tanhh(x)", - "expp(1)" + "expp(1)", ] for expr in invalid_expressions: with self.subTest(expr=expr): @@ -87,7 +90,7 @@ def test_evaluate_expression(self) -> None: expressions = { "sin(pi/2)": math.sin(math.pi / 2), "sqrt(16)": 4, - "sqrt(-4)": '2i', + "sqrt(-4)": "2i", "log(e)": 1.0, "cos(pi)": math.cos(math.pi), "tan(pi/4)": math.tan(math.pi / 4), @@ -106,18 +109,18 @@ def test_evaluate_expression(self) -> None: "exp(1)": math.e, "abs(-5)": 5.0, "pow(2, 3)": 8, - "bin(10)": '0b1010', + "bin(10)": "0b1010", "det([[1, 2], [3, 4]])": -2.0, "arrangements(6, 3)": math.perm(6, 3), "permutations(5, 2)": math.perm(5, 2), "permutations(5)": math.factorial(5), "combinations(6, 3)": math.comb(6, 3), - "x": x + "x": x, } for expr, expected in expressions.items(): with self.subTest(expr=expr): evaluation_expr = expr - if '!' in expr: + if "!" in expr: evaluation_expr = ExpressionValidator.fix_math_expression(expr, python_compatible=True) result = ExpressionValidator.evaluate_expression(evaluation_expr, x=x) if isinstance(expected, float): @@ -136,8 +139,10 @@ def test_degree_to_radian_conversion(self) -> None: for expr, expected in expressions_and_expected.items(): with self.subTest(expr=expr): fixed_expr = ExpressionValidator.fix_math_expression(expr, python_compatible=False) - self.assertAlmostEqual(eval(fixed_expr, {"sin": math.sin, "cos": math.cos, "tan": math.tan, "pi": math.pi}), \ - eval(expected, {"sin": math.sin, "cos": math.cos, "tan": math.tan, "pi": math.pi})) + self.assertAlmostEqual( + eval(fixed_expr, {"sin": math.sin, "cos": math.cos, "tan": math.tan, "pi": math.pi}), + eval(expected, {"sin": math.sin, "cos": math.cos, "tan": math.tan, "pi": math.pi}), + ) def test_fix_math_expression_python_compatibility(self) -> None: expressions_and_fixes = { @@ -189,7 +194,7 @@ def test_parse_function_string_returns_number(self) -> None: "exp(1)", "abs(-pi)", "pow(2, 3)", - "det([[1, 2], [3, 4]])" + "det([[1, 2], [3, 4]])", ] for element in function_elements: with self.subTest(element=element): @@ -197,7 +202,9 @@ def test_parse_function_string_returns_number(self) -> None: # Test with a range of values for x in range(-10, 11): result = f(x) - self.assertIsInstance(result, (int, float), f"Result of expression '{element}' for x={x} is not a number: {result}") + self.assertIsInstance( + result, (int, float), f"Result of expression '{element}' for x={x} is not a number: {result}" + ) def test_parse_function_string(self) -> None: expressions: Dict[str, Callable[[float], float]] = { @@ -208,7 +215,9 @@ def test_parse_function_string(self) -> None: "sqrt(x)": math.sqrt, "exp(x)": math.exp, "tan(x)": math.tan, - "sin(pi/4) + cos(π/3) - tan(sqrt(16)) * log(e) / log10(100) + log2(8) * factorial(3) + asin(0.5) - acos(0.5) + atan(1) + sinh(1) - cosh(1) + tanh(0) + exp(1) - abs(-pi) + pow(2, 3)^2": lambda x: 82.09880526150872 + "sin(pi/4) + cos(π/3) - tan(sqrt(16)) * log(e) / log10(100) + log2(8) * factorial(3) + asin(0.5) - acos(0.5) + atan(1) + sinh(1) - cosh(1) + tanh(0) + exp(1) - abs(-pi) + pow(2, 3)^2": lambda x: ( + 82.09880526150872 + ), } for use_mathjs in [False, True]: for expr, expected_func in expressions.items(): @@ -227,12 +236,13 @@ def test_fix_math_expression_factorials(self) -> None: ("(3!)!", "factorial((factorial(3)))"), ("((2+3)!)!", "factorial((factorial((2+3))))"), ("sin(x!) + (cos(y)!)", "sin(factorial(x))+(factorial(cos(y)))"), - ("((1+2)*(3+4))!", "factorial(((1+2)*(3+4)))") + ("((1+2)*(3+4))!", "factorial(((1+2)*(3+4)))"), ] for expression, expected in cases: for python_compatible in (True, False): - fixed_expression = ExpressionValidator.fix_math_expression(expression, python_compatible=python_compatible) - self.assertNotIn('!', fixed_expression) - self.assertEqual(fixed_expression.replace(' ', ''), expected) - + fixed_expression = ExpressionValidator.fix_math_expression( + expression, python_compatible=python_compatible + ) + self.assertNotIn("!", fixed_expression) + self.assertEqual(fixed_expression.replace(" ", ""), expected) diff --git a/static/client/client_tests/test_font_helpers.py b/static/client/client_tests/test_font_helpers.py index 33f8fb35..841b42fa 100644 --- a/static/client/client_tests/test_font_helpers.py +++ b/static/client/client_tests/test_font_helpers.py @@ -31,11 +31,11 @@ def test_none_candidate_uses_fallback(self) -> None: self.assertEqual(result, 12.0) def test_nan_candidate_uses_fallback(self) -> None: - result = _coerce_font_size(float('nan'), 10) + result = _coerce_font_size(float("nan"), 10) self.assertEqual(result, 10.0) def test_infinity_candidate_uses_fallback(self) -> None: - result = _coerce_font_size(float('inf'), 10) + result = _coerce_font_size(float("inf"), 10) self.assertEqual(result, 10.0) def test_invalid_string_uses_fallback(self) -> None: @@ -80,6 +80,7 @@ def test_zoomed_out_scales_font(self) -> None: def test_extreme_zoom_out_returns_minimum(self) -> None: from constants import label_min_screen_font_px + label = SimpleNamespace(reference_scale_factor=1.0) mapper = SimpleNamespace(scale_factor=0.01) result = _compute_zoom_adjusted_font_size(16.0, label, mapper) @@ -89,6 +90,7 @@ def test_extreme_zoom_out_returns_minimum(self) -> None: def test_vanish_threshold(self) -> None: from constants import label_vanish_threshold_px + label = SimpleNamespace(reference_scale_factor=1.0) mapper = SimpleNamespace(scale_factor=0.001) result = _compute_zoom_adjusted_font_size(2.0, label, mapper) @@ -126,4 +128,3 @@ def test_negative_reference_scale_uses_default(self) -> None: __all__ = ["TestCoerceFontSize", "TestComputeZoomAdjustedFontSize"] - diff --git a/static/client/client_tests/test_function.py b/static/client/client_tests/test_function.py index 455b1d7a..da255c04 100644 --- a/static/client/client_tests/test_function.py +++ b/static/client/client_tests/test_function.py @@ -20,7 +20,7 @@ def setUp(self) -> None: get_visible_right_bound=SimpleMock(return_value=10), get_visible_top_bound=SimpleMock(return_value=10), get_visible_bottom_bound=SimpleMock(return_value=-10), - height=500 + height=500, ), is_point_within_canvas_visible_area=SimpleMock(return_value=True), # Add coordinate_mapper properties @@ -31,7 +31,7 @@ def setUp(self) -> None: zoom_point=Position(1, 1), zoom_direction=1, zoom_step=0.1, - offset=Position(0, 0) + offset=Position(0, 0), ) # Sync canvas state with coordinate mapper @@ -41,7 +41,9 @@ def setUp(self) -> None: self.right_bound = 9 self.function_string = "x*2" self.name = "DoubleX" - self.function = Function(self.function_string, self.name, left_bound=self.left_bound, right_bound=self.right_bound) + self.function = Function( + self.function_string, self.name, left_bound=self.left_bound, right_bound=self.right_bound + ) def test_initialize(self) -> None: # Test that the function is correctly initialized @@ -54,7 +56,7 @@ def test_init_with_invalid_function_explicit(self) -> None: _ = Function("sin(/0)", "InvalidFunction") def test_get_class_name(self) -> None: - self.assertEqual(self.function.get_class_name(), 'Function') + self.assertEqual(self.function.get_class_name(), "Function") def test_generate_values(self) -> None: # Test the generation of function values within canvas bounds using FunctionRenderable @@ -77,7 +79,11 @@ def test_generate_values(self) -> None: for point_tuple in path: # Convert tuple (x, y) to screen coordinates for bounds checking math_x, math_y = self.canvas.coordinate_mapper.screen_to_math(point_tuple[0], point_tuple[1]) - self.assertTrue(self.canvas.cartesian2axis.get_visible_left_bound() <= math_x <= self.canvas.cartesian2axis.get_visible_right_bound()) + self.assertTrue( + self.canvas.cartesian2axis.get_visible_left_bound() + <= math_x + <= self.canvas.cartesian2axis.get_visible_right_bound() + ) # Count points within and outside y bounds if bottom_bound <= math_y <= top_bound: @@ -86,12 +92,18 @@ def test_generate_values(self) -> None: points_outside_bounds += 1 # Ensure majority of points are within bounds - self.assertGreater(points_within_bounds, points_outside_bounds, - "Majority of points should be within y bounds") + self.assertGreater(points_within_bounds, points_outside_bounds, "Majority of points should be within y bounds") def test_get_state(self) -> None: state = self.function.get_state() - expected_state = {"name": self.name, "args": {"function_string": self.function_string, "left_bound": self.left_bound, "right_bound": self.right_bound}} + expected_state = { + "name": self.name, + "args": { + "function_string": self.function_string, + "left_bound": self.left_bound, + "right_bound": self.right_bound, + }, + } self.assertEqual(state, expected_state) def test_deepcopy(self) -> None: @@ -142,7 +154,7 @@ def test_zoom_via_canvas_draw_mechanism(self) -> None: # 1. Check if renderable has invalidate_cache method # 2. Call it if it exists # 3. Call build_screen_paths() - if hasattr(renderable, 'invalidate_cache'): + if hasattr(renderable, "invalidate_cache"): renderable.invalidate_cache() # Build paths again after cache invalidation @@ -166,8 +178,11 @@ def test_high_amplitude_gets_more_samples(self) -> None: high_amp_points = sum(len(path) for path in high_amp_paths) # High amplitude function should have at least as many points - self.assertGreaterEqual(high_amp_points, low_amp_points, - f"Expected high amplitude function ({high_amp_points} points) to have at least as many points as low amplitude ({low_amp_points})") + self.assertGreaterEqual( + high_amp_points, + low_amp_points, + f"Expected high amplitude function ({high_amp_points} points) to have at least as many points as low amplitude ({low_amp_points})", + ) def test_discontinuity_handling(self) -> None: # Test function with discontinuity using FunctionRenderable @@ -186,12 +201,18 @@ def test_discontinuity_handling(self) -> None: # Find gaps in x coordinates that indicate discontinuity has_discontinuity = False for i in range(1, len(flat_points)): - math_x1, math_y1 = self.canvas.coordinate_mapper.screen_to_math(flat_points[i-1][0], flat_points[i-1][1]) + math_x1, math_y1 = self.canvas.coordinate_mapper.screen_to_math( + flat_points[i - 1][0], flat_points[i - 1][1] + ) math_x2, math_y2 = self.canvas.coordinate_mapper.screen_to_math(flat_points[i][0], flat_points[i][1]) # Check for either a large x gap or a transition through bounds - if (abs(math_x2 - math_x1) > discontinuous_function.step * 2) or \ - (abs(math_y2 - math_y1) > (self.canvas.cartesian2axis.get_visible_top_bound() - - self.canvas.cartesian2axis.get_visible_bottom_bound())): + if (abs(math_x2 - math_x1) > discontinuous_function.step * 2) or ( + abs(math_y2 - math_y1) + > ( + self.canvas.cartesian2axis.get_visible_top_bound() + - self.canvas.cartesian2axis.get_visible_bottom_bound() + ) + ): has_discontinuity = True break @@ -291,8 +312,8 @@ def test_performance_limits(self) -> None: for path in paths: if len(path) > 1: # Only check paths with at least 2 points for i in range(1, len(path)): - dx = abs(path[i][0] - path[i-1][0]) # x coordinates - dy = abs(path[i][1] - path[i-1][1]) # y coordinates + dx = abs(path[i][0] - path[i - 1][0]) # x coordinates + dy = abs(path[i][1] - path[i - 1][1]) # y coordinates self.assertGreater(dx + dy, 0, "Points should not be duplicates") def test_high_frequency_trig_functions(self) -> None: @@ -320,10 +341,16 @@ def test_high_frequency_trig_functions(self) -> None: total_points = sum(len(path) for path in paths) # Check bounds: at least min_points, at most canvas width + 1 (endpoint-inclusive) - self.assertGreaterEqual(total_points, min_points, - f"{function_string} ({description}): {total_points} points, expected at least {min_points}") - self.assertLessEqual(total_points, canvas_width + 1, - f"{function_string} ({description}): {total_points} points, exceeds canvas width + 1") + self.assertGreaterEqual( + total_points, + min_points, + f"{function_string} ({description}): {total_points} points, expected at least {min_points}", + ) + self.assertLessEqual( + total_points, + canvas_width + 1, + f"{function_string} ({description}): {total_points} points, exceeds canvas width + 1", + ) except Exception as e: self.fail(f"Failed to handle {function_string}: {str(e)}") @@ -349,6 +376,7 @@ class TestFunctionUndefinedAt(unittest.TestCase): def test_function_with_single_undefined_point(self) -> None: """Test that a function with a single undefined point returns NaN at that point.""" import math + f = Function("2", "constant_with_hole", undefined_at=[0]) self.assertAlmostEqual(f.function(-5), 2.0) @@ -359,6 +387,7 @@ def test_function_with_single_undefined_point(self) -> None: def test_function_with_multiple_undefined_points(self) -> None: """Test that a function with multiple undefined points returns NaN at those points.""" import math + f = Function("x^2", "parabola_with_holes", undefined_at=[-1, 0, 1]) self.assertAlmostEqual(f.function(-5), 25.0) diff --git a/static/client/client_tests/test_function_bounded_colored_area_integration.py b/static/client/client_tests/test_function_bounded_colored_area_integration.py index bcff7bbb..22c5b3a0 100644 --- a/static/client/client_tests/test_function_bounded_colored_area_integration.py +++ b/static/client/client_tests/test_function_bounded_colored_area_integration.py @@ -21,15 +21,15 @@ def _build_area( left: float, right: float, ) -> Tuple[FunctionsBoundedColoredArea, FunctionsBoundedAreaRenderable]: - f1 = Function(f1_str, name='f1', left_bound=left, right_bound=right) - f2 = Function(f2_str, name='f2', left_bound=left, right_bound=right) + f1 = Function(f1_str, name="f1", left_bound=left, right_bound=right) + f2 = Function(f2_str, name="f2", left_bound=left, right_bound=right) area = FunctionsBoundedColoredArea(f1, f2, left_bound=left, right_bound=right, num_sample_points=50) renderable = FunctionsBoundedAreaRenderable(area, self.mapper) return area, renderable def test_area_respects_function_bounds_simple(self) -> None: left, right = -300, 300 - _, renderable = self._build_area('50 * sin(x / 50)', '100 * sin(x / 30)', left, right) + _, renderable = self._build_area("50 * sin(x / 50)", "100 * sin(x / 30)", left, right) closed = renderable.build_screen_area() self.assertIsNotNone(closed) self.assertGreater(len(closed.forward_points), 0) @@ -44,6 +44,7 @@ def test_area_respects_function_bounds_simple(self) -> None: left_screen_x = fxs[0] left_math_x, _ = self.mapper.screen_to_math(left_screen_x, 0) import math + y1_left = 50 * math.sin(left_math_x / 50.0) y2_left = 100 * math.sin(left_math_x / 30.0) _, y1s = self.mapper.math_to_screen(left_math_x, y1_left) @@ -63,7 +64,7 @@ def test_area_respects_function_bounds_simple(self) -> None: def test_area_respects_function_bounds_with_asymptotes(self) -> None: # f3 with tan introduces vertical asymptotes; ensure pairing still aligns ends left, right = -300, 300 - _, renderable = self._build_area('50 * sin(x / 50)', '100 * sin(x / 50) + 50 * tan(x / 100)', left, right) + _, renderable = self._build_area("50 * sin(x / 50)", "100 * sin(x / 50) + 50 * tan(x / 100)", left, right) closed = renderable.build_screen_area() self.assertIsNotNone(closed) self.assertGreater(len(closed.forward_points), 0) @@ -76,6 +77,7 @@ def test_area_respects_function_bounds_with_asymptotes(self) -> None: left_screen_x = fxs[0] left_math_x, _ = self.mapper.screen_to_math(left_screen_x, 0) import math + y1_left = 50 * math.sin(left_math_x / 50.0) y2_left = 100 * math.sin(left_math_x / 50.0) + 50 * math.tan(left_math_x / 100.0) # If tangent is near asymptote, skip assertion for y2 at that end @@ -84,5 +86,3 @@ def test_area_respects_function_bounds_with_asymptotes(self) -> None: self.assertAlmostEqual(closed.reverse_points[-1][1], y2s, places=5) _, y1s = self.mapper.math_to_screen(left_math_x, y1_left) self.assertAlmostEqual(closed.forward_points[0][1], y1s, places=5) - - diff --git a/static/client/client_tests/test_function_calling.py b/static/client/client_tests/test_function_calling.py index 77db0960..cd81ccb8 100644 --- a/static/client/client_tests/test_function_calling.py +++ b/static/client/client_tests/test_function_calling.py @@ -19,10 +19,12 @@ class TestProcessFunctionCalls(unittest.TestCase): def setUp(self) -> None: # Setup the mock canvas and its functions as described self.canvas = Canvas(500, 500, draw_enabled=False) # Assuming a basic mock or actual Canvas class - self.mock_cartesian2axis = SimpleMock(draw=SimpleMock(return_value=None), - reset=SimpleMock(return_value=None), - get_state=SimpleMock(return_value={'Cartesian_System_Visibility': 'cartesian_state'}), - origin=Position(0, 0)) # Assuming Position is defined elsewhere + self.mock_cartesian2axis = SimpleMock( + draw=SimpleMock(return_value=None), + reset=SimpleMock(return_value=None), + get_state=SimpleMock(return_value={"Cartesian_System_Visibility": "cartesian_state"}), + origin=Position(0, 0), + ) # Assuming Position is defined elsewhere self.canvas.cartesian2axis = self.mock_cartesian2axis # Mocking a function in canvas.drawables['Function'] @@ -41,7 +43,7 @@ def test_evaluate_numeric_expression(self) -> None: def test_evaluate_expression_with_variables(self) -> None: expression = "x - 4 + y * 5" - variables = {'x': 7, 'y': 65} + variables = {"x": 7, "y": 65} result = ProcessFunctionCalls.evaluate_expression(expression, variables=variables, canvas=self.canvas) self.assertEqual(result, 328) # Expected result for "x - 4 + y * 5" with x = 7 and y = 65 @@ -73,8 +75,10 @@ def test_evaluate_function_expression(self) -> None: self.assertEqual(result, 25) # Expected result for "Quadratic(5)" def test_get_results1(self) -> None: - available_functions = {'evaluate_expression': ProcessFunctionCalls.evaluate_expression} - calls = [{'function_name': 'evaluate_expression', 'arguments': {'expression': 'Quadratic(5)', 'canvas': self.canvas}}] + available_functions = {"evaluate_expression": ProcessFunctionCalls.evaluate_expression} + calls = [ + {"function_name": "evaluate_expression", "arguments": {"expression": "Quadratic(5)", "canvas": self.canvas}} + ] undoable_functions = () # Example, assuming no undoable functions for simplicity results: Dict[str, Any] = ProcessFunctionCalls.get_results( calls, @@ -83,17 +87,22 @@ def test_get_results1(self) -> None: self.canvas, ) self.assertTrue(len(results) > 0) - self.assertIn('Quadratic(5)', results) # Check if the result for "Quadratic(5)" is available - self.assertEqual(results['Quadratic(5)'], 25) # Expected result for "Quadratic(5)" + self.assertIn("Quadratic(5)", results) # Check if the result for "Quadratic(5)" is available + self.assertEqual(results["Quadratic(5)"], 25) # Expected result for "Quadratic(5)" def test_get_results2(self) -> None: - available_functions = {'evaluate_expression': ProcessFunctionCalls.evaluate_expression} - calls = [{'function_name': 'evaluate_expression', 'arguments': {'expression': 'x + y', 'variables': {'x': 5, 'y': 1}, 'canvas': self.canvas}}] + available_functions = {"evaluate_expression": ProcessFunctionCalls.evaluate_expression} + calls = [ + { + "function_name": "evaluate_expression", + "arguments": {"expression": "x + y", "variables": {"x": 5, "y": 1}, "canvas": self.canvas}, + } + ] undoable_functions = () # Example, assuming no undoable functions for simplicity results = ProcessFunctionCalls.get_results(calls, available_functions, undoable_functions, self.canvas) self.assertTrue(len(results) > 0) - self.assertIn('x+y for x:5, y:1', results) - self.assertEqual(results['x+y for x:5, y:1'], 6) + self.assertIn("x+y for x:5, y:1", results) + self.assertEqual(results["x+y for x:5, y:1"], 6) def test_get_current_canvas_state_tool_returns_envelope(self) -> None: workspace_manager = WorkspaceManager(self.canvas) @@ -122,14 +131,16 @@ def test_get_current_canvas_state_tool_supports_filters(self) -> None: workspace_manager = WorkspaceManager(self.canvas) available_functions = FunctionRegistry.get_available_functions(self.canvas, workspace_manager) undoable_functions = FunctionRegistry.get_undoable_functions() - calls = [{ - "function_name": "get_current_canvas_state", - "arguments": { - "drawable_types": ["point"], - "object_names": ["a"], - "include_computations": False, - }, - }] + calls = [ + { + "function_name": "get_current_canvas_state", + "arguments": { + "drawable_types": ["point"], + "object_names": ["a"], + "include_computations": False, + }, + } + ] results = ProcessFunctionCalls.get_results(calls, available_functions, undoable_functions, self.canvas) @@ -226,9 +237,7 @@ def fake_evaluate(*_: Any) -> Dict[str, Any]: LinearAlgebraUtils.evaluate_expression = staticmethod(fake_evaluate) with self.assertRaises(ValueError): - ProcessFunctionCalls.evaluate_linear_algebra_expression([ - {"name": "A", "value": [[1.0]]} - ], "A") + ProcessFunctionCalls.evaluate_linear_algebra_expression([{"name": "A", "value": [[1.0]]}], "A") def test_evaluate_linear_algebra_expression_supports_grouped_operations(self) -> None: captured_calls: List[Any] = [] @@ -476,8 +485,24 @@ def fake_evaluate(objects: List[Dict[str, Any]], expression: str) -> LinearAlgeb "function_name": "evaluate_linear_algebra_expression", "arguments": { "objects": [ - {"name": "Ainv", "value": [[0.08978748025755, -0.0431349377372, -0.1110128961747, 0.01593866015249], [0.00113046897795, 0.01621907946636, 0.01479939945209, -0.00775100230215], [-0.04488430734323, 0.03425951076629, 0.07993324152089, -0.01240860042922], [-0.01517103288635, 0.01419148847707, 0.02433255715856, 0.00388978770005]]}, - {"name": "Binv", "value": [[-0.01852789402109, -0.00012003596207, 0.01726042942533, -0.01527634792136], [0.03015604564334, -0.02073178307553, -0.02950658907737, 0.0414638983539], [0.05327575815057, -0.02568902469341, -0.04039437128784, 0.05172564429912], [0.007453579912, 0.00006621910084, 0.00062570406241, 0.01027237643632]]}, + { + "name": "Ainv", + "value": [ + [0.08978748025755, -0.0431349377372, -0.1110128961747, 0.01593866015249], + [0.00113046897795, 0.01621907946636, 0.01479939945209, -0.00775100230215], + [-0.04488430734323, 0.03425951076629, 0.07993324152089, -0.01240860042922], + [-0.01517103288635, 0.01419148847707, 0.02433255715856, 0.00388978770005], + ], + }, + { + "name": "Binv", + "value": [ + [-0.01852789402109, -0.00012003596207, 0.01726042942533, -0.01527634792136], + [0.03015604564334, -0.02073178307553, -0.02950658907737, 0.0414638983539], + [0.05327575815057, -0.02568902469341, -0.04039437128784, 0.05172564429912], + [0.007453579912, 0.00006621910084, 0.00062570406241, 0.01027237643632], + ], + }, ], "expression": "Ainv * Binv", }, diff --git a/static/client/client_tests/test_function_manager.py b/static/client/client_tests/test_function_manager.py index 95f0e9fd..f4a664a7 100644 --- a/static/client/client_tests/test_function_manager.py +++ b/static/client/client_tests/test_function_manager.py @@ -19,9 +19,7 @@ def setUp(self) -> None: archive=SimpleMock(), ), ) - self.canvas.drawable_manager = SimpleMock( - delete_colored_areas_for_function=SimpleMock() - ) + self.canvas.drawable_manager = SimpleMock(delete_colored_areas_for_function=SimpleMock()) self.drawables = DrawablesContainer() self.name_generator = SimpleMock(name="NameGeneratorMock") self.dependency_manager = SimpleMock(name="DependencyManagerMock") diff --git a/static/client/client_tests/test_function_renderables.py b/static/client/client_tests/test_function_renderables.py index a9db7a65..0289f08d 100644 --- a/static/client/client_tests/test_function_renderables.py +++ b/static/client/client_tests/test_function_renderables.py @@ -146,7 +146,11 @@ def test_sin_peaks_have_reasonable_smoothness(self) -> None: self.assertGreater(total_angles_checked, 50) violation_rate = violations / total_angles_checked if total_angles_checked > 0 else 0 - self.assertLess(violation_rate, 0.30, f"Found {violations} angles below 30 degrees out of {total_angles_checked} ({violation_rate:.1%})") + self.assertLess( + violation_rate, + 0.30, + f"Found {violations} angles below 30 degrees out of {total_angles_checked} ({violation_rate:.1%})", + ) class TestFunctionsBoundedAreaRenderable(unittest.TestCase): @@ -160,9 +164,7 @@ def test_build_screen_area_with_two_functions(self) -> None: f1 = Function("x^2", name="f") f2 = Function("x", name="g") - area_model = FunctionsBoundedColoredArea( - f1, f2, left_bound=-1, right_bound=1, color="green", opacity=0.4 - ) + area_model = FunctionsBoundedColoredArea(f1, f2, left_bound=-1, right_bound=1, color="green", opacity=0.4) renderable = FunctionsBoundedAreaRenderable(area_model, self.mapper) result = renderable.build_screen_area() @@ -178,9 +180,7 @@ def test_area_with_x_bounds(self) -> None: f1 = Function("2*x", name="f") f2 = Function("x", name="g") - area_model = FunctionsBoundedColoredArea( - f1, f2, left_bound=0, right_bound=3 - ) + area_model = FunctionsBoundedColoredArea(f1, f2, left_bound=0, right_bound=3) renderable = FunctionsBoundedAreaRenderable(area_model, self.mapper) result = renderable.build_screen_area() @@ -204,9 +204,7 @@ def test_build_screen_area_from_segments(self) -> None: seg1 = Segment(p1, p2) seg2 = Segment(p4, p3) - area_model = SegmentsBoundedColoredArea( - seg1, seg2, color="red", opacity=0.5 - ) + area_model = SegmentsBoundedColoredArea(seg1, seg2, color="red", opacity=0.5) renderable = SegmentsBoundedAreaRenderable(area_model, self.mapper) result = renderable.build_screen_area() @@ -393,8 +391,7 @@ def test_complex_sin_tan_combination(self) -> None: def test_complex_sin_tan_with_wide_bounds(self) -> None: # Same function with wider bounds to hit multiple asymptotes - func = Function("100*sin(x/50) + 50*tan(x/100)", name="complex_wide", - left_bound=-500, right_bound=500) + func = Function("100*sin(x/50) + 50*tan(x/100)", name="complex_wide", left_bound=-500, right_bound=500) renderable = FunctionRenderable(func, self.mapper) result = renderable.build_screen_paths() @@ -407,14 +404,8 @@ def test_complex_sin_tan_with_wide_bounds(self) -> None: first_y = path[0][1] last_y = path[-1][1] # Points should be within reasonable range or at boundaries - self.assertTrue( - -10000 < first_y < 10000, - f"First y={first_y} out of reasonable range" - ) - self.assertTrue( - -10000 < last_y < 10000, - f"Last y={last_y} out of reasonable range" - ) + self.assertTrue(-10000 < first_y < 10000, f"First y={first_y} out of reasonable range") + self.assertTrue(-10000 < last_y < 10000, f"Last y={last_y} out of reasonable range") def test_no_diagonal_lines_across_asymptotes(self) -> None: # Regression test: ensure no diagonal lines connecting different branches @@ -432,10 +423,9 @@ def test_no_diagonal_lines_across_asymptotes(self) -> None: # Check that consecutive points don't have huge x jumps # (which would indicate crossing an asymptote incorrectly) for i in range(1, len(path)): - x_diff = abs(path[i][0] - path[i-1][0]) + x_diff = abs(path[i][0] - path[i - 1][0]) # X difference should be reasonable (not jumping across screen) - self.assertLess(x_diff, self.width / 2, - f"Large x jump detected: {x_diff}, possible diagonal line bug") + self.assertLess(x_diff, self.width / 2, f"Large x jump detected: {x_diff}, possible diagonal line bug") def test_no_path_spans_both_screen_halves(self) -> None: # Critical: A single path for 1/x should NOT have endpoints at both top and bottom @@ -460,9 +450,11 @@ def test_no_path_spans_both_screen_halves(self) -> None: near_top = min_y < tolerance near_bottom = max_y > self.height - tolerance - self.assertFalse(near_top and near_bottom, + self.assertFalse( + near_top and near_bottom, f"Path {idx} spans both screen halves (y: {min_y:.0f} to {max_y:.0f}), " - f"indicates diagonal line crossing asymptote") + f"indicates diagonal line crossing asymptote", + ) def test_path_does_not_cross_asymptote_x(self) -> None: # For 1/x, asymptote is at x=0. No path should have points on both sides. @@ -485,9 +477,11 @@ def test_path_does_not_cross_asymptote_x(self) -> None: has_left = any(left_of_asymptote) has_right = any(right_of_asymptote) - self.assertFalse(has_left and has_right, + self.assertFalse( + has_left and has_right, f"Path {idx} crosses asymptote at x={asymptote_screen_x:.0f}. " - f"x range: [{min(x_values):.0f}, {max(x_values):.0f}]") + f"x range: [{min(x_values):.0f}, {max(x_values):.0f}]", + ) def test_no_large_y_jump_in_path(self) -> None: # Consecutive points should not have extreme y jumps (indicates wrong branch extension) @@ -503,19 +497,21 @@ def test_no_large_y_jump_in_path(self) -> None: continue for i in range(1, len(path)): - y1 = path[i-1][1] + y1 = path[i - 1][1] y2 = path[i][1] y_diff = abs(y2 - y1) - self.assertLess(y_diff, max_y_jump, + self.assertLess( + y_diff, + max_y_jump, f"Path {idx} has large y-jump at point {i}: {y_diff:.0f}px " - f"(from y={y1:.0f} to y={y2:.0f}), indicates invalid extension") + f"(from y={y1:.0f} to y={y2:.0f}), indicates invalid extension", + ) def test_no_large_y_jump_sin_tan_combo(self) -> None: # Test the complex function that had diagonal line bugs # With adaptive sampling, this function produces many small paths near asymptotes - func = Function("100*sin(x/50)+50*tan(x/100)", name="combo", - left_bound=100, right_bound=200) + func = Function("100*sin(x/50)+50*tan(x/100)", name="combo", left_bound=100, right_bound=200) renderable = FunctionRenderable(func, self.mapper) result = renderable.build_screen_paths() @@ -546,8 +542,7 @@ def test_paths_extend_to_boundaries_near_asymptotes(self) -> None: x_values = [pt[0] for pt in path] all_left = all(x < asymptote_x for x in x_values) all_right = all(x > asymptote_x for x in x_values) - self.assertTrue(all_left or all_right, - "Path should not cross asymptote") + self.assertTrue(all_left or all_right, "Path should not cross asymptote") def test_tan_paths_reach_vertical_bounds(self) -> None: # tan(x) should produce multiple separate paths due to asymptotes @@ -557,14 +552,12 @@ def test_tan_paths_reach_vertical_bounds(self) -> None: result = renderable.build_screen_paths() # tan(x) has asymptotes at pi/2 + n*pi, should produce multiple paths - self.assertGreater(len(result.paths), 1, - "tan(x) should have multiple paths due to asymptotes") + self.assertGreater(len(result.paths), 1, "tan(x) should have multiple paths due to asymptotes") def test_sin_tan_combo_no_crossing_artifacts(self) -> None: # Specific test for the complex function that had diagonal line bugs # With adaptive sampling, this produces many paths due to asymptotes - func = Function("100*sin(x/50) + 50*tan(x/100)", name="combo", - left_bound=-1000, right_bound=1000) + func = Function("100*sin(x/50) + 50*tan(x/100)", name="combo", left_bound=-1000, right_bound=1000) renderable = FunctionRenderable(func, self.mapper) result = renderable.build_screen_paths() @@ -607,10 +600,8 @@ def test_function_rendered_at_exact_bounds(self) -> None: last_math_x, _ = self.mapper.screen_to_math(path[-1][0], path[-1][1]) # First point should be near left bound, last near right bound - self.assertLess(abs(first_math_x - (-10)), 1.0, - f"First point not at left bound: {first_math_x}") - self.assertLess(abs(last_math_x - 10), 1.0, - f"Last point not at right bound: {last_math_x}") + self.assertLess(abs(first_math_x - (-10)), 1.0, f"First point not at left bound: {first_math_x}") + self.assertLess(abs(last_math_x - 10), 1.0, f"Last point not at right bound: {last_math_x}") def test_no_extension_past_screen_half_boundary(self) -> None: # Extensions should not cross from top half to bottom half of screen @@ -625,15 +616,14 @@ def test_no_extension_past_screen_half_boundary(self) -> None: continue # Check that path doesn't have huge y jumps within consecutive points for i in range(1, len(path)): - y1 = path[i-1][1] + y1 = path[i - 1][1] y2 = path[i][1] # If one point is in top half and other in bottom half, # they shouldn't be far apart (which would indicate a diagonal line bug) if (y1 < mid_y and y2 > mid_y) or (y1 > mid_y and y2 < mid_y): y_diff = abs(y2 - y1) # Small crossing is OK (near center), large crossing is a bug - self.assertLess(y_diff, self.height * 0.8, - f"Large y jump across screen center: {y_diff}") + self.assertLess(y_diff, self.height * 0.8, f"Large y jump across screen center: {y_diff}") def test_extrapolation_continues_path_direction(self) -> None: # When extrapolation is used, it should follow the path's direction @@ -661,8 +651,7 @@ def test_extrapolation_continues_path_direction(self) -> None: # Directions should be roughly similar (not reversed) if abs(dx1) > 1 and abs(dx2) > 1: # X direction should be consistent - self.assertEqual(dx1 > 0, dx2 > 0, - "X direction reversal at path start") + self.assertEqual(dx1 > 0, dx2 > 0, "X direction reversal at path start") def test_asymptote_paths_extend_to_correct_boundary(self) -> None: # For asymptotes, paths going up should extend to y=0 (top), @@ -690,8 +679,7 @@ def test_asymptote_paths_extend_to_correct_boundary(self) -> None: # If going down, last point should be near bottom (y=height) if going_up: # End should be at or near top (allow small margin for floating point) - self.assertLess(last_y, self.height / 2 + 5, - f"Path going up but ends at y={last_y}, not near top") + self.assertLess(last_y, self.height / 2 + 5, f"Path going up but ends at y={last_y}, not near top") def test_valid_extension_check_same_screen_half(self) -> None: # Test the _is_valid_extension logic @@ -701,23 +689,26 @@ def test_valid_extension_check_same_screen_half(self) -> None: # Test cases: (original_y, extension_y, should_be_valid) test_cases = [ # Same half - should be valid - (100, 50, True), # Both in top half + (100, 50, True), # Both in top half (400, 450, True), # Both in bottom half # Different half - should be invalid (100, 400, False), # Top to bottom (400, 100, False), # Bottom to top # At boundary - valid if correct boundary - (100, 0, True), # Top half to top boundary + (100, 0, True), # Top half to top boundary (400, 480, True), # Bottom half to bottom boundary # At wrong boundary - invalid - (100, 480, False), # Top half to bottom boundary - (400, 0, False), # Bottom half to top boundary + (100, 480, False), # Top half to bottom boundary + (400, 0, False), # Bottom half to top boundary ] for orig_y, ext_y, expected_valid in test_cases: result = renderable._is_valid_extension(orig_y, ext_y, self.height) - self.assertEqual(result, expected_valid, - f"_is_valid_extension({orig_y}, {ext_y}, {self.height}) = {result}, expected {expected_valid}") + self.assertEqual( + result, + expected_valid, + f"_is_valid_extension({orig_y}, {ext_y}, {self.height}) = {result}, expected {expected_valid}", + ) class TestRenderableEdgeCases(unittest.TestCase): @@ -813,8 +804,7 @@ def test_sqrt_function_domain_edge(self) -> None: for path in result.paths: for x, y in path: math_x, _ = self.mapper.screen_to_math(x, y) - self.assertGreaterEqual(math_x, -0.1, - f"sqrt(x) point at x={math_x} which is invalid domain") + self.assertGreaterEqual(math_x, -0.1, f"sqrt(x) point at x={math_x} which is invalid domain") def test_log_function_domain_edge(self) -> None: # log(x) is undefined for x <= 0 @@ -832,8 +822,7 @@ def test_multiple_asymptotes_tan(self) -> None: result = renderable.build_screen_paths() # Should produce multiple separate paths - self.assertGreater(len(result.paths), 1, - "tan(x) over [-10,10] should have multiple paths due to asymptotes") + self.assertGreater(len(result.paths), 1, "tan(x) over [-10,10] should have multiple paths due to asymptotes") def test_steep_exponential(self) -> None: # e^x grows very steeply @@ -875,4 +864,3 @@ def test_piecewise_like_abs(self) -> None: "TestBoundaryExtension", "TestRenderableEdgeCases", ] - diff --git a/static/client/client_tests/test_function_segment_bounded_colored_area.py b/static/client/client_tests/test_function_segment_bounded_colored_area.py index 2f3c780d..0e24364d 100644 --- a/static/client/client_tests/test_function_segment_bounded_colored_area.py +++ b/static/client/client_tests/test_function_segment_bounded_colored_area.py @@ -26,7 +26,7 @@ def setUp(self) -> None: zoom_point=Position(1, 1), zoom_direction=1, zoom_step=0.1, - offset=Position(0, 0) # Set to (0,0) for simpler tests + offset=Position(0, 0), # Set to (0,0) for simpler tests ) # Sync canvas state with coordinate mapper @@ -37,19 +37,11 @@ def setUp(self) -> None: name="f1", function=lambda x: x**2, # Quadratic function y = x^2 left_bound=-5, - right_bound=5 + right_bound=5, ) # Create mock segment (math coordinates on x/y) - self.segment = SimpleMock( - name="AB", - point1=SimpleMock( - x=-150, y=50 - ), - point2=SimpleMock( - x=150, y=-50 - ) - ) + self.segment = SimpleMock(name="AB", point1=SimpleMock(x=-150, y=50), point2=SimpleMock(x=150, y=-50)) def test_init(self) -> None: """Test initialization of FunctionSegmentBoundedColoredArea.""" @@ -62,7 +54,7 @@ def test_init(self) -> None: def test_get_class_name(self) -> None: """Test class name retrieval.""" area = FunctionSegmentBoundedColoredArea(self.func, self.segment) - self.assertEqual(area.get_class_name(), 'FunctionSegmentBoundedColoredArea') + self.assertEqual(area.get_class_name(), "FunctionSegmentBoundedColoredArea") def test_generate_name(self) -> None: """Test name generation.""" @@ -109,7 +101,7 @@ def test_calculate_function_y_value_handles_exceptions(self) -> None: # Create a function that throws an exception bad_func = SimpleMock( name="bad_func", - function=lambda x: 1/0 # This will cause ZeroDivisionError + function=lambda x: 1 / 0, # This will cause ZeroDivisionError ) area = FunctionSegmentBoundedColoredArea(bad_func, self.segment) @@ -145,16 +137,10 @@ def test_uses_segment(self) -> None: area = FunctionSegmentBoundedColoredArea(self.func, self.segment) # Create matching segment - matching_segment = SimpleMock( - point1=SimpleMock(x=-150, y=50), - point2=SimpleMock(x=150, y=-50) - ) + matching_segment = SimpleMock(point1=SimpleMock(x=-150, y=50), point2=SimpleMock(x=150, y=-50)) # Create non-matching segment - different_segment = SimpleMock( - point1=SimpleMock(x=150, y=250), - point2=SimpleMock(x=450, y=350) - ) + different_segment = SimpleMock(point1=SimpleMock(x=150, y=250), point2=SimpleMock(x=450, y=350)) self.assertTrue(area.uses_segment(matching_segment)) self.assertFalse(area.uses_segment(different_segment)) @@ -164,10 +150,7 @@ def test_get_state(self) -> None: area = FunctionSegmentBoundedColoredArea(self.func, self.segment) state = area.get_state() - expected_args = { - "func": "f1", - "segment": "AB" - } + expected_args = {"func": "f1", "segment": "AB"} # Check that the args contain the expected function and segment names self.assertEqual(state["args"]["func"], expected_args["func"]) self.assertEqual(state["args"]["segment"], expected_args["segment"]) @@ -212,8 +195,8 @@ def test_function_with_domain_restrictions(self) -> None: restricted_func = SimpleMock( name="restricted", function=lambda x: x**2, - left_bound=-2, # Restricted domain - right_bound=2 + left_bound=-2, # Restricted domain + right_bound=2, ) area = FunctionSegmentBoundedColoredArea(restricted_func, self.segment) @@ -230,7 +213,7 @@ def test_function_evaluation_error_handling(self) -> None: # Create function that throws ZeroDivisionError error_func = SimpleMock( name="error_func", - function=lambda x: 1/0 if x == 0 else 1/x # Division by zero exception + function=lambda x: 1 / 0 if x == 0 else 1 / x, # Division by zero exception ) area = FunctionSegmentBoundedColoredArea(error_func, self.segment) @@ -248,12 +231,8 @@ def test_segment_bounds_with_swapped_points(self) -> None: # Create segment with points swapped (larger x first) swapped_segment = SimpleMock( name="BA", # Reverse order - point1=SimpleMock( - x=150, y=-50 - ), - point2=SimpleMock( - x=-150, y=50 - ) + point1=SimpleMock(x=150, y=-50), + point2=SimpleMock(x=-150, y=50), ) area = FunctionSegmentBoundedColoredArea(self.func, swapped_segment) @@ -331,15 +310,7 @@ def test_segment_points_follow_mapper_transformations(self) -> None: def test_edge_case_single_point_segment(self) -> None: """Test edge case where segment endpoints are the same.""" # Create segment where both points are identical - single_point_segment = SimpleMock( - name="AA", - point1=SimpleMock( - x=0, y=0 - ), - point2=SimpleMock( - x=0, y=0 - ) - ) + single_point_segment = SimpleMock(name="AA", point1=SimpleMock(x=0, y=0), point2=SimpleMock(x=0, y=0)) area = FunctionSegmentBoundedColoredArea(self.func, single_point_segment) diff --git a/static/client/client_tests/test_functions_bounded_colored_area.py b/static/client/client_tests/test_functions_bounded_colored_area.py index f7eb1d0f..7fdbbfb6 100644 --- a/static/client/client_tests/test_functions_bounded_colored_area.py +++ b/static/client/client_tests/test_functions_bounded_colored_area.py @@ -30,7 +30,7 @@ def setUp(self) -> None: zoom_point=Position(1, 1), zoom_direction=1, zoom_step=0.1, - offset=Position(0, 0) # Set to (0,0) for simpler tests + offset=Position(0, 0), # Set to (0,0) for simpler tests ) # Sync canvas state with coordinate mapper @@ -41,13 +41,13 @@ def setUp(self) -> None: name="f1", function=lambda x: x, # Linear function y = x left_bound=-5, - right_bound=5 + right_bound=5, ) self.func2 = SimpleMock( name="f2", function=lambda x: x**2, # Quadratic function y = x^2 left_bound=-3, - right_bound=3 + right_bound=3, ) def test_init(self) -> None: @@ -61,7 +61,7 @@ def test_init(self) -> None: def test_get_class_name(self) -> None: """Test class name retrieval.""" area = FunctionsBoundedColoredArea(self.func1, self.func2) - self.assertEqual(area.get_class_name(), 'FunctionsBoundedColoredArea') + self.assertEqual(area.get_class_name(), "FunctionsBoundedColoredArea") def test_get_bounds_model_default_then_clipped_in_renderer(self) -> None: """Model provides math-space defaults; renderer handles viewport clipping.""" @@ -136,7 +136,7 @@ def test_get_state(self) -> None: "right_bound": 2, "color": "lightblue", "opacity": 0.3, - "num_sample_points": 100 + "num_sample_points": 100, } self.assertEqual(state["args"], expected_args) @@ -153,17 +153,18 @@ def test_asymptote_detection_with_tangent_function(self) -> None: """Test asymptote detection for tangent function.""" # Create a tangent function with known asymptotes import math + tangent_func = SimpleMock( name="f3", # Special name that triggers asymptote detection - function=lambda x: math.tan(x/100), + function=lambda x: math.tan(x / 100), left_bound=-500, - right_bound=500 + right_bound=500, ) area = FunctionsBoundedColoredArea(tangent_func, self.func2) # Test asymptote detection at known asymptote positions - asym_x = 100 * (math.pi/2) # First asymptote + asym_x = 100 * (math.pi / 2) # First asymptote dx = 1.0 # Should detect asymptote when very close @@ -178,10 +179,7 @@ def test_asymptote_handling_during_path_generation(self) -> None: """Test that asymptotes are properly handled during path generation.""" # Create function with division by zero at x=0 asymptote_func = SimpleMock( - name="asymptote_func", - function=lambda x: 1/x if x != 0 else float('inf'), - left_bound=-5, - right_bound=5 + name="asymptote_func", function=lambda x: 1 / x if x != 0 else float("inf"), left_bound=-5, right_bound=5 ) area = FunctionsBoundedColoredArea(asymptote_func, None) @@ -217,18 +215,17 @@ def test_function_evaluation_with_nan_and_infinity(self) -> None: # Function that returns various problematic values problematic_func = SimpleMock( name="problematic", - function=lambda x: { - 2: None, - 3: "not_a_number" - }.get(x, x) # Return x for other values, None and string for specific cases + function=lambda x: {2: None, 3: "not_a_number"}.get( + x, x + ), # Return x for other values, None and string for specific cases ) area = FunctionsBoundedColoredArea(problematic_func, self.func2) # Test cases that should return None test_cases = [ - (2, None), # None -> None - (3, None), # String -> None (not int/float) + (2, None), # None -> None + (3, None), # String -> None (not int/float) ] for x_input, expected in test_cases: @@ -252,7 +249,7 @@ def test_coordinate_conversion_integration(self) -> None: self.assertIsNotNone(result_none, "None function (x-axis) should return a result") # Test that different functions return different results - result1 = area._get_function_y_at_x(5, 2.0) # y = 5 + result1 = area._get_function_y_at_x(5, 2.0) # y = 5 result2 = area._get_function_y_at_x(10, 2.0) # y = 10 self.assertNotEqual(result1, result2, "Different constant functions should return different canvas coordinates") @@ -303,15 +300,13 @@ def test_path_generation_with_large_values(self) -> None: name="large_func", function=lambda x: x**10, # Very large values for |x| > 1 left_bound=-2, - right_bound=2 + right_bound=2, ) area = FunctionsBoundedColoredArea(large_value_func, None) # Should handle large values without crashing (math-only behavior now) - result = area._get_function_y_at_x_with_asymptote_handling( - large_value_func, 1.5, 0.1 - ) + result = area._get_function_y_at_x_with_asymptote_handling(large_value_func, 1.5, 0.1) # Should return a clipped value, not crash self.assertIsNotNone(result) @@ -335,9 +330,9 @@ def test_draw_method_with_no_valid_points(self) -> None: # Function that always returns None invalid_func = SimpleMock( name="invalid_func", - function=lambda x: float('nan'), # Always invalid + function=lambda x: float("nan"), # Always invalid left_bound=-1, - right_bound=1 + right_bound=1, ) area = FunctionsBoundedColoredArea(invalid_func, None) @@ -355,6 +350,7 @@ def test_reverse_path_generation(self) -> None: # Mock coordinate_mapper for predictable results call_count = [0] # Use list to allow modification in nested function + def mock_math_to_screen(x: float, y: float) -> tuple[float, float]: call_count[0] += 1 return (x * 10 + 250, 250 - y * 10) # Simple linear transformation diff --git a/static/client/client_tests/test_generic_polygon.py b/static/client/client_tests/test_generic_polygon.py index a90fef8f..efad814c 100644 --- a/static/client/client_tests/test_generic_polygon.py +++ b/static/client/client_tests/test_generic_polygon.py @@ -150,4 +150,3 @@ def test_large_polygon_20_sides(self) -> None: if __name__ == "__main__": unittest.main() - diff --git a/static/client/client_tests/test_geometry_utils.py b/static/client/client_tests/test_geometry_utils.py index d35b397f..d4741301 100644 --- a/static/client/client_tests/test_geometry_utils.py +++ b/static/client/client_tests/test_geometry_utils.py @@ -653,4 +653,3 @@ def test_square_hull(self) -> None: if __name__ == "__main__": unittest.main() - diff --git a/static/client/client_tests/test_graph_analyzer.py b/static/client/client_tests/test_graph_analyzer.py index a68b1138..de9c817e 100644 --- a/static/client/client_tests/test_graph_analyzer.py +++ b/static/client/client_tests/test_graph_analyzer.py @@ -436,10 +436,7 @@ class TestAnalyzeGraphTreeOperations(unittest.TestCase): def _make_tree_state(self, name: str, vertex_names: List[str], edge_pairs: List[tuple], root: str) -> TreeState: vertices = [GraphVertexDescriptor(v) for v in vertex_names] - edges = [ - GraphEdgeDescriptor(f"e{i}", src, tgt, name=f"{src}{tgt}") - for i, (src, tgt) in enumerate(edge_pairs) - ] + edges = [GraphEdgeDescriptor(f"e{i}", src, tgt, name=f"{src}{tgt}") for i, (src, tgt) in enumerate(edge_pairs)] return TreeState(name, vertices, edges, root=root) def test_levels_simple_tree(self) -> None: @@ -534,10 +531,7 @@ class TestAnalyzeGraphTreeTransforms(unittest.TestCase): def _make_tree_state(self, name: str, vertex_names: List[str], edge_pairs: List[tuple], root: str) -> TreeState: vertices = [GraphVertexDescriptor(v) for v in vertex_names] - edges = [ - GraphEdgeDescriptor(f"e{i}", src, tgt) - for i, (src, tgt) in enumerate(edge_pairs) - ] + edges = [GraphEdgeDescriptor(f"e{i}", src, tgt) for i, (src, tgt) in enumerate(edge_pairs)] return TreeState(name, vertices, edges, root=root) def test_reroot_simple(self) -> None: @@ -681,10 +675,16 @@ def test_k4_bridges(self) -> None: vertices = [GraphVertexDescriptor(v) for v in ["A", "B", "C", "D"]] edges = [ GraphEdgeDescriptor(f"e{i}", src, tgt) - for i, (src, tgt) in enumerate([ - ("A", "B"), ("A", "C"), ("A", "D"), - ("B", "C"), ("B", "D"), ("C", "D"), - ]) + for i, (src, tgt) in enumerate( + [ + ("A", "B"), + ("A", "C"), + ("A", "D"), + ("B", "C"), + ("B", "D"), + ("C", "D"), + ] + ) ] state = GraphState("K4", vertices, edges, directed=False) result = GraphAnalyzer.analyze(state, "bridges", {}) @@ -696,10 +696,16 @@ def test_k4_not_bipartite(self) -> None: vertices = [GraphVertexDescriptor(v) for v in ["A", "B", "C", "D"]] edges = [ GraphEdgeDescriptor(f"e{i}", src, tgt) - for i, (src, tgt) in enumerate([ - ("A", "B"), ("A", "C"), ("A", "D"), - ("B", "C"), ("B", "D"), ("C", "D"), - ]) + for i, (src, tgt) in enumerate( + [ + ("A", "B"), + ("A", "C"), + ("A", "D"), + ("B", "C"), + ("B", "D"), + ("C", "D"), + ] + ) ] state = GraphState("K4", vertices, edges, directed=False) result = GraphAnalyzer.analyze(state, "bipartite", {}) @@ -873,4 +879,3 @@ def test_point_on_hull_boundary(self) -> None: if __name__ == "__main__": unittest.main() - diff --git a/static/client/client_tests/test_graph_layout.py b/static/client/client_tests/test_graph_layout.py index 092d9de1..7bb779af 100644 --- a/static/client/client_tests/test_graph_layout.py +++ b/static/client/client_tests/test_graph_layout.py @@ -19,7 +19,6 @@ class TestGraphLayout(unittest.TestCase): - def setUp(self) -> None: self.box = {"x": 0.0, "y": 0.0, "width": 100.0, "height": 100.0} @@ -79,7 +78,7 @@ def test_grid_layout_two_nodes(self) -> None: bx, by = positions["B"] # Nodes should be separated - dist = math.sqrt((ax - bx)**2 + (ay - by)**2) + dist = math.sqrt((ax - bx) ** 2 + (ay - by) ** 2) self.assertGreater(dist, 5.0) def test_grid_layout_path(self) -> None: @@ -121,8 +120,12 @@ def test_grid_layout_square_cycle(self) -> None: def test_grid_layout_k4(self) -> None: """K4 (complete graph on 4) is planar and should work.""" edges = [ - Edge("A", "B"), Edge("A", "C"), Edge("A", "D"), - Edge("B", "C"), Edge("B", "D"), Edge("C", "D"), + Edge("A", "B"), + Edge("A", "C"), + Edge("A", "D"), + Edge("B", "C"), + Edge("B", "D"), + Edge("C", "D"), ] positions = _grid_layout(["A", "B", "C", "D"], edges, self.box) self.assertEqual(len(positions), 4) @@ -131,9 +134,15 @@ def test_grid_layout_k5_non_planar(self) -> None: """K5 is non-planar but should still produce valid layout (via fallback).""" vertices = ["A", "B", "C", "D", "E"] edges = [ - Edge("A", "B"), Edge("A", "C"), Edge("A", "D"), Edge("A", "E"), - Edge("B", "C"), Edge("B", "D"), Edge("B", "E"), - Edge("C", "D"), Edge("C", "E"), + Edge("A", "B"), + Edge("A", "C"), + Edge("A", "D"), + Edge("A", "E"), + Edge("B", "C"), + Edge("B", "D"), + Edge("B", "E"), + Edge("C", "D"), + Edge("C", "E"), Edge("D", "E"), ] positions = _grid_layout(vertices, edges, self.box) @@ -163,8 +172,8 @@ def test_grid_layout_no_vertex_overlap(self) -> None: pos_list = list(positions.values()) for i, p1 in enumerate(pos_list): - for p2 in pos_list[i + 1:]: - dist = math.sqrt((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2) + for p2 in pos_list[i + 1 :]: + dist = math.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) self.assertGreater(dist, 0.01, "Vertices overlap") def test_grid_layout_fits_box(self) -> None: @@ -234,9 +243,15 @@ def test_is_planar_k5_detection(self) -> None: """K5 should be detected as non-planar (too many edges).""" vertices = ["A", "B", "C", "D", "E"] edges = [ - Edge("A", "B"), Edge("A", "C"), Edge("A", "D"), Edge("A", "E"), - Edge("B", "C"), Edge("B", "D"), Edge("B", "E"), - Edge("C", "D"), Edge("C", "E"), + Edge("A", "B"), + Edge("A", "C"), + Edge("A", "D"), + Edge("A", "E"), + Edge("B", "C"), + Edge("B", "D"), + Edge("B", "E"), + Edge("C", "D"), + Edge("C", "E"), Edge("D", "E"), ] is_planar, embedding = _is_planar(vertices, edges) @@ -260,9 +275,15 @@ def test_grid_layout_two_squares_bridge_no_crossings(self) -> None: vertices = ["A", "B", "C", "D", "E", "F", "G", "H"] edges = [ # Square 1: A-B-C-D - Edge("A", "B"), Edge("B", "C"), Edge("C", "D"), Edge("D", "A"), + Edge("A", "B"), + Edge("B", "C"), + Edge("C", "D"), + Edge("D", "A"), # Square 2: E-F-G-H - Edge("E", "F"), Edge("F", "G"), Edge("G", "H"), Edge("H", "E"), + Edge("E", "F"), + Edge("F", "G"), + Edge("G", "H"), + Edge("H", "E"), # Bridge: D-E Edge("D", "E"), ] @@ -276,15 +297,23 @@ def test_grid_layout_two_squares_with_caps_no_crossings(self) -> None: vertices = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"] edges = [ # Square 1: A-B-C-D - Edge("A", "B"), Edge("B", "C"), Edge("C", "D"), Edge("D", "A"), + Edge("A", "B"), + Edge("B", "C"), + Edge("C", "D"), + Edge("D", "A"), # Square 2: E-F-G-H - Edge("E", "F"), Edge("F", "G"), Edge("G", "H"), Edge("H", "E"), + Edge("E", "F"), + Edge("F", "G"), + Edge("G", "H"), + Edge("H", "E"), # Bridge: D-E Edge("D", "E"), # Cap I connects to B and C - Edge("I", "B"), Edge("I", "C"), + Edge("I", "B"), + Edge("I", "C"), # Cap J connects to F and G - Edge("J", "F"), Edge("J", "G"), + Edge("J", "F"), + Edge("J", "G"), ] positions = _grid_layout(vertices, edges, self.box) @@ -297,12 +326,18 @@ def test_grid_layout_two_k4_bridge_no_crossings(self) -> None: vertices = ["A", "B", "C", "D", "E", "F", "G", "H"] edges = [ # K4 #1: A-B-C-D (all connected) - Edge("A", "B"), Edge("A", "C"), Edge("A", "D"), - Edge("B", "C"), Edge("B", "D"), + Edge("A", "B"), + Edge("A", "C"), + Edge("A", "D"), + Edge("B", "C"), + Edge("B", "D"), Edge("C", "D"), # K4 #2: E-F-G-H (all connected) - Edge("E", "F"), Edge("E", "G"), Edge("E", "H"), - Edge("F", "G"), Edge("F", "H"), + Edge("E", "F"), + Edge("E", "G"), + Edge("E", "H"), + Edge("F", "G"), + Edge("F", "H"), Edge("G", "H"), # Bridge: D-E Edge("D", "E"), @@ -331,9 +366,15 @@ def test_grid_layout_two_squares_bridge_no_overlaps(self) -> None: vertices = ["A", "B", "C", "D", "E", "F", "G", "H"] edges = [ # Square 1: A-B-C-D - Edge("A", "B"), Edge("B", "C"), Edge("C", "D"), Edge("D", "A"), + Edge("A", "B"), + Edge("B", "C"), + Edge("C", "D"), + Edge("D", "A"), # Square 2: E-F-G-H - Edge("E", "F"), Edge("F", "G"), Edge("G", "H"), Edge("H", "E"), + Edge("E", "F"), + Edge("F", "G"), + Edge("G", "H"), + Edge("H", "E"), # Bridge: D-E Edge("D", "E"), ] @@ -347,15 +388,23 @@ def test_grid_layout_two_squares_with_caps_no_overlaps(self) -> None: vertices = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"] edges = [ # Square 1: A-B-C-D - Edge("A", "B"), Edge("B", "C"), Edge("C", "D"), Edge("D", "A"), + Edge("A", "B"), + Edge("B", "C"), + Edge("C", "D"), + Edge("D", "A"), # Square 2: E-F-G-H - Edge("E", "F"), Edge("F", "G"), Edge("G", "H"), Edge("H", "E"), + Edge("E", "F"), + Edge("F", "G"), + Edge("G", "H"), + Edge("H", "E"), # Bridge: D-E Edge("D", "E"), # Cap I connects to B and C - Edge("I", "B"), Edge("I", "C"), + Edge("I", "B"), + Edge("I", "C"), # Cap J connects to F and G - Edge("J", "F"), Edge("J", "G"), + Edge("J", "F"), + Edge("J", "G"), ] positions = _grid_layout(vertices, edges, self.box) @@ -367,12 +416,18 @@ def test_grid_layout_two_k4_bridge_no_overlaps(self) -> None: vertices = ["A", "B", "C", "D", "E", "F", "G", "H"] edges = [ # K4 #1: A-B-C-D (all connected) - Edge("A", "B"), Edge("A", "C"), Edge("A", "D"), - Edge("B", "C"), Edge("B", "D"), + Edge("A", "B"), + Edge("A", "C"), + Edge("A", "D"), + Edge("B", "C"), + Edge("B", "D"), Edge("C", "D"), # K4 #2: E-F-G-H (all connected) - Edge("E", "F"), Edge("E", "G"), Edge("E", "H"), - Edge("F", "G"), Edge("F", "H"), + Edge("E", "F"), + Edge("E", "G"), + Edge("E", "H"), + Edge("F", "G"), + Edge("F", "H"), Edge("G", "H"), # Bridge: D-E Edge("D", "E"), @@ -387,9 +442,15 @@ def test_grid_layout_k5_no_overlaps(self) -> None: vertices = ["A", "B", "C", "D", "E"] # K5: every vertex connects to every other edges = [ - Edge("A", "B"), Edge("A", "C"), Edge("A", "D"), Edge("A", "E"), - Edge("B", "C"), Edge("B", "D"), Edge("B", "E"), - Edge("C", "D"), Edge("C", "E"), + Edge("A", "B"), + Edge("A", "C"), + Edge("A", "D"), + Edge("A", "E"), + Edge("B", "C"), + Edge("B", "D"), + Edge("B", "E"), + Edge("C", "D"), + Edge("C", "E"), Edge("D", "E"), ] positions = _grid_layout(vertices, edges, self.box) @@ -403,7 +464,10 @@ def test_grid_layout_line_graph_no_overlaps(self) -> None: # A - B - C - D - E (a path that could be collinear) vertices = ["A", "B", "C", "D", "E"] edges = [ - Edge("A", "B"), Edge("B", "C"), Edge("C", "D"), Edge("D", "E"), + Edge("A", "B"), + Edge("B", "C"), + Edge("C", "D"), + Edge("D", "E"), ] positions = _grid_layout(vertices, edges, self.box) @@ -415,8 +479,11 @@ def test_grid_layout_star_graph_no_overlaps(self) -> None: # Center connects to 5 outer vertices vertices = ["Center", "A", "B", "C", "D", "E"] edges = [ - Edge("Center", "A"), Edge("Center", "B"), Edge("Center", "C"), - Edge("Center", "D"), Edge("Center", "E"), + Edge("Center", "A"), + Edge("Center", "B"), + Edge("Center", "C"), + Edge("Center", "D"), + Edge("Center", "E"), ] positions = _grid_layout(vertices, edges, self.box) @@ -428,8 +495,11 @@ def test_grid_layout_triangle_with_extensions_no_overlaps(self) -> None: # Triangle A-B-C with D connected to A and B (could create D on edge A-B) vertices = ["A", "B", "C", "D"] edges = [ - Edge("A", "B"), Edge("B", "C"), Edge("C", "A"), - Edge("D", "A"), Edge("D", "B"), + Edge("A", "B"), + Edge("B", "C"), + Edge("C", "A"), + Edge("D", "A"), + Edge("D", "B"), ] positions = _grid_layout(vertices, edges, self.box) @@ -448,17 +518,22 @@ def test_grid_layout_simple_square_orthogonality(self) -> None: orthogonal, total = GraphUtils.count_orthogonal_edges(edges, positions) # All 4 edges should be orthogonal - self.assertEqual(orthogonal, total, - f"Simple square: {orthogonal}/{total} edges orthogonal (expected all)") + self.assertEqual(orthogonal, total, f"Simple square: {orthogonal}/{total} edges orthogonal (expected all)") def test_grid_layout_two_squares_bridge_orthogonality(self) -> None: """Two square cycles connected by bridge should have all edges orthogonal.""" vertices = ["A", "B", "C", "D", "E", "F", "G", "H"] edges = [ # Square 1: A-B-C-D - Edge("A", "B"), Edge("B", "C"), Edge("C", "D"), Edge("D", "A"), + Edge("A", "B"), + Edge("B", "C"), + Edge("C", "D"), + Edge("D", "A"), # Square 2: E-F-G-H - Edge("E", "F"), Edge("F", "G"), Edge("G", "H"), Edge("H", "E"), + Edge("E", "F"), + Edge("F", "G"), + Edge("G", "H"), + Edge("H", "E"), # Bridge: D-E Edge("D", "E"), ] @@ -466,23 +541,32 @@ def test_grid_layout_two_squares_bridge_orthogonality(self) -> None: orthogonal, total = GraphUtils.count_orthogonal_edges(edges, positions) # All 9 edges should be orthogonal - self.assertEqual(orthogonal, total, - f"Two squares with bridge: {orthogonal}/{total} edges orthogonal (expected all)") + self.assertEqual( + orthogonal, total, f"Two squares with bridge: {orthogonal}/{total} edges orthogonal (expected all)" + ) def test_grid_layout_two_squares_with_caps_orthogonality(self) -> None: """Two squares with caps - most edges should be orthogonal.""" vertices = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"] edges = [ # Square 1: A-B-C-D - Edge("A", "B"), Edge("B", "C"), Edge("C", "D"), Edge("D", "A"), + Edge("A", "B"), + Edge("B", "C"), + Edge("C", "D"), + Edge("D", "A"), # Square 2: E-F-G-H - Edge("E", "F"), Edge("F", "G"), Edge("G", "H"), Edge("H", "E"), + Edge("E", "F"), + Edge("F", "G"), + Edge("G", "H"), + Edge("H", "E"), # Bridge: D-E Edge("D", "E"), # Cap I connects to B and C - Edge("I", "B"), Edge("I", "C"), + Edge("I", "B"), + Edge("I", "C"), # Cap J connects to F and G - Edge("J", "F"), Edge("J", "G"), + Edge("J", "F"), + Edge("J", "G"), ] positions = _grid_layout(vertices, edges, self.box) @@ -490,8 +574,11 @@ def test_grid_layout_two_squares_with_caps_orthogonality(self) -> None: # At minimum, the cycle edges (8) + bridge (1) should be orthogonal = 9 # The cap edges (4) may or may not be orthogonal depending on layout min_orthogonal = 9 # The two squares + bridge - self.assertGreaterEqual(orthogonal, min_orthogonal, - f"Two squares with caps: {orthogonal}/{total} orthogonal (expected at least {min_orthogonal})") + self.assertGreaterEqual( + orthogonal, + min_orthogonal, + f"Two squares with caps: {orthogonal}/{total} orthogonal (expected at least {min_orthogonal})", + ) def test_grid_layout_two_k4_bridge_orthogonality(self) -> None: """Two K4 (complete squares) connected by bridge - at least perimeter edges orthogonal. @@ -502,11 +589,19 @@ def test_grid_layout_two_k4_bridge_orthogonality(self) -> None: vertices = ["A", "B", "C", "D", "E", "F", "G", "H"] edges = [ # K4 #1: A-B-C-D with diagonals A-C, B-D - Edge("A", "B"), Edge("B", "C"), Edge("C", "D"), Edge("D", "A"), - Edge("A", "C"), Edge("B", "D"), + Edge("A", "B"), + Edge("B", "C"), + Edge("C", "D"), + Edge("D", "A"), + Edge("A", "C"), + Edge("B", "D"), # K4 #2: E-F-G-H with diagonals E-G, F-H - Edge("E", "F"), Edge("F", "G"), Edge("G", "H"), Edge("H", "E"), - Edge("E", "G"), Edge("F", "H"), + Edge("E", "F"), + Edge("F", "G"), + Edge("G", "H"), + Edge("H", "E"), + Edge("E", "G"), + Edge("F", "H"), # Bridge: D-E Edge("D", "E"), ] @@ -517,8 +612,11 @@ def test_grid_layout_two_k4_bridge_orthogonality(self) -> None: # K4's diagonals create constraints that may force some perimeter edges non-orthogonal # Expect at least 5 orthogonal (realistic for this complex structure) min_orthogonal = 5 - self.assertGreaterEqual(orthogonal, min_orthogonal, - f"Two K4 with bridge: {orthogonal}/{total} orthogonal (expected at least {min_orthogonal})") + self.assertGreaterEqual( + orthogonal, + min_orthogonal, + f"Two K4 with bridge: {orthogonal}/{total} orthogonal (expected at least {min_orthogonal})", + ) # ------------------------------------------------------------------ # Edge length uniformity tests @@ -532,15 +630,22 @@ def test_grid_layout_simple_square_edge_lengths(self) -> None: same_count, total, _ = GraphUtils.count_edges_with_same_length(edges, positions) # All 4 edges should have the same length (it's a square) - self.assertEqual(same_count, total, - f"Simple square: {same_count}/{total} edges have same length (expected all)") + self.assertEqual( + same_count, total, f"Simple square: {same_count}/{total} edges have same length (expected all)" + ) def test_grid_layout_two_squares_bridge_edge_lengths(self) -> None: """Two squares with bridge - square edges should have uniform length.""" vertices = ["A", "B", "C", "D", "E", "F", "G", "H"] edges = [ - Edge("A", "B"), Edge("B", "C"), Edge("C", "D"), Edge("D", "A"), - Edge("E", "F"), Edge("F", "G"), Edge("G", "H"), Edge("H", "E"), + Edge("A", "B"), + Edge("B", "C"), + Edge("C", "D"), + Edge("D", "A"), + Edge("E", "F"), + Edge("F", "G"), + Edge("G", "H"), + Edge("H", "E"), Edge("D", "E"), ] positions = _grid_layout(vertices, edges, self.box) @@ -548,8 +653,9 @@ def test_grid_layout_two_squares_bridge_edge_lengths(self) -> None: same_count, total, _ = GraphUtils.count_edges_with_same_length(edges, positions, tolerance=0.15) # The 8 square edges should have the same length # The bridge may be different, so expect at least 8 edges with same length - self.assertGreaterEqual(same_count, 8, - f"Two squares bridge: {same_count}/{total} edges have same length (expected at least 8)") + self.assertGreaterEqual( + same_count, 8, f"Two squares bridge: {same_count}/{total} edges have same length (expected at least 8)" + ) def test_grid_layout_triangle_edge_lengths(self) -> None: """Triangle should have reasonably uniform edge lengths.""" @@ -559,8 +665,9 @@ def test_grid_layout_triangle_edge_lengths(self) -> None: uniformity = GraphUtils.edge_length_uniformity_ratio(edges, positions, tolerance=0.2) # At least 2/3 of edges should have similar length - self.assertGreaterEqual(uniformity, 0.66, - f"Triangle: {uniformity*100:.0f}% edge length uniformity (expected at least 66%)") + self.assertGreaterEqual( + uniformity, 0.66, f"Triangle: {uniformity * 100:.0f}% edge length uniformity (expected at least 66%)" + ) def test_grid_layout_line_graph_edge_lengths(self) -> None: """Line graph (path) should have uniform edge lengths.""" @@ -570,54 +677,74 @@ def test_grid_layout_line_graph_edge_lengths(self) -> None: same_count, total, _ = GraphUtils.count_edges_with_same_length(edges, positions, tolerance=0.15) # All edges in a path should have similar length - self.assertGreaterEqual(same_count, total - 1, - f"Line graph: {same_count}/{total} edges have same length (expected at least {total-1})") + self.assertGreaterEqual( + same_count, + total - 1, + f"Line graph: {same_count}/{total} edges have same length (expected at least {total - 1})", + ) def test_grid_layout_star_graph_edge_lengths(self) -> None: """Star graph should have uniform edge lengths.""" vertices = ["center", "A", "B", "C", "D"] edges = [ - Edge("center", "A"), Edge("center", "B"), - Edge("center", "C"), Edge("center", "D"), + Edge("center", "A"), + Edge("center", "B"), + Edge("center", "C"), + Edge("center", "D"), ] positions = _grid_layout(vertices, edges, self.box) uniformity = GraphUtils.edge_length_uniformity_ratio(edges, positions, tolerance=0.2) # All edges from center should have similar length in a good star layout - self.assertGreaterEqual(uniformity, 0.75, - f"Star graph: {uniformity*100:.0f}% edge length uniformity (expected at least 75%)") + self.assertGreaterEqual( + uniformity, 0.75, f"Star graph: {uniformity * 100:.0f}% edge length uniformity (expected at least 75%)" + ) def test_grid_layout_hexagon_edge_lengths(self) -> None: """Hexagon cycle should have uniform edge lengths.""" vertices = ["A", "B", "C", "D", "E", "F"] edges = [ - Edge("A", "B"), Edge("B", "C"), Edge("C", "D"), - Edge("D", "E"), Edge("E", "F"), Edge("F", "A"), + Edge("A", "B"), + Edge("B", "C"), + Edge("C", "D"), + Edge("D", "E"), + Edge("E", "F"), + Edge("F", "A"), ] positions = _grid_layout(vertices, edges, self.box) same_count, total, _ = GraphUtils.count_edges_with_same_length(edges, positions, tolerance=0.2) # Most edges in a cycle should have similar length - self.assertGreaterEqual(same_count, 4, - f"Hexagon: {same_count}/{total} edges have same length (expected at least 4)") + self.assertGreaterEqual( + same_count, 4, f"Hexagon: {same_count}/{total} edges have same length (expected at least 4)" + ) def test_grid_layout_two_squares_with_caps_edge_lengths(self) -> None: """Two squares with caps - cycle edges should have similar lengths.""" vertices = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"] edges = [ - Edge("A", "B"), Edge("B", "C"), Edge("C", "D"), Edge("D", "A"), - Edge("E", "F"), Edge("F", "G"), Edge("G", "H"), Edge("H", "E"), + Edge("A", "B"), + Edge("B", "C"), + Edge("C", "D"), + Edge("D", "A"), + Edge("E", "F"), + Edge("F", "G"), + Edge("G", "H"), + Edge("H", "E"), Edge("D", "E"), - Edge("I", "B"), Edge("I", "C"), - Edge("J", "F"), Edge("J", "G"), + Edge("I", "B"), + Edge("I", "C"), + Edge("J", "F"), + Edge("J", "G"), ] positions = _grid_layout(vertices, edges, self.box) same_count, total, _ = GraphUtils.count_edges_with_same_length(edges, positions, tolerance=0.2) # Caps create triangular connections which may have different lengths # Expect at least 4 edges (one square) to have similar lengths - self.assertGreaterEqual(same_count, 4, - f"Two squares with caps: {same_count}/{total} edges have same length (expected at least 4)") + self.assertGreaterEqual( + same_count, 4, f"Two squares with caps: {same_count}/{total} edges have same length (expected at least 4)" + ) def test_grid_layout_edge_length_variance_simple_square(self) -> None: """Simple square should have low edge length variance.""" @@ -630,10 +757,11 @@ def test_grid_layout_edge_length_variance_simple_square(self) -> None: # Normalize by average length squared for relative comparison lengths = GraphUtils.get_edge_lengths(edges, positions) avg_length = sum(lengths) / len(lengths) if lengths else 1.0 - relative_variance = variance / (avg_length ** 2) if avg_length > 0 else 0.0 + relative_variance = variance / (avg_length**2) if avg_length > 0 else 0.0 - self.assertLess(relative_variance, 0.05, - f"Simple square: relative variance {relative_variance:.4f} (expected < 0.05)") + self.assertLess( + relative_variance, 0.05, f"Simple square: relative variance {relative_variance:.4f} (expected < 0.05)" + ) # ------------------------------------------------------------------ # Tree layout @@ -791,11 +919,11 @@ def test_force_layout_connected_closer_than_unconnected(self) -> None: cx, cy = positions["C"] # Distance A-B (connected) - dist_ab = math.sqrt((ax - bx)**2 + (ay - by)**2) + dist_ab = math.sqrt((ax - bx) ** 2 + (ay - by) ** 2) # Distance A-C (not connected) - dist_ac = math.sqrt((ax - cx)**2 + (ay - cy)**2) + dist_ac = math.sqrt((ax - cx) ** 2 + (ay - cy) ** 2) # Distance B-C (not connected) - dist_bc = math.sqrt((bx - cx)**2 + (by - cy)**2) + dist_bc = math.sqrt((bx - cx) ** 2 + (by - cy) ** 2) # Connected pair should be closer than at least one unconnected pair self.assertLess(dist_ab, max(dist_ac, dist_bc)) @@ -818,7 +946,7 @@ def test_force_layout_clusters_separate(self) -> None: c2x, c2y = (cx + dx) / 2, (cy + dy) / 2 # Cluster centers should be separated - cluster_dist = math.sqrt((c1x - c2x)**2 + (c1y - c2y)**2) + cluster_dist = math.sqrt((c1x - c2x) ** 2 + (c1y - c2y) ** 2) self.assertGreater(cluster_dist, 20.0) def test_force_layout_two_complete_squares_with_bridge(self) -> None: @@ -832,11 +960,19 @@ def test_force_layout_two_complete_squares_with_bridge(self) -> None: vertices = ["A", "B", "C", "D", "E", "F", "G", "H"] edges = [ # Cluster 1: complete square - Edge("A", "B"), Edge("B", "C"), Edge("C", "D"), Edge("D", "A"), - Edge("A", "C"), Edge("B", "D"), + Edge("A", "B"), + Edge("B", "C"), + Edge("C", "D"), + Edge("D", "A"), + Edge("A", "C"), + Edge("B", "D"), # Cluster 2: complete square - Edge("E", "F"), Edge("F", "G"), Edge("G", "H"), Edge("H", "E"), - Edge("E", "G"), Edge("F", "H"), + Edge("E", "F"), + Edge("F", "G"), + Edge("G", "H"), + Edge("H", "E"), + Edge("E", "G"), + Edge("F", "H"), # Bridge Edge("D", "E"), ] @@ -853,15 +989,15 @@ def test_force_layout_two_complete_squares_with_bridge(self) -> None: self.assertLess(y, 570.0, f"{vid} y too large: {y}") # Cluster 1 center (A,B,C,D) - c1_x = sum(positions[v][0] for v in ["A","B","C","D"]) / 4 - c1_y = sum(positions[v][1] for v in ["A","B","C","D"]) / 4 + c1_x = sum(positions[v][0] for v in ["A", "B", "C", "D"]) / 4 + c1_y = sum(positions[v][1] for v in ["A", "B", "C", "D"]) / 4 # Cluster 2 center (E,F,G,H) - c2_x = sum(positions[v][0] for v in ["E","F","G","H"]) / 4 - c2_y = sum(positions[v][1] for v in ["E","F","G","H"]) / 4 + c2_x = sum(positions[v][0] for v in ["E", "F", "G", "H"]) / 4 + c2_y = sum(positions[v][1] for v in ["E", "F", "G", "H"]) / 4 # Cluster centers should be well separated - cluster_dist = math.sqrt((c1_x - c2_x)**2 + (c1_y - c2_y)**2) + cluster_dist = math.sqrt((c1_x - c2_x) ** 2 + (c1_y - c2_y) ** 2) self.assertGreater(cluster_dist, 100.0, f"Cluster separation too small: {cluster_dist}") # Intra-cluster distances should be smaller than inter-cluster distances @@ -869,10 +1005,10 @@ def avg_dist(v_list: list[str]) -> float: total = 0.0 count = 0 for i, v1 in enumerate(v_list): - for v2 in v_list[i+1:]: + for v2 in v_list[i + 1 :]: x1, y1 = positions[v1] x2, y2 = positions[v2] - total += math.sqrt((x1-x2)**2 + (y1-y2)**2) + total += math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2) count += 1 return total / count if count > 0 else 0 @@ -965,9 +1101,12 @@ def test_infer_root_fallback_for_cycle(self) -> None: def test_infer_root_binary_tree(self) -> None: """Infer root for a binary tree structure.""" edges = [ - Edge("v0", "v1"), Edge("v0", "v2"), - Edge("v1", "v3"), Edge("v1", "v4"), - Edge("v2", "v5"), Edge("v2", "v6"), + Edge("v0", "v1"), + Edge("v0", "v2"), + Edge("v1", "v3"), + Edge("v1", "v4"), + Edge("v2", "v5"), + Edge("v2", "v6"), ] root = _infer_root(["v0", "v1", "v2", "v3", "v4", "v5", "v6"], edges) self.assertEqual(root, "v0") @@ -1018,9 +1157,12 @@ def test_is_tree_structure_simple_tree(self) -> None: def test_is_tree_structure_binary_tree(self) -> None: """Detect binary tree structure.""" edges = [ - Edge("v0", "v1"), Edge("v0", "v2"), - Edge("v1", "v3"), Edge("v1", "v4"), - Edge("v2", "v5"), Edge("v2", "v6"), + Edge("v0", "v1"), + Edge("v0", "v2"), + Edge("v1", "v3"), + Edge("v1", "v4"), + Edge("v2", "v5"), + Edge("v2", "v6"), ] self.assertTrue(_is_tree_structure(["v0", "v1", "v2", "v3", "v4", "v5", "v6"], edges)) @@ -1076,21 +1218,13 @@ def _assert_all_vertices_in_box( """Assert all vertex positions are within the bounding box.""" for vid, pos in positions.items(): x, y = pos[0], pos[1] - self.assertGreaterEqual( - x, box["x"], - f"{msg} Vertex {vid} x={x} is less than box x={box['x']}" - ) + self.assertGreaterEqual(x, box["x"], f"{msg} Vertex {vid} x={x} is less than box x={box['x']}") self.assertLessEqual( - x, box["x"] + box["width"], - f"{msg} Vertex {vid} x={x} exceeds box right={box['x'] + box['width']}" - ) - self.assertGreaterEqual( - y, box["y"], - f"{msg} Vertex {vid} y={y} is less than box y={box['y']}" + x, box["x"] + box["width"], f"{msg} Vertex {vid} x={x} exceeds box right={box['x'] + box['width']}" ) + self.assertGreaterEqual(y, box["y"], f"{msg} Vertex {vid} y={y} is less than box y={box['y']}") self.assertLessEqual( - y, box["y"] + box["height"], - f"{msg} Vertex {vid} y={y} exceeds box top={box['y'] + box['height']}" + y, box["y"] + box["height"], f"{msg} Vertex {vid} y={y} exceeds box top={box['y'] + box['height']}" ) # ------------------------------------------------------------------ @@ -1177,8 +1311,14 @@ def test_grid_layout_visibility_offset_box(self) -> None: """Grid layout vertices should be within offset box.""" box = {"x": -500.0, "y": -300.0, "width": 1000.0, "height": 600.0} edges = [ - Edge("A", "B"), Edge("B", "C"), Edge("C", "D"), Edge("D", "A"), - Edge("E", "F"), Edge("F", "G"), Edge("G", "H"), Edge("H", "E"), + Edge("A", "B"), + Edge("B", "C"), + Edge("C", "D"), + Edge("D", "A"), + Edge("E", "F"), + Edge("F", "G"), + Edge("G", "H"), + Edge("H", "E"), Edge("D", "E"), ] positions = layout_vertices( @@ -1209,8 +1349,11 @@ def test_tree_layout_visibility_offset_box(self) -> None: """Tree layout vertices should be within offset box.""" box = {"x": -300.0, "y": -250.0, "width": 600.0, "height": 500.0} edges = [ - Edge("R", "A"), Edge("R", "B"), Edge("R", "C"), - Edge("A", "D"), Edge("A", "E"), + Edge("R", "A"), + Edge("R", "B"), + Edge("R", "C"), + Edge("A", "D"), + Edge("A", "E"), Edge("B", "F"), ] positions = layout_vertices( @@ -1262,9 +1405,15 @@ def test_grid_layout_visibility_k5_nonplanar(self) -> None: box = {"x": -400.0, "y": -300.0, "width": 800.0, "height": 600.0} vertices = ["A", "B", "C", "D", "E"] edges = [ - Edge("A", "B"), Edge("A", "C"), Edge("A", "D"), Edge("A", "E"), - Edge("B", "C"), Edge("B", "D"), Edge("B", "E"), - Edge("C", "D"), Edge("C", "E"), + Edge("A", "B"), + Edge("A", "C"), + Edge("A", "D"), + Edge("A", "E"), + Edge("B", "C"), + Edge("B", "D"), + Edge("B", "E"), + Edge("C", "D"), + Edge("C", "E"), Edge("D", "E"), ] positions = layout_vertices( @@ -1281,8 +1430,14 @@ def test_grid_layout_visibility_two_squares_bridge(self) -> None: """Two squares connected by bridge should be within box.""" box = {"x": -500.0, "y": -400.0, "width": 1000.0, "height": 800.0} edges = [ - Edge("A", "B"), Edge("B", "C"), Edge("C", "D"), Edge("D", "A"), - Edge("E", "F"), Edge("F", "G"), Edge("G", "H"), Edge("H", "E"), + Edge("A", "B"), + Edge("B", "C"), + Edge("C", "D"), + Edge("D", "A"), + Edge("E", "F"), + Edge("F", "G"), + Edge("G", "H"), + Edge("H", "E"), Edge("D", "E"), ] positions = layout_vertices( @@ -1300,8 +1455,10 @@ def test_tree_layout_visibility_deep_tree(self) -> None: box = {"x": -300.0, "y": -400.0, "width": 600.0, "height": 800.0} # Create a tree with depth 5 edges = [ - Edge("L0", "L1a"), Edge("L0", "L1b"), - Edge("L1a", "L2a"), Edge("L1a", "L2b"), + Edge("L0", "L1a"), + Edge("L0", "L1b"), + Edge("L1a", "L2a"), + Edge("L1a", "L2b"), Edge("L1b", "L2c"), Edge("L2a", "L3a"), Edge("L3a", "L4a"), @@ -1339,8 +1496,10 @@ def test_hierarchical_layout_visibility(self) -> None: """Hierarchical layout vertices should be within box.""" box = {"x": -200.0, "y": -300.0, "width": 400.0, "height": 600.0} edges = [ - Edge("CEO", "VP1"), Edge("CEO", "VP2"), - Edge("VP1", "M1"), Edge("VP1", "M2"), + Edge("CEO", "VP1"), + Edge("CEO", "VP2"), + Edge("VP1", "M1"), + Edge("VP1", "M2"), Edge("VP2", "M3"), ] positions = layout_vertices( @@ -1360,9 +1519,12 @@ def test_binary_tree_3_layers_placement_box(self) -> None: box = {"x": -500.0, "y": -350.0, "width": 300.0, "height": 325.0} vertices = ["R", "L1", "L2", "L1A", "L1B", "L2A", "L2B"] edges = [ - Edge("R", "L1"), Edge("R", "L2"), - Edge("L1", "L1A"), Edge("L1", "L1B"), - Edge("L2", "L2A"), Edge("L2", "L2B"), + Edge("R", "L1"), + Edge("R", "L2"), + Edge("L1", "L1A"), + Edge("L1", "L1B"), + Edge("L2", "L2A"), + Edge("L2", "L2B"), ] positions = layout_vertices( vertices, @@ -1384,4 +1546,3 @@ def test_binary_tree_3_layers_placement_box(self) -> None: if __name__ == "__main__": unittest.main() - diff --git a/static/client/client_tests/test_graph_manager.py b/static/client/client_tests/test_graph_manager.py index 9fe662c5..559cc3ac 100644 --- a/static/client/client_tests/test_graph_manager.py +++ b/static/client/client_tests/test_graph_manager.py @@ -44,6 +44,7 @@ def setUp(self) -> None: ) self.points_created: List[Point] = [] + def create_point(x: float, y: float, name: str = "", color: str = None, extra_graphics: bool = True) -> Point: p = Point(x, y, name=name if name else f"P{len(self.points_created)}") self.points_created.append(p) @@ -57,7 +58,10 @@ def create_point(x: float, y: float, name: str = "", color: str = None, extra_gr ) self.segments_created: List[Segment] = [] - def create_segment_from_points(p1: Point, p2: Point, name: str = "", color: str = None, label_text: str = "", label_visible: bool = False) -> Segment: + + def create_segment_from_points( + p1: Point, p2: Point, name: str = "", color: str = None, label_text: str = "", label_visible: bool = False + ) -> Segment: seg = Segment(p1, p2, color=color or "#000000") self.segments_created.append(seg) self.drawables.add(seg) @@ -68,7 +72,9 @@ def create_segment_from_points(p1: Point, p2: Point, name: str = "", color: str create_segment_from_points=create_segment_from_points, delete_segment=SimpleMock(return_value=True), create_segment=lambda *args, **kwargs: create_segment_from_points( - Point(args[0], args[1]), Point(args[2], args[3]), **{k: v for k, v in kwargs.items() if k != 'extra_graphics'} + Point(args[0], args[1]), + Point(args[2], args[3]), + **{k: v for k, v in kwargs.items() if k != "extra_graphics"}, ), ) diff --git a/static/client/client_tests/test_graph_utils.py b/static/client/client_tests/test_graph_utils.py index 3b6f2e61..fe06fe8c 100644 --- a/static/client/client_tests/test_graph_utils.py +++ b/static/client/client_tests/test_graph_utils.py @@ -12,7 +12,6 @@ class TestGraph(unittest.TestCase): - # ------------------------------------------------------------------ # Edge # ------------------------------------------------------------------ @@ -339,8 +338,12 @@ def test_is_simple_cycle_rejects_t_structure(self) -> None: def test_is_simple_cycle_rejects_disconnected(self) -> None: edges = [ - Edge("A", "B"), Edge("B", "C"), Edge("C", "A"), - Edge("D", "E"), Edge("E", "F"), Edge("F", "D"), + Edge("A", "B"), + Edge("B", "C"), + Edge("C", "A"), + Edge("D", "E"), + Edge("E", "F"), + Edge("F", "D"), ] self.assertFalse(GraphUtils.is_simple_cycle(edges)) @@ -896,8 +899,12 @@ def test_count_edge_crossings_crossing_diagonals(self) -> None: def test_count_edge_crossings_k4_complete(self) -> None: """K4 with vertices at corners has 1 crossing (diagonals).""" edges = [ - Edge("A", "B"), Edge("A", "C"), Edge("A", "D"), - Edge("B", "C"), Edge("B", "D"), Edge("C", "D"), + Edge("A", "B"), + Edge("A", "C"), + Edge("A", "D"), + Edge("B", "C"), + Edge("B", "D"), + Edge("C", "D"), ] positions = {"A": (0.0, 0.0), "B": (1.0, 0.0), "C": (1.0, 1.0), "D": (0.0, 1.0)} # A-C and B-D cross @@ -1220,10 +1227,7 @@ def test_edge_length_uniformity_ratio_half(self) -> None: """Half same length gives ratio 0.5.""" edges = [Edge("A", "B"), Edge("B", "C"), Edge("C", "D"), Edge("D", "E")] # A-B: 1, B-C: 1, C-D: 2, D-E: 2 - positions = { - "A": (0.0, 0.0), "B": (1.0, 0.0), "C": (2.0, 0.0), - "D": (4.0, 0.0), "E": (6.0, 0.0) - } + positions = {"A": (0.0, 0.0), "B": (1.0, 0.0), "C": (2.0, 0.0), "D": (4.0, 0.0), "E": (6.0, 0.0)} ratio = GraphUtils.edge_length_uniformity_ratio(edges, positions, tolerance=0.1) self.assertAlmostEqual(ratio, 0.5, places=5) diff --git a/static/client/client_tests/test_heptagon.py b/static/client/client_tests/test_heptagon.py index 51266a16..fa09ecd3 100644 --- a/static/client/client_tests/test_heptagon.py +++ b/static/client/client_tests/test_heptagon.py @@ -130,4 +130,3 @@ def test_is_irregular_helper(self) -> None: if __name__ == "__main__": unittest.main() - diff --git a/static/client/client_tests/test_hexagon.py b/static/client/client_tests/test_hexagon.py index 75a414a3..c7cb574d 100644 --- a/static/client/client_tests/test_hexagon.py +++ b/static/client/client_tests/test_hexagon.py @@ -123,4 +123,3 @@ def test_is_irregular_helper(self) -> None: if __name__ == "__main__": unittest.main() - diff --git a/static/client/client_tests/test_image_attachment.py b/static/client/client_tests/test_image_attachment.py index d9d1d135..ac7c11cc 100644 --- a/static/client/client_tests/test_image_attachment.py +++ b/static/client/client_tests/test_image_attachment.py @@ -101,11 +101,7 @@ def test_append_image(self) -> None: def test_append_multiple_images(self) -> None: """Test appending multiple images.""" - images = [ - "data:image/png;base64,img1", - "data:image/jpeg;base64,img2", - "data:image/png;base64,img3" - ] + images = ["data:image/png;base64,img1", "data:image/jpeg;base64,img2", "data:image/png;base64,img3"] for img in images: self.ai._attached_images.append(img) self.assertEqual(len(self.ai._attached_images), 3) @@ -116,7 +112,7 @@ def test_remove_image_by_index(self) -> None: self.ai._attached_images = [ "data:image/png;base64,img1", "data:image/png;base64,img2", - "data:image/png;base64,img3" + "data:image/png;base64,img3", ] self.ai._attached_images.pop(1) self.assertEqual(len(self.ai._attached_images), 2) @@ -125,10 +121,7 @@ def test_remove_image_by_index(self) -> None: def test_clear_images(self) -> None: """Test clearing all images.""" - self.ai._attached_images = [ - "data:image/png;base64,img1", - "data:image/png;base64,img2" - ] + self.ai._attached_images = ["data:image/png;base64,img1", "data:image/png;base64,img2"] self.ai._attached_images = [] self.assertEqual(len(self.ai._attached_images), 0) @@ -180,17 +173,14 @@ class TestPayloadGeneration(unittest.TestCase): def test_payload_includes_attached_images(self) -> None: """Test that payload includes attached_images array.""" - attached_images = [ - "data:image/png;base64,img1", - "data:image/jpeg;base64,img2" - ] + attached_images = ["data:image/png;base64,img1", "data:image/jpeg;base64,img2"] # Simulate payload creation payload = { "canvas_state": {}, "user_message": "What do you see?", "use_vision": False, - "attached_images": attached_images + "attached_images": attached_images, } self.assertIn("attached_images", payload) @@ -230,7 +220,7 @@ def test_payload_with_vision_and_images(self) -> None: "canvas_state": {}, "user_message": "Compare canvas and image", "use_vision": True, # Vision enabled - "attached_images": attached_images # Also has attached images + "attached_images": attached_images, # Also has attached images } self.assertTrue(payload["use_vision"]) @@ -405,6 +395,7 @@ def test_modal_display_style_hidden(self) -> None: def test_backdrop_click_detection(self) -> None: """Test backdrop click is detected by element id.""" + class MockEvent: def __init__(self, target_id: str) -> None: self.target = MagicMock() @@ -426,41 +417,23 @@ class TestMessageElementWithImages(unittest.TestCase): def test_message_with_images_parameter(self) -> None: """Test message creation accepts images parameter.""" + # Simulate _create_message_element signature def create_message_element( - sender: str, - message: str, - message_type: str = "normal", - images: Optional[List[str]] = None + sender: str, message: str, message_type: str = "normal", images: Optional[List[str]] = None ) -> Dict[str, Any]: - return { - "sender": sender, - "message": message, - "type": message_type, - "images": images - } - - result = create_message_element( - "User", - "Hello", - images=["data:image/png;base64,img1"] - ) + return {"sender": sender, "message": message, "type": message_type, "images": images} + + result = create_message_element("User", "Hello", images=["data:image/png;base64,img1"]) self.assertEqual(result["images"], ["data:image/png;base64,img1"]) def test_message_without_images(self) -> None: """Test message creation without images.""" + def create_message_element( - sender: str, - message: str, - message_type: str = "normal", - images: Optional[List[str]] = None + sender: str, message: str, message_type: str = "normal", images: Optional[List[str]] = None ) -> Dict[str, Any]: - return { - "sender": sender, - "message": message, - "type": message_type, - "images": images - } + return {"sender": sender, "message": message, "type": message_type, "images": images} result = create_message_element("User", "Hello") self.assertIsNone(result["images"]) diff --git a/static/client/client_tests/test_intersections.py b/static/client/client_tests/test_intersections.py index 228567e9..c30bcb8c 100644 --- a/static/client/client_tests/test_intersections.py +++ b/static/client/client_tests/test_intersections.py @@ -20,13 +20,7 @@ class TestIntersections(unittest.TestCase): - - def _assert_point_near( - self, - actual: tuple[float, float], - expected: tuple[float, float], - places: int = 5 - ) -> None: + def _assert_point_near(self, actual: tuple[float, float], expected: tuple[float, float], places: int = 5) -> None: self.assertAlmostEqual(actual[0], expected[0], places=places) self.assertAlmostEqual(actual[1], expected[1], places=places) @@ -148,7 +142,7 @@ def test_circle_ellipse_intersection(self) -> None: result = circle_ellipse_intersection(circle, ellipse) self.assertGreaterEqual(len(result), 2) for point in result: - dist_to_circle = math.sqrt(point[0]**2 + point[1]**2) + dist_to_circle = math.sqrt(point[0] ** 2 + point[1] ** 2) self.assertAlmostEqual(dist_to_circle, 1.5, places=1) def test_ellipse_ellipse_intersection(self) -> None: @@ -172,26 +166,17 @@ def test_element_element_dispatch(self) -> None: self.assertEqual(len(result), 2) def test_path_path_intersections(self) -> None: - square = CompositePath.from_points([ - (0.0, 0.0), (2.0, 0.0), (2.0, 2.0), (0.0, 2.0), (0.0, 0.0) - ]) - diagonal = CompositePath([ - LineSegment((-1.0, 1.0), (3.0, 1.0)) - ]) + square = CompositePath.from_points([(0.0, 0.0), (2.0, 0.0), (2.0, 2.0), (0.0, 2.0), (0.0, 0.0)]) + diagonal = CompositePath([LineSegment((-1.0, 1.0), (3.0, 1.0))]) result = path_path_intersections(square, diagonal) self.assertEqual(len(result), 2) def test_path_path_no_intersections(self) -> None: - square1 = CompositePath.from_points([ - (0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0), (0.0, 0.0) - ]) - square2 = CompositePath.from_points([ - (5.0, 5.0), (6.0, 5.0), (6.0, 6.0), (5.0, 6.0), (5.0, 5.0) - ]) + square1 = CompositePath.from_points([(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0), (0.0, 0.0)]) + square2 = CompositePath.from_points([(5.0, 5.0), (6.0, 5.0), (6.0, 6.0), (5.0, 6.0), (5.0, 5.0)]) result = path_path_intersections(square1, square2) self.assertEqual(len(result), 0) if __name__ == "__main__": unittest.main() - diff --git a/static/client/client_tests/test_label.py b/static/client/client_tests/test_label.py index f15a2362..99f1e697 100644 --- a/static/client/client_tests/test_label.py +++ b/static/client/client_tests/test_label.py @@ -76,7 +76,9 @@ def record_archive() -> None: container = DrawablesContainer() name_generator = SimpleMock( - generate_label_name=lambda preferred: (preferred.strip() if isinstance(preferred, str) and preferred.strip() else "demo_label") + generate_label_name=lambda preferred: ( + preferred.strip() if isinstance(preferred, str) and preferred.strip() else "demo_label" + ) ) dependency_manager = SimpleMock() proxy = SimpleMock() @@ -96,9 +98,7 @@ def record_archive() -> None: self.assertEqual(len(undo_calls), 2) def test_label_name_generator_produces_letter_sequence(self) -> None: - manager, _, _ = self._make_label_manager( - name_generator_factory=lambda canvas: DrawableNameGenerator(canvas) - ) + manager, _, _ = self._make_label_manager(name_generator_factory=lambda canvas: DrawableNameGenerator(canvas)) label_one = manager.create_label(0.0, 0.0, "auto") label_two = manager.create_label(1.0, 1.0, "auto2") label_three = manager.create_label(2.0, 2.0, "auto3") @@ -108,9 +108,7 @@ def test_label_name_generator_produces_letter_sequence(self) -> None: self.assertEqual(label_three.name, "label_C") def test_label_name_generator_handles_duplicate_preferred_names(self) -> None: - manager, _, _ = self._make_label_manager( - name_generator_factory=lambda canvas: DrawableNameGenerator(canvas) - ) + manager, _, _ = self._make_label_manager(name_generator_factory=lambda canvas: DrawableNameGenerator(canvas)) custom_one = manager.create_label(0.0, 0.0, "auto", name="CustomLabel") custom_two = manager.create_label(1.0, 1.0, "auto", name="CustomLabel") @@ -355,4 +353,3 @@ def test_label_get_state_includes_render_mode(self) -> None: render_mode = args.get("render_mode") self.assertIsInstance(render_mode, dict) self.assertEqual(render_mode.get("kind"), "world") - diff --git a/static/client/client_tests/test_label_overlap_resolver.py b/static/client/client_tests/test_label_overlap_resolver.py index 939d6a17..4a335f38 100644 --- a/static/client/client_tests/test_label_overlap_resolver.py +++ b/static/client/client_tests/test_label_overlap_resolver.py @@ -23,5 +23,3 @@ def test_block_height_affects_collision(self) -> None: dy_again = resolver.get_or_place_dy("line", (0.0, 10.0, 0.0, 10.0), step=10.0) self.assertEqual(dy_again, dy) - - diff --git a/static/client/client_tests/test_linear_algebra_utils.py b/static/client/client_tests/test_linear_algebra_utils.py index 41352862..cf1f1be0 100644 --- a/static/client/client_tests/test_linear_algebra_utils.py +++ b/static/client/client_tests/test_linear_algebra_utils.py @@ -422,4 +422,3 @@ def test_diag_expression_uses_allowlist_function(self) -> None: if __name__ == "__main__": unittest.main() - diff --git a/static/client/client_tests/test_math_functions.py b/static/client/client_tests/test_math_functions.py index 5435f1c3..488ccb03 100644 --- a/static/client/client_tests/test_math_functions.py +++ b/static/client/client_tests/test_math_functions.py @@ -11,41 +11,41 @@ class TestMathFunctions(unittest.TestCase): def setUp(self) -> None: # Mock points for use in some tests (math-space coordinates only) - self.point1 = SimpleMock(x=0, y=0, name='A') - self.point2 = SimpleMock(x=1, y=1, name='B') + self.point1 = SimpleMock(x=0, y=0, name="A") + self.point2 = SimpleMock(x=1, y=1, name="B") # Mock segment using mocked points self.segment = SimpleMock(point1=self.point1, point2=self.point2) def test_format_number_for_cartesian(self) -> None: test_cases = [ - (123456789, 6, '1.2e+8'), - (0.000123456789, 6, '0.00012'), - (123456, 6, '123456'), - (123.456, 6, '123.456'), - (0, 6, '0'), - (-123456789, 6, '-1.2e+8'), - (-0.000123456789, 6, '-0.00012'), - (-123456, 6, '-123456'), - (-123.456, 6, '-123.456'), - (1.23456789, 6, '1.23457'), - (0.000000123456789, 6, '1.2e-7'), - (123456.789, 6, '123457'), - (123.456789, 6, '123.457'), - (0.00000000000001, 6, '1e-14'), - (-1.23456789, 6, '-1.23457'), - (-0.000000123456789, 6, '-1.2e-7'), - (-123456.789, 6, '-123457'), - (-123.456789, 6, '-123.457'), - (-0.00000000000001, 6, '-1e-14'), - (123456789, 3, '1.2e+8'), - (0.000123456789, 3, '1.2e-4'), - (123456, 3, '1.2e+5'), - (123.456, 3, '123'), - (1.23456789, 3, '1.23'), - (0.000000123456789, 3, '1.2e-7'), - (123456.789, 3, '1.2e+5'), - (123.456789, 3, '123'), - (0.00000000000001, 3, '1e-14'), + (123456789, 6, "1.2e+8"), + (0.000123456789, 6, "0.00012"), + (123456, 6, "123456"), + (123.456, 6, "123.456"), + (0, 6, "0"), + (-123456789, 6, "-1.2e+8"), + (-0.000123456789, 6, "-0.00012"), + (-123456, 6, "-123456"), + (-123.456, 6, "-123.456"), + (1.23456789, 6, "1.23457"), + (0.000000123456789, 6, "1.2e-7"), + (123456.789, 6, "123457"), + (123.456789, 6, "123.457"), + (0.00000000000001, 6, "1e-14"), + (-1.23456789, 6, "-1.23457"), + (-0.000000123456789, 6, "-1.2e-7"), + (-123456.789, 6, "-123457"), + (-123.456789, 6, "-123.457"), + (-0.00000000000001, 6, "-1e-14"), + (123456789, 3, "1.2e+8"), + (0.000123456789, 3, "1.2e-4"), + (123456, 3, "1.2e+5"), + (123.456, 3, "123"), + (1.23456789, 3, "1.23"), + (0.000000123456789, 3, "1.2e-7"), + (123456.789, 3, "1.2e+5"), + (123.456789, 3, "123"), + (0.00000000000001, 3, "1e-14"), ] for i, (input, max_digits, expected) in enumerate(test_cases): with self.subTest(i=i): @@ -61,9 +61,9 @@ def test_segment_matches_coordinates(self) -> None: self.assertFalse(MathUtils.segment_matches_coordinates(self.segment, 2, 2, 3, 3)) # Incorrect coordinates def test_segment_matches_point_names(self) -> None: - self.assertTrue(MathUtils.segment_matches_point_names(self.segment, 'A', 'B')) - self.assertTrue(MathUtils.segment_matches_point_names(self.segment, 'B', 'A')) # Reverse order - self.assertFalse(MathUtils.segment_matches_point_names(self.segment, 'C', 'D')) # Incorrect names + self.assertTrue(MathUtils.segment_matches_point_names(self.segment, "A", "B")) + self.assertTrue(MathUtils.segment_matches_point_names(self.segment, "B", "A")) # Reverse order + self.assertFalse(MathUtils.segment_matches_point_names(self.segment, "C", "D")) # Incorrect names def test_segment_endpoints_helpers(self) -> None: segment = SimpleMock( @@ -354,6 +354,7 @@ def test_sample_ellipse_arc_additional(self) -> None: MathUtils.sample_ellipse_arc(0.0, 0.0, 2.0, 0.0, 0.0, math.pi), [], ) + def test_segment_endpoints_additional_cases(self) -> None: degenerate_segment = SimpleMock( point1=SimpleMock(x=-1.0, y=-1.0), @@ -499,31 +500,31 @@ def test_is_point_on_segment(self) -> None: # Test case: Point on simple horizontal segment self.assertTrue( MathUtils.is_point_on_segment(5, 0, 0, 0, 10, 0), - "Point (5,0) should be detected as being on segment from (0,0) to (10,0)" + "Point (5,0) should be detected as being on segment from (0,0) to (10,0)", ) # Test case: Point on simple vertical segment self.assertTrue( MathUtils.is_point_on_segment(0, 5, 0, 0, 0, 10), - "Point (0,5) should be detected as being on segment from (0,0) to (0,10)" + "Point (0,5) should be detected as being on segment from (0,0) to (0,10)", ) # Test case: Point slightly off segment self.assertFalse( MathUtils.is_point_on_segment(5, 5.1, 0, 0, 10, 10), - "Point (5,5.1) should be detected as NOT being on segment from (0,0) to (10,10)" + "Point (5,5.1) should be detected as NOT being on segment from (0,0) to (10,10)", ) # Test case: Point outside bounding box of segment self.assertFalse( MathUtils.is_point_on_segment(15, 15, 0, 0, 10, 10), - "Point (15,15) should be detected as NOT being on segment from (0,0) to (10,10)" + "Point (15,15) should be detected as NOT being on segment from (0,0) to (10,10)", ) # Test case: Using the specific coordinates from the user's example self.assertTrue( MathUtils.is_point_on_segment(100.0, 45.332, -122.0, -69.0, 311.0, 154.0), - "Point (100.0, 45.332) should be detected as being on segment from (-122.0, -69.0) to (311.0, 154.0)" + "Point (100.0, 45.332) should be detected as being on segment from (-122.0, -69.0) to (311.0, 154.0)", ) # Test case: Additional real-world examples on a longer segment @@ -533,36 +534,40 @@ def test_is_point_on_segment(self) -> None: # Point C at y = 100 self.assertTrue( - MathUtils.is_point_on_segment(-113.39, 100.0, segment_start_x, segment_start_y, segment_end_x, segment_end_y), - "Point C (-113.39, 100.0) should be detected as being on segment from (-245.0, 195.0) to (323.0, -215.0)" + MathUtils.is_point_on_segment( + -113.39, 100.0, segment_start_x, segment_start_y, segment_end_x, segment_end_y + ), + "Point C (-113.39, 100.0) should be detected as being on segment from (-245.0, 195.0) to (323.0, -215.0)", ) # Point D at y = -24 self.assertTrue( MathUtils.is_point_on_segment(58.4, -24.0, segment_start_x, segment_start_y, segment_end_x, segment_end_y), - "Point D (58.4, -24.0) should be detected as being on segment from (-245.0, 195.0) to (323.0, -215.0)" + "Point D (58.4, -24.0) should be detected as being on segment from (-245.0, 195.0) to (323.0, -215.0)", ) # Point E at x = 3 self.assertTrue( MathUtils.is_point_on_segment(3.0, 15.99, segment_start_x, segment_start_y, segment_end_x, segment_end_y), - "Point E (3.0, 15.99) should be detected as being on segment from (-245.0, 195.0) to (323.0, -215.0)" + "Point E (3.0, 15.99) should be detected as being on segment from (-245.0, 195.0) to (323.0, -215.0)", ) # Point F at x = -199 self.assertTrue( - MathUtils.is_point_on_segment(-199.0, 161.8, segment_start_x, segment_start_y, segment_end_x, segment_end_y), - "Point F (-199.0, 161.8) should be detected as being on segment from (-245.0, 195.0) to (323.0, -215.0)" + MathUtils.is_point_on_segment( + -199.0, 161.8, segment_start_x, segment_start_y, segment_end_x, segment_end_y + ), + "Point F (-199.0, 161.8) should be detected as being on segment from (-245.0, 195.0) to (323.0, -215.0)", ) # Test case: Calculate a point that's exactly on the segment using linear interpolation t = 0.51 - point_x = -122.0 * (1-t) + 311.0 * t - point_y = -69.0 * (1-t) + 154.0 * t + point_x = -122.0 * (1 - t) + 311.0 * t + point_y = -69.0 * (1 - t) + 154.0 * t self.assertTrue( MathUtils.is_point_on_segment(point_x, point_y, -122.0, -69.0, 311.0, 154.0), - f"Interpolated point ({point_x}, {point_y}) should be detected as being on segment from (-122.0, -69.0) to (311.0, 154.0)" + f"Interpolated point ({point_x}, {point_y}) should be detected as being on segment from (-122.0, -69.0) to (311.0, 154.0)", ) def test_get_triangle_area(self) -> None: @@ -577,45 +582,67 @@ def test_get_rectangle_area(self) -> None: self.assertEqual(MathUtils.get_rectangle_area(p1, p2), 6) def test_cross_product(self) -> None: - self.assertEqual(MathUtils.cross_product(Position(0, 0), Position(1, 0), Position(0, 1)), 1) # "Perpendicular vectors" - self.assertEqual(MathUtils.cross_product(Position(0, 0), Position(1, 1), Position(1, 1)), 0) # "Zero vector test" - self.assertEqual(MathUtils.cross_product(Position(1, 2), Position(-1, -1), Position(2, -3)), 13) # "Negative values test" - self.assertEqual(MathUtils.cross_product(Position(0, 0), Position(1, 0), Position(2, 0)), 0) # "Collinear vectors" + self.assertEqual( + MathUtils.cross_product(Position(0, 0), Position(1, 0), Position(0, 1)), 1 + ) # "Perpendicular vectors" + self.assertEqual( + MathUtils.cross_product(Position(0, 0), Position(1, 1), Position(1, 1)), 0 + ) # "Zero vector test" + self.assertEqual( + MathUtils.cross_product(Position(1, 2), Position(-1, -1), Position(2, -3)), 13 + ) # "Negative values test" + self.assertEqual( + MathUtils.cross_product(Position(0, 0), Position(1, 0), Position(2, 0)), 0 + ) # "Collinear vectors" def test_dot_product(self) -> None: - self.assertEqual(MathUtils.dot_product(Position(0, 0), Position(1, 0), Position(1, 0)), 1) # "Parallel vectors" - self.assertEqual(MathUtils.dot_product(Position(0, 0), Position(0, 0), Position(0, 1)), 0) # "Zero vector test" - self.assertEqual(MathUtils.dot_product(Position(1, 2), Position(-1, -1), Position(2, -3)), 13) # "Negative values test" - self.assertEqual(MathUtils.dot_product(Position(0, 0), Position(1, 0), Position(0, 1)), 0) # "Perpendicular vectors" + self.assertEqual(MathUtils.dot_product(Position(0, 0), Position(1, 0), Position(1, 0)), 1) # "Parallel vectors" + self.assertEqual(MathUtils.dot_product(Position(0, 0), Position(0, 0), Position(0, 1)), 0) # "Zero vector test" + self.assertEqual( + MathUtils.dot_product(Position(1, 2), Position(-1, -1), Position(2, -3)), 13 + ) # "Negative values test" + self.assertEqual( + MathUtils.dot_product(Position(0, 0), Position(1, 0), Position(0, 1)), 0 + ) # "Perpendicular vectors" def test_is_right_angle(self) -> None: - self.assertEqual(MathUtils.is_right_angle(Position(0, 0), Position(1, 0), Position(0, 1)), True) # "Right angle" - self.assertEqual(MathUtils.is_right_angle(Position(0, 0), Position(1, 1), Position(1, 0)), False) # "Not right angle" - self.assertEqual(MathUtils.is_right_angle(Position(0, 0), Position(1, 0), Position(1, 1)), False) # "Almost right angle but not quite" + self.assertEqual( + MathUtils.is_right_angle(Position(0, 0), Position(1, 0), Position(0, 1)), True + ) # "Right angle" + self.assertEqual( + MathUtils.is_right_angle(Position(0, 0), Position(1, 1), Position(1, 0)), False + ) # "Not right angle" + self.assertEqual( + MathUtils.is_right_angle(Position(0, 0), Position(1, 0), Position(1, 1)), False + ) # "Almost right angle but not quite" def test_calculate_angle_degrees(self) -> None: # Vertex at origin for simplicity in these tests - v = (0,0) + v = (0, 0) # Test cases: (arm1_coords, arm2_coords, expected_degrees) test_cases = [ - ((1,0), (1,0), None), # Arm2 coincident with Arm1 (relative to vertex, leads to zero length vector for arm2 if not careful, or zero angle) - # Actually, MathUtils.calculate_angle_degrees has zero-length arm check based on v1x, v1y etc. - # If arm1=(1,0) and arm2=(1,0), v1=(1,0), v2=(1,0). angle1=0, angle2=0. diff=0. result=0. - # This case is more for are_points_valid_for_angle_geometry which checks p1 vs p2. - # For calculate_angle_degrees itself, if p1 and p2 are same *and distinct from vertex*, it's 0 deg. - ((1,0), (2,0), 0.0), # Collinear, same direction from vertex - ((1,0), (1,1), 45.0), # 45 degrees - ((1,0), (0,1), 90.0), # 90 degrees - ((1,0), (-1,1), 135.0), # 135 degrees - ((1,0), (-1,0), 180.0), # 180 degrees - ((1,0), (-1,-1), 225.0), # 225 degrees - ((1,0), (0,-1), 270.0), # 270 degrees - ((1,0), (1,-1), 315.0), # 315 degrees + ( + (1, 0), + (1, 0), + None, + ), # Arm2 coincident with Arm1 (relative to vertex, leads to zero length vector for arm2 if not careful, or zero angle) + # Actually, MathUtils.calculate_angle_degrees has zero-length arm check based on v1x, v1y etc. + # If arm1=(1,0) and arm2=(1,0), v1=(1,0), v2=(1,0). angle1=0, angle2=0. diff=0. result=0. + # This case is more for are_points_valid_for_angle_geometry which checks p1 vs p2. + # For calculate_angle_degrees itself, if p1 and p2 are same *and distinct from vertex*, it's 0 deg. + ((1, 0), (2, 0), 0.0), # Collinear, same direction from vertex + ((1, 0), (1, 1), 45.0), # 45 degrees + ((1, 0), (0, 1), 90.0), # 90 degrees + ((1, 0), (-1, 1), 135.0), # 135 degrees + ((1, 0), (-1, 0), 180.0), # 180 degrees + ((1, 0), (-1, -1), 225.0), # 225 degrees + ((1, 0), (0, -1), 270.0), # 270 degrees + ((1, 0), (1, -1), 315.0), # 315 degrees # Test None returns for zero-length arms from vertex - ((0,0), (1,1), None), # Arm1 is at vertex - ((1,1), (0,0), None), # Arm2 is at vertex + ((0, 0), (1, 1), None), # Arm1 is at vertex + ((1, 1), (0, 0), None), # Arm2 is at vertex # Test order of arms (p1, p2 vs p2, p1) - ((0,1), (1,0), 270.0), # P1=(0,1), P2=(1,0) -> angle from +Y to +X is 270 deg CCW + ((0, 1), (1, 0), 270.0), # P1=(0,1), P2=(1,0) -> angle from +Y to +X is 270 deg CCW ] for i, (p1_coords, p2_coords, expected) in enumerate(test_cases): @@ -624,13 +651,13 @@ def test_calculate_angle_degrees(self) -> None: if expected is None: self.assertIsNone(result) else: - self.assertIsNotNone(result) # Make sure it's not None before almostEqual + self.assertIsNotNone(result) # Make sure it's not None before almostEqual self.assertAlmostEqual(result, expected, places=5) # Test with non-origin vertex - v_offset = (5,5) - p1_offset = (6,5) # (1,0) relative to v_offset - p2_offset = (5,6) # (0,1) relative to v_offset + v_offset = (5, 5) + p1_offset = (6, 5) # (1,0) relative to v_offset + p2_offset = (5, 6) # (0,1) relative to v_offset self.assertAlmostEqual(MathUtils.calculate_angle_degrees(v_offset, p1_offset, p2_offset), 90.0, places=5) def test_are_points_valid_for_angle_geometry(self) -> None: @@ -638,26 +665,31 @@ def test_are_points_valid_for_angle_geometry(self) -> None: v = (0.0, 0.0) p1 = (1.0, 0.0) p2 = (0.0, 1.0) - p3 = (1.0, 0.0) # Same as p1 + p3 = (1.0, 0.0) # Same as p1 p4_close_to_v = (MathUtils.EPSILON / 2, MathUtils.EPSILON / 2) p5_close_to_p1 = (p1[0] + MathUtils.EPSILON / 2, p1[1] + MathUtils.EPSILON / 2) test_cases = [ - (v, p1, p2, True), # Valid case - (v, v, p2, False), # Vertex == Arm1 - (v, p1, v, False), # Vertex == Arm2 - (v, p1, p1, False), # Arm1 == Arm2 (p1 used twice for arm2) - (v, p1, p3, False), # Arm1 == Arm2 (p3 is same as p1) - (v, v, v, False), # All three coincident at vertex - (p1, p1, p1, False), # All three coincident at p1 + (v, p1, p2, True), # Valid case + (v, v, p2, False), # Vertex == Arm1 + (v, p1, v, False), # Vertex == Arm2 + (v, p1, p1, False), # Arm1 == Arm2 (p1 used twice for arm2) + (v, p1, p3, False), # Arm1 == Arm2 (p3 is same as p1) + (v, v, v, False), # All three coincident at vertex + (p1, p1, p1, False), # All three coincident at p1 # Epsilon tests - (v, p4_close_to_v, p2, False), # Arm1 too close to Vertex - (v, p1, p4_close_to_v, False), # Arm2 too close to Vertex - (v, p1, p5_close_to_p1, False), # Arm2 too close to Arm1 - ((0,0), (1,0), (1.0000000001, 0.0000000001), False) # arm2 very close to arm1 (within typical float precision but potentially outside strict epsilon for p1 vs p2) - # The are_points_valid uses direct comparison with EPSILON for each pair. - # If MathUtils.EPSILON = 1e-9, (1.0, 0.0) vs (1.0000000001, 0.0000000001) - # dx = 1e-10, dy = 1e-10. Both are < EPSILON. So this should be False. + (v, p4_close_to_v, p2, False), # Arm1 too close to Vertex + (v, p1, p4_close_to_v, False), # Arm2 too close to Vertex + (v, p1, p5_close_to_p1, False), # Arm2 too close to Arm1 + ( + (0, 0), + (1, 0), + (1.0000000001, 0.0000000001), + False, + ), # arm2 very close to arm1 (within typical float precision but potentially outside strict epsilon for p1 vs p2) + # The are_points_valid uses direct comparison with EPSILON for each pair. + # If MathUtils.EPSILON = 1e-9, (1.0, 0.0) vs (1.0000000001, 0.0000000001) + # dx = 1e-10, dy = 1e-10. Both are < EPSILON. So this should be False. ] for i, (vc, ac1, ac2, expected) in enumerate(test_cases): @@ -835,36 +867,37 @@ def test_variance(self) -> None: def test_check_div_by_zero(self) -> None: # Test cases that should raise ZeroDivisionError zero_division_cases = [ - "1/0", # Simple division by zero - "1/(3-3)", # Division by parenthesized zero - "1/(2*0)", # Direct multiplication by zero in denominator - "1/(0*x)", # Variable expression evaluating to zero - "10/(x-2)", # Variable expression evaluating to zero with variables - "1/(3*0+1-1)", # Complex expression evaluating to zero - "1/(-0)", # Negative zero - "1/(0.0)", # Zero as float - "1/(0e0)", # Zero in scientific notation + "1/0", # Simple division by zero + "1/(3-3)", # Division by parenthesized zero + "1/(2*0)", # Direct multiplication by zero in denominator + "1/(0*x)", # Variable expression evaluating to zero + "10/(x-2)", # Variable expression evaluating to zero with variables + "1/(3*0+1-1)", # Complex expression evaluating to zero + "1/(-0)", # Negative zero + "1/(0.0)", # Zero as float + "1/(0e0)", # Zero in scientific notation ] # Nested parentheses cases nested_zero_division_cases = [ - "1/2/(1-1)", # Chained division with zero - "1/(2/(1-1))", # Nested division with zero - "1/9*(3-3)", # Multiplication after division resulting in zero - "1/(9*(3-3))", # Division by parenthesized multiplication resulting in zero - "2/((1-1)*5)", # Division by zero with extra parentheses - "1/((2-2)*3*(4+1))", # Multiple terms evaluating to zero - "2/(1/(1-1))", # Division by infinity (division by zero in denominator) - "1/((3-3)/(4-4))", # Multiple zeros in nested divisions - "1/9*3*(1-1)", # Multiple operations after division resulting in zero - "1/3*2*(5-5)*4", # Zero product in denominator with multiple terms + "1/2/(1-1)", # Chained division with zero + "1/(2/(1-1))", # Nested division with zero + "1/9*(3-3)", # Multiplication after division resulting in zero + "1/(9*(3-3))", # Division by parenthesized multiplication resulting in zero + "2/((1-1)*5)", # Division by zero with extra parentheses + "1/((2-2)*3*(4+1))", # Multiple terms evaluating to zero + "2/(1/(1-1))", # Division by infinity (division by zero in denominator) + "1/((3-3)/(4-4))", # Multiple zeros in nested divisions + "1/9*3*(1-1)", # Multiple operations after division resulting in zero + "1/3*2*(5-5)*4", # Zero product in denominator with multiple terms ] # Test all zero division cases for expr in zero_division_cases: result = MathUtils.evaluate(expr) - self.assertTrue(isinstance(result, str) and "Error" in result, - f"Expected error for expression: {expr}, got {result}") + self.assertTrue( + isinstance(result, str) and "Error" in result, f"Expected error for expression: {expr}, got {result}" + ) # Test nested zero division cases # Note: The result of 0.0 for these cases is not typical and might be due to JavaScript's handling. @@ -875,86 +908,97 @@ def test_check_div_by_zero(self) -> None: elif expr == "1/((3-3)/(4-4))": # JavaScript returns nan for this case self.assertEqual(str(result).lower(), "nan", f"Expected nan for expression: {expr}, got {result}") else: - self.assertTrue(isinstance(result, str) and "Error" in result, - f"Expected error for nested expression: {expr}, got {result}") + self.assertTrue( + isinstance(result, str) and "Error" in result, + f"Expected error for nested expression: {expr}, got {result}", + ) # Test with variables result = MathUtils.evaluate("10/(x-2)", {"x": 2}) - self.assertTrue(isinstance(result, str) and "Error" in result, - f"Expected error for expression with x=2, got {result}") + self.assertTrue( + isinstance(result, str) and "Error" in result, f"Expected error for expression with x=2, got {result}" + ) # Test cases that should NOT raise ZeroDivisionError valid_division_cases = [ - "1/2", # Simple valid division - "1/(3-2)", # Valid division with parentheses - "1/2/3", # Chained valid division - "1/(2/3)", # Nested valid division - "1/9*(3-2)", # Valid multiplication after division - "1/(9*(3-2))", # Valid division with parenthesized multiplication - "2/((1+1)*5)", # Valid division with extra parentheses - "1/(2*1)", # Valid multiplication in denominator - "1/(x+1)", # Valid variable expression - "10/(x+2)", # Valid variable expression with variables - "1/(3*2+1)", # Valid complex expression - "1/((2+2)*3*(4+1))", # Valid multiple terms - "2/(1/(1+1))", # Valid nested division - "1/((3-2)/(4-3))", # Valid nested divisions - "1/9*3*(2-1)", # Valid multiple operations - "1/3*2*(5+5)*4", # Valid product in denominator - "1/3+4/5", # Multiple separate divisions - "1/3 + 4/5", # Divisions with whitespace - "1 / 3 * 2 * (5+5) * 4" # Complex expression with whitespace + "1/2", # Simple valid division + "1/(3-2)", # Valid division with parentheses + "1/2/3", # Chained valid division + "1/(2/3)", # Nested valid division + "1/9*(3-2)", # Valid multiplication after division + "1/(9*(3-2))", # Valid division with parenthesized multiplication + "2/((1+1)*5)", # Valid division with extra parentheses + "1/(2*1)", # Valid multiplication in denominator + "1/(x+1)", # Valid variable expression + "10/(x+2)", # Valid variable expression with variables + "1/(3*2+1)", # Valid complex expression + "1/((2+2)*3*(4+1))", # Valid multiple terms + "2/(1/(1+1))", # Valid nested division + "1/((3-2)/(4-3))", # Valid nested divisions + "1/9*3*(2-1)", # Valid multiple operations + "1/3*2*(5+5)*4", # Valid product in denominator + "1/3+4/5", # Multiple separate divisions + "1/3 + 4/5", # Divisions with whitespace + "1 / 3 * 2 * (5+5) * 4", # Complex expression with whitespace ] # Test all valid division cases for expr in valid_division_cases: result = MathUtils.evaluate(expr, {"x": 5}) # Using x=5 for variable cases - self.assertFalse(isinstance(result, str) and "Error" in result, - f"Unexpected error for valid expression: {expr}, got {result}") - self.assertIsInstance(result, (int, float, str), - f"Result should be numeric or string for expression: {expr}") + self.assertFalse( + isinstance(result, str) and "Error" in result, + f"Unexpected error for valid expression: {expr}, got {result}", + ) + self.assertIsInstance( + result, (int, float, str), f"Result should be numeric or string for expression: {expr}" + ) # Test with different variable values result = MathUtils.evaluate("1/(x+1)", {"x": -1}) # Should raise error - self.assertTrue(isinstance(result, str) and "Error" in result, - f"Expected error for expression with x=-1, got {result}") + self.assertTrue( + isinstance(result, str) and "Error" in result, f"Expected error for expression with x=-1, got {result}" + ) # Test edge cases edge_cases = [ - ("1/1e-100", False), # Very small but non-zero denominator + ("1/1e-100", False), # Very small but non-zero denominator ("1/(1-0.999999999)", False), # Nearly zero but not quite - ("1/(-0)", True), # Negative zero - ("1/(0.0)", True), # Zero as float - ("1/(0e0)", True), # Zero in scientific notation + ("1/(-0)", True), # Negative zero + ("1/(0.0)", True), # Zero as float + ("1/(0e0)", True), # Zero in scientific notation ] for expr, should_raise in edge_cases: result = MathUtils.evaluate(expr) if should_raise: - self.assertTrue(isinstance(result, str) and "Error" in result, - f"Expected error for edge case: {expr}, got {result}") + self.assertTrue( + isinstance(result, str) and "Error" in result, f"Expected error for edge case: {expr}, got {result}" + ) else: - self.assertFalse(isinstance(result, str) and "Error" in result, - f"Unexpected error for edge case: {expr}, got {result}") - self.assertIsInstance(result, (int, float, str), - f"Result should be numeric or string for edge case: {expr}") + self.assertFalse( + isinstance(result, str) and "Error" in result, + f"Unexpected error for edge case: {expr}, got {result}", + ) + self.assertIsInstance( + result, (int, float, str), f"Result should be numeric or string for edge case: {expr}" + ) def test_limit(self) -> None: - result = MathUtils.limit('sin(x) / x', 'x', 0) + result = MathUtils.limit("sin(x) / x", "x", 0) result = float(result) # convert result to float self.assertEqual(result, 1.0) def test_derivative(self) -> None: - result = MathUtils.derivative('x^2', 'x') + result = MathUtils.derivative("x^2", "x") self.assertEqual(result, "2*x") def test_integral_indefinite(self) -> None: - result = MathUtils.integral('x^2', 'x') + result = MathUtils.integral("x^2", "x") result = MathUtils.simplify(result) # simplify the result self.assertEqual(result, "0.3333333333333333*x^3") def test_integral(self) -> None: - result = MathUtils.integral('x^2', 'x', 0, 1) + result = MathUtils.integral("x^2", "x", 0, 1) result = float(result) # convert result to float self.assertAlmostEqual(result, 0.333, places=3) @@ -972,15 +1016,15 @@ def test_numeric_integrate_rejects_too_many_steps(self) -> None: MathUtils.numeric_integrate("x", "x", 0, 1, "midpoint", 10001) def test_simplify(self) -> None: - result = MathUtils.simplify('x^2 + 2*x + 1') + result = MathUtils.simplify("x^2 + 2*x + 1") self.assertEqual(result, "(1+x)^2") def test_expand(self) -> None: - result = MathUtils.expand('(x + 1)^2') + result = MathUtils.expand("(x + 1)^2") self.assertEqual(result, "1+2*x+x^2") def test_factor(self) -> None: - result = MathUtils.factor('x^2 - 1') + result = MathUtils.factor("x^2 - 1") self.assertEqual(result, "(-1+x)*(1+x)") def test_get_equation_type_with_linear_equation(self) -> None: @@ -1090,13 +1134,13 @@ def test_determine_max_number_of_solutions_other_non_linear(self) -> None: self.assertEqual(result, 0, "Other non-linear equations should indicate complex or uncertain scenarios.") def test_solve1(self) -> None: - result = MathUtils.solve('x^2 - 4', 'x') + result = MathUtils.solve("x^2 - 4", "x") result = json.loads(result) # parse result from JSON string to list result = [float(r) for r in result] # convert results to floats self.assertEqual(result, [2.0, -2.0]) def test_solve2(self) -> None: - result = MathUtils.solve('0.4 * x + 37.2 = -0.9 * x - 8', 'x') + result = MathUtils.solve("0.4 * x + 37.2 = -0.9 * x - 8", "x") result = json.loads(result) # Parse result from JSON string to list # Assuming the result is always a list with a single item for this test case solution = float(result[0]) # Convert the first (and only) result to float @@ -1132,13 +1176,13 @@ def test_solve_linear_quadratic_two_real_solutions(self) -> None: self.assertEqual(result, {"x1": 0.0, "y1": 1.0, "x2": -1.0, "y2": 0.0}) def test_solve_system_of_equations_linear(self) -> None: - result = MathUtils.solve_system_of_equations(['x + y = 4', 'x - y = 2']) + result = MathUtils.solve_system_of_equations(["x + y = 4", "x - y = 2"]) result = dict(item.split(" = ") for item in result.split(", ")) # parse result from string to dictionary result = {k: float(v) for k, v in result.items()} # convert results to floats self.assertEqual(result, {"x": 3.0, "y": 1.0}) def test_solve_system_of_equations_quadratic_linear(self) -> None: - result = MathUtils.solve_system_of_equations(['x^2 = y', '-x + 2 = y']) + result = MathUtils.solve_system_of_equations(["x^2 = y", "-x + 2 = y"]) result = dict(item.split(" = ") for item in result.split(", ")) # parse result from string to dictionary result = {k: float(v) for k, v in result.items()} # convert results to floats self.assertEqual(result, {"x1": 1.0, "y1": 1.0, "x2": -2.0, "y2": 4.0}) @@ -1177,7 +1221,7 @@ def test_calculate_vertical_asymptotes(self) -> None: # (-10 + π/2)/π ≤ n ≤ (10 + π/2)/π # -3.02 ≤ n ≤ 3.66 # Therefore n goes from -2 to 3 inclusive - expected = sorted([round((-math.pi/2 + n*math.pi), 6) for n in range(-2, 4)]) + expected = sorted([round((-math.pi / 2 + n * math.pi), 6) for n in range(-2, 4)]) actual = sorted([round(x, 6) for x in result]) self.assertEqual(actual, expected, "tan(x) should have correct asymptotes in [-10, 10]") @@ -1187,7 +1231,7 @@ def test_calculate_vertical_asymptotes(self) -> None: # (-5 + π/2)/π ≤ n ≤ (5 + π/2)/π # -1.41 ≤ n ≤ 2.07 # Therefore n goes from -1 to 2 inclusive - expected = sorted([round((-math.pi/2 + n*math.pi), 6) for n in range(-1, 3)]) + expected = sorted([round((-math.pi / 2 + n * math.pi), 6) for n in range(-1, 3)]) actual = sorted([round(x, 6) for x in result]) self.assertEqual(actual, expected, "tan(x) should have correct asymptotes in [-5, 5]") @@ -1197,7 +1241,7 @@ def test_calculate_vertical_asymptotes(self) -> None: # (-3 + π/2)/π ≤ n ≤ (3 + π/2)/π # -0.77 ≤ n ≤ 1.43 # Therefore n goes from 0 to 1 inclusive - expected = sorted([round((-math.pi/2 + n*math.pi), 6) for n in range(0, 2)]) + expected = sorted([round((-math.pi / 2 + n * math.pi), 6) for n in range(0, 2)]) actual = sorted([round(x, 6) for x in result]) self.assertEqual(actual, expected, "tan(x) should have correct asymptotes in [-3, 3]") @@ -1247,9 +1291,11 @@ def test_calculate_asymptotes(self) -> None: # (-5 + π/2)/π ≤ n ≤ (5 + π/2)/π # -1.41 ≤ n ≤ 2.07 # Therefore n goes from -1 to 2 inclusive (all values that give asymptotes within [-5, 5]) - expected_vert = sorted([round((-math.pi/2 + n*math.pi), 6) for n in range(-1, 3)]) + expected_vert = sorted([round((-math.pi / 2 + n * math.pi), 6) for n in range(-1, 3)]) actual_vert = sorted([round(x, 6) for x in vert]) - self.assertEqual(actual_vert, expected_vert, "tan(x) should have vertical asymptotes at x = π/2 + nπ within bounds") + self.assertEqual( + actual_vert, expected_vert, "tan(x) should have vertical asymptotes at x = π/2 + nπ within bounds" + ) self.assertEqual(horiz, [], "tan(x) should have no horizontal asymptotes") self.assertEqual(disc, [], "tan(x) should have no point discontinuities") @@ -1276,7 +1322,9 @@ def test_calculate_point_discontinuities(self) -> None: # Test floor function with bounds result = MathUtils.calculate_point_discontinuities("floor(x)", -2, 2) - self.assertEqual(result, [-2, -1, 0, 1, 2], "Floor function should have discontinuities at integers within bounds") + self.assertEqual( + result, [-2, -1, 0, 1, 2], "Floor function should have discontinuities at integers within bounds" + ) # Test ceil function with bounds result = MathUtils.calculate_point_discontinuities("ceil(x)", -1.5, 1.5) @@ -1334,18 +1382,15 @@ def test_function_vertical_asymptote_path_breaking(self) -> None: coordinate_mapper.sync_from_canvas(mock_canvas) # Test function with vertical asymptotes: tan(x) has asymptotes at π/2 + nπ - function = Function( - function_string="tan(x)", - name="test_tan", - left_bound=-5, - right_bound=5 - ) + function = Function(function_string="tan(x)", name="test_tan", left_bound=-5, right_bound=5) # Generate paths via FunctionRenderable paths = FunctionRenderable(function, coordinate_mapper).build_screen_paths().paths # Test that we have multiple paths (indicating path breaks at asymptotes) - self.assertGreater(len(paths), 1, "tan(x) should generate multiple separate paths due to vertical asymptotes") + self.assertGreater( + len(paths), 1, "tan(x) should generate multiple separate paths due to vertical asymptotes" + ) # Test that no path spans across a vertical asymptote for path in paths: @@ -1356,24 +1401,25 @@ def test_function_vertical_asymptote_path_breaking(self) -> None: path_max_x = max(x_coords) # Check that no vertical asymptote lies within this path's x range (exclusive) - asymptotes_in_path = [asym for asym in function.vertical_asymptotes - if path_min_x < asym < path_max_x] + asymptotes_in_path = [ + asym for asym in function.vertical_asymptotes if path_min_x < asym < path_max_x + ] - self.assertEqual(len(asymptotes_in_path), 0, - f"Path from x={path_min_x:.3f} to x={path_max_x:.3f} should not span across vertical asymptote(s) {asymptotes_in_path}") + self.assertEqual( + len(asymptotes_in_path), + 0, + f"Path from x={path_min_x:.3f} to x={path_max_x:.3f} should not span across vertical asymptote(s) {asymptotes_in_path}", + ) # Test with another function: 1/x has asymptote at x=0 - function2 = Function( - function_string="1/x", - name="test_reciprocal", - left_bound=-2, - right_bound=2 - ) + function2 = Function(function_string="1/x", name="test_reciprocal", left_bound=-2, right_bound=2) paths2 = FunctionRenderable(function2, coordinate_mapper).build_screen_paths().paths # Should have exactly 2 paths (one for x < 0, one for x > 0) - self.assertGreaterEqual(len(paths2), 2, "1/x should generate at least 2 separate paths due to vertical asymptote at x=0") + self.assertGreaterEqual( + len(paths2), 2, "1/x should generate at least 2 separate paths due to vertical asymptote at x=0" + ) # Verify no path spans across x=0 for path in paths2: @@ -1384,8 +1430,10 @@ def test_function_vertical_asymptote_path_breaking(self) -> None: # Path should not cross x=0 crosses_zero = path_min_x < 0 < path_max_x - self.assertFalse(crosses_zero, - f"Path from x={path_min_x:.3f} to x={path_max_x:.3f} should not cross the vertical asymptote at x=0") + self.assertFalse( + crosses_zero, + f"Path from x={path_min_x:.3f} to x={path_max_x:.3f} should not cross the vertical asymptote at x=0", + ) print("Vertical asymptote path breaking test passed successfully") @@ -1422,12 +1470,7 @@ def test_function_path_continuity(self) -> None: coordinate_mapper.sync_from_canvas(mock_canvas) # Test a continuous function: sin(x) should have one continuous path - function_sin = Function( - function_string="sin(x)", - name="test_sin", - left_bound=-10, - right_bound=10 - ) + function_sin = Function(function_string="sin(x)", name="test_sin", left_bound=-10, right_bound=10) paths_sin = FunctionRenderable(function_sin, coordinate_mapper).build_screen_paths().paths @@ -1442,21 +1485,18 @@ def test_function_path_continuity(self) -> None: # Check continuity within the path max_gap = 0 for i in range(1, len(path)): - x1, _ = mock_canvas.coordinate_mapper.screen_to_math(path[i-1][0], path[i-1][1]) + x1, _ = mock_canvas.coordinate_mapper.screen_to_math(path[i - 1][0], path[i - 1][1]) x2, _ = mock_canvas.coordinate_mapper.screen_to_math(path[i][0], path[i][1]) gap = abs(x2 - x1) max_gap = max(max_gap, gap) # The maximum gap between consecutive points shouldn't be too large - self.assertLess(max_gap, 1.0, f"sin(x) should have continuous points with max gap < 1.0, found {max_gap}") + self.assertLess( + max_gap, 1.0, f"sin(x) should have continuous points with max gap < 1.0, found {max_gap}" + ) # Test a quadratic function: x^2 should also be one continuous path - function_quad = Function( - function_string="x^2", - name="test_quad", - left_bound=-5, - right_bound=5 - ) + function_quad = Function(function_string="x^2", name="test_quad", left_bound=-5, right_bound=5) paths_quad = FunctionRenderable(function_quad, coordinate_mapper).build_screen_paths().paths @@ -1468,7 +1508,7 @@ def test_function_path_continuity(self) -> None: function_string="sin(x/10) + cos(x/15)", # Two different frequencies, no asymptotes name="test_moderate", left_bound=-20, - right_bound=20 + right_bound=20, ) paths_moderate = FunctionRenderable(function_moderate, coordinate_mapper).build_screen_paths().paths @@ -1480,23 +1520,26 @@ def test_function_path_continuity(self) -> None: function_string="10 * sin(x / 20)", # Simpler version to test basic functionality name="test_complex", left_bound=-50, # Even safer range - right_bound=50 + right_bound=50, ) paths_complex = FunctionRenderable(function_complex, coordinate_mapper).build_screen_paths().paths # This simplified function should definitely generate paths - self.assertGreater(len(paths_complex), 0, - f"Complex function should generate at least one path. " - f"Function: {function_complex.function_string}, " - f"Generated {len(paths_complex)} paths") + self.assertGreater( + len(paths_complex), + 0, + f"Complex function should generate at least one path. " + f"Function: {function_complex.function_string}, " + f"Generated {len(paths_complex)} paths", + ) # Test a simpler case to ensure basic functionality function_simple = Function( function_string="sin(x/10)", # Simple sine function name="test_simple", left_bound=-10, - right_bound=10 + right_bound=10, ) paths_simple = FunctionRenderable(function_simple, coordinate_mapper).build_screen_paths().paths @@ -1509,7 +1552,7 @@ def test_function_path_continuity(self) -> None: function_string="100 * sin(x / 50) + 50 * tan(x / 100)", name="test_original", left_bound=-30, # Very small, safe range - right_bound=30 + right_bound=30, ) paths_original = FunctionRenderable(function_original, coordinate_mapper).build_screen_paths().paths @@ -1519,7 +1562,9 @@ def test_function_path_continuity(self) -> None: total_points_orig = sum(len(path) for path in paths_original) self.assertGreater(total_points_orig, 5, "Original function should generate some points") else: - print(f"WARNING: Original complex function generated 0 paths - asymptotes: {function_original.vertical_asymptotes[:3] if hasattr(function_original, 'vertical_asymptotes') else 'None'}") + print( + f"WARNING: Original complex function generated 0 paths - asymptotes: {function_original.vertical_asymptotes[:3] if hasattr(function_original, 'vertical_asymptotes') else 'None'}" + ) except Exception as e: print(f"WARNING: Original complex function failed: {e}") @@ -1535,12 +1580,18 @@ def test_function_path_continuity(self) -> None: # Check for reasonable continuity in the longest path max_gap = 0 for i in range(1, len(longest_path)): - x1, _ = mock_canvas.coordinate_mapper.screen_to_math(longest_path[i-1][0], longest_path[i-1][1]) + x1, _ = mock_canvas.coordinate_mapper.screen_to_math( + longest_path[i - 1][0], longest_path[i - 1][1] + ) x2, _ = mock_canvas.coordinate_mapper.screen_to_math(longest_path[i][0], longest_path[i][1]) gap = abs(x2 - x1) max_gap = max(max_gap, gap) - self.assertLess(max_gap, 20.0, f"Complex function should have reasonably continuous points, max gap was {max_gap}") + self.assertLess( + max_gap, + 20.0, + f"Complex function should have reasonably continuous points, max gap was {max_gap}", + ) print("Function path continuity test passed successfully") @@ -1554,7 +1605,7 @@ def test_find_diagonal_points_standard_order(self) -> None: SimpleMock(name="A", x=0, y=1), SimpleMock(name="B", x=1, y=1), SimpleMock(name="C", x=1, y=0), - SimpleMock(name="D", x=0, y=0) + SimpleMock(name="D", x=0, y=0), ] p_diag1, p_diag2 = MathUtils.find_diagonal_points(points, "Rect1") self.assertIsNotNone(p_diag1, "p_diag1 should not be None") @@ -1566,15 +1617,18 @@ def test_find_diagonal_points_standard_order(self) -> None: expected_pairs = [("A", "C"), ("B", "D")] # Sort the names in the actual pair to make comparison order-independent # And check if this sorted pair is one of the sorted expected pairs - self.assertIn(actual_pair, [tuple(sorted(p)) for p in expected_pairs], - f"Expected diagonal pair like AC or BD, got {actual_pair}") + self.assertIn( + actual_pair, + [tuple(sorted(p)) for p in expected_pairs], + f"Expected diagonal pair like AC or BD, got {actual_pair}", + ) def test_find_diagonal_points_shuffled_order(self) -> None: points = [ SimpleMock(name="D", x=0, y=0), SimpleMock(name="B", x=1, y=1), SimpleMock(name="A", x=0, y=1), - SimpleMock(name="C", x=1, y=0) + SimpleMock(name="C", x=1, y=0), ] p_diag1, p_diag2 = MathUtils.find_diagonal_points(points, "Rect2") self.assertIsNotNone(p_diag1, "p_diag1 should not be None") @@ -1583,16 +1637,19 @@ def test_find_diagonal_points_shuffled_order(self) -> None: self.assertNotEqual(p_diag1.y, p_diag2.y) actual_pair = tuple(sorted((p_diag1.name, p_diag2.name))) - expected_pairs = [("A", "C"), ("B", "D")] # Same expected pairs - self.assertIn(actual_pair, [tuple(sorted(p)) for p in expected_pairs], - f"Expected diagonal pair like AC or BD, got {actual_pair}") + expected_pairs = [("A", "C"), ("B", "D")] # Same expected pairs + self.assertIn( + actual_pair, + [tuple(sorted(p)) for p in expected_pairs], + f"Expected diagonal pair like AC or BD, got {actual_pair}", + ) def test_find_diagonal_points_collinear_fail_case(self) -> None: points = [ SimpleMock(name="A", x=0, y=0), SimpleMock(name="B", x=1, y=0), SimpleMock(name="C", x=2, y=0), - SimpleMock(name="D", x=3, y=0) + SimpleMock(name="D", x=3, y=0), ] p_diag1, p_diag2 = MathUtils.find_diagonal_points(points, "Rect3_Collinear") self.assertIsNone(p_diag1) @@ -1603,7 +1660,7 @@ def test_find_diagonal_points_L_shape_fail_case(self) -> None: SimpleMock(name="A", x=0, y=1), SimpleMock(name="B", x=1, y=1), SimpleMock(name="C", x=1, y=0), - SimpleMock(name="D", x=2, y=0) + SimpleMock(name="D", x=2, y=0), ] p_diag1, p_diag2 = MathUtils.find_diagonal_points(points, "Rect4_L-shape") self.assertIsNotNone(p_diag1) @@ -1612,10 +1669,7 @@ def test_find_diagonal_points_L_shape_fail_case(self) -> None: self.assertEqual(p_diag2.name, "C") def test_find_diagonal_points_less_than_4_points(self) -> None: - points = [ - SimpleMock(name="A", x=0, y=0), - SimpleMock(name="B", x=1, y=1) - ] + points = [SimpleMock(name="A", x=0, y=0), SimpleMock(name="B", x=1, y=1)] p_diag1, p_diag2 = MathUtils.find_diagonal_points(points, "Rect5_TooFew") self.assertIsNone(p_diag1) self.assertIsNone(p_diag2) @@ -1625,7 +1679,7 @@ def test_find_diagonal_points_degenerate_rectangle_one_point_repeated(self) -> N SimpleMock(name="A1", x=0, y=1), SimpleMock(name="B", x=1, y=1), SimpleMock(name="C", x=1, y=0), - SimpleMock(name="A2", x=0, y=1) + SimpleMock(name="A2", x=0, y=1), ] p_diag1, p_diag2 = MathUtils.find_diagonal_points(points, "Rect6_Degenerate") self.assertIsNotNone(p_diag1) @@ -1638,7 +1692,7 @@ def test_find_diagonal_points_another_order(self) -> None: SimpleMock(name="A", x=0, y=0), SimpleMock(name="C", x=1, y=1), SimpleMock(name="B", x=0, y=1), - SimpleMock(name="D", x=1, y=0) + SimpleMock(name="D", x=1, y=0), ] p_diag1, p_diag2 = MathUtils.find_diagonal_points(points, "Rect7") self.assertIsNotNone(p_diag1) @@ -1784,7 +1838,7 @@ def test_next_prime_sequence(self) -> None: if i == 0: self.assertEqual(MathUtils.next_prime(0), exp) else: - self.assertEqual(MathUtils.next_prime(expected[i-1] + 1), exp) + self.assertEqual(MathUtils.next_prime(expected[i - 1] + 1), exp) # ========== prev_prime tests ========== def test_prev_prime_basic(self) -> None: @@ -1819,14 +1873,14 @@ def test_totient_primes(self) -> None: """Test totient for prime numbers (should be p-1).""" primes = [2, 3, 5, 7, 11, 13, 17, 19, 23] for p in primes: - self.assertEqual(MathUtils.totient(p), p - 1, f"totient({p}) should be {p-1}") + self.assertEqual(MathUtils.totient(p), p - 1, f"totient({p}) should be {p - 1}") def test_totient_prime_powers(self) -> None: """Test totient for prime powers (p^k -> p^(k-1) * (p-1)).""" - self.assertEqual(MathUtils.totient(4), 2) # 2^2 -> 2^1 * 1 = 2 - self.assertEqual(MathUtils.totient(8), 4) # 2^3 -> 2^2 * 1 = 4 - self.assertEqual(MathUtils.totient(9), 6) # 3^2 -> 3^1 * 2 = 6 - self.assertEqual(MathUtils.totient(27), 18) # 3^3 -> 3^2 * 2 = 18 + self.assertEqual(MathUtils.totient(4), 2) # 2^2 -> 2^1 * 1 = 2 + self.assertEqual(MathUtils.totient(8), 4) # 2^3 -> 2^2 * 1 = 4 + self.assertEqual(MathUtils.totient(9), 6) # 3^2 -> 3^1 * 2 = 6 + self.assertEqual(MathUtils.totient(27), 18) # 3^3 -> 3^2 * 2 = 18 def test_totient_100(self) -> None: """Test totient(100).""" @@ -2090,13 +2144,13 @@ def test_geometric_sum_infinite_convergent(self) -> None: def test_geometric_sum_infinite_third(self) -> None: """Test geometric_sum_infinite with r=1/3.""" # a/(1-r) = 1/(1-1/3) = 1.5 - self.assertEqual(MathUtils.geometric_sum_infinite(1, 1/3), 1.5) + self.assertEqual(MathUtils.geometric_sum_infinite(1, 1 / 3), 1.5) def test_geometric_sum_infinite_negative_ratio(self) -> None: """Test geometric_sum_infinite with negative ratio.""" # a/(1-r) = 1/(1-(-0.5)) = 1/1.5 = 2/3 result = MathUtils.geometric_sum_infinite(1, -0.5) - self.assertAlmostEqual(result, 2/3, places=10) + self.assertAlmostEqual(result, 2 / 3, places=10) def test_geometric_sum_infinite_divergent(self) -> None: """Test geometric_sum_infinite raises for divergent series.""" diff --git a/static/client/client_tests/test_nonagon.py b/static/client/client_tests/test_nonagon.py index 616d31f6..232924e1 100644 --- a/static/client/client_tests/test_nonagon.py +++ b/static/client/client_tests/test_nonagon.py @@ -130,4 +130,3 @@ def test_is_irregular_helper(self) -> None: if __name__ == "__main__": unittest.main() - diff --git a/static/client/client_tests/test_numeric_solver.py b/static/client/client_tests/test_numeric_solver.py index 81b0d196..31c5d71f 100644 --- a/static/client/client_tests/test_numeric_solver.py +++ b/static/client/client_tests/test_numeric_solver.py @@ -176,11 +176,7 @@ def test_solve_numeric_three_variables(self) -> None: """Test solving a 3-variable linear system.""" from numeric_solver import solve_numeric - result_json = solve_numeric([ - "x + y + z = 6", - "x - y + z = 2", - "x + y - z = 0" - ]) + result_json = solve_numeric(["x + y + z = 6", "x - y + z = 2", "x + y - z = 0"]) result = json.loads(result_json) self.assertGreater(len(result["solutions"]), 0) @@ -207,10 +203,7 @@ def test_solve_numeric_with_initial_guesses(self) -> None: from numeric_solver import solve_numeric # Provide a guess close to pi/6 for sin(x) = 0.5 - result_json = solve_numeric( - ["sin(x) = 0.5"], - initial_guesses=[[0.5]] - ) + result_json = solve_numeric(["sin(x) = 0.5"], initial_guesses=[[0.5]]) result = json.loads(result_json) self.assertGreater(len(result["solutions"]), 0) diff --git a/static/client/client_tests/test_octagon.py b/static/client/client_tests/test_octagon.py index baadeef1..8dd33ed4 100644 --- a/static/client/client_tests/test_octagon.py +++ b/static/client/client_tests/test_octagon.py @@ -130,4 +130,3 @@ def test_is_irregular_helper(self) -> None: if __name__ == "__main__": unittest.main() - diff --git a/static/client/client_tests/test_optimized_renderers.py b/static/client/client_tests/test_optimized_renderers.py index f99bc4dc..223dd7d8 100644 --- a/static/client/client_tests/test_optimized_renderers.py +++ b/static/client/client_tests/test_optimized_renderers.py @@ -68,7 +68,13 @@ def _record(self, op: str, *args, **kwargs) -> None: self.operations.append((op, args, kwargs)) def stroke_line(self, start, end, stroke, *, include_width=True): - self._record("stroke_line", _normalize_point(start), _normalize_point(end), _serialize_stroke(stroke), include_width=include_width) + self._record( + "stroke_line", + _normalize_point(start), + _normalize_point(end), + _serialize_stroke(stroke), + include_width=include_width, + ) def stroke_polyline(self, points, stroke): normalized = tuple(_normalize_point(pt) for pt in points) @@ -115,7 +121,9 @@ def fill_joined_area(self, forward, reverse, fill): reverse_norm = tuple(_normalize_point(pt) for pt in reverse) self._record("fill_joined_area", forward_norm, reverse_norm, _serialize_fill(fill)) - def stroke_arc(self, center, radius, start_angle_rad, end_angle_rad, sweep_clockwise, stroke, css_class=None, **kwargs): + def stroke_arc( + self, center, radius, start_angle_rad, end_angle_rad, sweep_clockwise, stroke, css_class=None, **kwargs + ): self._record( "stroke_arc", _normalize_point(center), @@ -285,4 +293,3 @@ def test_circle_arc_plan_scales_with_zoom(self) -> None: if __name__ == "__main__": unittest.main() - diff --git a/static/client/client_tests/test_path_elements.py b/static/client/client_tests/test_path_elements.py index 58884abe..d98e9d01 100644 --- a/static/client/client_tests/test_path_elements.py +++ b/static/client/client_tests/test_path_elements.py @@ -13,7 +13,6 @@ class TestPathElements(unittest.TestCase): - def test_line_segment_creation(self) -> None: seg = LineSegment((0.0, 0.0), (3.0, 4.0)) self.assertEqual(seg.start_point(), (0.0, 0.0)) diff --git a/static/client/client_tests/test_pentagon.py b/static/client/client_tests/test_pentagon.py index 2333c525..d48d2fc0 100644 --- a/static/client/client_tests/test_pentagon.py +++ b/static/client/client_tests/test_pentagon.py @@ -123,4 +123,3 @@ def test_is_irregular_helper(self) -> None: if __name__ == "__main__": unittest.main() - diff --git a/static/client/client_tests/test_periodicity_detection.py b/static/client/client_tests/test_periodicity_detection.py index 11b8f591..a8ced378 100644 --- a/static/client/client_tests/test_periodicity_detection.py +++ b/static/client/client_tests/test_periodicity_detection.py @@ -50,6 +50,7 @@ def test_high_frequency_sin(self): def test_range_hint_scales_test_range(self): """range_hint should scale test_range for long-period functions.""" + # sin(x/50) has period 2*pi*50 ~ 314 # With default test_range=20, it looks linear # With range_hint=600, it should be detected @@ -60,19 +61,18 @@ def long_period_sin(x): is_periodic_default, _ = MathUtils.detect_function_periodicity(long_period_sin) # With range_hint - should detect - is_periodic_hint, period = MathUtils.detect_function_periodicity( - long_period_sin, range_hint=600 - ) + is_periodic_hint, period = MathUtils.detect_function_periodicity(long_period_sin, range_hint=600) self.assertTrue(is_periodic_hint) self.assertIsNotNone(period) def test_tan_with_asymptotes(self): """tan(x) has asymptotes - should still detect periodicity or handle gracefully.""" + def safe_tan(x): try: return math.tan(x) except: - return float('inf') + return float("inf") # Should not crash, may or may not detect as periodic due to asymptotes is_periodic, period = MathUtils.detect_function_periodicity(safe_tan) @@ -81,20 +81,22 @@ def safe_tan(x): def test_combined_sin_function(self): """100*sin(x/50) + 50*tan(x/100) with range_hint should be detected.""" + def combo(x): try: return 100 * math.sin(x / 50) + 50 * math.tan(x / 100) except: - return float('inf') + return float("inf") is_periodic, period = MathUtils.detect_function_periodicity(combo, range_hint=600) self.assertTrue(is_periodic) def test_one_over_x_not_periodic(self): """1/x has curvature that may trigger periodicity detection.""" + def one_over_x(x): if x == 0: - return float('inf') + return float("inf") return 1 / x # 1/x has curvature on both sides of the asymptote which may look @@ -110,6 +112,7 @@ class TestPeriodicityEdgeCases(unittest.TestCase): def test_function_with_exceptions(self): """Function that throws exceptions should be handled gracefully.""" + def bad_func(x): if x > 0: raise ValueError("test error") @@ -121,16 +124,18 @@ def bad_func(x): def test_function_returning_nan(self): """Function returning NaN should be handled gracefully.""" + def nan_func(x): - return float('nan') + return float("nan") is_periodic, period = MathUtils.detect_function_periodicity(nan_func) self.assertFalse(is_periodic) def test_function_returning_inf(self): """Function returning infinity should be handled gracefully.""" + def inf_func(x): - return float('inf') + return float("inf") is_periodic, period = MathUtils.detect_function_periodicity(inf_func) self.assertFalse(is_periodic) @@ -138,9 +143,7 @@ def inf_func(x): def test_range_hint_capped_at_1000(self): """range_hint should be capped at 1000 to avoid excessive computation.""" # With very large range_hint, test_range should be capped - is_periodic, period = MathUtils.detect_function_periodicity( - math.sin, range_hint=10000 - ) + is_periodic, period = MathUtils.detect_function_periodicity(math.sin, range_hint=10000) # Should still work (capped at 1000) self.assertTrue(is_periodic) @@ -149,4 +152,3 @@ def test_range_hint_capped_at_1000(self): "TestPeriodicityDetection", "TestPeriodicityEdgeCases", ] - diff --git a/static/client/client_tests/test_piecewise_function.py b/static/client/client_tests/test_piecewise_function.py index 99c1392a..aa9b249a 100644 --- a/static/client/client_tests/test_piecewise_function.py +++ b/static/client/client_tests/test_piecewise_function.py @@ -196,6 +196,7 @@ def test_deepcopy(self) -> None: pf = PiecewiseFunction(pieces, name="f") import copy + pf_copy = copy.deepcopy(pf) self.assertEqual(pf_copy.name, "f") @@ -374,7 +375,14 @@ def test_interval_with_multiple_undefined_points(self) -> None: def test_piecewise_function_with_undefined_at(self) -> None: pieces = [ - {"expression": "2", "left": None, "right": None, "left_inclusive": True, "right_inclusive": True, "undefined_at": [0]}, + { + "expression": "2", + "left": None, + "right": None, + "left_inclusive": True, + "right_inclusive": True, + "undefined_at": [0], + }, ] pf = PiecewiseFunction(pieces, name="f_with_hole") @@ -443,7 +451,13 @@ def test_empty_pieces_raises_error(self) -> None: def test_invalid_expression_raises_error(self) -> None: pieces = [ - {"expression": "invalid_function(x)", "left": None, "right": None, "left_inclusive": True, "right_inclusive": True}, + { + "expression": "invalid_function(x)", + "left": None, + "right": None, + "left_inclusive": True, + "right_inclusive": True, + }, ] with self.assertRaises(ValueError): PiecewiseFunction(pieces, name="f") @@ -475,4 +489,3 @@ def test_gap_in_intervals_returns_nan(self) -> None: "TestPiecewiseFunctionUndefinedAt", "TestPiecewiseFunctionEdgeCases", ] - diff --git a/static/client/client_tests/test_point.py b/static/client/client_tests/test_point.py index 4a7a046b..e3c9622b 100644 --- a/static/client/client_tests/test_point.py +++ b/static/client/client_tests/test_point.py @@ -22,7 +22,7 @@ def setUp(self) -> None: zoom_point=Position(1, 1), zoom_direction=1, zoom_step=0.1, - offset=Position(0, 0) # Set to (0,0) for simpler tests + offset=Position(0, 0), # Set to (0,0) for simpler tests ) # Sync canvas state with coordinate mapper @@ -44,10 +44,10 @@ def test_init(self) -> None: self.assertEqual(self.point.color, "red") def test_get_class_name(self) -> None: - self.assertEqual(self.point.get_class_name(), 'Point') + self.assertEqual(self.point.get_class_name(), "Point") def test_str(self) -> None: - self.assertEqual(str(self.point), '1,2') + self.assertEqual(str(self.point), "1,2") def test_get_state(self) -> None: expected_state = {"name": "p1", "args": {"position": {"x": 1, "y": 2}}} @@ -90,4 +90,3 @@ def test_draw(self) -> None: # This test would check if draw calls create_svg_element with expected arguments # Might require a more complex setup or mocking to verify SVG output pass - diff --git a/static/client/client_tests/test_point_manager.py b/static/client/client_tests/test_point_manager.py index 8e8dc3ee..d4d8a279 100644 --- a/static/client/client_tests/test_point_manager.py +++ b/static/client/client_tests/test_point_manager.py @@ -38,9 +38,7 @@ def generate_point_name(preferred: str | None) -> str: self.drawable_manager_proxy = SimpleMock( name="DrawableManagerProxyMock", - segment_manager=SimpleMock( - name="SegmentManagerMock", _split_segments_with_point=MagicMock() - ), + segment_manager=SimpleMock(name="SegmentManagerMock", _split_segments_with_point=MagicMock()), create_drawables_from_new_connections=MagicMock(), delete_ellipse=MagicMock(), delete_circle=MagicMock(), @@ -59,9 +57,7 @@ def generate_point_name(preferred: str | None) -> str: def test_update_point_allows_solitary_edits(self) -> None: point = self.point_manager.create_point(0, 0, "A", extra_graphics=False) - result = self.point_manager.update_point( - "A", new_name="B", new_x=5.0, new_y=6.0, new_color="#123456" - ) + result = self.point_manager.update_point("A", new_name="B", new_x=5.0, new_y=6.0, new_color="#123456") self.assertTrue(result) self.assertEqual(point.name, "B") @@ -169,9 +165,7 @@ def test_update_point_rejects_combined_edit_when_dependent(self) -> None: self.dependency_manager.register_dependency(dependent_segment, point) with self.assertRaises(ValueError): - self.point_manager.update_point( - "A", new_name="B", new_x=3.0, new_y=4.0, new_color="#00ff00" - ) + self.point_manager.update_point("A", new_name="B", new_x=3.0, new_y=4.0, new_color="#00ff00") def test_update_point_rename_conflict_preserves_original_name(self) -> None: self.point_manager.create_point(0, 0, "A", extra_graphics=False) diff --git a/static/client/client_tests/test_polar_grid.py b/static/client/client_tests/test_polar_grid.py index 46b86fbe..0b1efafa 100644 --- a/static/client/client_tests/test_polar_grid.py +++ b/static/client/client_tests/test_polar_grid.py @@ -20,7 +20,7 @@ def setUp(self) -> None: zoom_direction=0, offset=Position(0, 0), zoom_point=Position(0, 0), - zoom_step=0.1 + zoom_step=0.1, ) self.coordinate_mapper.sync_from_canvas(self.canvas) @@ -100,7 +100,7 @@ def test_get_radial_circles(self) -> None: if len(circles) >= 2: spacing = circles[1] - circles[0] for i in range(2, len(circles)): - self.assertAlmostEqual(circles[i] - circles[i-1], spacing, places=6) + self.assertAlmostEqual(circles[i] - circles[i - 1], spacing, places=6) def test_get_state(self) -> None: self.polar_grid.angular_divisions = 8 diff --git a/static/client/client_tests/test_polygon_canonicalizer.py b/static/client/client_tests/test_polygon_canonicalizer.py index 7609646a..60e0d648 100644 --- a/static/client/client_tests/test_polygon_canonicalizer.py +++ b/static/client/client_tests/test_polygon_canonicalizer.py @@ -104,12 +104,16 @@ def test_rectangle_vertices_prioritize_first_diagonal(self) -> None: ] result = canonicalize_rectangle(vertices, tolerance=0.2) for corner in result: - if math.isclose(corner[0], vertices[0][0], abs_tol=0.1) and math.isclose(corner[1], vertices[0][1], abs_tol=0.1): + if math.isclose(corner[0], vertices[0][0], abs_tol=0.1) and math.isclose( + corner[1], vertices[0][1], abs_tol=0.1 + ): break else: self.fail("Canonical rectangle did not preserve proximity to first vertex along the diagonal.") for corner in result: - if math.isclose(corner[0], vertices[2][0], abs_tol=0.1) and math.isclose(corner[1], vertices[2][1], abs_tol=0.1): + if math.isclose(corner[0], vertices[2][0], abs_tol=0.1) and math.isclose( + corner[1], vertices[2][1], abs_tol=0.1 + ): break else: self.fail("Canonical rectangle did not preserve proximity to third vertex along the diagonal.") @@ -514,4 +518,3 @@ def test_quadrilateral_invalid_subtype_raises(self) -> None: if __name__ == "__main__": unittest.main() - diff --git a/static/client/client_tests/test_polygon_manager.py b/static/client/client_tests/test_polygon_manager.py index 99c15e30..12f727cd 100644 --- a/static/client/client_tests/test_polygon_manager.py +++ b/static/client/client_tests/test_polygon_manager.py @@ -119,42 +119,49 @@ def test_create_square(self) -> None: def test_create_pentagon(self) -> None: import math + vertices = [(math.cos(2 * math.pi * i / 5), math.sin(2 * math.pi * i / 5)) for i in range(5)] polygon = self.polygon_manager.create_polygon(vertices, polygon_type="pentagon") self.assertIsInstance(polygon, Pentagon) def test_create_hexagon(self) -> None: import math + vertices = [(math.cos(2 * math.pi * i / 6), math.sin(2 * math.pi * i / 6)) for i in range(6)] polygon = self.polygon_manager.create_polygon(vertices, polygon_type="hexagon") self.assertIsInstance(polygon, Hexagon) def test_create_heptagon(self) -> None: import math + vertices = [(math.cos(2 * math.pi * i / 7), math.sin(2 * math.pi * i / 7)) for i in range(7)] polygon = self.polygon_manager.create_polygon(vertices, polygon_type="heptagon") self.assertIsInstance(polygon, Heptagon) def test_create_octagon(self) -> None: import math + vertices = [(math.cos(2 * math.pi * i / 8), math.sin(2 * math.pi * i / 8)) for i in range(8)] polygon = self.polygon_manager.create_polygon(vertices, polygon_type="octagon") self.assertIsInstance(polygon, Octagon) def test_create_nonagon(self) -> None: import math + vertices = [(math.cos(2 * math.pi * i / 9), math.sin(2 * math.pi * i / 9)) for i in range(9)] polygon = self.polygon_manager.create_polygon(vertices, polygon_type="nonagon") self.assertIsInstance(polygon, Nonagon) def test_create_decagon(self) -> None: import math + vertices = [(math.cos(2 * math.pi * i / 10), math.sin(2 * math.pi * i / 10)) for i in range(10)] polygon = self.polygon_manager.create_polygon(vertices, polygon_type="decagon") self.assertIsInstance(polygon, Decagon) def test_create_generic_polygon(self) -> None: import math + vertices = [(math.cos(2 * math.pi * i / 12), math.sin(2 * math.pi * i / 12)) for i in range(12)] polygon = self.polygon_manager.create_polygon(vertices, polygon_type="generic") self.assertIsInstance(polygon, GenericPolygon) @@ -171,12 +178,14 @@ def test_infer_quadrilateral_from_vertex_count(self) -> None: def test_infer_pentagon_from_vertex_count(self) -> None: import math + vertices = [(math.cos(2 * math.pi * i / 5), math.sin(2 * math.pi * i / 5)) for i in range(5)] polygon = self.polygon_manager.create_polygon(vertices) self.assertIsInstance(polygon, Pentagon) def test_infer_generic_from_large_vertex_count(self) -> None: import math + vertices = [(math.cos(2 * math.pi * i / 15), math.sin(2 * math.pi * i / 15)) for i in range(15)] polygon = self.polygon_manager.create_polygon(vertices) self.assertIsInstance(polygon, GenericPolygon) @@ -263,4 +272,3 @@ def test_update_polygon_requires_property(self) -> None: if __name__ == "__main__": unittest.main() - diff --git a/static/client/client_tests/test_quadrilateral.py b/static/client/client_tests/test_quadrilateral.py index 4f56c3d1..c201fc34 100644 --- a/static/client/client_tests/test_quadrilateral.py +++ b/static/client/client_tests/test_quadrilateral.py @@ -57,6 +57,7 @@ def test_rectangle_flags(self) -> None: def test_rhombus_flags(self) -> None: import math + h = math.sqrt(3) points = [ _make_point("A", 0.0, 0.0), @@ -237,4 +238,3 @@ def test_get_state_types_irregular(self) -> None: if __name__ == "__main__": unittest.main() - diff --git a/static/client/client_tests/test_rectangle.py b/static/client/client_tests/test_rectangle.py index e5ecb81e..59150835 100644 --- a/static/client/client_tests/test_rectangle.py +++ b/static/client/client_tests/test_rectangle.py @@ -23,7 +23,7 @@ def setUp(self) -> None: zoom_point=Position(1, 1), zoom_direction=1, zoom_step=0.1, - offset=Position(0, 0) # Set to (0,0) for simpler tests + offset=Position(0, 0), # Set to (0,0) for simpler tests ) # Sync canvas state with coordinate mapper @@ -68,7 +68,7 @@ def test_rectangle_is_not_renderable_by_default(self) -> None: self.assertFalse(self.rectangle.is_renderable) def test_get_class_name(self) -> None: - self.assertEqual(self.rectangle.get_class_name(), 'Rectangle') + self.assertEqual(self.rectangle.get_class_name(), "Rectangle") def test_get_state(self) -> None: state = self.rectangle.get_state() diff --git a/static/client/client_tests/test_region.py b/static/client/client_tests/test_region.py index d0dff582..52bade19 100644 --- a/static/client/client_tests/test_region.py +++ b/static/client/client_tests/test_region.py @@ -12,7 +12,6 @@ class TestRegion(unittest.TestCase): - def test_region_from_square_points(self) -> None: points = [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)] region = Region.from_points(points) @@ -188,7 +187,6 @@ def test_ellipse_area_exact(self) -> None: class TestAreaUtilities(unittest.TestCase): - def test_polygon_area_square(self) -> None: points = [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)] area = GeometryUtils.polygon_area(points) @@ -228,4 +226,3 @@ def test_line_segment_area_contribution(self) -> None: if __name__ == "__main__": unittest.main() - diff --git a/static/client/client_tests/test_relation_inspector.py b/static/client/client_tests/test_relation_inspector.py index dcce8dd4..2e6d7c98 100644 --- a/static/client/client_tests/test_relation_inspector.py +++ b/static/client/client_tests/test_relation_inspector.py @@ -33,8 +33,8 @@ def _inspect( # Parallel # ------------------------------------------------------------------ -class TestParallel(_RelationTestBase): +class TestParallel(_RelationTestBase): def test_horizontal_parallel(self) -> None: s1 = self.canvas.create_segment(0, 0, 4, 0, name="AB") s2 = self.canvas.create_segment(0, 2, 4, 2, name="CD") @@ -87,8 +87,8 @@ def test_symmetry(self) -> None: # Perpendicular # ------------------------------------------------------------------ -class TestPerpendicular(_RelationTestBase): +class TestPerpendicular(_RelationTestBase): def test_axes_cross(self) -> None: s1 = self.canvas.create_segment(0, 0, 4, 0, name="AB") s2 = self.canvas.create_segment(5, 0, 5, 4, name="CD") @@ -126,8 +126,8 @@ def test_zero_length_error(self) -> None: # Collinear # ------------------------------------------------------------------ -class TestCollinear(_RelationTestBase): +class TestCollinear(_RelationTestBase): def test_three_on_x_axis(self) -> None: self.canvas.create_point(0, 0, name="A") self.canvas.create_point(3, 0, name="B") @@ -148,7 +148,8 @@ def test_four_collinear(self) -> None: self.canvas.create_point(2, 2, name="C") self.canvas.create_point(5, 5, name="D") res = self._inspect( - "collinear", ["A", "B", "C", "D"], + "collinear", + ["A", "B", "C", "D"], ["point", "point", "point", "point"], ) self.assertTrue(res["result"]) @@ -181,8 +182,8 @@ def test_too_few_points(self) -> None: # Concyclic # ------------------------------------------------------------------ -class TestConcyclic(_RelationTestBase): +class TestConcyclic(_RelationTestBase): def test_four_on_unit_circle(self) -> None: r = 5.0 self.canvas.create_point(r, 0, name="A") @@ -190,7 +191,8 @@ def test_four_on_unit_circle(self) -> None: self.canvas.create_point(-r, 0, name="C") self.canvas.create_point(0, -r, name="D") res = self._inspect( - "concyclic", ["A", "B", "C", "D"], + "concyclic", + ["A", "B", "C", "D"], ["point", "point", "point", "point"], ) self.assertTrue(res["result"]) @@ -201,7 +203,8 @@ def test_not_concyclic(self) -> None: self.canvas.create_point(-5, 0, name="C") self.canvas.create_point(1, 1, name="D") res = self._inspect( - "concyclic", ["A", "B", "C", "D"], + "concyclic", + ["A", "B", "C", "D"], ["point", "point", "point", "point"], ) self.assertFalse(res["result"]) @@ -212,7 +215,8 @@ def test_collinear_first_three(self) -> None: self.canvas.create_point(2, 0, name="C") self.canvas.create_point(0, 1, name="D") res = self._inspect( - "concyclic", ["A", "B", "C", "D"], + "concyclic", + ["A", "B", "C", "D"], ["point", "point", "point", "point"], ) self.assertFalse(res["result"]) @@ -222,8 +226,8 @@ def test_collinear_first_three(self) -> None: # Equal Length # ------------------------------------------------------------------ -class TestEqualLength(_RelationTestBase): +class TestEqualLength(_RelationTestBase): def test_equal_segments(self) -> None: s1 = self.canvas.create_segment(0, 0, 3, 4, name="AB") s2 = self.canvas.create_segment(10, 10, 13, 14, name="CD") @@ -250,24 +254,32 @@ def test_near_threshold(self) -> None: # Similar Triangles # ------------------------------------------------------------------ -class TestSimilarTriangles(_RelationTestBase): +class TestSimilarTriangles(_RelationTestBase): def test_scaled_copy(self) -> None: t1 = self.canvas.create_polygon( - [(0, 0), (3, 0), (0, 4)], polygon_type=PolygonType.TRIANGLE, name="ABC", + [(0, 0), (3, 0), (0, 4)], + polygon_type=PolygonType.TRIANGLE, + name="ABC", ) t2 = self.canvas.create_polygon( - [(10, 10), (16, 10), (10, 18)], polygon_type=PolygonType.TRIANGLE, name="DEF", + [(10, 10), (16, 10), (10, 18)], + polygon_type=PolygonType.TRIANGLE, + name="DEF", ) res = self._inspect("similar", [t1.name, t2.name], ["triangle", "triangle"]) self.assertTrue(res["result"]) def test_not_similar(self) -> None: t1 = self.canvas.create_polygon( - [(0, 0), (3, 0), (0, 4)], polygon_type=PolygonType.TRIANGLE, name="ABC", + [(0, 0), (3, 0), (0, 4)], + polygon_type=PolygonType.TRIANGLE, + name="ABC", ) t2 = self.canvas.create_polygon( - [(10, 10), (20, 10), (10, 11)], polygon_type=PolygonType.TRIANGLE, name="DEF", + [(10, 10), (20, 10), (10, 11)], + polygon_type=PolygonType.TRIANGLE, + name="DEF", ) res = self._inspect("similar", [t1.name, t2.name], ["triangle", "triangle"]) self.assertFalse(res["result"]) @@ -277,24 +289,32 @@ def test_not_similar(self) -> None: # Congruent Triangles # ------------------------------------------------------------------ -class TestCongruentTriangles(_RelationTestBase): +class TestCongruentTriangles(_RelationTestBase): def test_same_shape(self) -> None: t1 = self.canvas.create_polygon( - [(0, 0), (3, 0), (0, 4)], polygon_type=PolygonType.TRIANGLE, name="ABC", + [(0, 0), (3, 0), (0, 4)], + polygon_type=PolygonType.TRIANGLE, + name="ABC", ) t2 = self.canvas.create_polygon( - [(10, 10), (13, 10), (10, 14)], polygon_type=PolygonType.TRIANGLE, name="DEF", + [(10, 10), (13, 10), (10, 14)], + polygon_type=PolygonType.TRIANGLE, + name="DEF", ) res = self._inspect("congruent", [t1.name, t2.name], ["triangle", "triangle"]) self.assertTrue(res["result"]) def test_similar_not_congruent(self) -> None: t1 = self.canvas.create_polygon( - [(0, 0), (3, 0), (0, 4)], polygon_type=PolygonType.TRIANGLE, name="ABC", + [(0, 0), (3, 0), (0, 4)], + polygon_type=PolygonType.TRIANGLE, + name="ABC", ) t2 = self.canvas.create_polygon( - [(10, 10), (16, 10), (10, 18)], polygon_type=PolygonType.TRIANGLE, name="DEF", + [(10, 10), (16, 10), (10, 18)], + polygon_type=PolygonType.TRIANGLE, + name="DEF", ) res = self._inspect("congruent", [t1.name, t2.name], ["triangle", "triangle"]) self.assertFalse(res["result"]) @@ -304,8 +324,8 @@ def test_similar_not_congruent(self) -> None: # Tangent # ------------------------------------------------------------------ -class TestTangent(_RelationTestBase): +class TestTangent(_RelationTestBase): def test_segment_tangent_to_circle(self) -> None: c1 = self.canvas.create_circle(0, 0, 5) # Horizontal line y = 5 is tangent to circle centered at origin with radius 5 @@ -360,14 +380,15 @@ def test_segment_far_from_tangent_point(self) -> None: # Concurrent # ------------------------------------------------------------------ -class TestConcurrent(_RelationTestBase): +class TestConcurrent(_RelationTestBase): def test_three_lines_through_origin(self) -> None: s1 = self.canvas.create_segment(-5, 0, 5, 0, name="AB") s2 = self.canvas.create_segment(0, -5, 0, 5, name="CD") s3 = self.canvas.create_segment(-5, -5, 5, 5, name="EF") res = self._inspect( - "concurrent", [s1.name, s2.name, s3.name], + "concurrent", + [s1.name, s2.name, s3.name], ["segment", "segment", "segment"], ) self.assertTrue(res["result"]) @@ -380,7 +401,8 @@ def test_non_concurrent(self) -> None: s2 = self.canvas.create_segment(5, 0, 5, 4, name="CD") s3 = self.canvas.create_segment(1, 1, 5, 2, name="EF") res = self._inspect( - "concurrent", [s1.name, s2.name, s3.name], + "concurrent", + [s1.name, s2.name, s3.name], ["segment", "segment", "segment"], ) self.assertFalse(res["result"]) @@ -390,7 +412,8 @@ def test_parallel_pair_among_three(self) -> None: s2 = self.canvas.create_segment(0, 2, 4, 2, name="CD") s3 = self.canvas.create_segment(5, 0, 5, 4, name="EF") res = self._inspect( - "concurrent", [s1.name, s2.name, s3.name], + "concurrent", + [s1.name, s2.name, s3.name], ["segment", "segment", "segment"], ) self.assertFalse(res["result"]) @@ -400,8 +423,8 @@ def test_parallel_pair_among_three(self) -> None: # Point on Line # ------------------------------------------------------------------ -class TestPointOnLine(_RelationTestBase): +class TestPointOnLine(_RelationTestBase): def test_on_extended_line(self) -> None: s1 = self.canvas.create_segment(0, 0, 2, 2, name="AB") self.canvas.create_point(5, 5, name="P") @@ -432,8 +455,8 @@ def test_reversed_order(self) -> None: # Point on Circle # ------------------------------------------------------------------ -class TestPointOnCircle(_RelationTestBase): +class TestPointOnCircle(_RelationTestBase): def test_on_circle(self) -> None: c1 = self.canvas.create_circle(0, 0, 5) self.canvas.create_point(3, 4, name="P") @@ -457,8 +480,8 @@ def test_reversed_order(self) -> None: # Auto Inspect # ------------------------------------------------------------------ -class TestAutoInspect(_RelationTestBase): +class TestAutoInspect(_RelationTestBase): def test_two_parallel_equal_segments(self) -> None: s1 = self.canvas.create_segment(0, 0, 4, 0, name="AB") s2 = self.canvas.create_segment(0, 3, 4, 3, name="CD") @@ -469,11 +492,7 @@ def test_two_parallel_equal_segments(self) -> None: self.assertIn("perpendicular", checks) self.assertIn("equal_length", checks) # Both parallel and equal length should be true - true_ops = [ - r["operation"] - for r in res["details"]["results"] - if r.get("result") is True - ] + true_ops = [r["operation"] for r in res["details"]["results"] if r.get("result") is True] self.assertIn("parallel", true_ops) self.assertIn("equal_length", true_ops) @@ -492,8 +511,8 @@ def test_symmetry_ab_ba(self) -> None: # Error cases # ------------------------------------------------------------------ -class TestInspectErrors(_RelationTestBase): +class TestInspectErrors(_RelationTestBase): def test_wrong_object_count(self) -> None: s1 = self.canvas.create_segment(0, 0, 4, 0, name="AB") res = self._inspect("parallel", [s1.name], ["segment"]) diff --git a/static/client/client_tests/test_renderer_edge_cases.py b/static/client/client_tests/test_renderer_edge_cases.py index 8b22ff59..756d3dfb 100644 --- a/static/client/client_tests/test_renderer_edge_cases.py +++ b/static/client/client_tests/test_renderer_edge_cases.py @@ -46,8 +46,12 @@ def fill_polygon(self, points, fill, stroke=None, **kwargs): def fill_joined_area(self, forward, reverse, fill): self._record("fill_joined_area", forward, reverse, fill) - def stroke_arc(self, center, radius, start_angle_rad, end_angle_rad, sweep_clockwise, stroke, css_class=None, **kwargs): - self._record("stroke_arc", center, radius, start_angle_rad, end_angle_rad, sweep_clockwise, stroke, css_class, **kwargs) + def stroke_arc( + self, center, radius, start_angle_rad, end_angle_rad, sweep_clockwise, stroke, css_class=None, **kwargs + ): + self._record( + "stroke_arc", center, radius, start_angle_rad, end_angle_rad, sweep_clockwise, stroke, css_class, **kwargs + ) def draw_text(self, text, position, font, color, alignment, style_overrides=None, **kwargs): self._record("draw_text", text, position, font, color, alignment, style_overrides, **kwargs) @@ -72,7 +76,7 @@ def setUp(self) -> None: self.style = {"point_radius": 4, "point_color": "#000"} def test_point_with_nan_coordinates_does_not_crash(self) -> None: - point = Point(float('nan'), 5.0, name="P") + point = Point(float("nan"), 5.0, name="P") try: shared.render_point_helper(self.primitives, point, self.mapper, self.style) @@ -80,7 +84,7 @@ def test_point_with_nan_coordinates_does_not_crash(self) -> None: self.fail(f"Point with NaN coordinate raised exception: {e}") def test_point_with_infinity_coordinates_does_not_crash(self) -> None: - point = Point(float('inf'), 10.0, name="P") + point = Point(float("inf"), 10.0, name="P") try: shared.render_point_helper(self.primitives, point, self.mapper, self.style) @@ -88,7 +92,7 @@ def test_point_with_infinity_coordinates_does_not_crash(self) -> None: self.fail(f"Point with Infinity coordinate raised exception: {e}") def test_point_with_negative_infinity_does_not_crash(self) -> None: - point = Point(5.0, float('-inf'), name="P") + point = Point(5.0, float("-inf"), name="P") try: shared.render_point_helper(self.primitives, point, self.mapper, self.style) @@ -130,7 +134,7 @@ def setUp(self) -> None: def test_segment_with_nan_endpoint_does_not_crash(self) -> None: p1 = Point(0, 0, name="A") - p2 = Point(float('nan'), 5, name="B") + p2 = Point(float("nan"), 5, name="B") segment = Segment(p1, p2) try: @@ -140,7 +144,7 @@ def test_segment_with_nan_endpoint_does_not_crash(self) -> None: def test_segment_with_infinity_endpoint_does_not_crash(self) -> None: p1 = Point(0, 0, name="A") - p2 = Point(float('inf'), float('inf'), name="B") + p2 = Point(float("inf"), float("inf"), name="B") segment = Segment(p1, p2) try: @@ -176,7 +180,7 @@ def setUp(self) -> None: def test_circle_with_nan_radius_does_not_crash(self) -> None: center = Point(5, 5, name="O") - circle = Circle(center, radius=float('nan')) + circle = Circle(center, radius=float("nan")) try: shared.render_circle_helper(self.primitives, circle, self.mapper, self.style) @@ -185,7 +189,7 @@ def test_circle_with_nan_radius_does_not_crash(self) -> None: def test_circle_with_infinity_radius_does_not_crash(self) -> None: center = Point(5, 5, name="O") - circle = Circle(center, radius=float('inf')) + circle = Circle(center, radius=float("inf")) try: shared.render_circle_helper(self.primitives, circle, self.mapper, self.style) @@ -211,7 +215,7 @@ def test_circle_with_negative_radius_does_not_crash(self) -> None: self.fail(f"Circle with negative radius raised exception: {e}") def test_circle_with_nan_center_does_not_crash(self) -> None: - center = Point(float('nan'), float('nan'), name="O") + center = Point(float("nan"), float("nan"), name="O") circle = Circle(center, radius=10) try: @@ -241,7 +245,7 @@ def test_zero_length_vector_does_not_crash(self) -> None: def test_vector_with_nan_endpoint_does_not_crash(self) -> None: p1 = Point(0, 0, name="A") - p2 = Point(float('nan'), 0, name="B") + p2 = Point(float("nan"), 0, name="B") vector = Vector(p1, p2) try: @@ -251,7 +255,7 @@ def test_vector_with_nan_endpoint_does_not_crash(self) -> None: def test_vector_with_infinity_does_not_crash(self) -> None: p1 = Point(0, 0, name="A") - p2 = Point(float('inf'), 0, name="B") + p2 = Point(float("inf"), 0, name="B") vector = Vector(p1, p2) try: @@ -286,7 +290,7 @@ def test_ellipse_with_zero_radius_y_does_not_crash(self) -> None: def test_ellipse_with_nan_radii_does_not_crash(self) -> None: center = Point(5, 5, name="O") - ellipse = Ellipse(center, radius_x=float('nan'), radius_y=float('nan')) + ellipse = Ellipse(center, radius_x=float("nan"), radius_y=float("nan")) try: shared.render_ellipse_helper(self.primitives, ellipse, self.mapper, self.style) @@ -304,7 +308,7 @@ def test_ellipse_with_negative_radii_does_not_crash(self) -> None: def test_ellipse_with_nan_rotation_does_not_crash(self) -> None: center = Point(5, 5, name="O") - ellipse = Ellipse(center, radius_x=10, radius_y=5, rotation_angle=float('nan')) + ellipse = Ellipse(center, radius_x=10, radius_y=5, rotation_angle=float("nan")) try: shared.render_ellipse_helper(self.primitives, ellipse, self.mapper, self.style) @@ -343,7 +347,7 @@ def test_label_with_very_long_text_does_not_crash(self) -> None: self.fail(f"Label with long text raised exception: {e}") def test_label_with_nan_font_size_does_not_crash(self) -> None: - label = Label(5, 5, "Test", font_size=float('nan')) + label = Label(5, 5, "Test", font_size=float("nan")) try: shared.render_label_helper(self.primitives, label, self.mapper, self.style) @@ -367,7 +371,7 @@ def test_label_with_negative_font_size_does_not_crash(self) -> None: self.fail(f"Label with negative font size raised exception: {e}") def test_label_with_nan_position_does_not_crash(self) -> None: - label = Label(float('nan'), float('nan'), "Test") + label = Label(float("nan"), float("nan"), "Test") try: shared.render_label_helper(self.primitives, label, self.mapper, self.style) @@ -401,8 +405,8 @@ def test_cartesian_with_zero_dimensions_does_not_crash(self) -> None: def test_cartesian_with_nan_dimensions_does_not_crash(self) -> None: cartesian = SimpleNamespace( - width=float('nan'), - height=float('nan'), + width=float("nan"), + height=float("nan"), current_tick_spacing=50, default_tick_spacing=50, ) @@ -448,4 +452,3 @@ def test_cartesian_with_negative_tick_spacing_does_not_crash(self) -> None: "TestLabelEdgeCases", "TestCartesianEdgeCases", ] - diff --git a/static/client/client_tests/test_renderer_logic.py b/static/client/client_tests/test_renderer_logic.py index a76bdefa..d2e3ed54 100644 --- a/static/client/client_tests/test_renderer_logic.py +++ b/static/client/client_tests/test_renderer_logic.py @@ -102,7 +102,7 @@ def test_render_point_helper_records_metadata(self) -> None: self.assertIsNotNone(metadata) label_meta = metadata["point_label"] self.assertEqual(label_meta["math_position"], (point.x, point.y)) - self.assertEqual(label_meta["screen_offset"], (float(style["point_radius"]), float(-style["point_radius"])) ) + self.assertEqual(label_meta["screen_offset"], (float(style["point_radius"]), float(-style["point_radius"]))) def test_canvas2d_font_cache_quantizes_similar_sizes(self) -> None: canvas_el = SimpleNamespace(getContext=lambda _kind: SimpleNamespace()) diff --git a/static/client/client_tests/test_renderer_primitives.py b/static/client/client_tests/test_renderer_primitives.py index b09157ae..82867a2a 100644 --- a/static/client/client_tests/test_renderer_primitives.py +++ b/static/client/client_tests/test_renderer_primitives.py @@ -230,7 +230,9 @@ def lineTo(self, x: float, y: float) -> None: def arc(self, x: float, y: float, radius: float, start: float, end: float, anticlockwise: bool = False) -> None: self.log.append(("arc", x, y, radius, start, end, anticlockwise)) - def ellipse(self, x: float, y: float, radius_x: float, radius_y: float, rotation: float, start: float, end: float) -> None: + def ellipse( + self, x: float, y: float, radius_x: float, radius_y: float, rotation: float, start: float, end: float + ) -> None: self.log.append(("ellipse", x, y, radius_x, radius_y, rotation, start, end)) def closePath(self) -> None: @@ -399,7 +401,9 @@ def stroke_circle(self, center: Point2D, radius: float, stroke: Any) -> None: def fill_circle(self, center: Point2D, radius: float, fill: Any, stroke: Optional[Any] = None) -> None: raise AssertionError("Unexpected fill_circle call") - def stroke_ellipse(self, center: Point2D, radius_x: float, radius_y: float, rotation_rad: float, stroke: Any) -> None: + def stroke_ellipse( + self, center: Point2D, radius_x: float, radius_y: float, rotation_rad: float, stroke: Any + ) -> None: raise AssertionError("Unexpected stroke_ellipse call") def fill_polygon(self, points: List[Point2D], fill: Any, stroke: Optional[Any] = None) -> None: @@ -503,10 +507,12 @@ def test_svg_primitives_render_shapes(self) -> None: renderer = SvgRenderer() - with unittest.mock.patch("rendering.svg_renderer.svg", mock_svg), \ - unittest.mock.patch("rendering.svg_renderer.document", mock_document), \ - unittest.mock.patch("rendering.svg_primitive_adapter.svg", mock_svg), \ - unittest.mock.patch("rendering.svg_primitive_adapter.document", mock_document): + with ( + unittest.mock.patch("rendering.svg_renderer.svg", mock_svg), + unittest.mock.patch("rendering.svg_renderer.document", mock_document), + unittest.mock.patch("rendering.svg_primitive_adapter.svg", mock_svg), + unittest.mock.patch("rendering.svg_primitive_adapter.document", mock_document), + ): renderer._shared_primitives = SvgPrimitiveAdapter("math-svg") renderer._render_point(self.point_a, self.mapper) renderer._render_segment(self.segment_ab, self.mapper) @@ -585,10 +591,12 @@ def test_canvas_and_svg_primitives_have_matching_core_signals(self) -> None: mock_svg, mock_document = reset_svg_environment(svg_log) svg_renderer = SvgRenderer() - with unittest.mock.patch("rendering.svg_renderer.svg", mock_svg), \ - unittest.mock.patch("rendering.svg_renderer.document", mock_document), \ - unittest.mock.patch("rendering.svg_primitive_adapter.svg", mock_svg), \ - unittest.mock.patch("rendering.svg_primitive_adapter.document", mock_document): + with ( + unittest.mock.patch("rendering.svg_renderer.svg", mock_svg), + unittest.mock.patch("rendering.svg_renderer.document", mock_document), + unittest.mock.patch("rendering.svg_primitive_adapter.svg", mock_svg), + unittest.mock.patch("rendering.svg_primitive_adapter.document", mock_document), + ): svg_renderer._shared_primitives = SvgPrimitiveAdapter("math-svg") svg_renderer._render_point(self.point_a, self.mapper) svg_renderer._render_segment(self.segment_ab, self.mapper) @@ -601,8 +609,9 @@ def test_canvas_and_svg_primitives_have_matching_core_signals(self) -> None: svg_has_point_label = any(text and text.startswith("A(") for text in svg_text_entries) svg_has_function_label = "f" in svg_text_entries svg_tree = serialize_svg_tree(mock_document.surface) - svg_has_angle_arc = svg_tree_contains(svg_tree, "path", "class", "angle-arc") or \ - svg_tree_contains(svg_tree, "path", "stroke", self.angle_abc.color) + svg_has_angle_arc = svg_tree_contains(svg_tree, "path", "class", "angle-arc") or svg_tree_contains( + svg_tree, "path", "stroke", self.angle_abc.color + ) canvas_renderer = Canvas2DRenderer() mock_canvas = MockCanvasElement() @@ -615,7 +624,9 @@ def test_canvas_and_svg_primitives_have_matching_core_signals(self) -> None: canvas_renderer._render_function(self.function, self.mapper) text_operations = [entry for entry in ctx.log if entry[0] == "fillText"] - canvas_has_point_label = any(isinstance(entry[1], str) and entry[1].startswith("A(") for entry in text_operations) + canvas_has_point_label = any( + isinstance(entry[1], str) and entry[1].startswith("A(") for entry in text_operations + ) canvas_has_function_label = any(entry[1] == "f" for entry in text_operations) # Isolate the angle signal on a fresh canvas log so point/circle arcs # cannot mask angle rendering regressions. @@ -708,4 +719,3 @@ def test_webgl_vector_helper_uses_polygon_fallback(self) -> None: self.assertTrue(line_ops, "vector helper should draw the segment body") tip_calls = [entry for entry in mock_renderer.log if entry[0] == "line_strip"] self.assertTrue(tip_calls, "vector helper should approximate arrowhead with line strip fallback") - diff --git a/static/client/client_tests/test_result_processor_traced.py b/static/client/client_tests/test_result_processor_traced.py index 363ea40c..a1f66ed4 100644 --- a/static/client/client_tests/test_result_processor_traced.py +++ b/static/client/client_tests/test_result_processor_traced.py @@ -30,7 +30,10 @@ def test_returns_results_and_trace(self) -> None: } calls = [{"function_name": "evaluate_expression", "arguments": {"expression": "3 + 7", "canvas": self.canvas}}] results, traced = ProcessFunctionCalls.get_results_traced( - calls, available_functions, (), self.canvas, + calls, + available_functions, + (), + self.canvas, ) # Results should contain the evaluation self.assertTrue(len(results) > 0) @@ -46,7 +49,10 @@ def test_records_timing(self) -> None: } calls = [{"function_name": "evaluate_expression", "arguments": {"expression": "2 * 3", "canvas": self.canvas}}] _, traced = ProcessFunctionCalls.get_results_traced( - calls, available_functions, (), self.canvas, + calls, + available_functions, + (), + self.canvas, ) self.assertGreaterEqual(traced[0]["duration_ms"], 0) @@ -57,7 +63,10 @@ def always_fail(**kwargs: Any) -> str: available_functions: Dict[str, Any] = {"fail_func": always_fail} calls = [{"function_name": "fail_func", "arguments": {}}] results, traced = ProcessFunctionCalls.get_results_traced( - calls, available_functions, (), self.canvas, + calls, + available_functions, + (), + self.canvas, ) self.assertTrue(traced[0]["is_error"]) self.assertIn("Error", str(results.get("fail_func", ""))) @@ -72,10 +81,16 @@ def test_same_results_as_get_results(self) -> None: {"function_name": "evaluate_expression", "arguments": {"expression": "10 * 2", "canvas": self.canvas}}, ] results_normal = ProcessFunctionCalls.get_results( - calls, available_functions, (), self.canvas, + calls, + available_functions, + (), + self.canvas, ) results_traced, _ = ProcessFunctionCalls.get_results_traced( - calls, available_functions, (), self.canvas, + calls, + available_functions, + (), + self.canvas, ) self.assertEqual(results_normal, results_traced) @@ -86,7 +101,10 @@ def test_sanitized_arguments(self) -> None: } calls = [{"function_name": "evaluate_expression", "arguments": {"expression": "1+1", "canvas": self.canvas}}] _, traced = ProcessFunctionCalls.get_results_traced( - calls, available_functions, (), self.canvas, + calls, + available_functions, + (), + self.canvas, ) self.assertNotIn("canvas", traced[0]["arguments"]) self.assertIn("expression", traced[0]["arguments"]) @@ -96,7 +114,10 @@ def test_missing_function(self) -> None: available_functions: Dict[str, Any] = {} calls = [{"function_name": "nonexistent", "arguments": {}}] results, traced = ProcessFunctionCalls.get_results_traced( - calls, available_functions, (), self.canvas, + calls, + available_functions, + (), + self.canvas, ) self.assertTrue(traced[0]["is_error"]) self.assertIn("nonexistent", results) @@ -108,7 +129,10 @@ def test_result_value_captured(self) -> None: } calls = [{"function_name": "evaluate_expression", "arguments": {"expression": "3 + 7", "canvas": self.canvas}}] _, traced = ProcessFunctionCalls.get_results_traced( - calls, available_functions, (), self.canvas, + calls, + available_functions, + (), + self.canvas, ) self.assertEqual(traced[0]["result"], 10) @@ -117,11 +141,21 @@ def test_result_value_for_expression_with_variables(self) -> None: available_functions: Dict[str, Any] = { "evaluate_expression": ProcessFunctionCalls.evaluate_expression, } - calls = [{"function_name": "evaluate_expression", "arguments": { - "expression": "x + y", "variables": {"x": 5, "y": 3}, "canvas": self.canvas, - }}] + calls = [ + { + "function_name": "evaluate_expression", + "arguments": { + "expression": "x + y", + "variables": {"x": 5, "y": 3}, + "canvas": self.canvas, + }, + } + ] _, traced = ProcessFunctionCalls.get_results_traced( - calls, available_functions, (), self.canvas, + calls, + available_functions, + (), + self.canvas, ) self.assertEqual(traced[0]["result"], 8) @@ -136,7 +170,10 @@ def test_multiple_calls_sequential_order(self) -> None: {"function_name": "evaluate_expression", "arguments": {"expression": "3+3", "canvas": self.canvas}}, ] _, traced = ProcessFunctionCalls.get_results_traced( - calls, available_functions, (), self.canvas, + calls, + available_functions, + (), + self.canvas, ) self.assertEqual(len(traced), 3) self.assertEqual([t["seq"] for t in traced], [0, 1, 2]) diff --git a/static/client/client_tests/test_screen_offset_label_layout.py b/static/client/client_tests/test_screen_offset_label_layout.py index eb0aa52f..0c0bb940 100644 --- a/static/client/client_tests/test_screen_offset_label_layout.py +++ b/static/client/client_tests/test_screen_offset_label_layout.py @@ -2,7 +2,12 @@ import unittest -from rendering.helpers.screen_offset_label_layout import LabelBlock, make_label_text_call, solve_dy, solve_dy_with_hide_for_text_calls +from rendering.helpers.screen_offset_label_layout import ( + LabelBlock, + make_label_text_call, + solve_dy, + solve_dy_with_hide_for_text_calls, +) class TestScreenOffsetLabelLayout(unittest.TestCase): @@ -224,5 +229,3 @@ def __init__(self, size: float) -> None: dy, hidden = solve_dy_with_hide_for_text_calls([call_a, call_b], max_abs_dy_factor=3.0) # They may still overlap in rect-space, but proximity rule should not hide. self.assertNotIn("B", hidden) - - diff --git a/static/client/client_tests/test_segment.py b/static/client/client_tests/test_segment.py index 52f2269d..d1813506 100644 --- a/static/client/client_tests/test_segment.py +++ b/static/client/client_tests/test_segment.py @@ -23,7 +23,7 @@ def setUp(self) -> None: zoom_point=Position(1, 1), zoom_direction=1, zoom_step=0.1, - offset=Position(0, 0) # Set to (0,0) for simpler tests + offset=Position(0, 0), # Set to (0,0) for simpler tests ) # Sync canvas state with coordinate mapper @@ -51,7 +51,7 @@ def test_init(self) -> None: self.assertEqual(self.segment.color, "blue") def test_get_class_name(self) -> None: - self.assertEqual(self.segment.get_class_name(), 'Segment') + self.assertEqual(self.segment.get_class_name(), "Segment") def test_calculate_line_algebraic_formula(self) -> None: line_formula = self.segment._calculate_line_algebraic_formula() @@ -60,13 +60,16 @@ def test_calculate_line_algebraic_formula(self) -> None: def test_visibility_via_canvas(self) -> None: from canvas import Canvas # not used directly; we mimic Canvas._is_drawable_visible logic + # Compute visibility using canvas-level predicate # Endpoint-in-viewport or intersects viewport x1, y1 = self.coordinate_mapper.math_to_screen(self.segment.point1.x, self.segment.point1.y) x2, y2 = self.coordinate_mapper.math_to_screen(self.segment.point2.x, self.segment.point2.y) - in_view = self.canvas.is_point_within_canvas_visible_area(x1, y1) or \ - self.canvas.is_point_within_canvas_visible_area(x2, y2) or \ - self.canvas.any_segment_part_visible_in_canvas_area(x1, y1, x2, y2) + in_view = ( + self.canvas.is_point_within_canvas_visible_area(x1, y1) + or self.canvas.is_point_within_canvas_visible_area(x2, y2) + or self.canvas.any_segment_part_visible_in_canvas_area(x1, y1, x2, y2) + ) self.assertTrue(in_view) def test_get_state(self) -> None: @@ -179,4 +182,3 @@ def test_same_name_different_coords_produces_different_state(self) -> None: # But coordinates differ, so render cache should invalidate self.assertNotEqual(state1["_p1_coords"], state2["_p1_coords"]) self.assertNotEqual(state1["_p2_coords"], state2["_p2_coords"]) - diff --git a/static/client/client_tests/test_segment_manager.py b/static/client/client_tests/test_segment_manager.py index 368a3ac6..f880e5ad 100644 --- a/static/client/client_tests/test_segment_manager.py +++ b/static/client/client_tests/test_segment_manager.py @@ -175,4 +175,3 @@ def test_create_segment_from_points_removes_stale_segment(self) -> None: if __name__ == "__main__": unittest.main() - diff --git a/static/client/client_tests/test_segments_bounded_colored_area.py b/static/client/client_tests/test_segments_bounded_colored_area.py index 4391bd9a..36663b51 100644 --- a/static/client/client_tests/test_segments_bounded_colored_area.py +++ b/static/client/client_tests/test_segments_bounded_colored_area.py @@ -29,7 +29,7 @@ def setUp(self) -> None: zoom_point=Position(1, 1), zoom_direction=1, zoom_step=0.1, - offset=Position(0, 0) # Set to (0,0) for simpler tests + offset=Position(0, 0), # Set to (0,0) for simpler tests ) # Sync canvas state with coordinate mapper @@ -39,13 +39,13 @@ def setUp(self) -> None: self.segment1 = SimpleMock( name="AB", point1=SimpleMock(x=100, y=200), # Canvas coordinates - point2=SimpleMock(x=300, y=250) # Canvas coordinates + point2=SimpleMock(x=300, y=250), # Canvas coordinates ) self.segment2 = SimpleMock( name="CD", point1=SimpleMock(x=150, y=180), # Canvas coordinates - point2=SimpleMock(x=280, y=220) # Canvas coordinates + point2=SimpleMock(x=280, y=220), # Canvas coordinates ) def test_init_with_two_segments(self) -> None: @@ -65,7 +65,7 @@ def test_init_with_segment_and_x_axis(self) -> None: def test_get_class_name(self) -> None: """Test class name retrieval.""" area = SegmentsBoundedColoredArea(self.segment1, self.segment2) - self.assertEqual(area.get_class_name(), 'SegmentsBoundedColoredArea') + self.assertEqual(area.get_class_name(), "SegmentsBoundedColoredArea") def test_generate_name_with_two_segments(self) -> None: """Test name generation with two segments.""" @@ -84,10 +84,7 @@ def test_uses_segment_with_matching_first_segment(self) -> None: area = SegmentsBoundedColoredArea(self.segment1, self.segment2) # Create matching segment - matching_segment = SimpleMock( - point1=SimpleMock(x=100, y=200), - point2=SimpleMock(x=300, y=250) - ) + matching_segment = SimpleMock(point1=SimpleMock(x=100, y=200), point2=SimpleMock(x=300, y=250)) self.assertTrue(area.uses_segment(matching_segment)) @@ -96,10 +93,7 @@ def test_uses_segment_with_matching_second_segment(self) -> None: area = SegmentsBoundedColoredArea(self.segment1, self.segment2) # Create matching segment - matching_segment = SimpleMock( - point1=SimpleMock(x=150, y=180), - point2=SimpleMock(x=280, y=220) - ) + matching_segment = SimpleMock(point1=SimpleMock(x=150, y=180), point2=SimpleMock(x=280, y=220)) self.assertTrue(area.uses_segment(matching_segment)) @@ -108,10 +102,7 @@ def test_uses_segment_with_non_matching_segment(self) -> None: area = SegmentsBoundedColoredArea(self.segment1, self.segment2) # Create non-matching segment - different_segment = SimpleMock( - point1=SimpleMock(x=400, y=400), - point2=SimpleMock(x=500, y=500) - ) + different_segment = SimpleMock(point1=SimpleMock(x=400, y=400), point2=SimpleMock(x=500, y=500)) self.assertFalse(area.uses_segment(different_segment)) @@ -120,16 +111,10 @@ def test_uses_segment_with_only_first_segment(self) -> None: area = SegmentsBoundedColoredArea(self.segment1, None) # Create matching segment - matching_segment = SimpleMock( - point1=SimpleMock(x=100, y=200), - point2=SimpleMock(x=300, y=250) - ) + matching_segment = SimpleMock(point1=SimpleMock(x=100, y=200), point2=SimpleMock(x=300, y=250)) # Create non-matching segment - different_segment = SimpleMock( - point1=SimpleMock(x=400, y=400), - point2=SimpleMock(x=500, y=500) - ) + different_segment = SimpleMock(point1=SimpleMock(x=400, y=400), point2=SimpleMock(x=500, y=500)) self.assertTrue(area.uses_segment(matching_segment)) self.assertFalse(area.uses_segment(different_segment)) @@ -148,10 +133,7 @@ def test_get_state_with_two_segments(self) -> None: area = SegmentsBoundedColoredArea(self.segment1, self.segment2) state = area.get_state() - expected_args = { - "segment1": "AB", - "segment2": "CD" - } + expected_args = {"segment1": "AB", "segment2": "CD"} self.assertEqual(state["args"]["segment1"], expected_args["segment1"]) self.assertEqual(state["args"]["segment2"], expected_args["segment2"]) @@ -160,10 +142,7 @@ def test_get_state_with_segment_and_x_axis(self) -> None: area = SegmentsBoundedColoredArea(self.segment1, None) state = area.get_state() - expected_args = { - "segment1": "AB", - "segment2": "x_axis" - } + expected_args = {"segment1": "AB", "segment2": "x_axis"} self.assertEqual(state["args"]["segment1"], expected_args["segment1"]) self.assertEqual(state["args"]["segment2"], expected_args["segment2"]) @@ -199,13 +178,13 @@ def test_no_overlap_segments(self) -> None: segment1 = SimpleMock( name="AB", point1=SimpleMock(x=100, y=200), # x range: 100-150 - point2=SimpleMock(x=150, y=250) + point2=SimpleMock(x=150, y=250), ) segment2 = SimpleMock( name="CD", point1=SimpleMock(x=300, y=180), # x range: 300-400 (no overlap with 100-150) - point2=SimpleMock(x=400, y=220) + point2=SimpleMock(x=400, y=220), ) area = SegmentsBoundedColoredArea(segment1, segment2) @@ -221,13 +200,13 @@ def test_exactly_touching_segments(self) -> None: segment1 = SimpleMock( name="AB", point1=SimpleMock(x=100, y=200), # x range: 100-200 - point2=SimpleMock(x=200, y=250) + point2=SimpleMock(x=200, y=250), ) segment2 = SimpleMock( name="CD", point1=SimpleMock(x=200, y=180), # x range: 200-300 (touches at x=200) - point2=SimpleMock(x=300, y=220) + point2=SimpleMock(x=300, y=220), ) area = SegmentsBoundedColoredArea(segment1, segment2) @@ -243,13 +222,13 @@ def test_vertical_segment_interpolation(self) -> None: normal_segment = SimpleMock( name="AB", point1=SimpleMock(x=100, y=200), # x range: 100-300 - point2=SimpleMock(x=300, y=250) + point2=SimpleMock(x=300, y=250), ) vertical_segment = SimpleMock( name="CD", point1=SimpleMock(x=200, y=100), # Vertical line at x=200 - point2=SimpleMock(x=200, y=300) + point2=SimpleMock(x=200, y=300), ) area = SegmentsBoundedColoredArea(normal_segment, vertical_segment) @@ -275,13 +254,13 @@ def test_segment_with_negative_coordinates(self) -> None: negative_segment1 = SimpleMock( name="AB", point1=SimpleMock(x=-200, y=-100), # x range: -200 to -100 - point2=SimpleMock(x=-100, y=-50) + point2=SimpleMock(x=-100, y=-50), ) negative_segment2 = SimpleMock( name="CD", point1=SimpleMock(x=-150, y=-80), # x range: -150 to -50 (overlap: -150 to -100) - point2=SimpleMock(x=-50, y=-20) + point2=SimpleMock(x=-50, y=-20), ) area = SegmentsBoundedColoredArea(negative_segment1, negative_segment2) @@ -297,13 +276,13 @@ def test_segment_crossing_zero_coordinates(self) -> None: crossing_segment1 = SimpleMock( name="AB", point1=SimpleMock(x=-100, y=-50), # Crosses zero - point2=SimpleMock(x=100, y=50) + point2=SimpleMock(x=100, y=50), ) crossing_segment2 = SimpleMock( name="CD", point1=SimpleMock(x=-50, y=100), # Also crosses zero - point2=SimpleMock(x=150, y=-100) + point2=SimpleMock(x=150, y=-100), ) area = SegmentsBoundedColoredArea(crossing_segment1, crossing_segment2) diff --git a/static/client/client_tests/test_slash_commands.py b/static/client/client_tests/test_slash_commands.py index ff8abf16..471589f4 100644 --- a/static/client/client_tests/test_slash_commands.py +++ b/static/client/client_tests/test_slash_commands.py @@ -31,9 +31,7 @@ def __init__(self) -> None: self.coordinate_mapper.right_bound = 10 self.coordinate_mapper.top_bound = 10 self.coordinate_mapper.bottom_bound = -10 - self.coordinate_mapper.get_visible_bounds = lambda: { - "left": -10, "right": 10, "top": 10, "bottom": -10 - } + self.coordinate_mapper.get_visible_bounds = lambda: {"left": -10, "right": 10, "top": 10, "bottom": -10} # Mock cartesian2axis self.cartesian2axis = MagicMock() @@ -451,7 +449,7 @@ def test_cmd_import_invalid_json(self) -> None: # Error message may contain "Invalid JSON" or other JSON parsing errors self.assertTrue( "Invalid JSON" in result.message or "JSON" in result.message or "Error" in result.message, - f"Expected JSON error message, got: {result.message}" + f"Expected JSON error message, got: {result.message}", ) def test_cmd_list(self) -> None: @@ -688,18 +686,13 @@ def test_initial_state(self) -> None: def test_filter_empty_prefix(self) -> None: """Test filtering with empty prefix shows all commands.""" self.autocomplete.filter("") - self.assertEqual( - len(self.autocomplete.filtered_commands), - len(self.command_handler.get_commands_list()) - ) + self.assertEqual(len(self.autocomplete.filtered_commands), len(self.command_handler.get_commands_list())) def test_filter_matching_prefix(self) -> None: """Test filtering with matching prefix.""" self.autocomplete.filter("he") # Should match /help - self.assertTrue( - any("/help" in cmd for cmd, _ in self.autocomplete.filtered_commands) - ) + self.assertTrue(any("/help" in cmd for cmd, _ in self.autocomplete.filtered_commands)) def test_filter_no_match(self) -> None: """Test filtering with non-matching prefix.""" @@ -709,9 +702,7 @@ def test_filter_no_match(self) -> None: def test_filter_case_insensitive(self) -> None: """Test filtering is case-insensitive.""" self.autocomplete.filter("HE") - self.assertTrue( - any("/help" in cmd for cmd, _ in self.autocomplete.filtered_commands) - ) + self.assertTrue(any("/help" in cmd for cmd, _ in self.autocomplete.filtered_commands)) def test_select_next_wraps(self) -> None: """Test select_next wraps around.""" @@ -845,7 +836,7 @@ def test_long_message_detection_by_length(self) -> None: """Test that messages over 800 chars are considered long.""" # Create a message over 800 characters long_message = "x" * 801 - line_count = long_message.count('\n') + line_count = long_message.count("\n") is_long = len(long_message) > 800 or line_count > 20 self.assertTrue(is_long) @@ -853,49 +844,49 @@ def test_long_message_detection_by_lines(self) -> None: """Test that messages with >20 newlines are considered long.""" # Create a message with 22 lines (21 newlines) to trigger > 20 check long_message = "\n".join(["line"] * 22) - line_count = long_message.count('\n') + line_count = long_message.count("\n") is_long = len(long_message) > 800 or line_count > 20 self.assertTrue(is_long) def test_short_message_not_expandable(self) -> None: """Test that short messages are not considered long.""" short_message = "This is a short message" - line_count = short_message.count('\n') + line_count = short_message.count("\n") is_long = len(short_message) > 800 or line_count > 20 self.assertFalse(is_long) def test_preview_truncation_by_lines(self) -> None: """Test preview is truncated to 10 lines.""" lines = ["line " + str(i) for i in range(20)] - message = '\n'.join(lines) + message = "\n".join(lines) # Simulate preview creation logic - msg_lines = message.split('\n') + msg_lines = message.split("\n") if len(msg_lines) > 10: - preview_text = '\n'.join(msg_lines[:10]) + '\n...' + preview_text = "\n".join(msg_lines[:10]) + "\n..." else: preview_text = message - preview_lines = preview_text.split('\n') + preview_lines = preview_text.split("\n") # Should have 10 lines plus the "..." line self.assertEqual(len(preview_lines), 11) - self.assertTrue(preview_text.endswith('...')) + self.assertTrue(preview_text.endswith("...")) def test_preview_truncation_by_chars(self) -> None: """Test preview is truncated to 500 chars when few lines.""" message = "a" * 600 # Long but single line # Simulate preview creation logic - lines = message.split('\n') + lines = message.split("\n") if len(lines) > 10: - preview_text = '\n'.join(lines[:10]) + '\n...' + preview_text = "\n".join(lines[:10]) + "\n..." elif len(message) > 500: - preview_text = message[:500] + '...' + preview_text = message[:500] + "..." else: preview_text = message self.assertEqual(len(preview_text), 503) # 500 + "..." - self.assertTrue(preview_text.endswith('...')) + self.assertTrue(preview_text.endswith("...")) class TestExportCommandOutput(unittest.TestCase): @@ -917,6 +908,7 @@ def setUp(self) -> None: def test_export_returns_valid_json(self) -> None: """Test /export returns valid JSON string.""" import json + result = self.handler.execute("/export") self.assertTrue(result.success) # The message should be valid JSON @@ -929,6 +921,7 @@ def test_export_returns_valid_json(self) -> None: def test_export_data_matches_message(self) -> None: """Test /export data field matches message content.""" import json + result = self.handler.execute("/export") self.assertTrue(result.success) parsed_message = json.loads(result.message) @@ -972,9 +965,7 @@ def test_status_shows_bounds(self) -> None: self.assertTrue(result.success) # Should contain bound information msg_lower = result.message.lower() - self.assertTrue( - "bound" in msg_lower or "left" in msg_lower or "right" in msg_lower - ) + self.assertTrue("bound" in msg_lower or "left" in msg_lower or "right" in msg_lower) def test_status_shows_coordinate_system(self) -> None: """Test /status shows coordinate system mode.""" diff --git a/static/client/client_tests/test_statistics_distributions.py b/static/client/client_tests/test_statistics_distributions.py index a1905122..4d8693d8 100644 --- a/static/client/client_tests/test_statistics_distributions.py +++ b/static/client/client_tests/test_statistics_distributions.py @@ -28,6 +28,3 @@ def test_default_normal_bounds(self) -> None: left2, right2 = default_normal_bounds(2.0, 0.5, k=2.0) self.assertAlmostEqual(left2, 1.0) self.assertAlmostEqual(right2, 3.0) - - - diff --git a/static/client/client_tests/test_statistics_manager.py b/static/client/client_tests/test_statistics_manager.py index 44ce1984..4da5ee3c 100644 --- a/static/client/client_tests/test_statistics_manager.py +++ b/static/client/client_tests/test_statistics_manager.py @@ -272,7 +272,9 @@ def test_plot_bars_emits_observability_start_and_end_logs(self) -> None: y_base=0.0, ) - self.assertTrue(any("operation': 'plot_bars'" in msg and "stage': 'start'" in msg for msg in logger_spy.messages)) + self.assertTrue( + any("operation': 'plot_bars'" in msg and "stage': 'start'" in msg for msg in logger_spy.messages) + ) self.assertTrue(any("operation': 'plot_bars'" in msg and "stage': 'end'" in msg for msg in logger_spy.messages)) self.assertTrue(any("elapsed_ms" in msg for msg in logger_spy.messages)) @@ -294,7 +296,9 @@ def test_plot_distribution_failure_emits_observability_failure_log(self) -> None bar_count=None, ) - self.assertTrue(any("operation': 'plot_distribution'" in msg and "stage': 'failure'" in msg for msg in logger_spy.messages)) + self.assertTrue( + any("operation': 'plot_distribution'" in msg and "stage': 'failure'" in msg for msg in logger_spy.messages) + ) def test_log_operation_debug_no_logger_is_noop(self) -> None: stats = self.canvas.drawable_manager.statistics_manager diff --git a/static/client/client_tests/test_tangent_manager.py b/static/client/client_tests/test_tangent_manager.py index 473d1376..55104ad8 100644 --- a/static/client/client_tests/test_tangent_manager.py +++ b/static/client/client_tests/test_tangent_manager.py @@ -245,9 +245,11 @@ class TestMathUtilsTangentFunctions(unittest.TestCase): def test_numerical_derivative_at_polynomial(self) -> None: """Test numerical derivative of x^2 at x=3.""" + # y = x^2, y' = 2x, y'(3) = 6 def func(x): return x**2 + deriv = MathUtils.numerical_derivative_at(func, 3.0) self.assertIsNotNone(deriv) self.assertAlmostEqual(deriv, 6.0, places=4) @@ -288,7 +290,7 @@ def test_tangent_line_endpoints_length(self) -> None: slope = 2.0 p1, p2 = MathUtils.tangent_line_endpoints(slope, point, length) - actual_length = math.sqrt((p2[0] - p1[0])**2 + (p2[1] - p1[1])**2) + actual_length = math.sqrt((p2[0] - p1[0]) ** 2 + (p2[1] - p1[1]) ** 2) self.assertAlmostEqual(actual_length, 6.0, places=5) def test_normal_slope_from_horizontal(self) -> None: diff --git a/static/client/client_tests/test_throttle.py b/static/client/client_tests/test_throttle.py index aa15f7e8..50f2345f 100644 --- a/static/client/client_tests/test_throttle.py +++ b/static/client/client_tests/test_throttle.py @@ -23,7 +23,7 @@ def now() -> int: self.mock_window = SimpleMock( setTimeout=SimpleMock(return_value=123), # Return a mock timer ID clearTimeout=SimpleMock(), - performance=self.mock_performance + performance=self.mock_performance, ) # Save original window references @@ -125,6 +125,7 @@ def test_throttle_respects_wait_time(self) -> None: def test_throttle_handles_errors(self) -> None: """Test that throttle function handles errors gracefully.""" + def failing_func() -> None: raise Exception("Test error") diff --git a/static/client/client_tests/test_tool_call_log.py b/static/client/client_tests/test_tool_call_log.py index 1e6abbd9..0eff4ce7 100644 --- a/static/client/client_tests/test_tool_call_log.py +++ b/static/client/client_tests/test_tool_call_log.py @@ -412,5 +412,6 @@ def test_container_preserved_with_tool_log(self) -> None: # because tool call log entries exist ai._remove_empty_response_container() - self.assertIs(ai._stream_message_container, container, - "Container should NOT be removed when tool call log has entries") + self.assertIs( + ai._stream_message_container, container, "Container should NOT be removed when tool call log has entries" + ) diff --git a/static/client/client_tests/test_transformations_manager.py b/static/client/client_tests/test_transformations_manager.py index cbf73d64..ad836dcc 100644 --- a/static/client/client_tests/test_transformations_manager.py +++ b/static/client/client_tests/test_transformations_manager.py @@ -31,7 +31,9 @@ def get_children(self, drawable: object) -> set: return {self._lookup[child_id] for child_id in child_ids if child_id in self._lookup} -def _build_canvas(primary_drawable: object, segments: List[Segment]) -> Tuple[SimpleMock, _IdentityDependencyManager, SimpleMock]: +def _build_canvas( + primary_drawable: object, segments: List[Segment] +) -> Tuple[SimpleMock, _IdentityDependencyManager, SimpleMock]: renderer = SimpleMock() renderer.invalidate_drawable_cache = SimpleMock() @@ -116,9 +118,7 @@ def test_translate_polygon_invalidates_closed_area_children(self) -> None: ) canvas, dependency_manager, renderer = _build_canvas(rectangle, [s1, s2, s3, s4]) - dependency_manager.register_many( - (segment, area) for segment in (s1, s2, s3, s4) - ) + dependency_manager.register_many((segment, area) for segment in (s1, s2, s3, s4)) initial_snapshot = area.get_state()["args"]["geometry_snapshot"] diff --git a/static/client/client_tests/test_transforms.py b/static/client/client_tests/test_transforms.py index 0a2d1f07..e5f6b253 100644 --- a/static/client/client_tests/test_transforms.py +++ b/static/client/client_tests/test_transforms.py @@ -16,6 +16,7 @@ # Shared helpers # --------------------------------------------------------------------------- + class _IdentityDependencyManager: """Minimal dependency manager for manager-level tests.""" diff --git a/static/client/client_tests/test_triangle.py b/static/client/client_tests/test_triangle.py index b1d282f6..f8380ea4 100644 --- a/static/client/client_tests/test_triangle.py +++ b/static/client/client_tests/test_triangle.py @@ -23,7 +23,7 @@ def setUp(self) -> None: zoom_point=Position(1, 1), zoom_direction=1, zoom_step=0.1, - offset=Position(0, 0) # Set to (0,0) for simpler tests + offset=Position(0, 0), # Set to (0,0) for simpler tests ) # Sync canvas state with coordinate mapper @@ -83,7 +83,7 @@ def test_identical_points_not_forming_triangle(self) -> None: self.assertFalse(self.triangle._segments_form_triangle(self.segment1, segment4, segment5)) def test_get_class_name(self) -> None: - self.assertEqual(self.triangle.get_class_name(), 'Triangle') + self.assertEqual(self.triangle.get_class_name(), "Triangle") def test_get_state(self) -> None: state = self.triangle.get_state() @@ -130,4 +130,3 @@ def test_translate_triangle_in_math_space(self) -> None: self.assertEqual((s1p1x, s1p1y), (251, 248)) self.assertEqual((s1p2x, s1p2y), (255, 248)) self.assertEqual((s2p2x, s2p2y), (251, 245)) - diff --git a/static/client/client_tests/test_undo_redo_manager.py b/static/client/client_tests/test_undo_redo_manager.py index 9f5b6583..3dbdc584 100644 --- a/static/client/client_tests/test_undo_redo_manager.py +++ b/static/client/client_tests/test_undo_redo_manager.py @@ -71,4 +71,3 @@ def test_capture_state_returns_deep_copy(self) -> None: self.canvas.computations[0]["result"] = 99 self.assertEqual(snapshot["computations"][0]["result"], 1) - diff --git a/static/client/client_tests/test_vector.py b/static/client/client_tests/test_vector.py index a8890de0..00a5de05 100644 --- a/static/client/client_tests/test_vector.py +++ b/static/client/client_tests/test_vector.py @@ -23,7 +23,7 @@ def setUp(self) -> None: zoom_point=Position(1, 1), zoom_direction=1, zoom_step=0.1, - offset=Position(0, 0) # Set to (0,0) for simpler tests + offset=Position(0, 0), # Set to (0,0) for simpler tests ) # Sync canvas state with coordinate mapper @@ -48,7 +48,7 @@ def test_init(self) -> None: self.assertEqual(self.vector.color, "green") def test_get_class_name(self) -> None: - self.assertEqual(self.vector.get_class_name(), 'Vector') + self.assertEqual(self.vector.get_class_name(), "Vector") def test_get_state(self) -> None: state = self.vector.get_state() @@ -129,4 +129,3 @@ def test_same_name_different_coords_produces_different_state(self) -> None: # But coordinates differ, so render cache should invalidate self.assertNotEqual(state1["_origin_coords"], state2["_origin_coords"]) self.assertNotEqual(state1["_tip_coords"], state2["_tip_coords"]) - diff --git a/static/client/client_tests/test_vector_manager.py b/static/client/client_tests/test_vector_manager.py index c5cfc664..a38f5c4e 100644 --- a/static/client/client_tests/test_vector_manager.py +++ b/static/client/client_tests/test_vector_manager.py @@ -127,4 +127,3 @@ def test_create_vector_from_points_creates_new_if_different_points_same_coords(s if __name__ == "__main__": unittest.main() - diff --git a/static/client/client_tests/test_window_mocks.py b/static/client/client_tests/test_window_mocks.py index a481f010..3b9ed84d 100644 --- a/static/client/client_tests/test_window_mocks.py +++ b/static/client/client_tests/test_window_mocks.py @@ -16,7 +16,7 @@ def setUp(self) -> None: self.mock_window = SimpleMock( setTimeout=SimpleMock(return_value=123), # Return a mock timer ID clearTimeout=SimpleMock(), - performance=self.mock_performance + performance=self.mock_performance, ) # Save original window references @@ -50,8 +50,10 @@ def test_performance_now(self) -> None: def test_set_timeout(self) -> None: """Test that setTimeout stores the callback and returns the expected timer ID.""" + def callback() -> None: return None + wait_time = 100 # Call setTimeout and verify return value @@ -78,6 +80,7 @@ def test_clear_timeout(self) -> None: def test_mock_chain(self) -> None: """Test that the entire mock chain works together.""" + def callback() -> None: return None diff --git a/static/client/client_tests/test_workspace_plots.py b/static/client/client_tests/test_workspace_plots.py index b84724b1..fdddd272 100644 --- a/static/client/client_tests/test_workspace_plots.py +++ b/static/client/client_tests/test_workspace_plots.py @@ -145,5 +145,3 @@ def test_restore_workspace_rebuilds_plot_bars_and_supports_delete_plot(self) -> self.assertNotIn("MyBars", self._names_for_class(canvas2, "DiscretePlot")) self.assertNotIn("Poll", self._names_for_class(canvas2, "BarsPlot")) self.assertNotIn("MyPlot", self._names_for_class(canvas2, "ContinuousPlot")) - - diff --git a/static/client/client_tests/test_zoom.py b/static/client/client_tests/test_zoom.py index ac117ce7..9e4db049 100644 --- a/static/client/client_tests/test_zoom.py +++ b/static/client/client_tests/test_zoom.py @@ -193,10 +193,9 @@ def test_zoom_calls_draw(self) -> None: canvas.draw.assert_called_once() draw_args, draw_kwargs = canvas.draw.calls[0] self.assertEqual(draw_args, ()) - self.assertEqual(draw_kwargs, {'apply_zoom': True}) + self.assertEqual(draw_kwargs, {"apply_zoom": True}) def test_zoom_returns_true(self) -> None: canvas = ZoomTestFixture.create_canvas(640, 480) result = canvas.zoom(center_x=0, center_y=0, range_val=5, range_axis="x") self.assertTrue(result) - diff --git a/static/client/client_tests/tests.py b/static/client/client_tests/tests.py index 04df9be8..d38986f8 100644 --- a/static/client/client_tests/tests.py +++ b/static/client/client_tests/tests.py @@ -258,6 +258,7 @@ ) from .test_result_processor_traced import TestGetResultsTraced + class Tests: """Encapsulates execution and formatting of client-side tests.""" @@ -328,7 +329,7 @@ async def run_tests_async( print("[ClientTests] Async test execution finished.") results = test_runner._format_results_for_ai(combined_result) - results['stopped'] = stopped + results["stopped"] = stopped return results except Exception as exc: # pragma: no cover - defensive path print(f"[ClientTests] Exception during run_tests_async: {repr(exc)}") diff --git a/static/client/command_autocomplete.py b/static/client/command_autocomplete.py index b89fba8d..6a2de6a4 100644 --- a/static/client/command_autocomplete.py +++ b/static/client/command_autocomplete.py @@ -159,6 +159,7 @@ def _on_blur(self, event: Any) -> None: # Use setTimeout to allow click events on popup items to fire first try: from browser import window + window.setTimeout(self._delayed_hide, 150) except Exception: self.hide() @@ -215,7 +216,8 @@ def filter(self, prefix: str) -> None: # Filter by prefix (case-insensitive) prefix_lower = prefix.lower() self.filtered_commands = [ - (cmd, desc) for cmd, desc in all_commands + (cmd, desc) + for cmd, desc in all_commands if cmd[1:].lower().startswith(prefix_lower) # Skip the "/" in comparison ] @@ -266,20 +268,14 @@ def _filter_models(self, model_prefix: str) -> List[Tuple[str, str]]: # Filter by prefix prefix_lower = model_prefix.lower() - filtered = [ - m for m in models - if m.lower().startswith(prefix_lower) - ] + filtered = [m for m in models if m.lower().startswith(prefix_lower)] # If no matches, show all models if not filtered and model_prefix: filtered = models # Build suggestions - suggestions = [ - (f"/model {m}", f"Switch to model '{m}'") - for m in filtered - ] + suggestions = [(f"/model {m}", f"Switch to model '{m}'") for m in filtered] return suggestions if suggestions else [("/model", "No matching models")] @@ -310,10 +306,7 @@ def _filter_workspaces(self, command: str, workspace_prefix: str) -> List[Tuple[ # Filter by prefix prefix_lower = workspace_prefix.lower() - filtered = [ - ws for ws in workspaces - if ws.lower().startswith(prefix_lower) - ] + filtered = [ws for ws in workspaces if ws.lower().startswith(prefix_lower)] # If no matches, show all workspaces if not filtered and workspace_prefix: @@ -464,6 +457,7 @@ def handler(event: Any) -> None: event.stopPropagation() self.selected_index = idx self.confirm_selection() + return handler item.bind("mousedown", make_click_handler(index)) @@ -475,6 +469,7 @@ def make_hover_handler(idx: int) -> Callable[[Any], None]: def handler(event: Any) -> None: if self.selected_index != idx: self._update_selection_visual(idx) + return handler item.bind("mouseenter", make_hover_handler(index)) diff --git a/static/client/coordinate_mapper.py b/static/client/coordinate_mapper.py index ac0d1d61..074c245f 100644 --- a/static/client/coordinate_mapper.py +++ b/static/client/coordinate_mapper.py @@ -257,22 +257,17 @@ def get_visible_bounds(self) -> Dict[str, float]: left_bound, top_bound = self.screen_to_math(0, 0) right_bound, bottom_bound = self.screen_to_math(self.canvas_width, self.canvas_height) - return { - 'left': left_bound, - 'right': right_bound, - 'top': top_bound, - 'bottom': bottom_bound - } + return {"left": left_bound, "right": right_bound, "top": top_bound, "bottom": bottom_bound} def get_visible_width(self) -> float: """Get mathematical width of visible area.""" bounds: Dict[str, float] = self.get_visible_bounds() - return bounds['right'] - bounds['left'] + return bounds["right"] - bounds["left"] def get_visible_height(self) -> float: """Get mathematical height of visible area.""" bounds: Dict[str, float] = self.get_visible_bounds() - return bounds['top'] - bounds['bottom'] + return bounds["top"] - bounds["bottom"] def get_visible_left_bound(self) -> float: """Get mathematical left boundary of visible area. @@ -398,14 +393,14 @@ def get_state(self) -> Dict[str, Any]: dict: Current coordinate mapper state """ return { - 'canvas_width': self.canvas_width, - 'canvas_height': self.canvas_height, - 'scale_factor': self.scale_factor, - 'offset': {'x': self.offset.x, 'y': self.offset.y}, - 'origin': {'x': self.origin.x, 'y': self.origin.y}, - 'zoom_point': {'x': self.zoom_point.x, 'y': self.zoom_point.y}, - 'zoom_direction': self.zoom_direction, - 'zoom_step': self.zoom_step + "canvas_width": self.canvas_width, + "canvas_height": self.canvas_height, + "scale_factor": self.scale_factor, + "offset": {"x": self.offset.x, "y": self.offset.y}, + "origin": {"x": self.origin.x, "y": self.origin.y}, + "zoom_point": {"x": self.zoom_point.x, "y": self.zoom_point.y}, + "zoom_direction": self.zoom_direction, + "zoom_step": self.zoom_step, } def set_state(self, state: Dict[str, Any]) -> None: @@ -415,20 +410,20 @@ def set_state(self, state: Dict[str, Any]) -> None: state (dict): State dictionary with mapper properties """ # Update canvas dimensions if provided - self.canvas_width = state.get('canvas_width', self.canvas_width) - self.canvas_height = state.get('canvas_height', self.canvas_height) + self.canvas_width = state.get("canvas_width", self.canvas_width) + self.canvas_height = state.get("canvas_height", self.canvas_height) - self.scale_factor = state.get('scale_factor', 1.0) - offset_data: Dict[str, float] = state.get('offset', {'x': 0, 'y': 0}) - self.offset = Position(offset_data['x'], offset_data['y']) - origin_data: Dict[str, float] = state.get('origin', {'x': self.canvas_width / 2, 'y': self.canvas_height / 2}) - self.origin = Position(origin_data['x'], origin_data['y']) + self.scale_factor = state.get("scale_factor", 1.0) + offset_data: Dict[str, float] = state.get("offset", {"x": 0, "y": 0}) + self.offset = Position(offset_data["x"], offset_data["y"]) + origin_data: Dict[str, float] = state.get("origin", {"x": self.canvas_width / 2, "y": self.canvas_height / 2}) + self.origin = Position(origin_data["x"], origin_data["y"]) # Zoom state - zoom_point_data: Dict[str, float] = state.get('zoom_point', {'x': 0, 'y': 0}) - self.zoom_point = Position(zoom_point_data['x'], zoom_point_data['y']) - self.zoom_direction = state.get('zoom_direction', 0) - self.zoom_step = state.get('zoom_step', 0.1) + zoom_point_data: Dict[str, float] = state.get("zoom_point", {"x": 0, "y": 0}) + self.zoom_point = Position(zoom_point_data["x"], zoom_point_data["y"]) + self.zoom_direction = state.get("zoom_direction", 0) + self.zoom_step = state.get("zoom_step", 0.1) def sync_from_canvas(self, canvas: "Canvas") -> None: """Synchronize coordinate mapper state with Canvas object. @@ -440,34 +435,34 @@ def sync_from_canvas(self, canvas: "Canvas") -> None: canvas: Canvas object with scale_factor, offset, center, etc. """ # Update basic transformation parameters - self.scale_factor = getattr(canvas, 'scale_factor', 1.0) + self.scale_factor = getattr(canvas, "scale_factor", 1.0) # Handle offset - Canvas uses Position objects - canvas_offset: Any = getattr(canvas, 'offset', None) + canvas_offset: Any = getattr(canvas, "offset", None) if canvas_offset: self.offset = Position(canvas_offset.x, canvas_offset.y) else: self.offset = Position(0, 0) # Update canvas dimensions first if they've changed - if hasattr(canvas, 'width') and hasattr(canvas, 'height'): + if hasattr(canvas, "width") and hasattr(canvas, "height"): self.canvas_width = canvas.width self.canvas_height = canvas.height # Handle origin - Use canvas.center as the base origin (before offset) # Note: cartesian2axis.origin already includes offset via math_to_screen, # so we must use canvas.center to avoid double-counting offset - canvas_center: Any = getattr(canvas, 'center', None) + canvas_center: Any = getattr(canvas, "center", None) if canvas_center: self.origin = Position(canvas_center.x, canvas_center.y) else: self.origin = Position(self.canvas_width / 2, self.canvas_height / 2) # Handle zoom state if available - if hasattr(canvas, 'zoom_point'): + if hasattr(canvas, "zoom_point"): zoom_point: Any = canvas.zoom_point self.zoom_point = Position(zoom_point.x, zoom_point.y) - if hasattr(canvas, 'zoom_direction'): + if hasattr(canvas, "zoom_direction"): self.zoom_direction = canvas.zoom_direction - if hasattr(canvas, 'zoom_step'): + if hasattr(canvas, "zoom_step"): self.zoom_step = canvas.zoom_step diff --git a/static/client/drawables/angle.py b/static/client/drawables/angle.py index 8e114225..f913c442 100644 --- a/static/client/drawables/angle.py +++ b/static/client/drawables/angle.py @@ -38,6 +38,7 @@ from drawables.segment import Segment import utils.math_utils as math_utils + class Angle(Drawable): """Represents an angle formed by two intersecting line segments with arc visualization. @@ -55,6 +56,7 @@ class Angle(Drawable): angle_degrees (float): Display angle (small or reflex based on is_reflex) (arc radius is provided by the renderer; a default constant is used when not specified) """ + def __init__( self, arg1: Segment | Point, @@ -91,7 +93,9 @@ def __init__( ) if not self._segments_form_angle(segment1, segment2): - raise ValueError("The segments do not form a valid angle (must share exactly one vertex and have distinct arms).") + raise ValueError( + "The segments do not form a valid angle (must share exactly one vertex and have distinct arms)." + ) self.segment1: Segment = segment1 self.segment2: Segment = segment2 @@ -100,7 +104,9 @@ def __init__( self.vertex_point: Optional[Point] self.arm1_point: Optional[Point] self.arm2_point: Optional[Point] - self.vertex_point, self.arm1_point, self.arm2_point = self._extract_defining_points(self.segment1, self.segment2) + self.vertex_point, self.arm1_point, self.arm2_point = self._extract_defining_points( + self.segment1, self.segment2 + ) # Name: prefer provided name; otherwise compute deterministically from segment endpoint names computed_name: Optional[str] = None @@ -116,8 +122,8 @@ def __init__( super().__init__(name=final_name, color=color) - self.raw_angle_degrees: Optional[float] = None # To store the fundamental CCW angle (0-360) - self.angle_degrees: Optional[float] = None # To store the display angle (small or reflex) + self.raw_angle_degrees: Optional[float] = None # To store the fundamental CCW angle (0-360) + self.angle_degrees: Optional[float] = None # To store the display angle (small or reflex) self._initialize() @@ -134,24 +140,34 @@ def _segments_form_angle(self, s1: Segment, s2: Segment) -> bool: Validates if two segments can form an angle. They must share exactly one common point (vertex) and form distinct, non-degenerate arms. """ - if not s1 or not s2: return False # Segments must exist - if not hasattr(s1, 'point1') or not hasattr(s1, 'point2') or \ - not hasattr(s2, 'point1') or not hasattr(s2, 'point2'): - return False # Segments must have point attributes - - common_vertex_point_obj: Optional[Point] = self._get_common_vertex(s1, s2) # This returns a Point object or None + if not s1 or not s2: + return False # Segments must exist + if ( + not hasattr(s1, "point1") + or not hasattr(s1, "point2") + or not hasattr(s2, "point1") + or not hasattr(s2, "point2") + ): + return False # Segments must have point attributes + + common_vertex_point_obj: Optional[Point] = self._get_common_vertex( + s1, s2 + ) # This returns a Point object or None if common_vertex_point_obj is None: - return False # Segments do not share exactly one common vertex + return False # Segments do not share exactly one common vertex # Identify the Point objects for the arms arm1_point_obj: Point = s1.point1 if s1.point1 != common_vertex_point_obj else s1.point2 arm2_point_obj: Point = s2.point1 if s2.point1 != common_vertex_point_obj else s2.point2 - return cast(bool, math_utils.MathUtils.are_points_valid_for_angle_geometry( - (common_vertex_point_obj.x, common_vertex_point_obj.y), - (arm1_point_obj.x, arm1_point_obj.y), - (arm2_point_obj.x, arm2_point_obj.y) - )) + return cast( + bool, + math_utils.MathUtils.are_points_valid_for_angle_geometry( + (common_vertex_point_obj.x, common_vertex_point_obj.y), + (arm1_point_obj.x, arm1_point_obj.y), + (arm2_point_obj.x, arm2_point_obj.y), + ), + ) def _extract_defining_points(self, s1: Segment, s2: Segment) -> Tuple[Point, Point, Point]: """Extracts vertex, arm1, and arm2 points. Assumes segments form a valid angle.""" @@ -163,7 +179,9 @@ def _extract_defining_points(self, s1: Segment, s2: Segment) -> Tuple[Point, Poi arm2_p: Point = s2.point1 if s2.point1 != vertex_p else s2.point2 return vertex_p, arm1_p, arm2_p - def _calculate_display_angle(self, raw_angle_degrees: Optional[float], is_reflex: bool, epsilon: float) -> Optional[float]: + def _calculate_display_angle( + self, raw_angle_degrees: Optional[float], is_reflex: bool, epsilon: float + ) -> Optional[float]: """Helper function to calculate the display angle based on raw angle and reflex state.""" if raw_angle_degrees is None: return None @@ -171,17 +189,17 @@ def _calculate_display_angle(self, raw_angle_degrees: Optional[float], is_reflex display_angle: Optional[float] = None if is_reflex: # Calculate reflex angle for display - if abs(raw_angle_degrees) < epsilon: # Raw angle is 0 + if abs(raw_angle_degrees) < epsilon: # Raw angle is 0 display_angle = 360.0 - elif raw_angle_degrees > epsilon and raw_angle_degrees < (180.0 - epsilon): # Raw is (0, 180) + elif raw_angle_degrees > epsilon and raw_angle_degrees < (180.0 - epsilon): # Raw is (0, 180) display_angle = 360.0 - raw_angle_degrees - else: # Raw is [180, 360) + else: # Raw is [180, 360) display_angle = raw_angle_degrees else: # Calculate non-reflex (small) angle for display - if raw_angle_degrees > (180.0 + epsilon): # Raw is (180, 360) + if raw_angle_degrees > (180.0 + epsilon): # Raw is (180, 360) display_angle = 360.0 - raw_angle_degrees - else: # Raw is [0, 180] + else: # Raw is [0, 180] display_angle = raw_angle_degrees return display_angle @@ -198,21 +216,22 @@ def _initialize(self) -> None: arm2_coords: Tuple[float, float] = (self.arm2_point.x, self.arm2_point.y) # Calculate the fundamental CCW angle from arm1 to arm2 (0-360 degrees) - self.raw_angle_degrees = math_utils.MathUtils.calculate_angle_degrees( - vertex_coords, arm1_coords, arm2_coords - ) + self.raw_angle_degrees = math_utils.MathUtils.calculate_angle_degrees(vertex_coords, arm1_coords, arm2_coords) - self.angle_degrees = self._calculate_display_angle(self.raw_angle_degrees, self.is_reflex, math_utils.MathUtils.EPSILON) + self.angle_degrees = self._calculate_display_angle( + self.raw_angle_degrees, self.is_reflex, math_utils.MathUtils.EPSILON + ) # Arc radius comes from renderer (or default constant when not provided) - def get_class_name(self) -> str: - return 'Angle' + return "Angle" # Removed unused _get_drawing_references (renderer derives directly via mapper) - def _calculate_arc_parameters(self, vx: float, vy: float, p1x: float, p1y: float, p2x: float, p2y: float, arc_radius: Optional[float] = None) -> Optional[Dict[str, Any]]: + def _calculate_arc_parameters( + self, vx: float, vy: float, p1x: float, p1y: float, p2x: float, p2y: float, arc_radius: Optional[float] = None + ) -> Optional[Dict[str, Any]]: """ Calculates SVG path parameters for the arc using screen coordinates for positioning and a fixed self.drawn_arc_radius for size. @@ -228,9 +247,9 @@ def _calculate_arc_parameters(self, vx: float, vy: float, p1x: float, p1y: float current_arc_radius: float = arc_radius if arc_radius is not None else DEFAULT_ANGLE_ARC_SCREEN_RADIUS if current_arc_radius <= 0: - return None + return None - epsilon: float = math_utils.MathUtils.EPSILON # Use MathUtils.EPSILON + epsilon: float = math_utils.MathUtils.EPSILON # Use MathUtils.EPSILON # Calculate angles in a y-up frame to preserve mathematical CCW orientation angle_v_p1_rad: float = math.atan2(vy - p1y, p1x - vx) @@ -250,7 +269,7 @@ def _calculate_arc_parameters(self, vx: float, vy: float, p1x: float, p1y: float return { "arc_radius_on_screen": current_arc_radius, - "angle_v_p1_rad": angle_v_p1_rad, # For text positioning + "angle_v_p1_rad": angle_v_p1_rad, # For text positioning "final_sweep_flag": final_sweep_flag, "final_large_arc_flag": final_large_arc_flag, "arc_start_x": arc_start_x, @@ -259,7 +278,9 @@ def _calculate_arc_parameters(self, vx: float, vy: float, p1x: float, p1y: float "arc_end_y": arc_end_y, } - def _get_arc_flags(self, display_angle_degrees: Optional[float], raw_angle_degrees: Optional[float], epsilon: float) -> Tuple[str, str]: + def _get_arc_flags( + self, display_angle_degrees: Optional[float], raw_angle_degrees: Optional[float], epsilon: float + ) -> Tuple[str, str]: """Determines sweep and large-arc flags for drawing the angle arc. Note: SVG uses a y-down coordinate system; sweep-flag=1 draws in the positive-angle (clockwise) direction. @@ -267,14 +288,14 @@ def _get_arc_flags(self, display_angle_degrees: Optional[float], raw_angle_degre when display angle goes the opposite direction (to get the small or reflex complement), use sweep=1. """ if display_angle_degrees is None or raw_angle_degrees is None: - return '0', '0' + return "0", "0" # Large-arc flag: 1 if display angle > 180 (or ~360) if abs(display_angle_degrees - 360.0) < epsilon: - large_arc_flag: str = '1' + large_arc_flag: str = "1" elif display_angle_degrees > 180.0 + epsilon: - large_arc_flag = '1' + large_arc_flag = "1" else: - large_arc_flag = '0' + large_arc_flag = "0" # Determine if display follows raw CCW direction (math-space) same_direction: bool = abs(display_angle_degrees - raw_angle_degrees) < epsilon @@ -285,7 +306,7 @@ def _get_arc_flags(self, display_angle_degrees: Optional[float], raw_angle_degre same_direction = False # In SVG, sweep=0 gives CCW visually (since y-down) - sweep_flag: str = '0' if same_direction else '1' + sweep_flag: str = "0" if same_direction else "1" return sweep_flag, large_arc_flag @@ -300,8 +321,8 @@ def get_state(self) -> Dict[str, Any]: "segment1_name": self.segment1.name, "segment2_name": self.segment2.name, "color": self.color, - "is_reflex": self.is_reflex - } + "is_reflex": self.is_reflex, + }, } def __deepcopy__(self, memo: Dict[int, Any]) -> Any: @@ -310,12 +331,7 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> Any: new_segment1 = deepcopy(self.segment1, memo) new_segment2 = deepcopy(self.segment2, memo) - new_angle: Angle = Angle( - new_segment1, - new_segment2, - color=self.color, - is_reflex=self.is_reflex - ) + new_angle: Angle = Angle(new_segment1, new_segment2, color=self.color, is_reflex=self.is_reflex) memo[id(self)] = new_angle return new_angle @@ -332,7 +348,9 @@ def update_points_based_on_segments(self) -> bool: return False # Re-extract defining points as segments might have changed their internal point references - self.vertex_point, self.arm1_point, self.arm2_point = self._extract_defining_points(self.segment1, self.segment2) + self.vertex_point, self.arm1_point, self.arm2_point = self._extract_defining_points( + self.segment1, self.segment2 + ) self._initialize() return True diff --git a/static/client/drawables/attached_label.py b/static/client/drawables/attached_label.py index f62acff1..e494713f 100644 --- a/static/client/drawables/attached_label.py +++ b/static/client/drawables/attached_label.py @@ -84,5 +84,3 @@ def __init__( self.is_renderable = False except Exception: pass - - diff --git a/static/client/drawables/bar.py b/static/client/drawables/bar.py index 7dcf3e86..83a8f616 100644 --- a/static/client/drawables/bar.py +++ b/static/client/drawables/bar.py @@ -53,7 +53,9 @@ def __init__( label_text: Optional[str] = None, is_renderable: bool = True, ) -> None: - super().__init__(name=name, color=str(stroke_color) if stroke_color is not None else "", is_renderable=is_renderable) + super().__init__( + name=name, color=str(stroke_color) if stroke_color is not None else "", is_renderable=is_renderable + ) self.x_left: float = float(x_left) self.x_right: float = float(x_right) self.y_bottom: float = float(y_bottom) @@ -117,5 +119,3 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> "Bar": ) memo[id(self)] = copied return copied - - diff --git a/static/client/drawables/bars_plot.py b/static/client/drawables/bars_plot.py index fa1f4bc2..ef16a7cf 100644 --- a/static/client/drawables/bars_plot.py +++ b/static/client/drawables/bars_plot.py @@ -121,5 +121,3 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> "BarsPlot": ) memo[id(self)] = copied return copied - - diff --git a/static/client/drawables/circle.py b/static/client/drawables/circle.py index e1960ac3..8bef749a 100644 --- a/static/client/drawables/circle.py +++ b/static/client/drawables/circle.py @@ -45,6 +45,7 @@ class Circle(Drawable): radius (float): Radius in mathematical coordinate units circle_formula (dict): Algebraic circle equation coefficients """ + def __init__(self, center_point: Point, radius: float, color: str = default_color) -> None: """Initialize a circle with center point and radius. @@ -59,7 +60,7 @@ def __init__(self, center_point: Point, radius: float, color: str = default_colo super().__init__(name=self._generate_default_name(), color=color) def get_class_name(self) -> str: - return 'Circle' + return "Circle" def _calculate_circle_algebraic_formula(self) -> Dict[str, float]: x: float = self.center.x @@ -71,7 +72,10 @@ def _calculate_circle_algebraic_formula(self) -> Dict[str, float]: def get_state(self) -> Dict[str, Any]: radius: float = self.radius center: str = self.center.name - state: Dict[str, Any] = {"name": self.name, "args": {"center": center, "radius": radius, "circle_formula": self.circle_formula}} + state: Dict[str, Any] = { + "name": self.name, + "args": {"center": center, "radius": radius, "circle_formula": self.circle_formula}, + } return state def __deepcopy__(self, memo: Dict[int, Any]) -> Any: @@ -106,8 +110,7 @@ def scale(self, sx: float, sy: float, cx: float, cy: float) -> None: raise ValueError("Scale factor must not be zero") if abs(sx - sy) > 1e-9: raise ValueError( - "Non-uniform scaling of a circle is not supported; " - "convert to an ellipse first or use equal sx and sy" + "Non-uniform scaling of a circle is not supported; convert to an ellipse first or use equal sx and sy" ) self.center.scale(sx, sy, cx, cy) self.radius = abs(self.radius * sx) diff --git a/static/client/drawables/circle_arc.py b/static/client/drawables/circle_arc.py index 4b8173b9..3ad10e69 100644 --- a/static/client/drawables/circle_arc.py +++ b/static/client/drawables/circle_arc.py @@ -67,11 +67,7 @@ def __init__( chosen_color = str(color) if color is not None else DEFAULT_CIRCLE_ARC_COLOR computed_name = ( - name - if name - else self._build_default_name( - getattr(point1, "name", "P1"), getattr(point2, "name", "P2") - ) + name if name else self._build_default_name(getattr(point1, "name", "P1"), getattr(point2, "name", "P2")) ) super().__init__(name=computed_name, color=chosen_color) @@ -197,4 +193,3 @@ def reset(self) -> None: """Refresh cached angles when undo/redo restores previous state.""" self.sync_with_circle() self._refresh_angles() - diff --git a/static/client/drawables/closed_shape_colored_area.py b/static/client/drawables/closed_shape_colored_area.py index 8bed8785..36cfeb64 100644 --- a/static/client/drawables/closed_shape_colored_area.py +++ b/static/client/drawables/closed_shape_colored_area.py @@ -10,7 +10,12 @@ import copy from typing import Any, Dict, List, Optional, Tuple -from constants import default_area_fill_color, default_area_opacity, default_closed_shape_resolution, closed_shape_resolution_minimum +from constants import ( + default_area_fill_color, + default_area_opacity, + default_closed_shape_resolution, + closed_shape_resolution_minimum, +) from drawables.circle import Circle from drawables.colored_area import ColoredArea from drawables.ellipse import Ellipse @@ -197,5 +202,3 @@ def _snapshot_geometry(self) -> Dict[str, Any]: ] return snapshot - - diff --git a/static/client/drawables/colored_area.py b/static/client/drawables/colored_area.py index 8509d3ef..5a8b0c9f 100644 --- a/static/client/drawables/colored_area.py +++ b/static/client/drawables/colored_area.py @@ -24,6 +24,7 @@ from drawables.drawable import Drawable + class ColoredArea(Drawable): """Abstract base class for all colored area visualizations between geometric objects. @@ -34,6 +35,7 @@ class ColoredArea(Drawable): opacity (float): Fill opacity value between 0.0 and 1.0 color (str): CSS color value for area fill """ + def __init__(self, name: str, color: str = "lightblue", opacity: float = 0.3) -> None: """Initialize a colored area with basic properties. @@ -46,17 +48,11 @@ def __init__(self, name: str, color: str = "lightblue", opacity: float = 0.3) -> self.opacity: float = opacity def get_class_name(self) -> str: - return 'ColoredArea' + return "ColoredArea" def get_state(self) -> Dict[str, Any]: """Base state that all colored areas share""" - return { - "name": self.name, - "args": { - "color": self.color, - "opacity": self.opacity - } - } + return {"name": self.name, "args": {"color": self.color, "opacity": self.opacity}} def update_color(self, color: str) -> None: """Update the area fill color.""" diff --git a/static/client/drawables/continuous_plot.py b/static/client/drawables/continuous_plot.py index b3103c5a..1c78f3ef 100644 --- a/static/client/drawables/continuous_plot.py +++ b/static/client/drawables/continuous_plot.py @@ -76,6 +76,3 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> "ContinuousPlot": ) memo[id(self)] = copied return copied - - - diff --git a/static/client/drawables/decagon.py b/static/client/drawables/decagon.py index 750aa1ce..0a941ccd 100644 --- a/static/client/drawables/decagon.py +++ b/static/client/drawables/decagon.py @@ -85,4 +85,3 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> Decagon: new_decagon = Decagon(new_segments, color=self.color) memo[id(self)] = new_decagon return new_decagon - diff --git a/static/client/drawables/directed_graph.py b/static/client/drawables/directed_graph.py index 57a0f8c9..735273a0 100644 --- a/static/client/drawables/directed_graph.py +++ b/static/client/drawables/directed_graph.py @@ -49,7 +49,9 @@ def __init__( is_renderable=False, ) self._vectors: List["Vector"] = list(vectors or []) - self._cached_descriptors: Optional[tuple[List[GraphVertexDescriptor], List[GraphEdgeDescriptor], List[List[float]]]] = None + self._cached_descriptors: Optional[ + tuple[List[GraphVertexDescriptor], List[GraphEdgeDescriptor], List[List[float]]] + ] = None @property def directed(self) -> bool: diff --git a/static/client/drawables/discrete_plot.py b/static/client/drawables/discrete_plot.py index ffd270a7..98789c05 100644 --- a/static/client/drawables/discrete_plot.py +++ b/static/client/drawables/discrete_plot.py @@ -103,5 +103,3 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> "DiscretePlot": ) memo[id(self)] = copied return copied - - diff --git a/static/client/drawables/drawable.py b/static/client/drawables/drawable.py index 9f7a4487..136d5973 100644 --- a/static/client/drawables/drawable.py +++ b/static/client/drawables/drawable.py @@ -37,6 +37,7 @@ class Drawable: name (str): Identifier for the object color (str): Color metadata (used by renderers) """ + def __init__(self, name: str = "", color: str = default_color, *, is_renderable: bool = True) -> None: """Initialize a drawable object with basic properties. diff --git a/static/client/drawables/ellipse.py b/static/client/drawables/ellipse.py index a792f64c..c0fcbdb0 100644 --- a/static/client/drawables/ellipse.py +++ b/static/client/drawables/ellipse.py @@ -33,6 +33,7 @@ from drawables.point import Point from utils.math_utils import MathUtils + class Ellipse(Drawable): """Represents an ellipse with center point, dual radii, and rotation angle. @@ -47,7 +48,15 @@ class Ellipse(Drawable): rotation_angle (float): Rotation angle in degrees for ellipse orientation ellipse_formula (dict): Algebraic ellipse equation coefficients """ - def __init__(self, center_point: Point, radius_x: float, radius_y: float, rotation_angle: float = 0, color: str = default_color) -> None: + + def __init__( + self, + center_point: Point, + radius_x: float, + radius_y: float, + rotation_angle: float = 0, + color: str = default_color, + ) -> None: """Initialize an ellipse with center point, radii, and rotation. Args: @@ -65,7 +74,7 @@ def __init__(self, center_point: Point, radius_x: float, radius_y: float, rotati super().__init__(name=self._generate_default_name(), color=color) def get_class_name(self) -> str: - return 'Ellipse' + return "Ellipse" def _calculate_ellipse_algebraic_formula(self) -> Dict[str, float]: x: float = self.center.x @@ -82,8 +91,8 @@ def get_state(self) -> Dict[str, Any]: "radius_x": self.radius_x, "radius_y": self.radius_y, "rotation_angle": self.rotation_angle, - "ellipse_formula": self.ellipse_formula - } + "ellipse_formula": self.ellipse_formula, + }, } return state @@ -93,9 +102,9 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> Any: # Deep copy the center point new_center: Point = deepcopy(self.center, memo) # Create a new Ellipse instance with the copied center point and other properties - new_ellipse: Ellipse = Ellipse(new_center, self.radius_x, self.radius_y, - color=self.color, - rotation_angle=self.rotation_angle) + new_ellipse: Ellipse = Ellipse( + new_center, self.radius_x, self.radius_y, color=self.color, rotation_angle=self.rotation_angle + ) memo[id(self)] = new_ellipse return new_ellipse @@ -136,9 +145,7 @@ def scale(self, sx: float, sy: float, cx: float, cy: float) -> None: uniform = abs(sx - sy) < 1e-9 rotated = (self.rotation_angle % 180) > 1e-9 if not uniform and rotated: - raise ValueError( - "Non-uniform scaling of a rotated ellipse is not supported" - ) + raise ValueError("Non-uniform scaling of a rotated ellipse is not supported") self.center.scale(sx, sy, cx, cy) if uniform: self.radius_x = abs(self.radius_x * sx) diff --git a/static/client/drawables/function.py b/static/client/drawables/function.py index d66503c2..3bb53721 100644 --- a/static/client/drawables/function.py +++ b/static/client/drawables/function.py @@ -43,7 +43,22 @@ class Function(Drawable): estimated_period: Estimated period if is_periodic is True. undefined_at: Explicit list of x values where function is undefined. """ - def __init__(self, function_string: str, name: Optional[str] = None, step: float = default_point_size, color: str = default_color, left_bound: Optional[float] = None, right_bound: Optional[float] = None, vertical_asymptotes: Optional[List[float]] = None, horizontal_asymptotes: Optional[List[float]] = None, point_discontinuities: Optional[List[float]] = None, is_periodic: Optional[bool] = None, estimated_period: Optional[float] = None, undefined_at: Optional[List[float]] = None) -> None: + + def __init__( + self, + function_string: str, + name: Optional[str] = None, + step: float = default_point_size, + color: str = default_color, + left_bound: Optional[float] = None, + right_bound: Optional[float] = None, + vertical_asymptotes: Optional[List[float]] = None, + horizontal_asymptotes: Optional[List[float]] = None, + point_discontinuities: Optional[List[float]] = None, + is_periodic: Optional[bool] = None, + estimated_period: Optional[float] = None, + undefined_at: Optional[List[float]] = None, + ) -> None: self.step: float = step self.left_bound: Optional[float] = left_bound self.right_bound: Optional[float] = right_bound @@ -82,7 +97,7 @@ def function(self, x: float) -> float: """Evaluate the function at x. Returns NaN for explicit undefined points.""" for hole in self.undefined_at: if abs(x - hole) < 1e-12: - return float('nan') + return float("nan") return self._base_function(x) def _detect_periodicity(self) -> None: @@ -95,7 +110,7 @@ def _detect_periodicity(self) -> None: ) def get_class_name(self) -> str: - return 'Function' + return "Function" def get_state(self) -> Dict[str, Any]: function_string: str = self.function_string @@ -104,16 +119,16 @@ def get_state(self) -> Dict[str, Any]: "args": { "function_string": function_string, "left_bound": self.left_bound, - "right_bound": self.right_bound - } + "right_bound": self.right_bound, + }, } # Only include asymptotes and discontinuities lists that have values - if hasattr(self, 'vertical_asymptotes') and self.vertical_asymptotes: + if hasattr(self, "vertical_asymptotes") and self.vertical_asymptotes: state["args"]["vertical_asymptotes"] = self.vertical_asymptotes - if hasattr(self, 'horizontal_asymptotes') and self.horizontal_asymptotes: + if hasattr(self, "horizontal_asymptotes") and self.horizontal_asymptotes: state["args"]["horizontal_asymptotes"] = self.horizontal_asymptotes - if hasattr(self, 'point_discontinuities') and self.point_discontinuities: + if hasattr(self, "point_discontinuities") and self.point_discontinuities: state["args"]["point_discontinuities"] = self.point_discontinuities if self.undefined_at: state["args"]["undefined_at"] = self.undefined_at @@ -130,9 +145,15 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> Any: color=self.color, left_bound=self.left_bound, right_bound=self.right_bound, - vertical_asymptotes=self.vertical_asymptotes.copy() if hasattr(self, 'vertical_asymptotes') and self.vertical_asymptotes is not None else None, - horizontal_asymptotes=self.horizontal_asymptotes.copy() if hasattr(self, 'horizontal_asymptotes') and self.horizontal_asymptotes is not None else None, - point_discontinuities=self.point_discontinuities.copy() if hasattr(self, 'point_discontinuities') and self.point_discontinuities is not None else None, + vertical_asymptotes=self.vertical_asymptotes.copy() + if hasattr(self, "vertical_asymptotes") and self.vertical_asymptotes is not None + else None, + horizontal_asymptotes=self.horizontal_asymptotes.copy() + if hasattr(self, "horizontal_asymptotes") and self.horizontal_asymptotes is not None + else None, + point_discontinuities=self.point_discontinuities.copy() + if hasattr(self, "point_discontinuities") and self.point_discontinuities is not None + else None, is_periodic=self.is_periodic, estimated_period=self.estimated_period, undefined_at=self.undefined_at.copy() if self.undefined_at else None, @@ -163,17 +184,18 @@ def translate(self, x_offset: float, y_offset: float) -> None: # First handle horizontal translation by replacing x with (x - x_offset) if x_offset != 0: import re + # Use all allowed functions from ExpressionValidator protected_funcs: list[str] = sorted(ExpressionValidator.ALLOWED_FUNCTIONS, key=len, reverse=True) # Create a regex pattern that matches standalone x while protecting function names - func_pattern: str = '|'.join(map(re.escape, protected_funcs)) + func_pattern: str = "|".join(map(re.escape, protected_funcs)) # Use word boundaries to match standalone 'x' - pattern: str = rf'\b(x)\b|({func_pattern})' + pattern: str = rf"\b(x)\b|({func_pattern})" def replace_match(match: Any) -> str: if match.group(1): # If it's a standalone 'x' - return f'(x - {x_offset})' + return f"(x - {x_offset})" elif match.group(2): # If it's a function name return cast(str, match.group(2)) # Return the function name unchanged return cast(str, match.group(0)) @@ -221,23 +243,21 @@ def _calculate_asymptotes_and_discontinuities(self) -> None: from utils.math_utils import MathUtils # Calculate asymptotes and discontinuities using MathUtil - self.vertical_asymptotes, self.horizontal_asymptotes, self.point_discontinuities = MathUtils.calculate_asymptotes_and_discontinuities( - self.function_string, - self.left_bound, - self.right_bound + self.vertical_asymptotes, self.horizontal_asymptotes, self.point_discontinuities = ( + MathUtils.calculate_asymptotes_and_discontinuities(self.function_string, self.left_bound, self.right_bound) ) def has_point_discontinuity_between_x(self, x1: float, x2: float) -> bool: """Check if there is a point discontinuity between x1 and x2""" - return (hasattr(self, 'point_discontinuities') and any(x1 < x < x2 for x in self.point_discontinuities)) + return hasattr(self, "point_discontinuities") and any(x1 < x < x2 for x in self.point_discontinuities) def has_vertical_asymptote_between_x(self, x1: float, x2: float) -> bool: """Check if there is a vertical asymptote between x1 and x2""" - return (hasattr(self, 'vertical_asymptotes') and any(x1 <= x < x2 for x in self.vertical_asymptotes)) + return hasattr(self, "vertical_asymptotes") and any(x1 <= x < x2 for x in self.vertical_asymptotes) def get_vertical_asymptote_between_x(self, x1: float, x2: float) -> Optional[float]: """Get the x value of a vertical asymptote between x1 and x2, if any exists""" - if hasattr(self, 'vertical_asymptotes'): + if hasattr(self, "vertical_asymptotes"): for x in self.vertical_asymptotes: if x1 <= x < x2: return x diff --git a/static/client/drawables/function_segment_bounded_colored_area.py b/static/client/drawables/function_segment_bounded_colored_area.py index fee74f05..41c8857e 100644 --- a/static/client/drawables/function_segment_bounded_colored_area.py +++ b/static/client/drawables/function_segment_bounded_colored_area.py @@ -5,13 +5,13 @@ Provides area geometry between a function and a segment in math space; the renderer maps to screen. Key Features: - - Function-to-segment area visualization - - Support for function objects, constants, and x-axis boundaries - - Math-space boundary intersection calculation + - Function-to-segment area visualization + - Support for function objects, constants, and x-axis boundaries + - Math-space boundary intersection calculation Dependencies: - - drawables.colored_area: Base class for area visualization - - drawables.function: Function objects for boundary definitions + - drawables.colored_area: Base class for area visualization + - drawables.function: Function objects for boundary definitions """ from __future__ import annotations @@ -24,194 +24,197 @@ from drawables.segment import Segment from utils.math_utils import MathUtils + class FunctionSegmentBoundedColoredArea(ColoredArea): - """Creates a colored area bounded by a mathematical function and a line segment. - - This class creates a visual representation of the area between a function - and a segment using math-space geometry. The renderer handles mapping to screen. - - Attributes: - func (Function, None, or number): The bounding function - segment (Segment): The bounding line segment - """ - def __init__(self, func: Union[Function, None, float, int], segment: Segment, color: str = "lightblue", opacity: float = 0.3) -> None: - """Initialize a function segment bounded colored area. - - Args: - func (Function, None, or number): The bounding function - segment (Segment): The bounding line segment - color (str): CSS color value for area fill - opacity (float): Opacity value between 0.0 and 1.0 - """ - name = self._generate_name(func, segment) - super().__init__(name=name, color=color, opacity=opacity) - self.func: Union[Function, None, float, int] = func - self.segment: Segment = segment - - def _generate_name(self, func: Union[Function, None, float, int], segment: Segment) -> str: - """Generate a descriptive name for the colored area.""" - f_name: str = self._get_function_display_name(func) - s_name: str = segment.name - return f"area_between_{f_name}_and_{s_name}" - - def _get_function_display_name(self, func: Union[Function, None, float, int]) -> str: - """Extract function name for display purposes.""" - if isinstance(func, Function): - return str(func.name) - elif func is not None and not isinstance(func, (int, float)) and self._is_function_like(func): - return cast(str, func.name) - elif func is None: - return 'x_axis' - else: - return f'y_{func}' - - def get_class_name(self) -> str: - """Return the class name 'FunctionSegmentBoundedColoredArea'.""" - return 'FunctionSegmentBoundedColoredArea' - - def _is_function_like(self, obj: Any) -> bool: - """Check if an object has the necessary attributes to be treated as a function (duck typing).""" - required_attrs = ['name', 'function'] - return all(hasattr(obj, attr) for attr in required_attrs) - - def _get_function_y_at_x(self, x: float) -> Optional[float]: - """Get y value for a given x from the function. Returns math coordinates.""" - if self.func is None: # x-axis - return 0 # Math coordinate: y = 0 - if isinstance(self.func, (int, float)): # constant function - return float(self.func) # Math coordinate: y = constant - if isinstance(self.func, Function) or self._is_function_like(self.func): - return self._calculate_function_y_value(x) # Math coordinate - return None - - def _calculate_function_y_value(self, x: float) -> Optional[float]: - """Calculate y value for Function objects with coordinate conversion.""" - try: - if not isinstance(self.func, Function) and not self._is_function_like(self.func): - return None - if isinstance(self.func, (int, float)) or self.func is None: - return None - # x is already in math coordinates - y: Any = self.func.function(x) - # Return math coordinate - if isinstance(y, (int, float)): - result: float = float(y) - if isinstance(result, float) and (result != result or abs(result) == float('inf')): - return None - return result - return None - except (ValueError, ZeroDivisionError): - return None - - def _get_bounds(self) -> Tuple[float, float]: - """Calculate the left and right bounds for the colored area.""" - # Get segment bounds - seg_left: float - seg_right: float - seg_left, seg_right = self._get_segment_bounds() - - # For function bounds - if isinstance(self.func, Function) or self._is_function_like(self.func): - return self._get_intersection_bounds(seg_left, seg_right) - else: - # For x-axis or constant function, use segment bounds - return seg_left, seg_right - - def _get_segment_bounds(self) -> Tuple[float, float]: - """Get the left and right bounds of the segment.""" - # Use math-space x directly from points - x1: float - x2: float - x1, x2 = self.segment.point1.x, self.segment.point2.x - return min(x1, x2), max(x1, x2) - - def _get_intersection_bounds(self, seg_left: float, seg_right: float) -> Tuple[float, float]: - """Get intersection of segment and function bounds.""" - if not isinstance(self.func, Function) and not self._is_function_like(self.func): - return seg_left, seg_right - if isinstance(self.func, (int, float)) or self.func is None: - return seg_left, seg_right - func_left: Optional[float] = self.func.left_bound if hasattr(self.func, 'left_bound') else None - func_right: Optional[float] = self.func.right_bound if hasattr(self.func, 'right_bound') else None - # Use intersection of bounds - left_bound: float = max(seg_left, func_left) if func_left is not None else seg_left - right_bound: float = min(seg_right, func_right) if func_right is not None else seg_right - return left_bound, right_bound - - def _generate_segment_points(self) -> List[Tuple[float, float]]: - """Generate points for the segment path (in reverse order).""" - return [(self.segment.point2.x, self.segment.point2.y), - (self.segment.point1.x, self.segment.point1.y)] - - def _generate_function_points(self, left_bound: float, right_bound: float, num_points: int, dx: float) -> List[Tuple[float, float]]: - """Generate math-space points; renderer does mapping.""" - points: List[Tuple[float, float]] = [] - for i in range(num_points): - x_math: float = left_bound + i * dx - y_math: Optional[float] = self._get_function_y_at_x(x_math) - if y_math is not None: - points.append((x_math, y_math)) - return points - - def uses_segment(self, segment: Segment) -> bool: - """Check if this colored area uses a specific segment. - - Supports comparison in math space. - """ - try: - def same_point(ax: float, ay: float, bx: float, by: float) -> bool: - return cast(bool, abs(ax - bx) < MathUtils.EPSILON and abs(ay - by) < MathUtils.EPSILON) - - # Math space comparison (both orders) - a1x: float - a1y: float - a2x: float - a2y: float - b1x: float - b1y: float - b2x: float - b2y: float - a1x, a1y = self.segment.point1.x, self.segment.point1.y - a2x, a2y = self.segment.point2.x, self.segment.point2.y - b1x, b1y = segment.point1.x, segment.point1.y - b2x, b2y = segment.point2.x, segment.point2.y - - if (same_point(a1x, a1y, b1x, b1y) and same_point(a2x, a2y, b2x, b2y)): - return True - if (same_point(a1x, a1y, b2x, b2y) and same_point(a2x, a2y, b1x, b1y)): - return True - except Exception: - return False - return False - - def get_state(self) -> Dict[str, Any]: - """Serialize function segment bounded area state for persistence.""" - state: Dict[str, Any] = super().get_state() - func_name: str - if isinstance(self.func, Function) or self._is_function_like(self.func): - if isinstance(self.func, (int, float)) or self.func is None: - func_name = str(self.func) - else: - func_name = cast(str, self.func.name) - else: - func_name = str(self.func) - state["args"].update({ - "func": func_name, - "segment": self.segment.name - }) - return state - - def __deepcopy__(self, memo: Dict[int, Any]) -> Any: - """Create a deep copy for undo/redo functionality.""" - if id(self) in memo: - return cast(FunctionSegmentBoundedColoredArea, memo[id(self)]) - - new_area: FunctionSegmentBoundedColoredArea = FunctionSegmentBoundedColoredArea( - func=copy.deepcopy(self.func, memo), - segment=copy.deepcopy(self.segment, memo), - color=self.color, - opacity=self.opacity - ) - new_area.name = self.name - memo[id(self)] = new_area - return new_area + """Creates a colored area bounded by a mathematical function and a line segment. + + This class creates a visual representation of the area between a function + and a segment using math-space geometry. The renderer handles mapping to screen. + + Attributes: + func (Function, None, or number): The bounding function + segment (Segment): The bounding line segment + """ + + def __init__( + self, func: Union[Function, None, float, int], segment: Segment, color: str = "lightblue", opacity: float = 0.3 + ) -> None: + """Initialize a function segment bounded colored area. + + Args: + func (Function, None, or number): The bounding function + segment (Segment): The bounding line segment + color (str): CSS color value for area fill + opacity (float): Opacity value between 0.0 and 1.0 + """ + name = self._generate_name(func, segment) + super().__init__(name=name, color=color, opacity=opacity) + self.func: Union[Function, None, float, int] = func + self.segment: Segment = segment + + def _generate_name(self, func: Union[Function, None, float, int], segment: Segment) -> str: + """Generate a descriptive name for the colored area.""" + f_name: str = self._get_function_display_name(func) + s_name: str = segment.name + return f"area_between_{f_name}_and_{s_name}" + + def _get_function_display_name(self, func: Union[Function, None, float, int]) -> str: + """Extract function name for display purposes.""" + if isinstance(func, Function): + return str(func.name) + elif func is not None and not isinstance(func, (int, float)) and self._is_function_like(func): + return cast(str, func.name) + elif func is None: + return "x_axis" + else: + return f"y_{func}" + + def get_class_name(self) -> str: + """Return the class name 'FunctionSegmentBoundedColoredArea'.""" + return "FunctionSegmentBoundedColoredArea" + + def _is_function_like(self, obj: Any) -> bool: + """Check if an object has the necessary attributes to be treated as a function (duck typing).""" + required_attrs = ["name", "function"] + return all(hasattr(obj, attr) for attr in required_attrs) + + def _get_function_y_at_x(self, x: float) -> Optional[float]: + """Get y value for a given x from the function. Returns math coordinates.""" + if self.func is None: # x-axis + return 0 # Math coordinate: y = 0 + if isinstance(self.func, (int, float)): # constant function + return float(self.func) # Math coordinate: y = constant + if isinstance(self.func, Function) or self._is_function_like(self.func): + return self._calculate_function_y_value(x) # Math coordinate + return None + + def _calculate_function_y_value(self, x: float) -> Optional[float]: + """Calculate y value for Function objects with coordinate conversion.""" + try: + if not isinstance(self.func, Function) and not self._is_function_like(self.func): + return None + if isinstance(self.func, (int, float)) or self.func is None: + return None + # x is already in math coordinates + y: Any = self.func.function(x) + # Return math coordinate + if isinstance(y, (int, float)): + result: float = float(y) + if isinstance(result, float) and (result != result or abs(result) == float("inf")): + return None + return result + return None + except (ValueError, ZeroDivisionError): + return None + + def _get_bounds(self) -> Tuple[float, float]: + """Calculate the left and right bounds for the colored area.""" + # Get segment bounds + seg_left: float + seg_right: float + seg_left, seg_right = self._get_segment_bounds() + + # For function bounds + if isinstance(self.func, Function) or self._is_function_like(self.func): + return self._get_intersection_bounds(seg_left, seg_right) + else: + # For x-axis or constant function, use segment bounds + return seg_left, seg_right + + def _get_segment_bounds(self) -> Tuple[float, float]: + """Get the left and right bounds of the segment.""" + # Use math-space x directly from points + x1: float + x2: float + x1, x2 = self.segment.point1.x, self.segment.point2.x + return min(x1, x2), max(x1, x2) + + def _get_intersection_bounds(self, seg_left: float, seg_right: float) -> Tuple[float, float]: + """Get intersection of segment and function bounds.""" + if not isinstance(self.func, Function) and not self._is_function_like(self.func): + return seg_left, seg_right + if isinstance(self.func, (int, float)) or self.func is None: + return seg_left, seg_right + func_left: Optional[float] = self.func.left_bound if hasattr(self.func, "left_bound") else None + func_right: Optional[float] = self.func.right_bound if hasattr(self.func, "right_bound") else None + # Use intersection of bounds + left_bound: float = max(seg_left, func_left) if func_left is not None else seg_left + right_bound: float = min(seg_right, func_right) if func_right is not None else seg_right + return left_bound, right_bound + + def _generate_segment_points(self) -> List[Tuple[float, float]]: + """Generate points for the segment path (in reverse order).""" + return [(self.segment.point2.x, self.segment.point2.y), (self.segment.point1.x, self.segment.point1.y)] + + def _generate_function_points( + self, left_bound: float, right_bound: float, num_points: int, dx: float + ) -> List[Tuple[float, float]]: + """Generate math-space points; renderer does mapping.""" + points: List[Tuple[float, float]] = [] + for i in range(num_points): + x_math: float = left_bound + i * dx + y_math: Optional[float] = self._get_function_y_at_x(x_math) + if y_math is not None: + points.append((x_math, y_math)) + return points + + def uses_segment(self, segment: Segment) -> bool: + """Check if this colored area uses a specific segment. + + Supports comparison in math space. + """ + try: + + def same_point(ax: float, ay: float, bx: float, by: float) -> bool: + return cast(bool, abs(ax - bx) < MathUtils.EPSILON and abs(ay - by) < MathUtils.EPSILON) + + # Math space comparison (both orders) + a1x: float + a1y: float + a2x: float + a2y: float + b1x: float + b1y: float + b2x: float + b2y: float + a1x, a1y = self.segment.point1.x, self.segment.point1.y + a2x, a2y = self.segment.point2.x, self.segment.point2.y + b1x, b1y = segment.point1.x, segment.point1.y + b2x, b2y = segment.point2.x, segment.point2.y + + if same_point(a1x, a1y, b1x, b1y) and same_point(a2x, a2y, b2x, b2y): + return True + if same_point(a1x, a1y, b2x, b2y) and same_point(a2x, a2y, b1x, b1y): + return True + except Exception: + return False + return False + + def get_state(self) -> Dict[str, Any]: + """Serialize function segment bounded area state for persistence.""" + state: Dict[str, Any] = super().get_state() + func_name: str + if isinstance(self.func, Function) or self._is_function_like(self.func): + if isinstance(self.func, (int, float)) or self.func is None: + func_name = str(self.func) + else: + func_name = cast(str, self.func.name) + else: + func_name = str(self.func) + state["args"].update({"func": func_name, "segment": self.segment.name}) + return state + + def __deepcopy__(self, memo: Dict[int, Any]) -> Any: + """Create a deep copy for undo/redo functionality.""" + if id(self) in memo: + return cast(FunctionSegmentBoundedColoredArea, memo[id(self)]) + + new_area: FunctionSegmentBoundedColoredArea = FunctionSegmentBoundedColoredArea( + func=copy.deepcopy(self.func, memo), + segment=copy.deepcopy(self.segment, memo), + color=self.color, + opacity=self.opacity, + ) + new_area.name = self.name + memo[id(self)] = new_area + return new_area diff --git a/static/client/drawables/functions_bounded_colored_area.py b/static/client/drawables/functions_bounded_colored_area.py index 626f7a76..cb065889 100644 --- a/static/client/drawables/functions_bounded_colored_area.py +++ b/static/client/drawables/functions_bounded_colored_area.py @@ -40,8 +40,16 @@ class FunctionsBoundedColoredArea(ColoredArea): num_sample_points (int): Number of points for path generation """ - def __init__(self, func1: Union[Function, None, float, int], func2: Optional[Union[Function, float, int]] = None, left_bound: Optional[float] = None, right_bound: Optional[float] = None, - color: str = "lightblue", opacity: float = 0.3, num_sample_points: int = 100) -> None: + def __init__( + self, + func1: Union[Function, None, float, int], + func2: Optional[Union[Function, float, int]] = None, + left_bound: Optional[float] = None, + right_bound: Optional[float] = None, + color: str = "lightblue", + opacity: float = 0.3, + num_sample_points: int = 100, + ) -> None: """Initialize a functions bounded colored area. Args: @@ -62,14 +70,31 @@ def __init__(self, func1: Union[Function, None, float, int], func2: Optional[Uni self.right_bound: Optional[float] = right_bound self.num_sample_points: int = num_sample_points - def _validate_parameters(self, func1: Union[Function, None, float, int], func2: Optional[Union[Function, float, int]], left_bound: Optional[float], right_bound: Optional[float], num_sample_points: int) -> None: + def _validate_parameters( + self, + func1: Union[Function, None, float, int], + func2: Optional[Union[Function, float, int]], + left_bound: Optional[float], + right_bound: Optional[float], + num_sample_points: int, + ) -> None: """Validate input parameters for function bounded area creation.""" # Validate that func1 is provided in valid format (use duck typing for testing) - if func1 is not None and not isinstance(func1, (int, float)) and not isinstance(func1, Function) and not self._is_function_like(func1): + if ( + func1 is not None + and not isinstance(func1, (int, float)) + and not isinstance(func1, Function) + and not self._is_function_like(func1) + ): raise ValueError("func1 must be provided as a Function, None, or a number") # Validate func2 type if provided (use duck typing for testing) - if func2 is not None and not isinstance(func2, (int, float)) and not isinstance(func2, Function) and not self._is_function_like(func2): + if ( + func2 is not None + and not isinstance(func2, (int, float)) + and not isinstance(func2, Function) + and not self._is_function_like(func2) + ): raise ValueError("func2 must be a Function, None, or a number") # Validate bounds if provided @@ -94,7 +119,7 @@ def _validate_parameters(self, func1: Union[Function, None, float, int], func2: def _is_function_like(self, obj: Any) -> bool: """Check if an object has the necessary attributes to be treated as a function (duck typing).""" - required_attrs = ['name', 'function'] + required_attrs = ["name", "function"] return all(hasattr(obj, attr) for attr in required_attrs) def _is_function_or_function_like(self, obj: Any) -> bool: @@ -108,11 +133,13 @@ def _get_function_name(self, func: Union[Function, None, float, int]) -> str: elif func is not None and not isinstance(func, (int, float)) and self._is_function_like(func): return cast(str, func.name) elif func is None: - return 'x_axis' + return "x_axis" else: - return f'y_{func}' + return f"y_{func}" - def _generate_name(self, func1: Union[Function, None, float, int], func2: Optional[Union[Function, float, int]]) -> str: + def _generate_name( + self, func1: Union[Function, None, float, int], func2: Optional[Union[Function, float, int]] + ) -> str: """Generate a descriptive name for the colored area.""" f1_name: str = self._get_function_name(func1) f2_name: str = self._get_function_name(func2) @@ -120,7 +147,7 @@ def _generate_name(self, func1: Union[Function, None, float, int], func2: Option def get_class_name(self) -> str: """Return the class name 'FunctionsBoundedColoredArea'.""" - return 'FunctionsBoundedColoredArea' + return "FunctionsBoundedColoredArea" def _get_function_y_at_x(self, func: Union[Function, None, float, int], x: float) -> Optional[float]: """Return math-space y for given x; mapping to screen is renderer's job.""" @@ -133,14 +160,16 @@ def _get_function_y_at_x(self, func: Union[Function, None, float, int], x: float y: Any = func.function(x) if y is None or not isinstance(y, (int, float)): return None - if isinstance(y, float) and (y != y or abs(y) == float('inf')): + if isinstance(y, float) and (y != y or abs(y) == float("inf")): return None return float(y) return None except (ValueError, ZeroDivisionError, TypeError): return None - def _apply_function_bounds(self, bounds: List[Optional[float]], func: Union[Function, None, float, int]) -> List[Optional[float]]: + def _apply_function_bounds( + self, bounds: List[Optional[float]], func: Union[Function, None, float, int] + ) -> List[Optional[float]]: """ Apply bounds from a function if it has defined bounds. @@ -160,7 +189,12 @@ def _apply_function_bounds(self, bounds: List[Optional[float]], func: Union[Func return bounds if isinstance(func, (int, float)) or func is None: return bounds - if hasattr(func, 'left_bound') and hasattr(func, 'right_bound') and func.left_bound is not None and func.right_bound is not None: + if ( + hasattr(func, "left_bound") + and hasattr(func, "right_bound") + and func.left_bound is not None + and func.right_bound is not None + ): if bounds[0] is None: bounds[0] = func.left_bound else: @@ -249,26 +283,35 @@ def _has_asymptote_at(self, func: Union[Function, None, float, int], x_orig: flo return False # Manual asymptote detection for functions with tangent asymptotes (e.g. f3) if isinstance(func, Function) or self._is_function_like(func): - if hasattr(func, 'name') and func.name == 'f3': + if hasattr(func, "name") and func.name == "f3": import math + has_asym: bool = False # Check known asymptote positions for tan(x/100) for n in range(-5, 6): # Check a reasonable range - asym_x: float = 100 * (math.pi/2 + n * math.pi) + asym_x: float = 100 * (math.pi / 2 + n * math.pi) # Only consider very, very close to asymptote (20% of dx) if abs(x_orig - asym_x) < dx * 0.2: has_asym = True break return has_asym # Default asymptote detection for other functions - if hasattr(func, 'has_vertical_asymptote_between_x'): + if hasattr(func, "has_vertical_asymptote_between_x"): return cast(bool, func.has_vertical_asymptote_between_x(x_orig - dx, x_orig + dx)) return False except Exception: # If asymptote detection fails, assume no asymptote return False - def _generate_path(self, func: Union[Function, None, float, int], left_bound: float, right_bound: float, dx: float, num_points: int, reverse: bool = False) -> List[Tuple[float, float]]: + def _generate_path( + self, + func: Union[Function, None, float, int], + left_bound: float, + right_bound: float, + dx: float, + num_points: int, + reverse: bool = False, + ) -> List[Tuple[float, float]]: """ Generate a path of points for a function between bounds. @@ -328,7 +371,9 @@ def _generate_path(self, func: Union[Function, None, float, int], left_bound: fl return points - def _get_function_y_at_x_with_asymptote_handling(self, func: Union[Function, None, float, int], x_orig: float, dx: float) -> Optional[float]: + def _get_function_y_at_x_with_asymptote_handling( + self, func: Union[Function, None, float, int], x_orig: float, dx: float + ) -> Optional[float]: """ Get y value for a function at x with asymptote handling. @@ -375,25 +420,25 @@ def _get_function_y_at_x_with_asymptote_handling(self, func: Union[Function, Non # Reject NaN or infinite values if isinstance(y, float): - if y != y or abs(y) == float('inf'): + if y != y or abs(y) == float("inf"): return None return float(y) except (ValueError, ZeroDivisionError, TypeError, OverflowError): return None - - def get_state(self) -> Dict[str, Any]: """Serialize functions bounded area state for persistence.""" state: Dict[str, Any] = cast(Dict[str, Any], super().get_state()) - state["args"].update({ - "func1": self._get_function_name(self.func1), - "func2": self._get_function_name(self.func2), - "left_bound": self.left_bound, - "right_bound": self.right_bound, - "num_sample_points": self.num_sample_points - }) + state["args"].update( + { + "func1": self._get_function_name(self.func1), + "func2": self._get_function_name(self.func2), + "left_bound": self.left_bound, + "right_bound": self.right_bound, + "num_sample_points": self.num_sample_points, + } + ) return state def update_left_bound(self, left_bound: Optional[float]) -> None: diff --git a/static/client/drawables/generic_polygon.py b/static/client/drawables/generic_polygon.py index 3cb50d0e..f49d0088 100644 --- a/static/client/drawables/generic_polygon.py +++ b/static/client/drawables/generic_polygon.py @@ -87,4 +87,3 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> GenericPolygon: new_polygon = GenericPolygon(new_segments, color=self.color) memo[id(self)] = new_polygon return new_polygon - diff --git a/static/client/drawables/heptagon.py b/static/client/drawables/heptagon.py index 16711f00..8fe4cbc8 100644 --- a/static/client/drawables/heptagon.py +++ b/static/client/drawables/heptagon.py @@ -85,4 +85,3 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> Heptagon: new_heptagon = Heptagon(new_segments, color=self.color) memo[id(self)] = new_heptagon return new_heptagon - diff --git a/static/client/drawables/hexagon.py b/static/client/drawables/hexagon.py index fc7fa251..f371b3f6 100644 --- a/static/client/drawables/hexagon.py +++ b/static/client/drawables/hexagon.py @@ -85,4 +85,3 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> Hexagon: new_hexagon = Hexagon(new_segments, color=self.color) memo[id(self)] = new_hexagon return new_hexagon - diff --git a/static/client/drawables/label.py b/static/client/drawables/label.py index 00d3dac5..f6332210 100644 --- a/static/client/drawables/label.py +++ b/static/client/drawables/label.py @@ -47,9 +47,7 @@ def __init__( self._text: str = "" self._lines: List[str] = [] self._rotation_degrees: float = ( - float(rotation_degrees) - if rotation_degrees is not None - else float(default_label_rotation_degrees) + float(rotation_degrees) if rotation_degrees is not None else float(default_label_rotation_degrees) ) self._reference_scale_factor: float = self._normalize_reference_scale(reference_scale_factor) self._visible: bool = bool(visible) diff --git a/static/client/drawables/label_render_mode.py b/static/client/drawables/label_render_mode.py index bb89924c..740a0861 100644 --- a/static/client/drawables/label_render_mode.py +++ b/static/client/drawables/label_render_mode.py @@ -124,6 +124,3 @@ def from_state(cls, raw: Any) -> "_ScreenOffsetLabelMode": "world": _WorldLabelMode, "screen_offset": _ScreenOffsetLabelMode, } - - - diff --git a/static/client/drawables/nonagon.py b/static/client/drawables/nonagon.py index 7fc2c61c..5dad78f4 100644 --- a/static/client/drawables/nonagon.py +++ b/static/client/drawables/nonagon.py @@ -85,4 +85,3 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> Nonagon: new_nonagon = Nonagon(new_segments, color=self.color) memo[id(self)] = new_nonagon return new_nonagon - diff --git a/static/client/drawables/octagon.py b/static/client/drawables/octagon.py index 5a11f5a2..6acc5ed0 100644 --- a/static/client/drawables/octagon.py +++ b/static/client/drawables/octagon.py @@ -85,4 +85,3 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> Octagon: new_octagon = Octagon(new_segments, color=self.color) memo[id(self)] = new_octagon return new_octagon - diff --git a/static/client/drawables/parametric_function.py b/static/client/drawables/parametric_function.py index e988099a..3b8d9e4e 100644 --- a/static/client/drawables/parametric_function.py +++ b/static/client/drawables/parametric_function.py @@ -68,12 +68,14 @@ def __init__( try: self.x_expression: str = ExpressionValidator.fix_math_expression(x_expression) self.y_expression: str = ExpressionValidator.fix_math_expression(y_expression) - self._x_function: Callable[[float], float] = ExpressionValidator.parse_parametric_expression(self.x_expression) - self._y_function: Callable[[float], float] = ExpressionValidator.parse_parametric_expression(self.y_expression) - except Exception as e: - raise ValueError( - f"Failed to parse parametric expressions x='{x_expression}', y='{y_expression}': {str(e)}" + self._x_function: Callable[[float], float] = ExpressionValidator.parse_parametric_expression( + self.x_expression + ) + self._y_function: Callable[[float], float] = ExpressionValidator.parse_parametric_expression( + self.y_expression ) + except Exception as e: + raise ValueError(f"Failed to parse parametric expressions x='{x_expression}', y='{y_expression}': {str(e)}") super().__init__(name=name or "p", color=color) @@ -116,7 +118,7 @@ def get_state(self) -> Dict[str, Any]: "t_min": self.t_min, "t_max": self.t_max, "color": self.color, - } + }, } def __deepcopy__(self, memo: Dict[int, Any]) -> "ParametricFunction": diff --git a/static/client/drawables/pentagon.py b/static/client/drawables/pentagon.py index e8aabe32..292fcf74 100644 --- a/static/client/drawables/pentagon.py +++ b/static/client/drawables/pentagon.py @@ -86,4 +86,3 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> Pentagon: new_pentagon = Pentagon(new_segments, color=self.color) memo[id(self)] = new_pentagon return new_pentagon - diff --git a/static/client/drawables/piecewise_function.py b/static/client/drawables/piecewise_function.py index d6ae453c..8e13dc72 100644 --- a/static/client/drawables/piecewise_function.py +++ b/static/client/drawables/piecewise_function.py @@ -105,21 +105,24 @@ def _parse_intervals(self, pieces: List[Dict[str, Any]]) -> None: fixed_expr = ExpressionValidator.fix_math_expression(expression) evaluator = ExpressionValidator.parse_function_string(expression) - self.intervals.append(PiecewiseFunctionInterval( - expression=fixed_expr, - evaluator=evaluator, - left=piece_def.get("left"), - right=piece_def.get("right"), - left_inclusive=piece_def.get("left_inclusive", True), - right_inclusive=piece_def.get("right_inclusive", False), - undefined_at=piece_def.get("undefined_at"), - )) + self.intervals.append( + PiecewiseFunctionInterval( + expression=fixed_expr, + evaluator=evaluator, + left=piece_def.get("left"), + right=piece_def.get("right"), + left_inclusive=piece_def.get("left_inclusive", True), + right_inclusive=piece_def.get("right_inclusive", False), + undefined_at=piece_def.get("undefined_at"), + ) + ) def _sort_intervals(self) -> None: """Sort intervals by their left bound for consistent evaluation.""" + def sort_key(interval: PiecewiseFunctionInterval) -> float: if interval.left is None: - return float('-inf') + return float("-inf") return interval.left self.intervals.sort(key=sort_key) @@ -241,10 +244,10 @@ def function(self, x: float) -> float: for interval in self.intervals: if interval.contains(x): return interval.evaluate(x) - return float('nan') + return float("nan") def get_class_name(self) -> str: - return 'PiecewiseFunction' + return "PiecewiseFunction" def get_state(self) -> Dict[str, Any]: """Serialize function state for persistence.""" @@ -254,7 +257,7 @@ def get_state(self) -> Dict[str, Any]: "name": self.name, "args": { "pieces": pieces_data, - } + }, } if self.vertical_asymptotes: @@ -296,15 +299,13 @@ def translate(self, x_offset: float, y_offset: float) -> None: new_expr = interval.expression if x_offset != 0: - protected_funcs: list[str] = sorted( - ExpressionValidator.ALLOWED_FUNCTIONS, key=len, reverse=True - ) - func_pattern: str = '|'.join(map(re.escape, protected_funcs)) - pattern: str = rf'\b(x)\b|({func_pattern})' + protected_funcs: list[str] = sorted(ExpressionValidator.ALLOWED_FUNCTIONS, key=len, reverse=True) + func_pattern: str = "|".join(map(re.escape, protected_funcs)) + pattern: str = rf"\b(x)\b|({func_pattern})" def replace_match(match: Any) -> str: if match.group(1): - return f'(x - {x_offset})' + return f"(x - {x_offset})" elif match.group(2): return cast(str, match.group(2)) return cast(str, match.group(0)) @@ -318,14 +319,16 @@ def replace_match(match: Any) -> str: new_right = interval.right + x_offset if interval.right is not None else None new_undefined_at = [h + x_offset for h in interval.undefined_at] if interval.undefined_at else None - new_pieces.append({ - "expression": new_expr, - "left": new_left, - "right": new_right, - "left_inclusive": interval.left_inclusive, - "right_inclusive": interval.right_inclusive, - "undefined_at": new_undefined_at, - }) + new_pieces.append( + { + "expression": new_expr, + "left": new_left, + "right": new_right, + "left_inclusive": interval.left_inclusive, + "right_inclusive": interval.right_inclusive, + "undefined_at": new_undefined_at, + } + ) self.intervals = [] self._parse_intervals(new_pieces) diff --git a/static/client/drawables/piecewise_function_interval.py b/static/client/drawables/piecewise_function_interval.py index 4b8eeec2..46fdbb04 100644 --- a/static/client/drawables/piecewise_function_interval.py +++ b/static/client/drawables/piecewise_function_interval.py @@ -70,7 +70,7 @@ def is_undefined_at(self, x: float) -> bool: def evaluate(self, x: float) -> float: """Evaluate this interval's expression at x. Returns NaN for undefined points.""" if self.is_undefined_at(x): - return float('nan') + return float("nan") return self.evaluator(x) def to_dict(self) -> Dict[str, Any]: @@ -85,4 +85,3 @@ def to_dict(self) -> Dict[str, Any]: if self.undefined_at: result["undefined_at"] = self.undefined_at return result - diff --git a/static/client/drawables/plot.py b/static/client/drawables/plot.py index 4cdaebbf..a939cabc 100644 --- a/static/client/drawables/plot.py +++ b/static/client/drawables/plot.py @@ -81,4 +81,3 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> "Plot": ) memo[id(self)] = copied return copied - diff --git a/static/client/drawables/point.py b/static/client/drawables/point.py index 6e690163..b7d2d1ba 100644 --- a/static/client/drawables/point.py +++ b/static/client/drawables/point.py @@ -31,6 +31,7 @@ from drawables.position import Position from utils.math_utils import MathUtils + class Point(Drawable): """Represents a point in 2D mathematical space with coordinate tracking and labeling. @@ -41,6 +42,7 @@ class Point(Drawable): Attributes: x, y (float): Mathematical coordinates (unaffected by zoom/pan) """ + def __init__(self, x: float, y: float, name: str = "", color: str = default_color) -> None: """Initialize a point with mathematical coordinates. @@ -68,12 +70,13 @@ def __init__(self, x: float, y: float, name: str = "", color: str = default_colo ) def get_class_name(self) -> str: - return 'Point' + return "Point" def __str__(self) -> str: def fmt(v: float) -> str: return str(int(v)) if isinstance(v, float) and v.is_integer() else str(v) - return f'{fmt(self.x)},{fmt(self.y)}' + + return f"{fmt(self.x)},{fmt(self.y)}" def get_state(self) -> Dict[str, Any]: state: Dict[str, Any] = {"name": self.name, "args": {"position": {"x": self.x, "y": self.y}}} @@ -185,7 +188,7 @@ def __hash__(self) -> int: """Computes hash based on rounded coordinates.""" # Hash based on coordinates rounded to a few decimal places # Adjust precision as needed, should be coarser than EPSILON allows differences - precision: int = 6 # e.g., 6 decimal places + precision: int = 6 # e.g., 6 decimal places rounded_x: float = round(self.x, precision) rounded_y: float = round(self.y, precision) return hash((rounded_x, rounded_y)) diff --git a/static/client/drawables/polygon.py b/static/client/drawables/polygon.py index 6bce6bb5..da2eafde 100644 --- a/static/client/drawables/polygon.py +++ b/static/client/drawables/polygon.py @@ -28,6 +28,7 @@ from drawables.drawable import Drawable from drawables.point import Point + class Polygon(Drawable): """Abstract base class for polygons that can be rotated around their geometric center. @@ -42,8 +43,7 @@ def _get_shape_center(self, points: Set[Point]) -> Tuple[float, float]: """Calculate center point of a shape given its vertices""" x_coords: list[float] = [p.x for p in points] y_coords: list[float] = [p.y for p in points] - return (sum(x_coords) / len(x_coords), - sum(y_coords) / len(y_coords)) + return (sum(x_coords) / len(x_coords), sum(y_coords) / len(y_coords)) def _rotate_point_around_center(self, point: Point, center_x: float, center_y: float, angle_rad: float) -> None: """Rotate a single point around a center by given angle in radians""" diff --git a/static/client/drawables/position.py b/static/client/drawables/position.py index a427a462..37dc23cb 100644 --- a/static/client/drawables/position.py +++ b/static/client/drawables/position.py @@ -28,6 +28,7 @@ class Position: x (float): X-coordinate in the mathematical coordinate system y (float): Y-coordinate in the mathematical coordinate system """ + def __init__(self, x: float, y: float) -> None: """Initialize a position with x and y coordinates. @@ -39,7 +40,7 @@ def __init__(self, x: float, y: float) -> None: self.y: float = y def __str__(self) -> str: - return f'Position: {self.x}, {self.y}' + return f"Position: {self.x}, {self.y}" def get_state(self) -> Dict[str, Dict[str, float]]: state: Dict[str, Dict[str, float]] = {"Position": {"x": self.x, "y": self.y}} diff --git a/static/client/drawables/quadrilateral.py b/static/client/drawables/quadrilateral.py index 32c1bcee..a54ff24c 100644 --- a/static/client/drawables/quadrilateral.py +++ b/static/client/drawables/quadrilateral.py @@ -143,4 +143,3 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> Quadrilateral: ) memo[id(self)] = new_quad return new_quad - diff --git a/static/client/drawables/rectangle.py b/static/client/drawables/rectangle.py index 94491628..85a9a750 100644 --- a/static/client/drawables/rectangle.py +++ b/static/client/drawables/rectangle.py @@ -34,6 +34,7 @@ from drawables.segment import Segment from utils.math_utils import MathUtils + class Rectangle(Quadrilateral): """Represents a rectangle formed by four connected line segments. @@ -46,7 +47,10 @@ class Rectangle(Quadrilateral): segment3 (Segment): Third side of the rectangle segment4 (Segment): Fourth side of the rectangle """ - def __init__(self, segment1: Segment, segment2: Segment, segment3: Segment, segment4: Segment, color: str = default_color) -> None: + + def __init__( + self, segment1: Segment, segment2: Segment, segment3: Segment, segment4: Segment, color: str = default_color + ) -> None: """Initialize a rectangle from four connected line segments. Validates that the segments form a proper rectangle with right angles. @@ -63,34 +67,41 @@ def __init__(self, segment1: Segment, segment2: Segment, segment3: Segment, segm """ if not self._segments_form_rectangle(segment1, segment2, segment3, segment4): raise ValueError("The segments do not form a rectangle") - if not MathUtils.is_rectangle(segment1.point1.x, segment1.point1.y, - segment2.point1.x, segment2.point1.y, - segment3.point1.x, segment3.point1.y, - segment4.point1.x, segment4.point1.y): + if not MathUtils.is_rectangle( + segment1.point1.x, + segment1.point1.y, + segment2.point1.x, + segment2.point1.y, + segment3.point1.x, + segment3.point1.y, + segment4.point1.x, + segment4.point1.y, + ): raise ValueError("The quadrilateral formed by the segments is not a rectangle") super().__init__(segment1, segment2, segment3, segment4, color=color) self._set_base_type_labels(["quadrilateral", "rectangle"]) def get_class_name(self) -> str: - return 'Rectangle' + return "Rectangle" def _segments_form_rectangle(self, s1: Segment, s2: Segment, s3: Segment, s4: Segment) -> bool: # Check if the end point of one segment is the start point of the next correct_connections: bool = ( - s1.point2 == s2.point1 and - s2.point2 == s3.point1 and - s3.point2 == s4.point1 and - s4.point2 == s1.point1 + s1.point2 == s2.point1 and s2.point2 == s3.point1 and s3.point2 == s4.point1 and s4.point2 == s1.point1 ) return correct_connections def get_state(self) -> Dict[str, Any]: # Collect all point names into a list point_names: list[str] = [ - self.segment1.point1.name, self.segment1.point2.name, - self.segment2.point1.name, self.segment2.point2.name, - self.segment3.point1.name, self.segment3.point2.name, - self.segment4.point1.name, self.segment4.point2.name + self.segment1.point1.name, + self.segment1.point2.name, + self.segment2.point1.name, + self.segment2.point2.name, + self.segment3.point1.name, + self.segment3.point2.name, + self.segment4.point1.name, + self.segment4.point2.name, ] # Convert the list into a set to remove duplicates, then convert it back to a list and sort it points_names: list[str] = sorted(list(set(point_names))) diff --git a/static/client/drawables/segment.py b/static/client/drawables/segment.py index 7ac9fcb2..fabd6e81 100644 --- a/static/client/drawables/segment.py +++ b/static/client/drawables/segment.py @@ -36,6 +36,7 @@ from drawables.position import Position from utils.math_utils import MathUtils + class Segment(Drawable): """Represents a line segment between two points with mathematical line properties. @@ -47,6 +48,7 @@ class Segment(Drawable): point2 (Point): Second endpoint of the segment line_formula (dict): Algebraic line equation coefficients (a, b, c for ax + by + c = 0) """ + def __init__( self, p1: Point, @@ -80,7 +82,7 @@ def __init__( ) def get_class_name(self) -> str: - return 'Segment' + return "Segment" def _calculate_line_algebraic_formula(self) -> Dict[str, float]: p1: Point = self.point1 @@ -123,6 +125,7 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> Any: # Check if the segment has already been deep copied if id(self) in memo: from typing import cast + return cast(Segment, memo[id(self)]) # Deepcopy points that define the segment @@ -257,7 +260,7 @@ def __eq__(self, other: object) -> Any: # Assumes self.point1 and self.point2 are not None # Requires Point class to have proper __eq__ and __hash__ if self.point1 is None or self.point2 is None or other.point1 is None or other.point2 is None: - return False # Or handle appropriately if None points are possible during comparison + return False # Or handle appropriately if None points are possible during comparison points_self: set[Point] = {self.point1, self.point2} points_other: set[Point] = {other.point1, other.point2} return points_self == points_other @@ -266,8 +269,8 @@ def __hash__(self) -> int: """Computes hash based on a frozenset of the hashes of its endpoint Points.""" # Hash is based on the IDs of the point objects, order-independent if self.point1 is None or self.point2 is None: - # Consistent with __eq__ if points can be None - return hash((None, None)) + # Consistent with __eq__ if points can be None + return hash((None, None)) # Use frozenset of point hashes to ensure hash is consistent regardless of point1/point2 order # and relies on Point.__hash__ which is value-based. return hash(frozenset([hash(self.point1), hash(self.point2)])) diff --git a/static/client/drawables/segments_bounded_colored_area.py b/static/client/drawables/segments_bounded_colored_area.py index 1949ed86..0da16272 100644 --- a/static/client/drawables/segments_bounded_colored_area.py +++ b/static/client/drawables/segments_bounded_colored_area.py @@ -24,6 +24,7 @@ from drawables.colored_area import ColoredArea from drawables.segment import Segment + class SegmentsBoundedColoredArea(ColoredArea): """Creates a colored area bounded by line segments with geometric overlap detection. @@ -35,7 +36,9 @@ class SegmentsBoundedColoredArea(ColoredArea): segment2 (Segment or None): The second bounding segment (None means x-axis) """ - def __init__(self, segment1: Segment, segment2: Optional[Segment] = None, color: str = "lightblue", opacity: float = 0.3) -> None: + def __init__( + self, segment1: Segment, segment2: Optional[Segment] = None, color: str = "lightblue", opacity: float = 0.3 + ) -> None: """Initialize a segments bounded colored area. Args: @@ -51,33 +54,35 @@ def __init__(self, segment1: Segment, segment2: Optional[Segment] = None, color: def _generate_name(self, segment1: Segment, segment2: Optional[Segment]) -> str: """Generate a descriptive name for the colored area based on segment names.""" - s1_name: str = segment1.name if segment1 else 'x_axis' - s2_name: str = segment2.name if segment2 else 'x_axis' + s1_name: str = segment1.name if segment1 else "x_axis" + s2_name: str = segment2.name if segment2 else "x_axis" return f"area_between_{s1_name}_and_{s2_name}" def get_class_name(self) -> str: """Return the class name 'SegmentsBoundedColoredArea'.""" - return 'SegmentsBoundedColoredArea' - - + return "SegmentsBoundedColoredArea" def uses_segment(self, segment: Segment) -> bool: """Check if this colored area uses a specific segment for dependency tracking.""" - def segments_match(s1: Segment, s2: Segment) -> bool: - return bool(s1.point1.x == s2.point1.x and - s1.point1.y == s2.point1.y and - s1.point2.x == s2.point2.x and - s1.point2.y == s2.point2.y) - return bool(segments_match(self.segment1, segment) or (self.segment2 and segments_match(self.segment2, segment))) + def segments_match(s1: Segment, s2: Segment) -> bool: + return bool( + s1.point1.x == s2.point1.x + and s1.point1.y == s2.point1.y + and s1.point2.x == s2.point2.x + and s1.point2.y == s2.point2.y + ) + + return bool( + segments_match(self.segment1, segment) or (self.segment2 and segments_match(self.segment2, segment)) + ) def get_state(self) -> Dict[str, Any]: """Serialize segments bounded area state for persistence.""" state: Dict[str, Any] = super().get_state() - state["args"].update({ - "segment1": self.segment1.name, - "segment2": self.segment2.name if self.segment2 else "x_axis" - }) + state["args"].update( + {"segment1": self.segment1.name, "segment2": self.segment2.name if self.segment2 else "x_axis"} + ) return state def __deepcopy__(self, memo: Dict[int, Any]) -> Any: @@ -89,10 +94,7 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> Any: new_segment2 = copy.deepcopy(self.segment2, memo) if self.segment2 else None new_area: SegmentsBoundedColoredArea = SegmentsBoundedColoredArea( - segment1=new_segment1, - segment2=new_segment2, - color=self.color, - opacity=self.opacity + segment1=new_segment1, segment2=new_segment2, color=self.color, opacity=self.opacity ) new_area.name = self.name memo[id(self)] = new_area diff --git a/static/client/drawables/triangle.py b/static/client/drawables/triangle.py index 97c82b1a..dae04ea7 100644 --- a/static/client/drawables/triangle.py +++ b/static/client/drawables/triangle.py @@ -36,6 +36,7 @@ from drawables.segment import Segment from utils.geometry_utils import GeometryUtils + class Triangle(Polygon): """Represents a triangle formed by three connected line segments. @@ -47,6 +48,7 @@ class Triangle(Polygon): segment2 (Segment): Second side of the triangle segment3 (Segment): Third side of the triangle """ + def __init__(self, segment1: Segment, segment2: Segment, segment3: Segment, color: str = default_color) -> None: """Initialize a triangle from three connected line segments. @@ -74,14 +76,22 @@ def __init__(self, segment1: Segment, segment2: Segment, segment3: Segment, colo def _set_name(self) -> str: # Get unique vertices using a set first, then sort - vertices: Set[str] = {p.name for p in [self.segment1.point1, self.segment1.point2, - self.segment2.point1, self.segment2.point2, - self.segment3.point1, self.segment3.point2]} + vertices: Set[str] = { + p.name + for p in [ + self.segment1.point1, + self.segment1.point2, + self.segment2.point1, + self.segment2.point2, + self.segment3.point1, + self.segment3.point2, + ] + } vertices_list: list[str] = sorted(vertices) # Convert to sorted list return vertices_list[0] + vertices_list[1] + vertices_list[2] # Now we're guaranteed three unique points def get_class_name(self) -> str: - return 'Triangle' + return "Triangle" def _segments_form_triangle(self, s1: Segment, s2: Segment, s3: Segment) -> bool: points: list[Point] = [s1.point1, s1.point2, s2.point1, s2.point2, s3.point1, s3.point2] @@ -115,9 +125,12 @@ def is_right(self) -> bool: def get_state(self) -> Dict[str, Any]: # Collect all point names into a list point_names: list[str] = [ - self.segment1.point1.name, self.segment1.point2.name, - self.segment2.point1.name, self.segment2.point2.name, - self.segment3.point1.name, self.segment3.point2.name + self.segment1.point1.name, + self.segment1.point2.name, + self.segment2.point1.name, + self.segment2.point2.name, + self.segment3.point1.name, + self.segment3.point2.name, ] # Find the most frequent point most_frequent_point: str = max(set(point_names), key=point_names.count) @@ -147,9 +160,12 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> Any: def get_vertices(self) -> Set[Point]: """Return the set of unique vertices of the triangle""" return { - self.segment1.point1, self.segment1.point2, - self.segment2.point1, self.segment2.point2, - self.segment3.point1, self.segment3.point2 + self.segment1.point1, + self.segment1.point2, + self.segment2.point1, + self.segment2.point2, + self.segment3.point1, + self.segment3.point2, } def update_color(self, color: str) -> None: diff --git a/static/client/drawables/undirected_graph.py b/static/client/drawables/undirected_graph.py index fb9e9ad3..fd9d5bb2 100644 --- a/static/client/drawables/undirected_graph.py +++ b/static/client/drawables/undirected_graph.py @@ -49,7 +49,9 @@ def __init__( is_renderable=False, ) self._segments: List["Segment"] = list(segments or []) - self._cached_descriptors: Optional[tuple[List[GraphVertexDescriptor], List[GraphEdgeDescriptor], List[List[float]]]] = None + self._cached_descriptors: Optional[ + tuple[List[GraphVertexDescriptor], List[GraphEdgeDescriptor], List[List[float]]] + ] = None @property def segments(self) -> List["Segment"]: diff --git a/static/client/drawables/vector.py b/static/client/drawables/vector.py index bfa4d206..e6707513 100644 --- a/static/client/drawables/vector.py +++ b/static/client/drawables/vector.py @@ -33,6 +33,7 @@ from drawables.point import Point from drawables.segment import Segment + class Vector(Drawable): """Represents a directed line segment (vector) with origin, tip, and arrow head visualization. @@ -44,6 +45,7 @@ class Vector(Drawable): origin (Point): Starting point of the vector (property access to segment.point1) tip (Point): Ending point of the vector (property access to segment.point2) """ + def __init__(self, origin: Point, tip: Point, color: str = default_color) -> None: """Initialize a vector with origin and tip points. @@ -67,7 +69,7 @@ def tip(self) -> Point: return self.segment.point2 def get_class_name(self) -> str: - return 'Vector' + return "Vector" def get_state(self) -> Dict[str, Any]: return { diff --git a/static/client/drawables_aggregator.py b/static/client/drawables_aggregator.py index c4463b8f..a926e8f4 100644 --- a/static/client/drawables_aggregator.py +++ b/static/client/drawables_aggregator.py @@ -39,21 +39,21 @@ # Re-export all classes for convenient importing __all__: list[str] = [ - 'Drawable', - 'Point', - 'Position', - 'Segment', - 'Vector', - 'Triangle', - 'Quadrilateral', - 'Pentagon', - 'Hexagon', - 'Rectangle', - 'Circle', - 'Ellipse', - 'Function', - 'Graph', - 'DirectedGraph', - 'UndirectedGraph', - 'Tree', + "Drawable", + "Point", + "Position", + "Segment", + "Vector", + "Triangle", + "Quadrilateral", + "Pentagon", + "Hexagon", + "Rectangle", + "Circle", + "Ellipse", + "Function", + "Graph", + "DirectedGraph", + "UndirectedGraph", + "Tree", ] diff --git a/static/client/expression_evaluator.py b/static/client/expression_evaluator.py index 7f966ffe..8af03c5a 100644 --- a/static/client/expression_evaluator.py +++ b/static/client/expression_evaluator.py @@ -47,7 +47,7 @@ def evaluate_numeric_expression(expression: str, variables: Dict[str, Any]) -> f float: The computed numeric result """ result: Any = MathUtils.evaluate(expression, variables) - print(f"Evaluated numeric expression: {expression} = {result}") # DEBUG + print(f"Evaluated numeric expression: {expression} = {result}") # DEBUG # Convert numeric results to float for consistency if isinstance(result, (int, float)): result = float(result) @@ -69,25 +69,25 @@ def evaluate_function(expression: str, canvas: "Canvas") -> float: Raises: ValueError: If canvas is None, expression format is invalid, or function not found """ - print(f"Evaluating function with expression: {expression}") # DEBUG + print(f"Evaluating function with expression: {expression}") # DEBUG if canvas is None: raise ValueError("Cannot evaluate function: no canvas available") - functions = canvas.get_drawables_by_class_name('Function') + functions = canvas.get_drawables_by_class_name("Function") # Split the expression into function name and argument - match: Optional[re.Match[str]] = re.match(r'(\w+)\((.+)\)', expression) + match: Optional[re.Match[str]] = re.match(r"(\w+)\((.+)\)", expression) if match: function_name: str argument: str function_name, argument = match.groups() - print(f"Function name: {function_name}, argument: {argument}") # DEBUG + print(f"Function name: {function_name}, argument: {argument}") # DEBUG else: raise ValueError(f"Invalid function expression: {expression}") for function in functions: if function.name.lower() == function_name.lower(): # If the function name matches, evaluate the function - print(f"Found function: {function.name} = {function.function_string}") # DEBUG + print(f"Found function: {function.name} = {function.function_string}") # DEBUG try: argument_val: float = float(argument) # Convert argument to float result: Any = function.function(argument_val) @@ -102,7 +102,9 @@ def evaluate_function(expression: str, canvas: "Canvas") -> float: raise ValueError(f"No function found with name: {function_name}") @staticmethod - def evaluate_expression(expression: str, variables: Optional[Dict[str, Any]] = None, canvas: Optional["Canvas"] = None) -> Union[float, str]: + def evaluate_expression( + expression: str, variables: Optional[Dict[str, Any]] = None, canvas: Optional["Canvas"] = None + ) -> Union[float, str]: """Main method to evaluate expressions with fallback from numeric to function evaluation. First attempts numeric evaluation, then falls back to function evaluation if available. diff --git a/static/client/expression_validator.py b/static/client/expression_validator.py index 410b54d6..513f3c94 100644 --- a/static/client/expression_validator.py +++ b/static/client/expression_validator.py @@ -82,6 +82,7 @@ class ExpressionValidator(ast.NodeVisitor): ALLOWED_NODES (set): Whitelist of permitted AST node types ALLOWED_FUNCTIONS (set): Whitelist of permitted mathematical functions """ + ALLOWED_NODES: Set[Type[ast.AST]] = { ast.Add, ast.Sub, @@ -99,69 +100,69 @@ class ExpressionValidator(ast.NodeVisitor): ast.List, # List literals (e.g., [1, 2, 3]) } ALLOWED_FUNCTIONS: Set[str] = { - 'sin', - 'cos', - 'tan', - 'sqrt', - 'log', - 'log10', - 'log2', - 'factorial', - 'asin', - 'acos', - 'atan', - 'sinh', - 'cosh', - 'tanh', - 'exp', - 'abs', - 'pi', - 'e', - 'pow', - 'det', - 'bin', - 'arrangements', - 'permutations', - 'combinations', - 'round', - 'ceil', - 'floor', - 'trunc', - 'max', - 'min', - 'sum', - 'limit', - 'derive', - 'integrate', - 'simplify', - 'expand', - 'factor', - 'solve', - 'gcd', - 'lcm', - 'is_prime', - 'prime_factors', - 'mod_pow', - 'mod_inverse', - 'next_prime', - 'prev_prime', - 'totient', - 'divisors', - 'mean', - 'median', - 'mode', - 'stdev', - 'variance', - 'random', - 'randint', - 'summation', - 'product', - 'arithmetic_sum', - 'geometric_sum', - 'geometric_sum_infinite', - 'ratio_test', - 'root_test', - 'p_series_test' + "sin", + "cos", + "tan", + "sqrt", + "log", + "log10", + "log2", + "factorial", + "asin", + "acos", + "atan", + "sinh", + "cosh", + "tanh", + "exp", + "abs", + "pi", + "e", + "pow", + "det", + "bin", + "arrangements", + "permutations", + "combinations", + "round", + "ceil", + "floor", + "trunc", + "max", + "min", + "sum", + "limit", + "derive", + "integrate", + "simplify", + "expand", + "factor", + "solve", + "gcd", + "lcm", + "is_prime", + "prime_factors", + "mod_pow", + "mod_inverse", + "next_prime", + "prev_prime", + "totient", + "divisors", + "mean", + "median", + "mode", + "stdev", + "variance", + "random", + "randint", + "summation", + "product", + "arithmetic_sum", + "geometric_sum", + "geometric_sum_infinite", + "ratio_test", + "root_test", + "p_series_test", } def _is_allowed_node_type(self, node: ast.AST) -> bool: @@ -283,7 +284,7 @@ def validate_expression_tree(expression: str) -> None: """ try: # Parse the expression into an abstract syntax tree - tree = ast.parse(expression, mode='eval') + tree = ast.parse(expression, mode="eval") validator = ExpressionValidator() validator.visit(tree) @@ -318,80 +319,81 @@ def evaluate_expression(expression: str, x: float = 0) -> float: """ variables_and_functions = ExpressionValidator._get_variables_and_functions(x) # Parse the expression into an abstract syntax tree - tree = ast.parse(expression, mode='eval') + tree = ast.parse(expression, mode="eval") # Evaluate the expression using the abstract syntax tree and the variables dictionary - result = eval(compile(tree, '', mode='eval'), variables_and_functions) + result = eval(compile(tree, "", mode="eval"), variables_and_functions) return cast(float, result) @staticmethod def _get_variables_and_functions(x: float) -> Dict[str, Any]: """Create a dictionary with variables and functions for expression evaluation""" from utils.math_utils import MathUtils + return { - 'x': x, - 'sin': math.sin, - 'cos': math.cos, - 'tan': math.tan, - 'sqrt': MathUtils.sqrt, # Square root function - 'log': math.log, # Natural logarithm (base e) - 'log10': math.log10, # Logarithm base 10 - 'log2': math.log2, # Logarithm base 2 - 'factorial': math.factorial, # Factorial function - 'asin': math.asin, # Arcsine function - 'acos': math.acos, # Arccosine function - 'atan': math.atan, # Arctangent function - 'sinh': math.sinh, # Hyperbolic sine function - 'cosh': math.cosh, # Hyperbolic cosine function - 'tanh': math.tanh, # Hyperbolic tangent function - 'exp': math.exp, # Exponential function - 'abs': abs, # Absolute value function - 'pi': math.pi, # The constant pi - 'e': math.e, # The constant e - 'pow': MathUtils.pow, # Power function - 'bin': bin, # Binary representation of an integer - 'det': MathUtils.det, # Determinant of a matrix - 'arrangements': MathUtils.arrangements, # Arrangements aka permutations nPk - 'permutations': MathUtils.permutations, # Permutations - 'combinations': MathUtils.combinations, # Combinations - 'limit': MathUtils.limit, # Limit of a function - 'derive': MathUtils.derivative, # Derivative of a function - 'integrate': MathUtils.integral, # Indefinite integral of a function - 'simplify': MathUtils.simplify, # Simplify an expression - 'expand': MathUtils.expand, # Expand an expression - 'factor': MathUtils.factor, # Factor an expression - 'solve': MathUtils.solve, # Solve an equation - 'random': MathUtils.random, # Generate a random number - 'round': MathUtils.round, # Round a number - 'gcd': MathUtils.gcd, # Greatest common divisor - 'lcm': MathUtils.lcm, # Least common multiple - 'is_prime': MathUtils.is_prime, # Check if number is prime - 'prime_factors': MathUtils.prime_factors, # Prime factorization with multiplicity - 'mod_pow': MathUtils.mod_pow, # Modular exponentiation - 'mod_inverse': MathUtils.mod_inverse, # Modular multiplicative inverse - 'next_prime': MathUtils.next_prime, # Find smallest prime >= n - 'prev_prime': MathUtils.prev_prime, # Find largest prime <= n - 'totient': MathUtils.totient, # Euler's totient function - 'divisors': MathUtils.divisors, # All positive divisors - 'mean': MathUtils.mean, # Mean of a list of numbers - 'median': MathUtils.median, # Median of a list of numbers - 'mode': MathUtils.mode, # Mode of a list of numbers - 'stdev': MathUtils.stdev, # Standard deviation of a list of numbers - 'variance': MathUtils.variance, # Variance of a list of numbers - 'ceil': math.ceil, # Round up to the nearest integer - 'floor': math.floor, # Round down to the nearest integer - 'trunc': math.trunc, # Truncate to an integer - 'max': max, # Maximum of a list of numbers - 'min': min, # Minimum of a list of numbers - 'sum': sum, # Sum of a list of numbers - 'randint': lambda a, b: random.randint(a, b), # Random integer between a and b - 'summation': MathUtils.summation, - 'product': MathUtils.product, - 'arithmetic_sum': MathUtils.arithmetic_sum, - 'geometric_sum': MathUtils.geometric_sum, - 'geometric_sum_infinite': MathUtils.geometric_sum_infinite, - 'ratio_test': MathUtils.ratio_test, - 'root_test': MathUtils.root_test, - 'p_series_test': MathUtils.p_series_test + "x": x, + "sin": math.sin, + "cos": math.cos, + "tan": math.tan, + "sqrt": MathUtils.sqrt, # Square root function + "log": math.log, # Natural logarithm (base e) + "log10": math.log10, # Logarithm base 10 + "log2": math.log2, # Logarithm base 2 + "factorial": math.factorial, # Factorial function + "asin": math.asin, # Arcsine function + "acos": math.acos, # Arccosine function + "atan": math.atan, # Arctangent function + "sinh": math.sinh, # Hyperbolic sine function + "cosh": math.cosh, # Hyperbolic cosine function + "tanh": math.tanh, # Hyperbolic tangent function + "exp": math.exp, # Exponential function + "abs": abs, # Absolute value function + "pi": math.pi, # The constant pi + "e": math.e, # The constant e + "pow": MathUtils.pow, # Power function + "bin": bin, # Binary representation of an integer + "det": MathUtils.det, # Determinant of a matrix + "arrangements": MathUtils.arrangements, # Arrangements aka permutations nPk + "permutations": MathUtils.permutations, # Permutations + "combinations": MathUtils.combinations, # Combinations + "limit": MathUtils.limit, # Limit of a function + "derive": MathUtils.derivative, # Derivative of a function + "integrate": MathUtils.integral, # Indefinite integral of a function + "simplify": MathUtils.simplify, # Simplify an expression + "expand": MathUtils.expand, # Expand an expression + "factor": MathUtils.factor, # Factor an expression + "solve": MathUtils.solve, # Solve an equation + "random": MathUtils.random, # Generate a random number + "round": MathUtils.round, # Round a number + "gcd": MathUtils.gcd, # Greatest common divisor + "lcm": MathUtils.lcm, # Least common multiple + "is_prime": MathUtils.is_prime, # Check if number is prime + "prime_factors": MathUtils.prime_factors, # Prime factorization with multiplicity + "mod_pow": MathUtils.mod_pow, # Modular exponentiation + "mod_inverse": MathUtils.mod_inverse, # Modular multiplicative inverse + "next_prime": MathUtils.next_prime, # Find smallest prime >= n + "prev_prime": MathUtils.prev_prime, # Find largest prime <= n + "totient": MathUtils.totient, # Euler's totient function + "divisors": MathUtils.divisors, # All positive divisors + "mean": MathUtils.mean, # Mean of a list of numbers + "median": MathUtils.median, # Median of a list of numbers + "mode": MathUtils.mode, # Mode of a list of numbers + "stdev": MathUtils.stdev, # Standard deviation of a list of numbers + "variance": MathUtils.variance, # Variance of a list of numbers + "ceil": math.ceil, # Round up to the nearest integer + "floor": math.floor, # Round down to the nearest integer + "trunc": math.trunc, # Truncate to an integer + "max": max, # Maximum of a list of numbers + "min": min, # Minimum of a list of numbers + "sum": sum, # Sum of a list of numbers + "randint": lambda a, b: random.randint(a, b), # Random integer between a and b + "summation": MathUtils.summation, + "product": MathUtils.product, + "arithmetic_sum": MathUtils.arithmetic_sum, + "geometric_sum": MathUtils.geometric_sum, + "geometric_sum_infinite": MathUtils.geometric_sum_infinite, + "ratio_test": MathUtils.ratio_test, + "root_test": MathUtils.root_test, + "p_series_test": MathUtils.p_series_test, } @staticmethod @@ -419,22 +421,22 @@ def fix_math_expression(expression: str, python_compatible: bool = False) -> str @staticmethod def _convert_degrees(expression: str) -> str: """Convert degree symbols and text to radians""" - expression = expression.replace('°', ' deg') - expression = expression.replace('degrees', ' deg') - expression = expression.replace('degree', ' deg') - expression = re.sub(r'(\d+)\s*deg', lambda match: str(float(match.group(1)) * math.pi / 180), expression) + expression = expression.replace("°", " deg") + expression = expression.replace("degrees", " deg") + expression = expression.replace("degree", " deg") + expression = re.sub(r"(\d+)\s*deg", lambda match: str(float(match.group(1)) * math.pi / 180), expression) return expression @staticmethod def _handle_special_symbols(expression: str, python_compatible: bool) -> str: """Handle square roots, absolute values, and factorials""" # Handle square roots - expression = re.sub(r'√\((.*?)\)', r'sqrt(\1)', expression) - expression = re.sub(r'√([0-9a-zA-Z_]+)', r'sqrt(\1)', expression) + expression = re.sub(r"√\((.*?)\)", r"sqrt(\1)", expression) + expression = re.sub(r"√([0-9a-zA-Z_]+)", r"sqrt(\1)", expression) # Replace | | with the Python equivalent if needed if python_compatible: - expression = re.sub(r'\|(.*?)\|', r'abs(\1)', expression) + expression = re.sub(r"\|(.*?)\|", r"abs(\1)", expression) # Handle factorials with balanced operand extraction expression = ExpressionValidator._replace_factorials(expression) @@ -443,15 +445,15 @@ def _handle_special_symbols(expression: str, python_compatible: bool) -> str: @staticmethod def _replace_factorials(expression: str) -> str: """Replace factorial shorthand (n!) with factorial() calls using balanced parsing.""" - if '!' not in expression: + if "!" not in expression: return expression def is_token_char(char: str) -> bool: - return char.isalnum() or char in ['_', '.'] + return char.isalnum() or char in ["_", "."] - matching_pairs = {')': '(', ']': '[', '}': '{'} + matching_pairs = {")": "(", "]": "[", "}": "{"} - index = expression.find('!') + index = expression.find("!") while index != -1: left = index - 1 while left >= 0 and expression[left].isspace(): @@ -495,8 +497,8 @@ def is_token_char(char: str) -> bool: break replacement = f"factorial({operand})" - expression = expression[:start] + replacement + expression[index + 1:] - index = expression.find('!', start + len(replacement)) + expression = expression[:start] + replacement + expression[index + 1 :] + index = expression.find("!", start + len(replacement)) return expression @@ -512,8 +514,8 @@ def _replace_function_names(expression: str) -> str: def _get_function_replacements() -> Dict[str, str]: """Get a dictionary of function name replacements""" return { - 'π': 'pi', # Using the variable from the dictionary - 'ln': 'log', # Python's math.log is ln by default + "π": "pi", # Using the variable from the dictionary + "ln": "log", # Python's math.log is ln by default "absolute(": "abs(", "power(": "pow(", "binary(": "bin(", @@ -547,34 +549,34 @@ def _handle_power_and_imaginary(expression: str, python_compatible: bool) -> str """Handle power operators and imaginary numbers based on compatibility mode""" # Replace the power symbol with '**' if specified if python_compatible: - expression = expression.replace('^', '**') + expression = expression.replace("^", "**") else: - expression = expression.replace('**', '^') + expression = expression.replace("**", "^") # Replace 'i' with 'j' only in contexts likely to represent the imaginary unit - imaginary_unit = 'j' if python_compatible else 'i' - opposite_unit = 'i' if python_compatible else 'j' + imaginary_unit = "j" if python_compatible else "i" + opposite_unit = "i" if python_compatible else "j" # Assuming it's used in the form of numbers like '2i' or standalone 'i' - expression = re.sub(rf'(?<=\d){opposite_unit}\b', f'{imaginary_unit}', expression) # For numbers like '2i' - expression = re.sub(rf'\b{opposite_unit}\b', f'{imaginary_unit}', expression) # For standalone 'i' + expression = re.sub(rf"(?<=\d){opposite_unit}\b", f"{imaginary_unit}", expression) # For numbers like '2i' + expression = re.sub(rf"\b{opposite_unit}\b", f"{imaginary_unit}", expression) # For standalone 'i' return expression @staticmethod def _insert_multiplication_operators(expression: str, python_compatible: bool) -> str: """Insert multiplication operators where implicit multiplication is used""" - imaginary_unit = 'j' if python_compatible else 'i' + imaginary_unit = "j" if python_compatible else "i" # Step 1: Protect "log" followed by any number from being altered - expression = re.sub(r'log(\d+)', r'log[\1]', expression) + expression = re.sub(r"log(\d+)", r"log[\1]", expression) # Step 2: Insert '*' between a number and a variable, function name, or parenthesis, # excluding 'i' or 'j' immediately after a number - expression = re.sub(rf'(\d)(?!{imaginary_unit})([a-zA-Z_\(])', r'\1*\2', expression) + expression = re.sub(rf"(\d)(?!{imaginary_unit})([a-zA-Z_\(])", r"\1*\2", expression) # Step 3: Revert "log" followed by any number back to its original form - expression = re.sub(r'log\[(\d+)\]', r'log\1', expression) + expression = re.sub(r"log\[(\d+)\]", r"log\1", expression) return expression @@ -582,7 +584,8 @@ def _insert_multiplication_operators(expression: str, python_compatible: bool) - def _parse_with_mathjs(function_string: str) -> Callable[[float], Any]: """Parse a function string using mathjs (slower but more powerful)""" from utils.math_utils import MathUtils - return lambda x: MathUtils.evaluate(function_string, {'x': x}) + + return lambda x: MathUtils.evaluate(function_string, {"x": x}) @staticmethod def _parse_with_python(function_string: str) -> Callable[[float], float]: @@ -590,8 +593,8 @@ def _parse_with_python(function_string: str) -> Callable[[float], float]: function_string = ExpressionValidator.fix_math_expression(function_string, python_compatible=True) ExpressionValidator.validate_expression_tree(function_string) - tree = ast.parse(function_string, mode='eval') - compiled_code = compile(tree, '', mode='eval') + tree = ast.parse(function_string, mode="eval") + compiled_code = compile(tree, "", mode="eval") def evaluator(x: float) -> float: variables = ExpressionValidator._get_variables_and_functions(x) @@ -624,33 +627,34 @@ def _get_variables_and_functions_parametric(t: float) -> Dict[str, Any]: instead of 'x', for parametric curves like x(t), y(t). """ from utils.math_utils import MathUtils + return { - 't': t, - 'sin': math.sin, - 'cos': math.cos, - 'tan': math.tan, - 'sqrt': MathUtils.sqrt, - 'log': math.log, - 'log10': math.log10, - 'log2': math.log2, - 'factorial': math.factorial, - 'asin': math.asin, - 'acos': math.acos, - 'atan': math.atan, - 'sinh': math.sinh, - 'cosh': math.cosh, - 'tanh': math.tanh, - 'exp': math.exp, - 'abs': abs, - 'pi': math.pi, - 'e': math.e, - 'pow': MathUtils.pow, - 'ceil': math.ceil, - 'floor': math.floor, - 'trunc': math.trunc, - 'max': max, - 'min': min, - 'round': MathUtils.round, + "t": t, + "sin": math.sin, + "cos": math.cos, + "tan": math.tan, + "sqrt": MathUtils.sqrt, + "log": math.log, + "log10": math.log10, + "log2": math.log2, + "factorial": math.factorial, + "asin": math.asin, + "acos": math.acos, + "atan": math.atan, + "sinh": math.sinh, + "cosh": math.cosh, + "tanh": math.tanh, + "exp": math.exp, + "abs": abs, + "pi": math.pi, + "e": math.e, + "pow": MathUtils.pow, + "ceil": math.ceil, + "floor": math.floor, + "trunc": math.trunc, + "max": max, + "min": min, + "round": MathUtils.round, } @staticmethod @@ -662,8 +666,8 @@ def _parse_parametric_with_python(expression_string: str) -> Callable[[float], f expression_string = ExpressionValidator.fix_math_expression(expression_string, python_compatible=True) ExpressionValidator.validate_expression_tree(expression_string) - tree = ast.parse(expression_string, mode='eval') - compiled_code = compile(tree, '', mode='eval') + tree = ast.parse(expression_string, mode="eval") + compiled_code = compile(tree, "", mode="eval") def evaluator(t: float) -> float: variables = ExpressionValidator._get_variables_and_functions_parametric(t) diff --git a/static/client/function_registry.py b/static/client/function_registry.py index 19b13891..e8fb1a3e 100644 --- a/static/client/function_registry.py +++ b/static/client/function_registry.py @@ -69,7 +69,9 @@ def _convert_coordinates(coord1: float, coord2: float, from_system: str, to_syst return {"error": f"Invalid conversion: {from_system} to {to_system}"} @staticmethod - def get_available_functions(canvas: "Canvas", workspace_manager: "WorkspaceManager", ai_interface: Optional["AIInterface"] = None) -> Dict[str, Any]: + def get_available_functions( + canvas: "Canvas", workspace_manager: "WorkspaceManager", ai_interface: Optional["AIInterface"] = None + ) -> Dict[str, Any]: """Get the complete dictionary of all available functions with their implementations. Creates the mapping between AI function names and their bound Python methods, @@ -96,66 +98,53 @@ def get_available_functions(canvas: "Canvas", workspace_manager: "WorkspaceManag include_computations=include_computations, ), }, - # ===== POINT OPERATIONS ===== "create_point": canvas.create_point, "delete_point": canvas.delete_point, "update_point": canvas.update_point, - # ===== SEGMENT OPERATIONS ===== "create_segment": canvas.create_segment, "delete_segment": canvas.delete_segment, "update_segment": canvas.update_segment, - # ===== VECTOR OPERATIONS ===== "create_vector": canvas.create_vector, "delete_vector": canvas.delete_vector, "update_vector": canvas.update_vector, - # ===== POLYGON OPERATIONS ===== "create_polygon": canvas.create_polygon, "delete_polygon": canvas.delete_polygon, "update_polygon": canvas.update_polygon, - # ===== CIRCLE OPERATIONS ===== "create_circle": canvas.create_circle, "delete_circle": canvas.delete_circle, "update_circle": canvas.update_circle, - # ===== CIRCLE ARC OPERATIONS ===== "create_circle_arc": canvas.create_circle_arc, "delete_circle_arc": canvas.delete_circle_arc, "update_circle_arc": canvas.update_circle_arc, - # ===== ELLIPSE OPERATIONS ===== "create_ellipse": canvas.create_ellipse, "delete_ellipse": canvas.delete_ellipse, "update_ellipse": canvas.update_ellipse, - # ===== LABEL OPERATIONS ===== "create_label": canvas.create_label, "delete_label": canvas.delete_label, "update_label": canvas.update_label, - # ===== FUNCTION PLOTTING ===== "draw_function": canvas.draw_function, "delete_function": canvas.delete_function, "update_function": canvas.update_function, - # ===== PIECEWISE FUNCTION PLOTTING ===== "draw_piecewise_function": canvas.draw_piecewise_function, "delete_piecewise_function": canvas.delete_piecewise_function, "update_piecewise_function": canvas.update_piecewise_function, - # ===== PARAMETRIC FUNCTION PLOTTING ===== "draw_parametric_function": canvas.draw_parametric_function, "delete_parametric_function": canvas.delete_parametric_function, "update_parametric_function": canvas.update_parametric_function, - # ===== TANGENT AND NORMAL LINES ===== "draw_tangent_line": canvas.create_tangent_line, "draw_normal_line": canvas.create_normal_line, - # ===== GEOMETRIC CONSTRUCTIONS ===== "construct_midpoint": canvas.create_midpoint, "construct_perpendicular_bisector": canvas.create_perpendicular_bisector, @@ -164,14 +153,12 @@ def get_available_functions(canvas: "Canvas", workspace_manager: "WorkspaceManag "construct_parallel_line": canvas.create_parallel_line, "construct_circumcircle": canvas.create_circumcircle, "construct_incircle": canvas.create_incircle, - # ===== OBJECT TRANSFORMATIONS ===== "translate_object": canvas.translate_object, "rotate_object": canvas.rotate_object, "reflect_object": canvas.reflect_object, "scale_object": canvas.scale_object, "shear_object": canvas.shear_object, - # ===== MATHEMATICAL OPERATIONS ===== "evaluate_expression": ProcessFunctionCalls.evaluate_expression, "evaluate_linear_algebra_expression": ProcessFunctionCalls.evaluate_linear_algebra_expression, @@ -186,45 +173,36 @@ def get_available_functions(canvas: "Canvas", workspace_manager: "WorkspaceManag "solve": MathUtils.solve, "solve_system_of_equations": MathUtils.solve_system_of_equations, "solve_numeric": MathUtils.solve_numeric, - # ===== CANVAS HISTORY ===== "undo": canvas.undo, "redo": canvas.redo, - # ===== WORKSPACE OPERATIONS ===== "save_workspace": workspace_manager.save_workspace, "load_workspace": workspace_manager.load_workspace, "list_workspaces": workspace_manager.list_workspaces, "delete_workspace": workspace_manager.delete_workspace, - # ===== COLORED AREA OPERATIONS ===== "create_colored_area": canvas.create_colored_area, "create_region_colored_area": canvas.create_region_colored_area, "delete_colored_area": canvas.delete_colored_area, "update_colored_area": canvas.update_colored_area, - # ===== GRAPH OPERATIONS ===== "generate_graph": canvas.generate_graph, "delete_graph": canvas.delete_graph, "analyze_graph": canvas.analyze_graph, - # ===== RELATION INSPECTION ===== "inspect_relation": canvas.inspect_relation, - # ===== PLOT OPERATIONS ===== "plot_distribution": canvas.plot_distribution, "plot_bars": canvas.plot_bars, "delete_plot": canvas.delete_plot, "fit_regression": canvas.fit_regression, - # ===== ANGLE OPERATIONS ===== "create_angle": canvas.create_angle, "delete_angle": canvas.delete_angle, "update_angle": canvas.update_angle, - # ===== AREA CALCULATION ===== "calculate_area": lambda expression: ProcessFunctionCalls.calculate_area(expression, canvas), - # ===== COORDINATE SYSTEM OPERATIONS ===== "set_coordinate_system": canvas.set_coordinate_system, "convert_coordinates": FunctionRegistry._convert_coordinates, @@ -246,6 +224,7 @@ def _create_search_tools_handler() -> Callable[..., Dict[str, Any]]: Returns a function that makes a request to the backend /search_tools endpoint. """ + def search_tools(query: str, max_results: int | None = None) -> dict: """Search for tools matching a query description. @@ -327,66 +306,53 @@ def get_undoable_functions() -> Tuple[str, ...]: "clear_canvas", "reset_canvas", "zoom", - # Point operations "create_point", "delete_point", "update_point", - # Segment operations "create_segment", "delete_segment", "update_segment", - # Vector operations "create_vector", "delete_vector", "update_vector", - # Polygon operations "create_polygon", "delete_polygon", "update_polygon", - # Circle operations "create_circle", "delete_circle", "update_circle", - # Circle arc operations "create_circle_arc", "delete_circle_arc", "update_circle_arc", - # Ellipse operations "create_ellipse", "delete_ellipse", "update_ellipse", - # Label operations "create_label", "delete_label", "update_label", - # Function operations "draw_function", "delete_function", "update_function", - # Piecewise function operations "draw_piecewise_function", "delete_piecewise_function", "update_piecewise_function", - # Parametric function operations "draw_parametric_function", "delete_parametric_function", "update_parametric_function", - # Tangent and normal line operations "draw_tangent_line", "draw_normal_line", - # Geometric construction operations "construct_midpoint", "construct_perpendicular_bisector", @@ -395,32 +361,27 @@ def get_undoable_functions() -> Tuple[str, ...]: "construct_parallel_line", "construct_circumcircle", "construct_incircle", - # Object transformations "translate_object", "rotate_object", "reflect_object", "scale_object", "shear_object", - # Colored area operations "create_colored_area", "create_region_colored_area", "delete_colored_area", "update_colored_area", - # Graph operations "generate_graph", "delete_graph", - # Plot operations "plot_distribution", "plot_bars", "delete_plot", # Note: fit_regression is NOT undoable - it returns stats to the AI - # Angle operations "create_angle", "delete_angle", - "update_angle" + "update_angle", ) diff --git a/static/client/geometry/__init__.py b/static/client/geometry/__init__.py index 64f334cb..2c2eb0bf 100644 --- a/static/client/geometry/__init__.py +++ b/static/client/geometry/__init__.py @@ -23,18 +23,18 @@ from .region import Region __all__ = [ - 'PathElement', - 'LineSegment', - 'CircularArc', - 'EllipticalArc', - 'CompositePath', - 'Region', - 'line_line_intersection', - 'line_circle_intersection', - 'line_ellipse_intersection', - 'circle_circle_intersection', - 'circle_ellipse_intersection', - 'ellipse_ellipse_intersection', - 'element_element_intersection', - 'path_path_intersections', + "PathElement", + "LineSegment", + "CircularArc", + "EllipticalArc", + "CompositePath", + "Region", + "line_line_intersection", + "line_circle_intersection", + "line_ellipse_intersection", + "circle_circle_intersection", + "circle_ellipse_intersection", + "ellipse_ellipse_intersection", + "element_element_intersection", + "path_path_intersections", ] diff --git a/static/client/geometry/graph_state.py b/static/client/geometry/graph_state.py index 0e41c5ea..f37fc85a 100644 --- a/static/client/geometry/graph_state.py +++ b/static/client/geometry/graph_state.py @@ -156,6 +156,3 @@ def to_dict(self) -> Dict[str, Any]: data = super().to_dict() data["root"] = self.root return data - - - diff --git a/static/client/geometry/path/__init__.py b/static/client/geometry/path/__init__.py index ac319219..a100540f 100644 --- a/static/client/geometry/path/__init__.py +++ b/static/client/geometry/path/__init__.py @@ -22,18 +22,17 @@ ) __all__ = [ - 'PathElement', - 'LineSegment', - 'CircularArc', - 'EllipticalArc', - 'CompositePath', - 'line_line_intersection', - 'line_circle_intersection', - 'line_ellipse_intersection', - 'circle_circle_intersection', - 'circle_ellipse_intersection', - 'ellipse_ellipse_intersection', - 'element_element_intersection', - 'path_path_intersections', + "PathElement", + "LineSegment", + "CircularArc", + "EllipticalArc", + "CompositePath", + "line_line_intersection", + "line_circle_intersection", + "line_ellipse_intersection", + "circle_circle_intersection", + "circle_ellipse_intersection", + "ellipse_ellipse_intersection", + "element_element_intersection", + "path_path_intersections", ] - diff --git a/static/client/geometry/path/circular_arc.py b/static/client/geometry/path/circular_arc.py index df7c1fc4..b42c692f 100644 --- a/static/client/geometry/path/circular_arc.py +++ b/static/client/geometry/path/circular_arc.py @@ -75,14 +75,8 @@ def from_circle_arc(cls, arc: Any) -> "CircularArc": center = (float(arc.center_x), float(arc.center_y)) radius = float(arc.radius) - start_angle = math.atan2( - arc.point1.y - arc.center_y, - arc.point1.x - arc.center_x - ) - end_angle = math.atan2( - arc.point2.y - arc.center_y, - arc.point2.x - arc.center_x - ) + start_angle = math.atan2(arc.point1.y - arc.center_y, arc.point1.x - arc.center_x) + end_angle = math.atan2(arc.point2.y - arc.center_y, arc.point2.x - arc.center_x) ccw_span = end_angle - start_angle if ccw_span < 0: @@ -216,4 +210,3 @@ def __hash__(self) -> int: def __repr__(self) -> str: direction = "CW" if self._clockwise else "CCW" return f"CircularArc(center={self._center}, r={self._radius}, {self._start_angle:.3f} to {self._end_angle:.3f} {direction})" - diff --git a/static/client/geometry/path/composite_path.py b/static/client/geometry/path/composite_path.py index fa47b9fd..f52a5fa0 100644 --- a/static/client/geometry/path/composite_path.py +++ b/static/client/geometry/path/composite_path.py @@ -69,8 +69,7 @@ def append(self, element: PathElement) -> None: last_end = last.end_point() elem_start = element.start_point() raise ValueError( - f"Element does not connect: previous end {last_end} " - f"does not match new start {elem_start}" + f"Element does not connect: previous end {last_end} does not match new start {elem_start}" ) self._elements.append(element) @@ -89,8 +88,7 @@ def prepend(self, element: PathElement) -> None: elem_end = element.end_point() first_start = first.start_point() raise ValueError( - f"Element does not connect: new end {elem_end} " - f"does not match first start {first_start}" + f"Element does not connect: new end {elem_end} does not match first start {first_start}" ) self._elements.insert(0, element) @@ -238,4 +236,3 @@ def __eq__(self, other: object) -> bool: def __repr__(self) -> str: status = "closed" if self.is_closed() else "open" return f"CompositePath({len(self._elements)} elements, {status})" - diff --git a/static/client/geometry/path/elliptical_arc.py b/static/client/geometry/path/elliptical_arc.py index a0a6c8d5..56caead5 100644 --- a/static/client/geometry/path/elliptical_arc.py +++ b/static/client/geometry/path/elliptical_arc.py @@ -36,10 +36,7 @@ class EllipticalArc(PathElement): _clockwise: Direction of traversal """ - __slots__ = ( - "_center", "_radius_x", "_radius_y", "_rotation", - "_start_angle", "_end_angle", "_clockwise" - ) + __slots__ = ("_center", "_radius_x", "_radius_y", "_rotation", "_start_angle", "_end_angle", "_clockwise") def __init__( self, @@ -88,10 +85,7 @@ def from_ellipse(cls, ellipse: Any) -> "EllipticalArc": radius_x = float(ellipse.radius_x) radius_y = float(ellipse.radius_y) rotation = math.radians(float(getattr(ellipse, "rotation_angle", 0.0))) - return cls( - center, radius_x, radius_y, 0.0, 2 * math.pi, - rotation=rotation, clockwise=False - ) + return cls(center, radius_x, radius_y, 0.0, 2 * math.pi, rotation=rotation, clockwise=False) @property def center(self) -> Tuple[float, float]: @@ -230,10 +224,17 @@ def __eq__(self, other: object) -> bool: ) def __hash__(self) -> int: - return hash(( - self._center, self._radius_x, self._radius_y, - self._rotation, self._start_angle, self._end_angle, self._clockwise - )) + return hash( + ( + self._center, + self._radius_x, + self._radius_y, + self._rotation, + self._start_angle, + self._end_angle, + self._clockwise, + ) + ) def __repr__(self) -> str: direction = "CW" if self._clockwise else "CCW" @@ -243,4 +244,3 @@ def __repr__(self) -> str: f"rot={math.degrees(self._rotation):.1f}deg, " f"{self._start_angle:.3f} to {self._end_angle:.3f} {direction})" ) - diff --git a/static/client/geometry/path/intersections.py b/static/client/geometry/path/intersections.py index a5a51fa3..3db11736 100644 --- a/static/client/geometry/path/intersections.py +++ b/static/client/geometry/path/intersections.py @@ -25,6 +25,7 @@ def _get_geometry_utils(): global _GEOMETRY_UTILS if _GEOMETRY_UTILS is None: from utils.geometry_utils import GeometryUtils as _GeometryUtils + _GEOMETRY_UTILS = _GeometryUtils return _GEOMETRY_UTILS @@ -33,8 +34,7 @@ def line_line_intersection(seg1: LineSegment, seg2: LineSegment) -> List[Point]: """Find intersection point between two line segments.""" GeometryUtils = _get_geometry_utils() return GeometryUtils.line_line_intersection( - seg1.start_point(), seg1.end_point(), - seg2.start_point(), seg2.end_point() + seg1.start_point(), seg1.end_point(), seg2.start_point(), seg2.end_point() ) @@ -42,9 +42,7 @@ def line_circle_intersection(seg: LineSegment, arc: CircularArc) -> List[Point]: """Find intersection points between a line segment and a circular arc.""" GeometryUtils = _get_geometry_utils() return GeometryUtils.line_circle_intersection( - seg.start_point(), seg.end_point(), - arc.center, arc.radius, - arc.start_angle, arc.end_angle, arc.clockwise + seg.start_point(), seg.end_point(), arc.center, arc.radius, arc.start_angle, arc.end_angle, arc.clockwise ) @@ -52,9 +50,15 @@ def line_ellipse_intersection(seg: LineSegment, arc: EllipticalArc) -> List[Poin """Find intersection points between a line segment and an elliptical arc.""" GeometryUtils = _get_geometry_utils() return GeometryUtils.line_ellipse_intersection( - seg.start_point(), seg.end_point(), - arc.center, arc.radius_x, arc.radius_y, arc.rotation, - arc.start_angle, arc.end_angle, arc.clockwise + seg.start_point(), + seg.end_point(), + arc.center, + arc.radius_x, + arc.radius_y, + arc.rotation, + arc.start_angle, + arc.end_angle, + arc.clockwise, ) @@ -62,8 +66,16 @@ def circle_circle_intersection(arc1: CircularArc, arc2: CircularArc) -> List[Poi """Find intersection points between two circular arcs.""" GeometryUtils = _get_geometry_utils() return GeometryUtils.circle_circle_intersection( - arc1.center, arc1.radius, arc1.start_angle, arc1.end_angle, arc1.clockwise, - arc2.center, arc2.radius, arc2.start_angle, arc2.end_angle, arc2.clockwise + arc1.center, + arc1.radius, + arc1.start_angle, + arc1.end_angle, + arc1.clockwise, + arc2.center, + arc2.radius, + arc2.start_angle, + arc2.end_angle, + arc2.clockwise, ) @@ -71,9 +83,18 @@ def circle_ellipse_intersection(circle: CircularArc, ellipse: EllipticalArc) -> """Find intersection points between a circular arc and an elliptical arc.""" GeometryUtils = _get_geometry_utils() return GeometryUtils.circle_ellipse_intersection( - circle.center, circle.radius, circle.start_angle, circle.end_angle, circle.clockwise, - ellipse.center, ellipse.radius_x, ellipse.radius_y, ellipse.rotation, - ellipse.start_angle, ellipse.end_angle, ellipse.clockwise + circle.center, + circle.radius, + circle.start_angle, + circle.end_angle, + circle.clockwise, + ellipse.center, + ellipse.radius_x, + ellipse.radius_y, + ellipse.rotation, + ellipse.start_angle, + ellipse.end_angle, + ellipse.clockwise, ) @@ -81,10 +102,20 @@ def ellipse_ellipse_intersection(arc1: EllipticalArc, arc2: EllipticalArc) -> Li """Find intersection points between two elliptical arcs.""" GeometryUtils = _get_geometry_utils() return GeometryUtils.ellipse_ellipse_intersection( - arc1.center, arc1.radius_x, arc1.radius_y, arc1.rotation, - arc1.start_angle, arc1.end_angle, arc1.clockwise, - arc2.center, arc2.radius_x, arc2.radius_y, arc2.rotation, - arc2.start_angle, arc2.end_angle, arc2.clockwise + arc1.center, + arc1.radius_x, + arc1.radius_y, + arc1.rotation, + arc1.start_angle, + arc1.end_angle, + arc1.clockwise, + arc2.center, + arc2.radius_x, + arc2.radius_y, + arc2.rotation, + arc2.start_angle, + arc2.end_angle, + arc2.clockwise, ) @@ -121,10 +152,7 @@ def path_path_intersections(path1: CompositePath, path2: CompositePath) -> List[ for elem2 in path2: points = element_element_intersection(elem1, elem2) for point in points: - is_duplicate = any( - GeometryUtils._points_equal(point, existing) - for existing in results - ) + is_duplicate = any(GeometryUtils._points_equal(point, existing) for existing in results) if not is_duplicate: results.append(point) diff --git a/static/client/geometry/path/line_segment.py b/static/client/geometry/path/line_segment.py index d3bb094b..3705f7e1 100644 --- a/static/client/geometry/path/line_segment.py +++ b/static/client/geometry/path/line_segment.py @@ -94,4 +94,3 @@ def __hash__(self) -> int: def __repr__(self) -> str: return f"LineSegment({self._start}, {self._end})" - diff --git a/static/client/geometry/path/path_element.py b/static/client/geometry/path/path_element.py index 182b64ce..a05bfb19 100644 --- a/static/client/geometry/path/path_element.py +++ b/static/client/geometry/path/path_element.py @@ -73,4 +73,3 @@ def connects_to(self, other: PathElement, tolerance: float = 1e-9) -> bool: dx = end[0] - start[0] dy = end[1] - start[1] return (dx * dx + dy * dy) <= tolerance * tolerance - diff --git a/static/client/geometry/region.py b/static/client/geometry/region.py index e3477513..c147daae 100644 --- a/static/client/geometry/region.py +++ b/static/client/geometry/region.py @@ -24,6 +24,7 @@ def _get_geometry_utils(): global _GEOMETRY_UTILS if _GEOMETRY_UTILS is None: from utils.geometry_utils import GeometryUtils as _GeometryUtils + _GEOMETRY_UTILS = _GeometryUtils return _GEOMETRY_UTILS @@ -39,11 +40,7 @@ class Region: _holes: List of inner closed paths (holes) """ - def __init__( - self, - outer_boundary: CompositePath, - holes: Optional[List[CompositePath]] = None - ) -> None: + def __init__(self, outer_boundary: CompositePath, holes: Optional[List[CompositePath]] = None) -> None: """ Create a Region from an outer boundary and optional holes. @@ -96,17 +93,10 @@ def _path_area(self, path: CompositePath) -> float: for element in path: if isinstance(element, LineSegment): - total_area += GeometryUtils.line_segment_area_contribution( - element.start_point(), - element.end_point() - ) + total_area += GeometryUtils.line_segment_area_contribution(element.start_point(), element.end_point()) elif isinstance(element, CircularArc): total_area += GeometryUtils.circular_segment_area( - element.center, - element.radius, - element.start_angle, - element.end_angle, - element.clockwise + element.center, element.radius, element.start_angle, element.end_angle, element.clockwise ) elif isinstance(element, EllipticalArc): total_area += GeometryUtils.elliptical_segment_area( @@ -116,7 +106,7 @@ def _path_area(self, path: CompositePath) -> float: element.rotation, element.start_angle, element.end_angle, - element.clockwise + element.clockwise, ) return total_area @@ -151,12 +141,7 @@ def signed_area(self) -> float: return outer_area - hole_area - def _point_in_polygon( - self, - x: float, - y: float, - points: List[Tuple[float, float]] - ) -> bool: + def _point_in_polygon(self, x: float, y: float, points: List[Tuple[float, float]]) -> bool: """Ray casting algorithm for point-in-polygon test.""" n = len(points) if n < 3: @@ -252,11 +237,7 @@ def from_circle(cls, center: Tuple[float, float], radius: float) -> Region: @classmethod def from_ellipse( - cls, - center: Tuple[float, float], - radius_x: float, - radius_y: float, - rotation: float = 0.0 + cls, center: Tuple[float, float], radius_x: float, radius_y: float, rotation: float = 0.0 ) -> Region: """Create a Region from an ellipse. @@ -274,12 +255,7 @@ def from_ellipse( return cls(path) @classmethod - def from_half_plane( - cls, - point1: Tuple[float, float], - point2: Tuple[float, float], - size: float = 10000.0 - ) -> Region: + def from_half_plane(cls, point1: Tuple[float, float], point2: Tuple[float, float], size: float = 10000.0) -> Region: """Create a Region representing a half-plane bounded by a line. The half-plane is the area to the LEFT of the directed line from point1 to point2. @@ -331,10 +307,7 @@ def _sample_to_points(self, num_samples: int = 100) -> List[Tuple[float, float]] @staticmethod def _line_intersection( - p1: Tuple[float, float], - p2: Tuple[float, float], - p3: Tuple[float, float], - p4: Tuple[float, float] + p1: Tuple[float, float], p2: Tuple[float, float], p3: Tuple[float, float], p4: Tuple[float, float] ) -> Optional[Tuple[float, float]]: """Find intersection of lines (p1-p2) and (p3-p4).""" x1, y1 = p1 @@ -354,18 +327,16 @@ def _line_intersection( @staticmethod def _is_inside_edge( - point: Tuple[float, float], - edge_start: Tuple[float, float], - edge_end: Tuple[float, float] + point: Tuple[float, float], edge_start: Tuple[float, float], edge_end: Tuple[float, float] ) -> bool: """Check if point is on the inside (left) of a directed edge.""" - return ((edge_end[0] - edge_start[0]) * (point[1] - edge_start[1]) - - (edge_end[1] - edge_start[1]) * (point[0] - edge_start[0])) >= 0 + return ( + (edge_end[0] - edge_start[0]) * (point[1] - edge_start[1]) + - (edge_end[1] - edge_start[1]) * (point[0] - edge_start[0]) + ) >= 0 def _sutherland_hodgman_clip( - self, - subject: List[Tuple[float, float]], - clip: List[Tuple[float, float]] + self, subject: List[Tuple[float, float]], clip: List[Tuple[float, float]] ) -> List[Tuple[float, float]]: """Clip subject polygon against clip polygon using Sutherland-Hodgman.""" if len(subject) < 3 or len(clip) < 3: @@ -392,16 +363,12 @@ def _sutherland_hodgman_clip( if current_inside: if not previous_inside: - intersection = self._line_intersection( - previous, current, edge_start, edge_end - ) + intersection = self._line_intersection(previous, current, edge_start, edge_end) if intersection: output.append(intersection) output.append(current) elif previous_inside: - intersection = self._line_intersection( - previous, current, edge_start, edge_end - ) + intersection = self._line_intersection(previous, current, edge_start, edge_end) if intersection: output.append(intersection) diff --git a/static/client/main.py b/static/client/main.py index e3a5f5bf..ada767ec 100644 --- a/static/client/main.py +++ b/static/client/main.py @@ -77,26 +77,28 @@ def execute_tests() -> None: f"- **Errors:** {results.get('errors', 0)}\n" ) - if results.get('failing_tests'): + if results.get("failing_tests"): summary += "\n#### Failures:\n" - for fail in results['failing_tests']: + for fail in results["failing_tests"]: summary += f"- **{fail['test']}**: {fail['error']}\n" - if results.get('error_tests'): + if results.get("error_tests"): summary += "\n#### Errors:\n" - for err in results['error_tests']: + for err in results["error_tests"]: summary += f"- **{err['test']}**: {err['error']}\n" ai_interface._print_ai_message_in_chat(summary) except Exception as e: - _test_results = window.JSON.stringify({ - "tests_run": 0, - "failures": 0, - "errors": 1, - "failing_tests": [], - "error_tests": [{"test": "Test Runner", "error": str(e)}] - }) + _test_results = window.JSON.stringify( + { + "tests_run": 0, + "failures": 0, + "errors": 1, + "failing_tests": [], + "error_tests": [{"test": "Test Runner", "error": str(e)}], + } + ) if ai_interface is not None: ai_interface._print_ai_message_in_chat(f"Error running tests: {str(e)}") finally: @@ -147,7 +149,7 @@ def redraw_canvas() -> None: return # Update canvas dimensions from current viewport - viewport = document['math-svg'].getBoundingClientRect() + viewport = document["math-svg"].getBoundingClientRect() new_width = viewport.width new_height = viewport.height @@ -167,8 +169,8 @@ def redraw_canvas() -> None: # Update renderer surface size if available if _canvas.renderer is not None: - primitives = getattr(_canvas.renderer, '_shared_primitives', None) - if primitives is not None and hasattr(primitives, 'resize_surface'): + primitives = getattr(_canvas.renderer, "_shared_primitives", None) + if primitives is not None and hasattr(primitives, "resize_surface"): primitives.resize_surface(new_width, new_height) # Redraw the canvas @@ -184,7 +186,7 @@ def main() -> None: global _ai_interface, _canvas # Instantiate the canvas with current SVG viewport dimensions - viewport = document['math-svg'].getBoundingClientRect() + viewport = document["math-svg"].getBoundingClientRect() canvas = Canvas(viewport.width, viewport.height) _canvas = canvas diff --git a/static/client/managers/action_trace_collector.py b/static/client/managers/action_trace_collector.py index 0fd8d18d..06847151 100644 --- a/static/client/managers/action_trace_collector.py +++ b/static/client/managers/action_trace_collector.py @@ -22,9 +22,14 @@ _MAX_RESULT_STR_LEN = 500 # Functions that are not safe to replay (side-effects outside canvas state). -_NON_REPLAYABLE_FUNCTIONS = frozenset({ - "save_workspace", "load_workspace", "delete_workspace", "run_tests", -}) +_NON_REPLAYABLE_FUNCTIONS = frozenset( + { + "save_workspace", + "load_workspace", + "delete_workspace", + "run_tests", + } +) class ActionTraceCollector: @@ -225,7 +230,10 @@ def replay_trace( new_result: Any = None try: results = ResultProcessor.get_results( - [call], available_functions, undoable_functions, canvas, + [call], + available_functions, + undoable_functions, + canvas, ) # Get the single result value (first entry in the dict) if results: @@ -238,13 +246,15 @@ def replay_trace( original = tc.get("result") matched = self._results_match(original, new_result) - match_report.append({ - "function_name": fn, - "matched": matched, - "original_result": self._truncate(original), - "new_result": self._truncate(new_result), - "is_error": is_error, - }) + match_report.append( + { + "function_name": fn, + "matched": matched, + "original_result": self._truncate(original), + "new_result": self._truncate(new_result), + "is_error": is_error, + } + ) return {"match_report": match_report, "skipped": skipped} diff --git a/static/client/managers/angle_manager.py b/static/client/managers/angle_manager.py index 20a7c7ea..bc881fd1 100644 --- a/static/client/managers/angle_manager.py +++ b/static/client/managers/angle_manager.py @@ -60,6 +60,7 @@ from managers.segment_manager import SegmentManager from name_generator.drawable import DrawableNameGenerator + class AngleManager: """ Manages Angle drawables for a Canvas. @@ -147,14 +148,18 @@ def create_angle( # 2. Create/Retrieve segments using these points segment1 = self.segment_manager.create_segment( - vertex_point_obj.x, vertex_point_obj.y, - arm1_defining_point_obj.x, arm1_defining_point_obj.y, - extra_graphics=False + vertex_point_obj.x, + vertex_point_obj.y, + arm1_defining_point_obj.x, + arm1_defining_point_obj.y, + extra_graphics=False, ) segment2 = self.segment_manager.create_segment( - vertex_point_obj.x, vertex_point_obj.y, - arm2_defining_point_obj.x, arm2_defining_point_obj.y, - extra_graphics=False + vertex_point_obj.x, + vertex_point_obj.y, + arm2_defining_point_obj.x, + arm2_defining_point_obj.y, + extra_graphics=False, ) if not segment1 or not segment2: @@ -169,11 +174,11 @@ def create_angle( return existing_angle # 4. Instantiate the Angle (let Angle compute deterministic name if none provided) - angle_kwargs: Dict[str, Any] = {'is_reflex': is_reflex} + angle_kwargs: Dict[str, Any] = {"is_reflex": is_reflex} if color is not None: - angle_kwargs['color'] = color + angle_kwargs["color"] = color if angle_name is not None: - angle_kwargs['name'] = angle_name + angle_kwargs["name"] = angle_name try: new_angle = Angle(segment1, segment2, **angle_kwargs) except ValueError as e: @@ -220,7 +225,7 @@ def get_angle_by_name(self, name: str) -> Optional[Angle]: The Angle object if found, otherwise None. """ # Ensure self.drawables.Angles exists and is iterable - if not hasattr(self.drawables, 'Angles') or not isinstance(self.drawables.Angles, list): + if not hasattr(self.drawables, "Angles") or not isinstance(self.drawables.Angles, list): # print("AngleManager: DrawablesContainer has no 'Angles' list or it's not a list.") return None for angle in self.drawables.Angles: @@ -228,7 +233,9 @@ def get_angle_by_name(self, name: str) -> Optional[Angle]: return angle return None - def get_angle_by_segments(self, segment1: Segment, segment2: Segment, is_reflex_filter: Optional[bool] = None) -> Optional[Angle]: + def get_angle_by_segments( + self, segment1: Segment, segment2: Segment, is_reflex_filter: Optional[bool] = None + ) -> Optional[Angle]: """ Retrieves an Angle by its two defining Segment objects. The order of segments does not matter. @@ -242,28 +249,31 @@ def get_angle_by_segments(self, segment1: Segment, segment2: Segment, is_reflex_ Returns: The Angle object if found, otherwise None. """ - if not hasattr(self.drawables, 'Angles') or not isinstance(self.drawables.Angles, list): + if not hasattr(self.drawables, "Angles") or not isinstance(self.drawables.Angles, list): return None - if not segment1 or not segment2: # Ensure segments themselves are not None + if not segment1 or not segment2: # Ensure segments themselves are not None return None for angle in self.drawables.Angles: - if not (hasattr(angle, 'segment1') and hasattr(angle, 'segment2') and hasattr(angle, 'is_reflex')): - continue # Skip angles that don't have segment or is_reflex attributes properly set up + if not (hasattr(angle, "segment1") and hasattr(angle, "segment2") and hasattr(angle, "is_reflex")): + continue # Skip angles that don't have segment or is_reflex attributes properly set up if not (angle.segment1 and angle.segment2): continue - match_segments = (angle.segment1 is segment1 and angle.segment2 is segment2) or \ - (angle.segment1 is segment2 and angle.segment2 is segment1) + match_segments = (angle.segment1 is segment1 and angle.segment2 is segment2) or ( + angle.segment1 is segment2 and angle.segment2 is segment1 + ) if match_segments: - if is_reflex_filter is None: # No reflex filter, first match is fine + if is_reflex_filter is None: # No reflex filter, first match is fine return angle - elif angle.is_reflex == is_reflex_filter: # Reflex state also matches + elif angle.is_reflex == is_reflex_filter: # Reflex state also matches return angle return None - def get_angle_by_points(self, vertex_point: Point, arm1_point: Point, arm2_point: Point, is_reflex_filter: Optional[bool] = None) -> Optional[Angle]: + def get_angle_by_points( + self, vertex_point: Point, arm1_point: Point, arm2_point: Point, is_reflex_filter: Optional[bool] = None + ) -> Optional[Angle]: """ Retrieves an Angle defined by three Point objects: a common vertex, and one point on each arm. The order of arm1_point and arm2_point does not matter. @@ -278,15 +288,21 @@ def get_angle_by_points(self, vertex_point: Point, arm1_point: Point, arm2_point Returns: The Angle object if found, otherwise None. """ - if not hasattr(self.drawables, 'Angles') or not isinstance(self.drawables.Angles, list): + if not hasattr(self.drawables, "Angles") or not isinstance(self.drawables.Angles, list): return None if not all([vertex_point, arm1_point, arm2_point]): return None for angle in self.drawables.Angles: - if not all([hasattr(angle, 'vertex_point'), hasattr(angle, 'arm1_point'), - hasattr(angle, 'arm2_point'), hasattr(angle, 'is_reflex')]): - continue # Skip angles missing point or is_reflex attributes + if not all( + [ + hasattr(angle, "vertex_point"), + hasattr(angle, "arm1_point"), + hasattr(angle, "arm2_point"), + hasattr(angle, "is_reflex"), + ] + ): + continue # Skip angles missing point or is_reflex attributes if not all([angle.vertex_point, angle.arm1_point, angle.arm2_point]): continue @@ -297,9 +313,9 @@ def get_angle_by_points(self, vertex_point: Point, arm1_point: Point, arm2_point input_arm_points = {arm1_point, arm2_point} if angle_arm_points == input_arm_points: - if is_reflex_filter is None: # No reflex filter, first match is fine + if is_reflex_filter is None: # No reflex filter, first match is fine return angle - elif angle.is_reflex == is_reflex_filter: # Reflex state also matches + elif angle.is_reflex == is_reflex_filter: # Reflex state also matches return angle return None @@ -322,9 +338,9 @@ def delete_angle(self, angle_name: str) -> bool: segment1 = angle_to_delete.segment1 segment2 = angle_to_delete.segment2 - vertex_point = getattr(angle_to_delete, 'vertex_point', None) - arm1_point = getattr(angle_to_delete, 'arm1_point', None) - arm2_point = getattr(angle_to_delete, 'arm2_point', None) + vertex_point = getattr(angle_to_delete, "vertex_point", None) + arm1_point = getattr(angle_to_delete, "arm1_point", None) + arm2_point = getattr(angle_to_delete, "arm2_point", None) # 1. Unregister angle's dependencies on its parent segments if segment1: @@ -346,25 +362,23 @@ def delete_angle(self, angle_name: str) -> bool: # 3. Remove angle from the drawables container # This is the primary step that ensures it won't be drawn in the next canvas redraw. try: - if hasattr(self.drawables, 'Angles') and angle_to_delete in self.drawables.Angles: + if hasattr(self.drawables, "Angles") and angle_to_delete in self.drawables.Angles: self.drawables.Angles.remove(angle_to_delete) - elif hasattr(self.drawables, 'remove') and callable(self.drawables.remove): # Generic remove - self.drawables.remove(angle_to_delete) + elif hasattr(self.drawables, "remove") and callable(self.drawables.remove): # Generic remove + self.drawables.remove(angle_to_delete) else: print(f"AngleManager: Warning - Could not remove angle '{angle_name}' from drawables container.") except ValueError: print(f"AngleManager: Warning - Angle '{angle_name}' not found in Angles list for direct removal.") # 4. Attempt to delete the constituent segments (SegmentManager handles if they are still in use) - if segment1 and hasattr(segment1, 'point1') and hasattr(segment1, 'point2'): + if segment1 and hasattr(segment1, "point1") and hasattr(segment1, "point2"): self.segment_manager.delete_segment( - segment1.point1.x, segment1.point1.y, - segment1.point2.x, segment1.point2.y + segment1.point1.x, segment1.point1.y, segment1.point2.x, segment1.point2.y ) - if segment2 and hasattr(segment2, 'point1') and hasattr(segment2, 'point2'): + if segment2 and hasattr(segment2, "point1") and hasattr(segment2, "point2"): self.segment_manager.delete_segment( - segment2.point1.x, segment2.point1.y, - segment2.point2.x, segment2.point2.y + segment2.point1.x, segment2.point1.y, segment2.point2.x, segment2.point2.y ) # 5. Draw the canvas @@ -465,29 +479,31 @@ def handle_segment_updated(self, updated_segment_name: str) -> None: Args: updated_segment_name: The name of the segment that was updated. """ - if not hasattr(self.drawables, 'Angles') or not isinstance(self.drawables.Angles, list): + if not hasattr(self.drawables, "Angles") or not isinstance(self.drawables.Angles, list): return needs_redraw: bool = False - for angle in cast(List["Drawable"], list(self.drawables.Angles)): # Iterate over a copy in case of modification - if not (hasattr(angle, 'segment1') and hasattr(angle, 'segment2')): + for angle in cast(List["Drawable"], list(self.drawables.Angles)): # Iterate over a copy in case of modification + if not (hasattr(angle, "segment1") and hasattr(angle, "segment2")): continue if not (angle.segment1 and angle.segment2): continue if angle.segment1.name == updated_segment_name or angle.segment2.name == updated_segment_name: - if hasattr(angle, '_initialize') and callable(angle._initialize): + if hasattr(angle, "_initialize") and callable(angle._initialize): try: # Before re-initializing, remove old SVG elements - if hasattr(angle, 'remove_svg_elements') and callable(angle.remove_svg_elements): + if hasattr(angle, "remove_svg_elements") and callable(angle.remove_svg_elements): angle.remove_svg_elements() angle._initialize() needs_redraw = True except ValueError as e: # If _initialize fails (e.g., angle becomes invalid), remove the angle - print(f"AngleManager: Angle '{angle.name}' became invalid after segment '{updated_segment_name}' update. Error: {e}. Removing angle.") - self.delete_angle(angle.name) # This will handle its own draw call - needs_redraw = True # Ensure redraw happens even if this one is removed + print( + f"AngleManager: Angle '{angle.name}' became invalid after segment '{updated_segment_name}' update. Error: {e}. Removing angle." + ) + self.delete_angle(angle.name) # This will handle its own draw call + needs_redraw = True # Ensure redraw happens even if this one is removed else: print(f"AngleManager: Warning - Angle '{angle.name}' does not have _initialize method for update.") @@ -503,13 +519,13 @@ def handle_segment_removed(self, removed_segment_name: str) -> None: Args: removed_segment_name: The name of the segment that was removed. """ - if not hasattr(self.drawables, 'Angles') or not isinstance(self.drawables.Angles, list): + if not hasattr(self.drawables, "Angles") or not isinstance(self.drawables.Angles, list): return # Collect names first to avoid modification issues while iterating angles_to_remove_names: List[str] = [] - for angle in self.drawables.Angles: # No need for list copy if just collecting names - if not (hasattr(angle, 'segment1') and hasattr(angle, 'segment2')): + for angle in self.drawables.Angles: # No need for list copy if just collecting names + if not (hasattr(angle, "segment1") and hasattr(angle, "segment2")): continue if not (angle.segment1 and angle.segment2): continue @@ -522,8 +538,10 @@ def handle_segment_removed(self, removed_segment_name: str) -> None: # self.canvas.undo_redo_manager.archive() for angle_name in angles_to_remove_names: - print(f"AngleManager: Segment '{removed_segment_name}' was removed. Removing dependent angle '{angle_name}'.") - self.delete_angle(angle_name) # This handles its own draw and archive + print( + f"AngleManager: Segment '{removed_segment_name}' was removed. Removing dependent angle '{angle_name}'." + ) + self.delete_angle(angle_name) # This handles its own draw and archive # A final draw might not be necessary if delete_angle always draws and is the last action. # if self.canvas.draw_enabled: @@ -548,9 +566,9 @@ def load_angles(self, angles_data: List[Dict[str, Any]]) -> None: continue # Resolve segments via drawable_manager, then construct Angle without model canvas logic - args: Dict[str, Any] = angle_state.get('args', {}) - seg1_name = args.get('segment1_name') - seg2_name = args.get('segment2_name') + args: Dict[str, Any] = angle_state.get("args", {}) + seg1_name = args.get("segment1_name") + seg2_name = args.get("segment2_name") if not seg1_name or not seg2_name: continue segment1 = self.drawable_manager.get_segment_by_name(seg1_name) @@ -558,9 +576,9 @@ def load_angles(self, angles_data: List[Dict[str, Any]]) -> None: if not segment1 or not segment2: continue - angle_kwargs: Dict[str, Any] = {'is_reflex': args.get('is_reflex', False)} - if 'color' in args: - angle_kwargs['color'] = args['color'] + angle_kwargs: Dict[str, Any] = {"is_reflex": args.get("is_reflex", False)} + if "color" in args: + angle_kwargs["color"] = args["color"] new_angle = Angle(segment1, segment2, **angle_kwargs) @@ -570,23 +588,29 @@ def load_angles(self, angles_data: List[Dict[str, Any]]) -> None: # For simplicity, assume from_state returns a valid, potentially new, object. # A robust system might use get_angle_by_segments with segments derived from names in state. - self.drawables.add(new_angle) # Add to Angles list - if hasattr(new_angle, 'segment1') and new_angle.segment1 and \ - hasattr(new_angle, 'segment2') and new_angle.segment2: + self.drawables.add(new_angle) # Add to Angles list + if ( + hasattr(new_angle, "segment1") + and new_angle.segment1 + and hasattr(new_angle, "segment2") + and new_angle.segment2 + ): self.dependency_manager.register_dependency(child=new_angle, parent=new_angle.segment1) self.dependency_manager.register_dependency(child=new_angle, parent=new_angle.segment2) # Also register dependencies on the constituent points - if hasattr(new_angle, 'vertex_point') and new_angle.vertex_point: + if hasattr(new_angle, "vertex_point") and new_angle.vertex_point: self.dependency_manager.register_dependency(child=new_angle, parent=new_angle.vertex_point) - if hasattr(new_angle, 'arm1_point') and new_angle.arm1_point: + if hasattr(new_angle, "arm1_point") and new_angle.arm1_point: self.dependency_manager.register_dependency(child=new_angle, parent=new_angle.arm1_point) - if hasattr(new_angle, 'arm2_point') and new_angle.arm2_point: + if hasattr(new_angle, "arm2_point") and new_angle.arm2_point: self.dependency_manager.register_dependency(child=new_angle, parent=new_angle.arm2_point) else: - print(f"AngleManager: Warning - Loaded angle '{new_angle.name}' from state but segments are missing. Cannot register dependencies.") + print( + f"AngleManager: Warning - Loaded angle '{new_angle.name}' from state but segments are missing. Cannot register dependencies." + ) else: - angle_name_in_state = angle_state.get('name', '[Unknown Name]') + angle_name_in_state = angle_state.get("name", "[Unknown Name]") print(f"AngleManager: Failed to load angle from state: '{angle_name_in_state}'.") if self.canvas.draw_enabled: @@ -599,9 +623,13 @@ def get_angles_state(self) -> List[Dict[str, Any]]: Returns: A list of dictionaries, where each dict represents the state of an Angle. """ - if not hasattr(self.drawables, 'Angles') or not isinstance(self.drawables.Angles, list): + if not hasattr(self.drawables, "Angles") or not isinstance(self.drawables.Angles, list): return [] - return [angle.get_state() for angle in self.drawables.Angles if hasattr(angle, 'get_state') and callable(angle.get_state)] + return [ + angle.get_state() + for angle in self.drawables.Angles + if hasattr(angle, "get_state") and callable(angle.get_state) + ] def clear_angles(self) -> None: """ @@ -609,17 +637,23 @@ def clear_angles(self) -> None: """ self.canvas.undo_redo_manager.archive() - if not hasattr(self.drawables, 'Angles') or not isinstance(self.drawables.Angles, list): + if not hasattr(self.drawables, "Angles") or not isinstance(self.drawables.Angles, list): return - angle_names_to_remove: List[str] = [angle.name for angle in cast(List["Drawable"], list(self.drawables.Angles)) if hasattr(angle, 'name')] + angle_names_to_remove: List[str] = [ + angle.name for angle in cast(List["Drawable"], list(self.drawables.Angles)) if hasattr(angle, "name") + ] for angle_name in angle_names_to_remove: - self.delete_angle(angle_name) # This will handle individual draws and further dependencies. + self.delete_angle(angle_name) # This will handle individual draws and further dependencies. # A final draw call might be redundant if delete_angle always draws and handles all cleanup. # However, if delete_angle was optimized not to draw, this would be needed. - if self.canvas.draw_enabled and not angle_names_to_remove: # Only draw if nothing was removed (and drawn individually) + if ( + self.canvas.draw_enabled and not angle_names_to_remove + ): # Only draw if nothing was removed (and drawn individually) + self.canvas.draw() + elif ( + self.canvas.draw_enabled and angle_names_to_remove + ): # If angles were removed, they handled their draw, one final comprehensive draw self.canvas.draw() - elif self.canvas.draw_enabled and angle_names_to_remove: # If angles were removed, they handled their draw, one final comprehensive draw - self.canvas.draw() diff --git a/static/client/managers/arc_manager.py b/static/client/managers/arc_manager.py index 34c41225..65396954 100644 --- a/static/client/managers/arc_manager.py +++ b/static/client/managers/arc_manager.py @@ -45,9 +45,7 @@ def __init__( self.dependency_manager = dependency_manager self.point_manager = point_manager self.drawable_manager = drawable_manager_proxy - self.arc_edit_policy: Optional[DrawableEditPolicy] = get_drawable_edit_policy( - "CircleArc" - ) + self.arc_edit_policy: Optional[DrawableEditPolicy] = get_drawable_edit_policy("CircleArc") # ------------------------------------------------------------------ # Creation helpers @@ -65,9 +63,7 @@ def _resolve_point( if existing: return existing, False if x is None or y is None: - raise ValueError( - f"{label} name '{point_name}' was not found and coordinates were not provided." - ) + raise ValueError(f"{label} name '{point_name}' was not found and coordinates were not provided.") created_point = self.point_manager.create_point( x, y, @@ -171,14 +167,10 @@ def _resolve_circle_arc_points( suggested_p1, suggested_p2 = self.name_generator.extract_point_names_from_arc_name(arc_name) # Resolve point 1 - point1, point1_is_new = self._resolve_arc_point( - "Point 1", point1_name, point1_x, point1_y, suggested_p1 - ) + point1, point1_is_new = self._resolve_arc_point("Point 1", point1_name, point1_x, point1_y, suggested_p1) # Resolve point 2 - point2, point2_is_new = self._resolve_arc_point( - "Point 2", point2_name, point2_x, point2_y, suggested_p2 - ) + point2, point2_is_new = self._resolve_arc_point("Point 2", point2_name, point2_x, point2_y, suggested_p2) # Resolve optional point 3 point3: Optional["Point"] = None @@ -213,7 +205,8 @@ def _resolve_arc_point( # Use create_point directly with suggested name (like segments do) # create_point handles: coordinate lookup, name validation via generate_point_name point = self.point_manager.create_point( - x, y, + x, + y, name=suggested_name or "", extra_graphics=False, ) @@ -282,7 +275,16 @@ def _determine_arc_geometry( center_y, radius, ) - return circle, resolved_center_x, resolved_center_y, resolved_radius, point1, point1_is_new, point2, point2_is_new + return ( + circle, + resolved_center_x, + resolved_center_y, + resolved_radius, + point1, + point1_is_new, + point2, + point2_is_new, + ) def _resolve_geometry_from_three_points( self, @@ -307,9 +309,7 @@ def _resolve_geometry_from_three_points( center_point, _ = points_info[center_choice] endpoint_labels = [ - label - for label in ("point1", "point2", "point3") - if label != center_choice and label in points_info + label for label in ("point1", "point2", "point3") if label != center_choice and label in points_info ] if len(endpoint_labels) != 2: raise ValueError("Creating a circle arc from three points requires exactly three defined points.") @@ -575,16 +575,8 @@ def load_circle_arcs(self, arcs_data: List[Dict[str, Any]]) -> None: point1_name = args.get("point1_name") point2_name = args.get("point2_name") circle_name = args.get("circle_name") - point1 = ( - self.point_manager.get_point_by_name(point1_name) - if point1_name - else None - ) - point2 = ( - self.point_manager.get_point_by_name(point2_name) - if point2_name - else None - ) + point1 = self.point_manager.get_point_by_name(point1_name) if point1_name else None + point2 = self.point_manager.get_point_by_name(point2_name) if point2_name else None if not point1 or not point2: continue @@ -619,4 +611,3 @@ def handle_circle_removed(self, circle_name: str) -> None: for arc_name in arcs_to_remove: self.delete_circle_arc(arc_name) - diff --git a/static/client/managers/bar_manager.py b/static/client/managers/bar_manager.py index 866a5203..6834a340 100644 --- a/static/client/managers/bar_manager.py +++ b/static/client/managers/bar_manager.py @@ -243,5 +243,3 @@ def _update_existing_bar( bar.label_text = bar.label_above_text except Exception: pass - - diff --git a/static/client/managers/circle_manager.py b/static/client/managers/circle_manager.py index 51d61282..015a2c67 100644 --- a/static/client/managers/circle_manager.py +++ b/static/client/managers/circle_manager.py @@ -49,6 +49,7 @@ from managers.point_manager import PointManager from name_generator.drawable import DrawableNameGenerator + class CircleManager: """ Manages circle drawables for a Canvas. @@ -101,9 +102,7 @@ def get_circle(self, center_x: float, center_y: float, radius: float) -> Optiona """ circles = self.drawables.Circles for circle in circles: - if (circle.center.x == center_x and - circle.center.y == center_y and - circle.radius == radius): + if circle.center.x == center_x and circle.center.y == center_y and circle.radius == radius: return circle return None @@ -217,9 +216,7 @@ def delete_circle(self, name: str) -> bool: pass # Remove from drawables - removed = remove_drawable_with_dependencies( - self.drawables, self.dependency_manager, circle - ) + removed = remove_drawable_with_dependencies(self.drawables, self.dependency_manager, circle) # Redraw if self.canvas.draw_enabled: diff --git a/static/client/managers/colored_area_manager.py b/static/client/managers/colored_area_manager.py index 7cd144c0..b961de02 100644 --- a/static/client/managers/colored_area_manager.py +++ b/static/client/managers/colored_area_manager.py @@ -74,6 +74,7 @@ from name_generator.drawable import DrawableNameGenerator from drawables.rectangle import Rectangle + class ColoredAreaManager: """ Manages colored area drawables for a Canvas. @@ -200,16 +201,18 @@ def get_y_at_x(segment: Segment, x: float) -> float: y = get_y_at_x(drawable1, x2_max) self.drawable_manager.create_point(x2_max, y) - colored_area: Union[SegmentsBoundedColoredArea, FunctionSegmentBoundedColoredArea, FunctionsBoundedColoredArea] + colored_area: Union[ + SegmentsBoundedColoredArea, FunctionSegmentBoundedColoredArea, FunctionsBoundedColoredArea + ] colored_area = SegmentsBoundedColoredArea(drawable1, drawable2, color=color, opacity=opacity) elif isinstance(drawable2, Segment): # Function-segment case (we know drawable1 is not a segment due to the swap above) colored_area = FunctionSegmentBoundedColoredArea(drawable1, drawable2, color=color, opacity=opacity) else: # Function-function case - colored_area = FunctionsBoundedColoredArea(drawable1, drawable2, - left_bound=left_bound, right_bound=right_bound, - color=color, opacity=opacity) + colored_area = FunctionsBoundedColoredArea( + drawable1, drawable2, left_bound=left_bound, right_bound=right_bound, color=color, opacity=opacity + ) # Add to drawables self.drawables.add(colored_area) @@ -684,9 +687,7 @@ def delete_colored_areas_for_segment( def _remove_colored_area_drawable(self, area: "Drawable") -> bool: """Remove a colored-area drawable and clean dependency graph entries.""" - return bool(remove_drawable_with_dependencies( - self.drawables, self.dependency_manager, area - )) + return bool(remove_drawable_with_dependencies(self.drawables, self.dependency_manager, area)) def get_colored_areas_for_drawable(self, drawable: Union[Function, Segment]) -> List["Drawable"]: """ @@ -775,7 +776,9 @@ def update_colored_area( policy = self._get_policy_for_area(area) self._validate_policy(policy, list(pending_fields.keys())) - self._validate_colored_area_payload(area, pending_fields, new_color, new_opacity, new_left_bound, new_right_bound) + self._validate_colored_area_payload( + area, pending_fields, new_color, new_opacity, new_left_bound, new_right_bound + ) self.canvas.undo_redo_manager.archive() self._apply_colored_area_updates(area, pending_fields, new_color, new_opacity, new_left_bound, new_right_bound) @@ -868,11 +871,7 @@ def _validate_colored_area_payload( if "right_bound" in pending_fields and new_right_bound is not None: updated_right = float(new_right_bound) - if ( - updated_left is not None - and updated_right is not None - and updated_left >= updated_right - ): + if updated_left is not None and updated_right is not None and updated_left >= updated_right: raise ValueError("left_bound must be less than right_bound.") def _apply_colored_area_updates( diff --git a/static/client/managers/construction_manager.py b/static/client/managers/construction_manager.py index d384a444..2867dc7d 100644 --- a/static/client/managers/construction_manager.py +++ b/static/client/managers/construction_manager.py @@ -122,9 +122,7 @@ def _segment_slope(self, seg: "Segment") -> Optional[float]: dx = seg.point2.x - seg.point1.x dy = seg.point2.y - seg.point1.y if abs(dx) < MathUtils.EPSILON and abs(dy) < MathUtils.EPSILON: - raise ValueError( - f"Degenerate segment '{getattr(seg, 'name', '')}': endpoints coincide" - ) + raise ValueError(f"Degenerate segment '{getattr(seg, 'name', '')}': endpoints coincide") if abs(dx) < MathUtils.EPSILON: return None return dy / dx @@ -164,14 +162,13 @@ def create_midpoint( p1 = self._get_point(p1_name) p2 = self._get_point(p2_name) else: - raise ValueError( - "Provide either 'segment_name' or both 'p1_name' and 'p2_name'" - ) + raise ValueError("Provide either 'segment_name' or both 'p1_name' and 'p2_name'") mx, my = MathUtils.get_2D_midpoint(p1, p2) point = self.point_manager.create_point( - mx, my, + mx, + my, name=name or "", color=color or default_color, extra_graphics=False, @@ -218,7 +215,10 @@ def create_perpendicular_bisector( self._archive_for_undo() segment = self.segment_manager.create_segment( - x1, y1, x2, y2, + x1, + y1, + x2, + y2, name=name or "", color=color, extra_graphics=True, @@ -256,9 +256,12 @@ def create_perpendicular_from_point( color = default_color foot_x, foot_y = MathUtils.perpendicular_foot( - pt.x, pt.y, - seg.point1.x, seg.point1.y, - seg.point2.x, seg.point2.y, + pt.x, + pt.y, + seg.point1.x, + seg.point1.y, + seg.point2.x, + seg.point2.y, ) # Use suspend_archiving pattern for composite construction @@ -268,14 +271,18 @@ def create_perpendicular_from_point( try: foot_point = self.point_manager.create_point( - foot_x, foot_y, + foot_x, + foot_y, name="", color=color, extra_graphics=False, ) perp_segment = self.segment_manager.create_segment( - pt.x, pt.y, foot_x, foot_y, + pt.x, + pt.y, + foot_x, + foot_y, name=name or "", color=color, extra_graphics=True, @@ -350,9 +357,7 @@ def create_angle_bisector( p1x, p1y = p1.x, p1.y p2x, p2y = p2.x, p2.y else: - raise ValueError( - "Provide either 'angle_name' or all of 'vertex_name', 'p1_name', 'p2_name'" - ) + raise ValueError("Provide either 'angle_name' or all of 'vertex_name', 'p1_name', 'p2_name'") if length is None: length = DEFAULT_CONSTRUCTION_LENGTH @@ -367,7 +372,6 @@ def create_angle_bisector( dx, dy = -dx, -dy # Create segment from vertex along bisector direction - half = length / 2 x1 = vx y1 = vy x2 = vx + dx * length @@ -375,7 +379,10 @@ def create_angle_bisector( self._archive_for_undo() segment = self.segment_manager.create_segment( - x1, y1, x2, y2, + x1, + y1, + x2, + y2, name=name or "", color=color, extra_graphics=True, @@ -425,7 +432,10 @@ def create_parallel_line( self._archive_for_undo() segment = self.segment_manager.create_segment( - x1, y1, x2, y2, + x1, + y1, + x2, + y2, name=name or "", color=color, extra_graphics=True, @@ -454,9 +464,7 @@ def _triangle_vertices( p3 = self._get_point(p3_name) return (p1.x, p1.y, p2.x, p2.y, p3.x, p3.y) else: - raise ValueError( - "Provide either 'triangle_name' or all of 'p1_name', 'p2_name', 'p3_name'" - ) + raise ValueError("Provide either 'triangle_name' or all of 'p1_name', 'p2_name', 'p3_name'") def create_circumcircle( self, @@ -486,9 +494,7 @@ def create_circumcircle( Raises: ValueError: If inputs not found or points are collinear """ - x1, y1, x2, y2, x3, y3 = self._triangle_vertices( - triangle_name, p1_name, p2_name, p3_name - ) + x1, y1, x2, y2, x3, y3 = self._triangle_vertices(triangle_name, p1_name, p2_name, p3_name) if color is None: color = default_color @@ -501,7 +507,9 @@ def create_circumcircle( try: circle = self.proxy.create_circle( - cx, cy, radius, + cx, + cy, + radius, name=name or "", color=color, extra_graphics=True, @@ -547,9 +555,12 @@ def create_incircle( color = default_color cx, cy, radius = MathUtils.incenter_and_inradius( - verts[0].x, verts[0].y, - verts[1].x, verts[1].y, - verts[2].x, verts[2].y, + verts[0].x, + verts[0].y, + verts[1].x, + verts[1].y, + verts[2].x, + verts[2].y, ) # Use suspend_archiving since create_circle internally archives @@ -559,7 +570,9 @@ def create_incircle( try: circle = self.proxy.create_circle( - cx, cy, radius, + cx, + cy, + radius, name=name or "", color=color, extra_graphics=True, diff --git a/static/client/managers/dependency_removal.py b/static/client/managers/dependency_removal.py index f5b3866c..bc09fb28 100644 --- a/static/client/managers/dependency_removal.py +++ b/static/client/managers/dependency_removal.py @@ -35,10 +35,10 @@ def get_polygon_segments(polygon: Any) -> List[Any]: falling back to named ``segment1``..``segment4`` attributes (used by Triangle, Rectangle, and Quadrilateral). """ - if hasattr(polygon, 'get_segments') and callable(getattr(polygon, 'get_segments')): + if hasattr(polygon, "get_segments") and callable(getattr(polygon, "get_segments")): return polygon.get_segments() segments: List[Any] = [] - for attr in ('segment1', 'segment2', 'segment3', 'segment4'): + for attr in ("segment1", "segment2", "segment3", "segment4"): seg = getattr(polygon, attr, None) if seg is not None: segments.append(seg) diff --git a/static/client/managers/drawable_dependency_manager.py b/static/client/managers/drawable_dependency_manager.py index f315afae..11b5d89e 100644 --- a/static/client/managers/drawable_dependency_manager.py +++ b/static/client/managers/drawable_dependency_manager.py @@ -55,6 +55,7 @@ from managers.drawable_manager_proxy import DrawableManagerProxy from canvas import Canvas + class DrawableDependencyManager: """ Manages dependencies between drawable objects to maintain hierarchical structure. @@ -71,7 +72,7 @@ def __init__( debug_logging: bool = False, ) -> None: """Initialize the dependency manager""" - self.drawable_manager: Optional["DrawableManagerProxy"] = drawable_manager_proxy # Store the proxy + self.drawable_manager: Optional["DrawableManagerProxy"] = drawable_manager_proxy # Store the proxy self._debug_logging: bool = bool(debug_logging) # Re-add internal state maps needed by other methods self._parents: Dict[int, Set[int]] = {} @@ -79,33 +80,33 @@ def __init__( self._object_lookup: Dict[int, "Drawable"] = {} # Type hierarchy - which types depend on which other types self._type_hierarchy: Dict[str, List[str]] = { - 'Point': [], - 'Segment': ['Point'], - 'Vector': ['Segment'], - 'Triangle': ['Segment', 'Point'], - 'Rectangle': ['Segment', 'Point'], - 'Quadrilateral': ['Segment', 'Point'], - 'Pentagon': ['Segment', 'Point'], - 'Hexagon': ['Segment', 'Point'], - 'Heptagon': ['Segment', 'Point'], - 'Octagon': ['Segment', 'Point'], - 'Nonagon': ['Segment', 'Point'], - 'Decagon': ['Segment', 'Point'], - 'GenericPolygon': ['Segment', 'Point'], - 'Circle': ['Point'], - 'Ellipse': ['Point'], - 'Angle': ['Segment', 'Point'], - 'CircleArc': ['Point', 'Circle'], - 'Function': [], - 'ColoredArea': ['Function', 'Segment'], - 'SegmentsBoundedColoredArea': ['Segment'], - 'FunctionSegmentBoundedColoredArea': ['Function', 'Segment'], - 'FunctionsBoundedColoredArea': ['Function'], - 'ClosedShapeColoredArea': ['Segment', 'Circle', 'Ellipse'], - 'Graph': ['Segment', 'Vector', 'Point'], - 'DirectedGraph': ['Vector', 'Point'], - 'UndirectedGraph': ['Segment', 'Point'], - 'Tree': ['Segment', 'Point'], + "Point": [], + "Segment": ["Point"], + "Vector": ["Segment"], + "Triangle": ["Segment", "Point"], + "Rectangle": ["Segment", "Point"], + "Quadrilateral": ["Segment", "Point"], + "Pentagon": ["Segment", "Point"], + "Hexagon": ["Segment", "Point"], + "Heptagon": ["Segment", "Point"], + "Octagon": ["Segment", "Point"], + "Nonagon": ["Segment", "Point"], + "Decagon": ["Segment", "Point"], + "GenericPolygon": ["Segment", "Point"], + "Circle": ["Point"], + "Ellipse": ["Point"], + "Angle": ["Segment", "Point"], + "CircleArc": ["Point", "Circle"], + "Function": [], + "ColoredArea": ["Function", "Segment"], + "SegmentsBoundedColoredArea": ["Segment"], + "FunctionSegmentBoundedColoredArea": ["Function", "Segment"], + "FunctionsBoundedColoredArea": ["Function"], + "ClosedShapeColoredArea": ["Segment", "Circle", "Ellipse"], + "Graph": ["Segment", "Vector", "Point"], + "DirectedGraph": ["Vector", "Point"], + "UndirectedGraph": ["Segment", "Point"], + "Tree": ["Segment", "Point"], } def _describe_drawable(self, drawable: Optional["Drawable"]) -> str: @@ -139,14 +140,14 @@ def _debug_log_dependency_event( def _should_skip_point_point_dependency(self, child: "Drawable", parent: "Drawable") -> bool: """Check if a dependency registration should be skipped (e.g., Point as child of Point).""" is_child_point = ( - hasattr(child, 'get_class_name') - and callable(getattr(child, 'get_class_name', None)) - and child.get_class_name() == 'Point' + hasattr(child, "get_class_name") + and callable(getattr(child, "get_class_name", None)) + and child.get_class_name() == "Point" ) is_parent_point = ( - hasattr(parent, 'get_class_name') - and callable(getattr(parent, 'get_class_name', None)) - and parent.get_class_name() == 'Point' + hasattr(parent, "get_class_name") + and callable(getattr(parent, "get_class_name", None)) + and parent.get_class_name() == "Point" ) return is_child_point and is_parent_point @@ -161,8 +162,8 @@ def register_dependency(self, child: "Drawable", parent: "Drawable") -> None: if child is None or parent is None: return - child_getter = getattr(child, 'get_class_name', None) - parent_getter = getattr(parent, 'get_class_name', None) + child_getter = getattr(child, "get_class_name", None) + parent_getter = getattr(parent, "get_class_name", None) if not callable(child_getter) or not callable(parent_getter): return @@ -234,7 +235,7 @@ def _verify_get_class_name_method(self, obj: Any, obj_type_name: str) -> None: obj: The object to verify obj_type_name: A string indicating the type of object (e.g., "Child", "Parent") """ - if not hasattr(obj, 'get_class_name'): + if not hasattr(obj, "get_class_name"): print(f"WARNING: {obj_type_name} {obj} is missing get_class_name method") # If missing, let's make sure we can still identify the object print(f"{obj_type_name} object type: {type(obj)}") @@ -264,7 +265,7 @@ def _append_attr_dependency_if_present( dependency = getattr(drawable, attr_name) if require_truthy and not dependency: return - if require_get_class_name and not hasattr(dependency, 'get_class_name'): + if require_get_class_name and not hasattr(dependency, "get_class_name"): return self._append_and_register_dependency(drawable, dependency, dependencies) @@ -292,7 +293,7 @@ def _append_segment_attrs( ) -> None: """Append/register segment1..segmentN dependencies when present.""" for i in range(1, count + 1): - self._append_attr_dependency_if_present(drawable, f'segment{i}', dependencies) + self._append_attr_dependency_if_present(drawable, f"segment{i}", dependencies) def get_parents(self, drawable: Optional["Drawable"]) -> Set["Drawable"]: """ @@ -309,7 +310,11 @@ def get_parents(self, drawable: Optional["Drawable"]) -> Set["Drawable"]: return set() drawable_id = id(drawable) - return {self._object_lookup[parent_id] for parent_id in self._parents.get(drawable_id, set()) if parent_id in self._object_lookup} + return { + self._object_lookup[parent_id] + for parent_id in self._parents.get(drawable_id, set()) + if parent_id in self._object_lookup + } def get_children(self, drawable: Optional["Drawable"]) -> Set["Drawable"]: """ @@ -326,7 +331,11 @@ def get_children(self, drawable: Optional["Drawable"]) -> Set["Drawable"]: return set() drawable_id = id(drawable) - return {self._object_lookup[child_id] for child_id in self._children.get(drawable_id, set()) if child_id in self._object_lookup} + return { + self._object_lookup[child_id] + for child_id in self._children.get(drawable_id, set()) + if child_id in self._object_lookup + } def get_all_parents(self, drawable: Optional["Drawable"]) -> Set["Drawable"]: """ @@ -395,7 +404,7 @@ def remove_drawable(self, drawable: "Drawable") -> None: drawable: The drawable to remove """ drawable_id = id(drawable) - drawable_class = drawable.get_class_name() if hasattr(drawable, 'get_class_name') else "" + drawable_class = drawable.get_class_name() if hasattr(drawable, "get_class_name") else "" # Notify children (e.g., graphs) to remove references to this drawable for child_id in self._children.get(drawable_id, set()).copy(): @@ -423,18 +432,17 @@ def remove_drawable(self, drawable: "Drawable") -> None: def _notify_child_of_parent_removal(self, child: "Drawable", parent: "Drawable", parent_class: str) -> None: """Notify a child drawable that one of its parents has been removed.""" - child_class = child.get_class_name() if hasattr(child, 'get_class_name') else "" + child_class = child.get_class_name() if hasattr(child, "get_class_name") else "" # Handle graph types - remove the reference from internal lists - if child_class in ('Graph', 'DirectedGraph', 'UndirectedGraph', 'Tree'): - if parent_class == 'Segment' and hasattr(child, 'remove_segment'): + if child_class in ("Graph", "DirectedGraph", "UndirectedGraph", "Tree"): + if parent_class == "Segment" and hasattr(child, "remove_segment"): child.remove_segment(parent) - elif parent_class == 'Vector' and hasattr(child, 'remove_vector'): + elif parent_class == "Vector" and hasattr(child, "remove_vector"): child.remove_vector(parent) - elif parent_class == 'Point' and hasattr(child, 'remove_point'): + elif parent_class == "Point" and hasattr(child, "remove_point"): child.remove_point(parent) - def analyze_drawable_for_dependencies(self, drawable: "Drawable") -> List["Drawable"]: """ Analyze a drawable to find and register its dependencies @@ -451,120 +459,117 @@ def analyze_drawable_for_dependencies(self, drawable: "Drawable") -> List["Drawa self._verify_get_class_name_method(drawable, "Drawable") # Get class name safely - if not hasattr(drawable, 'get_class_name'): + if not hasattr(drawable, "get_class_name"): print(f"Cannot analyze dependencies for {drawable} without get_class_name method") return dependencies class_name = drawable.get_class_name() # Handle different drawable types - if class_name == 'Point': + if class_name == "Point": # Points don't have dependencies pass - elif class_name == 'Segment': - self._append_attr_dependency_if_present(drawable, 'point1', dependencies) - self._append_attr_dependency_if_present(drawable, 'point2', dependencies) + elif class_name == "Segment": + self._append_attr_dependency_if_present(drawable, "point1", dependencies) + self._append_attr_dependency_if_present(drawable, "point2", dependencies) - elif class_name == 'Vector': - self._append_attr_dependency_if_present(drawable, 'segment', dependencies) + elif class_name == "Vector": + self._append_attr_dependency_if_present(drawable, "segment", dependencies) - elif class_name == 'Triangle': + elif class_name == "Triangle": self._append_segment_attrs(drawable, dependencies, count=3) - elif class_name == 'Rectangle': + elif class_name == "Rectangle": self._append_segment_attrs(drawable, dependencies, count=4) - elif class_name == 'Quadrilateral': + elif class_name == "Quadrilateral": self._append_segment_attrs(drawable, dependencies, count=4) - elif class_name in ('Pentagon', 'Hexagon', 'Heptagon', 'Octagon', - 'Nonagon', 'Decagon', 'GenericPolygon'): - self._append_iterable_attr_dependencies( - drawable, '_segments', dependencies, require_truthy=True - ) + elif class_name in ("Pentagon", "Hexagon", "Heptagon", "Octagon", "Nonagon", "Decagon", "GenericPolygon"): + self._append_iterable_attr_dependencies(drawable, "_segments", dependencies, require_truthy=True) - elif class_name == 'Circle': - self._append_attr_dependency_if_present(drawable, 'center', dependencies) + elif class_name == "Circle": + self._append_attr_dependency_if_present(drawable, "center", dependencies) - elif class_name == 'Ellipse': - self._append_attr_dependency_if_present(drawable, 'center', dependencies) + elif class_name == "Ellipse": + self._append_attr_dependency_if_present(drawable, "center", dependencies) - elif class_name == 'Function': + elif class_name == "Function": # Functions typically don't have drawable dependencies pass - elif class_name == 'SegmentsBoundedColoredArea': - self._append_attr_dependency_if_present(drawable, 'segment1', dependencies, require_truthy=True) - self._append_attr_dependency_if_present(drawable, 'segment2', dependencies, require_truthy=True) + elif class_name == "SegmentsBoundedColoredArea": + self._append_attr_dependency_if_present(drawable, "segment1", dependencies, require_truthy=True) + self._append_attr_dependency_if_present(drawable, "segment2", dependencies, require_truthy=True) - elif class_name == 'FunctionSegmentBoundedColoredArea': + elif class_name == "FunctionSegmentBoundedColoredArea": self._append_attr_dependency_if_present( drawable, - 'func', + "func", dependencies, require_truthy=True, require_get_class_name=True, ) - self._append_attr_dependency_if_present(drawable, 'segment', dependencies) + self._append_attr_dependency_if_present(drawable, "segment", dependencies) - elif class_name == 'FunctionsBoundedColoredArea': + elif class_name == "FunctionsBoundedColoredArea": self._append_attr_dependency_if_present( drawable, - 'func1', + "func1", dependencies, require_truthy=True, require_get_class_name=True, ) self._append_attr_dependency_if_present( drawable, - 'func2', + "func2", dependencies, require_truthy=True, require_get_class_name=True, ) - elif class_name == 'Angle': + elif class_name == "Angle": # Angles depend on their constituent segments and points - self._append_attr_dependency_if_present(drawable, 'segment1', dependencies, require_truthy=True) - self._append_attr_dependency_if_present(drawable, 'segment2', dependencies, require_truthy=True) - self._append_attr_dependency_if_present(drawable, 'vertex_point', dependencies, require_truthy=True) - self._append_attr_dependency_if_present(drawable, 'arm1_point', dependencies, require_truthy=True) - self._append_attr_dependency_if_present(drawable, 'arm2_point', dependencies, require_truthy=True) - - elif class_name == 'CircleArc': - self._append_attr_dependency_if_present(drawable, 'point1', dependencies, require_truthy=True) - self._append_attr_dependency_if_present(drawable, 'point2', dependencies, require_truthy=True) - self._append_attr_dependency_if_present(drawable, 'circle', dependencies, require_truthy=True) - - elif class_name == 'ClosedShapeColoredArea': - self._append_iterable_attr_dependencies(drawable, 'segments', dependencies, require_truthy=True) - self._append_attr_dependency_if_present(drawable, 'circle', dependencies, require_truthy=True) - self._append_attr_dependency_if_present(drawable, 'ellipse', dependencies, require_truthy=True) - self._append_attr_dependency_if_present(drawable, 'chord_segment', dependencies, require_truthy=True) - - elif class_name == 'ColoredArea': + self._append_attr_dependency_if_present(drawable, "segment1", dependencies, require_truthy=True) + self._append_attr_dependency_if_present(drawable, "segment2", dependencies, require_truthy=True) + self._append_attr_dependency_if_present(drawable, "vertex_point", dependencies, require_truthy=True) + self._append_attr_dependency_if_present(drawable, "arm1_point", dependencies, require_truthy=True) + self._append_attr_dependency_if_present(drawable, "arm2_point", dependencies, require_truthy=True) + + elif class_name == "CircleArc": + self._append_attr_dependency_if_present(drawable, "point1", dependencies, require_truthy=True) + self._append_attr_dependency_if_present(drawable, "point2", dependencies, require_truthy=True) + self._append_attr_dependency_if_present(drawable, "circle", dependencies, require_truthy=True) + + elif class_name == "ClosedShapeColoredArea": + self._append_iterable_attr_dependencies(drawable, "segments", dependencies, require_truthy=True) + self._append_attr_dependency_if_present(drawable, "circle", dependencies, require_truthy=True) + self._append_attr_dependency_if_present(drawable, "ellipse", dependencies, require_truthy=True) + self._append_attr_dependency_if_present(drawable, "chord_segment", dependencies, require_truthy=True) + + elif class_name == "ColoredArea": # Base ColoredArea type - self._append_attr_dependency_if_present(drawable, 'function', dependencies) - self._append_iterable_attr_dependencies(drawable, 'segments', dependencies) + self._append_attr_dependency_if_present(drawable, "function", dependencies) + self._append_iterable_attr_dependencies(drawable, "segments", dependencies) - elif class_name.endswith('ColoredArea'): + elif class_name.endswith("ColoredArea"): # Generic case for other ColoredArea types - self._append_attr_dependency_if_present(drawable, 'function', dependencies) - self._append_iterable_attr_dependencies(drawable, 'segments', dependencies) + self._append_attr_dependency_if_present(drawable, "function", dependencies) + self._append_iterable_attr_dependencies(drawable, "segments", dependencies) - elif class_name in ('Graph', 'DirectedGraph', 'UndirectedGraph', 'Tree'): + elif class_name in ("Graph", "DirectedGraph", "UndirectedGraph", "Tree"): # Graphs depend on their segments, vectors, and isolated points - self._append_iterable_attr_dependencies(drawable, '_segments', dependencies, require_truthy=True) - self._append_iterable_attr_dependencies(drawable, '_vectors', dependencies, require_truthy=True) - self._append_iterable_attr_dependencies(drawable, '_isolated_points', dependencies, require_truthy=True) + self._append_iterable_attr_dependencies(drawable, "_segments", dependencies, require_truthy=True) + self._append_iterable_attr_dependencies(drawable, "_vectors", dependencies, require_truthy=True) + self._append_iterable_attr_dependencies(drawable, "_isolated_points", dependencies, require_truthy=True) return dependencies def _find_segment_children(self, segment: Optional["Drawable"]) -> List["Drawable"]: """Finds children geometrically by iterating through all segments.""" # Safety check for segment and its points - if not segment or not hasattr(segment, 'point1') or not hasattr(segment, 'point2'): + if not segment or not hasattr(segment, "point1") or not hasattr(segment, "point2"): return [] # Safety check for points @@ -581,7 +586,7 @@ def _find_segment_children(self, segment: Optional["Drawable"]) -> List["Drawabl for s in all_segments: if s == segment: continue - if not hasattr(s, 'point1') or not hasattr(s, 'point2'): # Safety check + if not hasattr(s, "point1") or not hasattr(s, "point2"): # Safety check continue if not s.point1 or not s.point2: continue @@ -589,8 +594,9 @@ def _find_segment_children(self, segment: Optional["Drawable"]) -> List["Drawabl p1x, p1y = s.point1.x, s.point1.y p2x, p2y = s.point2.x, s.point2.y # Check if s is geometrically within segment - if MathUtils.is_point_on_segment(p1x, p1y, sp1x, sp1y, sp2x, sp2y) and \ - MathUtils.is_point_on_segment(p2x, p2y, sp1x, sp1y, sp2x, sp2y): + if MathUtils.is_point_on_segment(p1x, p1y, sp1x, sp1y, sp2x, sp2y) and MathUtils.is_point_on_segment( + p2x, p2y, sp1x, sp1y, sp2x, sp2y + ): children.append(s) return children diff --git a/static/client/managers/drawable_manager.py b/static/client/managers/drawable_manager.py index fb53ac74..2f986365 100644 --- a/static/client/managers/drawable_manager.py +++ b/static/client/managers/drawable_manager.py @@ -88,6 +88,7 @@ from drawables.circle_arc import CircleArc from geometry.graph_state import GraphState + class DrawableManager: """ Manages drawable objects for a Canvas. @@ -118,7 +119,9 @@ def __init__(self, canvas: "Canvas") -> None: self.proxy: DrawableManagerProxy = DrawableManagerProxy(self) # Instantiate DependencyManager with just the proxy - self.dependency_manager: DrawableDependencyManager = DrawableDependencyManager(drawable_manager_proxy=self.proxy) + self.dependency_manager: DrawableDependencyManager = DrawableDependencyManager( + drawable_manager_proxy=self.proxy + ) # Initialize specialized managers with the proxy self.point_manager: PointManager = PointManager( @@ -126,13 +129,11 @@ def __init__(self, canvas: "Canvas") -> None: ) self.segment_manager: SegmentManager = SegmentManager( - canvas, self.drawables, self.name_generator, self.dependency_manager, - self.point_manager, self.proxy + canvas, self.drawables, self.name_generator, self.dependency_manager, self.point_manager, self.proxy ) self.vector_manager: VectorManager = VectorManager( - canvas, self.drawables, self.name_generator, self.dependency_manager, - self.point_manager, self.proxy + canvas, self.drawables, self.name_generator, self.dependency_manager, self.point_manager, self.proxy ) self.polygon_manager: PolygonManager = PolygonManager( @@ -158,13 +159,11 @@ def __init__(self, canvas: "Canvas") -> None: ) self.circle_manager: CircleManager = CircleManager( - canvas, self.drawables, self.name_generator, self.dependency_manager, - self.point_manager, self.proxy + canvas, self.drawables, self.name_generator, self.dependency_manager, self.point_manager, self.proxy ) self.ellipse_manager: EllipseManager = EllipseManager( - canvas, self.drawables, self.name_generator, self.dependency_manager, - self.point_manager, self.proxy + canvas, self.drawables, self.name_generator, self.dependency_manager, self.point_manager, self.proxy ) self.colored_area_manager: ColoredAreaManager = ColoredAreaManager( @@ -172,8 +171,13 @@ def __init__(self, canvas: "Canvas") -> None: ) self.angle_manager: AngleManager = AngleManager( - canvas, self.drawables, self.name_generator, self.dependency_manager, - self.point_manager, self.segment_manager, self.proxy + canvas, + self.drawables, + self.name_generator, + self.dependency_manager, + self.point_manager, + self.segment_manager, + self.proxy, ) self.label_manager: LabelManager = LabelManager( @@ -454,7 +458,9 @@ def create_segment( label_visible=label_visible, ) - def delete_segment(self, x1: float, y1: float, x2: float, y2: float, delete_children: bool = True, delete_parents: bool = False) -> bool: + def delete_segment( + self, x1: float, y1: float, x2: float, y2: float, delete_children: bool = True, delete_parents: bool = False + ) -> bool: """Delete a segment at the specified coordinates""" return bool(self.segment_manager.delete_segment(x1, y1, x2, y2, delete_children, delete_parents)) @@ -889,18 +895,21 @@ def plot_distribution( fill_opacity: Optional[float], bar_count: Optional[float], ) -> Dict[str, Any]: - return cast(Dict[str, Any], self.statistics_manager.plot_distribution( - name=name, - representation=representation, - distribution_type=distribution_type, - distribution_params=distribution_params, - plot_bounds=plot_bounds, - shade_bounds=shade_bounds, - curve_color=curve_color, - fill_color=fill_color, - fill_opacity=fill_opacity, - bar_count=bar_count, - )) + return cast( + Dict[str, Any], + self.statistics_manager.plot_distribution( + name=name, + representation=representation, + distribution_type=distribution_type, + distribution_params=distribution_params, + plot_bounds=plot_bounds, + shade_bounds=shade_bounds, + curve_color=curve_color, + fill_color=fill_color, + fill_opacity=fill_opacity, + bar_count=bar_count, + ), + ) def plot_bars( self, @@ -917,19 +926,22 @@ def plot_bars( x_start: Optional[float], y_base: Optional[float], ) -> Dict[str, Any]: - return cast(Dict[str, Any], self.statistics_manager.plot_bars( - name=name, - values=values, - labels_below=labels_below, - labels_above=labels_above, - bar_spacing=bar_spacing, - bar_width=bar_width, - stroke_color=stroke_color, - fill_color=fill_color, - fill_opacity=fill_opacity, - x_start=x_start, - y_base=y_base, - )) + return cast( + Dict[str, Any], + self.statistics_manager.plot_bars( + name=name, + values=values, + labels_below=labels_below, + labels_above=labels_above, + bar_spacing=bar_spacing, + bar_width=bar_width, + stroke_color=stroke_color, + fill_color=fill_color, + fill_opacity=fill_opacity, + x_start=x_start, + y_base=y_base, + ), + ) def delete_plot(self, name: str) -> bool: return bool(self.statistics_manager.delete_plot(name)) @@ -947,17 +959,20 @@ def fit_regression( show_points: Optional[bool], point_color: Optional[str], ) -> Dict[str, Any]: - return cast(Dict[str, Any], self.statistics_manager.fit_regression( - name=name, - x_data=x_data, - y_data=y_data, - model_type=model_type, - degree=degree, - plot_bounds=plot_bounds, - curve_color=curve_color, - show_points=show_points, - point_color=point_color, - )) + return cast( + Dict[str, Any], + self.statistics_manager.fit_regression( + name=name, + x_data=x_data, + y_data=y_data, + model_type=model_type, + degree=degree, + plot_bounds=plot_bounds, + curve_color=curve_color, + show_points=show_points, + point_color=point_color, + ), + ) # ------------------- Graph Methods ------------------- def create_graph(self, graph_state: "GraphState") -> "Drawable": @@ -1000,14 +1015,31 @@ def capture_graph_state(self, name: str) -> None: self.graph_manager.capture_state(name) # ------------------- Angle Methods ------------------- - def create_angle(self, vx: float, vy: float, p1x: float, p1y: float, p2x: float, p2y: float, color: Optional[str] = None, angle_name: Optional[str] = None, is_reflex: bool = False, extra_graphics: bool = True) -> Optional["Angle"]: + def create_angle( + self, + vx: float, + vy: float, + p1x: float, + p1y: float, + p2x: float, + p2y: float, + color: Optional[str] = None, + angle_name: Optional[str] = None, + is_reflex: bool = False, + extra_graphics: bool = True, + ) -> Optional["Angle"]: """Creates an angle defined by three points.""" return self.angle_manager.create_angle( - vx, vy, p1x, p1y, p2x, p2y, + vx, + vy, + p1x, + p1y, + p2x, + p2y, color=color, angle_name=angle_name, is_reflex=is_reflex, - extra_graphics=extra_graphics + extra_graphics=extra_graphics, ) def delete_angle(self, name: str) -> bool: @@ -1144,9 +1176,7 @@ def create_tangent_line( Returns: The created Segment drawable """ - return self.tangent_manager.create_tangent_line( - curve_name, parameter, name=name, length=length, color=color - ) + return self.tangent_manager.create_tangent_line(curve_name, parameter, name=name, length=length, color=color) def create_normal_line( self, @@ -1168,9 +1198,7 @@ def create_normal_line( Returns: The created Segment drawable """ - return self.tangent_manager.create_normal_line( - curve_name, parameter, name=name, length=length, color=color - ) + return self.tangent_manager.create_normal_line(curve_name, parameter, name=name, length=length, color=color) # ------------------- Construction Methods ------------------- @@ -1227,8 +1255,7 @@ def create_angle_bisector( ) -> "Segment": """Create a segment along the bisector of an angle.""" return self.construction_manager.create_angle_bisector( - vertex_name, p1_name, p2_name, - angle_name=angle_name, length=length, name=name, color=color + vertex_name, p1_name, p2_name, angle_name=angle_name, length=length, name=name, color=color ) def create_circumcircle( @@ -1244,8 +1271,11 @@ def create_circumcircle( """Create the circumscribed circle of a triangle or three points.""" return self.construction_manager.create_circumcircle( triangle_name=triangle_name, - p1_name=p1_name, p2_name=p2_name, p3_name=p3_name, - name=name, color=color, + p1_name=p1_name, + p2_name=p2_name, + p3_name=p3_name, + name=name, + color=color, ) def create_incircle( @@ -1257,7 +1287,9 @@ def create_incircle( ) -> "Circle": """Create the inscribed circle of a triangle.""" return self.construction_manager.create_incircle( - triangle_name, name=name, color=color, + triangle_name, + name=name, + color=color, ) def create_parallel_line( diff --git a/static/client/managers/drawable_manager_proxy.py b/static/client/managers/drawable_manager_proxy.py index c17e294a..205574a2 100644 --- a/static/client/managers/drawable_manager_proxy.py +++ b/static/client/managers/drawable_manager_proxy.py @@ -33,6 +33,7 @@ if TYPE_CHECKING: from managers.drawable_manager import DrawableManager + class DrawableManagerProxy: """A proxy for the DrawableManager that forwards all attribute access to the real manager. diff --git a/static/client/managers/drawables_container.py b/static/client/managers/drawables_container.py index 4720f074..a1b08b04 100644 --- a/static/client/managers/drawables_container.py +++ b/static/client/managers/drawables_container.py @@ -45,6 +45,7 @@ if TYPE_CHECKING: from drawables.drawable import Drawable + class DrawablesContainer: """ A container for storing and accessing drawable objects by their class names. @@ -85,9 +86,7 @@ def _apply_layering(self, colored: List["Drawable"], others: List["Drawable"]) - for drawable in others: class_name = ( - drawable.get_class_name() - if hasattr(drawable, "get_class_name") - else drawable.__class__.__name__ + drawable.get_class_name() if hasattr(drawable, "get_class_name") else drawable.__class__.__name__ ) if class_name == "Circle": circles.append(drawable) @@ -174,7 +173,7 @@ def get_colored_areas(self) -> List["Drawable"]: """ colored_areas = [] for drawable_type in self._drawables: - if 'ColoredArea' in drawable_type: + if "ColoredArea" in drawable_type: colored_areas.extend(self._drawables[drawable_type]) return colored_areas @@ -187,7 +186,7 @@ def get_non_colored_areas(self) -> List["Drawable"]: """ other_drawables = [] for drawable_type in self._drawables: - if 'ColoredArea' not in drawable_type: + if "ColoredArea" not in drawable_type: other_drawables.extend(self._drawables[drawable_type]) return other_drawables @@ -209,7 +208,7 @@ def get_renderables_with_layering(self) -> List["Drawable"]: colored: List["Drawable"] = [] others: List["Drawable"] = [] for class_name, bucket in self._renderables.items(): - if 'ColoredArea' in class_name: + if "ColoredArea" in class_name: colored.extend(bucket) else: others.extend(bucket) @@ -229,7 +228,7 @@ def get_state(self) -> Dict[str, List[Dict[str, Any]]]: """ state_dict = {} for category, drawables in self._drawables.items(): - state_dict[category + 's'] = [drawable.get_state() for drawable in drawables] + state_dict[category + "s"] = [drawable.get_state() for drawable in drawables] return state_dict def rebuild_renderables(self) -> None: @@ -243,22 +242,22 @@ def rebuild_renderables(self) -> None: @property def Points(self) -> List["Drawable"]: """Get all Point objects.""" - return self.get_by_class_name('Point') + return self.get_by_class_name("Point") @property def Segments(self) -> List["Drawable"]: """Get all Segment objects.""" - return self.get_by_class_name('Segment') + return self.get_by_class_name("Segment") @property def Vectors(self) -> List["Drawable"]: """Get all Vector objects.""" - return self.get_by_class_name('Vector') + return self.get_by_class_name("Vector") @property def Triangles(self) -> List["Drawable"]: """Get all Triangle objects.""" - return self.get_by_class_name('Triangle') + return self.get_by_class_name("Triangle") def get_triangle_by_name(self, name: str) -> Optional["Drawable"]: """Retrieve a triangle by name.""" @@ -272,7 +271,7 @@ def get_triangle_by_name(self, name: str) -> Optional["Drawable"]: @property def Rectangles(self) -> List["Drawable"]: """Get all Rectangle objects.""" - return self.get_by_class_name('Rectangle') + return self.get_by_class_name("Rectangle") def get_rectangle_by_name(self, name: str) -> Optional["Drawable"]: """Retrieve a rectangle by name.""" @@ -286,7 +285,7 @@ def get_rectangle_by_name(self, name: str) -> Optional["Drawable"]: @property def Quadrilaterals(self) -> List["Drawable"]: """Get all Quadrilateral objects.""" - return self.get_by_class_name('Quadrilateral') + return self.get_by_class_name("Quadrilateral") def get_quadrilateral_by_name(self, name: str) -> Optional["Drawable"]: """Retrieve a quadrilateral by name.""" @@ -300,7 +299,7 @@ def get_quadrilateral_by_name(self, name: str) -> Optional["Drawable"]: @property def Pentagons(self) -> List["Drawable"]: """Get all Pentagon objects.""" - return self.get_by_class_name('Pentagon') + return self.get_by_class_name("Pentagon") def get_pentagon_by_name(self, name: str) -> Optional["Drawable"]: """Retrieve a pentagon by name.""" @@ -314,7 +313,7 @@ def get_pentagon_by_name(self, name: str) -> Optional["Drawable"]: @property def Hexagons(self) -> List["Drawable"]: """Get all Hexagon objects.""" - return self.get_by_class_name('Hexagon') + return self.get_by_class_name("Hexagon") def get_hexagon_by_name(self, name: str) -> Optional["Drawable"]: """Retrieve a hexagon by name.""" @@ -328,8 +327,16 @@ def get_hexagon_by_name(self, name: str) -> Optional["Drawable"]: def iter_polygons(self, allowed_classes: Optional[Iterable[str]] = None) -> Iterable["Drawable"]: """Iterate over stored polygon drawables, optionally filtered by class name.""" polygon_classes = ( - "Triangle", "Quadrilateral", "Rectangle", "Pentagon", "Hexagon", - "Heptagon", "Octagon", "Nonagon", "Decagon", "GenericPolygon", + "Triangle", + "Quadrilateral", + "Rectangle", + "Pentagon", + "Hexagon", + "Heptagon", + "Octagon", + "Nonagon", + "Decagon", + "GenericPolygon", ) target_classes = tuple(allowed_classes) if allowed_classes else polygon_classes for class_name in target_classes: @@ -348,67 +355,67 @@ def get_polygon_by_name(self, name: str, allowed_classes: Optional[Iterable[str] @property def Circles(self) -> List["Drawable"]: """Get all Circle objects.""" - return self.get_by_class_name('Circle') + return self.get_by_class_name("Circle") @property def Ellipses(self) -> List["Drawable"]: """Get all Ellipse objects.""" - return self.get_by_class_name('Ellipse') + return self.get_by_class_name("Ellipse") @property def Functions(self) -> List["Drawable"]: """Get all Function objects.""" - return self.get_by_class_name('Function') + return self.get_by_class_name("Function") @property def PiecewiseFunctions(self) -> List["Drawable"]: """Get all PiecewiseFunction objects.""" - return self.get_by_class_name('PiecewiseFunction') + return self.get_by_class_name("PiecewiseFunction") @property def ParametricFunctions(self) -> List["Drawable"]: """Get all ParametricFunction objects.""" - return self.get_by_class_name('ParametricFunction') + return self.get_by_class_name("ParametricFunction") @property def Labels(self) -> List["Drawable"]: """Get all Label objects.""" - return self.get_by_class_name('Label') + return self.get_by_class_name("Label") @property def ColoredAreas(self) -> List["Drawable"]: """Get all ColoredArea objects.""" - return self.get_by_class_name('ColoredArea') + return self.get_by_class_name("ColoredArea") @property def FunctionsBoundedColoredAreas(self) -> List["Drawable"]: """Get all FunctionsBoundedColoredArea objects.""" - return self.get_by_class_name('FunctionsBoundedColoredArea') + return self.get_by_class_name("FunctionsBoundedColoredArea") @property def Angles(self) -> List["Drawable"]: """Get all Angle objects.""" - return self.get_by_class_name('Angle') + return self.get_by_class_name("Angle") @property def CircleArcs(self) -> List["Drawable"]: """Get all CircleArc objects.""" - return self.get_by_class_name('CircleArc') + return self.get_by_class_name("CircleArc") @property def SegmentsBoundedColoredAreas(self) -> List["Drawable"]: """Get all SegmentsBoundedColoredArea objects.""" - return self.get_by_class_name('SegmentsBoundedColoredArea') + return self.get_by_class_name("SegmentsBoundedColoredArea") @property def FunctionSegmentBoundedColoredAreas(self) -> List["Drawable"]: """Get all FunctionSegmentBoundedColoredArea objects.""" - return self.get_by_class_name('FunctionSegmentBoundedColoredArea') + return self.get_by_class_name("FunctionSegmentBoundedColoredArea") @property def ClosedShapeColoredAreas(self) -> List["Drawable"]: """Get all ClosedShapeColoredArea objects.""" - return self.get_by_class_name('ClosedShapeColoredArea') + return self.get_by_class_name("ClosedShapeColoredArea") # Direct dictionary-like access def __getitem__(self, key: str) -> List["Drawable"]: diff --git a/static/client/managers/edit_policy.py b/static/client/managers/edit_policy.py index 9476689a..2bef9ee8 100644 --- a/static/client/managers/edit_policy.py +++ b/static/client/managers/edit_policy.py @@ -382,5 +382,3 @@ def get_drawable_edit_policy(drawable_type: str) -> Optional[DrawableEditPolicy] """Lookup the policy for a drawable type.""" return DRAWABLE_EDIT_POLICIES.get(drawable_type) - - diff --git a/static/client/managers/ellipse_manager.py b/static/client/managers/ellipse_manager.py index c0181463..eed4f8a6 100644 --- a/static/client/managers/ellipse_manager.py +++ b/static/client/managers/ellipse_manager.py @@ -48,6 +48,7 @@ from managers.point_manager import PointManager from name_generator.drawable import DrawableNameGenerator + class EllipseManager: """ Manages ellipse drawables for a Canvas. @@ -103,10 +104,12 @@ def get_ellipse(self, center_x: float, center_y: float, radius_x: float, radius_ """ ellipses = self.drawables.Ellipses for ellipse in ellipses: - if (ellipse.center.x == center_x and - ellipse.center.y == center_y and - ellipse.radius_x == radius_x and - ellipse.radius_y == radius_y): + if ( + ellipse.center.x == center_x + and ellipse.center.y == center_y + and ellipse.radius_x == radius_x + and ellipse.radius_y == radius_y + ): return ellipse return None @@ -227,9 +230,7 @@ def delete_ellipse(self, name: str) -> bool: pass # Remove from drawables - removed = remove_drawable_with_dependencies( - self.drawables, self.dependency_manager, ellipse - ) + removed = remove_drawable_with_dependencies(self.drawables, self.dependency_manager, ellipse) # Redraw if self.canvas.draw_enabled: @@ -334,7 +335,9 @@ def _enforce_ellipse_rules( ) -> None: if any(rule.requires_solitary for rule in rules.values()): if not self._is_ellipse_solitary(ellipse): - raise ValueError(f"Ellipse '{ellipse.name}' is referenced by other drawables and cannot be edited in place.") + raise ValueError( + f"Ellipse '{ellipse.name}' is referenced by other drawables and cannot be edited in place." + ) if "center" in pending_fields: if not self._is_center_point_exclusive(ellipse): diff --git a/static/client/managers/function_manager.py b/static/client/managers/function_manager.py index f959b54e..902f1f55 100644 --- a/static/client/managers/function_manager.py +++ b/static/client/managers/function_manager.py @@ -57,6 +57,7 @@ from managers.drawable_manager_proxy import DrawableManagerProxy from name_generator.drawable import DrawableNameGenerator + class FunctionManager: """Manages function drawables for a Canvas with mathematical expression support.""" @@ -143,7 +144,9 @@ def draw_function( # If it exists, update its expression try: existing_function.function_string = ExpressionValidator.fix_math_expression(function_string) - existing_function._base_function = ExpressionValidator.parse_function_string(function_string, use_mathjs=False) + existing_function._base_function = ExpressionValidator.parse_function_string( + function_string, use_mathjs=False + ) except Exception as e: raise ValueError(f"Failed to parse function string '{function_string}': {str(e)}") # Update the bounds @@ -210,9 +213,7 @@ def delete_function(self, name: str) -> bool: self.canvas.undo_redo_manager.archive() # Remove the function - removed = remove_drawable_with_dependencies( - self.drawables, self.dependency_manager, function - ) + removed = remove_drawable_with_dependencies(self.drawables, self.dependency_manager, function) # Also delete any colored areas associated with this function self.canvas.drawable_manager.delete_colored_areas_for_function(function, archive=False) diff --git a/static/client/managers/graph_manager.py b/static/client/managers/graph_manager.py index b2ed3815..616dae0f 100644 --- a/static/client/managers/graph_manager.py +++ b/static/client/managers/graph_manager.py @@ -245,9 +245,7 @@ def delete_graph(self, name: str) -> bool: if isinstance(existing, DirectedGraph): vectors: List["Vector"] = list(existing.vectors) for vector in vectors: - self.vector_manager.delete_vector( - vector.origin.x, vector.origin.y, vector.tip.x, vector.tip.y - ) + self.vector_manager.delete_vector(vector.origin.x, vector.origin.y, vector.tip.x, vector.tip.y) point_names.add(vector.origin.name) point_names.add(vector.tip.name) else: @@ -272,9 +270,7 @@ def delete_graph(self, name: str) -> bool: for v_name in point_names: self.point_manager.delete_point_by_name(v_name) - removed = remove_drawable_with_dependencies( - self.drawables, self.dependency_manager, existing - ) + removed = remove_drawable_with_dependencies(self.drawables, self.dependency_manager, existing) if removed and self.canvas.draw_enabled: self.canvas.draw() return bool(removed) diff --git a/static/client/managers/label_manager.py b/static/client/managers/label_manager.py index 9481fa80..c281a826 100644 --- a/static/client/managers/label_manager.py +++ b/static/client/managers/label_manager.py @@ -247,4 +247,3 @@ def _normalize_rotation( if not math.isfinite(numeric): raise ValueError("Label rotation must be a finite number.") return numeric - diff --git a/static/client/managers/parametric_function_manager.py b/static/client/managers/parametric_function_manager.py index 2fd405c7..691d56e8 100644 --- a/static/client/managers/parametric_function_manager.py +++ b/static/client/managers/parametric_function_manager.py @@ -164,9 +164,7 @@ def delete_parametric_function(self, name: str) -> bool: self._archive_for_undo() # Remove from container - result = remove_drawable_with_dependencies( - self.drawables, self.dependency_manager, func - ) + result = remove_drawable_with_dependencies(self.drawables, self.dependency_manager, func) # Trigger render if enabled if result and getattr(self.canvas, "draw_enabled", False): diff --git a/static/client/managers/piecewise_function_manager.py b/static/client/managers/piecewise_function_manager.py index 983081ff..d7f3ed9a 100644 --- a/static/client/managers/piecewise_function_manager.py +++ b/static/client/managers/piecewise_function_manager.py @@ -62,7 +62,9 @@ def __init__( self.name_generator: "DrawableNameGenerator" = name_generator self.dependency_manager: "DrawableDependencyManager" = dependency_manager self.drawable_manager: "DrawableManagerProxy" = drawable_manager_proxy - self.piecewise_function_edit_policy: Optional[DrawableEditPolicy] = get_drawable_edit_policy("PiecewiseFunction") + self.piecewise_function_edit_policy: Optional[DrawableEditPolicy] = get_drawable_edit_policy( + "PiecewiseFunction" + ) def get_piecewise_function(self, name: str) -> Optional[PiecewiseFunction]: """ @@ -144,9 +146,7 @@ def _validate_pieces(self, pieces: List[Dict[str, Any]]) -> None: if left is not None and right is not None: if left >= right: - raise ValueError( - f"Piece {i + 1}: left bound ({left}) must be less than right bound ({right})" - ) + raise ValueError(f"Piece {i + 1}: left bound ({left}) must be less than right bound ({right})") def delete_piecewise_function(self, name: str) -> bool: """ @@ -167,9 +167,7 @@ def delete_piecewise_function(self, name: str) -> bool: self.canvas.undo_redo_manager.archive() - removed = remove_drawable_with_dependencies( - self.drawables, self.dependency_manager, pf - ) + removed = remove_drawable_with_dependencies(self.drawables, self.dependency_manager, pf) if self.canvas.draw_enabled: self.canvas.draw() diff --git a/static/client/managers/point_manager.py b/static/client/managers/point_manager.py index 0d75ec87..3ba28336 100644 --- a/static/client/managers/point_manager.py +++ b/static/client/managers/point_manager.py @@ -50,6 +50,7 @@ from managers.drawable_manager_proxy import DrawableManagerProxy from name_generator.drawable import DrawableNameGenerator + class PointManager: """ Manages point drawables for a Canvas. @@ -202,9 +203,7 @@ def delete_point(self, x: float, y: float) -> bool: self._delete_point_dependencies(x, y) # Now remove the point itself - removed = remove_drawable_with_dependencies( - self.drawables, self.dependency_manager, point - ) + removed = remove_drawable_with_dependencies(self.drawables, self.dependency_manager, point) # Redraw the canvas if self.canvas.draw_enabled: @@ -251,18 +250,22 @@ def _delete_point_dependencies(self, x: float, y: float) -> None: # Get all children (including angles and circle arcs) that depend on this point dependent_children = self.dependency_manager.get_children(point_to_delete) for child in cast(List["Drawable"], list(dependent_children)): - if hasattr(child, 'get_class_name'): + if hasattr(child, "get_class_name"): class_name = child.get_class_name() else: class_name = child.__class__.__name__ - if class_name == 'Angle': - print(f"PointManager: Point at ({x}, {y}) is being deleted. Removing dependent angle '{child.name}'.") - if hasattr(self.drawable_manager, 'angle_manager') and self.drawable_manager.angle_manager: + if class_name == "Angle": + print( + f"PointManager: Point at ({x}, {y}) is being deleted. Removing dependent angle '{child.name}'." + ) + if hasattr(self.drawable_manager, "angle_manager") and self.drawable_manager.angle_manager: self.drawable_manager.angle_manager.delete_angle(child.name) - if class_name == 'CircleArc': - print(f"PointManager: Point at ({x}, {y}) is being deleted. Removing dependent circle arc '{child.name}'.") - if hasattr(self.drawable_manager, 'arc_manager') and self.drawable_manager.arc_manager: + if class_name == "CircleArc": + print( + f"PointManager: Point at ({x}, {y}) is being deleted. Removing dependent circle arc '{child.name}'." + ) + if hasattr(self.drawable_manager, "arc_manager") and self.drawable_manager.arc_manager: self.drawable_manager.arc_manager.delete_circle_arc(child.name) # Delete all polygons that contain the point @@ -270,16 +273,17 @@ def _delete_point_dependencies(self, x: float, y: float) -> None: polygon_segments = get_polygon_segments(polygon) if any(MathUtils.segment_has_end_point(seg, x, y) for seg in polygon_segments if seg is not None): polygon_name = getattr(polygon, "name", "") - if polygon_name and hasattr(self.drawable_manager, "delete_region_expression_colored_areas_referencing_name"): + if polygon_name and hasattr( + self.drawable_manager, "delete_region_expression_colored_areas_referencing_name" + ): try: self.drawable_manager.delete_region_expression_colored_areas_referencing_name( - polygon_name, archive=False, + polygon_name, + archive=False, ) except Exception: pass - remove_drawable_with_dependencies( - self.drawables, self.dependency_manager, polygon - ) + remove_drawable_with_dependencies(self.drawables, self.dependency_manager, polygon) # Collect all segments that contain the point segments_to_delete: List[Segment] = [] @@ -305,8 +309,7 @@ def _delete_point_dependencies(self, x: float, y: float) -> None: p2x = segment.point2.x p2y = segment.point2.y # Use the proxy to call delete_segment - self.drawable_manager.delete_segment(p1x, p1y, p2x, p2y, - delete_children=True, delete_parents=False) + self.drawable_manager.delete_segment(p1x, p1y, p2x, p2y, delete_children=True, delete_parents=False) # Delete the vectors that contain the point vectors = self.drawables.Vectors @@ -361,9 +364,7 @@ def _can_bypass_solitary_rules( def _is_point_only_circle_center(self, point: Point) -> bool: circles = [ - circle - for circle in getattr(self.drawables, "Circles", []) - if getattr(circle, "center", None) is point + circle for circle in getattr(self.drawables, "Circles", []) if getattr(circle, "center", None) is point ] if not circles: return False @@ -472,17 +473,13 @@ def _point_is_locked_center(self, point: Point) -> bool: return True return False - def _compute_updated_name( - self, original_point: Point, pending_fields: Dict[str, str] - ) -> Optional[str]: + def _compute_updated_name(self, original_point: Point, pending_fields: Dict[str, str]) -> Optional[str]: if "name" not in pending_fields: return None candidate = pending_fields["name"] filtered_candidate: str = str( - self.name_generator.filter_string(candidate) - if hasattr(self.name_generator, "filter_string") - else candidate + self.name_generator.filter_string(candidate) if hasattr(self.name_generator, "filter_string") else candidate ) filtered_candidate = filtered_candidate.strip() if not filtered_candidate: @@ -512,8 +509,6 @@ def _compute_updated_coordinates( return (x_val, y_val) - def _validate_color_request( - self, pending_fields: Dict[str, str], new_color: Optional[str] - ) -> None: + def _validate_color_request(self, pending_fields: Dict[str, str], new_color: Optional[str]) -> None: if "color" in pending_fields and (new_color is None or not str(new_color).strip()): raise ValueError("Point color cannot be empty.") diff --git a/static/client/managers/polygon_manager.py b/static/client/managers/polygon_manager.py index 84da5a3e..0a7cc340 100644 --- a/static/client/managers/polygon_manager.py +++ b/static/client/managers/polygon_manager.py @@ -253,9 +253,7 @@ def delete_polygon( except Exception: pass - removed = remove_drawable_with_dependencies( - self.drawables, self.dependency_manager, target - ) + removed = remove_drawable_with_dependencies(self.drawables, self.dependency_manager, target) if removed: for segment in self._iter_polygon_segments(target): self.segment_manager.delete_segment( diff --git a/static/client/managers/polygon_type.py b/static/client/managers/polygon_type.py index c8626c4d..5d63c93e 100644 --- a/static/client/managers/polygon_type.py +++ b/static/client/managers/polygon_type.py @@ -34,4 +34,3 @@ def coerce(cls, value: str) -> "PolygonType": return cls(value.strip().lower()) except Exception as exc: raise ValueError(f"Unsupported polygon type '{value}'.") from exc - diff --git a/static/client/managers/segment_manager.py b/static/client/managers/segment_manager.py index 7346c5a4..f2eff33d 100644 --- a/static/client/managers/segment_manager.py +++ b/static/client/managers/segment_manager.py @@ -53,6 +53,7 @@ from managers.point_manager import PointManager from name_generator.drawable import DrawableNameGenerator + class SegmentManager: """ Manages segment drawables for a Canvas. @@ -267,7 +268,9 @@ def create_segment_from_points( label_visibility = bool(label_visible) if label_visible is not None else False if color_value: - segment = Segment(p1, p2, color=color_value, label_text=sanitized_label_text, label_visible=label_visibility) + segment = Segment( + p1, p2, color=color_value, label_text=sanitized_label_text, label_visible=label_visibility + ) else: segment = Segment(p1, p2, label_text=sanitized_label_text, label_visible=label_visibility) @@ -279,7 +282,9 @@ def create_segment_from_points( return segment - def delete_segment(self, x1: float, y1: float, x2: float, y2: float, delete_children: bool = True, delete_parents: bool = False) -> bool: + def delete_segment( + self, x1: float, y1: float, x2: float, y2: float, delete_children: bool = True, delete_parents: bool = False + ) -> bool: """ Delete a segment by its endpoint coordinates @@ -312,22 +317,22 @@ def delete_segment(self, x1: float, y1: float, x2: float, y2: float, delete_chil # Get all children (including angles) that depend on this segment dependent_children = self.dependency_manager.get_children(segment) for child in cast(List["Drawable"], list(dependent_children)): - if hasattr(child, 'get_class_name') and child.get_class_name() == 'Angle': - print(f"SegmentManager: Segment '{segment.name}' is being deleted. Removing dependent angle '{child.name}'.") - if hasattr(self.drawable_manager, 'angle_manager') and self.drawable_manager.angle_manager: + if hasattr(child, "get_class_name") and child.get_class_name() == "Angle": + print( + f"SegmentManager: Segment '{segment.name}' is being deleted. Removing dependent angle '{child.name}'." + ) + if hasattr(self.drawable_manager, "angle_manager") and self.drawable_manager.angle_manager: self.drawable_manager.angle_manager.delete_angle(child.name) # Also notify AngleManager if a segment is about to be removed (for backward compatibility) - if hasattr(self.drawable_manager, 'angle_manager') and self.drawable_manager.angle_manager: + if hasattr(self.drawable_manager, "angle_manager") and self.drawable_manager.angle_manager: self.drawable_manager.angle_manager.handle_segment_removed(segment.name) # Handle dependencies by calling the internal method self._delete_segment_dependencies(x1, y1, x2, y2, delete_children, delete_parents) # Now remove the segment itself - remove_drawable_with_dependencies( - self.drawables, self.dependency_manager, segment - ) + remove_drawable_with_dependencies(self.drawables, self.dependency_manager, segment) # Redraw if self.canvas.draw_enabled: @@ -458,8 +463,9 @@ def _normalize_label_visibility( return None return bool(new_label_visible) - - def _delete_segment_dependencies(self, x1: float, y1: float, x2: float, y2: float, delete_children: bool = True, delete_parents: bool = False) -> None: + def _delete_segment_dependencies( + self, x1: float, y1: float, x2: float, y2: float, delete_children: bool = True, delete_parents: bool = False + ) -> None: """ Delete all geometric objects that depend on the specified segment. @@ -481,24 +487,26 @@ def _delete_segment_dependencies(self, x1: float, y1: float, x2: float, y2: floa if delete_children: children = self.dependency_manager.get_all_children(segment) # print(f"Handling deletion of {len(children)} children for segment {segment.name}") # Keep commented unless debugging - for child in cast(List["Drawable"], list(children)): # Iterate over a copy - if hasattr(child, 'point1') and hasattr(child, 'point2'): # Check if child is segment-like + for child in cast(List["Drawable"], list(children)): # Iterate over a copy + if hasattr(child, "point1") and hasattr(child, "point2"): # Check if child is segment-like # Unlink child from the current segment being processed self.dependency_manager.unregister_dependency(child=child, parent=segment) # Check if the child still has any other SEGMENT parents remaining parents_of_child = self.dependency_manager.get_parents(child) has_segment_parent = any( - hasattr(p, 'get_class_name') and p.get_class_name() == 'Segment' - for p in parents_of_child + hasattr(p, "get_class_name") and p.get_class_name() == "Segment" for p in parents_of_child ) # If the child no longer has any segment parents, delete it recursively if not has_segment_parent: self.delete_segment( - child.point1.x, child.point1.y, - child.point2.x, child.point2.y, - delete_children=True, delete_parents=False + child.point1.x, + child.point1.y, + child.point2.x, + child.point2.y, + delete_children=True, + delete_parents=False, ) else: # Handle non-segment children. @@ -519,10 +527,15 @@ def _delete_segment_dependencies(self, x1: float, y1: float, x2: float, y2: floa parents_to_delete = self.dependency_manager.get_all_parents(segment) print(f"Handling deletion of {len(parents_to_delete)} parents for segment {segment.name}") for parent in cast(List["Drawable"], list(parents_to_delete)): - if hasattr(parent, 'point1') and hasattr(parent, 'point2'): - self.delete_segment(parent.point1.x, parent.point1.y, - parent.point2.x, parent.point2.y, - delete_children=True, delete_parents=False) + if hasattr(parent, "point1") and hasattr(parent, "point2"): + self.delete_segment( + parent.point1.x, + parent.point1.y, + parent.point2.x, + parent.point2.y, + delete_children=True, + delete_parents=False, + ) else: print(f"Warning: Parent {parent} of {segment.name} is not a segment, cannot recursively delete.") @@ -535,16 +548,17 @@ def _delete_segment_dependencies(self, x1: float, y1: float, x2: float, y2: floa polygon_segments = get_polygon_segments(polygon) if any(MathUtils.segment_matches_coordinates(s, x1, y1, x2, y2) for s in polygon_segments if s is not None): polygon_name = getattr(polygon, "name", "") - if polygon_name and hasattr(self.drawable_manager, "delete_region_expression_colored_areas_referencing_name"): + if polygon_name and hasattr( + self.drawable_manager, "delete_region_expression_colored_areas_referencing_name" + ): try: self.drawable_manager.delete_region_expression_colored_areas_referencing_name( - polygon_name, archive=False, + polygon_name, + archive=False, ) except Exception: pass - remove_drawable_with_dependencies( - self.drawables, self.dependency_manager, polygon - ) + remove_drawable_with_dependencies(self.drawables, self.dependency_manager, polygon) def _split_segments_with_point(self, x: float, y: float) -> None: """ @@ -577,12 +591,12 @@ def _split_segments_with_point(self, x: float, y: float) -> None: self.dependency_manager.register_dependency(child=segment1, parent=segment) # Propagate dependency to SEGMENT ancestors of the original segment for ancestor in self.dependency_manager.get_all_parents(segment): - if hasattr(ancestor, 'get_class_name') and ancestor.get_class_name() == 'Segment': + if hasattr(ancestor, "get_class_name") and ancestor.get_class_name() == "Segment": self.dependency_manager.register_dependency(child=segment1, parent=ancestor) if segment2: # The original segment that was split is a direct parent of the new segment self.dependency_manager.register_dependency(child=segment2, parent=segment) # Propagate dependency to SEGMENT ancestors of the original segment for ancestor in self.dependency_manager.get_all_parents(segment): - if hasattr(ancestor, 'get_class_name') and ancestor.get_class_name() == 'Segment': + if hasattr(ancestor, "get_class_name") and ancestor.get_class_name() == "Segment": self.dependency_manager.register_dependency(child=segment2, parent=ancestor) diff --git a/static/client/managers/statistics_manager.py b/static/client/managers/statistics_manager.py index 68c6182c..adba6472 100644 --- a/static/client/managers/statistics_manager.py +++ b/static/client/managers/statistics_manager.py @@ -130,9 +130,7 @@ def plot_distribution( plot_right_raw = plot_bounds_dict.get("right_bound") resolved_left, resolved_right = self._resolve_bounds(mean, sigma, plot_left_raw, plot_right_raw) - plot_name = self._generate_unique_name( - self.name_generator.filter_string(name or "") or "normal_plot" - ) + plot_name = self._generate_unique_name(self.name_generator.filter_string(name or "") or "normal_plot") if rep == "discrete": result_payload = self._plot_distribution_discrete( @@ -288,24 +286,22 @@ def plot_bars( if not math.isfinite(y0): raise ValueError("y_base must be finite") - plot_name = self._generate_unique_name( - self.name_generator.filter_string(name or "") or "bars_plot" - ) + plot_name = self._generate_unique_name(self.name_generator.filter_string(name or "") or "bars_plot") plot = BarsPlot( - plot_name, - plot_type="bars", - values=[float(v) for v in values], - labels_below=[str(v) for v in labels_below], - labels_above=None if labels_above is None else [str(v) for v in labels_above], - bar_spacing=spacing, - bar_width=width, - x_start=x0, - y_base=y0, - stroke_color=None if stroke_color is None else str(stroke_color), - fill_color=self._normalize_fill_color(fill_color), - fill_opacity=self._normalize_fill_opacity(fill_opacity), - ) + plot_name, + plot_type="bars", + values=[float(v) for v in values], + labels_below=[str(v) for v in labels_below], + labels_above=None if labels_above is None else [str(v) for v in labels_above], + bar_spacing=spacing, + bar_width=width, + x_start=x0, + y_base=y0, + stroke_color=None if stroke_color is None else str(stroke_color), + fill_color=self._normalize_fill_color(fill_color), + fill_opacity=self._normalize_fill_opacity(fill_opacity), + ) self.drawables.add(plot) self.materialize_bars_plot(plot) @@ -393,8 +389,7 @@ def fit_regression( # Validate model type if model not in SUPPORTED_MODEL_TYPES: raise ValueError( - f"Unsupported model_type '{model_type}'. " - f"Supported: {', '.join(SUPPORTED_MODEL_TYPES)}" + f"Unsupported model_type '{model_type}'. Supported: {', '.join(SUPPORTED_MODEL_TYPES)}" ) # Validate degree for polynomial @@ -469,9 +464,7 @@ def fit_regression( try: logger = getattr(self.canvas, "logger", None) if logger is not None: - logger.debug( - f"Failed to create regression point {point_preferred}: {e}" - ) + logger.debug(f"Failed to create regression point {point_preferred}: {e}") except Exception: pass @@ -889,7 +882,7 @@ def _resolve_bar_count(self, bar_count: Optional[float]) -> int: def _normal_pdf_value(self, x: float, mean: float, sigma: float) -> float: # f(x) = 1/(sigma*sqrt(2*pi)) * exp(-(x-mean)^2/(2*sigma^2)) coeff = 1.0 / (sigma * math.sqrt(2.0 * math.pi)) - exponent = -((x - mean) ** 2) / (2.0 * (sigma ** 2)) + exponent = -((x - mean) ** 2) / (2.0 * (sigma**2)) return coeff * math.exp(exponent) def _delete_continuous_plot(self, plot: Any) -> None: diff --git a/static/client/managers/tangent_manager.py b/static/client/managers/tangent_manager.py index e0390b10..c999e8a4 100644 --- a/static/client/managers/tangent_manager.py +++ b/static/client/managers/tangent_manager.py @@ -298,9 +298,7 @@ def _compute_normal_data( normal_slope = MathUtils.normal_slope(tangent_slope) # Calculate normal endpoints - endpoints = MathUtils.tangent_line_endpoints( - normal_slope, tangent_data["point"], length - ) + endpoints = MathUtils.tangent_line_endpoints(normal_slope, tangent_data["point"], length) return { "point": tangent_data["point"], @@ -354,7 +352,10 @@ def create_tangent_line( # Create the segment using segment manager segment = self.segment_manager.create_segment( - x1, y1, x2, y2, + x1, + y1, + x2, + y2, name=name or "", color=color, extra_graphics=True, @@ -410,7 +411,10 @@ def create_normal_line( # Create the segment using segment manager segment = self.segment_manager.create_segment( - x1, y1, x2, y2, + x1, + y1, + x2, + y2, name=name or "", color=color, extra_graphics=True, diff --git a/static/client/managers/transformations_manager.py b/static/client/managers/transformations_manager.py index 3b67d18f..8cabc1d9 100644 --- a/static/client/managers/transformations_manager.py +++ b/static/client/managers/transformations_manager.py @@ -42,8 +42,16 @@ # Types that do not support geometric transforms via AI tools. _EXCLUDE_TRANSFORM: Tuple[str, ...] = ( - "Function", "ParametricFunction", "PiecewiseFunction", - "Graph", "Angle", "CircleArc", "ColoredArea", "Label", "Bar", "Plot", + "Function", + "ParametricFunction", + "PiecewiseFunction", + "Graph", + "Angle", + "CircleArc", + "ColoredArea", + "Label", + "Bar", + "Plot", ) @@ -117,15 +125,12 @@ def _validate_scale_support(self, drawable: Any, sx: float, sy: float) -> None: uniform = abs(sx - sy) < 1e-9 if cn == "Circle" and not uniform: raise ValueError( - "Non-uniform scaling of a circle is not supported; " - "convert to an ellipse first or use equal sx and sy" + "Non-uniform scaling of a circle is not supported; convert to an ellipse first or use equal sx and sy" ) if cn == "Ellipse" and not uniform: rot = getattr(drawable, "rotation_angle", 0) if (rot % 180) > 1e-9: - raise ValueError( - "Non-uniform scaling of a rotated ellipse is not supported" - ) + raise ValueError("Non-uniform scaling of a rotated ellipse is not supported") def _refresh_dependencies_after_transform( self, diff --git a/static/client/managers/undo_redo_manager.py b/static/client/managers/undo_redo_manager.py index 1fd6f592..5c3b4fe6 100644 --- a/static/client/managers/undo_redo_manager.py +++ b/static/client/managers/undo_redo_manager.py @@ -48,6 +48,7 @@ if TYPE_CHECKING: from canvas import Canvas + class UndoRedoManager: """ Manages undo and redo operations for a Canvas object. @@ -84,8 +85,8 @@ def archive(self) -> None: def capture_state(self) -> Dict[str, Any]: """Capture the current canvas state snapshot.""" return { - 'drawables': copy.deepcopy(self.canvas.drawable_manager.drawables._drawables), - 'computations': copy.deepcopy(self.canvas.computations), + "drawables": copy.deepcopy(self.canvas.drawable_manager.drawables._drawables), + "computations": copy.deepcopy(self.canvas.computations), } def push_undo_state(self, state: Dict[str, Any]) -> None: @@ -95,9 +96,9 @@ def push_undo_state(self, state: Dict[str, Any]) -> None: def restore_state(self, state: Dict[str, Any], redraw: bool = True) -> None: """Restore a captured state snapshot.""" - self.canvas.drawable_manager.drawables._drawables = copy.deepcopy(state['drawables']) + self.canvas.drawable_manager.drawables._drawables = copy.deepcopy(state["drawables"]) self.canvas.drawable_manager.drawables.rebuild_renderables() - self.canvas.computations = copy.deepcopy(state.get('computations', [])) + self.canvas.computations = copy.deepcopy(state.get("computations", [])) self._rebuild_dependency_graph() if redraw: self.canvas.draw() @@ -126,13 +127,13 @@ def undo(self) -> bool: # Archive current state for redo current_state = { - 'drawables': copy.deepcopy(self.canvas.drawable_manager.drawables._drawables), - 'computations': copy.deepcopy(self.canvas.computations) + "drawables": copy.deepcopy(self.canvas.drawable_manager.drawables._drawables), + "computations": copy.deepcopy(self.canvas.computations), } self.redo_stack.append(current_state) # Restore only the drawables from the last state - self.canvas.drawable_manager.drawables._drawables = copy.deepcopy(last_state['drawables']) + self.canvas.drawable_manager.drawables._drawables = copy.deepcopy(last_state["drawables"]) self.canvas.drawable_manager.drawables.rebuild_renderables() # Ensure all objects are properly initialized @@ -159,13 +160,13 @@ def redo(self) -> bool: # Archive current state for undo current_state = { - 'drawables': copy.deepcopy(self.canvas.drawable_manager.drawables._drawables), - 'computations': copy.deepcopy(self.canvas.computations) + "drawables": copy.deepcopy(self.canvas.drawable_manager.drawables._drawables), + "computations": copy.deepcopy(self.canvas.computations), } self.undo_stack.append(current_state) # Restore only the drawables from the next state - self.canvas.drawable_manager.drawables._drawables = copy.deepcopy(next_state['drawables']) + self.canvas.drawable_manager.drawables._drawables = copy.deepcopy(next_state["drawables"]) self.canvas.drawable_manager.drawables.rebuild_renderables() # Ensure all objects are properly initialized @@ -211,7 +212,7 @@ def _rebuild_dependency_graph(self) -> None: # This assumes self.canvas has a 'dependency_manager' attribute # which is an instance of DrawableDependencyManager. - if hasattr(self.canvas, 'dependency_manager') and self.canvas.dependency_manager is not None: + if hasattr(self.canvas, "dependency_manager") and self.canvas.dependency_manager is not None: dependency_manager = self.canvas.dependency_manager # Clear existing dependency relationships from the manager @@ -226,8 +227,10 @@ def _rebuild_dependency_graph(self) -> None: else: # Log a warning if the dependency manager isn't found on the canvas object. # This helps in debugging if the expected structure isn't met. - print("UndoRedoManager: Warning - Canvas instance does not have a 'dependency_manager' " + \ - "attribute or it is None. Skipping dependency graph rebuild for undo/redo operation.") + print( + "UndoRedoManager: Warning - Canvas instance does not have a 'dependency_manager' " + + "attribute or it is None. Skipping dependency graph rebuild for undo/redo operation." + ) def clear(self) -> None: """ diff --git a/static/client/managers/vector_manager.py b/static/client/managers/vector_manager.py index 147af3e7..29fcf0cf 100644 --- a/static/client/managers/vector_manager.py +++ b/static/client/managers/vector_manager.py @@ -52,6 +52,7 @@ from managers.point_manager import PointManager from name_generator.drawable import DrawableNameGenerator + class VectorManager: """ Manages vector drawables for a Canvas. @@ -107,8 +108,9 @@ def get_vector(self, x1: float, y1: float, x2: float, y2: float) -> Optional[Vec """ vectors = self.drawables.Vectors for vector in vectors: - if (MathUtils.point_matches_coordinates(vector.origin, x1, y1) and - MathUtils.point_matches_coordinates(vector.tip, x2, y2)): + if MathUtils.point_matches_coordinates(vector.origin, x1, y1) and MathUtils.point_matches_coordinates( + vector.tip, x2, y2 + ): return vector return None @@ -253,13 +255,14 @@ def delete_vector(self, origin_x: float, origin_y: float, tip_x: float, tip_y: f # Find the vector that matches these coordinates vectors = self.drawables.Vectors for vector in vectors.copy(): - if (MathUtils.point_matches_coordinates(vector.origin, origin_x, origin_y) and - MathUtils.point_matches_coordinates(vector.tip, tip_x, tip_y)): + if MathUtils.point_matches_coordinates( + vector.origin, origin_x, origin_y + ) and MathUtils.point_matches_coordinates(vector.tip, tip_x, tip_y): # Archive before deletion self.canvas.undo_redo_manager.archive() # Remove the vector's segment if it's not used by other objects - if hasattr(vector, 'segment'): + if hasattr(vector, "segment"): segment = vector.segment p1x = segment.point1.x p1y = segment.point1.y @@ -268,9 +271,7 @@ def delete_vector(self, origin_x: float, origin_y: float, tip_x: float, tip_y: f self.canvas.drawable_manager.delete_segment(p1x, p1y, p2x, p2y) # Remove the vector - removed = remove_drawable_with_dependencies( - self.drawables, self.dependency_manager, vector - ) + removed = remove_drawable_with_dependencies(self.drawables, self.dependency_manager, vector) # Redraw if self.canvas.draw_enabled: diff --git a/static/client/markdown_parser.py b/static/client/markdown_parser.py index af7f8244..d5ab5891 100644 --- a/static/client/markdown_parser.py +++ b/static/client/markdown_parser.py @@ -36,7 +36,7 @@ def parse(self, text: str) -> str: except Exception as e: print(f"Error in custom markdown parsing: {e}") # Ultimate fallback - return text.replace('\n', '
') + return text.replace("\n", "
") def _simple_markdown_parse(self, text: str) -> str: """Simple markdown parser for basic formatting using string operations.""" @@ -45,23 +45,23 @@ def _simple_markdown_parse(self, text: str) -> str: text = self._process_tables(text) # Split text into lines for processing - lines = text.split('\n') + lines = text.split("\n") html_lines: list[str] = [] in_code_block = False code_block_content: list[str] = [] for line in lines: # Skip table processing if already processed - if '
' in line or '
' in line or '' in line or '' in line or '' in line: + if "" in line or "
" in line or "" in line or "" in line or "" in line: html_lines.append(line) continue # Handle code blocks - if line.strip().startswith('```'): + if line.strip().startswith("```"): if in_code_block: # End code block - code_content = '\n'.join(code_block_content) - html_lines.append(f'
{code_content}
') + code_content = "\n".join(code_block_content) + html_lines.append(f"
{code_content}
") code_block_content = [] in_code_block = False else: @@ -79,16 +79,16 @@ def _simple_markdown_parse(self, text: str) -> str: heading_match = self._parse_heading(processed_line) if heading_match: level, heading_content = heading_match - processed_line = f'{heading_content}' + processed_line = f"{heading_content}" # Lists - handle ordered and unordered with indentation elif self._is_list_item(processed_line): processed_line = self._process_list_item(processed_line) # Blockquotes - elif processed_line.startswith('> '): - processed_line = f'
{processed_line[2:]}
' + elif processed_line.startswith("> "): + processed_line = f"
{processed_line[2:]}
" # Horizontal rules - elif processed_line.strip() == '---': - processed_line = '
' + elif processed_line.strip() == "---": + processed_line = "
" # Handle inline formatting processed_line = self._process_inline_markdown(processed_line) @@ -107,7 +107,7 @@ def _simple_markdown_parse(self, text: str) -> str: except Exception as e: print(f"Error in simple markdown parsing: {e}") # Ultimate fallback - return text.replace('\n', '
') + return text.replace("\n", "
") def _parse_heading(self, line: str) -> Optional[Tuple[int, str]]: """Parse markdown heading and return (level, content) if matched.""" @@ -129,7 +129,7 @@ def _parse_heading(self, line: str) -> Optional[Tuple[int, str]]: def _process_tables(self, text: str) -> str: """Process markdown tables using proper GFM table parsing algorithm.""" - lines = text.split('\n') + lines = text.split("\n") result_lines = [] i = 0 @@ -137,7 +137,7 @@ def _process_tables(self, text: str) -> str: line = lines[i] # Check if this could be start of a table - must start with pipe - if line.strip().startswith('|') and line.strip(): + if line.strip().startswith("|") and line.strip(): # Look ahead to see if next line is a delimiter if i + 1 < len(lines): next_line = lines[i + 1] @@ -148,7 +148,7 @@ def _process_tables(self, text: str) -> str: # Collect data rows - must also start with pipe while j < len(lines): - if lines[j].strip().startswith('|') and lines[j].strip(): + if lines[j].strip().startswith("|") and lines[j].strip(): table_lines.append(lines[j]) j += 1 else: @@ -164,12 +164,12 @@ def _process_tables(self, text: str) -> str: result_lines.append(line) i += 1 - return '\n'.join(result_lines) + return "\n".join(result_lines) def _build_table_html(self, table_lines: list[str]) -> str: """Build HTML table from table lines.""" if len(table_lines) < 2: - return '\n'.join(table_lines) + return "\n".join(table_lines) header_line = table_lines[0] delimiter_line = table_lines[1] @@ -178,51 +178,51 @@ def _build_table_html(self, table_lines: list[str]) -> str: # Parse header header_cells = self._parse_table_row(header_line) if not header_cells: - return '\n'.join(table_lines) + return "\n".join(table_lines) # Parse alignments alignments = self._parse_alignments(delimiter_line) # Build HTML - no internal newlines to avoid extra spacing - html = '' + html = "
" for i, cell in enumerate(header_cells): - align = alignments[i] if i < len(alignments) else '' - align_attr = f' style="text-align: {align};"' if align else '' - html += f'{cell}' - html += '' + align = alignments[i] if i < len(alignments) else "" + align_attr = f' style="text-align: {align};"' if align else "" + html += f"{cell}" + html += "" # Process data rows if data_lines: - html += '' + html += "" for data_line in data_lines: data_cells = self._parse_table_row(data_line) if data_cells: - html += '' + html += "" for i, cell in enumerate(data_cells): - align = alignments[i] if i < len(alignments) else '' - align_attr = f' style="text-align: {align};"' if align else '' - html += f'{cell}' - html += '' - html += '' + align = alignments[i] if i < len(alignments) else "" + align_attr = f' style="text-align: {align};"' if align else "" + html += f"{cell}" + html += "" + html += "" - html += '
' + html += "" return html def _is_delimiter_row(self, line: str) -> bool: """Check if a line is a valid table delimiter row.""" # Delimiter row must start with | (after whitespace) line = line.strip() - if not line.startswith('|'): + if not line.startswith("|"): return False # Remove leading/trailing pipes - if line.startswith('|'): + if line.startswith("|"): line = line[1:] - if line.endswith('|'): + if line.endswith("|"): line = line[:-1] # Split by pipes and check each cell - cells = [cell.strip() for cell in line.split('|')] + cells = [cell.strip() for cell in line.split("|")] # Must have at least one valid delimiter cell valid_cell_count = 0 @@ -231,7 +231,7 @@ def _is_delimiter_row(self, line: str) -> bool: # Empty cells are allowed but don't count toward validity continue # Must be hyphens with optional colons for alignment - if re.match(r'^:?-+:?$', cell): + if re.match(r"^:?-+:?$", cell): valid_cell_count += 1 else: return False # Invalid delimiter character @@ -245,13 +245,13 @@ def _parse_table_row(self, line: str) -> list[str]: line = line.strip() # Remove leading/trailing pipes if present - if line.startswith('|'): + if line.startswith("|"): line = line[1:] - if line.endswith('|'): + if line.endswith("|"): line = line[:-1] # Split by pipes and process each cell - cells = [cell.strip() for cell in line.split('|')] + cells = [cell.strip() for cell in line.split("|")] # Process inline markdown in each cell processed_cells = [] @@ -261,7 +261,7 @@ def _parse_table_row(self, line: str) -> list[str]: processed_cell = self._process_inline_markdown(cell) processed_cells.append(processed_cell) else: - processed_cells.append('') + processed_cells.append("") return processed_cells @@ -269,31 +269,31 @@ def _parse_alignments(self, delimiter_line: str) -> list[str]: """Parse column alignments from delimiter row.""" # Remove leading/trailing whitespace and optional pipes line = delimiter_line.strip() - if line.startswith('|'): + if line.startswith("|"): line = line[1:] - if line.endswith('|'): + if line.endswith("|"): line = line[:-1] # Split by pipes and determine alignment for each column - cells = [cell.strip() for cell in line.split('|')] + cells = [cell.strip() for cell in line.split("|")] alignments = [] for cell in cells: if not cell: - alignments.append('') + alignments.append("") continue - starts_with_colon = cell.startswith(':') - ends_with_colon = cell.endswith(':') + starts_with_colon = cell.startswith(":") + ends_with_colon = cell.endswith(":") if starts_with_colon and ends_with_colon: - alignments.append('center') + alignments.append("center") elif starts_with_colon: - alignments.append('left') + alignments.append("left") elif ends_with_colon: - alignments.append('right') + alignments.append("right") else: - alignments.append('') + alignments.append("") return alignments @@ -302,15 +302,15 @@ def _is_list_item(self, line: str) -> bool: stripped = line.strip() # Checkbox items - if stripped.startswith('- [x]') or stripped.startswith('- [ ]'): + if stripped.startswith("- [x]") or stripped.startswith("- [ ]"): return True # Unordered list - if stripped.startswith('- ') or stripped.startswith('* '): + if stripped.startswith("- ") or stripped.startswith("* "): return True # Ordered list (number followed by period and space) - if len(stripped) > 2 and stripped[1:3] == '. ': + if len(stripped) > 2 and stripped[1:3] == ". ": try: int(stripped[0]) # Check if first char is number return True @@ -318,7 +318,7 @@ def _is_list_item(self, line: str) -> bool: pass # Handle multi-digit numbers - parts = stripped.split('. ', 1) + parts = stripped.split(". ", 1) if len(parts) == 2: try: int(parts[0]) @@ -336,35 +336,39 @@ def _process_list_item(self, line: str) -> str: stripped = line.strip() # Handle checkboxes - if stripped.startswith('- [x]'): + if stripped.startswith("- [x]"): content = stripped[6:] # Remove '- [x] ' checkbox = '' - return f'
  • {checkbox}{content}
  • ' - elif stripped.startswith('- [ ]'): + return ( + f'
  • {checkbox}{content}
  • ' + ) + elif stripped.startswith("- [ ]"): content = stripped[6:] # Remove '- [ ] ' checkbox = '' - return f'
  • {checkbox}{content}
  • ' + return ( + f'
  • {checkbox}{content}
  • ' + ) # Determine list type and content - if stripped.startswith('- ') or stripped.startswith('* '): + if stripped.startswith("- ") or stripped.startswith("* "): # Unordered list content = stripped[2:] - list_type = 'ul' + list_type = "ul" else: # Ordered list (number followed by period) - parts = stripped.split('. ', 1) + parts = stripped.split(". ", 1) if len(parts) == 2: try: int(parts[0]) content = parts[1] - list_type = 'ol' + list_type = "ol" except: # Fallback content = stripped - list_type = 'ul' + list_type = "ul" else: content = stripped - list_type = 'ul' + list_type = "ul" # Add data attributes to track list type and indent level return f'
  • {content}
  • ' @@ -380,37 +384,37 @@ def _join_lines_with_smart_breaks(self, lines: list[str]) -> str: else: # Empty line - only add break if not between list items or table elements if i > 0 and i < len(lines) - 1: - prev_line = lines[i-1].strip() - next_line = lines[i+1].strip() + prev_line = lines[i - 1].strip() + next_line = lines[i + 1].strip() # Don't add breaks between list items - prev_is_list = '', '', '', '', '']) - next_is_table = any(tag in next_line for tag in ['', '
    ', '', '', '']) + prev_is_table = any(tag in prev_line for tag in ["", "
    ", "", "", ""]) + next_is_table = any(tag in next_line for tag in ["", "
    ", "", "", ""]) if not (prev_is_list and next_is_list) and not (prev_is_table or next_is_table): - result_lines.append('
    ') + result_lines.append("
    ") - return '
    '.join(result_lines) + return "
    ".join(result_lines) except Exception as e: print(f"Error joining lines: {e}") - return '
    '.join(lines) + return "
    ".join(lines) def _wrap_list_items_improved(self, html: str) -> str: """Wrap list items with proper
      /
        tags and handle REAL nesting.""" try: - lines = html.split('
        ') + lines = html.split("
        ") result = [] i = 0 while i < len(lines): line = lines[i].strip() - if '
      1. ' in line: + if "
      2. " in line: # Start processing a list list_items = [] current_index = i @@ -418,17 +422,17 @@ def _wrap_list_items_improved(self, html: str) -> str: # Collect all consecutive list items while current_index < len(lines): current_line = lines[current_index].strip() - if '
      3. ' in current_line: - list_type = self._extract_data_attr(current_line, 'data-list-type') - indent_level = int(self._extract_data_attr(current_line, 'data-indent') or '0') + if "
      4. " in current_line: + list_type = self._extract_data_attr(current_line, "data-list-type") + indent_level = int(self._extract_data_attr(current_line, "data-indent") or "0") # Ensure list_type is not None if list_type is None: - list_type = 'ul' # Default to unordered list + list_type = "ul" # Default to unordered list # Clean the line (remove data attributes) - clean_line = current_line.replace(f' data-list-type="{list_type}"', '') - clean_line = clean_line.replace(f' data-indent="{indent_level}"', '') + clean_line = current_line.replace(f' data-list-type="{list_type}"', "") + clean_line = clean_line.replace(f' data-indent="{indent_level}"', "") list_items.append((clean_line, list_type, indent_level)) current_index += 1 @@ -446,7 +450,7 @@ def _wrap_list_items_improved(self, html: str) -> str: i += 1 # Join with line breaks - return '
        '.join(result) + return "
        ".join(result) except Exception as e: print(f"Error in improved list wrapping: {e}") @@ -465,7 +469,7 @@ def _build_nested_list_html(self, list_items: list[Tuple[str, str, int]]) -> str """Build properly nested HTML from list items.""" try: if not list_items: - return '' + return "" result: list[str] = [] stack: list[Tuple[str, int]] = [] # Stack of (list_type, indent_level) @@ -478,13 +482,13 @@ def _build_nested_list_html(self, list_items: list[Tuple[str, str, int]]) -> str break # Close the deeper list closed_type = stack.pop()[0] - tag = 'ol' if closed_type == 'ol' else 'ul' - result.append(f'') + tag = "ol" if closed_type == "ol" else "ul" + result.append(f"") # Open new list if needed if not stack or stack[-1][1] < indent_level or stack[-1][0] != list_type: - tag = 'ol' if list_type == 'ol' else 'ul' - result.append(f'<{tag}>') + tag = "ol" if list_type == "ol" else "ul" + result.append(f"<{tag}>") stack.append((list_type, indent_level)) # Add the list item @@ -493,61 +497,57 @@ def _build_nested_list_html(self, list_items: list[Tuple[str, str, int]]) -> str # Close all remaining open lists while stack: closed_type = stack.pop()[0] - tag = 'ol' if closed_type == 'ol' else 'ul' - result.append(f'') + tag = "ol" if closed_type == "ol" else "ul" + result.append(f"") - return ''.join(result) + return "".join(result) except Exception as e: print(f"Error building nested list HTML: {e}") - return '' + return "" def _process_inline_markdown(self, text: str) -> str: """Process inline markdown elements like bold, italic, code.""" try: # Bold text (**text** and __text__) - while '**' in text: - start = text.find('**') + while "**" in text: + start = text.find("**") if start == -1: break - end = text.find('**', start + 2) + end = text.find("**", start + 2) if end == -1: break before = text[:start] - content = text[start + 2:end] - after = text[end + 2:] - text = before + f'{content}' + after + content = text[start + 2 : end] + after = text[end + 2 :] + text = before + f"{content}" + after # Process double underscores for bold - restart search after each replacement - while '__' in text: + while "__" in text: found_match = False pos = 0 while pos < len(text): - start = text.find('__', pos) + start = text.find("__", pos) if start == -1: break - end = text.find('__', start + 2) + end = text.find("__", start + 2) if end == -1: break # Check if double underscore is surrounded by proper word boundaries - char_before = text[start - 1] if start > 0 else ' ' - char_after = text[end + 2] if end + 2 < len(text) else ' ' + char_before = text[start - 1] if start > 0 else " " + char_after = text[end + 2] if end + 2 < len(text) else " " # Only apply bold formatting if both double underscores are at proper word boundaries # Must be preceded and followed by space, punctuation, or start/end of text - before_is_boundary = (char_before == ' ' or - char_before in '.,!?:;()[]{}"\'-' or - start == 0) - after_is_boundary = (char_after == ' ' or - char_after in '.,!?:;()[]{}"\'-' or - end + 2 >= len(text)) + before_is_boundary = char_before == " " or char_before in ".,!?:;()[]{}\"'-" or start == 0 + after_is_boundary = char_after == " " or char_after in ".,!?:;()[]{}\"'-" or end + 2 >= len(text) if before_is_boundary and after_is_boundary: before = text[:start] - content = text[start + 2:end] - after = text[end + 2:] - text = before + f'{content}' + after + content = text[start + 2 : end] + after = text[end + 2 :] + text = before + f"{content}" + after found_match = True break # Break inner loop to restart search from beginning else: @@ -559,48 +559,44 @@ def _process_inline_markdown(self, text: str) -> str: break # Italic text (*text* and _text_) - while '*' in text: # Process asterisks regardless of bold tags - start = text.find('*') + while "*" in text: # Process asterisks regardless of bold tags + start = text.find("*") if start == -1: break - end = text.find('*', start + 1) + end = text.find("*", start + 1) if end == -1: break before = text[:start] - content = text[start + 1:end] - after = text[end + 1:] - text = before + f'{content}' + after + content = text[start + 1 : end] + after = text[end + 1 :] + text = before + f"{content}" + after # Enhanced underscore italic processing - only format when surrounded by spaces or at word boundaries - while '_' in text: # Process single underscores regardless of bold tags + while "_" in text: # Process single underscores regardless of bold tags found_match = False pos = 0 while pos < len(text): - start = text.find('_', pos) + start = text.find("_", pos) if start == -1: break - end = text.find('_', start + 1) + end = text.find("_", start + 1) if end == -1: break # Check if underscore is surrounded by proper word boundaries - char_before = text[start - 1] if start > 0 else ' ' - char_after = text[end + 1] if end + 1 < len(text) else ' ' + char_before = text[start - 1] if start > 0 else " " + char_after = text[end + 1] if end + 1 < len(text) else " " # Only apply italic formatting if both underscores are at proper word boundaries # Must be preceded and followed by space, punctuation, or start/end of text - before_is_boundary = (char_before == ' ' or - char_before in '.,!?:;()[]{}"\'-' or - start == 0) - after_is_boundary = (char_after == ' ' or - char_after in '.,!?:;()[]{}"\'-' or - end + 1 >= len(text)) + before_is_boundary = char_before == " " or char_before in ".,!?:;()[]{}\"'-" or start == 0 + after_is_boundary = char_after == " " or char_after in ".,!?:;()[]{}\"'-" or end + 1 >= len(text) if before_is_boundary and after_is_boundary: before = text[:start] - content = text[start + 1:end] - after = text[end + 1:] - text = before + f'{content}' + after + content = text[start + 1 : end] + after = text[end + 1 :] + text = before + f"{content}" + after found_match = True break # Break inner loop to restart search from beginning else: @@ -612,46 +608,46 @@ def _process_inline_markdown(self, text: str) -> str: break # Strikethrough (~~text~~) - while '~~' in text: - start = text.find('~~') + while "~~" in text: + start = text.find("~~") if start == -1: break - end = text.find('~~', start + 2) + end = text.find("~~", start + 2) if end == -1: break before = text[:start] - content = text[start + 2:end] - after = text[end + 2:] - text = before + f'{content}' + after + content = text[start + 2 : end] + after = text[end + 2 :] + text = before + f"{content}" + after # Inline code (`text`) - while '`' in text: - start = text.find('`') + while "`" in text: + start = text.find("`") if start == -1: break - end = text.find('`', start + 1) + end = text.find("`", start + 1) if end == -1: break before = text[:start] - content = text[start + 1:end] - after = text[end + 1:] - text = before + f'{content}' + after + content = text[start + 1 : end] + after = text[end + 1 :] + text = before + f"{content}" + after # Links [text](url) - while '[' in text and '](' in text and ')' in text: - start = text.find('[') + while "[" in text and "](" in text and ")" in text: + start = text.find("[") if start == -1: break - middle = text.find('](', start) + middle = text.find("](", start) if middle == -1: break - end = text.find(')', middle) + end = text.find(")", middle) if end == -1: break before = text[:start] - link_text = text[start + 1:middle] - link_url = text[middle + 2:end] - after = text[end + 1:] + link_text = text[start + 1 : middle] + link_url = text[middle + 2 : end] + after = text[end + 1 :] text = before + f'{link_text}' + after return text @@ -668,13 +664,13 @@ def _process_math_expressions(self, text: str) -> str: block_matches = [] pos = 0 while True: - start = text.find('$$', pos) + start = text.find("$$", pos) if start == -1: break - end = text.find('$$', start + 2) + end = text.find("$$", start + 2) if end == -1: break - block_matches.append((start, end + 2, text[start + 2:end])) + block_matches.append((start, end + 2, text[start + 2 : end])) pos = end + 2 # Replace from end to beginning @@ -686,17 +682,17 @@ def _process_math_expressions(self, text: str) -> str: bracket_matches = [] pos = 0 while True: - start = text.find('\\[', pos) + start = text.find("\\[", pos) if start == -1: break - end = text.find('\\]', start + 2) + end = text.find("\\]", start + 2) if end == -1: break - bracket_matches.append((start, end + 2, text[start + 2:end])) + bracket_matches.append((start, end + 2, text[start + 2 : end])) pos = end + 2 for start, end, content in reversed(bracket_matches): - cleaned = content.replace('
        ', '\n').strip() + cleaned = content.replace("
        ", "\n").strip() replacement = f'
        $${cleaned}$$
        ' text = text[:start] + replacement + text[end:] @@ -704,13 +700,13 @@ def _process_math_expressions(self, text: str) -> str: inline_matches = [] pos = 0 while True: - start = text.find('\\(', pos) + start = text.find("\\(", pos) if start == -1: break - end = text.find('\\)', start + 2) + end = text.find("\\)", start + 2) if end == -1: break - inline_matches.append((start, end + 2, text[start + 2:end])) + inline_matches.append((start, end + 2, text[start + 2 : end])) pos = end + 2 # Replace from end to beginning diff --git a/static/client/name_generator/__init__.py b/static/client/name_generator/__init__.py index b01ded51..32b3304f 100644 --- a/static/client/name_generator/__init__.py +++ b/static/client/name_generator/__init__.py @@ -14,9 +14,9 @@ from .arc import ArcNameGenerator __all__ = [ - 'DrawableNameGenerator', - 'NameGenerator', - 'PointNameGenerator', - 'FunctionNameGenerator', - 'ArcNameGenerator', + "DrawableNameGenerator", + "NameGenerator", + "PointNameGenerator", + "FunctionNameGenerator", + "ArcNameGenerator", ] diff --git a/static/client/name_generator/arc.py b/static/client/name_generator/arc.py index a252a42e..2f92aac7 100644 --- a/static/client/name_generator/arc.py +++ b/static/client/name_generator/arc.py @@ -38,9 +38,16 @@ class ArcNameGenerator(NameGenerator): # Prefixes to strip before extracting point names (longer prefixes first) ARC_PREFIXES = ( - "ArcMajor_", "ArcMinor_", "ArcMaj_", "ArcMin_", - "ArcMajor", "ArcMinor", - "arc_", "Arc_", "arc", "Arc", + "ArcMajor_", + "ArcMinor_", + "ArcMaj_", + "ArcMin_", + "ArcMajor", + "ArcMinor", + "arc_", + "Arc_", + "arc", + "Arc", ) def __init__(self, canvas: Any, point_generator: "PointNameGenerator") -> None: @@ -53,9 +60,7 @@ def __init__(self, canvas: Any, point_generator: "PointNameGenerator") -> None: super().__init__(canvas) self.point_generator = point_generator - def extract_point_names_from_arc_name( - self, arc_name: Optional[str] - ) -> Tuple[Optional[str], Optional[str]]: + def extract_point_names_from_arc_name(self, arc_name: Optional[str]) -> Tuple[Optional[str], Optional[str]]: """Extract suggested point names from an arc name suggestion. First strips known arc prefixes, then uses the point generator's @@ -79,7 +84,7 @@ def extract_point_names_from_arc_name( name = arc_name for prefix in self.ARC_PREFIXES: if name.startswith(prefix): - name = name[len(prefix):] + name = name[len(prefix) :] break if not name: @@ -141,4 +146,3 @@ def _make_unique(self, base_name: str, existing_names: Set[str]) -> str: candidate = f"{base_name}_{suffix}" suffix += 1 return candidate - diff --git a/static/client/name_generator/base.py b/static/client/name_generator/base.py index a46c6f58..060dd94f 100644 --- a/static/client/name_generator/base.py +++ b/static/client/name_generator/base.py @@ -71,4 +71,4 @@ def filter_string(self, name: str) -> str: return "" pattern: str = r"[a-zA-Z0-9_'\(\)]+" matches: List[str] = re.findall(pattern, name) - return ''.join(matches) + return "".join(matches) diff --git a/static/client/name_generator/drawable.py b/static/client/name_generator/drawable.py index b4584461..e2fc7c6e 100644 --- a/static/client/name_generator/drawable.py +++ b/static/client/name_generator/drawable.py @@ -58,7 +58,7 @@ def __init__(self, canvas: Any) -> None: def reset_state(self) -> None: """Reset the state of all specialized name generators.""" self.point_generator.reset_state() - self.function_generator.reset_state() # Assuming FunctionNameGenerator might also have state + self.function_generator.reset_state() # Assuming FunctionNameGenerator might also have state if hasattr(self.label_generator, "reset_state"): self.label_generator.reset_state() @@ -186,9 +186,7 @@ def generate_label_name(self, preferred_name: Optional[str]) -> str: """Generate a unique label name, using preferred_name when provided.""" return str(self.label_generator.generate_label_name(preferred_name)) - def extract_point_names_from_arc_name( - self, arc_name: Optional[str] - ) -> Tuple[Optional[str], Optional[str]]: + def extract_point_names_from_arc_name(self, arc_name: Optional[str]) -> Tuple[Optional[str], Optional[str]]: """Extract suggested point names from an arc name suggestion. Args: @@ -219,9 +217,9 @@ def generate_arc_name( Returns: Unique arc name """ - return str(self.arc_generator.generate_arc_name( - proposed_name, point1_name, point2_name, use_major_arc, existing_names - )) + return str( + self.arc_generator.generate_arc_name(proposed_name, point1_name, point2_name, use_major_arc, existing_names) + ) def _is_valid_point_list(self, points: List[str]) -> bool: """Helper to check if a list of points is valid for angle name generation. @@ -235,8 +233,7 @@ def _is_valid_point_list(self, points: List[str]) -> bool: if not points or not isinstance(points, list) or len(points) != 2: return False # Ensure both point names are non-empty strings - if not (isinstance(points[0], str) and points[0] and \ - isinstance(points[1], str) and points[1]): + if not (isinstance(points[0], str) and points[0] and isinstance(points[1], str) and points[1]): return False return True @@ -257,11 +254,11 @@ def generate_angle_name_from_segments(self, segment1_name: str, segment2_name: s # Temporarily reset next_index for these specific segment names if they were parsed before # This ensures a fresh parse by split_point_names for this method's context - if hasattr(self.point_generator, 'used_letters_from_names'): # Check if attribute exists + if hasattr(self.point_generator, "used_letters_from_names"): # Check if attribute exists if segment1_name in self.point_generator.used_letters_from_names: - self.point_generator.used_letters_from_names[segment1_name]['next_index'] = 0 + self.point_generator.used_letters_from_names[segment1_name]["next_index"] = 0 if segment2_name in self.point_generator.used_letters_from_names: - self.point_generator.used_letters_from_names[segment2_name]['next_index'] = 0 + self.point_generator.used_letters_from_names[segment2_name]["next_index"] = 0 s1_points: List[str] = self.point_generator.split_point_names(segment1_name) s2_points: List[str] = self.point_generator.split_point_names(segment2_name) @@ -287,7 +284,9 @@ def generate_angle_name_from_segments(self, segment1_name: str, segment2_name: s # Identify and sort the two arm points (excluding the vertex) # Ensure elements are strings before comparison with vertex_name if there's any doubt - arm_point_candidates: List[str] = [str(p) for p in all_unique_points if p is not None and str(p) != str(vertex_name)] + arm_point_candidates: List[str] = [ + str(p) for p in all_unique_points if p is not None and str(p) != str(vertex_name) + ] arm_point_names: List[str] = sorted(arm_point_candidates) if len(arm_point_names) != 2: diff --git a/static/client/name_generator/function.py b/static/client/name_generator/function.py index a6814ffb..acc43eaa 100644 --- a/static/client/name_generator/function.py +++ b/static/client/name_generator/function.py @@ -44,10 +44,10 @@ def _extract_number_suffix(self, func_name: str) -> Tuple[str, Optional[int]]: Returns: tuple: (prefix, number) where number is None if no suffix found """ - match: Optional[re.Match[str]] = re.search(r'(?<=\w)(\d+)$', func_name) + match: Optional[re.Match[str]] = re.search(r"(?<=\w)(\d+)$", func_name) if match: number: int = int(match.group()) - prefix: str = func_name[:match.start()] + prefix: str = func_name[: match.start()] return prefix, number return func_name, None @@ -69,7 +69,7 @@ def _increment_function_name(self, func_name: str) -> str: if number is not None: return prefix + str(number + 1) else: - return func_name + '1' + return func_name + "1" def _try_function_name(self, letter: str, number: int, existing_names: List[str]) -> Optional[str]: """Try a function name with the given letter and number. @@ -96,8 +96,8 @@ def _generate_unique_function_name(self) -> str: Raises: ValueError: If all function names are somehow taken (highly unlikely) """ - func_alphabet: str = 'fghijklmnopqrstuvwxyzabcde' - function_names: List[str] = self.get_drawable_names('Function') + func_alphabet: str = "fghijklmnopqrstuvwxyzabcde" + function_names: List[str] = self.get_drawable_names("Function") for number in count(): for letter in func_alphabet: @@ -116,9 +116,9 @@ def _extract_function_name_before_parenthesis(self, preferred_name: str) -> str: Returns: str: Function name without parentheses and arguments """ - match: Optional[re.Match[str]] = re.search(r'(?<=\w)(?=\()', preferred_name) + match: Optional[re.Match[str]] = re.search(r"(?<=\w)(?=\()", preferred_name) if match: - return preferred_name[:match.start()] + return preferred_name[: match.start()] return preferred_name def _find_available_function_name(self, preferred_name: str, function_names: List[str]) -> str: @@ -153,7 +153,7 @@ def generate_function_name(self, preferred_name: Optional[str]) -> str: if not preferred_name: return self._generate_unique_function_name() - function_names: List[str] = self.get_drawable_names('Function') + function_names: List[str] = self.get_drawable_names("Function") # Extract name before parenthesis if present clean_name: str = self._extract_function_name_before_parenthesis(preferred_name) @@ -161,9 +161,7 @@ def generate_function_name(self, preferred_name: Optional[str]) -> str: # Find an available function name return self._find_available_function_name(clean_name, function_names) - def _try_parametric_function_name( - self, letter: str, number: int, existing_names: List[str] - ) -> Optional[str]: + def _try_parametric_function_name(self, letter: str, number: int, existing_names: List[str]) -> Optional[str]: """Try a parametric function name with the given letter and number. Args: @@ -188,14 +186,12 @@ def _generate_unique_parametric_function_name(self) -> str: Returns: Unique parametric function name """ - func_alphabet: str = 'fghijklmnopqrstuvwxyzabcde' - parametric_names: List[str] = self.get_drawable_names('ParametricFunction') + func_alphabet: str = "fghijklmnopqrstuvwxyzabcde" + parametric_names: List[str] = self.get_drawable_names("ParametricFunction") for number in count(): for letter in func_alphabet: - name: Optional[str] = self._try_parametric_function_name( - letter, number, parametric_names - ) + name: Optional[str] = self._try_parametric_function_name(letter, number, parametric_names) if name: return name @@ -218,7 +214,7 @@ def generate_parametric_function_name(self, preferred_name: Optional[str]) -> st return self._generate_unique_parametric_function_name() # If a preferred name is provided, check if it's available - parametric_names: List[str] = self.get_drawable_names('ParametricFunction') + parametric_names: List[str] = self.get_drawable_names("ParametricFunction") if preferred_name not in parametric_names: return preferred_name diff --git a/static/client/name_generator/label.py b/static/client/name_generator/label.py index 79ce8da2..7680202f 100644 --- a/static/client/name_generator/label.py +++ b/static/client/name_generator/label.py @@ -49,4 +49,3 @@ def _normalize_preferred_name(self, preferred_name: Optional[str]) -> str: return "" filtered = self.filter_string(trimmed) return filtered or trimmed - diff --git a/static/client/name_generator/point.py b/static/client/name_generator/point.py index 75ad76bf..b5a5bfb4 100644 --- a/static/client/name_generator/point.py +++ b/static/client/name_generator/point.py @@ -57,10 +57,10 @@ def _init_tracking_for_expression(self, expression: str) -> Dict[str, Any]: dict: Tracking data for the expression """ if expression not in self.used_letters_from_names: - matches: List[str] = re.findall(r'[A-Z][\']*', expression) + matches: List[str] = re.findall(r"[A-Z][\']*", expression) self.used_letters_from_names[expression] = { - 'letters': list(dict.fromkeys(matches)), # All letters - 'next_index': 0 # Next unused letter index + "letters": list(dict.fromkeys(matches)), # All letters + "next_index": 0, # Next unused letter index } return self.used_letters_from_names[expression] @@ -74,18 +74,18 @@ def _get_next_letters(self, name_data: Dict[str, Any], n: int) -> List[str]: Returns: list: List of the next n letters """ - available_letters: List[str] = name_data['letters'] - start_index: int = name_data['next_index'] + available_letters: List[str] = name_data["letters"] + start_index: int = name_data["next_index"] result: List[str] = [] for i in range(n): if start_index + i < len(available_letters): result.append(available_letters[start_index + i]) else: - result.append('') + result.append("") # Update the next index - name_data['next_index'] = min(start_index + n, len(available_letters)) + name_data["next_index"] = min(start_index + n, len(available_letters)) return result def split_point_names(self, expression: Optional[str], n: int = 2) -> List[str]: @@ -99,7 +99,7 @@ def split_point_names(self, expression: Optional[str], n: int = 2) -> List[str]: list: List of individual point names """ if expression is None or len(expression) < 1: - return [''] * n + return [""] * n expression = self.filter_string(expression) expression = expression.upper() @@ -116,7 +116,7 @@ def _generate_unique_point_name(self) -> str: Returns: str: Unique point name following alphabetical progression """ - point_names: List[str] = self.get_drawable_names('Point') + point_names: List[str] = self.get_drawable_names("Point") return self._find_available_name_from_alphabet(ALPHABET, point_names) @@ -148,10 +148,10 @@ def _init_tracking_for_preferred_name(self, preferred_name: str) -> Dict[str, An dict: Tracking data for the preferred name """ if preferred_name not in self.used_letters_from_names: - matches: List[str] = re.findall(r'[A-Z][\']*', preferred_name) + matches: List[str] = re.findall(r"[A-Z][\']*", preferred_name) self.used_letters_from_names[preferred_name] = { - 'letters': list(dict.fromkeys(matches)), # All available letters with their apostrophes - 'next_index': 0 # Next unused letter index + "letters": list(dict.fromkeys(matches)), # All available letters with their apostrophes + "next_index": 0, # Next unused letter index } return self.used_letters_from_names[preferred_name] @@ -175,7 +175,9 @@ def _find_available_name_from_preferred(self, letter_with_apostrophes: str, poin result: Optional[str] = self._try_add_apostrophes(base_letter, point_names) return result if result is not None else base_letter - def _try_add_apostrophes(self, base_letter: str, point_names: List[str], initial_count: int = 1, max_attempts: int = 5) -> Optional[str]: + def _try_add_apostrophes( + self, base_letter: str, point_names: List[str], initial_count: int = 1, max_attempts: int = 5 + ) -> Optional[str]: """Try adding apostrophes to a base letter until finding an unused name. Args: @@ -213,13 +215,13 @@ def generate_point_name(self, preferred_name: Optional[str]) -> str: # Filter and uppercase the preferred name preferred_name = self.filter_string(preferred_name).upper() - point_names: List[str] = self.get_drawable_names('Point') + point_names: List[str] = self.get_drawable_names("Point") # Initialize tracking for this name name_data: Dict[str, Any] = self._init_tracking_for_preferred_name(preferred_name) - available_letters: List[str] = name_data['letters'] - start_index: int = name_data['next_index'] + available_letters: List[str] = name_data["letters"] + start_index: int = name_data["next_index"] # Try each remaining letter from the preferred name for i in range(start_index, len(available_letters)): @@ -228,7 +230,7 @@ def generate_point_name(self, preferred_name: Optional[str]) -> str: name: str = self._find_available_name_from_preferred(letter_with_apostrophes, point_names) if name: - name_data['next_index'] = i + 1 + name_data["next_index"] = i + 1 return name # If no letters from preferred name are available, generate a unique name diff --git a/static/client/numeric_solver/__init__.py b/static/client/numeric_solver/__init__.py index b6f30a12..9f5d4f09 100644 --- a/static/client/numeric_solver/__init__.py +++ b/static/client/numeric_solver/__init__.py @@ -7,4 +7,4 @@ from .solver import solve_numeric -__all__ = ['solve_numeric'] +__all__ = ["solve_numeric"] diff --git a/static/client/numeric_solver/expression_utils.py b/static/client/numeric_solver/expression_utils.py index 0a35f906..89539834 100644 --- a/static/client/numeric_solver/expression_utils.py +++ b/static/client/numeric_solver/expression_utils.py @@ -13,14 +13,41 @@ # Math function names to exclude from variable detection -MATH_FUNCTIONS = frozenset({ - 'sin', 'cos', 'tan', 'asin', 'acos', 'atan', 'atan2', - 'sinh', 'cosh', 'tanh', 'asinh', 'acosh', 'atanh', - 'log', 'ln', 'log10', 'log2', 'exp', - 'sqrt', 'cbrt', 'abs', 'sign', 'floor', 'ceil', 'round', - 'min', 'max', 'mod', 'pow', - 'pi', 'e', -}) +MATH_FUNCTIONS = frozenset( + { + "sin", + "cos", + "tan", + "asin", + "acos", + "atan", + "atan2", + "sinh", + "cosh", + "tanh", + "asinh", + "acosh", + "atanh", + "log", + "ln", + "log10", + "log2", + "exp", + "sqrt", + "cbrt", + "abs", + "sign", + "floor", + "ceil", + "round", + "min", + "max", + "mod", + "pow", + "pi", + "e", + } +) def detect_variables(equations: Sequence[str]) -> List[str]: @@ -38,7 +65,7 @@ def detect_variables(equations: Sequence[str]) -> List[str]: # Pattern matches single letters that are not part of longer words # Uses negative lookbehind and lookahead to ensure it's a standalone letter - pattern = r'(? str: Returns: Residual expression string. """ - if '=' in equation: - parts = equation.split('=', 1) + if "=" in equation: + parts = equation.split("=", 1) lhs = parts[0].strip() rhs = parts[1].strip() return f"({lhs}) - ({rhs})" @@ -109,4 +136,5 @@ def evaluate_residuals( def _is_finite(value: float) -> bool: """Check if a value is finite (not NaN or infinity).""" import math + return math.isfinite(value) diff --git a/static/client/numeric_solver/solver.py b/static/client/numeric_solver/solver.py index 629153a8..66e8e5c5 100644 --- a/static/client/numeric_solver/solver.py +++ b/static/client/numeric_solver/solver.py @@ -94,8 +94,7 @@ def solve_numeric( if not unique_solutions: result["message"] = ( - "No solutions found in search range [-10, 10]. " - "Try providing initial_guesses closer to expected solutions." + "No solutions found in search range [-10, 10]. Try providing initial_guesses closer to expected solutions." ) return json.dumps(result) @@ -103,9 +102,11 @@ def solve_numeric( def _error_result(variables: List[str], message: str) -> str: """Create an error result JSON string.""" - return json.dumps({ - "solutions": [], - "variables": variables, - "method": "newton_raphson", - "error": message, - }) + return json.dumps( + { + "solutions": [], + "variables": variables, + "method": "newton_raphson", + "error": message, + } + ) diff --git a/static/client/polar_grid.py b/static/client/polar_grid.py index d54f2e31..f27179e3 100644 --- a/static/client/polar_grid.py +++ b/static/client/polar_grid.py @@ -125,13 +125,10 @@ def max_radius_screen(self) -> float: if self.width is None or self.height is None: return 0 ox, oy = self.origin_screen - corners = [ - (0, 0), (self.width, 0), - (0, self.height), (self.width, self.height) - ] + corners = [(0, 0), (self.width, 0), (0, self.height), (self.width, self.height)] max_dist = 0.0 for cx, cy in corners: - dist = math.sqrt((cx - ox)**2 + (cy - oy)**2) + dist = math.sqrt((cx - ox) ** 2 + (cy - oy) ** 2) if dist > max_dist: max_dist = dist return max_dist * 1.1 # Add 10% margin @@ -142,12 +139,12 @@ def max_radius_math(self) -> float: scale = self.coordinate_mapper.scale_factor if scale <= 0: scale = 1.0 - return self.max_radius_screen / scale + return float(self.max_radius_screen / scale) @property def display_spacing(self) -> float: """Get the spacing between circles in screen pixels.""" - return abs(self._current_radial_spacing) * self.coordinate_mapper.scale_factor + return float(abs(self._current_radial_spacing) * self.coordinate_mapper.scale_factor) def reset(self) -> None: """Reset polar grid to initial state.""" @@ -172,7 +169,7 @@ def get_radial_circles(self) -> List[float]: Returns: List of radii for each concentric circle """ - circles = [] + circles: list[float] = [] spacing = self.display_spacing if spacing <= 0: return circles @@ -248,8 +245,8 @@ def _find_appropriate_spacing(self, ideal_spacing: float) -> float: possible_spacings = [magnitude * i for i in [1, 2, 5, 10]] for spacing in possible_spacings: if spacing >= effective_ideal: - return spacing - return possible_spacings[0] + return float(spacing) + return float(possible_spacings[0]) def _invalidate_cache_on_zoom(self) -> None: """Update radial spacing for zoom operations.""" diff --git a/static/client/process_function_calls.py b/static/client/process_function_calls.py index 73de2a14..307221f3 100644 --- a/static/client/process_function_calls.py +++ b/static/client/process_function_calls.py @@ -56,10 +56,7 @@ def evaluate_expression( return ExpressionEvaluator.evaluate_expression(expression, variables, canvas) @staticmethod - def evaluate_linear_algebra_expression( - objects: List[LinearAlgebraObject], - expression: str - ) -> LinearAlgebraResult: + def evaluate_linear_algebra_expression(objects: List[LinearAlgebraObject], expression: str) -> LinearAlgebraResult: """Evaluates a linear algebra expression using predefined objects. Delegates to LinearAlgebraUtils for matrix and vector validation and @@ -121,9 +118,13 @@ def get_results_traced( Returns: Tuple of (results dict, list of traced call records) """ - return ResultProcessor.get_results_traced( - calls, available_functions, undoable_functions, canvas, + result: tuple[dict[str, Any], list[dict[str, Any]]] = ResultProcessor.get_results_traced( + calls, + available_functions, + undoable_functions, + canvas, ) + return result @staticmethod def validate_results(results: Dict[str, Any]) -> bool: diff --git a/static/client/rendering/cached_render_plan.py b/static/client/rendering/cached_render_plan.py index f5e0f9f0..a10361e3 100644 --- a/static/client/rendering/cached_render_plan.py +++ b/static/client/rendering/cached_render_plan.py @@ -265,7 +265,9 @@ def _screen_to_math_point(screen_point: Tuple[float, float], state: MapState) -> return (mx, my) -def _reproject_points(points: Iterable[Tuple[float, float]], old: MapState, new: MapState) -> Tuple[Tuple[float, float], ...]: +def _reproject_points( + points: Iterable[Tuple[float, float]], old: MapState, new: MapState +) -> Tuple[Tuple[float, float], ...]: """Reproject multiple screen points from one map state to another. Args: @@ -309,9 +311,7 @@ def _get_safe_scale(state: MapState, key: str = "scale") -> float: return 1.0 if value <= 0 else value -def _reproject_stroke_line( - command: PrimitiveCommand, old_state: MapState, new_state: MapState -) -> None: +def _reproject_stroke_line(command: PrimitiveCommand, old_state: MapState, new_state: MapState) -> None: """Reproject a stroke_line command to a new map state. Updates the command's start/end points and geometry metadata in place. @@ -323,9 +323,7 @@ def _reproject_stroke_line( command.meta["geometry"] = _quantize_geometry((new_start, new_end)) -def _reproject_stroke_polyline( - command: PrimitiveCommand, old_state: MapState, new_state: MapState -) -> None: +def _reproject_stroke_polyline(command: PrimitiveCommand, old_state: MapState, new_state: MapState) -> None: """Reproject a stroke_polyline command to a new map state. Updates the command's point list and geometry metadata in place. @@ -336,9 +334,7 @@ def _reproject_stroke_polyline( command.meta["geometry"] = _quantize_geometry(new_points) -def _reproject_stroke_circle( - command: PrimitiveCommand, old_state: MapState, new_state: MapState -) -> None: +def _reproject_stroke_circle(command: PrimitiveCommand, old_state: MapState, new_state: MapState) -> None: """Reproject a stroke_circle command to a new map state. Updates center position and radius (unless screen_space flag is set). @@ -351,9 +347,7 @@ def _reproject_stroke_circle( command.meta["geometry"] = _quantize_geometry((new_center, new_radius)) -def _reproject_fill_circle( - command: PrimitiveCommand, old_state: MapState, new_state: MapState -) -> None: +def _reproject_fill_circle(command: PrimitiveCommand, old_state: MapState, new_state: MapState) -> None: """Reproject a fill_circle command to a new map state. Updates center position, radius, and geometry metadata in place. @@ -366,9 +360,7 @@ def _reproject_fill_circle( command.meta["geometry"] = _quantize_geometry((new_center, new_radius)) -def _reproject_stroke_ellipse( - command: PrimitiveCommand, old_state: MapState, new_state: MapState -) -> None: +def _reproject_stroke_ellipse(command: PrimitiveCommand, old_state: MapState, new_state: MapState) -> None: """Reproject a stroke_ellipse command to a new map state. Updates center position, both radii, and geometry metadata in place. @@ -381,9 +373,7 @@ def _reproject_stroke_ellipse( command.meta["geometry"] = _quantize_geometry((new_center, new_rx, new_ry, rotation)) -def _reproject_fill_joined_area( - command: PrimitiveCommand, old_state: MapState, new_state: MapState -) -> None: +def _reproject_fill_joined_area(command: PrimitiveCommand, old_state: MapState, new_state: MapState) -> None: """Reproject a fill_joined_area command to a new map state. Updates both forward and reverse point arrays used for shaded regions. @@ -395,9 +385,7 @@ def _reproject_fill_joined_area( command.meta["geometry"] = _quantize_geometry(new_forward + new_reverse) -def _compute_vector_arrow_points( - vector_meta: Dict[str, Any], new_state: MapState -) -> Tuple[Tuple[float, float], ...]: +def _compute_vector_arrow_points(vector_meta: Dict[str, Any], new_state: MapState) -> Tuple[Tuple[float, float], ...]: """Compute arrow head triangle points for a vector in the new map state. Vector arrows are rendered in screen space with a fixed tip size, so they @@ -436,9 +424,7 @@ def _compute_vector_arrow_points( return (tip, base1, base2) -def _reproject_fill_polygon( - command: PrimitiveCommand, old_state: MapState, new_state: MapState -) -> None: +def _reproject_fill_polygon(command: PrimitiveCommand, old_state: MapState, new_state: MapState) -> None: """Reproject a fill_polygon command to a new map state. Handles both regular polygons and vector arrow heads specially, since @@ -614,9 +600,7 @@ def _reproject_arc_with_circle_meta( ) -def _reproject_stroke_arc( - command: PrimitiveCommand, old_state: MapState, new_state: MapState -) -> None: +def _reproject_stroke_arc(command: PrimitiveCommand, old_state: MapState, new_state: MapState) -> None: """Reproject a stroke_arc command to a new map state. Handles three cases: angle arcs (with angle metadata), circle arcs @@ -788,9 +772,7 @@ def _reproject_text_with_label_meta( return new_position, font -def _reproject_draw_text( - command: PrimitiveCommand, old_state: MapState, new_state: MapState -) -> None: +def _reproject_draw_text(command: PrimitiveCommand, old_state: MapState, new_state: MapState) -> None: """Reproject a draw_text command to a new map state. Handles angle labels, point labels, standalone labels, and plain text @@ -1277,7 +1259,9 @@ def _pool_styles(self, value: Any) -> Any: return {key: self._pool_styles(item) for key, item in value.items()} return value - def _record(self, op: str, args: PrimitiveArgs, kwargs: PrimitiveKwargs, *, style: Any = None, geometry: Iterable[Any] = ()) -> None: + def _record( + self, op: str, args: PrimitiveArgs, kwargs: PrimitiveKwargs, *, style: Any = None, geometry: Iterable[Any] = () + ) -> None: """Record a primitive operation as a command.""" command_key = f"{self._drawable_key}:{op}:{self._counter}" self._counter += 1 @@ -1330,7 +1314,9 @@ def uses_screen_space(self) -> bool: return self._screen_space_used def stroke_line(self, start, end, stroke, *, include_width=True): - self._record("stroke_line", (start, end, stroke), {"include_width": include_width}, style=stroke, geometry=(start, end)) + self._record( + "stroke_line", (start, end, stroke), {"include_width": include_width}, style=stroke, geometry=(start, end) + ) def stroke_polyline(self, points, stroke): self._record("stroke_polyline", (tuple(points), stroke), {}, style=stroke, geometry=points) @@ -1348,7 +1334,13 @@ def fill_circle(self, center, radius, fill, stroke=None, *, screen_space=False): ) def stroke_ellipse(self, center, radius_x, radius_y, rotation_rad, stroke): - self._record("stroke_ellipse", (center, radius_x, radius_y, rotation_rad, stroke), {}, style=stroke, geometry=(center, radius_x, radius_y, rotation_rad)) + self._record( + "stroke_ellipse", + (center, radius_x, radius_y, rotation_rad, stroke), + {}, + style=stroke, + geometry=(center, radius_x, radius_y, rotation_rad), + ) def fill_polygon(self, points, fill, stroke=None, *, screen_space=False, metadata=None): self._record( @@ -1360,7 +1352,13 @@ def fill_polygon(self, points, fill, stroke=None, *, screen_space=False, metadat ) def fill_joined_area(self, forward, reverse, fill): - self._record("fill_joined_area", (tuple(forward), tuple(reverse), fill), {}, style=fill, geometry=list(forward) + list(reverse)) + self._record( + "fill_joined_area", + (tuple(forward), tuple(reverse), fill), + {}, + style=fill, + geometry=list(forward) + list(reverse), + ) def stroke_arc( self, @@ -1581,4 +1579,3 @@ def build_plan_for_polar( ) plan.update_map_state(map_state) return plan - diff --git a/static/client/rendering/canvas2d_primitive_adapter.py b/static/client/rendering/canvas2d_primitive_adapter.py index a40c3efa..699ef2d1 100644 --- a/static/client/rendering/canvas2d_primitive_adapter.py +++ b/static/client/rendering/canvas2d_primitive_adapter.py @@ -319,7 +319,7 @@ def stroke_arc( end_angle_rad: float, sweep_clockwise: bool, stroke: StrokeStyle, - css_class: str = None, + css_class: Optional[str] = None, *, screen_space: bool = False, metadata: Optional[Dict[str, Any]] = None, @@ -601,7 +601,7 @@ def execute_optimized(self, command: Any) -> None: # ------------------------------------------------------------------ def _batch_stroke_line(self, command: Any) -> None: - args = getattr(command, "args", ()) + args: tuple[Any, ...] = getattr(command, "args", ()) kwargs = getattr(command, "kwargs", {}) if len(args) < 3: return @@ -610,7 +610,7 @@ def _batch_stroke_line(self, command: Any) -> None: self._queue_line_segment(start, end, stroke, include_width) def _batch_polyline(self, command: Any) -> None: - args = getattr(command, "args", ()) + args: tuple[Any, ...] = getattr(command, "args", ()) if len(args) < 2: return points, stroke = args[:2] @@ -667,7 +667,7 @@ def _flush_line_batch(self) -> None: # ------------------------------------------------------------------ def _batch_fill_polygon_from_command(self, command: Any) -> None: - args = getattr(command, "args", ()) + args: tuple[Any, ...] = getattr(command, "args", ()) if not args: return if command.op == "fill_polygon": @@ -697,12 +697,9 @@ def _batch_fill_polygon( is_joined_area, ) batch = getattr(self, "_polygon_batch", None) - if ( - batch is None - or batch["signature"] != signature - ): + if batch is None or batch["signature"] != signature: self._flush_polygon_batch() - self._polygon_batch = { + self._polygon_batch: dict[str, Any] | None = { "fill": fill, "stroke": stroke, "polygons": [], @@ -765,4 +762,3 @@ def _flush_polygon_batch(self) -> None: self._record_event("stroke_calls") self._polygon_batch = None self._reset_alpha_if_needed() - diff --git a/static/client/rendering/canvas2d_renderer.py b/static/client/rendering/canvas2d_renderer.py index e24b1d10..978ff7d7 100644 --- a/static/client/rendering/canvas2d_renderer.py +++ b/static/client/rendering/canvas2d_renderer.py @@ -259,9 +259,7 @@ def render_cartesian(self, cartesian: Any, coordinate_mapper: Any) -> None: drawable_name = "Cartesian2Axis" map_state = self._capture_map_state(coordinate_mapper) signature = self._compute_drawable_signature(cartesian, coordinate_mapper) - plan = self._resolve_cartesian_plan( - cartesian, coordinate_mapper, map_state, signature, drawable_name - ) + plan = self._resolve_cartesian_plan(cartesian, coordinate_mapper, map_state, signature, drawable_name) if plan is None: return apply_start = self._telemetry.mark_time() @@ -287,9 +285,7 @@ def render_polar(self, polar_grid: Any, coordinate_mapper: Any) -> None: drawable_name = "PolarGrid" map_state = self._capture_map_state(coordinate_mapper) signature = self._compute_drawable_signature(polar_grid, coordinate_mapper) - plan = self._resolve_polar_plan( - polar_grid, coordinate_mapper, map_state, signature, drawable_name - ) + plan = self._resolve_polar_plan(polar_grid, coordinate_mapper, map_state, signature, drawable_name) if plan is None: return apply_start = self._telemetry.mark_time() @@ -320,9 +316,7 @@ def _resolve_polar_plan( else: if plan_entry is not None: self._drop_plan_group(plan_entry.get("plan")) - plan = self._build_polar_plan_with_metrics( - polar_grid, coordinate_mapper, map_state, drawable_name - ) + plan = self._build_polar_plan_with_metrics(polar_grid, coordinate_mapper, map_state, drawable_name) if plan is None: self._cartesian_cache = None return None @@ -374,81 +368,99 @@ def register_default_drawables(self) -> None: """Register handlers for all standard drawable types.""" try: from drawables.point import Point as PointDrawable + self.register(PointDrawable, self._render_point) except Exception: pass try: from drawables.segment import Segment as SegmentDrawable + self.register(SegmentDrawable, self._render_segment) except Exception: pass try: from drawables.circle import Circle as CircleDrawable + self.register(CircleDrawable, self._render_circle) except Exception: pass try: from drawables.ellipse import Ellipse as EllipseDrawable + self.register(EllipseDrawable, self._render_ellipse) except Exception: pass try: from drawables.circle_arc import CircleArc as CircleArcDrawable + self.register(CircleArcDrawable, self._render_circle_arc) except Exception: pass try: from drawables.vector import Vector as VectorDrawable + self.register(VectorDrawable, self._render_vector) except Exception: pass try: from drawables.angle import Angle as AngleDrawable + self.register(AngleDrawable, self._render_angle) except Exception: pass try: from drawables.function import Function as FunctionDrawable + self.register(FunctionDrawable, self._render_function) except Exception: pass try: from drawables.piecewise_function import PiecewiseFunction as PiecewiseFunctionDrawable + self.register(PiecewiseFunctionDrawable, self._render_function) except Exception: pass try: from drawables.parametric_function import ParametricFunction as ParametricFunctionDrawable + self.register(ParametricFunctionDrawable, self._render_function) except Exception: pass try: from drawables.functions_bounded_colored_area import FunctionsBoundedColoredArea as FunctionsAreaDrawable + self.register(FunctionsAreaDrawable, self._render_functions_bounded_colored_area) except Exception: pass try: - from drawables.function_segment_bounded_colored_area import FunctionSegmentBoundedColoredArea as FunctionSegmentAreaDrawable + from drawables.function_segment_bounded_colored_area import ( + FunctionSegmentBoundedColoredArea as FunctionSegmentAreaDrawable, + ) + self.register(FunctionSegmentAreaDrawable, self._render_function_segment_bounded_colored_area) except Exception: pass try: from drawables.segments_bounded_colored_area import SegmentsBoundedColoredArea as SegmentsAreaDrawable + self.register(SegmentsAreaDrawable, self._render_segments_bounded_colored_area) except Exception: pass try: from drawables.closed_shape_colored_area import ClosedShapeColoredArea as ClosedShapeAreaDrawable + self.register(ClosedShapeAreaDrawable, self._render_closed_shape_colored_area) except Exception: pass try: from drawables.label import Label as LabelDrawable + self.register(LabelDrawable, self._render_label) except Exception: pass try: from drawables.bar import Bar as BarDrawable + self.register(BarDrawable, self._render_drawable) except Exception: pass @@ -549,9 +561,7 @@ def _render_drawable(self, drawable: Any, coordinate_mapper: Any) -> None: map_state = self._capture_map_state(coordinate_mapper) signature = self._compute_drawable_signature(drawable, coordinate_mapper) cache_key = self._plan_cache_key(drawable, drawable_name) - plan = self._resolve_drawable_plan( - drawable, coordinate_mapper, map_state, signature, drawable_name, cache_key - ) + plan = self._resolve_drawable_plan(drawable, coordinate_mapper, map_state, signature, drawable_name, cache_key) if plan is None: return apply_start = self._telemetry.mark_time() @@ -840,9 +850,7 @@ def _resolve_drawable_plan( self._plan_cache.pop(cache_key, None) return None - def _is_cached_plan_valid( - self, cache_entry: Optional[Dict[str, Any]], signature: Optional[Any] - ) -> bool: + def _is_cached_plan_valid(self, cache_entry: Optional[Dict[str, Any]], signature: Optional[Any]) -> bool: return bool( cache_entry and cache_entry.get("signature") == signature @@ -857,5 +865,3 @@ def _mark_screen_space_plan_dirty(self, plan: OptimizedPrimitivePlan) -> None: """Mark a screen-space plan as needing reapplication.""" if getattr(plan, "uses_screen_space", lambda: False)(): plan.mark_dirty() - - diff --git a/static/client/rendering/factory.py b/static/client/rendering/factory.py index 77e37b27..b7dc5651 100644 --- a/static/client/rendering/factory.py +++ b/static/client/rendering/factory.py @@ -18,7 +18,7 @@ from __future__ import annotations -from typing import Callable, Optional, TypeVar +from typing import Callable, Optional, TypeVar, cast RendererType = TypeVar("RendererType") @@ -35,10 +35,11 @@ def _load_renderer(module_path: str, attr: str) -> Optional[RendererType]: """ try: module = __import__(module_path, fromlist=[attr]) - return getattr(module, attr) + return cast(Optional[RendererType], getattr(module, attr)) except Exception: return None + from rendering.interfaces import RendererProtocol SvgRenderer = _load_renderer("rendering.svg_renderer", "SvgRenderer") @@ -65,9 +66,7 @@ def _build_preference_chain(preferred: Optional[str]) -> list[str]: return chain -def _safe_instantiate( - factory: Callable[[], Optional[RendererProtocol]], *, error_message: str -) -> RendererProtocol: +def _safe_instantiate(factory: Callable[[], Optional[RendererProtocol]], *, error_message: str) -> RendererProtocol: """Instantiate a renderer and raise if it returns None.""" renderer = factory() if renderer is None: diff --git a/static/client/rendering/helpers/__init__.py b/static/client/rendering/helpers/__init__.py index 7f22c867..1b393391 100644 --- a/static/client/rendering/helpers/__init__.py +++ b/static/client/rendering/helpers/__init__.py @@ -83,4 +83,3 @@ "build_segments_colored_area", "build_closed_shape_colored_area", ] - diff --git a/static/client/rendering/helpers/angle_renderer.py b/static/client/rendering/helpers/angle_renderer.py index 94688870..7b38b0f1 100644 --- a/static/client/rendering/helpers/angle_renderer.py +++ b/static/client/rendering/helpers/angle_renderer.py @@ -99,8 +99,9 @@ def _compute_angle_label_params(vx, vy, p1y, clamped_radius, arc_radius, display return tx, ty, font, base_font_size, should_draw_label -def _build_angle_metadata(angle_obj, arc_radius, clamped_radius, min_arm_length, min_arm_length_math, - display_degrees, params, style): +def _build_angle_metadata( + angle_obj, arc_radius, clamped_radius, min_arm_length, min_arm_length_math, display_degrees, params, style +): """Build metadata dictionary for angle rendering debugging. Args: @@ -202,9 +203,7 @@ def render_angle_helper(primitives, angle_obj, coordinate_mapper, style): except Exception: return - params = angle_obj._calculate_arc_parameters( - vx, vy, p1x, p1y, p2x, p2y, arc_radius=style.get("angle_arc_radius") - ) + params = angle_obj._calculate_arc_parameters(vx, vy, p1x, p1y, p2x, p2y, arc_radius=style.get("angle_arc_radius")) if not params: return @@ -244,12 +243,12 @@ def render_angle_helper(primitives, angle_obj, coordinate_mapper, style): css_class = "angle-arc" if hasattr(primitives, "_surface") else None angle_metadata = _build_angle_metadata( - angle_obj, arc_radius, clamped_radius, min_arm_length, min_arm_length_math, - display_degrees, params, style + angle_obj, arc_radius, clamped_radius, min_arm_length, min_arm_length_math, display_degrees, params, style ) - _draw_angle_arc(primitives, vx, vy, clamped_radius, start_angle, end_angle, sweep_cw, stroke, css_class, angle_metadata) + _draw_angle_arc( + primitives, vx, vy, clamped_radius, start_angle, end_angle, sweep_cw, stroke, css_class, angle_metadata + ) if should_draw_label: _draw_angle_label(primitives, tx, ty, display_degrees, font, color, angle_metadata) - diff --git a/static/client/rendering/helpers/area_builders.py b/static/client/rendering/helpers/area_builders.py index 928d8e3c..3613e742 100644 --- a/static/client/rendering/helpers/area_builders.py +++ b/static/client/rendering/helpers/area_builders.py @@ -88,4 +88,3 @@ def build_closed_shape_colored_area(area_model, coordinate_mapper): return renderable.build_screen_area() except Exception: return None - diff --git a/static/client/rendering/helpers/bar_renderer.py b/static/client/rendering/helpers/bar_renderer.py index 750af298..d6c04030 100644 --- a/static/client/rendering/helpers/bar_renderer.py +++ b/static/client/rendering/helpers/bar_renderer.py @@ -143,5 +143,3 @@ def render_bar_helper(primitives, bar, coordinate_mapper, style): screen_space=True, metadata=below_meta, ) - - diff --git a/static/client/rendering/helpers/cartesian_renderer.py b/static/client/rendering/helpers/cartesian_renderer.py index e5763c1c..4cce84f1 100644 --- a/static/client/rendering/helpers/cartesian_renderer.py +++ b/static/client/rendering/helpers/cartesian_renderer.py @@ -70,18 +70,18 @@ def _format_tick_value(value: float, precision: int) -> str: # Use 2 significant figures for scientific notation formatted = f"{value:.1e}" # Clean up the exponent format (remove leading zeros) - if 'e' in formatted: - base, exp = formatted.split('e') - exp_sign = exp[0] if exp[0] in '+-' else '+' - exp_num = exp.lstrip('+-').lstrip('0') or '0' + if "e" in formatted: + base, exp = formatted.split("e") + exp_sign = exp[0] if exp[0] in "+-" else "+" + exp_num = exp.lstrip("+-").lstrip("0") or "0" formatted = f"{base}e{exp_sign}{exp_num}" return formatted if precision <= 0: return str(int(round(value))) formatted = f"{value:.{precision}f}" # Strip trailing zeros but keep at least one decimal place if precision > 0 - if '.' in formatted: - formatted = formatted.rstrip('0').rstrip('.') + if "." in formatted: + formatted = formatted.rstrip("0").rstrip(".") return formatted @@ -100,8 +100,20 @@ def _draw_cartesian_axes(primitives, ox, oy, width_px, height_px, axis_stroke): primitives.stroke_line((ox, 0.0), (ox, height_px), axis_stroke) -def _draw_cartesian_tick_x(primitives, x_pos, ox, oy, scale, tick_size, tick_font_float, font, - label_color, label_alignment, tick_stroke, precision=6): +def _draw_cartesian_tick_x( + primitives, + x_pos, + ox, + oy, + scale, + tick_size, + tick_font_float, + font, + label_color, + label_alignment, + tick_stroke, + precision=6, +): """Draw a single X-axis tick mark with label. Args: @@ -139,8 +151,9 @@ def _draw_cartesian_tick_x(primitives, x_pos, ox, oy, scale, tick_size, tick_fon ) -def _draw_cartesian_tick_y(primitives, y_pos, ox, oy, scale, tick_size, font, - label_color, label_alignment, tick_stroke, precision=6): +def _draw_cartesian_tick_y( + primitives, y_pos, ox, oy, scale, tick_size, font, label_color, label_alignment, tick_stroke, precision=6 +): """Draw a single Y-axis tick mark with label. Args: @@ -199,8 +212,7 @@ def _draw_cartesian_mid_tick_y(primitives, y_pos, ox, mid_tick_size, tick_stroke primitives.stroke_line((ox - mid_tick_size, y_pos), (ox + mid_tick_size, y_pos), tick_stroke) -def _draw_cartesian_grid_lines_x(primitives, ox, width_px, height_px, display_tick, grid_stroke, - minor_grid_stroke): +def _draw_cartesian_grid_lines_x(primitives, ox, width_px, height_px, display_tick, grid_stroke, minor_grid_stroke): """Draw vertical grid lines at regular X intervals. Args: @@ -215,6 +227,7 @@ def _draw_cartesian_grid_lines_x(primitives, ox, width_px, height_px, display_ti if display_tick <= 0: return import math + start_n = int(math.ceil(-ox / display_tick)) end_n = int(math.floor((width_px - ox) / display_tick)) for n in range(start_n, end_n + 1): @@ -227,8 +240,7 @@ def _draw_cartesian_grid_lines_x(primitives, ox, width_px, height_px, display_ti primitives.stroke_line((mid_x, 0.0), (mid_x, height_px), minor_grid_stroke) -def _draw_cartesian_grid_lines_y(primitives, oy, width_px, height_px, display_tick, grid_stroke, - minor_grid_stroke): +def _draw_cartesian_grid_lines_y(primitives, oy, width_px, height_px, display_tick, grid_stroke, minor_grid_stroke): """Draw horizontal grid lines at regular Y intervals. Args: @@ -243,6 +255,7 @@ def _draw_cartesian_grid_lines_y(primitives, oy, width_px, height_px, display_ti if display_tick <= 0: return import math + start_n = int(math.ceil(-oy / display_tick)) end_n = int(math.floor((height_px - oy) / display_tick)) for n in range(start_n, end_n + 1): @@ -255,9 +268,21 @@ def _draw_cartesian_grid_lines_y(primitives, oy, width_px, height_px, display_ti primitives.stroke_line((0.0, mid_y), (width_px, mid_y), minor_grid_stroke) -def _draw_cartesian_ticks_x(primitives, ox, oy, width_px, scale, display_tick, tick_size, - mid_tick_size, tick_font_float, font, label_color, label_alignment, - tick_stroke): +def _draw_cartesian_ticks_x( + primitives, + ox, + oy, + width_px, + scale, + display_tick, + tick_size, + mid_tick_size, + tick_font_float, + font, + label_color, + label_alignment, + tick_stroke, +): """Draw all tick marks and labels along the X axis. Args: @@ -278,6 +303,7 @@ def _draw_cartesian_ticks_x(primitives, ox, oy, width_px, scale, display_tick, t if display_tick <= 0: return import math + # Calculate math spacing and precision needed for labels math_spacing = display_tick / scale if scale > 0 else display_tick precision = _calculate_tick_precision(math_spacing) @@ -286,15 +312,39 @@ def _draw_cartesian_ticks_x(primitives, ox, oy, width_px, scale, display_tick, t for n in range(start_n, end_n + 1): x = ox + n * display_tick if 0 <= x <= width_px: - _draw_cartesian_tick_x(primitives, x, ox, oy, scale, tick_size, tick_font_float, font, - label_color, label_alignment, tick_stroke, precision) + _draw_cartesian_tick_x( + primitives, + x, + ox, + oy, + scale, + tick_size, + tick_font_float, + font, + label_color, + label_alignment, + tick_stroke, + precision, + ) mid_x = x + display_tick * 0.5 if 0 <= mid_x <= width_px: _draw_cartesian_mid_tick_x(primitives, mid_x, oy, mid_tick_size, tick_stroke) -def _draw_cartesian_ticks_y(primitives, ox, oy, height_px, scale, display_tick, tick_size, - mid_tick_size, font, label_color, label_alignment, tick_stroke): +def _draw_cartesian_ticks_y( + primitives, + ox, + oy, + height_px, + scale, + display_tick, + tick_size, + mid_tick_size, + font, + label_color, + label_alignment, + tick_stroke, +): """Draw all tick marks and labels along the Y axis. Args: @@ -314,6 +364,7 @@ def _draw_cartesian_ticks_y(primitives, ox, oy, height_px, scale, display_tick, if display_tick <= 0: return import math + # Calculate math spacing and precision needed for labels math_spacing = display_tick / scale if scale > 0 else display_tick precision = _calculate_tick_precision(math_spacing) @@ -322,8 +373,9 @@ def _draw_cartesian_ticks_y(primitives, ox, oy, height_px, scale, display_tick, for n in range(start_n, end_n + 1): y = oy + n * display_tick if 0 <= y <= height_px: - _draw_cartesian_tick_y(primitives, y, ox, oy, scale, tick_size, font, label_color, - label_alignment, tick_stroke, precision) + _draw_cartesian_tick_y( + primitives, y, ox, oy, scale, tick_size, font, label_color, label_alignment, tick_stroke, precision + ) mid_y = y + display_tick * 0.5 if 0 <= mid_y <= height_px: _draw_cartesian_mid_tick_y(primitives, mid_y, ox, mid_tick_size, tick_stroke) @@ -331,9 +383,23 @@ def _draw_cartesian_ticks_y(primitives, ox, oy, height_px, scale, display_tick, @_manages_shape def _render_cartesian_grid( - primitives, ox, oy, width_px, height_px, scale, display_tick, tick_size, mid_tick_size, - tick_font_float, font, label_color, label_alignment, axis_stroke, grid_stroke, - minor_grid_stroke, tick_stroke + primitives, + ox, + oy, + width_px, + height_px, + scale, + display_tick, + tick_size, + mid_tick_size, + tick_font_float, + font, + label_color, + label_alignment, + axis_stroke, + grid_stroke, + minor_grid_stroke, + tick_stroke, ): """Render the complete Cartesian grid with all components. @@ -357,15 +423,37 @@ def _render_cartesian_grid( tick_stroke: StrokeStyle for tick marks. """ _draw_cartesian_axes(primitives, ox, oy, width_px, height_px, axis_stroke) - _draw_cartesian_grid_lines_x(primitives, ox, width_px, height_px, display_tick, grid_stroke, - minor_grid_stroke) - _draw_cartesian_grid_lines_y(primitives, oy, width_px, height_px, display_tick, grid_stroke, - minor_grid_stroke) - _draw_cartesian_ticks_x(primitives, ox, oy, width_px, scale, display_tick, tick_size, - mid_tick_size, tick_font_float, font, label_color, label_alignment, - tick_stroke) - _draw_cartesian_ticks_y(primitives, ox, oy, height_px, scale, display_tick, tick_size, - mid_tick_size, font, label_color, label_alignment, tick_stroke) + _draw_cartesian_grid_lines_x(primitives, ox, width_px, height_px, display_tick, grid_stroke, minor_grid_stroke) + _draw_cartesian_grid_lines_y(primitives, oy, width_px, height_px, display_tick, grid_stroke, minor_grid_stroke) + _draw_cartesian_ticks_x( + primitives, + ox, + oy, + width_px, + scale, + display_tick, + tick_size, + mid_tick_size, + tick_font_float, + font, + label_color, + label_alignment, + tick_stroke, + ) + _draw_cartesian_ticks_y( + primitives, + ox, + oy, + height_px, + scale, + display_tick, + tick_size, + mid_tick_size, + font, + label_color, + label_alignment, + tick_stroke, + ) def _get_cartesian_styles(style): @@ -418,9 +506,7 @@ def _get_cartesian_styles(style): if not math.isfinite(minor_grid_width): minor_grid_width = 0.5 minor_grid_width = max(minor_grid_width, 0.0) - minor_grid_stroke = ( - StrokeStyle(color=minor_grid_color, width=minor_grid_width) if minor_grid_width > 0.0 else None - ) + minor_grid_stroke = StrokeStyle(color=minor_grid_color, width=minor_grid_width) if minor_grid_width > 0.0 else None return { "label_color": label_color, @@ -541,4 +627,3 @@ def render_cartesian_helper(primitives, cartesian, coordinate_mapper, style): styles["minor_grid_stroke"], styles["tick_stroke"], ) - diff --git a/static/client/rendering/helpers/circle_arc_renderer.py b/static/client/rendering/helpers/circle_arc_renderer.py index 4b88d7ce..ddf0b193 100644 --- a/static/client/rendering/helpers/circle_arc_renderer.py +++ b/static/client/rendering/helpers/circle_arc_renderer.py @@ -104,7 +104,9 @@ def _compute_circle_arc_sweep(circle_arc, center_screen, point1_screen): @_manages_shape -def _stroke_circle_arc(primitives, circle_arc, center_screen, radius_on_screen, start_angle, end_angle, sweep_clockwise, style): +def _stroke_circle_arc( + primitives, circle_arc, center_screen, radius_on_screen, start_angle, end_angle, sweep_clockwise, style +): """Stroke a circle arc with the computed parameters. Args: @@ -179,4 +181,3 @@ def render_circle_arc_helper(primitives, circle_arc, coordinate_mapper, style): sweep_clockwise, style, ) - diff --git a/static/client/rendering/helpers/circle_renderer.py b/static/client/rendering/helpers/circle_renderer.py index d4c2e823..21bbbd8c 100644 --- a/static/client/rendering/helpers/circle_renderer.py +++ b/static/client/rendering/helpers/circle_renderer.py @@ -37,4 +37,3 @@ def render_circle_helper(primitives, circle, coordinate_mapper, style): width=float(style.get("circle_stroke_width", 1) or 1), ) _render_circle(primitives, center, radius, stroke) - diff --git a/static/client/rendering/helpers/colored_area_renderer.py b/static/client/rendering/helpers/colored_area_renderer.py index 287772a4..2b6cf57f 100644 --- a/static/client/rendering/helpers/colored_area_renderer.py +++ b/static/client/rendering/helpers/colored_area_renderer.py @@ -204,4 +204,3 @@ def render_closed_shape_area_helper(primitives, area_model, coordinate_mapper, s if area is None: return render_colored_area_helper(primitives, area, coordinate_mapper, style) - diff --git a/static/client/rendering/helpers/ellipse_renderer.py b/static/client/rendering/helpers/ellipse_renderer.py index e71a1d25..68734c0a 100644 --- a/static/client/rendering/helpers/ellipse_renderer.py +++ b/static/client/rendering/helpers/ellipse_renderer.py @@ -77,4 +77,3 @@ def render_ellipse_helper(primitives, ellipse, coordinate_mapper, style): rotation_rad = 0.0 _render_ellipse(primitives, center, rx, ry, rotation_rad, stroke) - diff --git a/static/client/rendering/helpers/font_helpers.py b/static/client/rendering/helpers/font_helpers.py index 05451f85..35d0d042 100644 --- a/static/client/rendering/helpers/font_helpers.py +++ b/static/client/rendering/helpers/font_helpers.py @@ -84,4 +84,3 @@ def _compute_zoom_adjusted_font_size(base_size: float, label: Any, coordinate_ma if scaled <= label_vanish_threshold_px: return 0.0 return max(scaled, label_min_screen_font_px) - diff --git a/static/client/rendering/helpers/function_renderer.py b/static/client/rendering/helpers/function_renderer.py index a266c69b..1a1d215e 100644 --- a/static/client/rendering/helpers/function_renderer.py +++ b/static/client/rendering/helpers/function_renderer.py @@ -189,4 +189,3 @@ def render_function_helper(primitives, func, coordinate_mapper, style): _render_function_paths(primitives, screen_paths, stroke, width, height) _render_function_label(primitives, func, screen_paths, stroke, style) - diff --git a/static/client/rendering/helpers/label_overlap_resolver.py b/static/client/rendering/helpers/label_overlap_resolver.py index b3b5b7f0..dcf1cfbd 100644 --- a/static/client/rendering/helpers/label_overlap_resolver.py +++ b/static/client/rendering/helpers/label_overlap_resolver.py @@ -205,5 +205,3 @@ def get_or_place_dy(self, group: Any, base_rect: Rect, *, step: float) -> float: self._group_to_dy[group] = dy self._placed_rects.append(shift_rect_y(inflate_rect(base_rect, self._padding_px), dy)) return dy - - diff --git a/static/client/rendering/helpers/label_renderer.py b/static/client/rendering/helpers/label_renderer.py index db777573..4780fee7 100644 --- a/static/client/rendering/helpers/label_renderer.py +++ b/static/client/rendering/helpers/label_renderer.py @@ -486,4 +486,3 @@ def render_label_helper(primitives, label, coordinate_mapper, style): screen_x=screen_x, screen_y=screen_y, ) - diff --git a/static/client/rendering/helpers/parametric_function_renderer.py b/static/client/rendering/helpers/parametric_function_renderer.py index fd1d342c..4aac4a3e 100644 --- a/static/client/rendering/helpers/parametric_function_renderer.py +++ b/static/client/rendering/helpers/parametric_function_renderer.py @@ -108,10 +108,7 @@ def _render_parametric_label(primitives, func, screen_paths, stroke, style): first_point = screen_paths[0][0] label_offset_x = (1 + len(func.name)) * font_size / 2.0 position = (first_point[0] - label_offset_x, max(first_point[1], font_size)) - font_family = style.get( - "function_label_font_family", - style.get("font_family", default_font_family) - ) + font_family = style.get("function_label_font_family", style.get("font_family", default_font_family)) font = FontStyle(family=font_family, size=font_size) primitives.draw_text( func.name, diff --git a/static/client/rendering/helpers/point_renderer.py b/static/client/rendering/helpers/point_renderer.py index 4591e069..97392e35 100644 --- a/static/client/rendering/helpers/point_renderer.py +++ b/static/client/rendering/helpers/point_renderer.py @@ -68,4 +68,3 @@ def render_point_helper(primitives, point, coordinate_mapper, style): color=str(getattr(fill, "color", "#000")), style=style, ) - diff --git a/static/client/rendering/helpers/polar_renderer.py b/static/client/rendering/helpers/polar_renderer.py index c0d4f7df..e73e8a62 100644 --- a/static/client/rendering/helpers/polar_renderer.py +++ b/static/client/rendering/helpers/polar_renderer.py @@ -70,18 +70,18 @@ def _format_tick_value(value: float, precision: int) -> str: # Use 2 significant figures for scientific notation formatted = f"{value:.1e}" # Clean up the exponent format (remove leading zeros) - if 'e' in formatted: - base, exp = formatted.split('e') - exp_sign = exp[0] if exp[0] in '+-' else '+' - exp_num = exp.lstrip('+-').lstrip('0') or '0' + if "e" in formatted: + base, exp = formatted.split("e") + exp_sign = exp[0] if exp[0] in "+-" else "+" + exp_num = exp.lstrip("+-").lstrip("0") or "0" formatted = f"{base}e{exp_sign}{exp_num}" return formatted if precision <= 0: return str(int(round(value))) formatted = f"{value:.{precision}f}" # Strip trailing zeros but keep at least one decimal place if precision > 0 - if '.' in formatted: - formatted = formatted.rstrip('0').rstrip('.') + if "." in formatted: + formatted = formatted.rstrip("0").rstrip(".") return formatted @@ -117,8 +117,18 @@ def _draw_radial_lines(primitives, ox, oy, max_radius_screen, angular_step_degre primitives.stroke_line((ox, oy), (end_x, end_y), radial_stroke) -def _draw_angle_labels(primitives, ox, oy, label_radius_screen, angular_step_degrees, - font, label_color, label_alignment, width_px, height_px): +def _draw_angle_labels( + primitives, + ox, + oy, + label_radius_screen, + angular_step_degrees, + font, + label_color, + label_alignment, + width_px, + height_px, +): """Draw angle labels at the external boundary of the visible canvas.""" if angular_step_degrees <= 0: return @@ -168,8 +178,9 @@ def _draw_angle_labels(primitives, ox, oy, label_radius_screen, angular_step_deg primitives.draw_text(label_text, (label_x, label_y), font, label_color, label_alignment) -def _draw_radius_labels(primitives, ox, oy, scale, display_spacing, max_radius_screen, - tick_font_float, font, label_color, label_alignment): +def _draw_radius_labels( + primitives, ox, oy, scale, display_spacing, max_radius_screen, tick_font_float, font, label_color, label_alignment +): """Draw radius labels along the positive x-axis. Uses spacing-aware formatting to show minimum digits needed to distinguish @@ -200,9 +211,22 @@ def _draw_origin_marker(primitives, ox, oy, tick_font_float, font, label_color, @_manages_shape def _render_polar_grid( - primitives, ox, oy, width_px, height_px, scale, display_spacing, max_radius_screen, - angular_step_degrees, tick_font_float, font, label_color, label_alignment, - axis_stroke, circle_stroke, radial_stroke + primitives, + ox, + oy, + width_px, + height_px, + scale, + display_spacing, + max_radius_screen, + angular_step_degrees, + tick_font_float, + font, + label_color, + label_alignment, + axis_stroke, + circle_stroke, + radial_stroke, ): """Render the complete polar grid with all components. @@ -227,10 +251,30 @@ def _render_polar_grid( _draw_polar_axes(primitives, ox, oy, width_px, height_px, axis_stroke) _draw_concentric_circles(primitives, ox, oy, max_radius_screen, display_spacing, circle_stroke) _draw_radial_lines(primitives, ox, oy, max_radius_screen, angular_step_degrees, radial_stroke) - _draw_angle_labels(primitives, ox, oy, max_radius_screen, angular_step_degrees, - font, label_color, label_alignment, width_px, height_px) - _draw_radius_labels(primitives, ox, oy, scale, display_spacing, max_radius_screen, - tick_font_float, font, label_color, label_alignment) + _draw_angle_labels( + primitives, + ox, + oy, + max_radius_screen, + angular_step_degrees, + font, + label_color, + label_alignment, + width_px, + height_px, + ) + _draw_radius_labels( + primitives, + ox, + oy, + scale, + display_spacing, + max_radius_screen, + tick_font_float, + font, + label_color, + label_alignment, + ) _draw_origin_marker(primitives, ox, oy, tick_font_float, font, label_color, label_alignment) diff --git a/static/client/rendering/helpers/screen_offset_label_helper.py b/static/client/rendering/helpers/screen_offset_label_helper.py index cdaf75ea..cfe1de07 100644 --- a/static/client/rendering/helpers/screen_offset_label_helper.py +++ b/static/client/rendering/helpers/screen_offset_label_helper.py @@ -149,5 +149,3 @@ def draw_point_style_label_with_coords( screen_space=True, metadata=label_metadata, ) - - diff --git a/static/client/rendering/helpers/screen_offset_label_layout.py b/static/client/rendering/helpers/screen_offset_label_layout.py index 46376c7a..bbe0b923 100644 --- a/static/client/rendering/helpers/screen_offset_label_layout.py +++ b/static/client/rendering/helpers/screen_offset_label_layout.py @@ -109,6 +109,7 @@ class LabelTextCall: font_size: Font size in pixels. line_height: Line spacing in pixels. """ + __slots__ = ( "group", "order", @@ -172,6 +173,7 @@ class LabelBlock: base_rect: Bounding rectangle at dy=0. step: Vertical step size for displacement. """ + __slots__ = ("group", "order", "base_rect", "step") def __init__(self, *, group: Any, order: int, base_rect: Rect, step: float) -> None: @@ -193,6 +195,7 @@ class SpatialHash2D: _rects: Dict mapping groups to their bounding rectangles. _group_cells: Dict mapping groups to their occupied cells. """ + __slots__ = ("_cell_size", "_cells", "_rects", "_group_cells") def __init__(self, *, cell_size: float = 32.0) -> None: @@ -478,7 +481,7 @@ def better_dy(a: float, b: float, *, overlaps_a: int, overlaps_b: int) -> float: return a best_dy = self._dy.get(mover, 0.0) - best_overlaps = 10 ** 9 + best_overlaps = 10**9 for k in range(0, max_steps + 1): if k == 0: @@ -495,7 +498,7 @@ def better_dy(a: float, b: float, *, overlaps_a: int, overlaps_b: int) -> float: else: best_zero = better_dy(best_zero, candidate_dy, overlaps_a=0, overlaps_b=0) continue - if best_overlaps == 10 ** 9: + if best_overlaps == 10**9: best_overlaps = overlaps best_dy = candidate_dy continue @@ -886,7 +889,7 @@ def better(a: float, b: float, oa: int, ob: int) -> float: return a best_dy = dy.get(g, 0.0) or 0.0 - best_overlaps = 10 ** 9 + best_overlaps = 10**9 for k in range(0, max_k + 1): if k == 0: candidates = (0.0,) @@ -896,7 +899,7 @@ def better(a: float, b: float, oa: int, ob: int) -> float: for cand in candidates: cand_rect = shift_rect_y(rect0, cand) overlaps, _ = _count_overlaps(grid, cand_rect, ignore=ignore) - if best_overlaps == 10 ** 9: + if best_overlaps == 10**9: best_overlaps = overlaps best_dy = cand if overlaps == 0: @@ -969,7 +972,7 @@ def better(a: float, b: float, oa: int, ob: int) -> float: # Lookahead mover selection under the hard dy bound. best_group: Any = None best_dy: float = 0.0 - best_overlaps: int = 10 ** 9 + best_overlaps: int = 10**9 best_key: Optional[Tuple[int, float, float, int]] = None for cand in collision_set: if cand in hidden or grid.get_rect(cand) is None: @@ -1055,5 +1058,3 @@ def better(a: float, b: float, oa: int, ob: int) -> float: dy.pop(h, None) return dict(dy), hidden - - diff --git a/static/client/rendering/helpers/segment_renderer.py b/static/client/rendering/helpers/segment_renderer.py index a557e467..aee918c2 100644 --- a/static/client/rendering/helpers/segment_renderer.py +++ b/static/client/rendering/helpers/segment_renderer.py @@ -48,4 +48,3 @@ def render_segment_helper(primitives, segment, coordinate_mapper, style): embedded_label = getattr(segment, "label", None) if embedded_label is not None: render_label_helper(primitives, embedded_label, coordinate_mapper, style) - diff --git a/static/client/rendering/helpers/shape_decorator.py b/static/client/rendering/helpers/shape_decorator.py index ce1877aa..8a8c443a 100644 --- a/static/client/rendering/helpers/shape_decorator.py +++ b/static/client/rendering/helpers/shape_decorator.py @@ -24,6 +24,7 @@ def _manages_shape(render_fn): Returns: Wrapped function that calls begin_shape before and end_shape after. """ + def wrapper(primitives, *args, **kwargs): begin_shape = getattr(primitives, "begin_shape", None) end_shape = getattr(primitives, "end_shape", None) @@ -35,5 +36,5 @@ def wrapper(primitives, *args, **kwargs): finally: if managing: end_shape() - return wrapper + return wrapper diff --git a/static/client/rendering/helpers/vector_renderer.py b/static/client/rendering/helpers/vector_renderer.py index 4b0e9a17..278c67c9 100644 --- a/static/client/rendering/helpers/vector_renderer.py +++ b/static/client/rendering/helpers/vector_renderer.py @@ -65,4 +65,3 @@ def render_vector_helper(primitives, vector, coordinate_mapper, style): color = str(getattr(vector, "color", getattr(seg, "color", style.get("vector_color", "#000")))) stroke = StrokeStyle(color=color, width=float(style.get("segment_stroke_width", 1) or 1)) _render_vector(primitives, start, end, seg, color, stroke, style) - diff --git a/static/client/rendering/helpers/world_label_helper.py b/static/client/rendering/helpers/world_label_helper.py index dbbac1cd..5da615e1 100644 --- a/static/client/rendering/helpers/world_label_helper.py +++ b/static/client/rendering/helpers/world_label_helper.py @@ -108,7 +108,9 @@ def compute_world_label_font(label: Any, style: dict, coordinate_mapper: Any): return font, base_font_size, effective_font_size -def build_world_label_metadata(index: int, position: Any, offset_y: float, rotation_degrees: float, label: Any, base_font_size: float): +def build_world_label_metadata( + index: int, position: Any, offset_y: float, rotation_degrees: float, label: Any, base_font_size: float +): """Build metadata for world-space label reprojection. Args: @@ -189,5 +191,3 @@ def render_world_label_at_screen_point( alignment, metadata=metadata, ) - - diff --git a/static/client/rendering/interfaces.py b/static/client/rendering/interfaces.py index c7bb1553..c2e898f8 100644 --- a/static/client/rendering/interfaces.py +++ b/static/client/rendering/interfaces.py @@ -20,28 +20,18 @@ class RendererProtocol(Protocol): """Minimal renderer contract consumed by the canvas.""" - def clear(self) -> None: - ... + def clear(self) -> None: ... - def render(self, drawable: Any, coordinate_mapper: Any) -> bool: - ... + def render(self, drawable: Any, coordinate_mapper: Any) -> bool: ... - def render_cartesian(self, cartesian: Any, coordinate_mapper: Any) -> None: - ... + def render_cartesian(self, cartesian: Any, coordinate_mapper: Any) -> None: ... - def render_polar(self, polar_grid: Any, coordinate_mapper: Any) -> None: - ... + def render_polar(self, polar_grid: Any, coordinate_mapper: Any) -> None: ... - def register(self, cls: type, handler: Callable[[Any, Any], None]) -> None: - ... + def register(self, cls: type, handler: Callable[[Any, Any], None]) -> None: ... - def register_default_drawables(self) -> None: - ... - - def begin_frame(self) -> None: - ... - - def end_frame(self) -> None: - ... + def register_default_drawables(self) -> None: ... + def begin_frame(self) -> None: ... + def end_frame(self) -> None: ... diff --git a/static/client/rendering/primitives.py b/static/client/rendering/primitives.py index e7c989e2..c16d3a8f 100644 --- a/static/client/rendering/primitives.py +++ b/static/client/rendering/primitives.py @@ -65,7 +65,9 @@ def __init__(self, text: str, x: float, y: float) -> None: class StrokeStyle: __slots__ = ("color", "width", "line_join", "line_cap") - def __init__(self, color: str, width: float, line_join: Optional[str] = None, line_cap: Optional[str] = None, **kwargs: Any) -> None: + def __init__( + self, color: str, width: float, line_join: Optional[str] = None, line_cap: Optional[str] = None, **kwargs: Any + ) -> None: self.color = str(color) self.width = float(width) self.line_join = line_join @@ -113,7 +115,9 @@ def __init__(self, horizontal: str = "left", vertical: str = "alphabetic") -> No class RendererPrimitives: """Backend-specific primitive surface consumed by shared helpers.""" - def stroke_line(self, start: Tuple[float, float], end: Tuple[float, float], stroke: StrokeStyle, *, include_width: bool = True) -> None: + def stroke_line( + self, start: Tuple[float, float], end: Tuple[float, float], stroke: StrokeStyle, *, include_width: bool = True + ) -> None: raise NotImplementedError def stroke_polyline(self, points: List[Tuple[float, float]], stroke: StrokeStyle) -> None: @@ -122,10 +126,20 @@ def stroke_polyline(self, points: List[Tuple[float, float]], stroke: StrokeStyle def stroke_circle(self, center: Tuple[float, float], radius: float, stroke: StrokeStyle) -> None: raise NotImplementedError - def fill_circle(self, center: Tuple[float, float], radius: float, fill: FillStyle, stroke: Optional[StrokeStyle] = None, *, screen_space: bool = False) -> None: + def fill_circle( + self, + center: Tuple[float, float], + radius: float, + fill: FillStyle, + stroke: Optional[StrokeStyle] = None, + *, + screen_space: bool = False, + ) -> None: raise NotImplementedError - def stroke_ellipse(self, center: Tuple[float, float], radius_x: float, radius_y: float, rotation_rad: float, stroke: StrokeStyle) -> None: + def stroke_ellipse( + self, center: Tuple[float, float], radius_x: float, radius_y: float, rotation_rad: float, stroke: StrokeStyle + ) -> None: raise NotImplementedError def fill_polygon( @@ -139,7 +153,9 @@ def fill_polygon( ) -> None: raise NotImplementedError - def fill_joined_area(self, forward: List[Tuple[float, float]], reverse: List[Tuple[float, float]], fill: FillStyle) -> None: + def fill_joined_area( + self, forward: List[Tuple[float, float]], reverse: List[Tuple[float, float]], fill: FillStyle + ) -> None: raise NotImplementedError def stroke_arc( diff --git a/static/client/rendering/renderables/__init__.py b/static/client/rendering/renderables/__init__.py index 53faa6e7..2b676ee0 100644 --- a/static/client/rendering/renderables/__init__.py +++ b/static/client/rendering/renderables/__init__.py @@ -15,4 +15,3 @@ "ParametricFunctionRenderable", "SegmentsBoundedAreaRenderable", ] - diff --git a/static/client/rendering/renderables/adaptive_sampler.py b/static/client/rendering/renderables/adaptive_sampler.py index 94640521..02eeb87b 100644 --- a/static/client/rendering/renderables/adaptive_sampler.py +++ b/static/client/rendering/renderables/adaptive_sampler.py @@ -66,8 +66,7 @@ def generate_samples_with_asymptotes( if len(valid_asymptotes) > MAX_SUBRANGES - 1: # Too many asymptotes - just sample without splitting samples, _ = AdaptiveSampler.generate_samples( - left_bound, right_bound, eval_func, math_to_screen, - initial_segments, max_samples + left_bound, right_bound, eval_func, math_to_screen, initial_segments, max_samples ) return [samples] if samples else [] @@ -91,8 +90,7 @@ def generate_samples_with_asymptotes( continue samples, _ = AdaptiveSampler.generate_samples( - sub_left, sub_right, eval_func, math_to_screen, - initial_segments, max_samples + sub_left, sub_right, eval_func, math_to_screen, initial_segments, max_samples ) if samples: @@ -137,8 +135,7 @@ def generate_samples( min_expected_samples = 20 if len(results) < min_expected_samples and initial_segments is None: estimated_period = AdaptiveSampler._detect_periodicity( - left_bound, right_bound, results, - eval_func, math_to_screen + left_bound, right_bound, results, eval_func, math_to_screen ) if estimated_period is not None: range_width = right_bound - left_bound @@ -179,10 +176,7 @@ def _generate_with_segments( if p_left is not None and p_right is not None: AdaptiveSampler._subdivide( - x_left, x_right, - p_left, p_right, - eval_func, math_to_screen, - 0, results, max_samples + x_left, x_right, p_left, p_right, eval_func, math_to_screen, 0, results, max_samples ) return results @@ -300,12 +294,10 @@ def _subdivide( return AdaptiveSampler._subdivide( - x_left, x_mid, p_left, p_mid, - eval_func, math_to_screen, depth + 1, results, max_samples + x_left, x_mid, p_left, p_mid, eval_func, math_to_screen, depth + 1, results, max_samples ) AdaptiveSampler._subdivide( - x_mid, x_right, p_mid, p_right, - eval_func, math_to_screen, depth + 1, results, max_samples + x_mid, x_right, p_mid, p_right, eval_func, math_to_screen, depth + 1, results, max_samples ) @staticmethod @@ -331,4 +323,3 @@ def _is_straight( distance = cross / math.sqrt(length_sq) return distance < PIXEL_TOLERANCE - diff --git a/static/client/rendering/renderables/closed_shape_area_renderable.py b/static/client/rendering/renderables/closed_shape_area_renderable.py index 0bd475f8..371c795c 100644 --- a/static/client/rendering/renderables/closed_shape_area_renderable.py +++ b/static/client/rendering/renderables/closed_shape_area_renderable.py @@ -182,4 +182,3 @@ def _build_region_area( forward = [(float(x), float(y)) for x, y in points] reverse = list(reversed(forward)) return forward, reverse - diff --git a/static/client/rendering/renderables/function_renderable.py b/static/client/rendering/renderables/function_renderable.py index 7d81c380..b81eb2b5 100644 --- a/static/client/rendering/renderables/function_renderable.py +++ b/static/client/rendering/renderables/function_renderable.py @@ -64,13 +64,15 @@ def _get_screen_signature(self) -> Tuple[int, int]: screen_height = getattr(self.mapper, "canvas_height", None) return (int(screen_width or 0), int(screen_height or 0)) - def _update_cache_state(self, scale: Optional[float], bounds: Tuple[float, float], screen_sig: Tuple[int, int]) -> None: + def _update_cache_state( + self, scale: Optional[float], bounds: Tuple[float, float], screen_sig: Tuple[int, int] + ) -> None: self._last_scale = scale self._last_bounds = bounds self._last_screen_bounds = screen_sig def _should_regenerate(self) -> bool: - current_scale: Optional[float] = getattr(self.mapper, 'scale_factor', None) + current_scale: Optional[float] = getattr(self.mapper, "scale_factor", None) current_bounds: Tuple[float, float] = self._get_visible_bounds() screen_signature = self._get_screen_signature() if self._cached_screen_paths is None or not self._cache_valid: @@ -100,14 +102,14 @@ def _resolve_bounds(self, left_bound: Optional[float], right_bound: Optional[flo def _is_discontinuity(self, x: float) -> bool: try: - if getattr(self.func, 'point_discontinuities', None) and x in self.func.point_discontinuities: + if getattr(self.func, "point_discontinuities", None) and x in self.func.point_discontinuities: return True except Exception: pass return False def _get_asymptote_between(self, x1: float, x2: float) -> Optional[float]: - if not hasattr(self.func, 'get_vertical_asymptote_between_x'): + if not hasattr(self.func, "get_vertical_asymptote_between_x"): return None try: return cast(Optional[float], self.func.get_vertical_asymptote_between_x(x1, x2)) @@ -123,7 +125,7 @@ def _evaluate_function(self, x: float) -> Optional[float]: def _is_invalid_y(self, y: Optional[float]) -> bool: if y is None: return True - if isinstance(y, float) and (y != y or abs(y) == float('inf')): + if isinstance(y, float) and (y != y or abs(y) == float("inf")): return True return False @@ -182,8 +184,8 @@ def build_screen_paths(self) -> ScreenPolyline: def _get_effective_bounds(self) -> Tuple[float, float]: visible_left, visible_right = self._get_visible_bounds() - base_left: Optional[float] = getattr(self.func, 'left_bound', None) - base_right: Optional[float] = getattr(self.func, 'right_bound', None) + base_left: Optional[float] = getattr(self.func, "left_bound", None) + base_right: Optional[float] = getattr(self.func, "right_bound", None) # Use visible bounds when no explicit function bounds are set if base_left is None: base_left = visible_left @@ -192,8 +194,8 @@ def _get_effective_bounds(self) -> Tuple[float, float]: return max(visible_left, base_left), min(visible_right, base_right) def _get_screen_dimensions(self) -> Tuple[float, float]: - width: float = getattr(self.mapper, 'canvas_width', 0) or 0 - height: float = getattr(self.mapper, 'canvas_height', 0) or 0 + width: float = getattr(self.mapper, "canvas_width", 0) or 0 + height: float = getattr(self.mapper, "canvas_height", 0) or 0 return width, height def _calculate_sample_points_by_subrange(self, left_bound: float, right_bound: float) -> list[list[float]]: @@ -201,33 +203,43 @@ def _calculate_sample_points_by_subrange(self, left_bound: float, right_bound: f Calculate sample points, splitting at asymptotes and discontinuities into separate sub-ranges. Returns a list of sample lists, one per continuous sub-range. """ - canvas_width = int(getattr(self.mapper, 'canvas_width', 800) or 800) + canvas_width = int(getattr(self.mapper, "canvas_width", 800) or 800) initial_segments = None - if getattr(self.func, 'is_periodic', False) and getattr(self.func, 'estimated_period', None): + if getattr(self.func, "is_periodic", False) and getattr(self.func, "estimated_period", None): range_width = right_bound - left_bound num_periods = range_width / self.func.estimated_period initial_segments = min(canvas_width, max(8, int(num_periods * 4))) # Get asymptotes AND point discontinuities - both require splitting - asymptotes = getattr(self.func, 'vertical_asymptotes', []) or [] - point_discontinuities = getattr(self.func, 'point_discontinuities', []) or [] + asymptotes = getattr(self.func, "vertical_asymptotes", []) or [] + point_discontinuities = getattr(self.func, "point_discontinuities", []) or [] # Combine all split points (asymptotes and discontinuities) all_split_points = sorted(set(asymptotes + point_discontinuities)) if all_split_points: - return cast(list[list[float]], AdaptiveSampler.generate_samples_with_asymptotes( - left_bound, right_bound, self.func.function, - self.mapper.math_to_screen, all_split_points, initial_segments, - max_samples=canvas_width - )) + return cast( + list[list[float]], + AdaptiveSampler.generate_samples_with_asymptotes( + left_bound, + right_bound, + self.func.function, + self.mapper.math_to_screen, + all_split_points, + initial_segments, + max_samples=canvas_width, + ), + ) else: # No split points - single range samples, _ = AdaptiveSampler.generate_samples( - left_bound, right_bound, self.func.function, - self.mapper.math_to_screen, initial_segments, - max_samples=canvas_width + left_bound, + right_bound, + self.func.function, + self.mapper.math_to_screen, + initial_segments, + max_samples=canvas_width, ) return [samples] if samples else [] @@ -239,7 +251,6 @@ def _eval_scaled_point(self, x_val: float) -> Tuple[Tuple[Optional[float], Optio except Exception: return (None, None), None - def _adjust_point_for_asymptote_ahead( self, x: float, step: float, scaled_point: Tuple, y_val: Any ) -> Tuple[Tuple, Any, bool]: @@ -264,8 +275,7 @@ def _is_large_jump(self, prev_sy: float, sy: float, height: float) -> bool: return abs(prev_sy - sy) > height * 2 def _is_point_visible( - self, sx: float, sy: float, width: float, height: float, - visible_min_x: float, visible_max_x: float + self, sx: float, sy: float, width: float, height: float, visible_min_x: float, visible_max_x: float ) -> bool: if sy >= height or sy <= 0: return False @@ -281,11 +291,7 @@ def _clamp_screen_y(self, sy: float, height: float) -> float: return height # bottom of screen return sy - def _finalize_path( - self, - current_path: list[tuple[float, float]], - paths: list[list[tuple[float, float]]] - ) -> None: + def _finalize_path(self, current_path: list[tuple[float, float]], paths: list[list[tuple[float, float]]]) -> None: """Add current path to paths list if non-empty.""" if current_path: paths.append(current_path) @@ -328,9 +334,7 @@ def _is_on_screen(self, sy: float, height: float) -> bool: """Check if y coordinate is within screen bounds.""" return 0 <= sy <= height - def _build_path_from_samples( - self, sample_points: list[float], height: float - ) -> list[list[tuple[float, float]]]: + def _build_path_from_samples(self, sample_points: list[float], height: float) -> list[list[tuple[float, float]]]: """ Build screen paths from a list of sample x-values (within a single sub-range). Path breaks on: failed evaluation, large y-jumps, discontinuities, or asymptotes. @@ -399,8 +403,13 @@ def _build_path_from_samples( return paths def _extend_paths_to_boundaries( - self, paths: list[list[tuple[float, float]]], width: float, height: float, step: float, - left_bound: float, right_bound: float + self, + paths: list[list[tuple[float, float]]], + width: float, + height: float, + step: float, + left_bound: float, + right_bound: float, ) -> None: """ Ensures each sub-path extends to screen boundaries for complete rendering. @@ -638,9 +647,7 @@ def _compute_boundary_intersection( t = (height - y1) / (y2 - y1) return (x1 + t * (x2 - x1), height) - def _clamp_to_boundary( - self, x1: float, y1: float, x2: float, y2: float, height: float - ) -> tuple[float, float]: + def _clamp_to_boundary(self, x1: float, y1: float, x2: float, y2: float, height: float) -> tuple[float, float]: """ If (x2,y2) is outside screen, return intersection of line (x1,y1)→(x2,y2) with screen boundary. Uses linear interpolation: t = (target_y - y1) / (y2 - y1) @@ -656,10 +663,20 @@ def _clamp_to_boundary( return (x1 + t * (x2 - x1), height) def _handle_boundary_crossing( - self, x: float, left_bound: float, right_bound: float, width: float, height: float, - prev_sx: float, prev_sy: float, sx_val: float, sy: float, - neighbor_prev_scaled_point: Tuple, scaled_point: Tuple, - current_path: list, paths: list + self, + x: float, + left_bound: float, + right_bound: float, + width: float, + height: float, + prev_sx: float, + prev_sy: float, + sx_val: float, + sy: float, + neighbor_prev_scaled_point: Tuple, + scaled_point: Tuple, + current_path: list, + paths: list, ) -> Optional[Tuple[list, list, bool]]: if x <= left_bound: return None @@ -676,18 +693,14 @@ def _handle_boundary_crossing( if crossed_bound_onto_screen: # Compute intersection point at screen boundary instead of using off-screen point - boundary_pt = self._compute_boundary_intersection( - prev_sx, prev_sy, sx_val, sy, height - ) + boundary_pt = self._compute_boundary_intersection(prev_sx, prev_sy, sx_val, sy, height) current_path.append(boundary_pt) current_path.append((scaled_point[0], scaled_point[1])) return current_path, paths, True if crossed_bound_off_screen: # Compute intersection point at screen boundary - boundary_pt = self._compute_boundary_intersection( - prev_sx, prev_sy, sx_val, sy, height - ) + boundary_pt = self._compute_boundary_intersection(prev_sx, prev_sy, sx_val, sy, height) current_path.append(boundary_pt) if current_path: paths.append(current_path) @@ -715,4 +728,3 @@ def _handle_boundary_crossing( return current_path, paths, True return None - diff --git a/static/client/rendering/renderables/function_segment_area_renderable.py b/static/client/rendering/renderables/function_segment_area_renderable.py index 8733dad0..ac47900c 100644 --- a/static/client/rendering/renderables/function_segment_area_renderable.py +++ b/static/client/rendering/renderables/function_segment_area_renderable.py @@ -19,7 +19,7 @@ def _get_bounds(self) -> Tuple[float, float]: seg_right: float seg_left, seg_right = self.area._get_segment_bounds() func: Any = self.area.func - if hasattr(func, 'left_bound') and hasattr(func, 'right_bound'): + if hasattr(func, "left_bound") and hasattr(func, "right_bound"): return max(seg_left, func.left_bound), min(seg_right, func.right_bound) return seg_left, seg_right @@ -29,7 +29,7 @@ def _eval_function(self, x_math: float) -> Optional[float]: return 0.0 if isinstance(func, (int, float)): return float(func) - if hasattr(func, 'function'): + if hasattr(func, "function"): try: result: Any = func.function(x_math) return cast(Optional[float], result) @@ -37,7 +37,9 @@ def _eval_function(self, x_math: float) -> Optional[float]: return None return None - def _generate_function_points_math(self, left_bound: float, right_bound: float, num_points: int) -> List[Tuple[float, float]]: + def _generate_function_points_math( + self, left_bound: float, right_bound: float, num_points: int + ) -> List[Tuple[float, float]]: if num_points < 2: num_points = 2 dx: float = (right_bound - left_bound) / (num_points - 1) if num_points > 1 else 1.0 @@ -55,9 +57,9 @@ def _segment_reverse_points_math(self) -> Optional[List[Tuple[float, float]]]: p2: Any = self.area.segment.point2 if p1 is None or p2 is None: return None - if not hasattr(p1, 'x') or not hasattr(p1, 'y'): + if not hasattr(p1, "x") or not hasattr(p1, "y"): return None - if not hasattr(p2, 'x') or not hasattr(p2, 'y'): + if not hasattr(p2, "x") or not hasattr(p2, "y"): return None return [(p2.x, p2.y), (p1.x, p1.y)] @@ -76,4 +78,3 @@ def build_screen_area(self, num_points: int = 100) -> Optional[ClosedArea]: color=getattr(self.area, "color", None), opacity=getattr(self.area, "opacity", None), ) - diff --git a/static/client/rendering/renderables/functions_area_renderable.py b/static/client/rendering/renderables/functions_area_renderable.py index b2a653d6..ec0c472f 100644 --- a/static/client/rendering/renderables/functions_area_renderable.py +++ b/static/client/rendering/renderables/functions_area_renderable.py @@ -15,7 +15,7 @@ def __init__(self, area_model: Any, coordinate_mapper: Any) -> None: self.mapper: Any = coordinate_mapper def _is_function_like(self, f: Any) -> bool: - return hasattr(f, 'function') + return hasattr(f, "function") def _eval_y_math(self, f: Any, x_math: float) -> Optional[float]: if f is None: @@ -29,7 +29,7 @@ def _eval_y_math(self, f: Any, x_math: float) -> Optional[float]: return None if not isinstance(y, (int, float)): return None - if isinstance(y, float) and (y != y or abs(y) == float('inf')): + if isinstance(y, float) and (y != y or abs(y) == float("inf")): return None return y except Exception: @@ -44,12 +44,17 @@ def _get_bounds(self) -> Tuple[float, float]: except Exception: left, right = -10, 10 for f in (self.area.func1, self.area.func2): - if hasattr(f, 'left_bound') and hasattr(f, 'right_bound') and f.left_bound is not None and f.right_bound is not None: + if ( + hasattr(f, "left_bound") + and hasattr(f, "right_bound") + and f.left_bound is not None + and f.right_bound is not None + ): left = max(left, f.left_bound) right = min(right, f.right_bound) - if getattr(self.area, 'left_bound', None) is not None: + if getattr(self.area, "left_bound", None) is not None: left = max(left, self.area.left_bound) - if getattr(self.area, 'right_bound', None) is not None: + if getattr(self.area, "right_bound", None) is not None: right = min(right, self.area.right_bound) try: vis_left: float = self.mapper.get_visible_left_bound() @@ -63,7 +68,9 @@ def _get_bounds(self) -> Tuple[float, float]: left, right = c - 0.1, c + 0.1 return left, right - def _generate_pair_paths_screen(self, f1: Any, f2: Any, left: float, right: float, num_points: int) -> Tuple[List[Tuple[float, float]], List[Tuple[float, float]]]: + def _generate_pair_paths_screen( + self, f1: Any, f2: Any, left: float, right: float, num_points: int + ) -> Tuple[List[Tuple[float, float]], List[Tuple[float, float]]]: if num_points < 2: num_points = 2 dx: float = (right - left) / (num_points - 1) if num_points > 1 else 1.0 @@ -110,7 +117,7 @@ def build_screen_area(self, num_points: Optional[int] = None) -> Optional[Closed left: float right: float left, right = self._get_bounds() - n: int = num_points if num_points is not None else getattr(self.area, 'num_sample_points', 100) + n: int = num_points if num_points is not None else getattr(self.area, "num_sample_points", 100) fwd: List[Tuple[float, float]] rev: List[Tuple[float, float]] fwd, rev = self._generate_pair_paths_screen(self.area.func1, self.area.func2, left, right, n) @@ -123,4 +130,3 @@ def build_screen_area(self, num_points: Optional[int] = None) -> Optional[Closed color=getattr(self.area, "color", None), opacity=getattr(self.area, "opacity", None), ) - diff --git a/static/client/rendering/renderables/segments_area_renderable.py b/static/client/rendering/renderables/segments_area_renderable.py index 66db8fa8..802b69ff 100644 --- a/static/client/rendering/renderables/segments_area_renderable.py +++ b/static/client/rendering/renderables/segments_area_renderable.py @@ -17,7 +17,7 @@ def __init__(self, area_model: Any, coordinate_mapper: Any) -> None: self.mapper: Any = coordinate_mapper def _screen_xy(self, point: Any) -> Tuple[Optional[float], Optional[float]]: - if point is None or not hasattr(point, 'x') or not hasattr(point, 'y'): + if point is None or not hasattr(point, "x") or not hasattr(point, "y"): return None, None result: Any = self.mapper.math_to_screen(point.x, point.y) return cast(Tuple[Optional[float], Optional[float]], result) @@ -37,7 +37,7 @@ def _get_y_at_x_screen(self, segment: Any, x: float) -> Optional[float]: return y1 + t * (y2 - y1) def build_screen_area(self) -> Optional[ClosedArea]: - if not getattr(self.area, 'segment2', None): + if not getattr(self.area, "segment2", None): p1: Tuple[Optional[float], Optional[float]] = self._screen_xy(self.area.segment1.point1) p2: Tuple[Optional[float], Optional[float]] = self._screen_xy(self.area.segment1.point2) if None in p1 or None in p2: @@ -89,4 +89,3 @@ def build_screen_area(self) -> Optional[ClosedArea]: color=getattr(self.area, "color", None), opacity=getattr(self.area, "opacity", None), ) - diff --git a/static/client/rendering/style_manager.py b/static/client/rendering/style_manager.py index 6b508c32..525eb04f 100644 --- a/static/client/rendering/style_manager.py +++ b/static/client/rendering/style_manager.py @@ -45,59 +45,46 @@ "point_radius": default_point_size, "point_label_font_size": point_label_font_size, "point_label_font_family": default_font_family, - "label_text_color": default_color, "label_font_size": default_label_font_size, "label_font_family": default_font_family, - "segment_color": default_color, "segment_stroke_width": 1, - "circle_color": default_color, "circle_stroke_width": 1, - "ellipse_color": default_color, "ellipse_stroke_width": 1, - "vector_color": default_color, "vector_tip_size": default_point_size * 4, - "angle_color": default_color, "angle_arc_radius": DEFAULT_ANGLE_ARC_SCREEN_RADIUS, "angle_label_font_size": point_label_font_size, "angle_text_arc_radius_factor": DEFAULT_ANGLE_TEXT_ARC_RADIUS_FACTOR, "angle_label_font_family": default_font_family, - "circle_arc_color": DEFAULT_CIRCLE_ARC_COLOR, "circle_arc_stroke_width": DEFAULT_CIRCLE_ARC_STROKE_WIDTH, "circle_arc_radius_scale": DEFAULT_CIRCLE_ARC_RADIUS_SCALE, - "function_color": default_color, "function_stroke_width": 1, "function_label_font_size": point_label_font_size, "function_label_font_family": default_font_family, - "area_fill_color": default_area_fill_color, "area_opacity": 0.3, - "cartesian_axis_color": default_color, "cartesian_grid_color": "lightgrey", "cartesian_tick_size": 3, "cartesian_tick_font_size": 8, "cartesian_label_color": "grey", "cartesian_font_family": default_font_family, - "polar_axis_color": default_color, "polar_circle_color": "lightgrey", "polar_radial_color": "lightgrey", "polar_label_color": "grey", "polar_label_font_size": 8, "polar_font_family": default_font_family, - "fill_style": "rgba(0, 0, 0, 0)", "font_family": default_font_family, "canvas_background_color": "#ffffff", - # Bars "bar_label_padding_px": 6, } @@ -136,4 +123,3 @@ def get_default_style_value(key: str, default: Optional[Any] = None) -> Any: The default value for the given style key. """ return _BASE_STYLE.get(key, default) - diff --git a/static/client/rendering/svg_primitive_adapter.py b/static/client/rendering/svg_primitive_adapter.py index a3380dba..740490e8 100644 --- a/static/client/rendering/svg_primitive_adapter.py +++ b/static/client/rendering/svg_primitive_adapter.py @@ -439,7 +439,9 @@ def _format_number(self, value: Any) -> str: return str(int(num)) return str(num) - def _ensure_stroke_attrs(self, elem: Any, cache: Dict[str, Any], stroke: StrokeStyle, *, include_width: bool = True) -> None: + def _ensure_stroke_attrs( + self, elem: Any, cache: Dict[str, Any], stroke: StrokeStyle, *, include_width: bool = True + ) -> None: stroke_cache = cache.setdefault("stroke", {}) if stroke_cache.get("color") != stroke.color: self._set_attribute(elem, cache, "stroke", stroke.color) @@ -645,7 +647,9 @@ def _apply_draw_text_params( except Exception: rotation_deg = 0.0 if math.isfinite(rotation_deg) and rotation_deg != 0.0: - transform_value = f"rotate({-rotation_deg} {self._format_number(position[0])} {self._format_number(position[1])})" + transform_value = ( + f"rotate({-rotation_deg} {self._format_number(position[0])} {self._format_number(position[1])})" + ) self._set_attribute(elem, cache, "transform", transform_value) else: self._set_attribute(elem, cache, "transform", None) @@ -823,7 +827,9 @@ def stroke_arc( end_y = center[1] + radius * math.sin(end_angle_rad) large_arc_flag = "1" if abs(end_angle_rad - start_angle_rad) > math.pi else "0" sweep_flag = "1" if sweep_clockwise else "0" - radius_str = str(int(radius)) if isinstance(radius, (int, float)) and float(radius).is_integer() else str(radius) + radius_str = ( + str(int(radius)) if isinstance(radius, (int, float)) and float(radius).is_integer() else str(radius) + ) d = f"M {start_x} {start_y} A {radius_str} {radius_str} 0 {large_arc_flag} {sweep_flag} {end_x} {end_y}" kwargs = self._stroke_kwargs(stroke) kwargs["fill"] = "none" @@ -890,4 +896,3 @@ def resize_surface(self, width: float, height: float) -> None: surface = self._surface surface.setAttribute("width", str(width)) surface.setAttribute("height", str(height)) - diff --git a/static/client/rendering/svg_renderer.py b/static/client/rendering/svg_renderer.py index 0c836477..07f39d34 100644 --- a/static/client/rendering/svg_renderer.py +++ b/static/client/rendering/svg_renderer.py @@ -386,9 +386,7 @@ def render_cartesian(self, cartesian: Any, coordinate_mapper: Any) -> None: drawable_name = "Cartesian2Axis" map_state = self._capture_map_state(coordinate_mapper) signature = self._compute_drawable_signature(cartesian, coordinate_mapper) - plan_context = self._resolve_cartesian_plan( - cartesian, coordinate_mapper, map_state, signature, drawable_name - ) + plan_context = self._resolve_cartesian_plan(cartesian, coordinate_mapper, map_state, signature, drawable_name) if plan_context is None: return self._cartesian_rendered_this_frame = True @@ -403,9 +401,7 @@ def render_polar(self, polar_grid: Any, coordinate_mapper: Any) -> None: drawable_name = "PolarGrid" map_state = self._capture_map_state(coordinate_mapper) signature = self._compute_drawable_signature(polar_grid, coordinate_mapper) - plan_context = self._resolve_polar_plan( - polar_grid, coordinate_mapper, map_state, signature, drawable_name - ) + plan_context = self._resolve_polar_plan(polar_grid, coordinate_mapper, map_state, signature, drawable_name) if plan_context is None: return self._cartesian_rendered_this_frame = True @@ -433,9 +429,7 @@ def _resolve_polar_plan( else: if plan_entry is not None: self._drop_plan_group(plan_entry.get("plan")) - plan = self._build_polar_plan_with_metrics( - polar_grid, coordinate_mapper, map_state, drawable_name - ) + plan = self._build_polar_plan_with_metrics(polar_grid, coordinate_mapper, map_state, drawable_name) if plan is None: self._cartesian_cache = None return None @@ -808,9 +802,7 @@ def _resolve_cartesian_plan( else: if plan_entry is not None: self._drop_plan_group(plan_entry.get("plan")) - plan = self._build_cartesian_plan_with_metrics( - cartesian, coordinate_mapper, map_state, drawable_name - ) + plan = self._build_cartesian_plan_with_metrics(cartesian, coordinate_mapper, map_state, drawable_name) if plan is None: self._cartesian_cache = None return None @@ -871,9 +863,7 @@ def _resolve_drawable_plan_context( else: if cached_entry is not None: self._drop_plan_group(cached_entry.get("plan")) - plan = self._build_drawable_plan_with_metrics( - drawable, coordinate_mapper, map_state, drawable_name - ) + plan = self._build_drawable_plan_with_metrics(drawable, coordinate_mapper, map_state, drawable_name) if plan is None: self._plan_cache.pop(cache_key, None) return None @@ -1012,5 +1002,3 @@ def _apply_plan_transform(self, plan_key: Optional[str], plan: OptimizedPrimitiv transform = getattr(plan, "get_transform", lambda: None)() if callable(set_transform): set_transform(plan_key, transform) - - diff --git a/static/client/rendering/webgl_primitive_adapter.py b/static/client/rendering/webgl_primitive_adapter.py index 25b7d592..c4558982 100644 --- a/static/client/rendering/webgl_primitive_adapter.py +++ b/static/client/rendering/webgl_primitive_adapter.py @@ -215,4 +215,3 @@ def _sample_arc( theta = start_angle_rad + step * i samples.append((center[0] + radius * math.cos(theta), center[1] + radius * math.sin(theta))) return samples - diff --git a/static/client/rendering/webgl_renderer.py b/static/client/rendering/webgl_renderer.py index 0d2d46e5..71cb3f2e 100644 --- a/static/client/rendering/webgl_renderer.py +++ b/static/client/rendering/webgl_renderer.py @@ -101,51 +101,63 @@ def register(self, cls: type, handler: Callable[[Any, Any], None]) -> None: def register_default_drawables(self) -> None: try: from drawables.point import Point as PointDrawable + self.register(PointDrawable, self._render_point) except Exception: pass try: from drawables.segment import Segment as SegmentDrawable + self.register(SegmentDrawable, self._render_segment) except Exception: pass try: from drawables.circle import Circle as CircleDrawable + self.register(CircleDrawable, self._render_circle) except Exception: pass try: from drawables.circle_arc import CircleArc as CircleArcDrawable + self.register(CircleArcDrawable, self._render_circle_arc) except Exception: pass try: from drawables.functions_bounded_colored_area import FunctionsBoundedColoredArea as FunctionsAreaDrawable + self.register(FunctionsAreaDrawable, self._render_drawable) except Exception: pass try: - from drawables.function_segment_bounded_colored_area import FunctionSegmentBoundedColoredArea as FunctionSegmentAreaDrawable + from drawables.function_segment_bounded_colored_area import ( + FunctionSegmentBoundedColoredArea as FunctionSegmentAreaDrawable, + ) + self.register(FunctionSegmentAreaDrawable, self._render_drawable) except Exception: pass try: from drawables.segments_bounded_colored_area import SegmentsBoundedColoredArea as SegmentsAreaDrawable + self.register(SegmentsAreaDrawable, self._render_drawable) except Exception: pass try: from drawables.closed_shape_colored_area import ClosedShapeColoredArea as ClosedShapeAreaDrawable + self.register(ClosedShapeAreaDrawable, self._render_drawable) except Exception: pass try: from drawables.label import Label as LabelDrawable + self.register(LabelDrawable, self._render_label) except Exception: pass try: from drawables.bar import Bar as BarDrawable + self.register(BarDrawable, self._render_drawable) except Exception: pass @@ -194,7 +206,9 @@ def _render_drawable(self, drawable: Any, coordinate_mapper: Any) -> None: # ------------------------------------------------------------------ # Drawing helpers - def _draw_points(self, points: Sequence[Tuple[float, float]], color: Tuple[float, float, float, float], size: float) -> None: + def _draw_points( + self, points: Sequence[Tuple[float, float]], color: Tuple[float, float, float, float], size: float + ) -> None: self._use_program_for_draw() flat = self._prepare_vertices(points) self._set_color_uniform(color) @@ -440,5 +454,3 @@ def _apply_canvas_style_defaults(self, canvas_el: Any) -> None: canvas_el.style.pointerEvents = "none" canvas_el.style.display = "block" canvas_el.style.zIndex = "20" - - diff --git a/static/client/result_processor.py b/static/client/result_processor.py index 53c4fea4..2634a7eb 100644 --- a/static/client/result_processor.py +++ b/static/client/result_processor.py @@ -48,7 +48,12 @@ class ResultProcessor: """ @staticmethod - def get_results(calls: List[Dict[str, Any]], available_functions: Dict[str, Any], undoable_functions: Tuple[str, ...], canvas: "Canvas") -> Dict[str, Any]: + def get_results( + calls: List[Dict[str, Any]], + available_functions: Dict[str, Any], + undoable_functions: Tuple[str, ...], + canvas: "Canvas", + ) -> Dict[str, Any]: """ Process function calls and collect their results. @@ -66,20 +71,23 @@ def get_results(calls: List[Dict[str, Any]], available_functions: Dict[str, Any] results: Dict[str, Any] = {} # Use a dictionary for results non_computation_functions: Tuple[str, ...] unformattable_functions: Tuple[str, ...] - non_computation_functions, unformattable_functions = ResultProcessor._prepare_helper_variables(undoable_functions) + non_computation_functions, unformattable_functions = ResultProcessor._prepare_helper_variables( + undoable_functions + ) # Archive once at the start and then suspend archiving while calling undoable functions - contains_undoable_function: bool = any(call.get('function_name', '') in undoable_functions for call in calls) + contains_undoable_function: bool = any(call.get("function_name", "") in undoable_functions for call in calls) if contains_undoable_function: canvas.archive() # Process each function call for call in calls: try: - ResultProcessor._process_function_call(call, available_functions, - non_computation_functions, unformattable_functions, canvas, results) + ResultProcessor._process_function_call( + call, available_functions, non_computation_functions, unformattable_functions, canvas, results + ) except Exception as e: - function_name: str = call.get('function_name', '') + function_name: str = call.get("function_name", "") ResultProcessor._handle_exception(e, function_name, results) return results @@ -111,19 +119,21 @@ def get_results_traced( traced_calls: List[TracedCall] = [] non_computation_functions: Tuple[str, ...] unformattable_functions: Tuple[str, ...] - non_computation_functions, unformattable_functions = ResultProcessor._prepare_helper_variables(undoable_functions) + non_computation_functions, unformattable_functions = ResultProcessor._prepare_helper_variables( + undoable_functions + ) # Archive once at the start (same as get_results) - contains_undoable_function: bool = any(call.get('function_name', '') in undoable_functions for call in calls) + contains_undoable_function: bool = any(call.get("function_name", "") in undoable_functions for call in calls) if contains_undoable_function: canvas.archive() for seq, call in enumerate(calls): - function_name = call.get('function_name', '') - args = call.get('arguments', {}) + function_name = call.get("function_name", "") + args = call.get("arguments", {}) # Sanitize arguments for trace: exclude canvas ref, guard against non-dict if isinstance(args, dict): - sanitized_args = {k: v for k, v in args.items() if k != 'canvas'} + sanitized_args = {k: v for k, v in args.items() if k != "canvas"} else: sanitized_args = {"_raw": args} @@ -133,8 +143,12 @@ def get_results_traced( try: snapshot_before = dict(results) ResultProcessor._process_function_call( - call, available_functions, - non_computation_functions, unformattable_functions, canvas, results, + call, + available_functions, + non_computation_functions, + unformattable_functions, + canvas, + results, ) # Extract result: find the key that was added or changed for rk, rv in results.items(): @@ -152,19 +166,23 @@ def get_results_traced( is_error = True duration_ms = window.performance.now() - t0 - traced_calls.append({ - "seq": seq, - "function_name": function_name, - "arguments": sanitized_args, - "result": result_value, - "is_error": is_error, - "duration_ms": round(duration_ms, 2), - }) + traced_calls.append( + { + "seq": seq, + "function_name": function_name, + "arguments": sanitized_args, + "result": result_value, + "is_error": is_error, + "duration_ms": round(duration_ms, 2), + } + ) return results, traced_calls @staticmethod - def _validate_inputs(calls: List[Dict[str, Any]], available_functions: Dict[str, Any], undoable_functions: Tuple[str, ...]) -> None: + def _validate_inputs( + calls: List[Dict[str, Any]], available_functions: Dict[str, Any], undoable_functions: Tuple[str, ...] + ) -> None: """Validate the input parameters.""" if not isinstance(calls, list): raise ValueError("Invalid input for calls.") @@ -176,15 +194,25 @@ def _validate_inputs(calls: List[Dict[str, Any]], available_functions: Dict[str, @staticmethod def _prepare_helper_variables(undoable_functions: Tuple[str, ...]) -> Tuple[Tuple[str, ...], Tuple[str, ...]]: """Prepare helper variables needed for processing.""" - unformattable_functions: Tuple[str, ...] = undoable_functions + ('undo', 'redo') - non_computation_functions: Tuple[str, ...] = unformattable_functions + ('run_tests', 'list_workspaces', - 'save_workspace', 'load_workspace', - 'delete_workspace') + unformattable_functions: Tuple[str, ...] = undoable_functions + ("undo", "redo") + non_computation_functions: Tuple[str, ...] = unformattable_functions + ( + "run_tests", + "list_workspaces", + "save_workspace", + "load_workspace", + "delete_workspace", + ) return non_computation_functions, unformattable_functions @staticmethod - def _process_function_call(call: Dict[str, Any], available_functions: Dict[str, Any], - non_computation_functions: Tuple[str, ...], unformattable_functions: Tuple[str, ...], canvas: "Canvas", results: Dict[str, Any]) -> None: + def _process_function_call( + call: Dict[str, Any], + available_functions: Dict[str, Any], + non_computation_functions: Tuple[str, ...], + unformattable_functions: Tuple[str, ...], + canvas: "Canvas", + results: Dict[str, Any], + ) -> None: """ Process a single function call and update results. @@ -196,25 +224,28 @@ def _process_function_call(call: Dict[str, Any], available_functions: Dict[str, canvas: Canvas instance for adding computations results: Dictionary to update with the results """ - function_name: str = call.get('function_name', '') + function_name: str = call.get("function_name", "") # Check if function exists if not ResultProcessor._is_function_available(function_name, available_functions, results): return # Execute the function - args: Dict[str, Any] = call.get('arguments', {}) + args: Dict[str, Any] = call.get("arguments", {}) result: Any = ResultProcessor._execute_function(function_name, args, available_functions) # Format the key for results dictionary key: str = ResultProcessor._generate_result_key(function_name, args) # Process the result based on function type - ResultProcessor._process_result(function_name, args, result, key, unformattable_functions, - non_computation_functions, canvas, results) + ResultProcessor._process_result( + function_name, args, result, key, unformattable_functions, non_computation_functions, canvas, results + ) @staticmethod - def _is_function_available(function_name: str, available_functions: Dict[str, Any], results: Dict[str, Any]) -> bool: + def _is_function_available( + function_name: str, available_functions: Dict[str, Any], results: Dict[str, Any] + ) -> bool: """Check if the function exists and update results if not.""" if function_name not in available_functions: error_msg: str = f"Error: function {function_name} not found." @@ -236,20 +267,30 @@ def _generate_result_key(function_name: str, args: Dict[str, Any]) -> str: return f"{function_name}({formatted_args})" @staticmethod - def _process_result(function_name: str, args: Dict[str, Any], result: Any, key: str, unformattable_functions: Tuple[str, ...], - non_computation_functions: Tuple[str, ...], canvas: "Canvas", results: Dict[str, Any]) -> None: + def _process_result( + function_name: str, + args: Dict[str, Any], + result: Any, + key: str, + unformattable_functions: Tuple[str, ...], + non_computation_functions: Tuple[str, ...], + canvas: "Canvas", + results: Dict[str, Any], + ) -> None: """Process the result based on function type and update results dictionary.""" if function_name in unformattable_functions: # Handle unformattable functions (return success message) ResultProcessor._handle_unformattable_function(key, results) - elif function_name == 'evaluate_expression' and 'expression' in args: + elif function_name == "evaluate_expression" and "expression" in args: # Handle expression evaluation - ResultProcessor._handle_expression_evaluation(args, result, function_name, - non_computation_functions, canvas, results) + ResultProcessor._handle_expression_evaluation( + args, result, function_name, non_computation_functions, canvas, results + ) else: # Handle regular functions - ResultProcessor._handle_regular_function(key, result, function_name, - non_computation_functions, canvas, results) + ResultProcessor._handle_regular_function( + key, result, function_name, non_computation_functions, canvas, results + ) @staticmethod def _handle_unformattable_function(key: str, results: Dict[str, Any]) -> None: @@ -257,7 +298,14 @@ def _handle_unformattable_function(key: str, results: Dict[str, Any]) -> None: results[key] = successful_call_message @staticmethod - def _handle_regular_function(key: str, result: Any, function_name: str, non_computation_functions: Tuple[str, ...], canvas: "Canvas", results: Dict[str, Any]) -> None: + def _handle_regular_function( + key: str, + result: Any, + function_name: str, + non_computation_functions: Tuple[str, ...], + canvas: "Canvas", + results: Dict[str, Any], + ) -> None: """Handle result for regular functions.""" # Save computation to canvas state if it's not a non-computation function # DISABLED: Saving basic calculations to canvas state (takes up too many tokens, not useful info to store) @@ -268,18 +316,17 @@ def _handle_regular_function(key: str, result: Any, function_name: str, non_comp @staticmethod def _format_arguments(args: Dict[str, Any]) -> str: """Format function arguments for display.""" - return ', '.join(f"{k}:{v}" for k, v in args.items() if k != 'canvas') + return ", ".join(f"{k}:{v}" for k, v in args.items() if k != "canvas") @staticmethod - def _add_computation_if_needed(result: Any, function_name: str, non_computation_functions: Tuple[str, ...], - expression: str, canvas: "Canvas") -> None: + def _add_computation_if_needed( + result: Any, function_name: str, non_computation_functions: Tuple[str, ...], expression: str, canvas: "Canvas" + ) -> None: """Add the computation to canvas if it's not a non-computation function and succeeded.""" - if (not isinstance(result, str) or not result.startswith("Error:")) and \ - function_name not in non_computation_functions: - canvas.add_computation( - expression=expression, - result=result - ) + if ( + not isinstance(result, str) or not result.startswith("Error:") + ) and function_name not in non_computation_functions: + canvas.add_computation(expression=expression, result=result) @staticmethod def _handle_exception(exception: Exception, function_name: str, results: Dict[str, Any]) -> None: @@ -301,7 +348,14 @@ def _handle_exception(exception: Exception, function_name: str, results: Dict[st results[key] = f"Error: {str(exception)}" @staticmethod - def _handle_expression_evaluation(args: Dict[str, Any], result: Any, function_name: str, non_computation_functions: Tuple[str, ...], canvas: "Canvas", results: Dict[str, Any]) -> None: + def _handle_expression_evaluation( + args: Dict[str, Any], + result: Any, + function_name: str, + non_computation_functions: Tuple[str, ...], + canvas: "Canvas", + results: Dict[str, Any], + ) -> None: """ Handle the special case of expression evaluation results. @@ -313,10 +367,10 @@ def _handle_expression_evaluation(args: Dict[str, Any], result: Any, function_na canvas: Canvas instance for adding computations results: Dictionary to update with the result """ - expression: str = args.get('expression', '') + expression: str = args.get("expression", "") if not expression: return - expression = expression.replace(' ', '') + expression = expression.replace(" ", "") key: str = ResultProcessor._format_expression_key(expression, args) # DISABLED: Saving expression evaluation computations to canvas state (takes up too many tokens, not useful info to store) @@ -328,11 +382,11 @@ def _handle_expression_evaluation(args: Dict[str, Any], result: Any, function_na @staticmethod def _format_expression_key(expression: str, args: Dict[str, Any]) -> str: """Format a key for expression evaluation results.""" - if 'variables' in args: - variables_dict: Any = args.get('variables', {}) + if "variables" in args: + variables_dict: Any = args.get("variables", {}) if not isinstance(variables_dict, dict): variables_dict = {} - variables: str = ', '.join(f"{k}:{v}" for k, v in variables_dict.items()) + variables: str = ", ".join(f"{k}:{v}" for k, v in variables_dict.items()) return f"{expression} for {variables}" else: return expression diff --git a/static/client/slash_command_handler.py b/static/client/slash_command_handler.py index f8762005..9034b7db 100644 --- a/static/client/slash_command_handler.py +++ b/static/client/slash_command_handler.py @@ -44,6 +44,7 @@ class CommandResult: message: Human-readable result message data: Optional additional data from the command """ + success: bool message: str data: Optional[Any] = None @@ -59,6 +60,7 @@ class CommandInfo: handler: Callable that executes the command usage: Optional usage string (e.g., "/save [name]") """ + name: str description: str handler: Callable[[List[str]], CommandResult] @@ -653,6 +655,7 @@ def _cmd_status(self, args: List[str]) -> CommandResult: def _selected_model_has_vision(self) -> bool: """Check if the currently selected AI model supports vision.""" from browser import window + try: model_selector = document["ai-model-selector"] vision_models = list(window.VISION_MODELS) diff --git a/static/client/test_runner.py b/static/client/test_runner.py index c6f9ac50..6dc5c08b 100644 --- a/static/client/test_runner.py +++ b/static/client/test_runner.py @@ -49,7 +49,10 @@ class TestRunner: internal_errors (list): Collection of test errors for analysis internal_tests_run (int): Counter of executed test cases """ - def __init__(self, canvas: "Canvas", available_functions: Dict[str, Any], undoable_functions: Tuple[str, ...]) -> None: + + def __init__( + self, canvas: "Canvas", available_functions: Dict[str, Any], undoable_functions: Tuple[str, ...] + ) -> None: """Initialize test runner with canvas and function registry access. Sets up testing environment with access to canvas operations and function validation. @@ -113,9 +116,19 @@ def _build_graphics_call_key(self, call: Dict[str, Any]) -> str: # Extract the most identifying argument for the key key_parts: List[str] = [] - for arg_name in ("name", "triangle_name", "rectangle_name", "circle_name", - "ellipse_name", "arc_name", "angle_name", "expression", - "polygon_segment_names", "drawable1_name", "function_string"): + for arg_name in ( + "name", + "triangle_name", + "rectangle_name", + "circle_name", + "ellipse_name", + "arc_name", + "angle_name", + "expression", + "polygon_segment_names", + "drawable1_name", + "function_string", + ): if arg_name in args: value = args[arg_name] key_parts.append(f"{arg_name}:{value}") @@ -141,25 +154,18 @@ def _get_graphics_test_function_calls(self) -> List[Dict[str, Any]]: "label_visible": True, }, }, - { - "function_name": "create_point", - "arguments": {"x": -290.0, "y": 240.0, "name": "P"} - }, + {"function_name": "create_point", "arguments": {"x": -290.0, "y": 240.0, "name": "P"}}, { "function_name": "create_vector", - "arguments": {"origin_x": -143.0, "origin_y": 376.0, "tip_x": -82.0, "tip_y": 272.0, "name": "v1"} + "arguments": {"origin_x": -143.0, "origin_y": 376.0, "tip_x": -82.0, "tip_y": 272.0, "name": "v1"}, }, { "function_name": "create_polygon", "arguments": { - "vertices": [ - {"x": -60, "y": 380}, - {"x": 60, "y": 380}, - {"x": 0, "y": 260} - ], + "vertices": [{"x": -60, "y": 380}, {"x": 60, "y": 380}, {"x": 0, "y": 260}], "polygon_type": "triangle", - "name": "A'BD" - } + "name": "A'BD", + }, }, { "function_name": "create_polygon", @@ -168,20 +174,20 @@ def _get_graphics_test_function_calls(self) -> List[Dict[str, Any]]: {"x": 170.0, "y": 380.0}, {"x": 238.0, "y": 380.0}, {"x": 238.0, "y": 260.0}, - {"x": 170.0, "y": 260.0} + {"x": 170.0, "y": 260.0}, ], "polygon_type": "rectangle", "color": None, - "name": "REFT" - } + "name": "REFT", + }, }, { "function_name": "create_circle", - "arguments": {"center_x": -360, "center_y": 320, "radius": 60, "name": "G(60)"} + "arguments": {"center_x": -360, "center_y": 320, "radius": 60, "name": "G(60)"}, }, { "function_name": "create_ellipse", - "arguments": {"center_x": 360, "center_y": 320, "radius_x": 80, "radius_y": 50, "name": "I(80, 50)"} + "arguments": {"center_x": 360, "center_y": 320, "radius_x": 80, "radius_y": 50, "name": "I(80, 50)"}, }, # 12-sided polygon below the x-axis { @@ -199,33 +205,60 @@ def _get_graphics_test_function_calls(self) -> List[Dict[str, Any]]: {"x": 40.0, "y": -360.0}, {"x": -20.0, "y": -350.0}, {"x": -70.0, "y": -320.0}, - {"x": -110.0, "y": -280.0} + {"x": -110.0, "y": -280.0}, ], "polygon_type": "generic", "name": "JOKLMNQSUWXY", "color": None, - "subtype": None - } + "subtype": None, + }, }, { "function_name": "draw_function", - "arguments": {"function_string": "50 * sin(x / 50)", "name": "f1", "left_bound": -300, "right_bound": 300} + "arguments": { + "function_string": "50 * sin(x / 50)", + "name": "f1", + "left_bound": -300, + "right_bound": 300, + }, }, { "function_name": "draw_function", - "arguments": {"function_string": "100 * sin(x / 30)", "name": "f2", "left_bound": -300, "right_bound": 300} + "arguments": { + "function_string": "100 * sin(x / 30)", + "name": "f2", + "left_bound": -300, + "right_bound": 300, + }, }, { "function_name": "draw_function", - "arguments": {"function_string": "100 * sin(x / 50) + 50 * tan(x / 100)", "name": "f3", "left_bound": -300, "right_bound": 300} + "arguments": { + "function_string": "100 * sin(x / 50) + 50 * tan(x / 100)", + "name": "f3", + "left_bound": -300, + "right_bound": 300, + }, }, { "function_name": "draw_function", - "arguments": {"function_string": "-1/x", "name": "f4", "left_bound": -10, "right_bound": 10, "color": "red"} + "arguments": { + "function_string": "-1/x", + "name": "f4", + "left_bound": -10, + "right_bound": 10, + "color": "red", + }, }, { "function_name": "draw_function", - "arguments": {"function_string": "tan(x)", "name": "f5", "left_bound": -5, "right_bound": 5, "color": "orange"} + "arguments": { + "function_string": "tan(x)", + "name": "f5", + "left_bound": -5, + "right_bound": 5, + "color": "orange", + }, }, # Parametric function: red spiral { @@ -235,8 +268,8 @@ def _get_graphics_test_function_calls(self) -> List[Dict[str, Any]]: "y_expression": "t*sin(t)", "t_min": 0, "t_max": 50, - "color": "red" - } + "color": "red", + }, }, { "function_name": "create_circle_arc", @@ -250,8 +283,8 @@ def _get_graphics_test_function_calls(self) -> List[Dict[str, Any]]: "radius": 60.06, "arc_name": "arc_ZB'", "color": "orange", - "use_major_arc": False - } + "use_major_arc": False, + }, }, { "function_name": "create_circle_arc", @@ -265,8 +298,8 @@ def _get_graphics_test_function_calls(self) -> List[Dict[str, Any]]: "radius": 109.0, "arc_name": "arc_C'D'", "color": "purple", - "use_major_arc": True - } + "use_major_arc": True, + }, }, { "function_name": "create_segment", @@ -284,76 +317,47 @@ def _get_graphics_test_function_calls(self) -> List[Dict[str, Any]]: { "function_name": "create_region_colored_area", "arguments": { - "polygon_segment_names": [ - "JO", - "OK", - "KL", - "LM", - "MN", - "NQ", - "QS", - "SU", - "UW", - "WX", - "XY", - "YJ" - ], + "polygon_segment_names": ["JO", "OK", "KL", "LM", "MN", "NQ", "QS", "SU", "UW", "WX", "XY", "YJ"], "color": "plum", - "opacity": 0.3 - } + "opacity": 0.3, + }, }, { "function_name": "create_colored_area", - "arguments": {"drawable1_name": "f1", "drawable2_name": "f2", "color": "orange", "opacity": 0.3} + "arguments": {"drawable1_name": "f1", "drawable2_name": "f2", "color": "orange", "opacity": 0.3}, }, { "function_name": "create_colored_area", - "arguments": {"drawable1_name": "f2", "drawable2_name": "x_axis", "color": "lightgreen", "opacity": 0.3} + "arguments": { + "drawable1_name": "f2", + "drawable2_name": "x_axis", + "color": "lightgreen", + "opacity": 0.3, + }, }, { "function_name": "create_colored_area", - "arguments": {"drawable1_name": "f1", "drawable2_name": "f3", "color": "lightblue", "opacity": 0.3} + "arguments": {"drawable1_name": "f1", "drawable2_name": "f3", "color": "lightblue", "opacity": 0.3}, }, { "function_name": "create_colored_area", - "arguments": { - "drawable1_name": "f3", - "drawable2_name": "CH", - "color": "lightgray", - "opacity": 0.25 - } + "arguments": {"drawable1_name": "f3", "drawable2_name": "CH", "color": "lightgray", "opacity": 0.25}, }, { "function_name": "create_region_colored_area", - "arguments": { - "triangle_name": "A'BD", - "color": "orange", - "opacity": 0.4 - } + "arguments": {"triangle_name": "A'BD", "color": "orange", "opacity": 0.4}, }, { "function_name": "create_region_colored_area", - "arguments": { - "circle_name": "G(60)", - "color": "red", - "opacity": 0.35 - } + "arguments": {"circle_name": "G(60)", "color": "red", "opacity": 0.35}, }, { "function_name": "create_region_colored_area", - "arguments": { - "rectangle_name": "REFT", - "color": "green", - "opacity": 0.35 - } + "arguments": {"rectangle_name": "REFT", "color": "green", "opacity": 0.35}, }, { "function_name": "create_region_colored_area", - "arguments": { - "ellipse_name": "I(80, 50)", - "color": "blue", - "opacity": 0.35 - } + "arguments": {"ellipse_name": "I(80, 50)", "color": "blue", "opacity": 0.35}, }, { "function_name": "create_angle", @@ -366,8 +370,8 @@ def _get_graphics_test_function_calls(self) -> List[Dict[str, Any]]: "p2y": 230.0, "color": "red", "angle_name": "Angle1", - "is_reflex": True - } + "is_reflex": True, + }, }, { "function_name": "create_angle", @@ -380,8 +384,8 @@ def _get_graphics_test_function_calls(self) -> List[Dict[str, Any]]: "p2y": 219.3, "color": "blue", "angle_name": "Angle2", - "is_reflex": False - } + "is_reflex": False, + }, }, { "function_name": "create_angle", @@ -394,8 +398,8 @@ def _get_graphics_test_function_calls(self) -> List[Dict[str, Any]]: "p2y": 150.3, "color": "green", "angle_name": "Angle3", - "is_reflex": False - } + "is_reflex": False, + }, }, { "function_name": "create_label", @@ -423,11 +427,7 @@ def _get_graphics_test_function_calls(self) -> List[Dict[str, Any]]: }, { "function_name": "create_region_colored_area", - "arguments": { - "expression": "ArcMaj_C'D' & A''E'", - "color": "#D2B48C", - "opacity": 0.5 - } + "arguments": {"expression": "ArcMaj_C'D' & A''E'", "color": "#D2B48C", "opacity": 0.5}, }, { "function_name": "update_segment", @@ -447,12 +447,7 @@ def _get_graphics_test_function_calls(self) -> List[Dict[str, Any]]: "directed": True, "root": "R", "layout": "tree", - "placement_box": { - "x": -500, - "y": -350, - "width": 300, - "height": 325 - }, + "placement_box": {"x": -500, "y": -350, "width": 300, "height": 325}, "vertices": [ {"name": "R", "x": None, "y": None, "color": None, "label": None}, {"name": "L1", "x": None, "y": None, "color": None, "label": None}, @@ -460,7 +455,7 @@ def _get_graphics_test_function_calls(self) -> List[Dict[str, Any]]: {"name": "L1A", "x": None, "y": None, "color": None, "label": None}, {"name": "L1B", "x": None, "y": None, "color": None, "label": None}, {"name": "L2A", "x": None, "y": None, "color": None, "label": None}, - {"name": "L2B", "x": None, "y": None, "color": None, "label": None} + {"name": "L2B", "x": None, "y": None, "color": None, "label": None}, ], "edges": [ {"source": 0, "target": 1, "weight": None, "name": "R_L1", "color": None, "directed": True}, @@ -468,27 +463,15 @@ def _get_graphics_test_function_calls(self) -> List[Dict[str, Any]]: {"source": 1, "target": 3, "weight": None, "name": "L1_L1A", "color": None, "directed": True}, {"source": 1, "target": 4, "weight": None, "name": "L1_L1B", "color": None, "directed": True}, {"source": 2, "target": 5, "weight": None, "name": "L2_L2A", "color": None, "directed": True}, - {"source": 2, "target": 6, "weight": None, "name": "L2_L2B", "color": None, "directed": True} + {"source": 2, "target": 6, "weight": None, "name": "L2_L2B", "color": None, "directed": True}, ], - "adjacency_matrix": None - } - }, - { - "function_name": "clear_canvas", - "arguments": {} - }, - { - "function_name": "undo", - "arguments": {} - }, - { - "function_name": "redo", - "arguments": {} + "adjacency_matrix": None, + }, }, - { - "function_name": "undo", - "arguments": {} - } + {"function_name": "clear_canvas", "arguments": {}}, + {"function_name": "undo", "arguments": {}}, + {"function_name": "redo", "arguments": {}}, + {"function_name": "undo", "arguments": {}}, ] def _test_undoable_functions(self) -> bool: @@ -496,13 +479,13 @@ def _test_undoable_functions(self) -> bool: self.internal_tests_run += 1 try: self._validate_undoable_functions() - print("All undoable functions are available.") # DEBUG + print("All undoable functions are available.") # DEBUG return True except Exception as e: error_message: str = f"Error in undoable functions test: {str(e)}" print(error_message) - if not any(failure['test'] == 'Undoable Functions Test' for failure in self.internal_failures): - self._add_internal_failure('Undoable Functions Test', error_message) + if not any(failure["test"] == "Undoable Functions Test" for failure in self.internal_failures): + self._add_internal_failure("Undoable Functions Test", error_message) return False def _validate_undoable_functions(self) -> None: @@ -510,22 +493,16 @@ def _validate_undoable_functions(self) -> None: for function_name in self.undoable_functions: if function_name not in self.available_functions: error_message: str = f"Function '{function_name}' is not available." - self._add_internal_failure('Undoable Functions Test', error_message) + self._add_internal_failure("Undoable Functions Test", error_message) raise Exception(error_message) def _add_internal_failure(self, test_name: str, error_message: str) -> None: """Add a failure to the internal failures list.""" - self.internal_failures.append({ - 'test': test_name, - 'error': error_message - }) + self.internal_failures.append({"test": test_name, "error": error_message}) def _add_internal_error(self, test_name: str, error_message: str) -> None: """Add an error to the internal errors list.""" - self.internal_errors.append({ - 'test': test_name, - 'error': error_message - }) + self.internal_errors.append({"test": test_name, "error": error_message}) def run_tests(self) -> Dict[str, Any]: """Run unit tests for the graphics and function capabilities.""" @@ -612,6 +589,7 @@ def _run_client_tests(self) -> Dict[str, Any]: """Run the client-side tests and return the results.""" try: from client_tests.tests import run_tests + return cast(Dict[str, Any], run_tests()) except ImportError as e: print(f"client_tests import failed: {e}") @@ -625,6 +603,7 @@ async def _run_client_tests_async( """Run the client-side tests asynchronously and return the results.""" try: from client_tests.tests import run_tests_async + return cast(Dict[str, Any], await run_tests_async(should_stop=should_stop)) except ImportError as e: print(f"client_tests import failed: {e}") @@ -639,43 +618,43 @@ def _merge_test_results(self, client_results: Dict[str, Any]) -> Dict[str, Any]: merged_results: Dict[str, Any] = client_results.copy() # Add internal failures and errors to client test results - merged_results['failures'].extend(self.internal_failures) - merged_results['errors'].extend(self.internal_errors) + merged_results["failures"].extend(self.internal_failures) + merged_results["errors"].extend(self.internal_errors) # Update summary - merged_results['summary']['tests'] += self.internal_tests_run - merged_results['summary']['failures'] += len(self.internal_failures) - merged_results['summary']['errors'] += len(self.internal_errors) + merged_results["summary"]["tests"] += self.internal_tests_run + merged_results["summary"]["failures"] += len(self.internal_failures) + merged_results["summary"]["errors"] += len(self.internal_errors) return merged_results def _create_results_from_internal_only(self) -> Dict[str, Any]: """Create test results containing only internal test results.""" return { - 'failures': self.internal_failures, - 'errors': self.internal_errors, - 'summary': { - 'tests': self.internal_tests_run, - 'failures': len(self.internal_failures), - 'errors': len(self.internal_errors) - } + "failures": self.internal_failures, + "errors": self.internal_errors, + "summary": { + "tests": self.internal_tests_run, + "failures": len(self.internal_failures), + "errors": len(self.internal_errors), + }, } def _create_results_with_client_error(self, error_message: str) -> Dict[str, Any]: """Create test results with internal results plus a client test runner error.""" client_error: Dict[str, str] = { - 'test': 'Client Tests Runner', - 'error': f"Error running client tests: {error_message}" + "test": "Client Tests Runner", + "error": f"Error running client tests: {error_message}", } return { - 'failures': self.internal_failures, - 'errors': self.internal_errors + [client_error], - 'summary': { - 'tests': self.internal_tests_run, - 'failures': len(self.internal_failures), - 'errors': len(self.internal_errors) + 1 - } + "failures": self.internal_failures, + "errors": self.internal_errors + [client_error], + "summary": { + "tests": self.internal_tests_run, + "failures": len(self.internal_failures), + "errors": len(self.internal_errors) + 1, + }, } def get_test_results(self) -> Optional[Dict[str, Any]]: @@ -702,30 +681,24 @@ def format_results_for_ai(self, results: Dict[str, Any]) -> Dict[str, Any]: def _create_formatted_results_summary(self, results: Dict[str, Any]) -> Dict[str, Any]: """Create a basic summary of test results for AI consumption.""" return { - "tests_run": results['summary']['tests'], - "failures": results['summary']['failures'], - "errors": results['summary']['errors'], + "tests_run": results["summary"]["tests"], + "failures": results["summary"]["failures"], + "errors": results["summary"]["errors"], "failing_tests": [], - "error_tests": [] + "error_tests": [], } def _add_formatted_failure_details(self, formatted_results: Dict[str, Any], results: Dict[str, Any]) -> None: """Add failure details to the formatted results.""" - if results['failures']: - for failure in results['failures']: - formatted_results["failing_tests"].append({ - "test": failure['test'], - "error": failure['error'] - }) + if results["failures"]: + for failure in results["failures"]: + formatted_results["failing_tests"].append({"test": failure["test"], "error": failure["error"]}) def _add_formatted_error_details(self, formatted_results: Dict[str, Any], results: Dict[str, Any]) -> None: """Add error details to the formatted results.""" - if results['errors']: - for error in results['errors']: - formatted_results["error_tests"].append({ - "test": error['test'], - "error": error['error'] - }) + if results["errors"]: + for error in results["errors"]: + formatted_results["error_tests"].append({"test": error["test"], "error": error["error"]}) def _log_test_results_to_console(self, formatted_results: Dict[str, Any]) -> None: """Log detailed test results to the console for debugging.""" @@ -734,14 +707,14 @@ def _log_test_results_to_console(self, formatted_results: Dict[str, Any]) -> Non print(f"Failures: {formatted_results['failures']}") print(f"Errors: {formatted_results['errors']}") - if formatted_results['failing_tests']: + if formatted_results["failing_tests"]: print("\nFAILURES:") - for i, failure in enumerate(formatted_results['failing_tests'], 1): + for i, failure in enumerate(formatted_results["failing_tests"], 1): print(f"{i}. {failure['test']}: {failure['error']}") - if formatted_results['error_tests']: + if formatted_results["error_tests"]: print("\nERRORS:") - for i, error in enumerate(formatted_results['error_tests'], 1): + for i, error in enumerate(formatted_results["error_tests"], 1): print(f"{i}. {error['test']}: {error['error']}") print("===============================================================") diff --git a/static/client/tts_controller.py b/static/client/tts_controller.py index 99c45ee0..68ade466 100644 --- a/static/client/tts_controller.py +++ b/static/client/tts_controller.py @@ -144,17 +144,20 @@ def speak(self, text: str, voice: Optional[str] = None) -> None: # Make request to TTS endpoint try: req = ajax.ajax() - req.bind('complete', self._on_request_complete) - req.bind('error', self._on_request_error) - req.responseType = 'blob' - req.open('POST', '/api/tts', True) - req.set_header('Content-Type', 'application/json') + req.bind("complete", self._on_request_complete) + req.bind("error", self._on_request_error) + req.responseType = "blob" + req.open("POST", "/api/tts", True) + req.set_header("Content-Type", "application/json") import json - payload = json.dumps({ - 'text': text, - 'voice': voice, - }) + + payload = json.dumps( + { + "text": text, + "voice": voice, + } + ) req.send(payload) except Exception as e: @@ -187,7 +190,7 @@ def _extract_error_message(self, req: Any) -> str: Returns: Human-readable error message """ - status = getattr(req, 'status', 0) + status = getattr(req, "status", 0) # Handle common HTTP status codes with helpful messages if status == 503: @@ -200,12 +203,13 @@ def _extract_error_message(self, req: Any) -> str: # Try to parse JSON error response try: # Try responseText first (works better for non-blob responses) - response_text = getattr(req, 'responseText', None) or getattr(req, 'text', None) + response_text = getattr(req, "responseText", None) or getattr(req, "text", None) if response_text: import json + data = json.loads(response_text) if isinstance(data, dict): - msg = data.get('message', '') + msg = data.get("message", "") if msg: return f"TTS error: {msg}" except Exception: @@ -242,8 +246,8 @@ def on_error(event: Any) -> None: self._cleanup_audio() self._set_state("idle") - audio.addEventListener('ended', on_ended) - audio.addEventListener('error', on_error) + audio.addEventListener("ended", on_ended) + audio.addEventListener("error", on_error) # Store reference and play self._audio = audio @@ -259,7 +263,7 @@ def on_error(event: Any) -> None: def _cleanup_audio(self) -> None: """Clean up audio resources.""" try: - if hasattr(self, '_audio_url') and self._audio_url: + if hasattr(self, "_audio_url") and self._audio_url: window.URL.revokeObjectURL(self._audio_url) self._audio_url = None except Exception: diff --git a/static/client/typing/browser/__init__.pyi b/static/client/typing/browser/__init__.pyi new file mode 100644 index 00000000..fb4122e9 --- /dev/null +++ b/static/client/typing/browser/__init__.pyi @@ -0,0 +1,214 @@ +"""Type stubs for Brython's browser module. + +Provides MyPy-compatible type declarations for the top-level +``browser`` package used in MatHud's client-side Brython code. +Covers ``document``, ``window``, ``html``, ``svg``, ``console``, +and re-exports the ``ajax`` and ``aio`` submodules. +""" + +from __future__ import annotations + +from typing import Any, Callable + +from browser._dom import ClassList as ClassList +from browser._dom import DOMNode as DOMNode +from browser import ajax as ajax +from browser import aio as aio + +# --------------------------------------------------------------------------- +# Helper classes for typed Window sub-objects +# --------------------------------------------------------------------------- + +class _JSON: + def stringify(self, obj: Any) -> str: ... + def parse(self, s: str) -> Any: ... + +class _LocalStorage: + def getItem(self, key: str) -> str | None: ... + def setItem(self, key: str, value: str) -> None: ... + def removeItem(self, key: str) -> None: ... + +class _Performance: + def now(self) -> float: ... + +class _MathJS: + def evaluate(self, expr: str, scope: Any = ...) -> Any: ... + def format(self, val: Any) -> str: ... + def sqrt(self, x: Any) -> Any: ... + def pow(self, x: Any, exp: Any) -> Any: ... + def det(self, matrix: Any) -> Any: ... + def typeOf(self, val: Any) -> str: ... + def number(self, val: Any) -> Any: ... + def matrix(self, data: Any) -> Any: ... + +class _NerdamerExpr: + def text(self) -> str: ... + def evaluate(self) -> _NerdamerExpr: ... + def sub(self, var: str, val: Any) -> _NerdamerExpr: ... + def coeffs(self, var: str) -> _NerdamerExpr: ... + +class _Nerdamer: + def __call__(self, expr: str) -> _NerdamerExpr: ... + def solveEquations(self, eqs: Any) -> Any: ... + def coeffs(self, expr: Any, var: str) -> Any: ... + +class _Date: + def now(self) -> int: ... + def new(self) -> Any: ... + +class _URL: + def createObjectURL(self, blob: Any) -> str: ... + def revokeObjectURL(self, url: str) -> None: ... + +# --------------------------------------------------------------------------- +# Document +# --------------------------------------------------------------------------- + +class Document: + """The browser ``document`` object. + + Not a subclass of ``DOMNode`` to avoid Liskov substitution + conflicts with differing method signatures. + """ + + def getElementById(self, id: str) -> DOMNode | None: ... + def querySelector(self, selector: str) -> DOMNode | None: ... + def select(self, selector: str) -> list[DOMNode]: ... + def select_one(self, selector: str) -> DOMNode | None: ... + def createDocumentFragment(self) -> DOMNode: ... + def execCommand(self, command: str) -> bool: ... + def bind(self, event: str, handler: Callable[..., Any]) -> None: ... + def __getitem__(self, key: str) -> DOMNode: ... + def __contains__(self, key: str) -> bool: ... + def __le__(self, other: Any) -> None: ... + +# --------------------------------------------------------------------------- +# Window +# --------------------------------------------------------------------------- + +class Window: + """The browser ``window`` object. + + Typed sub-objects provide precision for commonly used APIs. + The ``__getattr__`` / ``__setattr__`` escape hatches cover + the many dynamic properties (MatHud custom globals, test + monkey-patching, etc.). + """ + + # Typed sub-objects + JSON: _JSON + localStorage: _LocalStorage + performance: _Performance + math: _MathJS + nerdamer: _Nerdamer + Date: _Date + URL: _URL + + # Direct methods + def setTimeout(self, callback: Callable[..., Any], delay: int) -> int: ... + def clearTimeout(self, timer_id: int) -> None: ... + def requestAnimationFrame(self, callback: Callable[..., Any]) -> int: ... + + # Constructor-like attributes + Audio: Any + Float32Array: Any + FileReader: Any + MouseEvent: Any + MathJax: Any + Math: Any + document: Any + + # Escape hatches for dynamic properties + def __getattr__(self, name: str) -> Any: ... + def __setattr__(self, name: str, value: Any) -> None: ... + def __getitem__(self, key: str) -> Any: ... + def __setitem__(self, key: str, value: Any) -> None: ... + +# --------------------------------------------------------------------------- +# HTMLFactory +# --------------------------------------------------------------------------- + +class HTMLFactory: + """Factory for creating HTML elements via ``html.DIV(...)``, etc.""" + + def BUTTON( + self, content: Any = ..., *, Class: str = ..., id: str = ..., style: Any = ..., **kwargs: Any + ) -> DOMNode: ... + def CANVAS( + self, content: Any = ..., *, Class: str = ..., id: str = ..., style: Any = ..., **kwargs: Any + ) -> DOMNode: ... + def DETAILS( + self, content: Any = ..., *, Class: str = ..., id: str = ..., style: Any = ..., **kwargs: Any + ) -> DOMNode: ... + def DIV( + self, content: Any = ..., *, Class: str = ..., id: str = ..., style: Any = ..., **kwargs: Any + ) -> DOMNode: ... + def H3( + self, content: Any = ..., *, Class: str = ..., id: str = ..., style: Any = ..., **kwargs: Any + ) -> DOMNode: ... + def IMG( + self, content: Any = ..., *, Class: str = ..., id: str = ..., style: Any = ..., **kwargs: Any + ) -> DOMNode: ... + def INPUT( + self, content: Any = ..., *, Class: str = ..., id: str = ..., style: Any = ..., **kwargs: Any + ) -> DOMNode: ... + def LABEL( + self, content: Any = ..., *, Class: str = ..., id: str = ..., style: Any = ..., **kwargs: Any + ) -> DOMNode: ... + def OPTION( + self, content: Any = ..., *, Class: str = ..., id: str = ..., style: Any = ..., **kwargs: Any + ) -> DOMNode: ... + def P(self, content: Any = ..., *, Class: str = ..., id: str = ..., style: Any = ..., **kwargs: Any) -> DOMNode: ... + def SELECT( + self, content: Any = ..., *, Class: str = ..., id: str = ..., style: Any = ..., **kwargs: Any + ) -> DOMNode: ... + def SPAN( + self, content: Any = ..., *, Class: str = ..., id: str = ..., style: Any = ..., **kwargs: Any + ) -> DOMNode: ... + def SUMMARY( + self, content: Any = ..., *, Class: str = ..., id: str = ..., style: Any = ..., **kwargs: Any + ) -> DOMNode: ... + def TEXTAREA( + self, content: Any = ..., *, Class: str = ..., id: str = ..., style: Any = ..., **kwargs: Any + ) -> DOMNode: ... + def __getattr__(self, name: str) -> Callable[..., DOMNode]: ... + +# --------------------------------------------------------------------------- +# SVGFactory +# --------------------------------------------------------------------------- + +class SVGFactory: + """Factory for creating SVG elements via ``svg.line(...)``, etc.""" + + def svg(self, **kwargs: Any) -> DOMNode: ... + def g(self, **kwargs: Any) -> DOMNode: ... + def line(self, **kwargs: Any) -> DOMNode: ... + def path(self, **kwargs: Any) -> DOMNode: ... + def circle(self, **kwargs: Any) -> DOMNode: ... + def ellipse(self, **kwargs: Any) -> DOMNode: ... + def polygon(self, **kwargs: Any) -> DOMNode: ... + def text(self, **kwargs: Any) -> DOMNode: ... + def __getattr__(self, name: str) -> Callable[..., DOMNode]: ... + +# --------------------------------------------------------------------------- +# Console +# --------------------------------------------------------------------------- + +class Console: + """The browser ``console`` object.""" + + def log(self, *args: Any) -> None: ... + def error(self, *args: Any) -> None: ... + def warn(self, *args: Any) -> None: ... + def groupCollapsed(self, label: str) -> None: ... + def groupEnd(self) -> None: ... + +# --------------------------------------------------------------------------- +# Module-level instances +# --------------------------------------------------------------------------- + +document: Document +window: Window +html: HTMLFactory +svg: SVGFactory +console: Console diff --git a/static/client/typing/browser/_dom.pyi b/static/client/typing/browser/_dom.pyi new file mode 100644 index 00000000..d4bf2ee0 --- /dev/null +++ b/static/client/typing/browser/_dom.pyi @@ -0,0 +1,70 @@ +"""Type stubs for Brython browser DOM types. + +Provides MyPy-compatible declarations for DOMNode and ClassList, +the shared base types used across all browser module stubs. +""" + +from __future__ import annotations + +from typing import Any, Callable + +class ClassList: + """Proxy for the DOM element's classList property.""" + + def add(self, token: str) -> None: ... + def remove(self, token: str) -> None: ... + def contains(self, token: str) -> bool: ... + +class DOMNode: + """Base DOM node type used throughout Brython's browser module. + + Represents an HTML or SVG element in the browser DOM. Brython + maps Python attribute access and operators onto the underlying + JavaScript DOM API. + """ + + # --- Properties --- + innerHTML: str + text: str + value: Any + disabled: bool + checked: bool + options: Any + scrollTop: float + scrollHeight: float + style: Any + classList: ClassList + attrs: dict[str, str] + parentNode: DOMNode | None + children: Any + firstChild: DOMNode | None + width: Any + height: Any + onload: Callable[..., Any] | None + result: Any + responseType: str + + # --- Methods --- + def getBoundingClientRect(self) -> Any: ... + def appendChild(self, child: DOMNode) -> DOMNode: ... + def removeChild(self, child: DOMNode) -> DOMNode: ... + def setAttribute(self, name: str, value: str) -> None: ... + def getAttribute(self, name: str) -> str | None: ... + def removeAttribute(self, name: str) -> None: ... + def insertBefore(self, new_child: DOMNode, ref_child: DOMNode | None) -> DOMNode: ... + def cloneNode(self, deep: bool = ...) -> DOMNode: ... + def bind(self, event: str, handler: Callable[..., Any]) -> None: ... + def focus(self) -> None: ... + def blur(self) -> None: ... + def click(self) -> None: ... + def clear(self) -> None: ... + def remove(self) -> None: ... + def getContext(self, context_type: str) -> Any: ... + def select(self, selector: str) -> list[DOMNode]: ... + def select_one(self, selector: str) -> DOMNode | None: ... + + # --- Operators --- + def __getitem__(self, key: str) -> Any: ... + def __setitem__(self, key: str, value: Any) -> None: ... + def __le__(self, other: Any) -> None: ... # Brython DOM append via <= + def __contains__(self, key: Any) -> bool: ... diff --git a/static/client/typing/browser/aio.pyi b/static/client/typing/browser/aio.pyi new file mode 100644 index 00000000..9745c709 --- /dev/null +++ b/static/client/typing/browser/aio.pyi @@ -0,0 +1,12 @@ +"""Type stubs for Brython's browser.aio module. + +Provides async utilities for scheduling coroutines and sleeping +within the Brython event loop. +""" + +from __future__ import annotations + +from typing import Any, Awaitable, Coroutine + +def run(coroutine: Coroutine[Any, Any, Any]) -> None: ... +def sleep(seconds: float) -> Awaitable[None]: ... diff --git a/static/client/typing/browser/ajax.pyi b/static/client/typing/browser/ajax.pyi new file mode 100644 index 00000000..0ce12b88 --- /dev/null +++ b/static/client/typing/browser/ajax.pyi @@ -0,0 +1,46 @@ +"""Type stubs for Brython's browser.ajax module. + +Provides HTTP request functionality for making AJAX calls from +the browser. Both ``ajax()`` (lowercase function) and ``Ajax()`` +(class constructor) return request objects with the same interface. +""" + +from __future__ import annotations + +from typing import Any, Callable + +class AjaxRequest: + """Object returned by ``ajax()`` — wraps a browser XMLHttpRequest.""" + + status: int + text: str + response: Any + responseType: str + + def bind(self, event: str, handler: Callable[..., Any]) -> None: ... + def open(self, method: str, url: str, async_: bool = ...) -> None: ... + def set_header(self, name: str, value: str) -> None: ... + def send(self, data: str | None = ...) -> None: ... + +class Ajax: + """Alternate constructor form used as ``ajax.Ajax()``.""" + + status: int + text: str + response: Any + responseType: str + + def bind(self, event: str, handler: Callable[..., Any]) -> None: ... + def open(self, method: str, url: str, async_: bool = ...) -> None: ... + def set_header(self, name: str, value: str) -> None: ... + def send(self, data: str | None = ...) -> None: ... + +def ajax(timeout: int = ...) -> AjaxRequest: ... +def post( + url: str, + *, + data: str = ..., + headers: dict[str, str] = ..., + oncomplete: Callable[..., Any] = ..., + onerror: Callable[..., Any] = ..., +) -> None: ... diff --git a/static/client/utils/area_expression_evaluator.py b/static/client/utils/area_expression_evaluator.py index e6a8ddb4..a67a4425 100644 --- a/static/client/utils/area_expression_evaluator.py +++ b/static/client/utils/area_expression_evaluator.py @@ -41,7 +41,10 @@ class _RegionWithSource: """Wrapper to track the source drawable for special handling.""" def __init__( - self, region: Region, source_type: str, source_drawable: Optional["Drawable"] = None + self, + region: Optional[Region], + source_type: str, + source_drawable: Optional["Drawable"] = None, ) -> None: self.region = region self.source_type = source_type # "arc", "segment", "polygon", "circle", "ellipse" @@ -71,6 +74,7 @@ def to_dict(self) -> Dict[str, Any]: class _ASTNode: """Base class for AST nodes.""" + pass @@ -337,9 +341,7 @@ def _evaluate_ast( if isinstance(node, _BinaryOpNode): left_result = AreaExpressionEvaluator._evaluate_ast(node.left, canvas) right_result = AreaExpressionEvaluator._evaluate_ast(node.right, canvas) - return AreaExpressionEvaluator._apply_operation( - left_result, right_result, node.op - ) + return AreaExpressionEvaluator._apply_operation(left_result, right_result, node.op) raise ValueError(f"Unknown AST node type: {type(node)}") @@ -377,7 +379,7 @@ def _drawable_to_region_with_source(drawable: "Drawable") -> _RegionWithSource: return _RegionWithSource(region, "arc", drawable) if class_name == "Segment": - return _RegionWithSource(None, "segment", drawable) # type: ignore + return _RegionWithSource(None, "segment", drawable) if hasattr(drawable, "get_segments"): region = Region.from_polygon(drawable) @@ -493,9 +495,7 @@ def _region_from_segments(segments: list) -> Region: raise ValueError("Not enough points to form a region") @staticmethod - def _normalize_to_single( - result: Union[_RegionWithSource, Region, List[Region], None] - ) -> Optional[Region]: + def _normalize_to_single(result: Union[_RegionWithSource, Region, List[Region], None]) -> Optional[Region]: """Convert a result to a single Region or None.""" if result is None: return None @@ -503,6 +503,8 @@ def _normalize_to_single( # For segments without a pre-computed region, create a half-plane if result.region is None and result.source_type == "segment": seg = result.source_drawable + if seg is None: + return None p1 = (seg.point1.x, seg.point1.y) p2 = (seg.point2.x, seg.point2.y) return Region.from_half_plane(p1, p2) @@ -536,9 +538,7 @@ def _apply_operation( # Handle segment intersection with shape specially if op == "&": - result = AreaExpressionEvaluator._handle_segment_intersection( - left_source, right_source - ) + result = AreaExpressionEvaluator._handle_segment_intersection(left_source, right_source) if result is not None: return result @@ -602,6 +602,8 @@ def _handle_segment_intersection( segment = segment_source.source_drawable shape = shape_source.source_drawable + if segment is None or shape is None: + return None if shape_source.source_type == "arc": return AreaExpressionEvaluator._arc_segment_enclosed_region(shape, segment) @@ -612,14 +614,15 @@ def _handle_segment_intersection( p1 = (segment.point1.x, segment.point1.y) p2 = (segment.point2.x, segment.point2.y) half_plane = Region.from_half_plane(p1, p2) + if shape_source.region is None: + return None return shape_source.region.intersection(half_plane) return None @staticmethod def _line_circle_intersections( - x1: float, y1: float, x2: float, y2: float, - cx: float, cy: float, radius: float + x1: float, y1: float, x2: float, y2: float, cx: float, cy: float, radius: float ) -> List[Dict[str, float]]: """Find intersections between a line (infinite) and a circle. @@ -657,9 +660,7 @@ def _line_circle_intersections( return intersections @staticmethod - def _arc_segment_enclosed_region( - arc: "Drawable", segment: "Drawable" - ) -> Optional[Region]: + def _arc_segment_enclosed_region(arc: "Drawable", segment: "Drawable") -> Optional[Region]: """Create the enclosed region bounded by an arc curve and a segment. Finds intersection points between the segment line and the arc, @@ -692,9 +693,7 @@ def _arc_segment_enclosed_region( # Find intersections between the segment LINE (extended) and the circle intersections = AreaExpressionEvaluator._line_circle_intersections( - segment.point1.x, segment.point1.y, - segment.point2.x, segment.point2.y, - center[0], center[1], radius + segment.point1.x, segment.point1.y, segment.point2.x, segment.point2.y, center[0], center[1], radius ) if len(intersections) < 2: @@ -769,9 +768,7 @@ def angle_on_arc(angle: float) -> bool: return Region.from_points(points) @staticmethod - def _circle_segment_enclosed_region( - circle: "Drawable", segment: "Drawable" - ) -> Optional[Region]: + def _circle_segment_enclosed_region(circle: "Drawable", segment: "Drawable") -> Optional[Region]: """Create the enclosed region (circular segment) cut by a segment from a circle. Creates the smaller circular segment (minor segment) on the side @@ -781,9 +778,7 @@ def _circle_segment_enclosed_region( radius = circle.radius intersections = AreaExpressionEvaluator._line_circle_intersections( - segment.point1.x, segment.point1.y, - segment.point2.x, segment.point2.y, - center[0], center[1], radius + segment.point1.x, segment.point1.y, segment.point2.x, segment.point2.y, center[0], center[1], radius ) if len(intersections) < 2: @@ -836,5 +831,3 @@ def _circle_segment_enclosed_region( return None return Region.from_points(points) - - diff --git a/static/client/utils/canonicalizers/__init__.py b/static/client/utils/canonicalizers/__init__.py index 8c8ec5e6..06b8e841 100644 --- a/static/client/utils/canonicalizers/__init__.py +++ b/static/client/utils/canonicalizers/__init__.py @@ -40,6 +40,3 @@ "canonicalize_rectangle", "canonicalize_triangle", ] - - - diff --git a/static/client/utils/canonicalizers/common.py b/static/client/utils/canonicalizers/common.py index 6e912b89..e71e42de 100644 --- a/static/client/utils/canonicalizers/common.py +++ b/static/client/utils/canonicalizers/common.py @@ -73,6 +73,3 @@ def nearest_point(points: Iterable[PointTuple], candidate: PointTuple) -> Option "contains_point", "nearest_point", ] - - - diff --git a/static/client/utils/canonicalizers/quadrilateral.py b/static/client/utils/canonicalizers/quadrilateral.py index d6bbc19f..3edb6b00 100644 --- a/static/client/utils/canonicalizers/quadrilateral.py +++ b/static/client/utils/canonicalizers/quadrilateral.py @@ -406,20 +406,12 @@ def _synthesize_rectangle( for orientation in (1.0, -1.0): v_actual_unit = (v_actual_unit_base[0] * orientation, v_actual_unit_base[1] * orientation) b_corner = ( - primary_anchor[0] - + u_actual_unit[0] * (proj_b_u * scale) - + v_actual_unit[0] * (proj_b_v * scale), - primary_anchor[1] - + u_actual_unit[1] * (proj_b_u * scale) - + v_actual_unit[1] * (proj_b_v * scale), + primary_anchor[0] + u_actual_unit[0] * (proj_b_u * scale) + v_actual_unit[0] * (proj_b_v * scale), + primary_anchor[1] + u_actual_unit[1] * (proj_b_u * scale) + v_actual_unit[1] * (proj_b_v * scale), ) d_corner = ( - primary_anchor[0] - + u_actual_unit[0] * (proj_d_u * scale) - + v_actual_unit[0] * (proj_d_v * scale), - primary_anchor[1] - + u_actual_unit[1] * (proj_d_u * scale) - + v_actual_unit[1] * (proj_d_v * scale), + primary_anchor[0] + u_actual_unit[0] * (proj_d_u * scale) + v_actual_unit[0] * (proj_d_v * scale), + primary_anchor[1] + u_actual_unit[1] * (proj_d_u * scale) + v_actual_unit[1] * (proj_d_v * scale), ) candidate = [primary_anchor, b_corner, opposite_anchor, d_corner] error = QuadrilateralCanonicalizer._rectangle_fit_error(candidate, source_vertices) @@ -726,9 +718,7 @@ def _signed_area(points: Sequence[PointTuple]) -> float: def _ensure_non_degenerate(points: Sequence[PointTuple], tolerance: float) -> None: area = abs(QuadrilateralCanonicalizer._signed_area(points)) if area <= tolerance: - raise PolygonCanonicalizationError( - "Provided vertices collapse to a line; cannot form a quadrilateral." - ) + raise PolygonCanonicalizationError("Provided vertices collapse to a line; cannot form a quadrilateral.") @staticmethod def _average_side_length(points: Sequence[PointTuple]) -> float: @@ -752,11 +742,7 @@ def _align_to_original_order( start_index = distances.index(min(distances)) aligned = list(vertices[start_index:]) + list(vertices[:start_index]) - if ( - QuadrilateralCanonicalizer._signed_area(aligned) - * QuadrilateralCanonicalizer._signed_area(original) - < 0 - ): + if QuadrilateralCanonicalizer._signed_area(aligned) * QuadrilateralCanonicalizer._signed_area(original) < 0: aligned = [aligned[0]] + list(reversed(aligned[1:])) return aligned diff --git a/static/client/utils/canonicalizers/triangle.py b/static/client/utils/canonicalizers/triangle.py index 68461333..2023a7ed 100644 --- a/static/client/utils/canonicalizers/triangle.py +++ b/static/client/utils/canonicalizers/triangle.py @@ -222,6 +222,7 @@ def _dedupe_vertices(points: Sequence[PointTuple], tolerance: float) -> List[Poi @staticmethod def _order_ccw(points: Sequence[PointTuple]) -> List[PointTuple]: centroid = TriangleCanonicalizer._compute_centroid(points) + def angle(point: PointTuple) -> float: return math.atan2(point[1] - centroid[1], point[0] - centroid[0]) @@ -318,6 +319,3 @@ def canonicalize_triangle( __all__ = ["TriangleCanonicalizer", "canonicalize_triangle"] - - - diff --git a/static/client/utils/computation_utils.py b/static/client/utils/computation_utils.py index 3bbb5a8f..5eef3143 100644 --- a/static/client/utils/computation_utils.py +++ b/static/client/utils/computation_utils.py @@ -37,6 +37,7 @@ class ComputationUtils: Provides static methods for managing computation history, detecting duplicates, and maintaining a record of mathematical calculations performed in the canvas. """ + @staticmethod def has_computation(computations: List[Dict[str, Any]], expression: str) -> bool: """ @@ -65,8 +66,5 @@ def add_computation(computations: List[Dict[str, Any]], expression: str, result: list: Updated list of computations """ if not ComputationUtils.has_computation(computations, expression): - computations.append({ - "expression": expression, - "result": result - }) + computations.append({"expression": expression, "result": result}) return computations diff --git a/static/client/utils/geometry_utils.py b/static/client/utils/geometry_utils.py index 7fe4f32c..53958974 100644 --- a/static/client/utils/geometry_utils.py +++ b/static/client/utils/geometry_utils.py @@ -494,9 +494,8 @@ def quadrilateral_type_flags(points: Sequence[PointLike]) -> Dict[str, bool]: angles = GeometryUtils._polygon_internal_angles(points) all_sides_equal = GeometryUtils._all_close(side_lengths) - opposite_sides_equal = ( - GeometryUtils._is_close(side_lengths[0], side_lengths[2]) - and GeometryUtils._is_close(side_lengths[1], side_lengths[3]) + opposite_sides_equal = GeometryUtils._is_close(side_lengths[0], side_lengths[2]) and GeometryUtils._is_close( + side_lengths[1], side_lengths[3] ) right_angles = all(GeometryUtils._is_close(angle, 90.0) for angle in angles) @@ -686,21 +685,12 @@ def is_irregular_polygon_from_segments(segments: List["Segment"]) -> bool: INTERSECTION_EPSILON = 1e-9 @staticmethod - def _points_equal( - p1: Tuple[float, float], - p2: Tuple[float, float], - tol: float = 1e-9 - ) -> bool: + def _points_equal(p1: Tuple[float, float], p2: Tuple[float, float], tol: float = 1e-9) -> bool: """Check if two points are equal within tolerance.""" return abs(p1[0] - p2[0]) < tol and abs(p1[1] - p2[1]) < tol @staticmethod - def _angle_in_arc_range( - angle: float, - start_angle: float, - end_angle: float, - clockwise: bool - ) -> bool: + def _angle_in_arc_range(angle: float, start_angle: float, end_angle: float, clockwise: bool) -> bool: """Check if angle is within the arc's angular range.""" two_pi = 2 * math.pi @@ -735,7 +725,7 @@ def line_line_intersection( seg1_start: Tuple[float, float], seg1_end: Tuple[float, float], seg2_start: Tuple[float, float], - seg2_end: Tuple[float, float] + seg2_end: Tuple[float, float], ) -> List[Tuple[float, float]]: """ Find intersection point between two line segments. @@ -776,27 +766,26 @@ def line_circle_intersection( radius: float, start_angle: float, end_angle: float, - clockwise: bool = False + clockwise: bool = False, ) -> List[Tuple[float, float]]: """ Find intersection points between a line segment and a circular arc. Delegates to MathUtils for core calculation, then filters by arc range. """ + class SegmentAdapter: def __init__(self, p1: Tuple[float, float], p2: Tuple[float, float]): - self.point1 = type('P', (), {'x': p1[0], 'y': p1[1]})() - self.point2 = type('P', (), {'x': p2[0], 'y': p2[1]})() + self.point1 = type("P", (), {"x": p1[0], "y": p1[1]})() + self.point2 = type("P", (), {"x": p2[0], "y": p2[1]})() segment = SegmentAdapter(seg_start, seg_end) - raw_intersections = MathUtils.circle_segment_intersections( - center[0], center[1], radius, segment - ) + raw_intersections = MathUtils.circle_segment_intersections(center[0], center[1], radius, segment) results: List[Tuple[float, float]] = [] for hit in raw_intersections: - angle = hit['angle'] + angle = hit["angle"] if GeometryUtils._angle_in_arc_range(angle, start_angle, end_angle, clockwise): - results.append((hit['x'], hit['y'])) + results.append((hit["x"], hit["y"])) return results @@ -810,16 +799,17 @@ def line_ellipse_intersection( rotation: float, start_angle: float, end_angle: float, - clockwise: bool = False + clockwise: bool = False, ) -> List[Tuple[float, float]]: """ Find intersection points between a line segment and an elliptical arc. Delegates to MathUtils for core calculation, then filters by arc range. """ + class SegmentAdapter: def __init__(self, p1: Tuple[float, float], p2: Tuple[float, float]): - self.point1 = type('P', (), {'x': p1[0], 'y': p1[1]})() - self.point2 = type('P', (), {'x': p2[0], 'y': p2[1]})() + self.point1 = type("P", (), {"x": p1[0], "y": p1[1]})() + self.point2 = type("P", (), {"x": p2[0], "y": p2[1]})() segment = SegmentAdapter(seg_start, seg_end) rotation_degrees = math.degrees(rotation) @@ -829,9 +819,9 @@ def __init__(self, p1: Tuple[float, float], p2: Tuple[float, float]): results: List[Tuple[float, float]] = [] for hit in raw_intersections: - angle = hit['angle'] + angle = hit["angle"] if GeometryUtils._angle_in_arc_range(angle, start_angle, end_angle, clockwise): - results.append((hit['x'], hit['y'])) + results.append((hit["x"], hit["y"])) return results @@ -846,7 +836,7 @@ def circle_circle_intersection( radius2: float, start_angle2: float, end_angle2: float, - clockwise2: bool + clockwise2: bool, ) -> List[Tuple[float, float]]: """ Find intersection points between two circular arcs. @@ -882,8 +872,9 @@ def circle_circle_intersection( point = (px, py) angle1 = math.atan2(point[1] - cy1, point[0] - cx1) angle2 = math.atan2(point[1] - cy2, point[0] - cx2) - if (GeometryUtils._angle_in_arc_range(angle1, start_angle1, end_angle1, clockwise1) and - GeometryUtils._angle_in_arc_range(angle2, start_angle2, end_angle2, clockwise2)): + if GeometryUtils._angle_in_arc_range( + angle1, start_angle1, end_angle1, clockwise1 + ) and GeometryUtils._angle_in_arc_range(angle2, start_angle2, end_angle2, clockwise2): results.append(point) else: offset_x = h * dy / d @@ -893,8 +884,9 @@ def circle_circle_intersection( point = (px + sign * offset_x, py - sign * offset_y) angle1 = math.atan2(point[1] - cy1, point[0] - cx1) angle2 = math.atan2(point[1] - cy2, point[0] - cx2) - if (GeometryUtils._angle_in_arc_range(angle1, start_angle1, end_angle1, clockwise1) and - GeometryUtils._angle_in_arc_range(angle2, start_angle2, end_angle2, clockwise2)): + if GeometryUtils._angle_in_arc_range( + angle1, start_angle1, end_angle1, clockwise1 + ) and GeometryUtils._angle_in_arc_range(angle2, start_angle2, end_angle2, clockwise2): results.append(point) return results @@ -912,7 +904,7 @@ def circle_ellipse_intersection( ellipse_rotation: float, ellipse_start: float, ellipse_end: float, - ellipse_cw: bool + ellipse_cw: bool, ) -> List[Tuple[float, float]]: """ Find intersection points between a circular arc and an elliptical arc. @@ -951,12 +943,11 @@ def circle_ellipse_intersection( circle_angle = math.atan2(world_y - circle_center[1], world_x - circle_center[0]) - if (GeometryUtils._angle_in_arc_range(angle, ellipse_start, ellipse_end, ellipse_cw) and - GeometryUtils._angle_in_arc_range(circle_angle, circle_start, circle_end, circle_cw)): - + if GeometryUtils._angle_in_arc_range( + angle, ellipse_start, ellipse_end, ellipse_cw + ) and GeometryUtils._angle_in_arc_range(circle_angle, circle_start, circle_end, circle_cw): is_duplicate = any( - GeometryUtils._points_equal((world_x, world_y), existing, tol=0.001) - for existing in results + GeometryUtils._points_equal((world_x, world_y), existing, tol=0.001) for existing in results ) if not is_duplicate: results.append((world_x, world_y)) @@ -978,7 +969,7 @@ def ellipse_ellipse_intersection( rotation2: float, start2: float, end2: float, - cw2: bool + cw2: bool, ) -> List[Tuple[float, float]]: """ Find intersection points between two elliptical arcs. @@ -1017,8 +1008,7 @@ def ellipse_ellipse_intersection( if GeometryUtils._angle_in_arc_range(angle2, start2, end2, cw2): is_duplicate = any( - GeometryUtils._points_equal((world_x, world_y), existing, tol=0.001) - for existing in results + GeometryUtils._points_equal((world_x, world_y), existing, tol=0.001) for existing in results ) if not is_duplicate: results.append((world_x, world_y)) @@ -1071,11 +1061,7 @@ def circular_sector_area(radius: float, angle_span: float) -> float: @staticmethod def circular_segment_area( - center: Tuple[float, float], - radius: float, - start_angle: float, - end_angle: float, - clockwise: bool = False + center: Tuple[float, float], radius: float, start_angle: float, end_angle: float, clockwise: bool = False ) -> float: """ Calculate the signed area contribution of a circular arc segment. @@ -1135,7 +1121,7 @@ def elliptical_segment_area( rotation: float, start_angle: float, end_angle: float, - clockwise: bool = False + clockwise: bool = False, ) -> float: """ Calculate the signed area contribution of an elliptical arc segment. @@ -1199,10 +1185,7 @@ def elliptical_segment_area( return area @staticmethod - def line_segment_area_contribution( - start: Tuple[float, float], - end: Tuple[float, float] - ) -> float: + def line_segment_area_contribution(start: Tuple[float, float], end: Tuple[float, float]) -> float: """ Calculate the signed area contribution of a line segment. diff --git a/static/client/utils/graph_analyzer.py b/static/client/utils/graph_analyzer.py index fed90323..1b43dafe 100644 --- a/static/client/utils/graph_analyzer.py +++ b/static/client/utils/graph_analyzer.py @@ -24,6 +24,14 @@ class GraphAnalyzer: + @staticmethod + def _string_or_none(value: Any) -> Optional[str]: + if value is None: + return None + if isinstance(value, str): + return value + return str(value) + @staticmethod def _build_edges(state: GraphState) -> List[Edge[str]]: return [Edge(edge.source, edge.target) for edge in state.edges] @@ -55,33 +63,35 @@ def _edge_name_for_endpoints( for edge in state.edges: if directed: if edge.source == u and edge.target == v: - return edge.name or edge.id + return GraphAnalyzer._string_or_none(edge.name) or GraphAnalyzer._string_or_none(edge.id) else: if (edge.source == u and edge.target == v) or (edge.source == v and edge.target == u): - return edge.name or edge.id + return GraphAnalyzer._string_or_none(edge.name) or GraphAnalyzer._string_or_none(edge.id) return None @staticmethod - def _resolve_root(state: GraphState, adjacency: Dict[str, set[str]], params: Optional[Dict[str, Any]]) -> Optional[str]: + def _resolve_root( + state: GraphState, adjacency: Dict[str, set[str]], params: Optional[Dict[str, Any]] + ) -> Optional[str]: """Resolve root from params or state, handling old internal ID format.""" params = params or {} root = params.get("root") or getattr(state, "root", None) if root is None: # No root specified, use first vertex from state if available if state.vertices: - return state.vertices[0].id + return GraphAnalyzer._string_or_none(state.vertices[0].id) if adjacency: return next(iter(adjacency.keys())) return None # If root is already in adjacency, use it directly - if root in adjacency: + if isinstance(root, str) and root in adjacency: return root # Handle old format: internal IDs like "v0", "v1" if isinstance(root, str) and root.startswith("v") and root[1:].isdigit(): idx = int(root[1:]) # First try state.vertices (more reliable ordering) if state.vertices and 0 <= idx < len(state.vertices): - candidate = state.vertices[idx].id + candidate = GraphAnalyzer._string_or_none(state.vertices[idx].id) if candidate in adjacency: return candidate # Fall back to sorted adjacency keys @@ -90,8 +100,9 @@ def _resolve_root(state: GraphState, adjacency: Dict[str, set[str]], params: Opt return vertex_ids[idx] # Try matching root against vertex names in state for v in state.vertices: - if v.name == root and v.id in adjacency: - return v.id + vertex_id = GraphAnalyzer._string_or_none(v.id) + if v.name == root and vertex_id in adjacency: + return vertex_id # Fallback: return first vertex from adjacency if adjacency: return next(iter(adjacency.keys())) @@ -194,7 +205,9 @@ def analyze(state: GraphState, operation: str, params: Optional[Dict[str, Any]] rooted = GraphUtils.root_tree(adjacency, root) if rooted is None: adj_keys = list(adjacency.keys())[:5] - return {"error": f"invalid tree structure: root={root!r}, adjacency_sample={adj_keys}, edges={len(state.edges)}"} + return { + "error": f"invalid tree structure: root={root!r}, adjacency_sample={adj_keys}, edges={len(state.edges)}" + } parent, children = rooted depths = GraphUtils.node_depths(root, adjacency) or {} lca_node = GraphUtils.lowest_common_ancestor(parent, depths, a, b) @@ -207,7 +220,9 @@ def analyze(state: GraphState, operation: str, params: Optional[Dict[str, Any]] rooted = GraphUtils.root_tree(adjacency, root) if rooted is None: adj_keys = list(adjacency.keys())[:5] - return {"error": f"invalid tree structure: root={root!r}, adjacency_sample={adj_keys}, edges={len(state.edges)}"} + return { + "error": f"invalid tree structure: root={root!r}, adjacency_sample={adj_keys}, edges={len(state.edges)}" + } _, children = rooted balanced = GraphUtils.balance_children(root, children) return {"children": balanced} @@ -219,7 +234,9 @@ def analyze(state: GraphState, operation: str, params: Optional[Dict[str, Any]] rooted = GraphUtils.root_tree(adjacency, root) if rooted is None: adj_keys = list(adjacency.keys())[:5] - return {"error": f"invalid tree structure: root={root!r}, adjacency_sample={adj_keys}, edges={len(state.edges)}"} + return { + "error": f"invalid tree structure: root={root!r}, adjacency_sample={adj_keys}, edges={len(state.edges)}" + } _, children = rooted inverted = GraphUtils.invert_children(children) return {"children": inverted} @@ -232,7 +249,9 @@ def analyze(state: GraphState, operation: str, params: Optional[Dict[str, Any]] rooted = GraphUtils.root_tree(adjacency, root) if rooted is None: adj_keys = list(adjacency.keys())[:5] - return {"error": f"invalid tree structure: root={root!r}, adjacency_sample={adj_keys}, edges={len(state.edges)}"} + return { + "error": f"invalid tree structure: root={root!r}, adjacency_sample={adj_keys}, edges={len(state.edges)}" + } parent, children = rooted rerooted = GraphUtils.reroot_tree(parent, children, new_root) if rerooted is None: @@ -263,11 +282,7 @@ def analyze(state: GraphState, operation: str, params: Optional[Dict[str, Any]] y = params.get("y") if x is None or y is None: return {"error": "x and y coordinates are required for point_in_hull"} - positions = [ - (float(v.x), float(v.y)) - for v in state.vertices - if v.x is not None and v.y is not None - ] + positions = [(float(v.x), float(v.y)) for v in state.vertices if v.x is not None and v.y is not None] hull = GeometryUtils.convex_hull(positions) inside = GeometryUtils.point_in_convex_hull((float(x), float(y)), hull) # Convert tuples to lists for JSON serialization @@ -301,6 +316,3 @@ def _edge_names_from_path(state: GraphState, path: List[str], directed: bool) -> if name: names.append(name) return names - - - diff --git a/static/client/utils/graph_layout.py b/static/client/utils/graph_layout.py index 607c2108..22ee4556 100644 --- a/static/client/utils/graph_layout.py +++ b/static/client/utils/graph_layout.py @@ -32,6 +32,7 @@ # Box Utilities # ============================================================================= + def _default_box(box: Optional[Dict[str, float]], width: float, height: float) -> Dict[str, float]: """Normalize placement box, using canvas dimensions as fallback.""" if not box: @@ -86,10 +87,7 @@ def _center_graph_in_box( return positions # Apply translation to all positions - return { - vid: (pos[0] + dx, pos[1] + dy) - for vid, pos in positions.items() - } + return {vid: (pos[0] + dx, pos[1] + dy) for vid, pos in positions.items()} def _add_missing_vertices( @@ -108,6 +106,7 @@ def _add_missing_vertices( # Layout Selector # ============================================================================= + def layout_vertices( vertex_ids: List[str], edges: List[Edge[str]], @@ -254,6 +253,7 @@ def _infer_root(vertex_ids: List[str], edges: List[Edge[str]]) -> Optional[str]: # Circular Layout # ============================================================================= + def _circular_layout(vertex_ids: List[str], box: Dict[str, float]) -> Dict[str, Tuple[float, float]]: """ Place vertices evenly spaced around a circle. @@ -279,6 +279,7 @@ def _circular_layout(vertex_ids: List[str], box: Dict[str, float]) -> Dict[str, # Grid Layout (TSM Orthogonal) # ============================================================================= + def _simple_grid_placement( vertex_ids: List[str], box: Dict[str, float], @@ -355,10 +356,7 @@ def _optimize_layout_if_needed( def _count_diagonal_edges(ortho_rep: OrthogonalRep) -> int: """Count edges with more than one direction (diagonal edges).""" - return sum( - 1 for edge_key in ortho_rep.edge_directions - if len(ortho_rep.edge_directions[edge_key]) > 1 - ) + return sum(1 for edge_key in ortho_rep.edge_directions if len(ortho_rep.edge_directions[edge_key]) > 1) def _ensure_float_positions( @@ -583,6 +581,7 @@ def _orthogonal_tree_layout( # TSM Phase 1: Planarity and Embedding # ============================================================================= + def _is_planar( vertex_ids: List[str], edges: List[Edge[str]], @@ -621,8 +620,7 @@ def _is_planar( # Each component must be planar for component in components: comp_vertices = list(component) - comp_edges = [e for e in edges - if e.source in component and e.target in component] + comp_edges = [e for e in edges if e.source in component and e.target in component] comp_planarity = _is_planar(comp_vertices, comp_edges) is_comp_planar = comp_planarity[0] if not is_comp_planar: @@ -654,6 +652,7 @@ def _trivial_embedding( class _DfsResult: """Result of DFS traversal for planarity checking.""" + def __init__(self) -> None: self.dfs_number: Dict[str, int] = {} self.dfs_parent: Dict[str, Optional[str]] = {} @@ -690,7 +689,7 @@ def _dfs_for_planarity( if neighbor not in result.visited: result.tree_edges.append((node, neighbor)) stack.append((neighbor, node)) - elif neighbor != parent and result.dfs_number.get(neighbor, float('inf')) < result.dfs_number.get(node, 0): + elif neighbor != parent and result.dfs_number.get(neighbor, float("inf")) < result.dfs_number.get(node, 0): result.back_edges.append((node, neighbor)) return result @@ -708,15 +707,11 @@ def _compute_lowpoints( lowpt[v] = dfs_result.dfs_number.get(v, 0) lowpt2[v] = dfs_result.dfs_number.get(v, 0) - sorted_by_dfs = sorted( - vertex_ids, - key=lambda x: dfs_result.dfs_number.get(x, 0), - reverse=True - ) + sorted_by_dfs = sorted(vertex_ids, key=lambda x: dfs_result.dfs_number.get(x, 0), reverse=True) for v in sorted_by_dfs: # Update from back edges - for (src, tgt) in dfs_result.back_edges: + for src, tgt in dfs_result.back_edges: if src == v: tgt_num = dfs_result.dfs_number.get(tgt, 0) if tgt_num < lowpt[v]: @@ -771,9 +766,7 @@ def _lr_planarity_check( return False, None # Accept graphs passing edge bounds - embedding = _build_embedding_from_dfs( - vertex_ids, adjacency, dfs_result.dfs_number, dfs_result.dfs_parent - ) + embedding = _build_embedding_from_dfs(vertex_ids, adjacency, dfs_result.dfs_number, dfs_result.dfs_parent) return True, embedding @@ -826,12 +819,14 @@ def _build_embedding_from_dfs( # TSM Phase 2: Orthogonalization # ============================================================================= + class OrthogonalRep: """ Orthogonal representation of a planar graph. Stores vertex grid positions and edge routing directions. """ + def __init__(self) -> None: self.vertex_pos: Dict[str, Tuple[int, int]] = {} self.edge_directions: Dict[Tuple[str, str], List[str]] = {} @@ -928,10 +923,7 @@ def _find_cycle_components( component_verts = _find_component_from_start(start, adj_no_bridges, visited) component_set = set(component_verts) - component_edges = [ - e for e in non_bridge_edges - if e.source in component_set and e.target in component_set - ] + component_edges = [e for e in non_bridge_edges if e.source in component_set and e.target in component_set] # Check if this component is a simple cycle if len(component_edges) == len(component_verts) and len(component_verts) >= 3: @@ -1087,18 +1079,18 @@ def _compute_edge_directions( directions: List[str] = [] if u_col == v_col: - directions.append('S' if v_row > u_row else 'N') + directions.append("S" if v_row > u_row else "N") elif u_row == v_row: - directions.append('E' if v_col > u_col else 'W') + directions.append("E" if v_col > u_col else "W") else: if v_row > u_row: - directions.append('S') + directions.append("S") else: - directions.append('N') + directions.append("N") if v_col > u_col: - directions.append('E') + directions.append("E") else: - directions.append('W') + directions.append("W") edge_directions[(u, v)] = directions @@ -1184,6 +1176,7 @@ def _orthogonalize_multi_cycle( class _BfsTreeResult: """Result of BFS tree construction.""" + def __init__(self) -> None: self.parent: Dict[str, Optional[str]] = {} self.children: Dict[str, List[str]] = {} @@ -1307,24 +1300,24 @@ def _compute_edge_directions_with_stats( if delta_col == 0: if delta_row > 0: - directions.append('S') + directions.append("S") elif delta_row < 0: - directions.append('N') + directions.append("N") elif delta_row == 0: if delta_col > 0: - directions.append('E') + directions.append("E") elif delta_col < 0: - directions.append('W') + directions.append("W") else: diagonal_count += 1 if delta_row > 0: - directions.append('S') + directions.append("S") elif delta_row < 0: - directions.append('N') + directions.append("N") if delta_col > 0: - directions.append('E') + directions.append("E") elif delta_col < 0: - directions.append('W') + directions.append("W") edge_directions[(u, v)] = directions @@ -1476,6 +1469,7 @@ def is_occupied(row: int, col: int, exclude: str) -> bool: # TSM Phase 3: Compaction # ============================================================================= + def _compact_orthogonal( ortho_rep: OrthogonalRep, vertex_ids: List[str], @@ -1534,6 +1528,7 @@ def _compact_orthogonal( # Planarization (for non-planar graphs) # ============================================================================= + def _find_all_edge_crossings( edges: List[Edge[str]], positions: Dict[str, Tuple[float, float]], @@ -1543,14 +1538,13 @@ def _find_all_edge_crossings( edge_list = list(edges) for i, e1 in enumerate(edge_list): - for e2 in edge_list[i + 1:]: + for e2 in edge_list[i + 1 :]: # Skip edges sharing a vertex if e1.source in (e2.source, e2.target) or e1.target in (e2.source, e2.target): continue crossing = _edge_intersection( - positions[e1.source], positions[e1.target], - positions[e2.source], positions[e2.target] + positions[e1.source], positions[e1.target], positions[e2.source], positions[e2.target] ) if crossing is not None: @@ -1626,15 +1620,12 @@ def _planarize_graph( if not crossings: return vertex_ids, edges, set() - dummy_vertices, dummy_ids, edge_splits, edges_to_remove = \ - _create_dummy_vertices_for_crossings(crossings) + dummy_vertices, dummy_ids, edge_splits, edges_to_remove = _create_dummy_vertices_for_crossings(crossings) new_vertices = list(vertex_ids) + dummy_vertices # Add edges that don't need splitting - new_edges: List[Edge[str]] = [ - e for e in edges if (e.source, e.target) not in edges_to_remove - ] + new_edges: List[Edge[str]] = [e for e in edges if (e.source, e.target) not in edges_to_remove] # Add split edges new_edges.extend(_create_split_edges(edge_splits)) @@ -1685,6 +1676,7 @@ def _edge_intersection( # Crossing Elimination # ============================================================================= + def _find_crossing_pairs( edges: List[Edge[str]], positions: Dict[str, Tuple[float, float]], @@ -1715,10 +1707,7 @@ def _find_crossing_pairs( continue # Check if edges share a vertex - shares_vertex = ( - e1.source in (e2.source, e2.target) or - e1.target in (e2.source, e2.target) - ) + shares_vertex = e1.source in (e2.source, e2.target) or e1.target in (e2.source, e2.target) if shares_vertex: # Check for adjacent edge overlap @@ -1801,6 +1790,7 @@ def _edges_collinear_overlap( Two edges overlap if they are on the same line and their projections overlap. """ + # Check if all 4 points are collinear def cross_product(o: Tuple[float, float], a: Tuple[float, float], b: Tuple[float, float]) -> float: return (a[0] - o[0]) * (b[1] - o[1]) - (a[1] - o[1]) * (b[0] - o[0]) @@ -1930,10 +1920,7 @@ def _count_crossings_for_vertex( continue # Check if edges share a vertex - shares_vertex = ( - e1.source in (e2.source, e2.target) or - e1.target in (e2.source, e2.target) - ) + shares_vertex = e1.source in (e2.source, e2.target) or e1.target in (e2.source, e2.target) if shares_vertex: # Check for adjacent edge overlap (vertex on segment) @@ -2249,9 +2236,7 @@ def _phase1_eliminate_crossings( improved = False for crossing_pair in crossing_pairs: - best_move = _find_best_crossing_move( - crossing_pair, grid_pos, edges, grid_size, box, best_crossing_count - ) + best_move = _find_best_crossing_move(crossing_pair, grid_pos, edges, grid_size, box, best_crossing_count) if best_move is not None: v = best_move[0] @@ -2267,9 +2252,7 @@ def _phase1_eliminate_crossings( continue # Try expanded search - expanded_result = _try_expanded_search( - vertex_ids, grid_pos, edges, grid_size, box, best_crossing_count - ) + expanded_result = _try_expanded_search(vertex_ids, grid_pos, edges, grid_size, box, best_crossing_count) improved = expanded_result[0] best_crossing_count = expanded_result[1] best_grid_pos = expanded_result[2] @@ -2336,9 +2319,7 @@ def _search_nearby_ortho_positions( continue new_pos = (new_col, new_row) - ortho_result = _try_ortho_position( - v, new_pos, best_ortho_pos, grid_pos, edges, grid_size, best_ortho_count - ) + ortho_result = _try_ortho_position(v, new_pos, best_ortho_pos, grid_pos, edges, grid_size, best_ortho_count) best_ortho_pos = ortho_result[0] best_ortho_count = ortho_result[1] @@ -2367,9 +2348,7 @@ def _search_neighbor_aligned_positions( if (test_col, n_row) in occupied or (test_col, n_row) == old_pos: continue new_pos = (test_col, n_row) - ortho_result = _try_ortho_position( - v, new_pos, best_ortho_pos, grid_pos, edges, grid_size, best_ortho_count - ) + ortho_result = _try_ortho_position(v, new_pos, best_ortho_pos, grid_pos, edges, grid_size, best_ortho_count) best_ortho_pos = ortho_result[0] best_ortho_count = ortho_result[1] @@ -2378,9 +2357,7 @@ def _search_neighbor_aligned_positions( if (n_col, test_row) in occupied or (n_col, test_row) == old_pos: continue new_pos = (n_col, test_row) - ortho_result = _try_ortho_position( - v, new_pos, best_ortho_pos, grid_pos, edges, grid_size, best_ortho_count - ) + ortho_result = _try_ortho_position(v, new_pos, best_ortho_pos, grid_pos, edges, grid_size, best_ortho_count) best_ortho_pos = ortho_result[0] best_ortho_count = ortho_result[1] @@ -2416,16 +2393,13 @@ def _phase2_optimize_orthogonality( occupied = {grid_pos[vid] for vid in grid_pos if vid != v} # Strategy 1: Search nearby positions - search_result = _search_nearby_ortho_positions( - v, old_pos, grid_pos, edges, grid_size, occupied - ) + search_result = _search_nearby_ortho_positions(v, old_pos, grid_pos, edges, grid_size, occupied) best_ortho_pos = search_result[0] best_ortho_count = search_result[1] # Strategy 2: Try aligning with each neighbor align_result = _search_neighbor_aligned_positions( - v, adjacency, grid_pos, edges, grid_size, occupied, - old_pos, best_ortho_pos, best_ortho_count + v, adjacency, grid_pos, edges, grid_size, occupied, old_pos, best_ortho_pos, best_ortho_count ) best_ortho_pos = align_result[0] best_ortho_count = align_result[1] @@ -2469,9 +2443,7 @@ def _eliminate_crossings( grid_pos = _to_grid_coords(positions, box, grid_size) # Phase 1: Eliminate crossings - phase1_result = _phase1_eliminate_crossings( - vertex_ids, edges, grid_pos, grid_size, box, max_iterations - ) + phase1_result = _phase1_eliminate_crossings(vertex_ids, edges, grid_pos, grid_size, box, max_iterations) best_grid_pos = phase1_result[0] best_crossing_count = phase1_result[1] @@ -2530,8 +2502,8 @@ def _normalize_orthogonal_edge_lengths( return positions # Calculate gaps - y_gaps = [sorted_ys[i+1] - sorted_ys[i] for i in range(len(sorted_ys)-1)] if len(sorted_ys) > 1 else [] - x_gaps = [sorted_xs[i+1] - sorted_xs[i] for i in range(len(sorted_xs)-1)] if len(sorted_xs) > 1 else [] + y_gaps = [sorted_ys[i + 1] - sorted_ys[i] for i in range(len(sorted_ys) - 1)] if len(sorted_ys) > 1 else [] + x_gaps = [sorted_xs[i + 1] - sorted_xs[i] for i in range(len(sorted_xs) - 1)] if len(sorted_xs) > 1 else [] all_gaps = [g for g in y_gaps + x_gaps if g > 1.0] # Filter tiny gaps if not all_gaps: @@ -2808,9 +2780,9 @@ def _equalize_edge_lengths( def _get_edge_neighbor(edge: Edge[str], vid: str) -> Optional[str]: """Get the neighbor vertex in an edge, or None if vid is not in the edge.""" if edge.source == vid: - return edge.target + return str(edge.target) elif edge.target == vid: - return edge.source + return str(edge.source) return None @@ -2970,6 +2942,7 @@ def _force_to_grid_fallback( # Radial Layout # ============================================================================= + def _radial_layout( vertex_ids: List[str], edges: List[Edge[str]], @@ -3009,6 +2982,7 @@ def _radial_layout( # Tree Layout (Reingold-Tilford Style) # ============================================================================= + def _tree_layout( vertex_ids: List[str], edges: List[Edge[str]], @@ -3045,9 +3019,7 @@ def _tree_layout( # Phase 3: Assign positions top-down positions = _assign_tree_positions( - root_id, children, subtree_widths, - box["x"], box["y"], box["height"], - scale_x, layer_height + root_id, children, subtree_widths, box["x"], box["y"], box["height"], scale_x, layer_height ) _add_missing_vertices(positions, vertex_ids, box) @@ -3114,6 +3086,7 @@ def assign(node: str, left_x: float, depth: int) -> None: # Force-Directed Layout (Fruchterman-Reingold Style) # ============================================================================= + def _force_directed_layout( vertex_ids: List[str], edges: List[Edge[str]], @@ -3189,7 +3162,7 @@ def _compute_forces( pos1 = positions[v1] x1 = float(pos1[0]) y1 = float(pos1[1]) - for v2 in vertex_ids[i + 1:]: + for v2 in vertex_ids[i + 1 :]: pos2 = positions[v2] x2 = float(pos2[0]) y2 = float(pos2[1]) diff --git a/static/client/utils/graph_utils.py b/static/client/utils/graph_utils.py index 12695bce..825f940c 100644 --- a/static/client/utils/graph_utils.py +++ b/static/client/utils/graph_utils.py @@ -47,6 +47,7 @@ class Edge(Generic[V]): source: The starting vertex of the edge target: The ending vertex of the edge """ + __slots__ = ("_source", "_target") def __init__(self, source: V, target: V) -> None: @@ -557,9 +558,7 @@ def shortest_path_unweighted( directed: bool = False, ) -> Optional[List[V]]: adjacency = ( - GraphUtils.build_directed_adjacency_map(edges) - if directed - else GraphUtils.build_adjacency_map(edges) + GraphUtils.build_directed_adjacency_map(edges) if directed else GraphUtils.build_adjacency_map(edges) ) if start not in adjacency or goal not in adjacency: return None @@ -1236,6 +1235,7 @@ def edges_overlap( Returns: True if the edges are collinear and overlap. """ + # Cross product to check collinearity def cross_product( o: Tuple[float, float], @@ -1386,10 +1386,7 @@ def count_edge_overlaps( continue # Check if edges share a vertex - shares_vertex = ( - e1.source in (e2.source, e2.target) or - e1.target in (e2.source, e2.target) - ) + shares_vertex = e1.source in (e2.source, e2.target) or e1.target in (e2.source, e2.target) if shares_vertex: # Check for adjacent edge overlap @@ -1503,4 +1500,3 @@ def edge_length_uniformity_ratio( if total == 0: return 1.0 return same_count / total - diff --git a/static/client/utils/linear_algebra_utils.py b/static/client/utils/linear_algebra_utils.py index 4b5c5f6a..93058c3b 100644 --- a/static/client/utils/linear_algebra_utils.py +++ b/static/client/utils/linear_algebra_utils.py @@ -126,9 +126,7 @@ def _validate_identifiers(expression: str, scope: Dict[str, Any]) -> None: unknown_tokens.sort() if unknown_tokens: - raise ValueError( - "Unknown identifiers in expression: " + ", ".join(unknown_tokens) - ) + raise ValueError("Unknown identifiers in expression: " + ", ".join(unknown_tokens)) @staticmethod def _extract_object(entry: LinearAlgebraObject) -> Tuple[str, Any]: diff --git a/static/client/utils/math_utils.py b/static/client/utils/math_utils.py index 6c3eaa95..657b42dd 100644 --- a/static/client/utils/math_utils.py +++ b/static/client/utils/math_utils.py @@ -56,6 +56,7 @@ class MathUtils: Class Attributes: EPSILON (float): Global tolerance constant (1e-9) for floating-point comparisons """ + # Epsilon (tolerance) EPSILON = 1e-9 MAX_SERIES_TERMS = 1000 @@ -83,22 +84,22 @@ def format_number_for_cartesian(n: Number, max_digits: int = 6) -> str: if n == 0: return "0" # Use scientific notation for very large or very small numbers but not zero - elif abs(n) >= 10**max_digits or (abs(n) < 10**(-max_digits + 1)): + elif abs(n) >= 10**max_digits or (abs(n) < 10 ** (-max_digits + 1)): formatted_number = f"{n:.1e}" else: formatted_number = f"{n:.{max_digits}g}" # Process scientific notation to adjust exponent formatting - if 'e' in formatted_number: - base, exponent = formatted_number.split('e') - base = base.rstrip('0').rstrip('.') + if "e" in formatted_number: + base, exponent = formatted_number.split("e") + base = base.rstrip("0").rstrip(".") # Fix handling for exponent sign - sign = exponent[0] if exponent.startswith('-') else '+' - exponent_number = exponent.lstrip('+').lstrip('-').lstrip('0') or '0' + sign = exponent[0] if exponent.startswith("-") else "+" + exponent_number = exponent.lstrip("+").lstrip("-").lstrip("0") or "0" formatted_number = f"{base}e{sign}{exponent_number}" else: # Truncate to max_digits significant digits for non-scientific notation - if '.' in formatted_number: - formatted_number = formatted_number[:formatted_number.find('.') + max_digits] + if "." in formatted_number: + formatted_number = formatted_number[: formatted_number.find(".") + max_digits] return formatted_number @staticmethod @@ -144,8 +145,12 @@ def segment_matches_coordinates( Returns: bool: True if segment endpoints match coordinates (in either order), False otherwise """ - first_direction_match = MathUtils.point_matches_coordinates(segment.point1, x1, y1) and MathUtils.point_matches_coordinates(segment.point2, x2, y2) - second_direction_match = MathUtils.point_matches_coordinates(segment.point1, x2, y2) and MathUtils.point_matches_coordinates(segment.point2, x1, y1) + first_direction_match = MathUtils.point_matches_coordinates( + segment.point1, x1, y1 + ) and MathUtils.point_matches_coordinates(segment.point2, x2, y2) + second_direction_match = MathUtils.point_matches_coordinates( + segment.point1, x2, y2 + ) and MathUtils.point_matches_coordinates(segment.point2, x1, y1) return bool(first_direction_match or second_direction_match) @staticmethod @@ -330,13 +335,14 @@ def is_point_on_segment( # Check if point is on the line defined by the segment # Using the cross product approach to check if three points are collinear from drawables.point import Position + origin = Position(sp1x, sp1y) p1 = Position(sp2x, sp2y) p2 = Position(px, py) cross_product = MathUtils.cross_product(origin, p1, p2) # Calculate segment length for a better threshold - segment_length = math.sqrt((sp2x - sp1x)**2 + (sp2y - sp1y)**2) + segment_length = math.sqrt((sp2x - sp1x) ** 2 + (sp2y - sp1y) ** 2) # Calculate a threshold as a proportion of the segment length # This makes it work well for both small and large coordinate values @@ -790,6 +796,7 @@ def is_rectangle( y4: Number, ) -> bool: # points must be in clockwise or counterclockwise order from drawables.point import Position + points = [Position(x, y) for x, y in [(x1, y1), (x2, y2), (x3, y3), (x4, y4)]] # Check for duplicate points with tolerance @@ -800,7 +807,9 @@ def is_rectangle( return False # Calculate all pairwise distances - distances = [MathUtils.get_2D_distance(p1, p2) for i, p1 in enumerate(points) for j, p2 in enumerate(points) if i < j] + distances = [ + MathUtils.get_2D_distance(p1, p2) for i, p1 in enumerate(points) for j, p2 in enumerate(points) if i < j + ] # Group similar distances using tolerance grouped_distances: List[List[float]] = [] @@ -836,7 +845,6 @@ def is_rectangle( return True - # DEPRECATED BUT FASTER @staticmethod def evaluate_expression_using_python(expression: str) -> float: @@ -851,6 +859,7 @@ def evaluate_expression_using_python(expression: str) -> float: float: Result of evaluating expression at x=0 """ from expression_validator import ExpressionValidator + result = ExpressionValidator.parse_function_string(expression)(0) return float(result) @@ -982,12 +991,13 @@ def get_ellipse_formula( Returns: String representation of the ellipse formula """ + def fmt_num(n: Any) -> str: try: n_float = float(n) if n_float.is_integer(): return str(int(n_float)) - return str(n_float).rstrip('0').rstrip('.') + return str(n_float).rstrip("0").rstrip(".") except Exception: return str(n) @@ -1003,9 +1013,9 @@ def fmt_num(n: Any) -> str: sin_a = math.sin(angle_rad) # Calculate coefficients for the rotated ellipse equation - A = (cos_a**2/rx**2) + (sin_a**2/ry**2) - B = 2*cos_a*sin_a*(1/rx**2 - 1/ry**2) - C = (sin_a**2/rx**2) + (cos_a**2/ry**2) + A = (cos_a**2 / rx**2) + (sin_a**2 / ry**2) + B = 2 * cos_a * sin_a * (1 / rx**2 - 1 / ry**2) + C = (sin_a**2 / rx**2) + (cos_a**2 / ry**2) # Format coefficients to 4 decimal places for readability A = round(A, 4) @@ -1057,10 +1067,22 @@ def convert(value: Number, from_unit: str, to_unit: str) -> Any: # Number theory functions that require Python evaluation (not available in Math.js) _PYTHON_ONLY_FUNCTIONS = { - 'is_prime', 'prime_factors', 'mod_pow', 'mod_inverse', - 'next_prime', 'prev_prime', 'totient', 'divisors', - 'summation', 'product', 'arithmetic_sum', 'geometric_sum', - 'geometric_sum_infinite', 'ratio_test', 'root_test', 'p_series_test' + "is_prime", + "prime_factors", + "mod_pow", + "mod_inverse", + "next_prime", + "prev_prime", + "totient", + "divisors", + "summation", + "product", + "arithmetic_sum", + "geometric_sum", + "geometric_sum_infinite", + "ratio_test", + "root_test", + "p_series_test", } @staticmethod @@ -1080,6 +1102,7 @@ def evaluate(expression: str, variables: Optional[Dict[str, Number]] = None) -> """ try: from expression_validator import ExpressionValidator + js_expression = ExpressionValidator.fix_math_expression(expression, python_compatible=False) python_expression = ExpressionValidator.fix_math_expression(expression, python_compatible=True) ExpressionValidator.validate_expression_tree(python_expression) @@ -1087,7 +1110,9 @@ def evaluate(expression: str, variables: Optional[Dict[str, Number]] = None) -> # Check if expression contains Python-only functions (number theory) if any(func in expression for func in MathUtils._PYTHON_ONLY_FUNCTIONS): # Use Python evaluation for number theory functions - result = ExpressionValidator.evaluate_expression(python_expression, variables.get('x', 0) if variables else 0) + result = ExpressionValidator.evaluate_expression( + python_expression, variables.get("x", 0) if variables else 0 + ) # Preserve boolean and list types for better display if isinstance(result, bool): return "True" if result else "False" @@ -1105,10 +1130,15 @@ def evaluate(expression: str, variables: Optional[Dict[str, Number]] = None) -> converted_result = MathUtils.try_convert_to_number(result) # Check for division by zero - if "lim" not in expression and "limit" not in expression and \ - (converted_result == float('-inf') or \ - converted_result == float('inf') or \ - str(converted_result).lower() in ['-inf', 'inf', 'infinity', '-infinity']): + if ( + "lim" not in expression + and "limit" not in expression + and ( + converted_result == float("-inf") + or converted_result == float("inf") + or str(converted_result).lower() in ["-inf", "inf", "infinity", "-infinity"] + ) + ): raise ZeroDivisionError() return converted_result @@ -1152,11 +1182,11 @@ def limit(expression: str, variable: str, value_to_approach: Union[Number, str]) str: Limit result as string or error message """ try: - value_to_approach = str(value_to_approach).lower().replace(' ', '') - if value_to_approach in ['inf', 'infinity']: - value_to_approach = 'Infinity' - elif value_to_approach in ['-inf', '-infinity']: - value_to_approach = '-Infinity' + value_to_approach = str(value_to_approach).lower().replace(" ", "") + if value_to_approach in ["inf", "infinity"]: + value_to_approach = "Infinity" + elif value_to_approach in ["-inf", "-infinity"]: + value_to_approach = "-Infinity" return str(window.nerdamer(f"limit({expression}, {variable}, {value_to_approach})").text()) except Exception as e: return f"Error: {e} {getattr(e, 'message', str(e))}" @@ -1225,9 +1255,7 @@ def numeric_integrate( if steps <= 0: raise ValueError("steps must be positive") if steps > MathUtils.MAX_NUMERIC_INTEGRATION_STEPS: - raise ValueError( - f"steps cannot exceed {MathUtils.MAX_NUMERIC_INTEGRATION_STEPS}" - ) + raise ValueError(f"steps cannot exceed {MathUtils.MAX_NUMERIC_INTEGRATION_STEPS}") expr = window.nerdamer(expression) @@ -1304,26 +1332,27 @@ def factor(expression: str) -> str: @staticmethod def get_equation_type(equation: str) -> str: import re + try: # Preprocess the equation by expanding it to eliminate parentheses expanded_equation = MathUtils.expand(equation) # Remove whitespaces for easier processing - expanded_equation = expanded_equation.replace(' ', '') + expanded_equation = expanded_equation.replace(" ", "") # Split into left and right sides if equation contains = - if '=' in expanded_equation: - left, right = expanded_equation.split('=') + if "=" in expanded_equation: + left, right = expanded_equation.split("=") # If one side is just 'y', use the other side for analysis - if left.strip() == 'y': + if left.strip() == "y": expanded_equation = right - elif right.strip() == 'y': + elif right.strip() == "y": expanded_equation = left # Check for higher order equations (power >= 5) # Pattern: x^5, y^6, z^10, etc. # Matches: 'x^5', 'y^9', 'x^10', 'y^123' # Does not match: 'x^2', 'x^3', 'x^4' - higher_order_match = re.search(r'\b[a-zA-Z]\^([5-9]|\d{2,})\b', expanded_equation) + higher_order_match = re.search(r"\b[a-zA-Z]\^([5-9]|\d{2,})\b", expanded_equation) if higher_order_match: power = higher_order_match.group(1) return f"Order {power}" @@ -1331,13 +1360,13 @@ def get_equation_type(equation: str) -> str: # Check for multiple variables # Pattern: any letters a-z or A-Z # Matches: 'x', 'y', 'X', 'Y' - set(re.findall(r'[a-zA-Z]', expanded_equation)) + set(re.findall(r"[a-zA-Z]", expanded_equation)) # Check for trigonometric equations # Pattern: trig function followed by parentheses and content # Matches: 'sin(x)', 'cos(2x)', 'tan(x+y)' # Does not match: 'sin', 'cos x', 'tan[x]' - trigonometric_match = re.search(r'\b(sin|cos|tan|csc|sec|cot)\s*\(([^)]+)\)', expanded_equation) + trigonometric_match = re.search(r"\b(sin|cos|tan|csc|sec|cot)\s*\(([^)]+)\)", expanded_equation) if trigonometric_match: return "Trigonometric" @@ -1347,14 +1376,14 @@ def get_equation_type(equation: str) -> str: # Does not match: 'x+y', 'x-y' # Note: Only check for variable products, not just multiple variables # (linear equations like x + y = 4 should not be flagged as non-linear) - if re.search(r'[a-zA-Z]\s*[*]?\s*[a-zA-Z]', expanded_equation): + if re.search(r"[a-zA-Z]\s*[*]?\s*[a-zA-Z]", expanded_equation): return "Other Non-linear" # Check for quartic equations # Pattern: letter followed by ^4 # Matches: 'x^4', 'y^4' # Does not match: 'x^2', 'x^5', 'x4' - quartic_match = re.search(r'\b[a-zA-Z]\^4\b', expanded_equation) + quartic_match = re.search(r"\b[a-zA-Z]\^4\b", expanded_equation) if quartic_match: return "Quartic" @@ -1362,7 +1391,7 @@ def get_equation_type(equation: str) -> str: # Pattern: letter followed by ^3 # Matches: 'x^3', 'y^3' # Does not match: 'x^2', 'x^4', 'x3' - cubic_match = re.search(r'\b[a-zA-Z]\^3\b', expanded_equation) + cubic_match = re.search(r"\b[a-zA-Z]\^3\b", expanded_equation) if cubic_match: return "Cubic" @@ -1370,7 +1399,7 @@ def get_equation_type(equation: str) -> str: # Pattern: letter followed by ^2 # Matches: 'x^2', 'y^2' # Does not match: 'x^3', 'x2', 'x^' - quadratic_match = re.search(r'\b[a-zA-Z]\^2\b', expanded_equation) + quadratic_match = re.search(r"\b[a-zA-Z]\^2\b", expanded_equation) if quadratic_match: return "Quadratic" @@ -1378,7 +1407,7 @@ def get_equation_type(equation: str) -> str: # Pattern: single letter # Matches: 'x', 'y' (when not part of another term) # Does not match: 'x^2', 'xy', '2' - linear_match = re.search(r'\b[a-zA-Z]\b', expanded_equation) + linear_match = re.search(r"\b[a-zA-Z]\b", expanded_equation) if linear_match: return "Linear" @@ -1400,12 +1429,7 @@ def determine_max_number_of_solutions(equations: Sequence[str]) -> int: equation_types = [MathUtils.get_equation_type(eq) for eq in equations] # Create a dictionary mapping equation types to their degrees - type_to_degree = { - "Linear": 1, - "Quadratic": 2, - "Cubic": 3, - "Quartic": 4 - } + type_to_degree = {"Linear": 1, "Quadratic": 2, "Cubic": 3, "Quartic": 4} # If any equation type is unknown, trigonometric, or contains 'Error' if any(t in ["Unknown", "Trigonometric"] or "Error" in t for t in equation_types): @@ -1465,13 +1489,13 @@ def solve_linear_system(equations: Sequence[str]) -> str: print(f"Attempting to solve a system of linear equations: {equations}") # Use nerdamer to solve the system of equations - solutions = window.nerdamer.solveEquations(equations) # returns [['x', 3], ['y', 1]] + solutions = window.nerdamer.solveEquations(equations) # returns [['x', 3], ['y', 1]] print(f"Solutions: {solutions}") # Prepare the solution dictionary solution_dict = {sol[0]: sol[1] for sol in solutions} # Convert solution_dict to string format solution_strings = [f"{k} = {v}" for k, v in solution_dict.items()] - return ', '.join(solution_strings) + return ", ".join(solution_strings) except ValueError as ve: raise ve except Exception as e: @@ -1488,15 +1512,16 @@ def solve_linear_quadratic_system(equations: Sequence[str]) -> str: print(f"Attempting to solve a system of linear and quadratic equations: {equations}") from expression_validator import ExpressionValidator + eq1 = MathUtils.expand(equations[0]) eq1 = ExpressionValidator.fix_math_expression(eq1, python_compatible=False) # Split by '=' to separate the left and right sides of the equation and take the side containing the variable - eq1 = eq1.split('=')[0] if 'x' in eq1.split('=')[0] else eq1.split('=')[1] + eq1 = eq1.split("=")[0] if "x" in eq1.split("=")[0] else eq1.split("=")[1] eq2 = MathUtils.expand(equations[1]) eq2 = ExpressionValidator.fix_math_expression(eq2, python_compatible=False) # Split by '=' to separate the left and right sides of the equation and take the side containing the variable - eq2 = eq2.split('=')[0] if 'x' in eq2.split('=')[0] else eq2.split('=')[1] + eq2 = eq2.split("=")[0] if "x" in eq2.split("=")[0] else eq2.split("=")[1] linear, quadratic = (eq1, eq2) if "^2" in eq2 else (eq2, eq1) @@ -1504,44 +1529,46 @@ def solve_linear_quadratic_system(equations: Sequence[str]) -> str: system_eq = MathUtils.expand(system_eq) # Extract m, n = coefficients of the linear equation (assuming y = mx + n form) - lin_coeffs_str = window.nerdamer.coeffs(linear, 'x').text() # The coefficients are placed in the index of their power. So constants are in the 0th place, x^2 would be in the 2nd place, etc. + lin_coeffs_str = window.nerdamer.coeffs( + linear, "x" + ).text() # The coefficients are placed in the index of their power. So constants are in the 0th place, x^2 would be in the 2nd place, etc. lin_coeffs = literal_eval(lin_coeffs_str) m, n = lin_coeffs[1], lin_coeffs[0] # Extract a, b, c = coefficients of system equation - quadratic_coeffs_str = window.nerdamer.coeffs(system_eq, 'x').text() + quadratic_coeffs_str = window.nerdamer.coeffs(system_eq, "x").text() quadratic_coeffs = literal_eval(quadratic_coeffs_str) a, b, c = quadratic_coeffs[2], quadratic_coeffs[1], quadratic_coeffs[0] # Solve the quadratic equation of the system - discriminant = b**2 - 4*a*c + discriminant = b**2 - 4 * a * c if discriminant < 0: raise ValueError(f"No real solution for the quadratic equation {quadratic}.") - x1 = (-b + math.sqrt(discriminant)) / (2*a) - x2 = (-b - math.sqrt(discriminant)) / (2*a) + x1 = (-b + math.sqrt(discriminant)) / (2 * a) + x2 = (-b - math.sqrt(discriminant)) / (2 * a) # If the linear equation is not directly in terms of y, adjust accordingly - y1 = m*x1 + n - y2 = m*x2 + n + y1 = m * x1 + n + y2 = m * x2 + n # Format solutions solutions = [] if discriminant > 0: # Two solutions - solutions.append(('x1', x1)) - solutions.append(('y1', y1)) - solutions.append(('x2', x2)) - solutions.append(('y2', y2)) + solutions.append(("x1", x1)) + solutions.append(("y1", y1)) + solutions.append(("x2", x2)) + solutions.append(("y2", y2)) elif discriminant == 0: # One solution - solutions.append(('x', x1)) - solutions.append(('y', y1)) + solutions.append(("x", x1)) + solutions.append(("y", y1)) # Prepare the solution dictionary (assuming a single solution format for simplification) solution_dict = {sol[0]: sol[1] for sol in solutions} # Convert solution_dict to string format solution_strings = [f"{k} = {v}" for k, v in solution_dict.items()] - return ', '.join(solution_strings) + return ", ".join(solution_strings) except ValueError as ve: raise ve @@ -1561,7 +1588,7 @@ def solve_quadratic_system(equations: Sequence[str]) -> str: for equation in equations: eq = MathUtils.expand(equation) eq = ExpressionValidator.fix_math_expression(eq, python_compatible=False) - eq = eq.split('=')[0] if 'x' in eq.split('=')[0] else eq.split('=')[1] + eq = eq.split("=")[0] if "x" in eq.split("=")[0] else eq.split("=")[1] eqs.append(eq) # Construct the system equation by setting the equations equal to each other @@ -1569,37 +1596,35 @@ def solve_quadratic_system(equations: Sequence[str]) -> str: system_eq = MathUtils.expand(system_eq) # Use nerdamer to solve the system equation for x - x_solutions_raw = MathUtils.solve(system_eq, 'x') + x_solutions_raw = MathUtils.solve(system_eq, "x") x_solutions_data = json.loads(x_solutions_raw) x_solutions = [float(r) for r in x_solutions_data] solution_dict = {} for x_solution in x_solutions: # Substitute x_solution into both original equations to find y - y_equation1 = eqs[0].replace('x', f"({x_solution})") - y_equation2 = eqs[1].replace('x', f"({x_solution})") + y_equation1 = eqs[0].replace("x", f"({x_solution})") + y_equation2 = eqs[1].replace("x", f"({x_solution})") - if 'y' not in y_equation1: - y_equation1 += ' = y' - if 'y' not in y_equation2: - y_equation2 += ' = y' + if "y" not in y_equation1: + y_equation1 += " = y" + if "y" not in y_equation2: + y_equation2 += " = y" - y1_raw = MathUtils.solve(y_equation1, 'y') - y2_raw = MathUtils.solve(y_equation2, 'y') + y1_raw = MathUtils.solve(y_equation1, "y") + y2_raw = MathUtils.solve(y_equation2, "y") y1_value: Optional[float] = float(json.loads(y1_raw)[0]) if y1_raw else None y2_value: Optional[float] = float(json.loads(y2_raw)[0]) if y2_raw else None - print( - f"Solving for x = {x_solution}: {y_equation1} = {y1_value}, {y_equation2} = {y2_value}" - ) + print(f"Solving for x = {x_solution}: {y_equation1} = {y1_value}, {y_equation2} = {y2_value}") if y1_value is not None and y2_value is not None and y1_value == y2_value: solution_dict[x_solution] = y1_value solution_strings = [f"(x = {k}, y = {v})" for k, v in solution_dict.items()] print(f"Solutions found: {solution_strings}") - return ', '.join(solution_strings) + return ", ".join(solution_strings) except ValueError as ve: raise ve @@ -1612,13 +1637,14 @@ def solve_system_of_equations(equations: Sequence[str]) -> str: raise ValueError("Invalid input for equations. Expected a list of equations.") try: # Split single equation strings into two equations - if len(equations) == 1 and 'x' in equations[0] and '=' in equations[0]: - eq1, eq2 = equations[0].split('=') + if len(equations) == 1 and "x" in equations[0] and "=" in equations[0]: + eq1, eq2 = equations[0].split("=") eq1 += "= y" eq2 += "= y" equations = [eq1, eq2] from expression_validator import ExpressionValidator + equations = [ExpressionValidator.fix_math_expression(eq, python_compatible=False) for eq in equations] max_solutions_of_system = MathUtils.determine_max_number_of_solutions(equations) @@ -1651,7 +1677,7 @@ def solve_system_of_equations(equations: Sequence[str]) -> str: try: solutions = window.nerdamer.solveEquations(equations) solution_strings = [f"{solution[0]} = {solution[1]}" for solution in solutions] - return ', '.join(solution_strings) + return ", ".join(solution_strings) except Exception as e: print(f"Nerdamer failed ({e}), falling back to numeric solver") return MathUtils.solve_numeric(equations) @@ -1683,6 +1709,7 @@ def solve_numeric( JSON string with solutions, variables, and method information. """ from numeric_solver import solve_numeric as _solve_numeric + return str(_solve_numeric(equations, variables, initial_guesses, tolerance, max_iterations)) @staticmethod @@ -1820,6 +1847,7 @@ def mod_inverse(a: Number, mod: Number) -> int: a, mod = int(a), int(mod) if mod <= 0: raise ValueError("mod_inverse requires a positive modulus") + # Extended Euclidean Algorithm def extended_gcd(a: int, b: int) -> Tuple[int, int, int]: if a == 0: @@ -1977,9 +2005,7 @@ def summation(expression: str, variable: str, start: int, end: int) -> str: return "0" term_count = end - start + 1 if term_count > MathUtils.MAX_SERIES_TERMS: - raise ValueError( - f"summation supports at most {MathUtils.MAX_SERIES_TERMS} terms" - ) + raise ValueError(f"summation supports at most {MathUtils.MAX_SERIES_TERMS} terms") total = 0.0 for i in range(start, end + 1): value = float(window.nerdamer(expression).sub(variable, i).evaluate().text()) @@ -2013,9 +2039,7 @@ def product(expression: str, variable: str, start: int, end: int) -> str: return "1" term_count = end - start + 1 if term_count > MathUtils.MAX_SERIES_TERMS: - raise ValueError( - f"product supports at most {MathUtils.MAX_SERIES_TERMS} terms" - ) + raise ValueError(f"product supports at most {MathUtils.MAX_SERIES_TERMS} terms") total = 1.0 for i in range(start, end + 1): value = float(window.nerdamer(expression).sub(variable, i).evaluate().text()) @@ -2080,7 +2104,7 @@ def geometric_sum(first: Number, ratio: Number, n: int) -> Number: raise ValueError("Number of terms must be at least 1") if ratio == 1: return first * n - result = first * (1 - ratio ** n) / (1 - ratio) + result = first * (1 - ratio**n) / (1 - ratio) # Return integer if it's a whole number if result == int(result): return int(result) @@ -2326,33 +2350,33 @@ def calculate_vertical_asymptotes( vertical_asymptotes: List[float] = [] # For logarithmic functions - if 'log' in function_string: + if "log" in function_string: vertical_asymptotes.append(0.0) # For rational functions - if '/' in function_string: - denominator = function_string.split('/')[-1].strip() + if "/" in function_string: + denominator = function_string.split("/")[-1].strip() try: # Try to solve denominator = 0 - zeros = json.loads(MathUtils.solve(denominator, 'x')) + zeros = json.loads(MathUtils.solve(denominator, "x")) vertical_asymptotes.extend(float(x) for x in zeros) except: pass # For tangent functions - if 'tan' in function_string: + if "tan" in function_string: # Find all tangent terms in the function - tan_matches = re.findall(r'tan\((.*?)(?:\)|$)', function_string) + tan_matches = re.findall(r"tan\((.*?)(?:\)|$)", function_string) for tan_arg in tan_matches: coeff = 1.0 # Check for x/divisor pattern first (e.g., x/100) - div_match = re.search(r'x\s*/\s*(\d+\.?\d*)', tan_arg) + div_match = re.search(r"x\s*/\s*(\d+\.?\d*)", tan_arg) if div_match: divisor = float(div_match.group(1)) coeff = 1.0 / divisor if divisor != 0 else 1.0 else: # Check for coefficient*x pattern (e.g., 2*x or 2x) - coeff_match = re.search(r'([+-]?\d+\.?\d*)\s*\*?\s*x', tan_arg) + coeff_match = re.search(r"([+-]?\d+\.?\d*)\s*\*?\s*x", tan_arg) if coeff_match: coeff = float(coeff_match.group(1)) @@ -2364,7 +2388,7 @@ def calculate_vertical_asymptotes( # Asymptotes occur at x = (pi/2 + n*pi)/coeff n = math.floor(left * coeff / math.pi - 0.5) while True: - x = (math.pi/2 + n*math.pi) / coeff + x = (math.pi / 2 + n * math.pi) / coeff if x > right: break if x >= left: @@ -2384,7 +2408,7 @@ def calculate_horizontal_asymptotes(function_string: str) -> List[float]: try: # Check limit as x approaches infinity - limit_inf = float(MathUtils.limit(function_string, 'x', 'inf')) + limit_inf = float(MathUtils.limit(function_string, "x", "inf")) if not math.isinf(limit_inf) and not math.isnan(limit_inf): horizontal_asymptotes.append(limit_inf) except: @@ -2392,7 +2416,7 @@ def calculate_horizontal_asymptotes(function_string: str) -> List[float]: try: # Check limit as x approaches negative infinity - limit_neg_inf = float(MathUtils.limit(function_string, 'x', '-inf')) + limit_neg_inf = float(MathUtils.limit(function_string, "x", "-inf")) if not math.isinf(limit_neg_inf) and not math.isnan(limit_neg_inf): horizontal_asymptotes.append(limit_neg_inf) except: @@ -2432,20 +2456,20 @@ def calculate_point_discontinuities( # For piecewise functions (indicated by presence of conditional operators) # Match both Python-style (if/else) and mathematical notation (<, >, etc.) - if any(op in function_string for op in ['if', 'else', '<', '>', '<=', '>=', '==']): + if any(op in function_string for op in ["if", "else", "<", ">", "<=", ">=", "=="]): # Extract transition points from conditions # Handle both styles of conditions condition_patterns = [ - r'(?:<=|>=|<|>|==)\s*(-?\d*\.?\d+)', # Mathematical notation - r'if\s+x\s*(?:<=|>=|<|>|==)\s*(-?\d*\.?\d+)', # Python if notation - r'(?:<=|>=|<|>|==)\s*x\s*(?:<=|>=|<|>|==)\s*(-?\d*\.?\d+)' # Double conditions + r"(?:<=|>=|<|>|==)\s*(-?\d*\.?\d+)", # Mathematical notation + r"if\s+x\s*(?:<=|>=|<|>|==)\s*(-?\d*\.?\d+)", # Python if notation + r"(?:<=|>=|<|>|==)\s*x\s*(?:<=|>=|<|>|==)\s*(-?\d*\.?\d+)", # Double conditions ] for pattern in condition_patterns: matches = re.findall(pattern, function_string) point_discontinuities_set.update(float(x) for x in matches) # For floor and ceil functions - if 'floor' in function_string or 'ceil' in function_string: + if "floor" in function_string or "ceil" in function_string: # If bounds are provided, check each integer within bounds if left_bound is not None and right_bound is not None: left = math.ceil(left_bound) @@ -2453,14 +2477,14 @@ def calculate_point_discontinuities( point_discontinuities_set.update(range(left, right + 1)) # For absolute value function at its corners - if 'abs' in function_string: + if "abs" in function_string: # Extract all arguments of abs functions - abs_pattern = r'abs\((.*?)\)' + abs_pattern = r"abs\((.*?)\)" matches = re.findall(abs_pattern, function_string) for match in matches: try: # Try to solve the argument = 0 to find the corner point - zeros = json.loads(MathUtils.solve(match, 'x')) + zeros = json.loads(MathUtils.solve(match, "x")) point_discontinuities_set.update(float(x) for x in zeros) except: pass @@ -2548,12 +2572,10 @@ def find_diagonal_points( # Sort by a combination of factors that make good rectangle diagonals: # 1. Prefer more balanced rectangles (closer dx/dy ratio to 1.0) # 2. Then by distance as secondary criterion - def diagonal_score( - diag_info: Tuple[PointLike, PointLike, float, float, float] - ) -> Tuple[float, float]: + def diagonal_score(diag_info: Tuple[PointLike, PointLike, float, float, float]) -> Tuple[float, float]: p1, p2, distance, dx, dy = diag_info # Calculate how balanced the rectangle would be (closer to 1.0 is better) - aspect_ratio = max(dx, dy) / min(dx, dy) if min(dx, dy) > 0 else float('inf') + aspect_ratio = max(dx, dy) / min(dx, dy) if min(dx, dy) > 0 else float("inf") balance_score = 1.0 / aspect_ratio # Higher score for more balanced rectangles # Return tuple for sorting: (balance_score descending, distance descending) return (-balance_score, -distance) @@ -2634,10 +2656,7 @@ def detect_function_periodicity( y_left = eval_func(seg_left) y_mid = eval_func(seg_mid) y_right = eval_func(seg_right) - if not all( - isinstance(v, (int, float)) and math.isfinite(v) - for v in [y_left, y_mid, y_right] - ): + if not all(isinstance(v, (int, float)) and math.isfinite(v) for v in [y_left, y_mid, y_right]): continue expected_mid = (y_left + y_right) / 2 deviation = abs(y_mid - expected_mid) @@ -2714,7 +2733,7 @@ def tangent_line_endpoints( return (px, py - half_length), (px, py + half_length) # Calculate dx from: length/2 = sqrt(dx^2 + (slope*dx)^2) = dx * sqrt(1 + slope^2) - dx = half_length / math.sqrt(1 + slope ** 2) + dx = half_length / math.sqrt(1 + slope**2) dy = slope * dx return (px - dx, py - dy), (px + dx, py + dy) @@ -2836,9 +2855,12 @@ def ellipse_tangent_slope_at_angle( @staticmethod def perpendicular_foot( - px: float, py: float, - x1: float, y1: float, - x2: float, y2: float, + px: float, + py: float, + x1: float, + y1: float, + x2: float, + y2: float, ) -> Tuple[float, float]: """Project a point onto a line defined by two points. @@ -2876,9 +2898,12 @@ def perpendicular_foot( @staticmethod def angle_bisector_direction( - vx: float, vy: float, - p1x: float, p1y: float, - p2x: float, p2y: float, + vx: float, + vy: float, + p1x: float, + p1y: float, + p2x: float, + p2y: float, ) -> Tuple[float, float]: """Compute the unit vector along the angle bisector. @@ -2937,9 +2962,12 @@ def angle_bisector_direction( @staticmethod def circumcenter( - x1: float, y1: float, - x2: float, y2: float, - x3: float, y3: float, + x1: float, + y1: float, + x2: float, + y2: float, + x3: float, + y3: float, ) -> Tuple[float, float, float]: """Compute the circumcenter and circumradius of a triangle. @@ -2978,9 +3006,12 @@ def circumcenter( @staticmethod def incenter_and_inradius( - x1: float, y1: float, - x2: float, y2: float, - x3: float, y3: float, + x1: float, + y1: float, + x2: float, + y2: float, + x3: float, + y3: float, ) -> Tuple[float, float, float]: """Compute the incenter and inradius of a triangle. diff --git a/static/client/utils/polygon_canonicalizer.py b/static/client/utils/polygon_canonicalizer.py index 42e237e9..9612aa8d 100644 --- a/static/client/utils/polygon_canonicalizer.py +++ b/static/client/utils/polygon_canonicalizer.py @@ -36,6 +36,3 @@ "canonicalize_rectangle", "canonicalize_triangle", ] - - - diff --git a/static/client/utils/polygon_subtypes.py b/static/client/utils/polygon_subtypes.py index b59b260b..45c5cac1 100644 --- a/static/client/utils/polygon_subtypes.py +++ b/static/client/utils/polygon_subtypes.py @@ -37,7 +37,7 @@ def iter_values(cls) -> Iterable[str]: return (member.value for member in cls) def __str__(self) -> str: - return self.value + return str(self.value) class TriangleSubtype(_BaseSubtype): @@ -60,4 +60,3 @@ class QuadrilateralSubtype(_BaseSubtype): __all__ = ["TriangleSubtype", "QuadrilateralSubtype"] - diff --git a/static/client/utils/relation_inspector.py b/static/client/utils/relation_inspector.py index 267bf0af..fb1a44cb 100644 --- a/static/client/utils/relation_inspector.py +++ b/static/client/utils/relation_inspector.py @@ -80,8 +80,10 @@ def _extract_coords(obj: Any, otype: str) -> List[float]: if otype in ("segment", "vector"): seg = RelationInspector._as_segment(obj, otype) return [ - float(seg.point1.x), float(seg.point1.y), - float(seg.point2.x), float(seg.point2.y), + float(seg.point1.x), + float(seg.point1.y), + float(seg.point2.x), + float(seg.point2.y), ] if otype == "circle": return [float(obj.center.x), float(obj.center.y), float(obj.radius)] @@ -92,8 +94,10 @@ def _extract_coords(obj: Any, otype: str) -> List[float]: return coords if otype == "ellipse": return [ - float(obj.center.x), float(obj.center.y), - float(obj.radius_x), float(obj.radius_y), + float(obj.center.x), + float(obj.center.y), + float(obj.radius_x), + float(obj.radius_y), ] if otype == "rectangle": coords_r: List[float] = [] @@ -251,13 +255,19 @@ def _check_collinear(objects: List[Any], object_types: List[str]) -> Dict[str, A continue if cross / denom > tol: return RelationInspector._ok( - "collinear", False, + "collinear", + False, f"Points are not collinear (point at index {i} deviates)", - tol, {"first_deviant_index": i}, + tol, + {"first_deviant_index": i}, ) return RelationInspector._ok( - "collinear", True, "All points are collinear", tol, {"point_count": len(objects)}, + "collinear", + True, + "All points are collinear", + tol, + {"point_count": len(objects)}, ) @staticmethod @@ -270,13 +280,17 @@ def _check_concyclic(objects: List[Any], object_types: List[str]) -> Dict[str, A p0, p1, p2 = objects[0], objects[1], objects[2] try: cx, cy, r = MathUtils.circumcenter( - float(p0.x), float(p0.y), - float(p1.x), float(p1.y), - float(p2.x), float(p2.y), + float(p0.x), + float(p0.y), + float(p1.x), + float(p1.y), + float(p2.x), + float(p2.y), ) except ValueError: return RelationInspector._ok( - "concyclic", False, + "concyclic", + False, "First three points are collinear — no common circle exists", RelationInspector.RELATION_TOLERANCE, {"reason": "collinear_first_three"}, @@ -288,15 +302,19 @@ def _check_concyclic(objects: List[Any], object_types: List[str]) -> Dict[str, A dist = math.hypot(float(pi.x) - cx, float(pi.y) - cy) if abs(dist - r) > tol: return RelationInspector._ok( - "concyclic", False, + "concyclic", + False, f"Point at index {i} is not on the common circle (distance from center: {dist:.6f}, radius: {r:.6f})", - tol, {"circumcenter": [cx, cy], "circumradius": r, "first_deviant_index": i}, + tol, + {"circumcenter": [cx, cy], "circumradius": r, "first_deviant_index": i}, ) return RelationInspector._ok( - "concyclic", True, + "concyclic", + True, f"All {len(objects)} points lie on a common circle (center: ({cx:.4f}, {cy:.4f}), radius: {r:.4f})", - tol, {"circumcenter": [cx, cy], "circumradius": r}, + tol, + {"circumcenter": [cx, cy], "circumradius": r}, ) @staticmethod @@ -319,7 +337,10 @@ def _check_equal_length(objects: List[Any], object_types: List[str]) -> Dict[str else f"Segments have different lengths ({len1:.6f} vs {len2:.6f})" ) return RelationInspector._ok( - "equal_length", equal, expl, tol, + "equal_length", + equal, + expl, + tol, {"length1": len1, "length2": len2, "difference": diff}, ) @@ -360,7 +381,10 @@ def _check_similar(objects: List[Any], object_types: List[str]) -> Dict[str, Any else f"Triangles are not similar (side ratios: {ratios[0]:.4f}, {ratios[1]:.4f}, {ratios[2]:.4f})" ) return RelationInspector._ok( - "similar", similar, expl, tol, + "similar", + similar, + expl, + tol, {"side_ratios": ratios, "scale_factor": ratios[0] if similar else None}, ) @@ -386,7 +410,10 @@ def _check_congruent(objects: List[Any], object_types: List[str]) -> Dict[str, A else f"Triangles are not congruent (sides1: {s1}, sides2: {s2})" ) return RelationInspector._ok( - "congruent", cong, expl, tol, + "congruent", + cong, + expl, + tol, {"sides1": s1, "sides2": s2}, ) @@ -403,7 +430,9 @@ def _check_tangent(objects: List[Any], object_types: List[str]) -> Dict[str, Any seg_idx = 0 if object_types[0] in ("segment", "vector") else 1 cir_idx = 1 - seg_idx return RelationInspector._tangent_segment_circle( - objects[seg_idx], object_types[seg_idx], objects[cir_idx], + objects[seg_idx], + object_types[seg_idx], + objects[cir_idx], ) # circle + circle @@ -462,9 +491,17 @@ def _tangent_segment_circle(seg_obj: Any, seg_type: str, circle: Any) -> Dict[st expl = f"Segment is not tangent to circle (distance from center to line: {dist:.6f}, radius: {r:.6f})" return RelationInspector._ok( - "tangent", is_tangent, expl, tol, - {"distance_to_line": dist, "radius": r, "tangent_point": [fx, fy], - "foot_on_segment": foot_on_segment, "t_parameter": t_param}, + "tangent", + is_tangent, + expl, + tol, + { + "distance_to_line": dist, + "radius": r, + "tangent_point": [fx, fy], + "foot_on_segment": foot_on_segment, + "t_parameter": t_param, + }, ) @staticmethod @@ -486,14 +523,18 @@ def _tangent_circle_circle(c1: Any, c2: Any) -> Dict[str, Any]: else: kind = "not tangent" - expl = ( - f"Circles are {kind} (center distance: {d:.6f}, r1+r2: {r1 + r2:.6f}, |r1-r2|: {abs(r1 - r2):.6f})" - ) + expl = f"Circles are {kind} (center distance: {d:.6f}, r1+r2: {r1 + r2:.6f}, |r1-r2|: {abs(r1 - r2):.6f})" return RelationInspector._ok( - "tangent", is_tangent, expl, tol, + "tangent", + is_tangent, + expl, + tol, { - "center_distance": d, "r1": r1, "r2": r2, - "externally_tangent": ext_tangent, "internally_tangent": int_tangent, + "center_distance": d, + "r1": r1, + "r2": r2, + "externally_tangent": ext_tangent, + "internally_tangent": int_tangent, }, ) @@ -512,7 +553,8 @@ def _check_concurrent(objects: List[Any], object_types: List[str]) -> Dict[str, maybe_ix, maybe_iy = RelationInspector._line_line_intersection(segs[0], segs[1]) if maybe_ix is None or maybe_iy is None: return RelationInspector._ok( - "concurrent", False, + "concurrent", + False, "First two lines are parallel — no single point of concurrence", RelationInspector.RELATION_TOLERANCE, {"reason": "parallel_pair"}, @@ -528,20 +570,25 @@ def _check_concurrent(objects: List[Any], object_types: List[str]) -> Dict[str, ref = max(1.0, abs(ix), abs(iy), RelationInspector._seg_length(segs[i])) if dist / ref > tol: return RelationInspector._ok( - "concurrent", False, + "concurrent", + False, f"Lines are not concurrent (line at index {i} misses intersection point)", - tol, {"intersection_of_first_two": [ix, iy], "first_deviant_index": i}, + tol, + {"intersection_of_first_two": [ix, iy], "first_deviant_index": i}, ) return RelationInspector._ok( - "concurrent", True, + "concurrent", + True, f"All {len(objects)} lines are concurrent at ({ix:.4f}, {iy:.4f})", - tol, {"intersection": [ix, iy]}, + tol, + {"intersection": [ix, iy]}, ) @staticmethod def _line_line_intersection( - s1: Any, s2: Any, + s1: Any, + s2: Any, ) -> Tuple[Optional[float], Optional[float]]: """Compute intersection of infinite lines through two segments. @@ -637,7 +684,10 @@ def _check_point_on_circle(objects: List[Any], object_types: List[str]) -> Dict[ else f"Point does not lie on the circle (distance from center: {dist:.6f}, radius: {r:.6f})" ) return RelationInspector._ok( - "point_on_circle", on_circle, expl, tol, + "point_on_circle", + on_circle, + expl, + tol, {"distance_from_center": dist, "radius": r, "deviation": abs(dist - r)}, ) @@ -649,7 +699,6 @@ def _check_point_on_circle(objects: List[Any], object_types: List[str]) -> Dict[ def _auto_inspect(objects: List[Any], object_types: List[str]) -> Dict[str, Any]: """Run all applicable checks for the given type combination.""" results: List[Dict[str, Any]] = [] - type_tuple = tuple(sorted(object_types)) n = len(objects) checks: List[str] = [] diff --git a/static/client/utils/statistics/distributions.py b/static/client/utils/statistics/distributions.py index 7ac43f0d..a2587f01 100644 --- a/static/client/utils/statistics/distributions.py +++ b/static/client/utils/statistics/distributions.py @@ -36,9 +36,7 @@ def normal_pdf_expression(mean: float, sigma: float) -> str: raise ValueError("sigma must be > 0") # f(x) = (1 / (sigma * sqrt(2*pi))) * exp(-((x-mean)^2) / (2*sigma^2)) - return ( - f"(1/(({sigma})*sqrt(2*pi)))*exp(-(((x-({mean}))^2)/(2*({sigma})^2)))" - ) + return f"(1/(({sigma})*sqrt(2*pi)))*exp(-(((x-({mean}))^2)/(2*({sigma})^2)))" def default_normal_bounds(mean: float, sigma: float, k: float = 4.0) -> Tuple[float, float]: @@ -50,6 +48,3 @@ def default_normal_bounds(mean: float, sigma: float, k: float = 4.0) -> Tuple[fl if k <= 0.0: raise ValueError("k must be > 0") return (mean - k * sigma, mean + k * sigma) - - - diff --git a/static/client/utils/statistics/regression.py b/static/client/utils/statistics/regression.py index 38638b33..ec459368 100644 --- a/static/client/utils/statistics/regression.py +++ b/static/client/utils/statistics/regression.py @@ -22,6 +22,7 @@ class RegressionResult(TypedDict): """Result of a regression fit.""" + expression: str coefficients: Dict[str, float] r_squared: float @@ -32,6 +33,7 @@ class RegressionResult(TypedDict): # Validation helpers # --------------------------------------------------------------------------- + def _validate_data(x_data: List[float], y_data: List[float], min_points: int = 2) -> None: """Validate input data arrays.""" if not isinstance(x_data, list) or not isinstance(y_data, list): @@ -58,6 +60,7 @@ def _validate_positive(values: List[float], name: str) -> None: # Pure Python matrix operations (no numpy dependency) # --------------------------------------------------------------------------- + def _matrix_multiply(A: List[List[float]], B: List[List[float]]) -> List[List[float]]: """Multiply two matrices.""" rows_A = len(A) @@ -147,6 +150,7 @@ def _solve_least_squares(X: List[List[float]], y: List[float]) -> List[float]: # R-squared calculation # --------------------------------------------------------------------------- + def calculate_r_squared(y_actual: List[float], y_predicted: List[float]) -> float: """ Calculate coefficient of determination (R²). @@ -177,6 +181,7 @@ def calculate_r_squared(y_actual: List[float], y_predicted: List[float]) -> floa # Expression building # --------------------------------------------------------------------------- + def _format_coefficient(value: float, precision: int = 6) -> str: """Format a coefficient for expression string. @@ -186,9 +191,9 @@ def _format_coefficient(value: float, precision: int = 6) -> str: return "0" # Use fixed-point notation to avoid scientific notation (e.g., 1e-05) # which MatHud would interpret as 1 * euler_number - 05 - formatted = f"{value:.{precision}f}".rstrip('0').rstrip('.') + formatted = f"{value:.{precision}f}".rstrip("0").rstrip(".") # Ensure we don't return empty string or just a minus sign - if not formatted or formatted == '-': + if not formatted or formatted == "-": return "0" return formatted @@ -282,6 +287,7 @@ def build_expression(model_type: str, coefficients: Dict[str, float]) -> str: # Model fitting functions # --------------------------------------------------------------------------- + def fit_linear(x_data: List[float], y_data: List[float]) -> RegressionResult: """ Fit linear model: y = mx + b @@ -335,7 +341,7 @@ def fit_polynomial(x_data: List[float], y_data: List[float], degree: int) -> Reg # Build Vandermonde matrix X: List[List[float]] = [] for x in x_data: - row = [x ** j for j in range(degree + 1)] + row = [x**j for j in range(degree + 1)] X.append(row) # Solve least squares @@ -344,7 +350,7 @@ def fit_polynomial(x_data: List[float], y_data: List[float], degree: int) -> Reg # Calculate R-squared y_predicted = [] for x in x_data: - y_pred = sum(coeffs[j] * (x ** j) for j in range(degree + 1)) + y_pred = sum(coeffs[j] * (x**j) for j in range(degree + 1)) y_predicted.append(y_pred) r_squared = calculate_r_squared(y_data, y_predicted) @@ -453,7 +459,7 @@ def fit_power(x_data: List[float], y_data: List[float]) -> RegressionResult: a = math.exp(ln_a) # Calculate R-squared on original scale - y_predicted = [a * (x ** b) for x in x_data] + y_predicted = [a * (x**b) for x in x_data] r_squared = calculate_r_squared(y_data, y_predicted) coefficients = {"a": a, "b": b} @@ -650,7 +656,7 @@ def fit_sinusoidal( if len(crossings) >= 2: # Estimate period as twice average distance between crossings - crossing_diffs = [crossings[i+1] - crossings[i] for i in range(len(crossings) - 1)] + crossing_diffs = [crossings[i + 1] - crossings[i] for i in range(len(crossings) - 1)] half_period = sum(crossing_diffs) / len(crossing_diffs) period = 2 * half_period b_init = 2 * math.pi / period if period > 1e-10 else 1.0 @@ -676,7 +682,7 @@ def compute_sse(a: float, b: float, c: float, d: float) -> float: # Try different periods b_values = [b_init * f for f in [0.5, 0.75, 1.0, 1.25, 1.5, 2.0]] - c_values = [0, math.pi/4, math.pi/2, 3*math.pi/4, math.pi] + c_values = [0, math.pi / 4, math.pi / 2, 3 * math.pi / 4, math.pi] for b_test in b_values: for c_test in c_values: @@ -772,10 +778,7 @@ def fit_regression( model = model_type.strip().lower() if isinstance(model_type, str) else "" if model not in SUPPORTED_MODEL_TYPES: - raise ValueError( - f"Unsupported model_type '{model_type}'. " - f"Supported: {', '.join(SUPPORTED_MODEL_TYPES)}" - ) + raise ValueError(f"Unsupported model_type '{model_type}'. Supported: {', '.join(SUPPORTED_MODEL_TYPES)}") if model == "linear": return fit_linear(x_data, y_data) diff --git a/static/client/utils/style_utils.py b/static/client/utils/style_utils.py index d0f2e938..b97bde54 100644 --- a/static/client/utils/style_utils.py +++ b/static/client/utils/style_utils.py @@ -36,36 +36,159 @@ class StyleUtils: Provides static methods for validating color values, opacity settings, and other styling properties used throughout the MatHud canvas system. """ + @staticmethod def is_valid_css_color(color: str) -> bool: """Validates if a string is a valid CSS color. Supports named colors, hex colors, rgb(), rgba(), hsl(), and hsla().""" # Basic named colors - if color.lower() in ["aliceblue", "antiquewhite", "aqua", "aquamarine", "azure", "beige", "bisque", "black", - "blanchedalmond", "blue", "blueviolet", "brown", "burlywood", "cadetblue", "chartreuse", - "chocolate", "coral", "cornflowerblue", "cornsilk", "crimson", "cyan", "darkblue", "darkcyan", - "darkgoldenrod", "darkgray", "darkgreen", "darkkhaki", "darkmagenta", "darkolivegreen", - "darkorange", "darkorchid", "darkred", "darksalmon", "darkseagreen", "darkslateblue", - "darkslategray", "darkturquoise", "darkviolet", "deeppink", "deepskyblue", "dimgray", - "dodgerblue", "firebrick", "floralwhite", "forestgreen", "fuchsia", "gainsboro", "ghostwhite", - "gold", "goldenrod", "gray", "green", "greenyellow", "honeydew", "hotpink", "indianred", - "indigo", "ivory", "khaki", "lavender", "lavenderblush", "lawngreen", "lemonchiffon", - "lightblue", "lightcoral", "lightcyan", "lightgoldenrodyellow", "lightgray", "lightgreen", - "lightpink", "lightsalmon", "lightseagreen", "lightskyblue", "lightslategray", "lightsteelblue", - "lightyellow", "lime", "limegreen", "linen", "magenta", "maroon", "mediumaquamarine", - "mediumblue", "mediumorchid", "mediumpurple", "mediumseagreen", "mediumslateblue", - "mediumspringgreen", "mediumturquoise", "mediumvioletred", "midnightblue", "mintcream", - "mistyrose", "moccasin", "navajowhite", "navy", "oldlace", "olive", "olivedrab", "orange", - "orangered", "orchid", "palegoldenrod", "palegreen", "paleturquoise", "palevioletred", - "papayawhip", "peachpuff", "peru", "pink", "plum", "powderblue", "purple", "rebeccapurple", - "red", "rosybrown", "royalblue", "saddlebrown", "salmon", "sandybrown", "seagreen", "seashell", - "sienna", "silver", "skyblue", "slateblue", "slategray", "snow", "springgreen", "steelblue", - "tan", "teal", "thistle", "tomato", "turquoise", "violet", "wheat", "white", "whitesmoke", - "yellow", "yellowgreen"]: + if color.lower() in [ + "aliceblue", + "antiquewhite", + "aqua", + "aquamarine", + "azure", + "beige", + "bisque", + "black", + "blanchedalmond", + "blue", + "blueviolet", + "brown", + "burlywood", + "cadetblue", + "chartreuse", + "chocolate", + "coral", + "cornflowerblue", + "cornsilk", + "crimson", + "cyan", + "darkblue", + "darkcyan", + "darkgoldenrod", + "darkgray", + "darkgreen", + "darkkhaki", + "darkmagenta", + "darkolivegreen", + "darkorange", + "darkorchid", + "darkred", + "darksalmon", + "darkseagreen", + "darkslateblue", + "darkslategray", + "darkturquoise", + "darkviolet", + "deeppink", + "deepskyblue", + "dimgray", + "dodgerblue", + "firebrick", + "floralwhite", + "forestgreen", + "fuchsia", + "gainsboro", + "ghostwhite", + "gold", + "goldenrod", + "gray", + "green", + "greenyellow", + "honeydew", + "hotpink", + "indianred", + "indigo", + "ivory", + "khaki", + "lavender", + "lavenderblush", + "lawngreen", + "lemonchiffon", + "lightblue", + "lightcoral", + "lightcyan", + "lightgoldenrodyellow", + "lightgray", + "lightgreen", + "lightpink", + "lightsalmon", + "lightseagreen", + "lightskyblue", + "lightslategray", + "lightsteelblue", + "lightyellow", + "lime", + "limegreen", + "linen", + "magenta", + "maroon", + "mediumaquamarine", + "mediumblue", + "mediumorchid", + "mediumpurple", + "mediumseagreen", + "mediumslateblue", + "mediumspringgreen", + "mediumturquoise", + "mediumvioletred", + "midnightblue", + "mintcream", + "mistyrose", + "moccasin", + "navajowhite", + "navy", + "oldlace", + "olive", + "olivedrab", + "orange", + "orangered", + "orchid", + "palegoldenrod", + "palegreen", + "paleturquoise", + "palevioletred", + "papayawhip", + "peachpuff", + "peru", + "pink", + "plum", + "powderblue", + "purple", + "rebeccapurple", + "red", + "rosybrown", + "royalblue", + "saddlebrown", + "salmon", + "sandybrown", + "seagreen", + "seashell", + "sienna", + "silver", + "skyblue", + "slateblue", + "slategray", + "snow", + "springgreen", + "steelblue", + "tan", + "teal", + "thistle", + "tomato", + "turquoise", + "violet", + "wheat", + "white", + "whitesmoke", + "yellow", + "yellowgreen", + ]: return True # Hex colors - if color.startswith('#') and len(color) in [4, 7]: # #RGB or #RRGGBB + if color.startswith("#") and len(color) in [4, 7]: # #RGB or #RRGGBB try: int(color[1:], 16) return True @@ -73,7 +196,7 @@ def is_valid_css_color(color: str) -> bool: return False # rgb(), rgba(), hsl(), hsla() - if color.startswith(('rgb(', 'rgba(', 'hsl(', 'hsla(')): + if color.startswith(("rgb(", "rgba(", "hsl(", "hsla(")): return True return False diff --git a/static/client/workspace_manager.py b/static/client/workspace_manager.py index 340b58b1..5235e64a 100644 --- a/static/client/workspace_manager.py +++ b/static/client/workspace_manager.py @@ -66,6 +66,7 @@ from drawables.point import Point from drawables.segment import Segment + class WorkspaceManager: """ Client-side workspace manager that handles workspace operations via AJAX communication. @@ -96,6 +97,7 @@ def save_workspace(self, name: Optional[str] = None) -> str: Returns: str: Success or error message from the save operation. """ + def on_complete(req: Any) -> str: return self._parse_save_workspace_response(req, name) @@ -114,7 +116,7 @@ def _parse_save_workspace_response(self, req: Any, name: Optional[str]) -> str: return f'Workspace "{name if name else "current"}" saved successfully.' return self._format_workspace_error("saving", response) except Exception as e: - return f'Error saving workspace: {str(e)}' + return f"Error saving workspace: {str(e)}" def _build_save_workspace_payload(self, name: Optional[str]) -> Dict[str, Any]: return { @@ -430,9 +432,7 @@ def _restore_rectangle(self, item_state: Dict[str, Any]) -> None: points = [self.canvas.get_point_by_name(name) for name in arg_point_names] if not all(points): - missing_names: List[str] = [ - str(arg_point_names[i]) for i, p in enumerate(points) if not p - ] + missing_names: List[str] = [str(arg_point_names[i]) for i, p in enumerate(points) if not p] self._warn_rectangle_missing_points(rect_name, missing_names) return @@ -477,13 +477,14 @@ def _resolve_rectangle_vertices( except PolygonCanonicalizationError: return self._canonicalize_rectangle_from_diagonal(points, rect_name) - def _canonicalize_rectangle_vertices( - self, points: List[Optional["Point"]] - ) -> List[Tuple[float, float]]: - return cast(List[Tuple[float, float]], canonicalize_rectangle( - [(point.x, point.y) for point in points if point is not None], - construction_mode="vertices", - )) + def _canonicalize_rectangle_vertices(self, points: List[Optional["Point"]]) -> List[Tuple[float, float]]: + return cast( + List[Tuple[float, float]], + canonicalize_rectangle( + [(point.x, point.y) for point in points if point is not None], + construction_mode="vertices", + ), + ) def _canonicalize_rectangle_from_diagonal( self, @@ -492,19 +493,18 @@ def _canonicalize_rectangle_from_diagonal( ) -> Optional[List[Tuple[float, float]]]: p_diag1, p_diag2 = MathUtils.find_diagonal_points(points, rect_name) if not p_diag1 or not p_diag2: - print( - f"Warning: Could not determine diagonal points for rectangle '{rect_name}'. Skipping." - ) + print(f"Warning: Could not determine diagonal points for rectangle '{rect_name}'. Skipping.") return None try: - return cast(Optional[List[Tuple[float, float]]], canonicalize_rectangle( - [(p_diag1.x, p_diag1.y), (p_diag2.x, p_diag2.y)], - construction_mode="diagonal", - )) - except PolygonCanonicalizationError: - print( - f"Warning: Unable to canonicalize rectangle '{rect_name}' from supplied coordinates. Skipping." + return cast( + Optional[List[Tuple[float, float]]], + canonicalize_rectangle( + [(p_diag1.x, p_diag1.y), (p_diag2.x, p_diag2.y)], + construction_mode="diagonal", + ), ) + except PolygonCanonicalizationError: + print(f"Warning: Unable to canonicalize rectangle '{rect_name}' from supplied coordinates. Skipping.") return None def _create_circles(self, state: Dict[str, Any]) -> None: @@ -515,9 +515,7 @@ def _create_circles(self, state: Dict[str, Any]) -> None: self._restore_circle(item_state) def _restore_circle(self, item_state: Dict[str, Any]) -> None: - center_point: Optional["Point"] = self.canvas.get_point_by_name( - item_state["args"]["center"] - ) + center_point: Optional["Point"] = self.canvas.get_point_by_name(item_state["args"]["center"]) if center_point: self.canvas.create_circle( center_point.x, @@ -534,9 +532,7 @@ def _create_ellipses(self, state: Dict[str, Any]) -> None: self._restore_ellipse(item_state) def _restore_ellipse(self, item_state: Dict[str, Any]) -> None: - center_point: Optional["Point"] = self.canvas.get_point_by_name( - item_state["args"]["center"] - ) + center_point: Optional["Point"] = self.canvas.get_point_by_name(item_state["args"]["center"]) if center_point: self.canvas.create_ellipse( center_point.x, @@ -638,9 +634,7 @@ def _apply_restored_colored_area_name(self, created: Any, item_state: Dict[str, created.name = name def _warn_colored_area_restore_failure(self, item_state: Dict[str, Any], exc: Exception) -> None: - print( - f"Warning: Could not restore colored area '{item_state.get('name', 'Unnamed')}': {exc}" - ) + print(f"Warning: Could not restore colored area '{item_state.get('name', 'Unnamed')}': {exc}") def _create_plots(self, state: Dict[str, Any]) -> None: """ @@ -914,9 +908,7 @@ def _restore_closed_shape_area(self, item_state: Dict[str, Any]) -> None: opacity=opacity, ) - def _closed_shape_style_args( - self, shape_args: Dict[str, Any] - ) -> Tuple[Any, Any, Any]: + def _closed_shape_style_args(self, shape_args: Dict[str, Any]) -> Tuple[Any, Any, Any]: color = shape_args.get("color", default_area_fill_color) opacity = shape_args.get("opacity", default_area_opacity) resolution = shape_args.get("resolution", default_closed_shape_resolution) @@ -962,15 +954,9 @@ def _restore_generic_colored_area(self, item_state: Dict[str, Any]) -> None: opacity=args.get("opacity", default_area_opacity), ) - def _generic_colored_area_drawable_names( - self, args: Dict[str, Any] - ) -> Tuple[Any, Any]: - drawable1_name = ( - args.get("drawable1_name") or args.get("segment1") or args.get("func1") - ) - drawable2_name = ( - args.get("drawable2_name") or args.get("segment2") or args.get("func2") - ) + def _generic_colored_area_drawable_names(self, args: Dict[str, Any]) -> Tuple[Any, Any]: + drawable1_name = args.get("drawable1_name") or args.get("segment1") or args.get("func1") + drawable2_name = args.get("drawable2_name") or args.get("segment2") or args.get("func2") return drawable1_name, drawable2_name def _create_angles(self, state: Dict[str, Any]) -> None: @@ -1011,9 +997,7 @@ def _restore_circle_arc(self, arc_state: Dict[str, Any]) -> None: return try: - self.canvas.create_circle_arc( - **self._build_restored_circle_arc_kwargs(arc_state, args, point1, point2) - ) + self.canvas.create_circle_arc(**self._build_restored_circle_arc_kwargs(arc_state, args, point1, point2)) except Exception as exc: self._warn_circle_arc_restore_failure(arc_state, exc) @@ -1214,9 +1198,7 @@ def _materialize_plot_family(self, class_name: str, materializer_name: str) -> N self._materialize_plots(plots, materializer) def _get_statistics_manager_for_materialization(self, materializer_name: str) -> Any: - stats_manager = getattr( - getattr(self.canvas, "drawable_manager", None), "statistics_manager", None - ) + stats_manager = getattr(getattr(self.canvas, "drawable_manager", None), "statistics_manager", None) if stats_manager is None or not hasattr(stats_manager, materializer_name): return None return stats_manager @@ -1260,15 +1242,16 @@ def load_workspace(self, name: Optional[str] = None) -> str: Returns: str: Success or error message from the load operation. """ + def on_complete(req: Any) -> str: return self._parse_load_workspace_response(req, name) - url: str = f'/load_workspace?name={name}' if name else '/load_workspace' + url: str = f"/load_workspace?name={name}" if name else "/load_workspace" return self._execute_sync_request( - method='GET', + method="GET", url=url, on_complete=on_complete, - error_prefix='Error loading workspace', + error_prefix="Error loading workspace", ) def _parse_load_workspace_response(self, req: Any, name: Optional[str]) -> str: @@ -1300,14 +1283,15 @@ def list_workspaces(self) -> str: Returns: str: Comma-separated list of workspace names, or 'None' if empty. """ + def on_complete(req: Any) -> str: return self._parse_list_workspaces_response(req) return self._execute_sync_request( - method='GET', - url='/list_workspaces', + method="GET", + url="/list_workspaces", on_complete=on_complete, - error_prefix='Error listing workspaces', + error_prefix="Error listing workspaces", ) def _parse_list_workspaces_response(self, req: Any) -> str: @@ -1335,15 +1319,16 @@ def delete_workspace(self, name: str) -> str: Returns: str: Success or error message from the delete operation. """ + def on_complete(req: Any) -> str: return self._parse_delete_workspace_response(req, name) - url: str = f'/delete_workspace?name={name}' + url: str = f"/delete_workspace?name={name}" return self._execute_sync_request( - method='GET', + method="GET", url=url, on_complete=on_complete, - error_prefix='Error deleting workspace', + error_prefix="Error deleting workspace", ) def _parse_delete_workspace_response(self, req: Any, name: str) -> str: @@ -1382,7 +1367,7 @@ def _workspace_list_from_response(self, response: Dict[str, Any]) -> List[str]: return cast(List[str], response.get("data", [])) def _format_workspace_error(self, action_gerund: str, response: Dict[str, Any]) -> str: - return f'Error {action_gerund} workspace: {response.get("message")}' + return f"Error {action_gerund} workspace: {response.get('message')}" def _execute_sync_request( self, diff --git a/static/functions_definitions.py b/static/functions_definitions.py index 449b9dda..892d7b3b 100644 --- a/static/functions_definitions.py +++ b/static/functions_definitions.py @@ -55,2970 +55,2656 @@ FUNCTIONS: List[Dict[str, Any]] = [ - { - "type": "function", - "function": { - "name": "reset_canvas", - "description": "Resets the canvas zoom and offset", - "strict": True, - "parameters": { - "type": "object", - "properties": {}, - "required": [], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "clear_canvas", - "description": "Clears the canvas by deleting all drawable objects", - "strict": True, - "parameters": { - "type": "object", - "properties": {}, - "required": [], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "zoom", - "description": "Centers viewport on (center_x, center_y). The range_val specifies half-width (if range_axis='x') or half-height (if range_axis='y'); the other axis scales with canvas aspect ratio. Example: 'zoom x in range +-2, y around 10' uses center_x=0, center_y=10, range_val=2, range_axis='x'.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "center_x": { - "type": "number", - "description": "X coordinate to center on" - }, - "center_y": { - "type": "number", - "description": "Y coordinate to center on" - }, - "range_val": { - "type": "number", - "description": "Half-size for the specified axis (shows center +/- this value)" - }, - "range_axis": { - "type": "string", - "enum": ["x", "y"], - "description": "Which axis range applies to; other axis scales with aspect ratio" - } - }, - "required": ["center_x", "center_y", "range_val", "range_axis"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "undo", - "description": "Undoes the last action on the canvas", - "strict": True, - "parameters": { - "type": "object", - "properties": {}, - "required": [], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "redo", - "description": "Redoes the last action on the canvas", - "strict": True, - "parameters": { - "type": "object", - "properties": {}, - "required": [], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "get_current_canvas_state", - "description": "Returns the current serialized canvas state (drawables, cartesian state, computations) without modifying the canvas. Optional filters can narrow by drawable collections or object names.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "drawable_types": { - "type": ["array", "null"], - "description": "Optional drawable collection filters (e.g. ['Points','Segments','Circles']). Case-insensitive; singular or plural accepted.", - "items": { - "type": "string" - } - }, - "object_names": { - "type": ["array", "null"], - "description": "Optional object-name filters across drawable collections (e.g. ['A','B','f']).", - "items": { - "type": "string" - } - }, - "include_computations": { - "type": ["boolean", "null"], - "description": "Whether to include the computations list in the returned state. Defaults to true." - } - }, - "required": ["drawable_types", "object_names", "include_computations"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "run_tests", - "description": "Runs the test suite for the canvas", - "strict": True, - "parameters": { - "type": "object", - "properties": {}, - "required": [], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "create_point", - "description": "Creates and draws a point at the given coordinates. If a name is provided, it will try to use the first available letter from that name as the point's name.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "x": { - "type": "number", - "description": "The X coordinate of the point" - }, - "y": { - "type": "number", - "description": "The Y coordinate of the point" - }, - "color": { - "type": ["string", "null"], - "description": "Optional color for the point" - }, - "name": { - "type": ["string", "null"], - "description": "Optional name for the point. If provided, the first available letter from this name will be used." - } - }, - "required": ["x", "y", "color", "name"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "delete_point", - "description": "Deletes the point with the given coordinates", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "x": { - "type": "number", - "description": "The X coordinate of the point" - }, - "y": { - "type": "number", - "description": "The Y coordinate of the point" - } - }, - "required": ["x", "y"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "update_point", - "description": "Updates the name, color, or position of a solitary point without recreating it. Provide at least one property to change.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "point_name": { - "type": "string", - "description": "Existing name of the point to edit" - }, - "new_name": { - "type": ["string", "null"], - "description": "Optional new name for the point" - }, - "new_x": { - "type": ["number", "null"], - "description": "Optional new x-coordinate (requires new_y)" - }, - "new_y": { - "type": ["number", "null"], - "description": "Optional new y-coordinate (requires new_x)" - }, - "new_color": { - "type": ["string", "null"], - "description": "Optional new display color for the point" - } - }, - "required": ["point_name", "new_name", "new_x", "new_y", "new_color"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "create_segment", - "description": "Creates and draws a segment at the given coordinates for two points. If a name is provided, the first two available letters from that name will be used to name the endpoints.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "x1": { - "type": "number", - "description": "The X coordinate of the first point" - }, - "y1": { - "type": "number", - "description": "The Y coordinate of the first point" - }, - "x2": { - "type": "number", - "description": "The X coordinate of the second point" - }, - "y2": { - "type": "number", - "description": "The Y coordinate of the second point" - }, - "color": { - "type": ["string", "null"], - "description": "Optional color for the segment" - }, - "name": { - "type": ["string", "null"], - "description": "Optional name for the segment. If provided, the first two available letters will be used to name the endpoints." - }, - "label_text": { - "type": ["string", "null"], - "description": "Optional text for the segment-owned label (default empty)" - }, - "label_visible": { - "type": ["boolean", "null"], - "description": "Whether to display the segment-owned label (default false)" - } - }, - "required": ["x1", "y1", "x2", "y2", "color", "name", "label_text", "label_visible"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "delete_segment", - "description": "Deletes the segment found at the given coordinates for two points. If only a name is given, search for appropriate point coordinates in the canvas state.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "x1": { - "type": "number", - "description": "The X coordinate of the first point" - }, - "y1": { - "type": "number", - "description": "The Y coordinate of the first point" - }, - "x2": { - "type": "number", - "description": "The X coordinate of the second point" - }, - "y2": { - "type": "number", - "description": "The Y coordinate of the second point" - } - }, - "required": ["x1", "y1", "x2", "y2"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "update_segment", - "description": "Updates editable properties of an existing segment (color, label text, or label visibility). Provide null for fields that should remain unchanged.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "name": { - "type": "string", - "description": "Existing name of the segment to edit" - }, - "new_color": { - "type": ["string", "null"], - "description": "Optional new color for the segment" - }, - "new_label_text": { - "type": ["string", "null"], - "description": "Optional new text for the segment-owned label" - }, - "new_label_visible": { - "type": ["boolean", "null"], - "description": "Optional visibility flag for the segment-owned label" - } - }, - "required": ["name", "new_color", "new_label_text", "new_label_visible"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "create_vector", - "description": "Creates and draws a vector at the given coordinates for two points called origin and tip. If a name is provided, the first two available letters will be used to name the origin and tip points.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "origin_x": { - "type": "number", - "description": "The X coordinate of the origin point" - }, - "origin_y": { - "type": "number", - "description": "The Y coordinate of the origin point" - }, - "tip_x": { - "type": "number", - "description": "The X coordinate of the tip point" - }, - "tip_y": { - "type": "number", - "description": "The Y coordinate of the tip point" - }, - "color": { - "type": ["string", "null"], - "description": "Optional color for the vector" - }, - "name": { - "type": ["string", "null"], - "description": "Optional name for the vector. If provided, the first two available letters will be used to name the origin and tip points." - } - }, - "required": ["origin_x", "origin_y", "tip_x", "tip_y", "color", "name"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "delete_vector", - "description": "Deletes the vector found at the given coordinates for two points called origin and tip. If only a name is given, search for appropriate point coordinates in the canvas state.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "origin_x": { - "type": "number", - "description": "The X coordinate of the origin point", - }, - "origin_y": { - "type": "number", - "description": "The Y coordinate of the origin point", - }, - "tip_x": { - "type": "number", - "description": "The X coordinate of the tip point", - }, - "tip_y": { - "type": "number", - "description": "The Y coordinate of the tip point", - } - }, - "required": ["origin_x", "origin_y", "tip_x", "tip_y"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "update_vector", - "description": "Updates editable properties of an existing vector (currently just color). Provide null for fields that should remain unchanged.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "name": { - "type": "string", - "description": "Existing name of the vector to edit" - }, - "new_color": { - "type": ["string", "null"], - "description": "Optional new color for the vector" - } - }, - "required": ["name", "new_color"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "create_polygon", - "description": "Creates a polygon from ordered vertex coordinates. For rectangle and square types, coordinates are normalized through the canonicalizer so near-rectangles snap into valid rectangles. Triangle inputs can optionally request canonicalization toward special subtypes such as equilateral or right triangles.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "vertices": { - "type": "array", - "minItems": 3, - "description": "Ordered list of polygon vertex coordinates.", - "items": { - "type": "object", - "properties": { - "x": { - "type": "number", - "description": "X coordinate of the vertex." - }, - "y": { - "type": "number", - "description": "Y coordinate of the vertex." - } - }, - "required": ["x", "y"], - "additionalProperties": False - } - }, - "polygon_type": { - "type": ["string", "null"], - "description": "Optional polygon classification (triangle, quadrilateral, pentagon, hexagon, heptagon, octagon, nonagon, decagon, or generic).", - "enum": ["triangle", "quadrilateral", "pentagon", "hexagon", "heptagon", "octagon", "nonagon", "decagon", "generic", None], - }, - "color": { - "type": ["string", "null"], - "description": "Optional stroke color for the polygon edges." - }, - "name": { - "type": ["string", "null"], - "description": "Optional name for the polygon. Letters are reused to label vertices." - }, - "subtype": { - "type": ["string", "null"], - "description": "Optional polygon subtype hint. Triangles support equilateral, isosceles, right, right_isosceles. Quadrilaterals support rectangle, square, parallelogram, rhombus, kite, trapezoid, isosceles_trapezoid, right_trapezoid.", - "enum": POLYGON_SUBTYPE_VALUES + [None], - } - }, - "required": ["vertices", "polygon_type", "color", "name", "subtype"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "delete_polygon", - "description": "Deletes a polygon by name or by matching a set of vertex coordinates. Specify polygon_type to limit the search.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "polygon_type": { - "type": ["string", "null"], - "description": "Optional polygon classification (triangle, quadrilateral, rectangle, square, pentagon, hexagon, heptagon, octagon, nonagon, decagon, or generic)." - }, - "name": { - "type": ["string", "null"], - "description": "Existing name of the polygon to delete." + { + "type": "function", + "function": { + "name": "reset_canvas", + "description": "Resets the canvas zoom and offset", + "strict": True, + "parameters": {"type": "object", "properties": {}, "required": [], "additionalProperties": False}, + }, + }, + { + "type": "function", + "function": { + "name": "clear_canvas", + "description": "Clears the canvas by deleting all drawable objects", + "strict": True, + "parameters": {"type": "object", "properties": {}, "required": [], "additionalProperties": False}, + }, + }, + { + "type": "function", + "function": { + "name": "zoom", + "description": "Centers viewport on (center_x, center_y). The range_val specifies half-width (if range_axis='x') or half-height (if range_axis='y'); the other axis scales with canvas aspect ratio. Example: 'zoom x in range +-2, y around 10' uses center_x=0, center_y=10, range_val=2, range_axis='x'.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "center_x": {"type": "number", "description": "X coordinate to center on"}, + "center_y": {"type": "number", "description": "Y coordinate to center on"}, + "range_val": { + "type": "number", + "description": "Half-size for the specified axis (shows center +/- this value)", + }, + "range_axis": { + "type": "string", + "enum": ["x", "y"], + "description": "Which axis range applies to; other axis scales with aspect ratio", + }, + }, + "required": ["center_x", "center_y", "range_val", "range_axis"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "undo", + "description": "Undoes the last action on the canvas", + "strict": True, + "parameters": {"type": "object", "properties": {}, "required": [], "additionalProperties": False}, + }, + }, + { + "type": "function", + "function": { + "name": "redo", + "description": "Redoes the last action on the canvas", + "strict": True, + "parameters": {"type": "object", "properties": {}, "required": [], "additionalProperties": False}, + }, + }, + { + "type": "function", + "function": { + "name": "get_current_canvas_state", + "description": "Returns the current serialized canvas state (drawables, cartesian state, computations) without modifying the canvas. Optional filters can narrow by drawable collections or object names.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "drawable_types": { + "type": ["array", "null"], + "description": "Optional drawable collection filters (e.g. ['Points','Segments','Circles']). Case-insensitive; singular or plural accepted.", + "items": {"type": "string"}, + }, + "object_names": { + "type": ["array", "null"], + "description": "Optional object-name filters across drawable collections (e.g. ['A','B','f']).", + "items": {"type": "string"}, + }, + "include_computations": { + "type": ["boolean", "null"], + "description": "Whether to include the computations list in the returned state. Defaults to true.", + }, + }, + "required": ["drawable_types", "object_names", "include_computations"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "run_tests", + "description": "Runs the test suite for the canvas", + "strict": True, + "parameters": {"type": "object", "properties": {}, "required": [], "additionalProperties": False}, + }, + }, + { + "type": "function", + "function": { + "name": "create_point", + "description": "Creates and draws a point at the given coordinates. If a name is provided, it will try to use the first available letter from that name as the point's name.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "x": {"type": "number", "description": "The X coordinate of the point"}, + "y": {"type": "number", "description": "The Y coordinate of the point"}, + "color": {"type": ["string", "null"], "description": "Optional color for the point"}, + "name": { + "type": ["string", "null"], + "description": "Optional name for the point. If provided, the first available letter from this name will be used.", + }, + }, + "required": ["x", "y", "color", "name"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "delete_point", + "description": "Deletes the point with the given coordinates", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "x": {"type": "number", "description": "The X coordinate of the point"}, + "y": {"type": "number", "description": "The Y coordinate of the point"}, + }, + "required": ["x", "y"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "update_point", + "description": "Updates the name, color, or position of a solitary point without recreating it. Provide at least one property to change.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "point_name": {"type": "string", "description": "Existing name of the point to edit"}, + "new_name": {"type": ["string", "null"], "description": "Optional new name for the point"}, + "new_x": {"type": ["number", "null"], "description": "Optional new x-coordinate (requires new_y)"}, + "new_y": {"type": ["number", "null"], "description": "Optional new y-coordinate (requires new_x)"}, + "new_color": { + "type": ["string", "null"], + "description": "Optional new display color for the point", + }, + }, + "required": ["point_name", "new_name", "new_x", "new_y", "new_color"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "create_segment", + "description": "Creates and draws a segment at the given coordinates for two points. If a name is provided, the first two available letters from that name will be used to name the endpoints.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "x1": {"type": "number", "description": "The X coordinate of the first point"}, + "y1": {"type": "number", "description": "The Y coordinate of the first point"}, + "x2": {"type": "number", "description": "The X coordinate of the second point"}, + "y2": {"type": "number", "description": "The Y coordinate of the second point"}, + "color": {"type": ["string", "null"], "description": "Optional color for the segment"}, + "name": { + "type": ["string", "null"], + "description": "Optional name for the segment. If provided, the first two available letters will be used to name the endpoints.", + }, + "label_text": { + "type": ["string", "null"], + "description": "Optional text for the segment-owned label (default empty)", + }, + "label_visible": { + "type": ["boolean", "null"], + "description": "Whether to display the segment-owned label (default false)", + }, + }, + "required": ["x1", "y1", "x2", "y2", "color", "name", "label_text", "label_visible"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "delete_segment", + "description": "Deletes the segment found at the given coordinates for two points. If only a name is given, search for appropriate point coordinates in the canvas state.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "x1": {"type": "number", "description": "The X coordinate of the first point"}, + "y1": {"type": "number", "description": "The Y coordinate of the first point"}, + "x2": {"type": "number", "description": "The X coordinate of the second point"}, + "y2": {"type": "number", "description": "The Y coordinate of the second point"}, + }, + "required": ["x1", "y1", "x2", "y2"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "update_segment", + "description": "Updates editable properties of an existing segment (color, label text, or label visibility). Provide null for fields that should remain unchanged.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "Existing name of the segment to edit"}, + "new_color": {"type": ["string", "null"], "description": "Optional new color for the segment"}, + "new_label_text": { + "type": ["string", "null"], + "description": "Optional new text for the segment-owned label", + }, + "new_label_visible": { + "type": ["boolean", "null"], + "description": "Optional visibility flag for the segment-owned label", + }, + }, + "required": ["name", "new_color", "new_label_text", "new_label_visible"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "create_vector", + "description": "Creates and draws a vector at the given coordinates for two points called origin and tip. If a name is provided, the first two available letters will be used to name the origin and tip points.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "origin_x": {"type": "number", "description": "The X coordinate of the origin point"}, + "origin_y": {"type": "number", "description": "The Y coordinate of the origin point"}, + "tip_x": {"type": "number", "description": "The X coordinate of the tip point"}, + "tip_y": {"type": "number", "description": "The Y coordinate of the tip point"}, + "color": {"type": ["string", "null"], "description": "Optional color for the vector"}, + "name": { + "type": ["string", "null"], + "description": "Optional name for the vector. If provided, the first two available letters will be used to name the origin and tip points.", + }, + }, + "required": ["origin_x", "origin_y", "tip_x", "tip_y", "color", "name"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "delete_vector", + "description": "Deletes the vector found at the given coordinates for two points called origin and tip. If only a name is given, search for appropriate point coordinates in the canvas state.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "origin_x": { + "type": "number", + "description": "The X coordinate of the origin point", + }, + "origin_y": { + "type": "number", + "description": "The Y coordinate of the origin point", + }, + "tip_x": { + "type": "number", + "description": "The X coordinate of the tip point", + }, + "tip_y": { + "type": "number", + "description": "The Y coordinate of the tip point", + }, + }, + "required": ["origin_x", "origin_y", "tip_x", "tip_y"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "update_vector", + "description": "Updates editable properties of an existing vector (currently just color). Provide null for fields that should remain unchanged.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "Existing name of the vector to edit"}, + "new_color": {"type": ["string", "null"], "description": "Optional new color for the vector"}, + }, + "required": ["name", "new_color"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "create_polygon", + "description": "Creates a polygon from ordered vertex coordinates. For rectangle and square types, coordinates are normalized through the canonicalizer so near-rectangles snap into valid rectangles. Triangle inputs can optionally request canonicalization toward special subtypes such as equilateral or right triangles.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "vertices": { + "type": "array", + "minItems": 3, + "description": "Ordered list of polygon vertex coordinates.", + "items": { + "type": "object", + "properties": { + "x": {"type": "number", "description": "X coordinate of the vertex."}, + "y": {"type": "number", "description": "Y coordinate of the vertex."}, + }, + "required": ["x", "y"], + "additionalProperties": False, + }, + }, + "polygon_type": { + "type": ["string", "null"], + "description": "Optional polygon classification (triangle, quadrilateral, pentagon, hexagon, heptagon, octagon, nonagon, decagon, or generic).", + "enum": [ + "triangle", + "quadrilateral", + "pentagon", + "hexagon", + "heptagon", + "octagon", + "nonagon", + "decagon", + "generic", + None, + ], + }, + "color": { + "type": ["string", "null"], + "description": "Optional stroke color for the polygon edges.", + }, + "name": { + "type": ["string", "null"], + "description": "Optional name for the polygon. Letters are reused to label vertices.", + }, + "subtype": { + "type": ["string", "null"], + "description": "Optional polygon subtype hint. Triangles support equilateral, isosceles, right, right_isosceles. Quadrilaterals support rectangle, square, parallelogram, rhombus, kite, trapezoid, isosceles_trapezoid, right_trapezoid.", + "enum": POLYGON_SUBTYPE_VALUES + [None], + }, + }, + "required": ["vertices", "polygon_type", "color", "name", "subtype"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "delete_polygon", + "description": "Deletes a polygon by name or by matching a set of vertex coordinates. Specify polygon_type to limit the search.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "polygon_type": { + "type": ["string", "null"], + "description": "Optional polygon classification (triangle, quadrilateral, rectangle, square, pentagon, hexagon, heptagon, octagon, nonagon, decagon, or generic).", + }, + "name": {"type": ["string", "null"], "description": "Existing name of the polygon to delete."}, + "vertices": { + "type": "array", + "minItems": 3, + "description": "Ordered list of polygon vertex coordinates used for lookup.", + "items": { + "type": "object", + "properties": { + "x": {"type": "number", "description": "X coordinate of the vertex."}, + "y": {"type": "number", "description": "Y coordinate of the vertex."}, + }, + "required": ["x", "y"], + "additionalProperties": False, + }, + }, + }, + "required": ["polygon_type", "name", "vertices"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "update_polygon", + "description": "Updates editable properties of an existing polygon (currently just color). Provide null for fields that should remain unchanged.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "polygon_name": {"type": "string", "description": "Existing name of the polygon to edit."}, + "polygon_type": { + "type": ["string", "null"], + "description": "Optional polygon classification to disambiguate the lookup.", + }, + "new_color": { + "type": ["string", "null"], + "description": "Optional new color for the polygon edges.", + }, + }, + "required": ["polygon_name", "polygon_type", "new_color"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "create_circle", + "description": "Creates and draws a circle with the specified center coordinates and radius. If a name is provided, it will be used to reference the circle.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "center_x": {"type": "number", "description": "The X coordinate of the circle's center"}, + "center_y": {"type": "number", "description": "The Y coordinate of the circle's center"}, + "radius": {"type": "number", "description": "The radius of the circle"}, + "color": {"type": ["string", "null"], "description": "Optional color to assign to the circle"}, + "name": {"type": ["string", "null"], "description": "Optional name for the circle"}, + }, + "required": ["center_x", "center_y", "radius", "color", "name"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "delete_circle", + "description": "Deletes the circle with the given name", + "strict": True, + "parameters": { + "type": "object", + "properties": {"name": {"type": "string", "description": "The name of the circle"}}, + "required": ["name"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "update_circle", + "description": "Updates editable properties of an existing circle (color or center position). Provide null for fields to keep them unchanged.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "Existing name of the circle to edit."}, + "new_color": {"type": ["string", "null"], "description": "Optional new color for the circle."}, + "new_center_x": { + "type": ["number", "null"], + "description": "Optional new x-coordinate for the circle center (requires y value when provided).", + }, + "new_center_y": { + "type": ["number", "null"], + "description": "Optional new y-coordinate for the circle center (requires x value when provided).", + }, + }, + "required": ["name", "new_color", "new_center_x", "new_center_y"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "create_circle_arc", + "description": "Creates an arc on a circle. Use this for requests like 'draw an arc with center (x,y), radius r' or 'arc on circle C between two points'. Supports standalone center/radius arcs, arcs on an existing circle, or deriving from three points with center_point_choice.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "point1_x": { + "type": "number", + "description": "Reference X coordinate for the first arc point (snapped to the circle when center/radius are provided)", + }, + "point1_y": { + "type": "number", + "description": "Reference Y coordinate for the first arc point (snapped to the circle when center/radius are provided)", + }, + "point2_x": { + "type": "number", + "description": "Reference X coordinate for the second arc point (snapped to the circle when center/radius are provided)", + }, + "point2_y": { + "type": "number", + "description": "Reference Y coordinate for the second arc point (snapped to the circle when center/radius are provided)", + }, + "point1_name": {"type": ["string", "null"], "description": "Optional name for the first arc point"}, + "point2_name": { + "type": ["string", "null"], + "description": "Optional name for the second arc point", + }, + "point3_x": { + "type": ["number", "null"], + "description": "Optional reference X coordinate for a third point when deriving the circle from three points", + }, + "point3_y": { + "type": ["number", "null"], + "description": "Optional reference Y coordinate for a third point when deriving the circle from three points", + }, + "point3_name": { + "type": ["string", "null"], + "description": "Optional name for the third point (used when deriving the circle from three points)", + }, + "center_point_choice": { + "type": ["string", "null"], + "description": "Optional selector ('point1', 'point2', or 'point3') indicating which provided point should be treated as the circle center", + }, + "circle_name": {"type": ["string", "null"], "description": "Existing circle to attach the arc to"}, + "center_x": { + "type": ["number", "null"], + "description": "Circle center x-coordinate when defining a standalone arc", + }, + "center_y": { + "type": ["number", "null"], + "description": "Circle center y-coordinate when defining a standalone arc", + }, + "radius": { + "type": ["number", "null"], + "description": "Circle radius when defining a standalone arc", + }, + "use_major_arc": { + "type": "boolean", + "description": "True to draw the major arc, False for the minor arc", + }, + "arc_name": {"type": ["string", "null"], "description": "Optional custom arc name"}, + "color": { + "type": ["string", "null"], + "description": f"Optional CSS color for the arc (defaults to {DEFAULT_CIRCLE_ARC_COLOR})", + }, + }, + "required": [ + "point1_x", + "point1_y", + "point2_x", + "point2_y", + "point1_name", + "point2_name", + "point3_x", + "point3_y", + "point3_name", + "center_point_choice", + "circle_name", + "center_x", + "center_y", + "radius", + "use_major_arc", + "arc_name", + "color", + ], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "delete_circle_arc", + "description": "Deletes a circle arc by name.", + "strict": True, + "parameters": { + "type": "object", + "properties": {"name": {"type": "string", "description": "Name of the circle arc to delete"}}, + "required": ["name"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "update_circle_arc", + "description": "Updates editable properties of an existing circle arc (color or major/minor toggle). Provide null for fields to keep them unchanged.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "Existing name of the arc to edit"}, + "new_color": {"type": ["string", "null"], "description": "Optional new color for the arc"}, + "use_major_arc": { + "type": ["boolean", "null"], + "description": "Set to true for the major arc, false for the minor arc", + }, + }, + "required": ["name", "new_color", "use_major_arc"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "create_ellipse", + "description": "Creates an ellipse with the specified center point, x-radius, y-radius, and optional rotation angle", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "center_x": {"type": "number", "description": "The x-coordinate of the ellipse's center"}, + "center_y": {"type": "number", "description": "The y-coordinate of the ellipse's center"}, + "radius_x": { + "type": "number", + "description": "The radius of the ellipse in the x-direction (half the width)", + }, + "radius_y": { + "type": "number", + "description": "The radius of the ellipse in the y-direction (half the height)", + }, + "rotation_angle": { + "type": ["number", "null"], + "description": "Optional angle in degrees to rotate the ellipse around its center (default: 0)", + }, + "color": {"type": ["string", "null"], "description": "Optional color for the ellipse"}, + "name": {"type": ["string", "null"], "description": "Optional name for the ellipse"}, + }, + "required": ["center_x", "center_y", "radius_x", "radius_y", "rotation_angle", "color", "name"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "delete_ellipse", + "description": "Deletes the ellipse with the given name", + "strict": True, + "parameters": { + "type": "object", + "properties": {"name": {"type": "string", "description": "The name of the ellipse"}}, + "required": ["name"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "update_ellipse", + "description": "Updates editable properties of an existing ellipse (color, radii, rotation, or center). Provide null for fields that should remain unchanged.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "Existing name of the ellipse to edit."}, + "new_color": {"type": ["string", "null"], "description": "Optional new color for the ellipse."}, + "new_radius_x": { + "type": ["number", "null"], + "description": "Optional new horizontal radius (requires ellipse to be solitary).", + }, + "new_radius_y": { + "type": ["number", "null"], + "description": "Optional new vertical radius (requires ellipse to be solitary).", + }, + "new_rotation_angle": { + "type": ["number", "null"], + "description": "Optional new rotation angle in degrees.", + }, + "new_center_x": { + "type": ["number", "null"], + "description": "Optional new x-coordinate for the center (requires y value when provided).", + }, + "new_center_y": { + "type": ["number", "null"], + "description": "Optional new y-coordinate for the center (requires x value when provided).", + }, + }, + "required": [ + "name", + "new_color", + "new_radius_x", + "new_radius_y", + "new_rotation_angle", + "new_center_x", + "new_center_y", + ], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "create_label", + "description": ( + "Creates a text label anchored at a math-space coordinate " + f"(max {LABEL_TEXT_MAX_LENGTH} chars, wraps every {LABEL_LINE_WRAP_THRESHOLD} chars)" + ), + "strict": True, + "parameters": { + "type": "object", + "properties": { + "x": {"type": "number", "description": "Math-space X coordinate for the label anchor"}, + "y": {"type": "number", "description": "Math-space Y coordinate for the label anchor"}, + "text": { + "type": "string", + "description": "Label text content; lines wrap after 40 characters", + "maxLength": 160, + }, + "name": {"type": ["string", "null"], "description": "Optional label name used for later updates"}, + "color": {"type": ["string", "null"], "description": "Optional CSS color for the label text"}, + "font_size": {"type": ["number", "null"], "description": "Optional font size in pixels"}, + "rotation_degrees": { + "type": ["number", "null"], + "description": "Optional angle in degrees to rotate the label text", + }, + }, + "required": [ + "x", + "y", + "text", + "name", + "color", + "font_size", + "rotation_degrees", + ], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "delete_label", + "description": "Deletes an existing label by name", + "strict": True, + "parameters": { + "type": "object", + "properties": {"name": {"type": "string", "description": "Name of the label to delete"}}, + "required": ["name"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "update_label", + "description": "Updates editable properties of an existing label (text, color, position, font size, rotation). Provide null for fields that should remain unchanged.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "Existing name of the label to edit"}, + "new_text": {"type": ["string", "null"], "description": "Optional replacement text for the label"}, + "new_x": {"type": ["number", "null"], "description": "Optional new x-coordinate (requires new_y)"}, + "new_y": {"type": ["number", "null"], "description": "Optional new y-coordinate (requires new_x)"}, + "new_color": {"type": ["string", "null"], "description": "Optional new text color"}, + "new_font_size": { + "type": ["number", "null"], + "description": "Optional new font size in math-space units", + }, + "new_rotation_degrees": { + "type": ["number", "null"], + "description": "Optional rotation angle in degrees", + }, + }, + "required": [ + "name", + "new_text", + "new_x", + "new_y", + "new_color", + "new_font_size", + "new_rotation_degrees", + ], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "draw_function", + "description": "Plots the given mathematical function on the canvas between the specified left and right bounds.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "function_string": { + "type": "string", + "description": "The mathematical expression represented as a string, e.g., '2*x + 3'.", + }, + "name": { + "type": ["string", "null"], + "description": "The name or label for the plotted function. Useful for referencing later.", + }, + "left_bound": { + "type": ["number", "null"], + "description": "The left bound of the interval on which to plot the function.", + }, + "right_bound": { + "type": ["number", "null"], + "description": "The right bound of the interval on which to plot the function.", + }, + "color": {"type": ["string", "null"], "description": "Optional color for the plotted function."}, + "undefined_at": { + "type": ["array", "null"], + "description": "Optional list of x-values where the function is explicitly undefined (holes). E.g., [0, 2] means the function has holes at x=0 and x=2.", + "items": {"type": "number"}, + }, + }, + "required": ["function_string", "name", "left_bound", "right_bound", "color", "undefined_at"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "delete_function", + "description": "Removes the plotted mathematical function with the given name from the canvas.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "The name or label of the function to be deleted."} + }, + "required": ["name"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "update_function", + "description": "Updates editable properties of an existing plotted function (color and/or bounds). Provide null for fields to leave them unchanged.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "Existing name of the function to edit."}, + "new_color": { + "type": ["string", "null"], + "description": "Optional new color for the function plot.", + }, + "new_left_bound": {"type": ["number", "null"], "description": "Optional new left plotting bound."}, + "new_right_bound": { + "type": ["number", "null"], + "description": "Optional new right plotting bound.", + }, + }, + "required": ["name", "new_color", "new_left_bound", "new_right_bound"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "draw_piecewise_function", + "description": "Plots a piecewise-defined function with different expressions for different intervals. Each piece specifies an expression and its valid interval bounds. Use null for unbounded intervals (extending to infinity). Use undefined_at for explicit holes (points where the function is undefined).", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "pieces": { + "type": "array", + "minItems": 1, + "description": "List of function pieces, each defining an expression and its interval.", + "items": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "Mathematical expression for this piece (e.g., 'x^2', 'sin(x)').", + }, + "left": { + "type": ["number", "null"], + "description": "Left interval bound (null for negative infinity).", + }, + "right": { + "type": ["number", "null"], + "description": "Right interval bound (null for positive infinity).", + }, + "left_inclusive": { + "type": "boolean", + "description": "Whether the left bound is included in the interval.", + }, + "right_inclusive": { + "type": "boolean", + "description": "Whether the right bound is included in the interval.", + }, + "undefined_at": { + "type": ["array", "null"], + "description": "Optional list of x-values where this piece is explicitly undefined (holes).", + "items": {"type": "number"}, + }, }, - "vertices": { - "type": "array", - "minItems": 3, - "description": "Ordered list of polygon vertex coordinates used for lookup.", - "items": { - "type": "object", - "properties": { - "x": { - "type": "number", - "description": "X coordinate of the vertex." + "required": [ + "expression", + "left", + "right", + "left_inclusive", + "right_inclusive", + "undefined_at", + ], + "additionalProperties": False, + }, + }, + "name": {"type": ["string", "null"], "description": "Optional name for the piecewise function."}, + "color": {"type": ["string", "null"], "description": "Optional color for the plotted function."}, + }, + "required": ["pieces", "name", "color"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "delete_piecewise_function", + "description": "Removes the plotted piecewise function with the given name from the canvas.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "The name of the piecewise function to delete."} + }, + "required": ["name"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "update_piecewise_function", + "description": "Updates editable properties of an existing piecewise function (currently just color). Provide null for fields to leave them unchanged.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "Existing name of the piecewise function to edit."}, + "new_color": { + "type": ["string", "null"], + "description": "Optional new color for the function plot.", + }, + }, + "required": ["name", "new_color"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "draw_parametric_function", + "description": "Plots a parametric curve defined by x(t) and y(t) expressions. Use this for curves that cannot be expressed as y=f(x), such as circles, spirals, Lissajous figures, and other complex shapes. The parameter t ranges from t_min to t_max (default 0 to 2*pi for periodic curves).", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "x_expression": { + "type": "string", + "description": "Mathematical expression for x as a function of t. Example: 'cos(t)' for a circle, 't*cos(t)' for a spiral.", + }, + "y_expression": { + "type": "string", + "description": "Mathematical expression for y as a function of t. Example: 'sin(t)' for a circle, 't*sin(t)' for a spiral.", + }, + "name": { + "type": ["string", "null"], + "description": "Optional name or label for the parametric curve. Useful for referencing later.", + }, + "t_min": {"type": ["number", "null"], "description": "Minimum value of parameter t. Default is 0."}, + "t_max": { + "type": ["number", "null"], + "description": "Maximum value of parameter t. Default is 2*pi (~6.283) for periodic curves.", + }, + "color": {"type": ["string", "null"], "description": "Optional color for the plotted curve."}, + }, + "required": ["x_expression", "y_expression", "name", "t_min", "t_max", "color"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "delete_parametric_function", + "description": "Removes the plotted parametric function with the given name from the canvas.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "The name of the parametric function to delete."} + }, + "required": ["name"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "update_parametric_function", + "description": "Updates editable properties of an existing parametric function (color, t_min, t_max). Provide null for fields to leave them unchanged.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "Existing name of the parametric function to edit."}, + "new_color": {"type": ["string", "null"], "description": "Optional new color for the curve."}, + "new_t_min": { + "type": ["number", "null"], + "description": "Optional new minimum value of parameter t.", + }, + "new_t_max": { + "type": ["number", "null"], + "description": "Optional new maximum value of parameter t.", + }, + }, + "required": ["name", "new_color", "new_t_min", "new_t_max"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "draw_tangent_line", + "description": "Draws a tangent line segment to a curve at a specified point. For functions y=f(x), the parameter is the x-coordinate. For parametric curves, it's the t value. For circles and ellipses, it's the angle in radians from the positive x-axis.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "curve_name": { + "type": "string", + "description": "Name of the curve to draw tangent to (function, parametric function, circle, or ellipse)", + }, + "parameter": { + "type": "number", + "description": "Position on curve: x-coordinate for functions, t-value for parametric curves, or angle (radians) for circles/ellipses", + }, + "name": {"type": ["string", "null"], "description": "Optional name for the tangent line segment"}, + "length": { + "type": ["number", "null"], + "description": "Total length of the tangent segment in math units (default: 4.0)", + }, + "color": { + "type": ["string", "null"], + "description": "Optional color for the tangent line (default: same as curve)", + }, + }, + "required": ["curve_name", "parameter", "name", "length", "color"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "draw_normal_line", + "description": "Draws a normal line segment (perpendicular to tangent) to a curve at a specified point. For functions y=f(x), the parameter is the x-coordinate. For parametric curves, it's the t value. For circles and ellipses, it's the angle in radians from the positive x-axis.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "curve_name": { + "type": "string", + "description": "Name of the curve to draw normal to (function, parametric function, circle, or ellipse)", + }, + "parameter": { + "type": "number", + "description": "Position on curve: x-coordinate for functions, t-value for parametric curves, or angle (radians) for circles/ellipses", + }, + "name": {"type": ["string", "null"], "description": "Optional name for the normal line segment"}, + "length": { + "type": ["number", "null"], + "description": "Total length of the normal segment in math units (default: 4.0)", + }, + "color": { + "type": ["string", "null"], + "description": "Optional color for the normal line (default: same as curve)", + }, + }, + "required": ["curve_name", "parameter", "name", "length", "color"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "construct_midpoint", + "description": "Constructs a point at the midpoint of a segment or between two named points. Provide either 'segment_name' or both 'p1_name' and 'p2_name'.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "p1_name": { + "type": ["string", "null"], + "description": "Name of the first point (use with p2_name)", + }, + "p2_name": { + "type": ["string", "null"], + "description": "Name of the second point (use with p1_name)", + }, + "segment_name": { + "type": ["string", "null"], + "description": "Name of the segment whose midpoint to find (alternative to p1_name/p2_name)", + }, + "name": {"type": ["string", "null"], "description": "Optional name for the created midpoint"}, + "color": {"type": ["string", "null"], "description": "Optional color for the midpoint"}, + }, + "required": ["p1_name", "p2_name", "segment_name", "name", "color"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "construct_perpendicular_bisector", + "description": "Constructs the perpendicular bisector of a segment. Creates a new segment that passes through the midpoint and is perpendicular to the original.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "segment_name": {"type": "string", "description": "Name of the segment to bisect perpendicularly"}, + "length": { + "type": ["number", "null"], + "description": "Total length of the bisector segment in math units (default: 6.0)", + }, + "name": { + "type": ["string", "null"], + "description": "Optional name for the created bisector segment", + }, + "color": {"type": ["string", "null"], "description": "Optional color for the bisector segment"}, + }, + "required": ["segment_name", "length", "name", "color"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "construct_perpendicular_from_point", + "description": "Drops a perpendicular from a point to a segment. Creates the foot point on the line and a segment from the original point to the foot.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "point_name": {"type": "string", "description": "Name of the point to project onto the segment"}, + "segment_name": {"type": "string", "description": "Name of the target segment"}, + "name": {"type": ["string", "null"], "description": "Optional name for the perpendicular segment"}, + "color": {"type": ["string", "null"], "description": "Optional color for created drawables"}, + }, + "required": ["point_name", "segment_name", "name", "color"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "construct_angle_bisector", + "description": "Constructs a segment along the bisector of an angle. Provide either 'angle_name' for an existing angle, or 'vertex_name', 'p1_name', 'p2_name' to define the angle by three points.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "vertex_name": { + "type": ["string", "null"], + "description": "Name of the angle vertex point (use with p1_name and p2_name)", + }, + "p1_name": { + "type": ["string", "null"], + "description": "Name of the first arm endpoint (use with vertex_name and p2_name)", + }, + "p2_name": { + "type": ["string", "null"], + "description": "Name of the second arm endpoint (use with vertex_name and p1_name)", + }, + "angle_name": { + "type": ["string", "null"], + "description": "Name of an existing angle to bisect (alternative to vertex/p1/p2)", + }, + "length": { + "type": ["number", "null"], + "description": "Length of the bisector segment in math units (default: 6.0)", + }, + "name": { + "type": ["string", "null"], + "description": "Optional name for the created bisector segment", + }, + "color": {"type": ["string", "null"], "description": "Optional color for the bisector segment"}, + }, + "required": ["vertex_name", "p1_name", "p2_name", "angle_name", "length", "name", "color"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "construct_parallel_line", + "description": "Constructs a segment through a point that is parallel to a given segment. The new segment is centered on the specified point.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "segment_name": { + "type": "string", + "description": "Name of the reference segment to be parallel to", + }, + "point_name": { + "type": "string", + "description": "Name of the point the parallel line passes through", + }, + "length": { + "type": ["number", "null"], + "description": "Total length of the parallel segment in math units (default: 6.0)", + }, + "name": { + "type": ["string", "null"], + "description": "Optional name for the created parallel segment", + }, + "color": {"type": ["string", "null"], "description": "Optional color for the parallel segment"}, + }, + "required": ["segment_name", "point_name", "length", "name", "color"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "construct_circumcircle", + "description": "Constructs the circumscribed circle (circumcircle) of a triangle or three points. The circumcircle passes through all three vertices. Provide either triangle_name or all three point names.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "triangle_name": { + "type": ["string", "null"], + "description": "Name of an existing triangle (alternative to specifying three points)", + }, + "p1_name": { + "type": ["string", "null"], + "description": "Name of the first point (used with p2_name and p3_name instead of triangle_name)", + }, + "p2_name": {"type": ["string", "null"], "description": "Name of the second point"}, + "p3_name": {"type": ["string", "null"], "description": "Name of the third point"}, + "name": {"type": ["string", "null"], "description": "Optional name for the created circumcircle"}, + "color": {"type": ["string", "null"], "description": "Optional color for the circumcircle"}, + }, + "required": ["triangle_name", "p1_name", "p2_name", "p3_name", "name", "color"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "construct_incircle", + "description": "Constructs the inscribed circle (incircle) of a triangle. The incircle is tangent to all three sides of the triangle.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "triangle_name": {"type": "string", "description": "Name of an existing triangle"}, + "name": {"type": ["string", "null"], "description": "Optional name for the created incircle"}, + "color": {"type": ["string", "null"], "description": "Optional color for the incircle"}, + }, + "required": ["triangle_name", "name", "color"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "evaluate_expression", + "description": "Calculate or evaluate a mathematical expression and return the numerical result. Use for arithmetic (+, -, *, /, ^), algebra, and math functions. Supports variables (x, y), constants (e, pi), and functions: sin, cos, tan, sqrt, log, log10, log2, factorial, arrangements, permutations, combinations, asin, acos, atan, sinh, cosh, tanh, exp, abs, pow, det, bin, round, ceil, floor, trunc, max, min, sum, gcd, lcm, is_prime, prime_factors, mod_pow, mod_inverse, next_prime, prev_prime, totient, divisors, mean, median, mode, stdev, variance, random, randint. Also supports sequence/series helpers such as summation('n^2','n',0,50), product('n+1','n',0,10), arithmetic_sum(1,2,20), geometric_sum(3,0.5,15), ratio_test('1/factorial(n)','n'), root_test('(1/2)^n','n'), and p_series_test(2).", + "parameters": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "The mathematical expression to be evaluated. Example: '5*x - 1' or 'sin(x)'", + }, + "variables": { + "type": ["object", "null"], + "description": "Dictionary containing key-value pairs of the variables and values to be substituted in the expression. Example: {'x': 2, 'y': 3}", + }, + }, + "required": ["expression", "variables"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "evaluate_linear_algebra_expression", + "description": "Evaluates matrix/vector/scalar expressions (linear algebra). Use for determinant, inverse, transpose, matrix multiplication, and eigen computations with named objects.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "objects": { + "type": "array", + "minItems": 1, + "description": "List of named linear algebra objects available to the expression evaluator.", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "Identifier used to reference the object inside the expression. Must start with a letter or underscore.", + }, + "value": { + "description": "Scalar, vector, or matrix definition for the object.", + "anyOf": [ + {"type": "number"}, + {"type": "array", "minItems": 1, "items": {"type": "number"}}, + { + "type": "array", + "minItems": 1, + "items": {"type": "array", "minItems": 1, "items": {"type": "number"}}, }, - "y": { - "type": "number", - "description": "Y coordinate of the vertex." - } - }, - "required": ["x", "y"], - "additionalProperties": False - } - } - }, - "required": ["polygon_type", "name", "vertices"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "update_polygon", - "description": "Updates editable properties of an existing polygon (currently just color). Provide null for fields that should remain unchanged.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "polygon_name": { - "type": "string", - "description": "Existing name of the polygon to edit." - }, - "polygon_type": { - "type": ["string", "null"], - "description": "Optional polygon classification to disambiguate the lookup." + ], + }, }, - "new_color": { - "type": ["string", "null"], - "description": "Optional new color for the polygon edges." - } - }, - "required": ["polygon_name", "polygon_type", "new_color"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "create_circle", - "description": "Creates and draws a circle with the specified center coordinates and radius. If a name is provided, it will be used to reference the circle.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "center_x": { + "required": ["name", "value"], + "additionalProperties": False, + }, + }, + "expression": { + "type": "string", + "description": "Math.js compatible expression composed of the provided object names and supported linear algebra functions. Example: 'A + B' or 'inv(A) * b'.", + }, + }, + "required": ["objects", "expression"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "convert", + "description": "Converts a numeric value between units (e.g., degrees to radians, km to m, F to C). For Cartesian/polar coordinate conversion, use convert_coordinates instead.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "value": {"type": "number", "description": "The value to be converted"}, + "from_unit": {"type": "string", "description": "The unit to convert from"}, + "to_unit": {"type": "string", "description": "The unit to convert to"}, + }, + "required": ["value", "from_unit", "to_unit"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "limit", + "description": "Computes the limit of a function as it approaches a value", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "The mathematical expression represented as a string. Example: 'log(x)^2'.", + }, + "variable": { + "type": "string", + "description": "The variable with respect to which the limit is computed.", + }, + "value_to_approach": {"type": "string", "description": "The value the variable approaches."}, + }, + "required": ["expression", "variable", "value_to_approach"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "derive", + "description": "Computes the derivative of a function with respect to a variable", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "The mathematical expression represented as a string. Example: '2*x + 3'.", + }, + "variable": { + "type": "string", + "description": "The variable with respect to which the derivative is computed.", + }, + }, + "required": ["expression", "variable"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "integrate", + "description": "Computes the integral of a function with respect to a variable. Specify the lower and upper bounds only for definite integrals.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "The mathematical expression represented as a string. Example: '2*x + 3'", + }, + "variable": { + "type": "string", + "description": "The variable with respect to which the integral is computed. Example: 'x'", + }, + "lower_bound": {"type": ["number", "null"], "description": "The lower bound of the integral."}, + "upper_bound": {"type": ["number", "null"], "description": "The upper bound of the integral."}, + }, + "required": ["expression", "variable", "lower_bound", "upper_bound"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "numeric_integrate", + "description": "Numerically approximate a definite integral over finite bounds using trapezoid, midpoint, or Simpson's rule.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "Integrand expression such as 'sin(x)' or 'x^2 + 1'.", + }, + "variable": {"type": "string", "description": "Integration variable, typically 'x'."}, + "lower_bound": {"type": "number", "description": "Lower finite bound of integration."}, + "upper_bound": {"type": "number", "description": "Upper finite bound of integration."}, + "method": { + "type": "string", + "enum": ["trapezoid", "midpoint", "simpson"], + "description": "Numeric integration method. Optional; defaults to 'simpson' when omitted.", + }, + "steps": { + "type": "integer", + "description": "Number of subintervals. Must be a positive integer and <= 10000. Optional; defaults to 200 when omitted.", + }, + }, + "required": ["expression", "variable", "lower_bound", "upper_bound"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "simplify", + "description": "Simplifies a mathematical expression.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "The mathematical expression represented as a string. Example: 'x^2 + 2*x + 1'", + } + }, + "required": ["expression"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "expand", + "description": "Expands a mathematical expression.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "The mathematical expression represented as a string. Example: '(x+1)^2'", + } + }, + "required": ["expression"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "factor", + "description": "Factors a mathematical expression.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "The mathematical expression represented as a string. Example: 'x^2 - 1'", + } + }, + "required": ["expression"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "solve", + "description": "Solves a mathematical equation for a given variable.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "equation": { + "type": "string", + "description": "The mathematical equation represented as a string. Example: 'x^2 - 1'", + }, + "variable": {"type": "string", "description": "The variable to solve for. Example: 'x'"}, + }, + "required": ["equation", "variable"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "solve_system_of_equations", + "description": "Solves a system of mathematical equations.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "equations": { + "type": "array", + "description": "An array of mathematical equations represented as strings. Example: ['2*x/3 = y', 'x-2 = y']", + "items": {"type": "string"}, + } + }, + "required": ["equations"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "solve_numeric", + "description": "Numerically solves a system of equations using multi-start Newton-Raphson. Use for transcendental, mixed nonlinear, or systems that can't be solved symbolically (e.g., sin(x) + y = 1, x^2 + y^2 = 4). Supports any number of variables. Returns multiple solutions when they exist.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "equations": { + "type": "array", + "description": "Array of equation strings. Use '=' for equations (e.g., ['sin(x) + y = 1', 'x^2 + y^2 = 4']). If no '=' is present, the expression is assumed equal to 0. Variables are auto-detected.", + "items": {"type": "string"}, + } + }, + "required": ["equations"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "translate_object", + "description": "Moves/shifts/translates an existing drawable object or function by x and y offsets (dx, dy).", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "The exact name of the object to translate taken from the canvas state", + }, + "x_offset": { + "type": "number", + "description": "The horizontal translation distance (positive moves right, negative moves left)", + }, + "y_offset": { + "type": "number", + "description": "The vertical translation distance (positive moves up, negative moves down)", + }, + }, + "required": ["name", "x_offset", "y_offset"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "rotate_object", + "description": "Rotates a drawable object by the specified angle. By default rotates around the object's own center. When center_x and center_y are provided, rotates around that arbitrary point (works for all types including points and circles).", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "The name of the object to rotate"}, + "angle": { + "type": "number", + "description": "The angle in degrees to rotate the object (positive for counterclockwise)", + }, + "center_x": { + "type": ["number", "null"], + "description": "X-coordinate of the rotation center. Must be provided together with center_y for rotation around an arbitrary point. Omit (null) to rotate around the object's own center.", + }, + "center_y": { + "type": ["number", "null"], + "description": "Y-coordinate of the rotation center. Must be provided together with center_x for rotation around an arbitrary point. Omit (null) to rotate around the object's own center.", + }, + }, + "required": ["name", "angle", "center_x", "center_y"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "reflect_object", + "description": "Reflects (mirrors) a drawable object across an axis or line. Supports x-axis, y-axis, an arbitrary line (ax + by + c = 0), or a named segment as the reflection axis.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "The name of the object to reflect"}, + "axis": { + "type": "string", + "enum": ["x_axis", "y_axis", "line", "segment"], + "description": "The reflection axis type", + }, + "line_a": { + "type": ["number", "null"], + "description": "Coefficient a in ax + by + c = 0 (required when axis is 'line')", + }, + "line_b": { + "type": ["number", "null"], + "description": "Coefficient b in ax + by + c = 0 (required when axis is 'line')", + }, + "line_c": { + "type": ["number", "null"], + "description": "Coefficient c in ax + by + c = 0 (required when axis is 'line')", + }, + "segment_name": { + "type": ["string", "null"], + "description": "Name of a segment to use as the reflection axis (required when axis is 'segment')", + }, + }, + "required": ["name", "axis", "line_a", "line_b", "line_c", "segment_name"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "scale_object", + "description": "Scales (dilates) a drawable object by the specified factors from a center point. Use equal sx and sy for uniform scaling. Circles require uniform scaling (equal sx and sy); for non-uniform scaling, convert to an ellipse first.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "The name of the object to scale"}, + "sx": { + "type": "number", + "description": "Horizontal scale factor (e.g. 2 to double width, 0.5 to halve)", + }, + "sy": { + "type": "number", + "description": "Vertical scale factor (e.g. 2 to double height, 0.5 to halve)", + }, + "cx": {"type": "number", "description": "X-coordinate of the scaling center"}, + "cy": {"type": "number", "description": "Y-coordinate of the scaling center"}, + }, + "required": ["name", "sx", "sy", "cx", "cy"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "shear_object", + "description": "Shears a drawable object along the specified axis from a center point. Not supported for circles and ellipses.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "The name of the object to shear"}, + "axis": { + "type": "string", + "enum": ["horizontal", "vertical"], + "description": "The shear direction", + }, + "factor": { + "type": "number", + "description": "The shear factor (e.g. 0.5 shifts x by 0.5*dy for horizontal shear)", + }, + "cx": {"type": "number", "description": "X-coordinate of the shear center"}, + "cy": {"type": "number", "description": "Y-coordinate of the shear center"}, + }, + "required": ["name", "axis", "factor", "cx", "cy"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "save_workspace", + "description": "Saves the current workspace state to a file. If no name is provided, saves to the current workspace file with timestamp. The workspace name MUST only contain alphanumeric characters, underscores, or hyphens (no spaces, dots, slashes, or other special characters).", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "name": { + "type": ["string", "null"], + "description": "Optional name for the workspace. Must contain only alphanumeric characters, underscores, or hyphens (e.g., 'my_workspace', 'workspace-1', 'test123'). If not provided, saves to current workspace.", + } + }, + "required": ["name"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "load_workspace", + "description": "Loads a workspace from a file. If no name is provided, loads the (most recent) current workspace. The workspace name MUST only contain alphanumeric characters, underscores, or hyphens (no spaces, dots, slashes, or other special characters).", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "name": { + "type": ["string", "null"], + "description": "Optional name of the workspace to load. Must contain only alphanumeric characters, underscores, or hyphens (e.g., 'my_workspace', 'workspace-1', 'test123'). If not provided, loads current workspace.", + } + }, + "required": ["name"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "list_workspaces", + "description": "Lists all saved workspaces. Only shows workspaces with valid names (containing only alphanumeric characters, underscores, or hyphens).", + "strict": True, + "parameters": {"type": "object", "properties": {}, "required": [], "additionalProperties": False}, + }, + }, + { + "type": "function", + "function": { + "name": "delete_workspace", + "description": "Delete a workspace by name. The workspace name MUST only contain alphanumeric characters, underscores, or hyphens (no spaces, dots, slashes, or other special characters).", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "Name of the workspace to delete. Must contain only alphanumeric characters, underscores, or hyphens (e.g., 'my_workspace', 'workspace-1', 'test123').", + } + }, + "required": ["name"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "create_colored_area", + "description": "Creates a colored area between two drawables (functions, segments, or a function and a segment). If only one drawable is provided, the area will be between that drawable and the x-axis.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "drawable1_name": { + "type": "string", + "description": "Name of the first drawable (function or segment). Use 'x_axis' for the x-axis.", + }, + "drawable2_name": { + "type": ["string", "null"], + "description": "Optional name of the second drawable (function or segment). Use 'x_axis' for the x-axis. If not provided, area will be between drawable1 and x-axis.", + }, + "left_bound": { + "type": ["number", "null"], + "description": "Optional left bound for function areas. Only used when at least one drawable is a function.", + }, + "right_bound": { + "type": ["number", "null"], + "description": "Optional right bound for function areas. Only used when at least one drawable is a function.", + }, + "color": { + "type": ["string", "null"], + "description": "Optional color for the area. Default is 'lightblue'.", + }, + "opacity": { + "type": ["number", "null"], + "description": "Optional opacity for the area between 0 and 1. Default is 0.3.", + }, + }, + "required": ["drawable1_name", "drawable2_name", "left_bound", "right_bound", "color", "opacity"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "create_region_colored_area", + "description": "Fill a region defined by a boolean expression or a closed shape. Supports expressions with operators (& | - ^), arcs, circles, ellipses, polygons, and segments. Expression takes precedence if provided.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "expression": { + "type": ["string", "null"], + "description": "Boolean region expression using shape names and operators. Examples: 'ArcMaj_AB & CD' (arc intersected with segment), 'circle_A - triangle_ABC' (difference). Takes precedence over other parameters.", + }, + "triangle_name": { + "type": ["string", "null"], + "description": "Name of an existing triangle to fill.", + }, + "rectangle_name": { + "type": ["string", "null"], + "description": "Name of an existing rectangle to fill.", + }, + "polygon_segment_names": { + "type": ["array", "null"], + "items": {"type": "string"}, + "description": "List of segment names that form a closed polygon loop (at least three segments).", + }, + "circle_name": { + "type": ["string", "null"], + "description": "Name of the circle to fill or to use with a chord segment.", + }, + "ellipse_name": { + "type": ["string", "null"], + "description": "Name of the ellipse to fill or to use with a chord segment.", + }, + "chord_segment_name": { + "type": ["string", "null"], + "description": "Segment name that serves as the chord/clip when creating a circle or ellipse segment region.", + }, + "arc_clockwise": { + "type": ["boolean", "null"], + "description": "Set to true to trace the arc clockwise when using a round shape with a chord segment. Default is false (counter-clockwise).", + }, + "resolution": { + "type": ["number", "null"], + "description": "Number of samples used to approximate curved boundaries. Defaults to 96.", + }, + "color": { + "type": ["string", "null"], + "description": "Optional color for the filled area. Default is 'lightblue'.", + }, + "opacity": { + "type": ["number", "null"], + "description": "Optional opacity between 0 and 1. Default is 0.3.", + }, + }, + "required": [ + "expression", + "triangle_name", + "rectangle_name", + "polygon_segment_names", + "circle_name", + "ellipse_name", + "chord_segment_name", + "arc_clockwise", + "resolution", + "color", + "opacity", + ], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "delete_colored_area", + "description": "Deletes a colored area by its name", + "strict": True, + "parameters": { + "type": "object", + "properties": {"name": {"type": "string", "description": "Name of the colored area to delete"}}, + "required": ["name"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "update_colored_area", + "description": "Updates editable properties of an existing colored area (color, opacity, and for function-bounded areas, optional left/right bounds). Provide null for fields that should remain unchanged.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "Existing name of the colored area to edit."}, + "new_color": {"type": ["string", "null"], "description": "Optional new color for the area."}, + "new_opacity": {"type": ["number", "null"], "description": "Optional new opacity between 0 and 1."}, + "new_left_bound": { + "type": ["number", "null"], + "description": "Optional new left bound (functions-bounded areas only).", + }, + "new_right_bound": { + "type": ["number", "null"], + "description": "Optional new right bound (functions-bounded areas only).", + }, + }, + "required": ["name", "new_color", "new_opacity", "new_left_bound", "new_right_bound"], + "additionalProperties": False, + }, + }, + }, + # START GRAPH FUNCTIONS + { + "type": "function", + "function": { + "name": "generate_graph", + "description": "Generates a graph or tree on the canvas using provided vertices/edges or an adjacency matrix. Returns the created graph state and drawable names for follow-up highlighting.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "name": {"type": ["string", "null"]}, + "graph_type": { + "type": "string", + "enum": ["graph", "tree", "dag"], + "description": "Type of graph to create.", + }, + "directed": {"type": ["boolean", "null"]}, + "root": {"type": ["string", "null"], "description": "Root id for trees."}, + "layout": { + "type": ["string", "null"], + "description": "Layout hint: 'tree' or 'hierarchical' for top-down tree display (default for trees), 'radial' for concentric rings from root, 'circular' for nodes on a circle, 'grid' for rectangular grid, 'force' for force-directed.", + }, + "placement_box": { + "type": ["object", "null"], + "description": "Bounding box for vertex placement. Defined from bottom-left corner in math coordinates (y increases upward). Box spans from (x, y) to (x + width, y + height).", + "properties": { + "x": {"type": "number", "description": "Left edge X coordinate (bottom-left corner)"}, + "y": { "type": "number", - "description": "The X coordinate of the circle's center" + "description": "Bottom edge Y coordinate (bottom-left corner, in math coords where y increases upward)", }, - "center_y": { + "width": { "type": "number", - "description": "The Y coordinate of the circle's center" + "description": "Box width extending rightward (positive X direction)", }, - "radius": { + "height": { "type": "number", - "description": "The radius of the circle" - }, - "color": { - "type": ["string", "null"], - "description": "Optional color to assign to the circle" - }, - "name": { - "type": ["string", "null"], - "description": "Optional name for the circle" - } - }, - "required": ["center_x", "center_y", "radius", "color", "name"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "delete_circle", - "description": "Deletes the circle with the given name", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "name": { - "type": "string", - "description": "The name of the circle" - } - }, - "required": ["name"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "update_circle", - "description": "Updates editable properties of an existing circle (color or center position). Provide null for fields to keep them unchanged.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "name": { - "type": "string", - "description": "Existing name of the circle to edit." - }, - "new_color": { - "type": ["string", "null"], - "description": "Optional new color for the circle." - }, - "new_center_x": { - "type": ["number", "null"], - "description": "Optional new x-coordinate for the circle center (requires y value when provided)." - }, - "new_center_y": { - "type": ["number", "null"], - "description": "Optional new y-coordinate for the circle center (requires x value when provided)." - } - }, - "required": ["name", "new_color", "new_center_x", "new_center_y"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "create_circle_arc", - "description": "Creates an arc on a circle. Use this for requests like 'draw an arc with center (x,y), radius r' or 'arc on circle C between two points'. Supports standalone center/radius arcs, arcs on an existing circle, or deriving from three points with center_point_choice.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "point1_x": {"type": "number", "description": "Reference X coordinate for the first arc point (snapped to the circle when center/radius are provided)"}, - "point1_y": {"type": "number", "description": "Reference Y coordinate for the first arc point (snapped to the circle when center/radius are provided)"}, - "point2_x": {"type": "number", "description": "Reference X coordinate for the second arc point (snapped to the circle when center/radius are provided)"}, - "point2_y": {"type": "number", "description": "Reference Y coordinate for the second arc point (snapped to the circle when center/radius are provided)"}, - "point1_name": {"type": ["string", "null"], "description": "Optional name for the first arc point"}, - "point2_name": {"type": ["string", "null"], "description": "Optional name for the second arc point"}, - "point3_x": {"type": ["number", "null"], "description": "Optional reference X coordinate for a third point when deriving the circle from three points"}, - "point3_y": {"type": ["number", "null"], "description": "Optional reference Y coordinate for a third point when deriving the circle from three points"}, - "point3_name": {"type": ["string", "null"], "description": "Optional name for the third point (used when deriving the circle from three points)"}, - "center_point_choice": {"type": ["string", "null"], "description": "Optional selector ('point1', 'point2', or 'point3') indicating which provided point should be treated as the circle center"}, - "circle_name": {"type": ["string", "null"], "description": "Existing circle to attach the arc to"}, - "center_x": {"type": ["number", "null"], "description": "Circle center x-coordinate when defining a standalone arc"}, - "center_y": {"type": ["number", "null"], "description": "Circle center y-coordinate when defining a standalone arc"}, - "radius": {"type": ["number", "null"], "description": "Circle radius when defining a standalone arc"}, - "use_major_arc": {"type": "boolean", "description": "True to draw the major arc, False for the minor arc"}, - "arc_name": {"type": ["string", "null"], "description": "Optional custom arc name"}, - "color": {"type": ["string", "null"], "description": f"Optional CSS color for the arc (defaults to {DEFAULT_CIRCLE_ARC_COLOR})"} - }, - "required": [ - "point1_x", - "point1_y", - "point2_x", - "point2_y", - "point1_name", - "point2_name", - "point3_x", - "point3_y", - "point3_name", - "center_point_choice", - "circle_name", - "center_x", - "center_y", - "radius", - "use_major_arc", - "arc_name", - "color", + "description": "Box height extending upward (positive Y direction)", + }, + }, + "required": ["x", "y", "width", "height"], + "additionalProperties": False, + }, + "vertices": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": {"type": ["string", "null"]}, + "x": {"type": ["number", "null"]}, + "y": {"type": ["number", "null"]}, + "color": {"type": ["string", "null"]}, + "label": {"type": ["string", "null"]}, + }, + "required": ["name", "x", "y", "color", "label"], + "additionalProperties": False, + }, + "description": "List of vertex descriptors. Vertex id is implied by array index starting at 0.", + }, + "edges": { + "type": "array", + "items": { + "type": "object", + "properties": { + "source": { + "type": "number", + "description": "Source vertex index (0-based, matches vertices array order)", + }, + "target": { + "type": "number", + "description": "Target vertex index (0-based, matches vertices array order)", + }, + "weight": {"type": ["number", "null"]}, + "name": {"type": ["string", "null"]}, + "color": {"type": ["string", "null"]}, + "directed": {"type": ["boolean", "null"]}, + }, + "required": ["source", "target", "weight", "name", "color", "directed"], + "additionalProperties": False, + }, + "description": "List of edge descriptors.", + }, + "adjacency_matrix": { + "type": ["array", "null"], + "items": {"type": "array", "items": {"type": "number"}}, + "description": "Optional adjacency matrix (weights allowed). Rows/columns follow the order of the provided vertices array (0-based).", + }, + }, + "required": [ + "name", + "graph_type", + "directed", + "root", + "layout", + "placement_box", + "vertices", + "edges", + "adjacency_matrix", + ], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "delete_graph", + "description": "Deletes a graph or tree and its associated drawables by name.", + "strict": True, + "parameters": { + "type": "object", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "analyze_graph", + "description": "Analyzes an existing graph/tree for connectivity and structural queries (connectedness, shortest path, BFS/DFS, bipartite, bridges, articulation points, diameter, etc.). Use generate_graph first if the graph does not exist yet.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "graph_name": { + "type": "string", + "description": "Existing graph name to analyze (must exist on canvas).", + }, + "operation": { + "type": "string", + "enum": [ + "shortest_path", + "mst", + "topological_sort", + "bridges", + "articulation_points", + "euler_status", + "bipartite", + "bfs", + "dfs", + "levels", + "diameter", + "lca", + "balance_children", + "invert_children", + "reroot", + "convex_hull", + "point_in_hull", ], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "delete_circle_arc", - "description": "Deletes a circle arc by name.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "name": {"type": "string", "description": "Name of the circle arc to delete"} - }, - "required": ["name"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "update_circle_arc", - "description": "Updates editable properties of an existing circle arc (color or major/minor toggle). Provide null for fields to keep them unchanged.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "name": {"type": "string", "description": "Existing name of the arc to edit"}, - "new_color": {"type": ["string", "null"], "description": "Optional new color for the arc"}, - "use_major_arc": {"type": ["boolean", "null"], "description": "Set to true for the major arc, false for the minor arc"} - }, - "required": [ - "name", - "new_color", - "use_major_arc" + }, + "params": { + "type": ["object", "null"], + "description": "Operation-specific parameters (start, goal, root, a, b, new_root, x, y for point_in_hull, etc.).", + "properties": { + "start": { + "type": ["string", "null"], + "description": "Start vertex for shortest_path, bfs, dfs.", + }, + "goal": {"type": ["string", "null"], "description": "Goal vertex for shortest_path."}, + "root": {"type": ["string", "null"], "description": "Root vertex for tree operations."}, + "a": {"type": ["string", "null"], "description": "First vertex for LCA."}, + "b": {"type": ["string", "null"], "description": "Second vertex for LCA."}, + "new_root": { + "type": ["string", "null"], + "description": "New root vertex for reroot operation.", + }, + "x": {"type": ["number", "null"], "description": "X coordinate for point_in_hull."}, + "y": {"type": ["number", "null"], "description": "Y coordinate for point_in_hull."}, + }, + "required": ["start", "goal", "root", "a", "b", "new_root", "x", "y"], + "additionalProperties": False, + }, + }, + "required": ["graph_name", "operation", "params"], + "additionalProperties": False, + }, + }, + }, + # END GRAPH FUNCTIONS + # START RELATION INSPECTION + { + "type": "function", + "function": { + "name": "inspect_relation", + "description": "Check and explain geometric relations between objects on the canvas. Supported: parallel, perpendicular, collinear, concyclic, equal_length, similar, congruent, tangent, concurrent, point_on_line, point_on_circle. Use 'auto' to check all applicable relations.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "operation": { + "type": "string", + "enum": [ + "parallel", + "perpendicular", + "collinear", + "concyclic", + "equal_length", + "similar", + "congruent", + "tangent", + "concurrent", + "point_on_line", + "point_on_circle", + "auto", ], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "create_ellipse", - "description": "Creates an ellipse with the specified center point, x-radius, y-radius, and optional rotation angle", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "center_x": { - "type": "number", - "description": "The x-coordinate of the ellipse's center" - }, - "center_y": { - "type": "number", - "description": "The y-coordinate of the ellipse's center" - }, - "radius_x": { - "type": "number", - "description": "The radius of the ellipse in the x-direction (half the width)" - }, - "radius_y": { - "type": "number", - "description": "The radius of the ellipse in the y-direction (half the height)" - }, - "rotation_angle": { - "type": ["number", "null"], - "description": "Optional angle in degrees to rotate the ellipse around its center (default: 0)" - }, - "color": { - "type": ["string", "null"], - "description": "Optional color for the ellipse" - }, - "name": { - "type": ["string", "null"], - "description": "Optional name for the ellipse" - } - }, - "required": ["center_x", "center_y", "radius_x", "radius_y", "rotation_angle", "color", "name"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "delete_ellipse", - "description": "Deletes the ellipse with the given name", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "name": { - "type": "string", - "description": "The name of the ellipse" - } - }, - "required": ["name"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "update_ellipse", - "description": "Updates editable properties of an existing ellipse (color, radii, rotation, or center). Provide null for fields that should remain unchanged.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "name": { - "type": "string", - "description": "Existing name of the ellipse to edit." - }, - "new_color": { - "type": ["string", "null"], - "description": "Optional new color for the ellipse." - }, - "new_radius_x": { - "type": ["number", "null"], - "description": "Optional new horizontal radius (requires ellipse to be solitary)." - }, - "new_radius_y": { - "type": ["number", "null"], - "description": "Optional new vertical radius (requires ellipse to be solitary)." - }, - "new_rotation_angle": { - "type": ["number", "null"], - "description": "Optional new rotation angle in degrees." - }, - "new_center_x": { - "type": ["number", "null"], - "description": "Optional new x-coordinate for the center (requires y value when provided)." - }, - "new_center_y": { + }, + "objects": { + "type": "array", + "items": {"type": "string"}, + "description": "Names of objects to check, e.g. ['s1', 's2']", + }, + "object_types": { + "type": "array", + "items": { + "type": "string", + "enum": ["point", "segment", "vector", "circle", "ellipse", "triangle", "rectangle"], + }, + "description": "Type of each object in same order as objects", + }, + }, + "required": ["operation", "objects", "object_types"], + "additionalProperties": False, + }, + }, + }, + # END RELATION INSPECTION + # START PLOT FUNCTIONS + { + "type": "function", + "function": { + "name": "plot_distribution", + "description": "Plots a probability distribution on the canvas. Choose representation 'continuous' for a function curve or 'discrete' for bar rectangles. For continuous plots, you can optionally draw the curve over plot_bounds while shading only over shade_bounds (clamped into plot_bounds). Creates a tracked plot composite for reliable deletion.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "name": { + "type": ["string", "null"], + "description": "Optional plot name. If null, a name will be generated.", + }, + "representation": { + "type": "string", + "enum": ["continuous", "discrete"], + "description": "Plot representation. 'continuous' draws a smooth curve. 'discrete' draws bars (rectangles).", + }, + "distribution_type": { + "type": "string", + "enum": ["normal"], + "description": "Distribution to plot. v1 supports only 'normal' (Gaussian).", + }, + "distribution_params": { + "type": ["object", "null"], + "description": "Parameters for the selected distribution type. For 'normal', provide mean and sigma.", + "properties": { + "mean": { "type": ["number", "null"], - "description": "Optional new y-coordinate for the center (requires x value when provided)." - } - }, - "required": ["name", "new_color", "new_radius_x", "new_radius_y", "new_rotation_angle", "new_center_x", "new_center_y"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "create_label", - "description": ( - "Creates a text label anchored at a math-space coordinate " - f"(max {LABEL_TEXT_MAX_LENGTH} chars, wraps every {LABEL_LINE_WRAP_THRESHOLD} chars)" - ), - "strict": True, - "parameters": { - "type": "object", - "properties": { - "x": { - "type": "number", - "description": "Math-space X coordinate for the label anchor" - }, - "y": { - "type": "number", - "description": "Math-space Y coordinate for the label anchor" + "description": "Mean (mu) for the normal distribution. Defaults to 0 if null.", }, - "text": { - "type": "string", - "description": "Label text content; lines wrap after 40 characters", - "maxLength": 160 - }, - "name": { - "type": ["string", "null"], - "description": "Optional label name used for later updates" - }, - "color": { - "type": ["string", "null"], - "description": "Optional CSS color for the label text" - }, - "font_size": { + "sigma": { "type": ["number", "null"], - "description": "Optional font size in pixels" + "description": "Standard deviation (sigma) for the normal distribution. Defaults to 1 if null. Must be > 0.", }, - "rotation_degrees": { - "type": ["number", "null"], - "description": "Optional angle in degrees to rotate the label text" - } - }, - "required": [ - "x", - "y", - "text", - "name", - "color", - "font_size", - "rotation_degrees", - ], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "delete_label", - "description": "Deletes an existing label by name", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "name": { - "type": "string", - "description": "Name of the label to delete" - } }, - "required": ["name"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "update_label", - "description": "Updates editable properties of an existing label (text, color, position, font size, rotation). Provide null for fields that should remain unchanged.", - "strict": True, - "parameters": { - "type": "object", + "required": ["mean", "sigma"], + "additionalProperties": False, + }, + "plot_bounds": { + "type": ["object", "null"], + "description": "Optional bounds for plotting the curve. If null, or either side is null, defaults to mean +/- 4*sigma.", "properties": { - "name": { - "type": "string", - "description": "Existing name of the label to edit" - }, - "new_text": { - "type": ["string", "null"], - "description": "Optional replacement text for the label" - }, - "new_x": { - "type": ["number", "null"], - "description": "Optional new x-coordinate (requires new_y)" - }, - "new_y": { + "left_bound": { "type": ["number", "null"], - "description": "Optional new y-coordinate (requires new_x)" - }, - "new_color": { - "type": ["string", "null"], - "description": "Optional new text color" + "description": "Optional left bound for plotting the curve. Defaults to mean - 4*sigma when null.", }, - "new_font_size": { + "right_bound": { "type": ["number", "null"], - "description": "Optional new font size in math-space units" + "description": "Optional right bound for plotting the curve. Defaults to mean + 4*sigma when null.", }, - "new_rotation_degrees": { - "type": ["number", "null"], - "description": "Optional rotation angle in degrees" - } }, - "required": [ - "name", - "new_text", - "new_x", - "new_y", - "new_color", - "new_font_size", - "new_rotation_degrees" - ], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "draw_function", - "description": "Plots the given mathematical function on the canvas between the specified left and right bounds.", - "strict": True, - "parameters": { - "type": "object", + "required": ["left_bound", "right_bound"], + "additionalProperties": False, + }, + "shade_bounds": { + "type": ["object", "null"], + "description": "Continuous only. Optional bounds for shading under the curve. If null, defaults to plot_bounds. Bounds are clamped into plot_bounds.", "properties": { - "function_string": { - "type": "string", - "description": "The mathematical expression represented as a string, e.g., '2*x + 3'." - }, - "name": { - "type": ["string", "null"], - "description": "The name or label for the plotted function. Useful for referencing later." - }, "left_bound": { "type": ["number", "null"], - "description": "The left bound of the interval on which to plot the function." + "description": "Optional left bound for shading under the curve. If null, defaults to plot_bounds.left_bound.", }, "right_bound": { "type": ["number", "null"], - "description": "The right bound of the interval on which to plot the function." - }, - "color": { - "type": ["string", "null"], - "description": "Optional color for the plotted function." - }, - "undefined_at": { - "type": ["array", "null"], - "description": "Optional list of x-values where the function is explicitly undefined (holes). E.g., [0, 2] means the function has holes at x=0 and x=2.", - "items": { - "type": "number" - } - } - }, - "required": ["function_string", "name", "left_bound", "right_bound", "color", "undefined_at"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "delete_function", - "description": "Removes the plotted mathematical function with the given name from the canvas.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "name": { - "type": "string", - "description": "The name or label of the function to be deleted." - } - }, - "required": ["name"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "update_function", - "description": "Updates editable properties of an existing plotted function (color and/or bounds). Provide null for fields to leave them unchanged.", - "strict": True, - "parameters": { - "type": "object", + "description": "Optional right bound for shading under the curve. If null, defaults to plot_bounds.right_bound.", + }, + }, + "required": ["left_bound", "right_bound"], + "additionalProperties": False, + }, + "curve_color": {"type": ["string", "null"], "description": "Optional color for the plotted curve."}, + "fill_color": { + "type": ["string", "null"], + "description": "Optional fill color for the area under the curve. Defaults to the standard area fill color.", + }, + "fill_opacity": { + "type": ["number", "null"], + "description": "Optional fill opacity (0 to 1). Defaults to the standard area opacity.", + }, + "bar_count": { + "type": ["number", "null"], + "description": "Discrete only. Number of bars to draw across the bounds. If null, a default is used.", + }, + }, + "required": [ + "name", + "representation", + "distribution_type", + "distribution_params", + "plot_bounds", + "shade_bounds", + "curve_color", + "fill_color", + "fill_opacity", + "bar_count", + ], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "plot_bars", + "description": "Plots a bar chart from tabular data (values with labels). Creates a tracked plot composite for reliable deletion.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "name": { + "type": ["string", "null"], + "description": "Optional plot name. If null, a name will be generated.", + }, + "values": { + "type": "array", + "items": {"type": "number"}, + "description": "Bar heights (math-space units). Must have at least one entry.", + }, + "labels_below": { + "type": "array", + "items": {"type": "string"}, + "description": "Label under each bar. Must have one label per value.", + }, + "labels_above": { + "type": ["array", "null"], + "items": {"type": "string"}, + "description": "Optional label above each bar (for example, formatted values). If provided, must have one label per value.", + }, + "bar_spacing": { + "type": ["number", "null"], + "description": "Optional spacing between bars in math-space units. Defaults to 0.2.", + }, + "bar_width": { + "type": ["number", "null"], + "description": "Optional bar width in math-space units. Defaults to 1.0.", + }, + "stroke_color": {"type": ["string", "null"], "description": "Optional stroke color for each bar."}, + "fill_color": {"type": ["string", "null"], "description": "Optional fill color for each bar."}, + "fill_opacity": {"type": ["number", "null"], "description": "Optional fill opacity (0 to 1)."}, + "x_start": { + "type": ["number", "null"], + "description": "Optional left x coordinate for the first bar. Defaults to 0.", + }, + "y_base": { + "type": ["number", "null"], + "description": "Optional baseline y coordinate for bars. Defaults to 0.", + }, + }, + "required": [ + "name", + "values", + "labels_below", + "labels_above", + "bar_spacing", + "bar_width", + "stroke_color", + "fill_color", + "fill_opacity", + "x_start", + "y_base", + ], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "delete_plot", + "description": "Deletes a previously created plot composite by name, including any underlying components (curve and filled area, or derived bars).", + "strict": True, + "parameters": { + "type": "object", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "fit_regression", + "description": "Fits a regression model to data points and plots the resulting curve. Supported model types: linear (y = mx + b), polynomial (y = a0 + a1*x + ... + an*x^n), exponential (y = a*e^(bx)), logarithmic (y = a + b*ln(x)), power (y = a*x^b), logistic (y = L/(1+e^(-k(x-x0)))), and sinusoidal (y = a*sin(bx+c)+d). Returns the function_name, fitted expression, coefficients, R-squared, and point_names. Use delete_function to remove the curve; delete points individually.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "name": { + "type": ["string", "null"], + "description": "Optional base name for the function and data points. If null, a name will be generated based on model type.", + }, + "x_data": { + "type": "array", + "items": {"type": "number"}, + "description": "Array of x values (independent variable). Must have at least 2 points (more for polynomial).", + }, + "y_data": { + "type": "array", + "items": {"type": "number"}, + "description": "Array of y values (dependent variable). Must have same length as x_data.", + }, + "model_type": { + "type": "string", + "enum": [ + "linear", + "polynomial", + "exponential", + "logarithmic", + "power", + "logistic", + "sinusoidal", + ], + "description": "Type of regression model to fit. Note: exponential and power require positive y values; logarithmic and power require positive x values.", + }, + "degree": { + "type": ["integer", "null"], + "description": "Polynomial degree (required for polynomial model, ignored otherwise). Must be >= 1 and less than the number of data points.", + }, + "plot_bounds": { + "type": ["object", "null"], + "description": "Optional bounds for plotting the fitted curve. Defaults to data range with 10% padding.", "properties": { - "name": { - "type": "string", - "description": "Existing name of the function to edit." - }, - "new_color": { - "type": ["string", "null"], - "description": "Optional new color for the function plot." - }, - "new_left_bound": { + "left_bound": { "type": ["number", "null"], - "description": "Optional new left plotting bound." + "description": "Left bound for plotting. Defaults to min(x_data) - 10% range.", }, - "new_right_bound": { + "right_bound": { "type": ["number", "null"], - "description": "Optional new right plotting bound." - } - }, - "required": ["name", "new_color", "new_left_bound", "new_right_bound"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "draw_piecewise_function", - "description": "Plots a piecewise-defined function with different expressions for different intervals. Each piece specifies an expression and its valid interval bounds. Use null for unbounded intervals (extending to infinity). Use undefined_at for explicit holes (points where the function is undefined).", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "pieces": { - "type": "array", - "minItems": 1, - "description": "List of function pieces, each defining an expression and its interval.", - "items": { - "type": "object", - "properties": { - "expression": { - "type": "string", - "description": "Mathematical expression for this piece (e.g., 'x^2', 'sin(x)')." - }, - "left": { - "type": ["number", "null"], - "description": "Left interval bound (null for negative infinity)." - }, - "right": { - "type": ["number", "null"], - "description": "Right interval bound (null for positive infinity)." - }, - "left_inclusive": { - "type": "boolean", - "description": "Whether the left bound is included in the interval." - }, - "right_inclusive": { - "type": "boolean", - "description": "Whether the right bound is included in the interval." - }, - "undefined_at": { - "type": ["array", "null"], - "description": "Optional list of x-values where this piece is explicitly undefined (holes).", - "items": { - "type": "number" - } - } - }, - "required": ["expression", "left", "right", "left_inclusive", "right_inclusive", "undefined_at"], - "additionalProperties": False - } - }, - "name": { - "type": ["string", "null"], - "description": "Optional name for the piecewise function." - }, - "color": { - "type": ["string", "null"], - "description": "Optional color for the plotted function." - } - }, - "required": ["pieces", "name", "color"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "delete_piecewise_function", - "description": "Removes the plotted piecewise function with the given name from the canvas.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "name": { - "type": "string", - "description": "The name of the piecewise function to delete." - } - }, - "required": ["name"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "update_piecewise_function", - "description": "Updates editable properties of an existing piecewise function (currently just color). Provide null for fields to leave them unchanged.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "name": { - "type": "string", - "description": "Existing name of the piecewise function to edit." - }, - "new_color": { - "type": ["string", "null"], - "description": "Optional new color for the function plot." - } - }, - "required": ["name", "new_color"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "draw_parametric_function", - "description": "Plots a parametric curve defined by x(t) and y(t) expressions. Use this for curves that cannot be expressed as y=f(x), such as circles, spirals, Lissajous figures, and other complex shapes. The parameter t ranges from t_min to t_max (default 0 to 2*pi for periodic curves).", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "x_expression": { - "type": "string", - "description": "Mathematical expression for x as a function of t. Example: 'cos(t)' for a circle, 't*cos(t)' for a spiral." - }, - "y_expression": { - "type": "string", - "description": "Mathematical expression for y as a function of t. Example: 'sin(t)' for a circle, 't*sin(t)' for a spiral." - }, - "name": { - "type": ["string", "null"], - "description": "Optional name or label for the parametric curve. Useful for referencing later." - }, - "t_min": { - "type": ["number", "null"], - "description": "Minimum value of parameter t. Default is 0." - }, - "t_max": { - "type": ["number", "null"], - "description": "Maximum value of parameter t. Default is 2*pi (~6.283) for periodic curves." - }, - "color": { - "type": ["string", "null"], - "description": "Optional color for the plotted curve." - } - }, - "required": ["x_expression", "y_expression", "name", "t_min", "t_max", "color"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "delete_parametric_function", - "description": "Removes the plotted parametric function with the given name from the canvas.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "name": { - "type": "string", - "description": "The name of the parametric function to delete." - } - }, - "required": ["name"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "update_parametric_function", - "description": "Updates editable properties of an existing parametric function (color, t_min, t_max). Provide null for fields to leave them unchanged.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "name": { - "type": "string", - "description": "Existing name of the parametric function to edit." - }, - "new_color": { - "type": ["string", "null"], - "description": "Optional new color for the curve." - }, - "new_t_min": { - "type": ["number", "null"], - "description": "Optional new minimum value of parameter t." - }, - "new_t_max": { - "type": ["number", "null"], - "description": "Optional new maximum value of parameter t." - } - }, - "required": ["name", "new_color", "new_t_min", "new_t_max"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "draw_tangent_line", - "description": "Draws a tangent line segment to a curve at a specified point. For functions y=f(x), the parameter is the x-coordinate. For parametric curves, it's the t value. For circles and ellipses, it's the angle in radians from the positive x-axis.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "curve_name": { - "type": "string", - "description": "Name of the curve to draw tangent to (function, parametric function, circle, or ellipse)" - }, - "parameter": { - "type": "number", - "description": "Position on curve: x-coordinate for functions, t-value for parametric curves, or angle (radians) for circles/ellipses" - }, - "name": { - "type": ["string", "null"], - "description": "Optional name for the tangent line segment" - }, - "length": { - "type": ["number", "null"], - "description": "Total length of the tangent segment in math units (default: 4.0)" - }, - "color": { - "type": ["string", "null"], - "description": "Optional color for the tangent line (default: same as curve)" - } - }, - "required": ["curve_name", "parameter", "name", "length", "color"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "draw_normal_line", - "description": "Draws a normal line segment (perpendicular to tangent) to a curve at a specified point. For functions y=f(x), the parameter is the x-coordinate. For parametric curves, it's the t value. For circles and ellipses, it's the angle in radians from the positive x-axis.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "curve_name": { - "type": "string", - "description": "Name of the curve to draw normal to (function, parametric function, circle, or ellipse)" - }, - "parameter": { - "type": "number", - "description": "Position on curve: x-coordinate for functions, t-value for parametric curves, or angle (radians) for circles/ellipses" - }, - "name": { - "type": ["string", "null"], - "description": "Optional name for the normal line segment" - }, - "length": { - "type": ["number", "null"], - "description": "Total length of the normal segment in math units (default: 4.0)" - }, - "color": { - "type": ["string", "null"], - "description": "Optional color for the normal line (default: same as curve)" - } - }, - "required": ["curve_name", "parameter", "name", "length", "color"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "construct_midpoint", - "description": "Constructs a point at the midpoint of a segment or between two named points. Provide either 'segment_name' or both 'p1_name' and 'p2_name'.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "p1_name": { - "type": ["string", "null"], - "description": "Name of the first point (use with p2_name)" - }, - "p2_name": { - "type": ["string", "null"], - "description": "Name of the second point (use with p1_name)" - }, - "segment_name": { - "type": ["string", "null"], - "description": "Name of the segment whose midpoint to find (alternative to p1_name/p2_name)" - }, - "name": { - "type": ["string", "null"], - "description": "Optional name for the created midpoint" - }, - "color": { - "type": ["string", "null"], - "description": "Optional color for the midpoint" - } - }, - "required": ["p1_name", "p2_name", "segment_name", "name", "color"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "construct_perpendicular_bisector", - "description": "Constructs the perpendicular bisector of a segment. Creates a new segment that passes through the midpoint and is perpendicular to the original.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "segment_name": { - "type": "string", - "description": "Name of the segment to bisect perpendicularly" - }, - "length": { - "type": ["number", "null"], - "description": "Total length of the bisector segment in math units (default: 6.0)" - }, - "name": { - "type": ["string", "null"], - "description": "Optional name for the created bisector segment" - }, - "color": { - "type": ["string", "null"], - "description": "Optional color for the bisector segment" - } - }, - "required": ["segment_name", "length", "name", "color"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "construct_perpendicular_from_point", - "description": "Drops a perpendicular from a point to a segment. Creates the foot point on the line and a segment from the original point to the foot.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "point_name": { - "type": "string", - "description": "Name of the point to project onto the segment" - }, - "segment_name": { - "type": "string", - "description": "Name of the target segment" - }, - "name": { - "type": ["string", "null"], - "description": "Optional name for the perpendicular segment" - }, - "color": { - "type": ["string", "null"], - "description": "Optional color for created drawables" - } - }, - "required": ["point_name", "segment_name", "name", "color"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "construct_angle_bisector", - "description": "Constructs a segment along the bisector of an angle. Provide either 'angle_name' for an existing angle, or 'vertex_name', 'p1_name', 'p2_name' to define the angle by three points.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "vertex_name": { - "type": ["string", "null"], - "description": "Name of the angle vertex point (use with p1_name and p2_name)" - }, - "p1_name": { - "type": ["string", "null"], - "description": "Name of the first arm endpoint (use with vertex_name and p2_name)" - }, - "p2_name": { - "type": ["string", "null"], - "description": "Name of the second arm endpoint (use with vertex_name and p1_name)" - }, - "angle_name": { - "type": ["string", "null"], - "description": "Name of an existing angle to bisect (alternative to vertex/p1/p2)" - }, - "length": { - "type": ["number", "null"], - "description": "Length of the bisector segment in math units (default: 6.0)" - }, - "name": { - "type": ["string", "null"], - "description": "Optional name for the created bisector segment" - }, - "color": { - "type": ["string", "null"], - "description": "Optional color for the bisector segment" - } - }, - "required": ["vertex_name", "p1_name", "p2_name", "angle_name", "length", "name", "color"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "construct_parallel_line", - "description": "Constructs a segment through a point that is parallel to a given segment. The new segment is centered on the specified point.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "segment_name": { - "type": "string", - "description": "Name of the reference segment to be parallel to" - }, - "point_name": { - "type": "string", - "description": "Name of the point the parallel line passes through" - }, - "length": { - "type": ["number", "null"], - "description": "Total length of the parallel segment in math units (default: 6.0)" - }, - "name": { - "type": ["string", "null"], - "description": "Optional name for the created parallel segment" - }, - "color": { - "type": ["string", "null"], - "description": "Optional color for the parallel segment" - } - }, - "required": ["segment_name", "point_name", "length", "name", "color"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "construct_circumcircle", - "description": "Constructs the circumscribed circle (circumcircle) of a triangle or three points. The circumcircle passes through all three vertices. Provide either triangle_name or all three point names.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "triangle_name": { - "type": ["string", "null"], - "description": "Name of an existing triangle (alternative to specifying three points)" - }, - "p1_name": { - "type": ["string", "null"], - "description": "Name of the first point (used with p2_name and p3_name instead of triangle_name)" - }, - "p2_name": { - "type": ["string", "null"], - "description": "Name of the second point" - }, - "p3_name": { - "type": ["string", "null"], - "description": "Name of the third point" - }, - "name": { - "type": ["string", "null"], - "description": "Optional name for the created circumcircle" - }, - "color": { - "type": ["string", "null"], - "description": "Optional color for the circumcircle" - } - }, - "required": ["triangle_name", "p1_name", "p2_name", "p3_name", "name", "color"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "construct_incircle", - "description": "Constructs the inscribed circle (incircle) of a triangle. The incircle is tangent to all three sides of the triangle.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "triangle_name": { - "type": "string", - "description": "Name of an existing triangle" - }, - "name": { - "type": ["string", "null"], - "description": "Optional name for the created incircle" - }, - "color": { - "type": ["string", "null"], - "description": "Optional color for the incircle" - } - }, - "required": ["triangle_name", "name", "color"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "evaluate_expression", - "description": "Calculate or evaluate a mathematical expression and return the numerical result. Use for arithmetic (+, -, *, /, ^), algebra, and math functions. Supports variables (x, y), constants (e, pi), and functions: sin, cos, tan, sqrt, log, log10, log2, factorial, arrangements, permutations, combinations, asin, acos, atan, sinh, cosh, tanh, exp, abs, pow, det, bin, round, ceil, floor, trunc, max, min, sum, gcd, lcm, is_prime, prime_factors, mod_pow, mod_inverse, next_prime, prev_prime, totient, divisors, mean, median, mode, stdev, variance, random, randint. Also supports sequence/series helpers such as summation('n^2','n',0,50), product('n+1','n',0,10), arithmetic_sum(1,2,20), geometric_sum(3,0.5,15), ratio_test('1/factorial(n)','n'), root_test('(1/2)^n','n'), and p_series_test(2).", - "parameters": { - "type": "object", - "properties": { - "expression": { - "type": "string", - "description": "The mathematical expression to be evaluated. Example: '5*x - 1' or 'sin(x)'" - }, - "variables": { - "type": ["object", "null"], - "description": "Dictionary containing key-value pairs of the variables and values to be substituted in the expression. Example: {'x': 2, 'y': 3}" - } - }, - "required": ["expression", "variables"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "evaluate_linear_algebra_expression", - "description": "Evaluates matrix/vector/scalar expressions (linear algebra). Use for determinant, inverse, transpose, matrix multiplication, and eigen computations with named objects.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "objects": { - "type": "array", - "minItems": 1, - "description": "List of named linear algebra objects available to the expression evaluator.", - "items": { - "type": "object", - "properties": { - "name": { - "type": "string", - "description": "Identifier used to reference the object inside the expression. Must start with a letter or underscore." - }, - "value": { - "description": "Scalar, vector, or matrix definition for the object.", - "anyOf": [ - { - "type": "number" - }, - { - "type": "array", - "minItems": 1, - "items": { - "type": "number" - } - }, - { - "type": "array", - "minItems": 1, - "items": { - "type": "array", - "minItems": 1, - "items": { - "type": "number" - } - } - } - ] - } - }, - "required": ["name", "value"], - "additionalProperties": False - } - }, - "expression": { - "type": "string", - "description": "Math.js compatible expression composed of the provided object names and supported linear algebra functions. Example: 'A + B' or 'inv(A) * b'." - } - }, - "required": ["objects", "expression"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "convert", - "description": "Converts a numeric value between units (e.g., degrees to radians, km to m, F to C). For Cartesian/polar coordinate conversion, use convert_coordinates instead.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "value": { - "type": "number", - "description": "The value to be converted" - }, - "from_unit": { - "type": "string", - "description": "The unit to convert from" - }, - "to_unit": { - "type": "string", - "description": "The unit to convert to" - } - }, - "required": ["value", "from_unit", "to_unit"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "limit", - "description": "Computes the limit of a function as it approaches a value", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "expression": { - "type": "string", - "description": "The mathematical expression represented as a string. Example: 'log(x)^2'." - }, - "variable": { - "type": "string", - "description": "The variable with respect to which the limit is computed." - }, - "value_to_approach": { - "type": "string", - "description": "The value the variable approaches." - } - }, - "required": ["expression", "variable", "value_to_approach"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "derive", - "description": "Computes the derivative of a function with respect to a variable", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "expression": { - "type": "string", - "description": "The mathematical expression represented as a string. Example: '2*x + 3'." - }, - "variable": { - "type": "string", - "description": "The variable with respect to which the derivative is computed." - } - }, - "required": ["expression", "variable"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "integrate", - "description": "Computes the integral of a function with respect to a variable. Specify the lower and upper bounds only for definite integrals.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "expression": { - "type": "string", - "description": "The mathematical expression represented as a string. Example: '2*x + 3'" - }, - "variable": { - "type": "string", - "description": "The variable with respect to which the integral is computed. Example: 'x'" - }, - "lower_bound": { - "type": ["number", "null"], - "description": "The lower bound of the integral." - }, - "upper_bound": { - "type": ["number", "null"], - "description": "The upper bound of the integral." - } - }, - "required": ["expression", "variable", "lower_bound", "upper_bound"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "numeric_integrate", - "description": "Numerically approximate a definite integral over finite bounds using trapezoid, midpoint, or Simpson's rule.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "expression": { - "type": "string", - "description": "Integrand expression such as 'sin(x)' or 'x^2 + 1'." - }, - "variable": { - "type": "string", - "description": "Integration variable, typically 'x'." - }, - "lower_bound": { - "type": "number", - "description": "Lower finite bound of integration." - }, - "upper_bound": { - "type": "number", - "description": "Upper finite bound of integration." - }, - "method": { - "type": "string", - "enum": ["trapezoid", "midpoint", "simpson"], - "description": "Numeric integration method. Optional; defaults to 'simpson' when omitted." - }, - "steps": { - "type": "integer", - "description": "Number of subintervals. Must be a positive integer and <= 10000. Optional; defaults to 200 when omitted." - } - }, - "required": [ - "expression", - "variable", - "lower_bound", - "upper_bound" - ], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "simplify", - "description": "Simplifies a mathematical expression.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "expression": { - "type": "string", - "description": "The mathematical expression represented as a string. Example: 'x^2 + 2*x + 1'" - } - }, - "required": ["expression"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "expand", - "description": "Expands a mathematical expression.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "expression": { - "type": "string", - "description": "The mathematical expression represented as a string. Example: '(x+1)^2'" - } - }, - "required": ["expression"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "factor", - "description": "Factors a mathematical expression.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "expression": { - "type": "string", - "description": "The mathematical expression represented as a string. Example: 'x^2 - 1'" - } - }, - "required": ["expression"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "solve", - "description": "Solves a mathematical equation for a given variable.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "equation": { - "type": "string", - "description": "The mathematical equation represented as a string. Example: 'x^2 - 1'" - }, - "variable": { - "type": "string", - "description": "The variable to solve for. Example: 'x'" - } - }, - "required": ["equation", "variable"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "solve_system_of_equations", - "description": "Solves a system of mathematical equations.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "equations": { - "type": "array", - "description": "An array of mathematical equations represented as strings. Example: ['2*x/3 = y', 'x-2 = y']", - "items": { - "type": "string" - } - } - }, - "required": ["equations"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "solve_numeric", - "description": "Numerically solves a system of equations using multi-start Newton-Raphson. Use for transcendental, mixed nonlinear, or systems that can't be solved symbolically (e.g., sin(x) + y = 1, x^2 + y^2 = 4). Supports any number of variables. Returns multiple solutions when they exist.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "equations": { - "type": "array", - "description": "Array of equation strings. Use '=' for equations (e.g., ['sin(x) + y = 1', 'x^2 + y^2 = 4']). If no '=' is present, the expression is assumed equal to 0. Variables are auto-detected.", - "items": {"type": "string"} - } - }, - "required": ["equations"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "translate_object", - "description": "Moves/shifts/translates an existing drawable object or function by x and y offsets (dx, dy).", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "name": { - "type": "string", - "description": "The exact name of the object to translate taken from the canvas state" - }, - "x_offset": { - "type": "number", - "description": "The horizontal translation distance (positive moves right, negative moves left)" - }, - "y_offset": { - "type": "number", - "description": "The vertical translation distance (positive moves up, negative moves down)" - } - }, - "required": ["name", "x_offset", "y_offset"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "rotate_object", - "description": "Rotates a drawable object by the specified angle. By default rotates around the object's own center. When center_x and center_y are provided, rotates around that arbitrary point (works for all types including points and circles).", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "name": { - "type": "string", - "description": "The name of the object to rotate" - }, - "angle": { - "type": "number", - "description": "The angle in degrees to rotate the object (positive for counterclockwise)" - }, - "center_x": { - "type": ["number", "null"], - "description": "X-coordinate of the rotation center. Must be provided together with center_y for rotation around an arbitrary point. Omit (null) to rotate around the object's own center." - }, - "center_y": { - "type": ["number", "null"], - "description": "Y-coordinate of the rotation center. Must be provided together with center_x for rotation around an arbitrary point. Omit (null) to rotate around the object's own center." - } - }, - "required": ["name", "angle", "center_x", "center_y"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "reflect_object", - "description": "Reflects (mirrors) a drawable object across an axis or line. Supports x-axis, y-axis, an arbitrary line (ax + by + c = 0), or a named segment as the reflection axis.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "name": { - "type": "string", - "description": "The name of the object to reflect" - }, - "axis": { - "type": "string", - "enum": ["x_axis", "y_axis", "line", "segment"], - "description": "The reflection axis type" - }, - "line_a": { - "type": ["number", "null"], - "description": "Coefficient a in ax + by + c = 0 (required when axis is 'line')" - }, - "line_b": { - "type": ["number", "null"], - "description": "Coefficient b in ax + by + c = 0 (required when axis is 'line')" - }, - "line_c": { - "type": ["number", "null"], - "description": "Coefficient c in ax + by + c = 0 (required when axis is 'line')" - }, - "segment_name": { - "type": ["string", "null"], - "description": "Name of a segment to use as the reflection axis (required when axis is 'segment')" - } - }, - "required": ["name", "axis", "line_a", "line_b", "line_c", "segment_name"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "scale_object", - "description": "Scales (dilates) a drawable object by the specified factors from a center point. Use equal sx and sy for uniform scaling. Circles require uniform scaling (equal sx and sy); for non-uniform scaling, convert to an ellipse first.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "name": { - "type": "string", - "description": "The name of the object to scale" - }, - "sx": { - "type": "number", - "description": "Horizontal scale factor (e.g. 2 to double width, 0.5 to halve)" - }, - "sy": { - "type": "number", - "description": "Vertical scale factor (e.g. 2 to double height, 0.5 to halve)" - }, - "cx": { - "type": "number", - "description": "X-coordinate of the scaling center" - }, - "cy": { - "type": "number", - "description": "Y-coordinate of the scaling center" - } - }, - "required": ["name", "sx", "sy", "cx", "cy"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "shear_object", - "description": "Shears a drawable object along the specified axis from a center point. Not supported for circles and ellipses.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "name": { - "type": "string", - "description": "The name of the object to shear" - }, - "axis": { - "type": "string", - "enum": ["horizontal", "vertical"], - "description": "The shear direction" - }, - "factor": { - "type": "number", - "description": "The shear factor (e.g. 0.5 shifts x by 0.5*dy for horizontal shear)" - }, - "cx": { - "type": "number", - "description": "X-coordinate of the shear center" - }, - "cy": { - "type": "number", - "description": "Y-coordinate of the shear center" - } - }, - "required": ["name", "axis", "factor", "cx", "cy"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "save_workspace", - "description": "Saves the current workspace state to a file. If no name is provided, saves to the current workspace file with timestamp. The workspace name MUST only contain alphanumeric characters, underscores, or hyphens (no spaces, dots, slashes, or other special characters).", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "name": { - "type": ["string", "null"], - "description": "Optional name for the workspace. Must contain only alphanumeric characters, underscores, or hyphens (e.g., 'my_workspace', 'workspace-1', 'test123'). If not provided, saves to current workspace." - } - }, - "required": ["name"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "load_workspace", - "description": "Loads a workspace from a file. If no name is provided, loads the (most recent) current workspace. The workspace name MUST only contain alphanumeric characters, underscores, or hyphens (no spaces, dots, slashes, or other special characters).", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "name": { - "type": ["string", "null"], - "description": "Optional name of the workspace to load. Must contain only alphanumeric characters, underscores, or hyphens (e.g., 'my_workspace', 'workspace-1', 'test123'). If not provided, loads current workspace." - } - }, - "required": ["name"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "list_workspaces", - "description": "Lists all saved workspaces. Only shows workspaces with valid names (containing only alphanumeric characters, underscores, or hyphens).", - "strict": True, - "parameters": { - "type": "object", - "properties": {}, - "required": [], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "delete_workspace", - "description": "Delete a workspace by name. The workspace name MUST only contain alphanumeric characters, underscores, or hyphens (no spaces, dots, slashes, or other special characters).", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "name": { - "type": "string", - "description": "Name of the workspace to delete. Must contain only alphanumeric characters, underscores, or hyphens (e.g., 'my_workspace', 'workspace-1', 'test123')." - } - }, - "required": ["name"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "create_colored_area", - "description": "Creates a colored area between two drawables (functions, segments, or a function and a segment). If only one drawable is provided, the area will be between that drawable and the x-axis.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "drawable1_name": { - "type": "string", - "description": "Name of the first drawable (function or segment). Use 'x_axis' for the x-axis." - }, - "drawable2_name": { - "type": ["string", "null"], - "description": "Optional name of the second drawable (function or segment). Use 'x_axis' for the x-axis. If not provided, area will be between drawable1 and x-axis." - }, - "left_bound": { - "type": ["number", "null"], - "description": "Optional left bound for function areas. Only used when at least one drawable is a function." - }, - "right_bound": { - "type": ["number", "null"], - "description": "Optional right bound for function areas. Only used when at least one drawable is a function." - }, - "color": { - "type": ["string", "null"], - "description": "Optional color for the area. Default is 'lightblue'." - }, - "opacity": { - "type": ["number", "null"], - "description": "Optional opacity for the area between 0 and 1. Default is 0.3." - } - }, - "required": ["drawable1_name", "drawable2_name", "left_bound", "right_bound", "color", "opacity"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "create_region_colored_area", - "description": "Fill a region defined by a boolean expression or a closed shape. Supports expressions with operators (& | - ^), arcs, circles, ellipses, polygons, and segments. Expression takes precedence if provided.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "expression": { - "type": ["string", "null"], - "description": "Boolean region expression using shape names and operators. Examples: 'ArcMaj_AB & CD' (arc intersected with segment), 'circle_A - triangle_ABC' (difference). Takes precedence over other parameters." - }, - "triangle_name": { - "type": ["string", "null"], - "description": "Name of an existing triangle to fill." - }, - "rectangle_name": { - "type": ["string", "null"], - "description": "Name of an existing rectangle to fill." - }, - "polygon_segment_names": { - "type": ["array", "null"], - "items": {"type": "string"}, - "description": "List of segment names that form a closed polygon loop (at least three segments)." - }, - "circle_name": { - "type": ["string", "null"], - "description": "Name of the circle to fill or to use with a chord segment." - }, - "ellipse_name": { - "type": ["string", "null"], - "description": "Name of the ellipse to fill or to use with a chord segment." - }, - "chord_segment_name": { - "type": ["string", "null"], - "description": "Segment name that serves as the chord/clip when creating a circle or ellipse segment region." - }, - "arc_clockwise": { - "type": ["boolean", "null"], - "description": "Set to true to trace the arc clockwise when using a round shape with a chord segment. Default is false (counter-clockwise)." - }, - "resolution": { - "type": ["number", "null"], - "description": "Number of samples used to approximate curved boundaries. Defaults to 96." - }, - "color": { - "type": ["string", "null"], - "description": "Optional color for the filled area. Default is 'lightblue'." - }, - "opacity": { - "type": ["number", "null"], - "description": "Optional opacity between 0 and 1. Default is 0.3." - } - }, - "required": [ - "expression", - "triangle_name", - "rectangle_name", - "polygon_segment_names", - "circle_name", - "ellipse_name", - "chord_segment_name", - "arc_clockwise", - "resolution", - "color", - "opacity" - ], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "delete_colored_area", - "description": "Deletes a colored area by its name", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "name": { - "type": "string", - "description": "Name of the colored area to delete" - } - }, - "required": ["name"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "update_colored_area", - "description": "Updates editable properties of an existing colored area (color, opacity, and for function-bounded areas, optional left/right bounds). Provide null for fields that should remain unchanged.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "name": { - "type": "string", - "description": "Existing name of the colored area to edit." - }, - "new_color": { - "type": ["string", "null"], - "description": "Optional new color for the area." - }, - "new_opacity": { - "type": ["number", "null"], - "description": "Optional new opacity between 0 and 1." - }, - "new_left_bound": { - "type": ["number", "null"], - "description": "Optional new left bound (functions-bounded areas only)." - }, - "new_right_bound": { - "type": ["number", "null"], - "description": "Optional new right bound (functions-bounded areas only)." - } - }, - "required": ["name", "new_color", "new_opacity", "new_left_bound", "new_right_bound"], - "additionalProperties": False - } - } - }, - # START GRAPH FUNCTIONS - { - "type": "function", - "function": { - "name": "generate_graph", - "description": "Generates a graph or tree on the canvas using provided vertices/edges or an adjacency matrix. Returns the created graph state and drawable names for follow-up highlighting.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "name": {"type": ["string", "null"]}, - "graph_type": { - "type": "string", - "enum": ["graph", "tree", "dag"], - "description": "Type of graph to create." - }, - "directed": {"type": ["boolean", "null"]}, - "root": {"type": ["string", "null"], "description": "Root id for trees."}, - "layout": {"type": ["string", "null"], "description": "Layout hint: 'tree' or 'hierarchical' for top-down tree display (default for trees), 'radial' for concentric rings from root, 'circular' for nodes on a circle, 'grid' for rectangular grid, 'force' for force-directed."}, - "placement_box": { - "type": ["object", "null"], - "description": "Bounding box for vertex placement. Defined from bottom-left corner in math coordinates (y increases upward). Box spans from (x, y) to (x + width, y + height).", - "properties": { - "x": {"type": "number", "description": "Left edge X coordinate (bottom-left corner)"}, - "y": {"type": "number", "description": "Bottom edge Y coordinate (bottom-left corner, in math coords where y increases upward)"}, - "width": {"type": "number", "description": "Box width extending rightward (positive X direction)"}, - "height": {"type": "number", "description": "Box height extending upward (positive Y direction)"} - }, - "required": ["x", "y", "width", "height"], - "additionalProperties": False - }, - "vertices": { - "type": "array", - "items": { - "type": "object", - "properties": { - "name": {"type": ["string", "null"]}, - "x": {"type": ["number", "null"]}, - "y": {"type": ["number", "null"]}, - "color": {"type": ["string", "null"]}, - "label": {"type": ["string", "null"]} - }, - "required": ["name", "x", "y", "color", "label"], - "additionalProperties": False - }, - "description": "List of vertex descriptors. Vertex id is implied by array index starting at 0." - }, - "edges": { - "type": "array", - "items": { - "type": "object", - "properties": { - "source": {"type": "number", "description": "Source vertex index (0-based, matches vertices array order)"}, - "target": {"type": "number", "description": "Target vertex index (0-based, matches vertices array order)"}, - "weight": {"type": ["number", "null"]}, - "name": {"type": ["string", "null"]}, - "color": {"type": ["string", "null"]}, - "directed": {"type": ["boolean", "null"]} - }, - "required": ["source", "target", "weight", "name", "color", "directed"], - "additionalProperties": False - }, - "description": "List of edge descriptors." - }, - "adjacency_matrix": { - "type": ["array", "null"], - "items": { - "type": "array", - "items": {"type": "number"} - }, - "description": "Optional adjacency matrix (weights allowed). Rows/columns follow the order of the provided vertices array (0-based)." - } - }, - "required": ["name", "graph_type", "directed", "root", "layout", "placement_box", "vertices", "edges", "adjacency_matrix"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "delete_graph", - "description": "Deletes a graph or tree and its associated drawables by name.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "name": {"type": "string"} - }, - "required": ["name"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "analyze_graph", - "description": "Analyzes an existing graph/tree for connectivity and structural queries (connectedness, shortest path, BFS/DFS, bipartite, bridges, articulation points, diameter, etc.). Use generate_graph first if the graph does not exist yet.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "graph_name": {"type": "string", "description": "Existing graph name to analyze (must exist on canvas)."}, - "operation": { - "type": "string", - "enum": ["shortest_path", "mst", "topological_sort", "bridges", "articulation_points", "euler_status", "bipartite", "bfs", "dfs", "levels", "diameter", "lca", "balance_children", "invert_children", "reroot", "convex_hull", "point_in_hull"] - }, - "params": { - "type": ["object", "null"], - "description": "Operation-specific parameters (start, goal, root, a, b, new_root, x, y for point_in_hull, etc.).", - "properties": { - "start": {"type": ["string", "null"], "description": "Start vertex for shortest_path, bfs, dfs."}, - "goal": {"type": ["string", "null"], "description": "Goal vertex for shortest_path."}, - "root": {"type": ["string", "null"], "description": "Root vertex for tree operations."}, - "a": {"type": ["string", "null"], "description": "First vertex for LCA."}, - "b": {"type": ["string", "null"], "description": "Second vertex for LCA."}, - "new_root": {"type": ["string", "null"], "description": "New root vertex for reroot operation."}, - "x": {"type": ["number", "null"], "description": "X coordinate for point_in_hull."}, - "y": {"type": ["number", "null"], "description": "Y coordinate for point_in_hull."} - }, - "required": ["start", "goal", "root", "a", "b", "new_root", "x", "y"], - "additionalProperties": False - } - }, - "required": ["graph_name", "operation", "params"], - "additionalProperties": False - } - } - }, - # END GRAPH FUNCTIONS - # START RELATION INSPECTION - { - "type": "function", - "function": { - "name": "inspect_relation", - "description": "Check and explain geometric relations between objects on the canvas. Supported: parallel, perpendicular, collinear, concyclic, equal_length, similar, congruent, tangent, concurrent, point_on_line, point_on_circle. Use 'auto' to check all applicable relations.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "operation": { - "type": "string", - "enum": ["parallel", "perpendicular", "collinear", "concyclic", - "equal_length", "similar", "congruent", "tangent", - "concurrent", "point_on_line", "point_on_circle", "auto"] - }, - "objects": { - "type": "array", - "items": {"type": "string"}, - "description": "Names of objects to check, e.g. ['s1', 's2']" - }, - "object_types": { - "type": "array", - "items": { - "type": "string", - "enum": ["point", "segment", "vector", "circle", - "ellipse", "triangle", "rectangle"] - }, - "description": "Type of each object in same order as objects" - } - }, - "required": ["operation", "objects", "object_types"], - "additionalProperties": False - } - } - }, - # END RELATION INSPECTION - # START PLOT FUNCTIONS - { - "type": "function", - "function": { - "name": "plot_distribution", - "description": "Plots a probability distribution on the canvas. Choose representation 'continuous' for a function curve or 'discrete' for bar rectangles. For continuous plots, you can optionally draw the curve over plot_bounds while shading only over shade_bounds (clamped into plot_bounds). Creates a tracked plot composite for reliable deletion.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "name": { - "type": ["string", "null"], - "description": "Optional plot name. If null, a name will be generated." - }, - "representation": { - "type": "string", - "enum": ["continuous", "discrete"], - "description": "Plot representation. 'continuous' draws a smooth curve. 'discrete' draws bars (rectangles)." - }, - "distribution_type": { - "type": "string", - "enum": ["normal"], - "description": "Distribution to plot. v1 supports only 'normal' (Gaussian)." - }, - "distribution_params": { - "type": ["object", "null"], - "description": "Parameters for the selected distribution type. For 'normal', provide mean and sigma.", - "properties": { - "mean": { - "type": ["number", "null"], - "description": "Mean (mu) for the normal distribution. Defaults to 0 if null." - }, - "sigma": { - "type": ["number", "null"], - "description": "Standard deviation (sigma) for the normal distribution. Defaults to 1 if null. Must be > 0." - } - }, - "required": ["mean", "sigma"], - "additionalProperties": False - }, - "plot_bounds": { - "type": ["object", "null"], - "description": "Optional bounds for plotting the curve. If null, or either side is null, defaults to mean +/- 4*sigma.", - "properties": { - "left_bound": { - "type": ["number", "null"], - "description": "Optional left bound for plotting the curve. Defaults to mean - 4*sigma when null." - }, - "right_bound": { - "type": ["number", "null"], - "description": "Optional right bound for plotting the curve. Defaults to mean + 4*sigma when null." - } - }, - "required": ["left_bound", "right_bound"], - "additionalProperties": False - }, - "shade_bounds": { - "type": ["object", "null"], - "description": "Continuous only. Optional bounds for shading under the curve. If null, defaults to plot_bounds. Bounds are clamped into plot_bounds.", - "properties": { - "left_bound": { - "type": ["number", "null"], - "description": "Optional left bound for shading under the curve. If null, defaults to plot_bounds.left_bound." - }, - "right_bound": { - "type": ["number", "null"], - "description": "Optional right bound for shading under the curve. If null, defaults to plot_bounds.right_bound." - } - }, - "required": ["left_bound", "right_bound"], - "additionalProperties": False - }, - "curve_color": { - "type": ["string", "null"], - "description": "Optional color for the plotted curve." - }, - "fill_color": { - "type": ["string", "null"], - "description": "Optional fill color for the area under the curve. Defaults to the standard area fill color." - }, - "fill_opacity": { - "type": ["number", "null"], - "description": "Optional fill opacity (0 to 1). Defaults to the standard area opacity." - }, - "bar_count": { - "type": ["number", "null"], - "description": "Discrete only. Number of bars to draw across the bounds. If null, a default is used." - } - }, - "required": [ - "name", - "representation", - "distribution_type", - "distribution_params", - "plot_bounds", - "shade_bounds", - "curve_color", - "fill_color", - "fill_opacity", - "bar_count" - ], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "plot_bars", - "description": "Plots a bar chart from tabular data (values with labels). Creates a tracked plot composite for reliable deletion.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "name": { - "type": ["string", "null"], - "description": "Optional plot name. If null, a name will be generated." - }, - "values": { - "type": "array", - "items": {"type": "number"}, - "description": "Bar heights (math-space units). Must have at least one entry." - }, - "labels_below": { - "type": "array", - "items": {"type": "string"}, - "description": "Label under each bar. Must have one label per value." - }, - "labels_above": { - "type": ["array", "null"], - "items": {"type": "string"}, - "description": "Optional label above each bar (for example, formatted values). If provided, must have one label per value." - }, - "bar_spacing": { - "type": ["number", "null"], - "description": "Optional spacing between bars in math-space units. Defaults to 0.2." - }, - "bar_width": { - "type": ["number", "null"], - "description": "Optional bar width in math-space units. Defaults to 1.0." - }, - "stroke_color": { - "type": ["string", "null"], - "description": "Optional stroke color for each bar." - }, - "fill_color": { - "type": ["string", "null"], - "description": "Optional fill color for each bar." - }, - "fill_opacity": { - "type": ["number", "null"], - "description": "Optional fill opacity (0 to 1)." - }, - "x_start": { - "type": ["number", "null"], - "description": "Optional left x coordinate for the first bar. Defaults to 0." - }, - "y_base": { - "type": ["number", "null"], - "description": "Optional baseline y coordinate for bars. Defaults to 0." - } - }, - "required": [ - "name", - "values", - "labels_below", - "labels_above", - "bar_spacing", - "bar_width", - "stroke_color", - "fill_color", - "fill_opacity", - "x_start", - "y_base" - ], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "delete_plot", - "description": "Deletes a previously created plot composite by name, including any underlying components (curve and filled area, or derived bars).", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "name": {"type": "string"} - }, - "required": ["name"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "fit_regression", - "description": "Fits a regression model to data points and plots the resulting curve. Supported model types: linear (y = mx + b), polynomial (y = a0 + a1*x + ... + an*x^n), exponential (y = a*e^(bx)), logarithmic (y = a + b*ln(x)), power (y = a*x^b), logistic (y = L/(1+e^(-k(x-x0)))), and sinusoidal (y = a*sin(bx+c)+d). Returns the function_name, fitted expression, coefficients, R-squared, and point_names. Use delete_function to remove the curve; delete points individually.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "name": { - "type": ["string", "null"], - "description": "Optional base name for the function and data points. If null, a name will be generated based on model type." - }, - "x_data": { - "type": "array", - "items": {"type": "number"}, - "description": "Array of x values (independent variable). Must have at least 2 points (more for polynomial)." - }, - "y_data": { - "type": "array", - "items": {"type": "number"}, - "description": "Array of y values (dependent variable). Must have same length as x_data." - }, - "model_type": { - "type": "string", - "enum": ["linear", "polynomial", "exponential", "logarithmic", "power", "logistic", "sinusoidal"], - "description": "Type of regression model to fit. Note: exponential and power require positive y values; logarithmic and power require positive x values." - }, - "degree": { - "type": ["integer", "null"], - "description": "Polynomial degree (required for polynomial model, ignored otherwise). Must be >= 1 and less than the number of data points." - }, - "plot_bounds": { - "type": ["object", "null"], - "description": "Optional bounds for plotting the fitted curve. Defaults to data range with 10% padding.", - "properties": { - "left_bound": { - "type": ["number", "null"], - "description": "Left bound for plotting. Defaults to min(x_data) - 10% range." - }, - "right_bound": { - "type": ["number", "null"], - "description": "Right bound for plotting. Defaults to max(x_data) + 10% range." - } - }, - "required": ["left_bound", "right_bound"], - "additionalProperties": False - }, - "curve_color": { - "type": ["string", "null"], - "description": "Optional color for the fitted curve." - }, - "show_points": { - "type": ["boolean", "null"], - "description": "Whether to plot the data points. Defaults to true." - }, - "point_color": { - "type": ["string", "null"], - "description": "Optional color for data points (if show_points is true)." - } - }, - "required": [ - "name", - "x_data", - "y_data", - "model_type", - "degree", - "plot_bounds", - "curve_color", - "show_points", - "point_color" - ], - "additionalProperties": False - } - } - }, - # END PLOT FUNCTIONS - # START ANGLE FUNCTIONS - { - "type": "function", - "function": { - "name": "create_angle", - "description": "Creates and draws an angle defined by three points. The first point (vx, vy) is the common vertex, and the other two points (p1x, p1y and p2x, p2y) define the angle's arms. For example, in an angle ABC, (vx, vy) would be the coordinates of point B. The angle's visual representation (arc and degree value) will be drawn. Segments forming the angle will be created if they don't exist.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "vx": { - "type": "number", - "description": "The X coordinate of the common vertex point (e.g., point B in an angle ABC)." - }, - "vy": { - "type": "number", - "description": "The Y coordinate of the common vertex point (e.g., point B in an angle ABC)." - }, - "p1x": { - "type": "number", - "description": "The X coordinate of the first arm point." - }, - "p1y": { - "type": "number", - "description": "The Y coordinate of the first arm point." - }, - "p2x": { - "type": "number", - "description": "The X coordinate of the second arm point." - }, - "p2y": { - "type": "number", - "description": "The Y coordinate of the second arm point." - }, - "color": { - "type": [ "string", "null" ], - "description": "Optional color for the angle's arc and text. Defaults to the canvas default color." - }, - "angle_name": { - "type": [ "string", "null" ], - "description": "Optional name for the angle. If not provided, a name might be generated (e.g., 'angle_ABC')." - }, - "is_reflex": { - "type": ["boolean", "null"], - "description": "Optional. If true, the reflex angle will be created. Defaults to false (smallest angle)." - } - }, - "required": ["vx", "vy", "p1x", "p1y", "p2x", "p2y", "color", "angle_name", "is_reflex"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "delete_angle", - "description": "Removes an angle by its name. This will also attempt to remove its constituent segments if they are no longer part of other drawables.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "name": { - "type": "string", - "description": "The name of the angle to remove (e.g., 'angle_ABC')." - } - }, - "required": ["name"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "update_angle", - "description": "Updates editable properties of an existing angle (currently just its color).", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "name": { - "type": "string", - "description": "The name of the angle to update." - }, - "new_color": { - "type": [ "string", "null" ], - "description": "The new color for the angle. Provide null to leave unchanged." - } - }, - "required": ["name", "new_color"], - "additionalProperties": False - } - } - }, - # END ANGLE FUNCTIONS - # START AREA CALCULATION FUNCTIONS - { - "type": "function", - "function": { - "name": "calculate_area", - "description": "Calculates geometric area (triangle, polygon, circle, arc segment, region unions/intersections) from canvas drawables or boolean region expressions. Use this for 'area of a triangle/circle/region'.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "expression": { - "type": "string", - "description": "Boolean expression with drawable names. Examples: 'circle_A' (single shape), 'circle_A & triangle_ABC' (intersection), 'C(5) & AB' (circle cut by segment AB), 'ArcMaj_CD & triangle_ABC' (arc segment intersected with triangle), 'circle_A - triangle_ABC' (difference), '(circle_A & quad_ABCD) & EF' (shapes intersected then cut by segment)." - } - }, - "required": ["expression"], - "additionalProperties": False - } - } - }, - # END AREA CALCULATION FUNCTIONS - # START COORDINATE SYSTEM FUNCTIONS - { - "type": "function", - "function": { - "name": "set_coordinate_system", - "description": "Sets the coordinate system mode for the canvas grid. Choose 'cartesian' for the standard x-y grid or 'polar' for a polar coordinate grid with concentric circles and radial lines.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "mode": { - "type": "string", - "enum": ["cartesian", "polar"], - "description": "The coordinate system mode: 'cartesian' for x-y grid, 'polar' for polar grid" - } - }, - "required": ["mode"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "convert_coordinates", - "description": "Converts coordinates between rectangular (Cartesian) and polar coordinate systems. For rectangular to polar: returns (r, theta) where r is radius and theta is angle in radians. For polar to rectangular: returns (x, y) coordinates.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "coord1": { - "type": "number", - "description": "First coordinate: x for rectangular-to-polar, r (radius) for polar-to-rectangular" - }, - "coord2": { - "type": "number", - "description": "Second coordinate: y for rectangular-to-polar, theta (angle in radians) for polar-to-rectangular" - }, - "from_system": { - "type": "string", - "enum": ["rectangular", "cartesian", "polar"], - "description": "The source coordinate system ('rectangular' and 'cartesian' are equivalent)" - }, - "to_system": { - "type": "string", - "enum": ["rectangular", "cartesian", "polar"], - "description": "The target coordinate system ('rectangular' and 'cartesian' are equivalent)" - } - }, - "required": ["coord1", "coord2", "from_system", "to_system"], - "additionalProperties": False - } - } - }, - { - "type": "function", - "function": { - "name": "set_grid_visible", - "description": "Sets the visibility of the active coordinate grid (Cartesian or Polar). Use this to show or hide the grid lines without changing the coordinate system mode.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "visible": { - "type": "boolean", - "description": "Whether the grid should be visible (true to show, false to hide)" - } - }, - "required": ["visible"], - "additionalProperties": False - } - } - }, - # END COORDINATE SYSTEM FUNCTIONS - # START TOOL SEARCH FUNCTIONS - { - "type": "function", - "function": { - "name": "search_tools", - "description": "Search for the best tools to accomplish a task. Use this when you're unsure which specific tool to use. Provide a description of what you want to do, and receive the most relevant tool definitions.", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "Description of what you want to accomplish (e.g., 'draw a triangle with vertices at specific coordinates', 'calculate the derivative of a function')" - }, - "max_results": { - "type": ["integer", "null"], - "description": "Maximum number of tools to return (default: 10, max: 20)" - } - }, - "required": ["query", "max_results"], - "additionalProperties": False - } - } - } - # END TOOL SEARCH FUNCTIONS - ] + "description": "Right bound for plotting. Defaults to max(x_data) + 10% range.", + }, + }, + "required": ["left_bound", "right_bound"], + "additionalProperties": False, + }, + "curve_color": {"type": ["string", "null"], "description": "Optional color for the fitted curve."}, + "show_points": { + "type": ["boolean", "null"], + "description": "Whether to plot the data points. Defaults to true.", + }, + "point_color": { + "type": ["string", "null"], + "description": "Optional color for data points (if show_points is true).", + }, + }, + "required": [ + "name", + "x_data", + "y_data", + "model_type", + "degree", + "plot_bounds", + "curve_color", + "show_points", + "point_color", + ], + "additionalProperties": False, + }, + }, + }, + # END PLOT FUNCTIONS + # START ANGLE FUNCTIONS + { + "type": "function", + "function": { + "name": "create_angle", + "description": "Creates and draws an angle defined by three points. The first point (vx, vy) is the common vertex, and the other two points (p1x, p1y and p2x, p2y) define the angle's arms. For example, in an angle ABC, (vx, vy) would be the coordinates of point B. The angle's visual representation (arc and degree value) will be drawn. Segments forming the angle will be created if they don't exist.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "vx": { + "type": "number", + "description": "The X coordinate of the common vertex point (e.g., point B in an angle ABC).", + }, + "vy": { + "type": "number", + "description": "The Y coordinate of the common vertex point (e.g., point B in an angle ABC).", + }, + "p1x": {"type": "number", "description": "The X coordinate of the first arm point."}, + "p1y": {"type": "number", "description": "The Y coordinate of the first arm point."}, + "p2x": {"type": "number", "description": "The X coordinate of the second arm point."}, + "p2y": {"type": "number", "description": "The Y coordinate of the second arm point."}, + "color": { + "type": ["string", "null"], + "description": "Optional color for the angle's arc and text. Defaults to the canvas default color.", + }, + "angle_name": { + "type": ["string", "null"], + "description": "Optional name for the angle. If not provided, a name might be generated (e.g., 'angle_ABC').", + }, + "is_reflex": { + "type": ["boolean", "null"], + "description": "Optional. If true, the reflex angle will be created. Defaults to false (smallest angle).", + }, + }, + "required": ["vx", "vy", "p1x", "p1y", "p2x", "p2y", "color", "angle_name", "is_reflex"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "delete_angle", + "description": "Removes an angle by its name. This will also attempt to remove its constituent segments if they are no longer part of other drawables.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "The name of the angle to remove (e.g., 'angle_ABC')."} + }, + "required": ["name"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "update_angle", + "description": "Updates editable properties of an existing angle (currently just its color).", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "The name of the angle to update."}, + "new_color": { + "type": ["string", "null"], + "description": "The new color for the angle. Provide null to leave unchanged.", + }, + }, + "required": ["name", "new_color"], + "additionalProperties": False, + }, + }, + }, + # END ANGLE FUNCTIONS + # START AREA CALCULATION FUNCTIONS + { + "type": "function", + "function": { + "name": "calculate_area", + "description": "Calculates geometric area (triangle, polygon, circle, arc segment, region unions/intersections) from canvas drawables or boolean region expressions. Use this for 'area of a triangle/circle/region'.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "Boolean expression with drawable names. Examples: 'circle_A' (single shape), 'circle_A & triangle_ABC' (intersection), 'C(5) & AB' (circle cut by segment AB), 'ArcMaj_CD & triangle_ABC' (arc segment intersected with triangle), 'circle_A - triangle_ABC' (difference), '(circle_A & quad_ABCD) & EF' (shapes intersected then cut by segment).", + } + }, + "required": ["expression"], + "additionalProperties": False, + }, + }, + }, + # END AREA CALCULATION FUNCTIONS + # START COORDINATE SYSTEM FUNCTIONS + { + "type": "function", + "function": { + "name": "set_coordinate_system", + "description": "Sets the coordinate system mode for the canvas grid. Choose 'cartesian' for the standard x-y grid or 'polar' for a polar coordinate grid with concentric circles and radial lines.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "mode": { + "type": "string", + "enum": ["cartesian", "polar"], + "description": "The coordinate system mode: 'cartesian' for x-y grid, 'polar' for polar grid", + } + }, + "required": ["mode"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "convert_coordinates", + "description": "Converts coordinates between rectangular (Cartesian) and polar coordinate systems. For rectangular to polar: returns (r, theta) where r is radius and theta is angle in radians. For polar to rectangular: returns (x, y) coordinates.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "coord1": { + "type": "number", + "description": "First coordinate: x for rectangular-to-polar, r (radius) for polar-to-rectangular", + }, + "coord2": { + "type": "number", + "description": "Second coordinate: y for rectangular-to-polar, theta (angle in radians) for polar-to-rectangular", + }, + "from_system": { + "type": "string", + "enum": ["rectangular", "cartesian", "polar"], + "description": "The source coordinate system ('rectangular' and 'cartesian' are equivalent)", + }, + "to_system": { + "type": "string", + "enum": ["rectangular", "cartesian", "polar"], + "description": "The target coordinate system ('rectangular' and 'cartesian' are equivalent)", + }, + }, + "required": ["coord1", "coord2", "from_system", "to_system"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "set_grid_visible", + "description": "Sets the visibility of the active coordinate grid (Cartesian or Polar). Use this to show or hide the grid lines without changing the coordinate system mode.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "visible": { + "type": "boolean", + "description": "Whether the grid should be visible (true to show, false to hide)", + } + }, + "required": ["visible"], + "additionalProperties": False, + }, + }, + }, + # END COORDINATE SYSTEM FUNCTIONS + # START TOOL SEARCH FUNCTIONS + { + "type": "function", + "function": { + "name": "search_tools", + "description": "Search for the best tools to accomplish a task. Use this when you're unsure which specific tool to use. Provide a description of what you want to do, and receive the most relevant tool definitions.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Description of what you want to accomplish (e.g., 'draw a triangle with vertices at specific coordinates', 'calculate the derivative of a function')", + }, + "max_results": { + "type": ["integer", "null"], + "description": "Maximum number of tools to return (default: 10, max: 20)", + }, + }, + "required": ["query", "max_results"], + "additionalProperties": False, + }, + }, + }, + # END TOOL SEARCH FUNCTIONS +] diff --git a/static/log_manager.py b/static/log_manager.py index b5399e83..d2f769ad 100644 --- a/static/log_manager.py +++ b/static/log_manager.py @@ -46,7 +46,7 @@ class LogManager: Supports optional forwarding of logs to the browser console. """ - def __init__(self, logs_dir: str = './logs/') -> None: + def __init__(self, logs_dir: str = "./logs/") -> None: """Initialize LogManager with specified logs directory. Args: @@ -90,7 +90,7 @@ def _get_log_file_name(self) -> str: Returns: str: Date-based log filename (e.g., 'mathud_session_24_03_15.log') """ - return datetime.now().strftime('mathud_session_%y_%m_%d.log') + return datetime.now().strftime("mathud_session_%y_%m_%d.log") def _setup_logging(self) -> None: """Initialize logging configuration. @@ -105,16 +105,14 @@ def _setup_logging(self) -> None: root_logger = logging.getLogger() if not root_logger.handlers: logging.basicConfig( - filename=log_file_path, - level=logging.INFO, - format='%(asctime)s %(levelname)s %(message)s' + filename=log_file_path, level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s" ) self._logger.setLevel(logging.INFO) self._logger.propagate = False if not self._logger.handlers: - handler = logging.FileHandler(log_file_path, encoding='utf-8') - handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s')) + handler = logging.FileHandler(log_file_path, encoding="utf-8") + handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s")) self._logger.addHandler(handler) self.log_new_session() @@ -146,19 +144,19 @@ def log_user_message(self, user_message: str) -> None: svg_state = user_message_json.get("svg_state") if isinstance(svg_state, dict): - self._logger.info(f'### SVG state dimensions: {svg_state.get("dimensions")}') + self._logger.info(f"### SVG state dimensions: {svg_state.get('dimensions')}") canvas_state = user_message_json.get("canvas_state") if canvas_state is not None: - self._logger.info(f'### Canvas state: {canvas_state}') + self._logger.info(f"### Canvas state: {canvas_state}") previous_results = user_message_json.get("previous_results") if previous_results is not None: - self._logger.info(f'### Previously calculated results: {previous_results}') + self._logger.info(f"### Previously calculated results: {previous_results}") user_message_text = user_message_json.get("user_message") if user_message_text is not None: - self._logger.info(f'### User message: {user_message_text}') + self._logger.info(f"### User message: {user_message_text}") def log_ai_response(self, ai_message: str) -> None: """Log AI response message. @@ -166,7 +164,7 @@ def log_ai_response(self, ai_message: str) -> None: Args: ai_message: AI-generated response text """ - self._logger.info(f'### AI response: {ai_message}') + self._logger.info(f"### AI response: {ai_message}") def log_ai_tool_calls(self, ai_tool_calls: Sequence[ProcessedToolCall] | Sequence[Dict[str, Any]] | None) -> None: """Log AI tool calls. @@ -175,7 +173,7 @@ def log_ai_tool_calls(self, ai_tool_calls: Sequence[ProcessedToolCall] | Sequenc ai_tool_calls: List of AI-requested function calls (ProcessedToolCall or dict) """ if ai_tool_calls is not None: - self._logger.info(f'### AI tool calls: {list(ai_tool_calls)}') + self._logger.info(f"### AI tool calls: {list(ai_tool_calls)}") def log_action_trace(self, trace_summary: Dict[str, Any]) -> None: """Log a structured action trace summary as a JSON line. diff --git a/static/mirror_client_modules.py b/static/mirror_client_modules.py index dbba02d6..20caa1f9 100644 --- a/static/mirror_client_modules.py +++ b/static/mirror_client_modules.py @@ -46,4 +46,3 @@ def ensure_client_constants_available() -> None: def ensure_polygon_subtypes_available() -> None: """Mirror the polygon subtype enums into the server package if needed.""" _mirror_if_stale(_POLYGON_SUBTYPES_CLIENT_PATH, _POLYGON_SUBTYPES_SERVER_PATH) - diff --git a/static/openai_api_base.py b/static/openai_api_base.py index cb58ade9..9951396b 100644 --- a/static/openai_api_base.py +++ b/static/openai_api_base.py @@ -35,12 +35,14 @@ ToolMode = Literal["full", "search"] # Essential tool names that should always be available after injection -ESSENTIAL_TOOLS = frozenset({ - "search_tools", - "undo", - "redo", - "get_current_canvas_state", -}) +ESSENTIAL_TOOLS = frozenset( + { + "search_tools", + "undo", + "redo", + "get_current_canvas_state", + } +) def _build_search_mode_tools() -> List[FunctionDefinition]: @@ -93,9 +95,7 @@ def _initialize_api_key() -> str: api_key = os.getenv("OPENAI_API_KEY") if not api_key: - logging.getLogger("mathud").warning( - "OPENAI_API_KEY not found. OpenAI models will be unavailable." - ) + logging.getLogger("mathud").warning("OPENAI_API_KEY not found. OpenAI models will be unavailable.") return "not-configured" return api_key @@ -125,9 +125,7 @@ def __init__( self._custom_tools: Optional[Sequence[FunctionDefinition]] = tools self._injected_tools: bool = False # Track if tools were dynamically injected self.tools: Sequence[FunctionDefinition] = self._resolve_tools() - self.messages: List[MessageDict] = [ - {"role": "developer", "content": OpenAIAPIBase.DEV_MSG} - ] + self.messages: List[MessageDict] = [{"role": "developer", "content": OpenAIAPIBase.DEV_MSG}] def _resolve_tools(self) -> Sequence[FunctionDefinition]: """Resolve the active tool set based on mode and custom tools. @@ -322,11 +320,8 @@ def _create_enhanced_prompt_with_image( if include_canvas_snapshot: try: with open("canvas_snapshots/canvas.png", "rb") as image_file: - image_data = base64.b64encode(image_file.read()).decode('utf-8') - content.append({ - "type": "image_url", - "image_url": {"url": f"data:image/png;base64,{image_data}"} - }) + image_data = base64.b64encode(image_file.read()).decode("utf-8") + content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_data}"}}) has_images = True except Exception as e: error_msg = f"Failed to load canvas image: {e}" @@ -337,10 +332,7 @@ def _create_enhanced_prompt_with_image( if attached_images: for img_url in attached_images: if isinstance(img_url, str) and img_url.startswith("data:image"): - content.append({ - "type": "image_url", - "image_url": {"url": img_url} - }) + content.append({"type": "image_url", "image_url": {"url": img_url}}) has_images = True return content if has_images else None @@ -534,10 +526,7 @@ def _create_error_response( error_message: str = "I encountered an error processing your request. Please try again.", ) -> SimpleNamespace: """Create an error response that matches OpenAI's response structure.""" - return SimpleNamespace( - message=SimpleNamespace(content=error_message, tool_calls=[]), - finish_reason="error" - ) + return SimpleNamespace(message=SimpleNamespace(content=error_message, tool_calls=[]), finish_reason="error") def _create_tool_message(self, tool_call_id: Optional[str], content: str) -> MessageDict: """Create a tool message in response to a tool call.""" @@ -547,10 +536,7 @@ def _append_tool_messages(self, tool_calls: Sequence[Any] | None) -> None: """Create and append placeholder tool messages for each tool call.""" if tool_calls: for tool_call in tool_calls: - tool_message = self._create_tool_message( - getattr(tool_call, "id", None), - "Awaiting result..." - ) + tool_message = self._create_tool_message(getattr(tool_call, "id", None), "Awaiting result...") self.messages.append(tool_message) def _update_tool_messages_with_results(self, tool_call_results: str) -> None: diff --git a/static/openai_completions_api.py b/static/openai_completions_api.py index 4f10e9a6..e07d1cf2 100644 --- a/static/openai_completions_api.py +++ b/static/openai_completions_api.py @@ -37,8 +37,8 @@ def _create_assistant_message(self, response_message: Any) -> MessageDict: "type": "function", "function": { "name": getattr(getattr(tool_call, "function", None), "name", None), - "arguments": getattr(getattr(tool_call, "function", None), "arguments", None) - } + "arguments": getattr(getattr(tool_call, "function", None), "arguments", None), + }, } for tool_call in tool_calls ] @@ -260,13 +260,15 @@ def _normalize_tool_calls(self, accumulator: Dict[int, Dict[str, Any]]) -> List[ normalized = [] for tc in tool_calls_list: func = tc.get("function", {}) if isinstance(tc, dict) else {} - normalized.append({ - "id": tc.get("id") if isinstance(tc, dict) else None, - "function": { - "name": func.get("name") if isinstance(func, dict) else None, - "arguments": func.get("arguments") if isinstance(func, dict) else None, - }, - }) + normalized.append( + { + "id": tc.get("id") if isinstance(tc, dict) else None, + "function": { + "name": func.get("name") if isinstance(func, dict) else None, + "arguments": func.get("arguments") if isinstance(func, dict) else None, + }, + } + ) return normalized def _finalize_stream(self, accumulated_text: str, normalized_tool_calls: List[Dict[str, Any]]) -> None: @@ -287,22 +289,25 @@ def _finalize_stream(self, accumulated_text: str, normalized_tool_calls: List[Di assistant_message = self._create_assistant_message(assistant_message_like) self.messages.append(assistant_message) - self._append_tool_messages([ - SimpleNamespace( - id=tc.get("id"), - function=SimpleNamespace( - name=tc.get("function", {}).get("name"), - arguments=tc.get("function", {}).get("arguments"), - ), - ) - for tc in normalized_tool_calls - ]) + self._append_tool_messages( + [ + SimpleNamespace( + id=tc.get("id"), + function=SimpleNamespace( + name=tc.get("function", {}).get("name"), + arguments=tc.get("function", {}).get("arguments"), + ), + ) + for tc in normalized_tool_calls + ] + ) self._clean_conversation_history() def _prepare_tool_calls_for_response(self, normalized_tool_calls: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Prepare tool calls for the final response.""" import json as _json + result = [] for tc in normalized_tool_calls: func = tc.get("function", {}) if isinstance(tc, dict) else {} @@ -312,9 +317,10 @@ def _prepare_tool_calls_for_response(self, normalized_tool_calls: List[Dict[str, func_args = _json.loads(func_args_raw) if func_args_raw else {} except Exception: func_args = {} - result.append({ - "function_name": func_name or "", - "arguments": func_args, - }) + result.append( + { + "function_name": func_name or "", + "arguments": func_args, + } + ) return result - diff --git a/static/openai_responses_api.py b/static/openai_responses_api.py index ca065c49..aaa92c42 100644 --- a/static/openai_responses_api.py +++ b/static/openai_responses_api.py @@ -101,12 +101,14 @@ def _convert_tools_for_responses_api(self) -> List[Dict[str, Any]]: for tool in self.tools: if isinstance(tool, dict) and tool.get("type") == "function": func = tool.get("function", {}) - converted_tools.append({ - "type": "function", - "name": func.get("name", ""), - "description": func.get("description", ""), - "parameters": func.get("parameters", {}), - }) + converted_tools.append( + { + "type": "function", + "name": func.get("name", ""), + "description": func.get("description", ""), + "parameters": func.get("parameters", {}), + } + ) return converted_tools def _convert_messages_to_input(self) -> List[Dict[str, Any]]: @@ -136,18 +138,14 @@ def _convert_messages_to_input(self) -> List[Dict[str, Any]]: self._flush_log() return input_messages - def _handle_developer_message( - self, msg: Dict[str, Any], output: List[Dict[str, Any]], index: int - ) -> int: + def _handle_developer_message(self, msg: Dict[str, Any], output: List[Dict[str, Any]], index: int) -> int: """Convert developer message to system role for Responses API.""" content = msg.get("content", "") output.append({"role": "system", "content": content}) self._log(f"[Responses API] Including developer message at index {index}") return index + 1 - def _handle_assistant_with_tool_calls( - self, msg: Dict[str, Any], output: List[Dict[str, Any]], index: int - ) -> int: + def _handle_assistant_with_tool_calls(self, msg: Dict[str, Any], output: List[Dict[str, Any]], index: int) -> int: """Convert assistant tool calls and results to text message pairs.""" tool_calls = msg.get("tool_calls", []) tool_results, end_index = self._collect_tool_results(index + 1) @@ -157,7 +155,9 @@ def _handle_assistant_with_tool_calls( user_msg = self._create_tool_results_message(tool_results) output.append(assistant_msg) output.append(user_msg) - self._log(f"[Responses API] Converted tool call+results at index {index}-{end_index - 1} to assistant+user messages") + self._log( + f"[Responses API] Converted tool call+results at index {index}-{end_index - 1} to assistant+user messages" + ) return end_index else: self._log(f"[Responses API] Skipping assistant with pending tool calls at index {index}") @@ -168,9 +168,7 @@ def _handle_orphan_tool_message(self, index: int) -> int: self._log(f"[Responses API] Skipping orphan tool message at index {index}") return index + 1 - def _handle_regular_message( - self, msg: Dict[str, Any], role: str, output: List[Dict[str, Any]], index: int - ) -> int: + def _handle_regular_message(self, msg: Dict[str, Any], role: str, output: List[Dict[str, Any]], index: int) -> int: """Include regular user/assistant messages, converting content format if needed.""" content = msg.get("content", "") # Convert Chat Completions content format to Responses API format @@ -396,9 +394,7 @@ def _handle_reasoning_item(self, item: Any, state: Dict[str, Any]) -> Iterator[S yield {"type": "reasoning", "text": "(Reasoning in progress...)\n"} state["reasoning_placeholder_sent"] = True - def _handle_function_call_item( - self, item: Any, output_index: int, accumulator: Dict[int, Dict[str, Any]] - ) -> None: + def _handle_function_call_item(self, item: Any, output_index: int, accumulator: Dict[int, Dict[str, Any]]) -> None: """Handle function call items from output_item.added events.""" call_id = getattr(item, "call_id", None) name = getattr(item, "name", None) @@ -477,7 +473,9 @@ def _build_final_response(self, state: Dict[str, Any]) -> StreamEvent: self._log(f"[Responses API] Final ai_tool_calls: {ai_tool_calls}") final_finish_reason = "tool_calls" if ai_tool_calls else (state["finish_reason"] or "stop") - self._log(f"[Responses API] Yielding final event with {len(ai_tool_calls)} tool calls, finish_reason={final_finish_reason}") + self._log( + f"[Responses API] Yielding final event with {len(ai_tool_calls)} tool calls, finish_reason={final_finish_reason}" + ) self._flush_log() return { @@ -575,13 +573,15 @@ def _normalize_tool_calls(self, acc: Dict[int, Dict[str, Any]]) -> List[Dict[str for i in sorted(acc): tc = acc[i] func = tc.get("function", {}) - result.append({ - "id": tc.get("id"), - "function": { - "name": func.get("name"), - "arguments": func.get("arguments"), - }, - }) + result.append( + { + "id": tc.get("id"), + "function": { + "name": func.get("name"), + "arguments": func.get("arguments"), + }, + } + ) return result def _finalize_stream(self, text: str, tool_calls: List[Dict[str, Any]]) -> None: @@ -600,16 +600,18 @@ def _finalize_stream(self, text: str, tool_calls: List[Dict[str, Any]]) -> None: ], ) self.messages.append(self._create_assistant_message(assistant_msg)) - self._append_tool_messages([ - SimpleNamespace( - id=tc.get("id"), - function=SimpleNamespace( - name=tc.get("function", {}).get("name"), - arguments=tc.get("function", {}).get("arguments"), - ), - ) - for tc in tool_calls - ]) + self._append_tool_messages( + [ + SimpleNamespace( + id=tc.get("id"), + function=SimpleNamespace( + name=tc.get("function", {}).get("name"), + arguments=tc.get("function", {}).get("arguments"), + ), + ) + for tc in tool_calls + ] + ) self._clean_conversation_history() def _create_assistant_message(self, response_message: Any) -> MessageDict: @@ -624,8 +626,8 @@ def _create_assistant_message(self, response_message: Any) -> MessageDict: "type": "function", "function": { "name": getattr(getattr(tc, "function", None), "name", None), - "arguments": getattr(getattr(tc, "function", None), "arguments", None) - } + "arguments": getattr(getattr(tc, "function", None), "arguments", None), + }, } for tc in tool_calls ] diff --git a/static/providers/__init__.py b/static/providers/__init__.py index 430a3d91..3caf14a1 100644 --- a/static/providers/__init__.py +++ b/static/providers/__init__.py @@ -81,6 +81,7 @@ def is_provider_available(cls, provider_name: str) -> bool: # Local providers use server availability check if provider_name in LOCAL_PROVIDERS: from static.providers.local import LocalProviderRegistry + return LocalProviderRegistry.is_provider_available(provider_name) # API-based providers use API key check @@ -141,8 +142,9 @@ def get_provider_for_model(model_id: str) -> Optional[str]: Provider name, or None if unknown """ from static.ai_model import AIModel + model = AIModel.from_identifier(model_id) - return getattr(model, 'provider', PROVIDER_OPENAI) + return getattr(model, "provider", PROVIDER_OPENAI) def is_local_provider(provider_name: str) -> bool: @@ -173,6 +175,7 @@ def create_provider_instance( # For local providers, get class from LocalProviderRegistry if provider_name in LOCAL_PROVIDERS: from static.providers.local import LocalProviderRegistry + provider_class = LocalProviderRegistry.get_provider_class(provider_name) if provider_class is None: _logger.warning(f"Local provider not registered: {provider_name}") @@ -213,12 +216,14 @@ def discover_providers() -> None: # Import API-based provider modules - they self-register on import try: from static.providers import anthropic_api # noqa: F401 + _logger.debug("Loaded anthropic_api provider module") except ImportError as e: _logger.debug(f"Could not load anthropic_api: {e}") try: from static.providers import openrouter_api # noqa: F401 + _logger.debug("Loaded openrouter_api provider module") except ImportError as e: _logger.debug(f"Could not load openrouter_api: {e}") @@ -226,6 +231,7 @@ def discover_providers() -> None: # Import local provider modules - they self-register on import try: from static.providers.local import ollama_api # noqa: F401 + _logger.debug("Loaded ollama_api local provider module") except ImportError as e: _logger.debug(f"Could not load ollama_api: {e}") diff --git a/static/providers/anthropic_api.py b/static/providers/anthropic_api.py index 94534dd1..a2a9fa60 100644 --- a/static/providers/anthropic_api.py +++ b/static/providers/anthropic_api.py @@ -65,8 +65,7 @@ def __init__( import anthropic except ImportError as e: raise ImportError( - "anthropic package is required for Anthropic provider. " - "Install with: pip install anthropic" + "anthropic package is required for Anthropic provider. Install with: pip install anthropic" ) from e self._anthropic_client = anthropic.Anthropic(api_key=_get_anthropic_api_key()) @@ -159,12 +158,14 @@ def _convert_messages_to_anthropic(self) -> List[Dict[str, Any]]: args = json.loads(args_str) if isinstance(args_str, str) else args_str except json.JSONDecodeError: args = {} - content_blocks.append({ - "type": "tool_use", - "id": tc.get("id", ""), - "name": func.get("name", ""), - "input": args, - }) + content_blocks.append( + { + "type": "tool_use", + "id": tc.get("id", ""), + "name": func.get("name", ""), + "input": args, + } + ) anthropic_messages.append({"role": "assistant", "content": content_blocks}) else: anthropic_messages.append({"role": "assistant", "content": content or ""}) @@ -208,14 +209,16 @@ def _convert_content_blocks(self, content: List[Dict[str, Any]]) -> List[Dict[st parts = url.split(",", 1) if len(parts) == 2: media_type_part = parts[0].replace("data:", "").replace(";base64", "") - anthropic_blocks.append({ - "type": "image", - "source": { - "type": "base64", - "media_type": media_type_part, - "data": parts[1], - }, - }) + anthropic_blocks.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": media_type_part, + "data": parts[1], + }, + } + ) return anthropic_blocks @@ -274,14 +277,16 @@ def _process_anthropic_response(self, response: Any) -> Any: if block.type == "text": text_content += block.text elif block.type == "tool_use": - tool_calls.append({ - "id": block.id, - "type": "function", - "function": { - "name": block.name, - "arguments": json.dumps(block.input), - }, - }) + tool_calls.append( + { + "id": block.id, + "type": "function", + "function": { + "name": block.name, + "arguments": json.dumps(block.input), + }, + } + ) # Create assistant message for history assistant_message: MessageDict = {"role": "assistant", "content": text_content} @@ -292,11 +297,13 @@ def _process_anthropic_response(self, response: Any) -> Any: # Append placeholder tool messages if tool_calls: for tc in tool_calls: - self.messages.append({ - "role": "tool", - "tool_call_id": tc["id"], - "content": "Awaiting result...", - }) + self.messages.append( + { + "role": "tool", + "tool_call_id": tc["id"], + "content": "Awaiting result...", + } + ) self._clean_conversation_history() @@ -314,7 +321,9 @@ def _process_anthropic_response(self, response: Any) -> Any: ), ) for tc in tool_calls - ] if tool_calls else None, + ] + if tool_calls + else None, ), finish_reason=finish_reason, ) @@ -408,9 +417,7 @@ def create_chat_completion_stream(self, full_prompt: str) -> Iterator[StreamEven "finish_reason": finish_reason or "stop", } - def _finalize_anthropic_stream( - self, accumulated_text: str, tool_calls: List[Dict[str, Any]] - ) -> None: + def _finalize_anthropic_stream(self, accumulated_text: str, tool_calls: List[Dict[str, Any]]) -> None: """Finalize the streaming response by updating messages.""" # Create assistant message assistant_message: MessageDict = {"role": "assistant", "content": accumulated_text} @@ -427,17 +434,17 @@ def _finalize_anthropic_stream( # Append placeholder tool messages for tc in tool_calls: - self.messages.append({ - "role": "tool", - "tool_call_id": tc["id"], - "content": "Awaiting result...", - }) + self.messages.append( + { + "role": "tool", + "tool_call_id": tc["id"], + "content": "Awaiting result...", + } + ) self._clean_conversation_history() - def _prepare_tool_calls_for_response( - self, tool_calls: List[Dict[str, Any]] - ) -> List[Dict[str, Any]]: + def _prepare_tool_calls_for_response(self, tool_calls: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Prepare tool calls for the final response.""" result = [] for tc in tool_calls: @@ -448,10 +455,12 @@ def _prepare_tool_calls_for_response( func_args = json.loads(func_args_raw) if func_args_raw else {} except json.JSONDecodeError: func_args = {} - result.append({ - "function_name": func_name, - "arguments": func_args, - }) + result.append( + { + "function_name": func_name, + "arguments": func_args, + } + ) return result diff --git a/static/providers/local/__init__.py b/static/providers/local/__init__.py index 1fb8dff4..25afe94d 100644 --- a/static/providers/local/__init__.py +++ b/static/providers/local/__init__.py @@ -207,9 +207,7 @@ def __init__( self.tools: Sequence[FunctionDefinition] = self._resolve_tools() # Use developer message as system prompt - self.messages: List[Dict[str, Any]] = [ - {"role": "system", "content": OpenAIAPIBase.DEV_MSG} - ] + self.messages: List[Dict[str, Any]] = [{"role": "system", "content": OpenAIAPIBase.DEV_MSG}] @abstractmethod def _is_available(self) -> bool: @@ -265,10 +263,7 @@ def discover_models_with_tool_support(self) -> List[Dict[str, Any]]: if supports_tools(model_name): tool_capable.append(model_info) - _logger.info( - f"Discovered {len(tool_capable)} tool-capable models " - f"out of {len(all_models)} total" - ) + _logger.info(f"Discovered {len(tool_capable)} tool-capable models out of {len(all_models)} total") return tool_capable def reset_conversation(self) -> None: @@ -476,14 +471,16 @@ def _process_response(self, choice: Any) -> Any: # Build tool calls list tool_calls = [] for tc in raw_tool_calls: - tool_calls.append({ - "id": tc.id, - "type": "function", - "function": { - "name": tc.function.name, - "arguments": tc.function.arguments, - }, - }) + tool_calls.append( + { + "id": tc.id, + "type": "function", + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } + ) # Add assistant message to history assistant_message: Dict[str, Any] = {"role": "assistant", "content": text_content} @@ -493,11 +490,13 @@ def _process_response(self, choice: Any) -> Any: # Add placeholder tool messages for tc in tool_calls: - self.messages.append({ - "role": "tool", - "tool_call_id": tc["id"], - "content": "Awaiting result...", - }) + self.messages.append( + { + "role": "tool", + "tool_call_id": tc["id"], + "content": "Awaiting result...", + } + ) self._clean_conversation_history() @@ -515,14 +514,14 @@ def _process_response(self, choice: Any) -> Any: ), ) for tc in tool_calls - ] if tool_calls else None, + ] + if tool_calls + else None, ), finish_reason=finish_reason, ) - def _finalize_stream( - self, accumulated_text: str, tool_calls: List[Dict[str, Any]] - ) -> None: + def _finalize_stream(self, accumulated_text: str, tool_calls: List[Dict[str, Any]]) -> None: """Finalize the streaming response by updating messages.""" # Create assistant message assistant_message: Dict[str, Any] = {"role": "assistant", "content": accumulated_text} @@ -539,17 +538,17 @@ def _finalize_stream( # Add placeholder tool messages for tc in tool_calls: - self.messages.append({ - "role": "tool", - "tool_call_id": tc["id"], - "content": "Awaiting result...", - }) + self.messages.append( + { + "role": "tool", + "tool_call_id": tc["id"], + "content": "Awaiting result...", + } + ) self._clean_conversation_history() - def _prepare_tool_calls_for_response( - self, tool_calls: List[Dict[str, Any]] - ) -> List[Dict[str, Any]]: + def _prepare_tool_calls_for_response(self, tool_calls: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Prepare tool calls for the final response.""" import json @@ -562,8 +561,10 @@ def _prepare_tool_calls_for_response( func_args = json.loads(func_args_raw) if func_args_raw else {} except json.JSONDecodeError: func_args = {} - result.append({ - "function_name": func_name, - "arguments": func_args, - }) + result.append( + { + "function_name": func_name, + "arguments": func_args, + } + ) return result diff --git a/static/providers/local/ollama_api.py b/static/providers/local/ollama_api.py index e4c77555..160a3bc6 100644 --- a/static/providers/local/ollama_api.py +++ b/static/providers/local/ollama_api.py @@ -109,11 +109,13 @@ def _discover_models(self) -> List[Dict[str, Any]]: models = [] for model in data.get("models", []): - models.append({ - "name": model.get("name", ""), - "size": model.get("size", 0), - "modified_at": model.get("modified_at", ""), - }) + models.append( + { + "name": model.get("name", ""), + "size": model.get("size", 0), + "modified_at": model.get("modified_at", ""), + } + ) return models except Exception as e: @@ -156,11 +158,13 @@ def get_tool_capable_models(cls) -> List[Dict[str, Any]]: for model in data.get("models", []): name = model.get("name", "") if supports_tools(name): - tool_capable.append({ - "name": name, - "size": model.get("size", 0), - "modified_at": model.get("modified_at", ""), - }) + tool_capable.append( + { + "name": name, + "size": model.get("size", 0), + "modified_at": model.get("modified_at", ""), + } + ) return tool_capable except Exception as e: diff --git a/static/providers/openrouter_api.py b/static/providers/openrouter_api.py index a0ac931b..dc32a69a 100644 --- a/static/providers/openrouter_api.py +++ b/static/providers/openrouter_api.py @@ -81,6 +81,7 @@ def __init__( # Initialize message history with developer message from static.openai_api_base import OpenAIAPIBase + self.messages = [{"role": "developer", "content": OpenAIAPIBase.DEV_MSG}] diff --git a/static/routes.py b/static/routes.py index 850a9b6b..63a52b52 100644 --- a/static/routes.py +++ b/static/routes.py @@ -77,8 +77,7 @@ def get_provider_for_model(app: MatHudFlask, model_id: str) -> OpenAIAPIBase: provider_instance = create_provider_instance(provider_name, **create_kwargs) if provider_instance is None: raise ValueError( - f"Provider '{provider_name}' is not available. " - f"Check that the API key is configured in .env." + f"Provider '{provider_name}' is not available. Check that the API key is configured in .env." ) app.providers[provider_name] = provider_instance @@ -114,7 +113,7 @@ def save_canvas_snapshot_from_data_url(data_url: str) -> bool: def extract_vision_payload( - request_payload: Dict[str, Any] + request_payload: Dict[str, Any], ) -> tuple[Optional[Dict[str, Any]], Optional[str], Optional[str], Optional[List[str]]]: """Extract vision-related data from the request payload. @@ -129,27 +128,27 @@ def extract_vision_payload( renderer_mode: Optional[str] = None attached_images: Optional[List[str]] = None - raw_svg_state = request_payload.get('svg_state') + raw_svg_state = request_payload.get("svg_state") if isinstance(raw_svg_state, dict): svg_state = raw_svg_state - raw_renderer = request_payload.get('renderer_mode') + raw_renderer = request_payload.get("renderer_mode") if isinstance(raw_renderer, str): renderer_mode = raw_renderer - vision_snapshot = request_payload.get('vision_snapshot') + vision_snapshot = request_payload.get("vision_snapshot") if isinstance(vision_snapshot, dict): - snapshot_svg = vision_snapshot.get('svg_state') + snapshot_svg = vision_snapshot.get("svg_state") if isinstance(snapshot_svg, dict): svg_state = snapshot_svg - snapshot_renderer = vision_snapshot.get('renderer_mode') + snapshot_renderer = vision_snapshot.get("renderer_mode") if isinstance(snapshot_renderer, str): renderer_mode = snapshot_renderer - snapshot_canvas = vision_snapshot.get('canvas_image') + snapshot_canvas = vision_snapshot.get("canvas_image") if isinstance(snapshot_canvas, str): canvas_image = snapshot_canvas # Extract attached images from the request payload - raw_attached_images = request_payload.get('attached_images') + raw_attached_images = request_payload.get("attached_images") if isinstance(raw_attached_images, list): attached_images = [img for img in raw_attached_images if isinstance(img, str)] @@ -243,12 +242,12 @@ def _intercept_search_tools( def _tool_call_name(call: Dict[str, Any]) -> Optional[str]: """Extract a tool call function name from normalized or nested shapes.""" - function_name = call.get('function_name') + function_name = call.get("function_name") if isinstance(function_name, str): return function_name - function = call.get('function') + function = call.get("function") if isinstance(function, dict): - nested_name = function.get('name') + nested_name = function.get("name") if isinstance(nested_name, str): return nested_name return None @@ -257,14 +256,14 @@ def _tool_call_name(call: Dict[str, Any]) -> Optional[str]: def _find_search_tools_call(tool_calls: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]: """Return the first search_tools call, if any.""" for call in tool_calls: - if _tool_call_name(call) == 'search_tools': + if _tool_call_name(call) == "search_tools": return call return None def _normalize_search_tools_args(call: Dict[str, Any]) -> Dict[str, Any]: """Normalize search_tools arguments to a dictionary.""" - args = call.get('arguments', {}) + args = call.get("arguments", {}) if isinstance(args, dict): return args if isinstance(args, str): @@ -280,10 +279,10 @@ def _normalize_search_tools_args(call: Dict[str, Any]) -> Dict[str, Any]: def _extract_search_query_and_limit(call: Dict[str, Any]) -> tuple[str, int]: """Extract (query, max_results) from a search_tools call.""" args = _normalize_search_tools_args(call) - query = args.get('query', '') + query = args.get("query", "") if not isinstance(query, str): - query = '' - raw_max_results = args.get('max_results', 10) + query = "" + raw_max_results = args.get("max_results", 10) if isinstance(raw_max_results, int): max_results = raw_max_results else: @@ -296,11 +295,7 @@ def _collect_allowed_tool_names( essential_tools: AbstractSet[str], ) -> set[str]: """Collect allowed tool names from search result plus essentials.""" - allowed_names = { - t.get('function', {}).get('name') - for t in search_result - if isinstance(t, dict) - } + allowed_names = {t.get("function", {}).get("name") for t in search_result if isinstance(t, dict)} allowed_names.update(essential_tools) return allowed_names @@ -362,12 +357,13 @@ def require_auth(f: F) -> F: Returns: Wrapped function that checks authentication before proceeding """ + @functools.wraps(f) def decorated_function(*args: Any, **kwargs: Any) -> ResponseReturnValue: # Require authentication when deployed or explicitly enabled if AppManager.requires_auth(): - if not session.get('authenticated'): - return redirect(url_for('login')) + if not session.get("authenticated"): + return redirect(url_for("login")) return f(*args, **kwargs) return cast(F, decorated_function) @@ -383,26 +379,26 @@ def register_routes(app: MatHudFlask) -> None: app: Flask application instance """ - @app.route('/login', methods=['GET', 'POST']) + @app.route("/login", methods=["GET", "POST"]) def login() -> ResponseReturnValue: """Handle user authentication with access code.""" if not AppManager.requires_auth(): - return redirect(url_for('get_index')) + return redirect(url_for("get_index")) # If user is already authenticated, redirect to main page - if session.get('authenticated'): - return redirect(url_for('get_index')) + if session.get("authenticated"): + return redirect(url_for("get_index")) - if request.method == 'POST': - client_ip = request.environ.get('HTTP_X_FORWARDED_FOR', request.remote_addr) + if request.method == "POST": + client_ip = request.environ.get("HTTP_X_FORWARDED_FOR", request.remote_addr) current_time = time.time() - pin_submitted = request.form.get('pin', '') + pin_submitted = request.form.get("pin", "") auth_pin = AppManager.get_auth_pin() if not auth_pin: - flash('Authentication not configured') - return render_template('login.html') + flash("Authentication not configured") + return render_template("login.html") is_pin_correct = hmac.compare_digest(pin_submitted, auth_pin) @@ -410,9 +406,9 @@ def login() -> ResponseReturnValue: if is_pin_correct: # 1. PIN is correct. Login succeeds immediately. - session['authenticated'] = True - login_attempts.pop(client_ip, None) # Clear any old rate limit. - return redirect(url_for('get_index')) + session["authenticated"] = True + login_attempts.pop(client_ip, None) # Clear any old rate limit. + return redirect(url_for("get_index")) else: # 2. PIN is incorrect. Now we handle rate limiting. if client_ip in login_attempts: @@ -422,42 +418,41 @@ def login() -> ResponseReturnValue: # 2a. Cooldown is ACTIVE. Block and show countdown. remaining_cooldown = 5.0 - time_since_last_failed display_time = math.ceil(remaining_cooldown) - flash(f'Too many attempts. Please wait {display_time} seconds.') - return render_template('login.html') + flash(f"Too many attempts. Please wait {display_time} seconds.") + return render_template("login.html") # 2b. PIN was wrong, but no active cooldown. Start a new one. login_attempts[client_ip] = current_time - flash('Invalid access code') + flash("Invalid access code") # Cleanup logic (runs only after a failed attempt) if len(login_attempts) > 1000: - cutoff = current_time - 3600 # 1 hour + cutoff = current_time - 3600 # 1 hour for ip, timestamp in list(login_attempts.items()): if timestamp < cutoff: del login_attempts[ip] - return render_template('login.html') + return render_template("login.html") # For GET request - return render_template('login.html') + return render_template("login.html") - @app.route('/logout') + @app.route("/logout") def logout() -> ResponseReturnValue: """Handle user logout and session cleanup.""" - session.pop('authenticated', None) + session.pop("authenticated", None) if AppManager.requires_auth(): - return redirect(url_for('login')) - return redirect(url_for('get_index')) + return redirect(url_for("login")) + return redirect(url_for("get_index")) - @app.route('/auth_status') + @app.route("/auth_status") def auth_status() -> ResponseReturnValue: """Return authentication status information.""" - return AppManager.make_response(data={ - 'auth_required': AppManager.requires_auth(), - 'authenticated': session.get('authenticated', False) - }) + return AppManager.make_response( + data={"auth_required": AppManager.requires_auth(), "authenticated": session.get("authenticated", False)} + ) - @app.route('/api/available_models', methods=['GET']) + @app.route("/api/available_models", methods=["GET"]) @require_auth def get_available_models() -> ResponseReturnValue: """Return models grouped by provider, only for available providers.""" @@ -503,7 +498,7 @@ def get_available_models() -> ResponseReturnValue: return jsonify(models_by_provider) - @app.route('/api/preload_model', methods=['POST']) + @app.route("/api/preload_model", methods=["POST"]) @require_auth def preload_model() -> ResponseReturnValue: """Preload an Ollama model into memory. @@ -519,16 +514,16 @@ def preload_model() -> ResponseReturnValue: request_payload = request.get_json(silent=True) if not isinstance(request_payload, dict): return AppManager.make_response( - message='Invalid request body', - status='error', + message="Invalid request body", + status="error", code=400, ) - model_id = request_payload.get('model_id') + model_id = request_payload.get("model_id") if not isinstance(model_id, str) or not model_id: return AppManager.make_response( - message='model_id is required', - status='error', + message="model_id is required", + status="error", code=400, ) @@ -536,24 +531,24 @@ def preload_model() -> ResponseReturnValue: model = AIModel.from_identifier(model_id) if model.provider != PROVIDER_OLLAMA: return AppManager.make_response( - message='Model preloading only supported for Ollama models', - status='error', + message="Model preloading only supported for Ollama models", + status="error", code=400, ) # Check if server is running if not OllamaAPI.is_server_running(): return AppManager.make_response( - message='Ollama server is not running', - status='error', + message="Ollama server is not running", + status="error", code=503, ) # Check if already loaded if OllamaAPI.is_model_loaded(model_id): return AppManager.make_response( - data={'already_loaded': True}, - message=f'Model {model_id} is already loaded', + data={"already_loaded": True}, + message=f"Model {model_id} is already loaded", ) # Preload the model (this may take a while) @@ -561,17 +556,17 @@ def preload_model() -> ResponseReturnValue: if success: return AppManager.make_response( - data={'already_loaded': False}, + data={"already_loaded": False}, message=message, ) else: return AppManager.make_response( message=message, - status='error', + status="error", code=500, ) - @app.route('/api/model_status', methods=['GET']) + @app.route("/api/model_status", methods=["GET"]) @require_auth def get_model_status() -> ResponseReturnValue: """Get the loading status of Ollama models. @@ -584,11 +579,11 @@ def get_model_status() -> ResponseReturnValue: """ from static.providers.local.ollama_api import OllamaAPI - model_id = request.args.get('model_id') + model_id = request.args.get("model_id") if not OllamaAPI.is_server_running(): return AppManager.make_response( - data={'server_running': False, 'loaded_models': []}, + data={"server_running": False, "loaded_models": []}, ) loaded_models = OllamaAPI.get_loaded_models() @@ -597,28 +592,28 @@ def get_model_status() -> ResponseReturnValue: is_loaded = OllamaAPI.is_model_loaded(model_id) return AppManager.make_response( data={ - 'server_running': True, - 'model_id': model_id, - 'is_loaded': is_loaded, - 'loaded_models': loaded_models, + "server_running": True, + "model_id": model_id, + "is_loaded": is_loaded, + "loaded_models": loaded_models, }, ) return AppManager.make_response( data={ - 'server_running': True, - 'loaded_models': loaded_models, + "server_running": True, + "loaded_models": loaded_models, }, ) - @app.route('/api/debug/conversation', methods=['GET']) + @app.route("/api/debug/conversation", methods=["GET"]) @require_auth def debug_conversation() -> ResponseReturnValue: """Debug endpoint to view conversation history for all providers. Returns the message history for debugging purposes. """ - provider_name = request.args.get('provider') + provider_name = request.args.get("provider") def summarize_message(msg: Dict[str, Any]) -> Dict[str, Any]: """Summarize a message for display, truncating long content.""" @@ -675,100 +670,83 @@ def summarize_message(msg: Dict[str, Any]) -> Dict[str, Any]: return AppManager.make_response(data=cast(JsonValue, result)) - @app.route('/api/debug/canvas-state-comparison', methods=['POST']) + @app.route("/api/debug/canvas-state-comparison", methods=["POST"]) @require_auth def debug_canvas_state_comparison() -> ResponseReturnValue: """Development-only endpoint for full vs summary canvas-state comparison.""" if AppManager.is_deployed(): return AppManager.make_response( - message='Not found', - status='error', + message="Not found", + status="error", code=404, ) payload = request.get_json(silent=True) if not isinstance(payload, dict): return AppManager.make_response( - message='Invalid request body', - status='error', + message="Invalid request body", + status="error", code=400, ) - canvas_state = payload.get('canvas_state') + canvas_state = payload.get("canvas_state") if not isinstance(canvas_state, dict): return AppManager.make_response( - message='canvas_state must be an object', - status='error', + message="canvas_state must be an object", + status="error", code=400, ) comparison = compare_canvas_states(canvas_state) return AppManager.make_response(data=cast(JsonValue, comparison)) - @app.route('/') + @app.route("/") @require_auth def get_index() -> ResponseReturnValue: - return render_template('index.html') + return render_template("index.html") - @app.route('/init_webdriver') + @app.route("/init_webdriver") @require_auth def init_webdriver_route() -> ResponseReturnValue: """Route to initialize WebDriver after Flask has started""" if not app.webdriver_manager: try: from static.webdriver_manager import WebDriverManager - port = app.config.get('SERVER_PORT', 5000) + + port = app.config.get("SERVER_PORT", 5000) base_url = f"http://127.0.0.1:{port}/" app.webdriver_manager = WebDriverManager(base_url=base_url) except Exception as e: print(f"Failed to initialize WebDriverManager: {str(e)}") return AppManager.make_response( - message=f"WebDriver initialization failed: {str(e)}", - status='error', - code=500 + message=f"WebDriver initialization failed: {str(e)}", status="error", code=500 ) return AppManager.make_response(message="WebDriver initialization successful") - @app.route('/save_workspace', methods=['POST']) + @app.route("/save_workspace", methods=["POST"]) @require_auth def save_workspace_route() -> ResponseReturnValue: """Save the current workspace state.""" try: data = request.get_json(silent=True) if not isinstance(data, dict): - return AppManager.make_response( - message='Invalid request body', - status='error', - code=400 - ) + return AppManager.make_response(message="Invalid request body", status="error", code=400) - state = data.get('state') - name = data.get('name') + state = data.get("state") + name = data.get("name") if name is not None and not isinstance(name, str): - return AppManager.make_response( - message='Workspace name must be a string', - status='error', - code=400 - ) + return AppManager.make_response(message="Workspace name must be a string", status="error", code=400) success = app.workspace_manager.save_workspace(state, name) if success: - return AppManager.make_response(message='Workspace saved successfully') + return AppManager.make_response(message="Workspace saved successfully") else: - return AppManager.make_response( - message='Failed to save workspace', - status='error', - code=500 - ) + return AppManager.make_response(message="Failed to save workspace", status="error", code=500) except Exception as e: - return AppManager.make_response( - message=str(e), - status='error', - code=500 - ) + return AppManager.make_response(message=str(e), status="error", code=500) - @app.route('/send_message_stream', methods=['POST']) + @app.route("/send_message_stream", methods=["POST"]) @require_auth def send_message_stream() -> ResponseReturnValue: """Stream AI response tokens for the provided message payload. @@ -785,17 +763,17 @@ def send_message_stream() -> ResponseReturnValue: request_payload_raw: JsonValue = request.get_json(silent=True) if not isinstance(request_payload_raw, dict): return AppManager.make_response( - message='Invalid request body', - status='error', + message="Invalid request body", + status="error", code=400, ) request_payload: JsonObject = request_payload_raw - message_raw = request_payload.get('message') + message_raw = request_payload.get("message") if not isinstance(message_raw, str) or not message_raw: return AppManager.make_response( - message='Message is required', - status='error', + message="Message is required", + status="error", code=400, ) message: str = message_raw @@ -804,26 +782,26 @@ def send_message_stream() -> ResponseReturnValue: message_json_value: JsonValue = json.loads(message) except (json.JSONDecodeError, TypeError): return AppManager.make_response( - message='Invalid message format', - status='error', + message="Invalid message format", + status="error", code=400, ) if not isinstance(message_json_value, dict): return AppManager.make_response( - message='Invalid message format', - status='error', + message="Invalid message format", + status="error", code=400, ) message_json: JsonObject = message_json_value svg_state, canvas_image_data, _, _ = extract_vision_payload(request_payload) - use_vision = bool(message_json.get('use_vision', False)) - ai_model_raw = message_json.get('ai_model') + use_vision = bool(message_json.get("use_vision", False)) + ai_model_raw = message_json.get("ai_model") ai_model = ai_model_raw if isinstance(ai_model_raw, str) else None # Extract attached images from the message JSON (not request_payload) - attached_images_raw = message_json.get('attached_images') + attached_images_raw = message_json.get("attached_images") attached_images: Optional[List[str]] = None if isinstance(attached_images_raw, list): attached_images = [img for img in attached_images_raw if isinstance(img, str)] @@ -841,7 +819,7 @@ def send_message_stream() -> ResponseReturnValue: app.log_manager.log_user_message(message) # Log action trace if present (sent as top-level field, not in prompt) - action_trace_raw = request_payload.get('action_trace') + action_trace_raw = request_payload.get("action_trace") if isinstance(action_trace_raw, dict): app.log_manager.log_action_trace(action_trace_raw) @@ -854,7 +832,7 @@ def send_message_stream() -> ResponseReturnValue: ) # Check for search_tools results and inject tools if found - tool_call_results_raw = message_json.get('tool_call_results') + tool_call_results_raw = message_json.get("tool_call_results") if isinstance(tool_call_results_raw, str) and tool_call_results_raw: _maybe_inject_search_tools(app.ai_api, tool_call_results_raw) _maybe_inject_search_tools(app.responses_api, tool_call_results_raw) @@ -892,26 +870,24 @@ def _yield_pending_logs() -> Iterator[str]: if isinstance(event, dict): event_dict = cast(StreamEventDict, event) - if event_dict.get('type') == 'final': + if event_dict.get("type") == "final": try: - app.log_manager.log_ai_response(str(event_dict.get('ai_message', ''))) - tool_calls = event_dict.get('ai_tool_calls') + app.log_manager.log_ai_response(str(event_dict.get("ai_message", ""))) + tool_calls = event_dict.get("ai_tool_calls") if isinstance(tool_calls, list): dict_tool_calls: List[Dict[str, Any]] = [ - cast(Dict[str, Any], call) - for call in tool_calls - if isinstance(call, dict) + cast(Dict[str, Any], call) for call in tool_calls if isinstance(call, dict) ] # Intercept search_tools and filter other tool calls if dict_tool_calls: filtered_calls = _intercept_search_tools(app, dict_tool_calls, provider) - event_dict['ai_tool_calls'] = cast(JsonValue, filtered_calls) + event_dict["ai_tool_calls"] = cast(JsonValue, filtered_calls) app.log_manager.log_ai_tool_calls(filtered_calls) except Exception: pass # Reset tools if AI finished (not requesting more tool calls) - finish_reason = event_dict.get('finish_reason') - if finish_reason != 'tool_calls': + finish_reason = event_dict.get("finish_reason") + if finish_reason != "tool_calls": if app.ai_api.has_injected_tools(): app.ai_api.reset_tools() if app.responses_api.has_injected_tools(): @@ -961,35 +937,27 @@ def _yield_pending_logs() -> Iterator[str]: } yield json.dumps(fallback_payload) + "\n" - response = Response(generate(), mimetype='application/x-ndjson') + response = Response(generate(), mimetype="application/x-ndjson") # Headers to reduce buffering in some proxies - response.headers['Cache-Control'] = 'no-cache' - response.headers['X-Accel-Buffering'] = 'no' + response.headers["Cache-Control"] = "no-cache" + response.headers["X-Accel-Buffering"] = "no" return response - @app.route('/load_workspace', methods=['GET']) + @app.route("/load_workspace", methods=["GET"]) @require_auth def load_workspace_route() -> ResponseReturnValue: """Load a workspace state.""" try: - name = request.args.get('name') + name = request.args.get("name") state = app.workspace_manager.load_workspace(name) - return AppManager.make_response(data={'state': state}) + return AppManager.make_response(data={"state": state}) except FileNotFoundError as e: - return AppManager.make_response( - message=str(e), - status='error', - code=404 - ) + return AppManager.make_response(message=str(e), status="error", code=404) except Exception as e: - return AppManager.make_response( - message=str(e), - status='error', - code=500 - ) + return AppManager.make_response(message=str(e), status="error", code=500) - @app.route('/list_workspaces', methods=['GET']) + @app.route("/list_workspaces", methods=["GET"]) @require_auth def list_workspaces_route() -> ResponseReturnValue: """List all saved workspaces.""" @@ -997,42 +965,26 @@ def list_workspaces_route() -> ResponseReturnValue: workspaces = app.workspace_manager.list_workspaces() return AppManager.make_response(data=cast(JsonValue, workspaces)) except Exception as e: - return AppManager.make_response( - message=str(e), - status='error', - code=500 - ) + return AppManager.make_response(message=str(e), status="error", code=500) - @app.route('/delete_workspace', methods=['GET']) + @app.route("/delete_workspace", methods=["GET"]) @require_auth def delete_workspace_route() -> ResponseReturnValue: """Delete a workspace.""" try: - name = request.args.get('name') + name = request.args.get("name") if not name: - return AppManager.make_response( - message='Workspace name is required', - status='error', - code=400 - ) + return AppManager.make_response(message="Workspace name is required", status="error", code=400) success = app.workspace_manager.delete_workspace(name) if success: - return AppManager.make_response(message='Workspace deleted successfully') + return AppManager.make_response(message="Workspace deleted successfully") else: - return AppManager.make_response( - message='Failed to delete workspace', - status='error', - code=404 - ) + return AppManager.make_response(message="Failed to delete workspace", status="error", code=404) except Exception as e: - return AppManager.make_response( - message=str(e), - status='error', - code=500 - ) + return AppManager.make_response(message=str(e), status="error", code=500) - @app.route('/new_conversation', methods=['POST']) + @app.route("/new_conversation", methods=["POST"]) @require_auth def new_conversation_route() -> ResponseReturnValue: """Reset the AI conversation history for a new session.""" @@ -1043,15 +995,11 @@ def new_conversation_route() -> ResponseReturnValue: for provider in app.providers.values(): provider.reset_conversation() app.log_manager.log_new_session() - return AppManager.make_response(message='New conversation started.') + return AppManager.make_response(message="New conversation started.") except Exception as e: - return AppManager.make_response( - message=str(e), - status='error', - code=500 - ) + return AppManager.make_response(message=str(e), status="error", code=500) - @app.route('/save_partial_response', methods=['POST']) + @app.route("/save_partial_response", methods=["POST"]) @require_auth def save_partial_response() -> ResponseReturnValue: """Save a partial AI response that was interrupted by the user.""" @@ -1059,16 +1007,16 @@ def save_partial_response() -> ResponseReturnValue: request_payload = request.get_json(silent=True) if not isinstance(request_payload, dict): return AppManager.make_response( - message='Invalid request body', - status='error', + message="Invalid request body", + status="error", code=400, ) - partial_message = request_payload.get('partial_message', '') + partial_message = request_payload.get("partial_message", "") if not isinstance(partial_message, str) or not partial_message.strip(): return AppManager.make_response( - message='No partial message to save', - status='error', + message="No partial message to save", + status="error", code=400, ) @@ -1078,13 +1026,9 @@ def save_partial_response() -> ResponseReturnValue: for provider in app.providers.values(): provider.add_partial_assistant_message(partial_message) - return AppManager.make_response(message='Partial response saved.') + return AppManager.make_response(message="Partial response saved.") except Exception as e: - return AppManager.make_response( - message=str(e), - status='error', - code=500 - ) + return AppManager.make_response(message=str(e), status="error", code=500) def _process_ai_response(app: MatHudFlask, choice: Any) -> tuple[str, ToolCallList]: """Process the AI response choice and log the results. @@ -1111,22 +1055,22 @@ def _process_ai_response(app: MatHudFlask, choice: Any) -> tuple[str, ToolCallLi return ai_message, tool_calls - @app.route('/send_message', methods=['POST']) + @app.route("/send_message", methods=["POST"]) @require_auth def send_message() -> ResponseReturnValue: request_payload = request.get_json(silent=True) if not isinstance(request_payload, dict): return AppManager.make_response( - message='Invalid request body', - status='error', + message="Invalid request body", + status="error", code=400, ) - message = request_payload.get('message') + message = request_payload.get("message") if not isinstance(message, str) or not message: return AppManager.make_response( - message='Message is required', - status='error', + message="Message is required", + status="error", code=400, ) @@ -1134,25 +1078,25 @@ def send_message() -> ResponseReturnValue: message_json_raw = json.loads(message) except (json.JSONDecodeError, TypeError): return AppManager.make_response( - message='Invalid message format', - status='error', + message="Invalid message format", + status="error", code=400, ) if not isinstance(message_json_raw, dict): return AppManager.make_response( - message='Invalid message format', - status='error', + message="Invalid message format", + status="error", code=400, ) svg_state, canvas_image_data, _, _ = extract_vision_payload(request_payload) - use_vision = bool(message_json_raw.get('use_vision', False)) - ai_model_raw = message_json_raw.get('ai_model') + use_vision = bool(message_json_raw.get("use_vision", False)) + ai_model_raw = message_json_raw.get("ai_model") ai_model = ai_model_raw if isinstance(ai_model_raw, str) else None # Extract attached images from the message JSON - attached_images_raw = message_json_raw.get('attached_images') + attached_images_raw = message_json_raw.get("attached_images") attached_images: Optional[List[str]] = None if isinstance(attached_images_raw, list): attached_images = [img for img in attached_images_raw if isinstance(img, str)] @@ -1170,7 +1114,7 @@ def send_message() -> ResponseReturnValue: app.log_manager.log_user_message(message) # Log action trace if present (sent as top-level field, not in prompt) - action_trace_raw_legacy = request_payload.get('action_trace') + action_trace_raw_legacy = request_payload.get("action_trace") if isinstance(action_trace_raw_legacy, dict): app.log_manager.log_action_trace(action_trace_raw_legacy) @@ -1183,7 +1127,7 @@ def send_message() -> ResponseReturnValue: ) # Check for search_tools results and inject tools if found - tool_call_results_raw = message_json_raw.get('tool_call_results') + tool_call_results_raw = message_json_raw.get("tool_call_results") if isinstance(tool_call_results_raw, str) and tool_call_results_raw: _maybe_inject_search_tools(app.ai_api, tool_call_results_raw) _maybe_inject_search_tools(app.responses_api, tool_call_results_raw) @@ -1196,7 +1140,7 @@ def send_message() -> ResponseReturnValue: def _reset_tools_if_needed(finish_reason: Any) -> None: """Reset tools if AI finished (not requesting more tool calls).""" - if finish_reason != 'tool_calls': + if finish_reason != "tool_calls": if app.ai_api.has_injected_tools(): app.ai_api.reset_tools() if app.responses_api.has_injected_tools(): @@ -1217,7 +1161,7 @@ def _reset_tools_if_needed(finish_reason: Any) -> None: break if final_event is None: - _reset_tools_if_needed('error') + _reset_tools_if_needed("error") return AppManager.make_response( message="No final response event produced", status="error", @@ -1240,11 +1184,16 @@ def _reset_tools_if_needed(finish_reason: Any) -> None: app.log_manager.log_ai_tool_calls(ai_tool_calls) _reset_tools_if_needed(finish_reason) - return AppManager.make_response(data=cast(JsonObject, { - "ai_message": ai_message, - "ai_tool_calls": cast(JsonValue, ai_tool_calls), - "finish_reason": finish_reason, - })) + return AppManager.make_response( + data=cast( + JsonObject, + { + "ai_message": ai_message, + "ai_tool_calls": cast(JsonValue, ai_tool_calls), + "finish_reason": finish_reason, + }, + ) + ) choice = provider.create_chat_completion(message) ai_message, ai_tool_calls_processed = _process_ai_response(app, choice) @@ -1252,23 +1201,28 @@ def _reset_tools_if_needed(finish_reason: Any) -> None: # Intercept search_tools and filter other tool calls if ai_tool_calls: ai_tool_calls = _intercept_search_tools(app, ai_tool_calls, provider) - finish_reason = getattr(choice, 'finish_reason', None) + finish_reason = getattr(choice, "finish_reason", None) _reset_tools_if_needed(finish_reason) - return AppManager.make_response(data=cast(JsonObject, { - "ai_message": ai_message, - "ai_tool_calls": cast(JsonValue, ai_tool_calls), - "finish_reason": finish_reason, - })) + return AppManager.make_response( + data=cast( + JsonObject, + { + "ai_message": ai_message, + "ai_tool_calls": cast(JsonValue, ai_tool_calls), + "finish_reason": finish_reason, + }, + ) + ) except Exception as exc: - _reset_tools_if_needed('error') + _reset_tools_if_needed("error") return AppManager.make_response( message=str(exc), - status='error', + status="error", code=500, ) - @app.route('/search_tools', methods=['POST']) + @app.route("/search_tools", methods=["POST"]) @require_auth def search_tools_route() -> ResponseReturnValue: """Search for tools matching a query description. @@ -1288,20 +1242,20 @@ def search_tools_route() -> ResponseReturnValue: request_payload = request.get_json(silent=True) if not isinstance(request_payload, dict): return AppManager.make_response( - message='Invalid request body', - status='error', + message="Invalid request body", + status="error", code=400, ) - query = request_payload.get('query') + query = request_payload.get("query") if not isinstance(query, str) or not query.strip(): return AppManager.make_response( - message='Query is required', - status='error', + message="Query is required", + status="error", code=400, ) - max_results_raw = request_payload.get('max_results') + max_results_raw = request_payload.get("max_results") max_results = 10 # default if isinstance(max_results_raw, int): max_results = max_results_raw @@ -1309,7 +1263,7 @@ def search_tools_route() -> ResponseReturnValue: max_results = int(max_results_raw) # Get current model from request to use same provider for search - ai_model_raw = request_payload.get('ai_model') + ai_model_raw = request_payload.get("ai_model") ai_model = ai_model_raw if isinstance(ai_model_raw, str) else None try: @@ -1330,11 +1284,11 @@ def search_tools_route() -> ResponseReturnValue: except Exception as exc: return AppManager.make_response( message=str(exc), - status='error', + status="error", code=500, ) - @app.route('/api/tts', methods=['POST']) + @app.route("/api/tts", methods=["POST"]) @require_auth def generate_tts() -> ResponseReturnValue: """Generate text-to-speech audio from text. @@ -1352,28 +1306,28 @@ def generate_tts() -> ResponseReturnValue: request_payload = request.get_json(silent=True) if not isinstance(request_payload, dict): return AppManager.make_response( - message='Invalid request body', - status='error', + message="Invalid request body", + status="error", code=400, ) - text = request_payload.get('text') + text = request_payload.get("text") if not isinstance(text, str) or not text.strip(): return AppManager.make_response( - message='Text is required', - status='error', + message="Text is required", + status="error", code=400, ) - voice_raw = request_payload.get('voice') + voice_raw = request_payload.get("voice") voice = voice_raw if isinstance(voice_raw, str) else None tts_manager = get_tts_manager() if not tts_manager.is_available(): return AppManager.make_response( - message='TTS service is not available. Kokoro may not be installed.', - status='error', + message="TTS service is not available. Kokoro may not be installed.", + status="error", code=503, ) @@ -1386,21 +1340,21 @@ def generate_tts() -> ResponseReturnValue: if not success: return AppManager.make_response( message=str(result), - status='error', + status="error", code=500, ) # Return WAV audio bytes return Response( result, - mimetype='audio/wav', + mimetype="audio/wav", headers={ - 'Content-Type': 'audio/wav', - 'Content-Disposition': 'inline; filename="tts_output.wav"', - } + "Content-Type": "audio/wav", + "Content-Disposition": 'inline; filename="tts_output.wav"', + }, ) - @app.route('/api/tts/voices', methods=['GET']) + @app.route("/api/tts/voices", methods=["GET"]) @require_auth def get_tts_voices() -> ResponseReturnValue: """Get available TTS voices. @@ -1409,8 +1363,13 @@ def get_tts_voices() -> ResponseReturnValue: JSON response with list of voice IDs """ tts_manager = get_tts_manager() - return AppManager.make_response(data=cast(JsonValue, { - 'voices': tts_manager.get_voices(), - 'default_voice': tts_manager.DEFAULT_VOICE, - 'available': tts_manager.is_available(), - })) + return AppManager.make_response( + data=cast( + JsonValue, + { + "voices": tts_manager.get_voices(), + "default_voice": tts_manager.DEFAULT_VOICE, + "available": tts_manager.is_available(), + }, + ) + ) diff --git a/static/tool_argument_validator.py b/static/tool_argument_validator.py index 309b86cb..2e7e6862 100644 --- a/static/tool_argument_validator.py +++ b/static/tool_argument_validator.py @@ -82,7 +82,7 @@ def _truncate(value: Any) -> str: """Truncate a value's repr for safe inclusion in error messages.""" s = repr(value) if len(s) > _ERROR_VALUE_MAX_LEN: - return s[: _ERROR_VALUE_MAX_LEN] + "..." + return s[:_ERROR_VALUE_MAX_LEN] + "..." return s @@ -192,9 +192,7 @@ def _validate_value( test_value = copy.deepcopy(value) # Create a temporary container for canonical writes tmp: Dict[str, Any] = {"v": test_value} - _validate_value( - test_value, alt, path, tool_name, test_errors, None, "v", tmp - ) + _validate_value(test_value, alt, path, tool_name, test_errors, None, "v", tmp) if not test_errors: # This alternative matched; use its canonical form. result = tmp["v"] @@ -272,19 +270,13 @@ def _validate_value( # --- NaN / Infinity rejection for numeric values --- if isinstance(value, (int, float)) and not isinstance(value, bool): if _check_nan_inf(value): - errors.append( - f"Tool '{tool_name}': argument '{path}' must be a finite number, " - f"got {_truncate(value)}." - ) + errors.append(f"Tool '{tool_name}': argument '{path}' must be a finite number, got {_truncate(value)}.") return value # --- Enum validation --- enum_values = schema.get("enum") if enum_values is not None and value not in enum_values: - errors.append( - f"Tool '{tool_name}': argument '{path}' must be one of " - f"{enum_values}, got {_truncate(value)}." - ) + errors.append(f"Tool '{tool_name}': argument '{path}' must be one of {enum_values}, got {_truncate(value)}.") return value # --- maxLength validation (strings) --- @@ -292,8 +284,7 @@ def _validate_value( if max_length is not None and isinstance(value, str): if len(value) > max_length: errors.append( - f"Tool '{tool_name}': argument '{path}' must be at most " - f"{max_length} characters, got {len(value)}." + f"Tool '{tool_name}': argument '{path}' must be at most {max_length} characters, got {len(value)}." ) # --- Array validation --- @@ -302,8 +293,7 @@ def _validate_value( min_items = schema.get("minItems") if min_items is not None and len(value) < min_items: errors.append( - f"Tool '{tool_name}': argument '{path}' must have at least " - f"{min_items} items, got {len(value)}." + f"Tool '{tool_name}': argument '{path}' must have at least {min_items} items, got {len(value)}." ) # Validate each element against items schema @@ -344,11 +334,9 @@ def _validate_value( if key not in allowed_keys: allowed_list = sorted(allowed_keys) errors.append( - f"Tool '{tool_name}': unknown argument " - f"'{path}.{key}' (allowed: {', '.join(allowed_list)})." + f"Tool '{tool_name}': unknown argument '{path}.{key}' (allowed: {', '.join(allowed_list)})." if path - else f"Tool '{tool_name}': unknown argument " - f"'{key}' (allowed: {', '.join(allowed_list)})." + else f"Tool '{tool_name}': unknown argument '{key}' (allowed: {', '.join(allowed_list)})." ) # Recurse into each property @@ -411,8 +399,7 @@ def validate(function_name: str, arguments: Dict[str, Any]) -> ValidationResult: if schema is None: # Unknown function — log a warning but pass through. logger.warning( - "Tool '%s': no schema found in registry; " - "arguments pass through unvalidated.", + "Tool '%s': no schema found in registry; arguments pass through unvalidated.", function_name, ) return ValidationResult( @@ -433,9 +420,7 @@ def validate(function_name: str, arguments: Dict[str, Any]) -> ValidationResult: # Required fields for req_key in required: if req_key not in canonical_args: - errors.append( - f"Tool '{function_name}': missing required argument '{req_key}'." - ) + errors.append(f"Tool '{function_name}': missing required argument '{req_key}'.") # Unknown keys if additional is False and properties is not None: @@ -444,8 +429,7 @@ def validate(function_name: str, arguments: Dict[str, Any]) -> ValidationResult: if key not in allowed_keys: allowed_list = sorted(allowed_keys) errors.append( - f"Tool '{function_name}': unknown argument " - f"'{key}' (allowed: {', '.join(allowed_list)})." + f"Tool '{function_name}': unknown argument '{key}' (allowed: {', '.join(allowed_list)})." ) # Validate each property diff --git a/static/tool_search_service.py b/static/tool_search_service.py index c67864c8..fcca6779 100644 --- a/static/tool_search_service.py +++ b/static/tool_search_service.py @@ -49,9 +49,34 @@ class ToolSearchService: Return a JSON array of up to {max_results} tool names. Example: ["create_circle", "create_point"]""" _STOPWORDS = frozenset( { - "a", "an", "and", "are", "as", "at", "be", "by", "for", "from", - "how", "i", "if", "in", "is", "it", "me", "of", "on", "or", - "please", "show", "the", "to", "up", "use", "with", "you", + "a", + "an", + "and", + "are", + "as", + "at", + "be", + "by", + "for", + "from", + "how", + "i", + "if", + "in", + "is", + "it", + "me", + "of", + "on", + "or", + "please", + "show", + "the", + "to", + "up", + "use", + "with", + "you", } ) @@ -257,11 +282,18 @@ def _tool_score(cls, query_tokens: List[str], tool_name: str, description: str) # Intent boosts for common confusion clusters. if any(t in query_tokens for t in ("move", "shift", "translate")) and tool_name == "translate_object": score += 4.0 - if any(t in query_tokens for t in ("area", "shade", "region")) and tool_name in {"calculate_area", "create_colored_area", "create_region_colored_area"}: + if any(t in query_tokens for t in ("area", "shade", "region")) and tool_name in { + "calculate_area", + "create_colored_area", + "create_region_colored_area", + }: score += 2.0 if any(t in query_tokens for t in ("distribution", "gaussian", "normal")) and tool_name == "plot_distribution": score += 4.0 - if any(t in query_tokens for t in ("determinant", "eigenvalue", "matrix", "vector")) and tool_name == "evaluate_linear_algebra_expression": + if ( + any(t in query_tokens for t in ("determinant", "eigenvalue", "matrix", "vector")) + and tool_name == "evaluate_linear_algebra_expression" + ): score += 4.0 if any(t in query_tokens for t in ("undo", "redo", "history")) and tool_name in {"undo", "redo"}: score += 4.0 diff --git a/static/tts_manager.py b/static/tts_manager.py index e915eba3..9bb50b90 100644 --- a/static/tts_manager.py +++ b/static/tts_manager.py @@ -39,11 +39,11 @@ class TTSManager: # Available voices VOICES: List[str] = [ "am_michael", # Male - clear, natural - "am_fenrir", # Male - deeper tone - "am_onyx", # Male - darker tone - "am_echo", # Male - resonant - "af_nova", # Female - clear - "af_bella", # Female - warm + "am_fenrir", # Male - deeper tone + "am_onyx", # Male - darker tone + "am_echo", # Male - resonant + "af_nova", # Female - clear + "af_bella", # Female - warm ] DEFAULT_VOICE: str = "am_michael" @@ -72,7 +72,7 @@ def _get_pipeline(self) -> Tuple[bool, Union[object, str]]: from kokoro import KPipeline # Initialize Kokoro pipeline for American English - self._pipeline = KPipeline(lang_code='a', repo_id='hexgrad/Kokoro-82M') + self._pipeline = KPipeline(lang_code="a", repo_id="hexgrad/Kokoro-82M") return True, self._pipeline except ImportError as e: @@ -181,9 +181,7 @@ def generate_speech_threaded( Tuple of (success, audio_bytes_or_error_message) """ try: - future: Future[Tuple[bool, Union[bytes, str]]] = self._executor.submit( - self.generate_speech, text, voice - ) + future: Future[Tuple[bool, Union[bytes, str]]] = self._executor.submit(self.generate_speech, text, voice) return future.result(timeout=timeout) except TimeoutError: return False, "TTS generation timed out" @@ -206,7 +204,7 @@ def _audio_to_wav(self, audio: np.ndarray, sample_rate: int) -> bytes: buffer = io.BytesIO() # Write WAV to buffer - sf.write(buffer, audio, sample_rate, format='WAV', subtype='PCM_16') + sf.write(buffer, audio, sample_rate, format="WAV", subtype="PCM_16") # Get bytes buffer.seek(0) diff --git a/static/webdriver_manager.py b/static/webdriver_manager.py index 7c92e430..8f6feaf0 100644 --- a/static/webdriver_manager.py +++ b/static/webdriver_manager.py @@ -90,7 +90,8 @@ def update_svg_state(self, svg_state: SvgState) -> None: print("Loading SVG state...") if self.driver is None: raise RuntimeError("WebDriver not initialized") - self.driver.execute_script(""" + self.driver.execute_script( + """ const svg = document.getElementById('math-svg'); const container = document.querySelector('.math-container'); @@ -111,7 +112,9 @@ def update_svg_state(self, svg_state: SvgState) -> None: } return true; // Confirm execution - """, svg_state) + """, + svg_state, + ) time.sleep(1) # Give time for the SVG to be redrawn def _setup_driver(self) -> None: @@ -122,7 +125,7 @@ def _setup_driver(self) -> None: """ print("Initializing WebDriver...") firefox_options = Options() - firefox_options.add_argument('--headless') + firefox_options.add_argument("--headless") max_retries = 3 for attempt in range(max_retries): @@ -209,15 +212,8 @@ def _wait_for_svg_elements(self) -> None: """ if self.driver is None: raise RuntimeError("WebDriver not initialized") - WebDriverWait(self.driver, 10).until( - EC.presence_of_element_located((By.ID, "math-svg")) - ) - WebDriverWait(self.driver, 10).until( - EC.presence_of_all_elements_located(( - By.CSS_SELECTOR, - "#math-svg > *" - )) - ) + WebDriverWait(self.driver, 10).until(EC.presence_of_element_located((By.ID, "math-svg"))) + WebDriverWait(self.driver, 10).until(EC.presence_of_all_elements_located((By.CSS_SELECTOR, "#math-svg > *"))) def _verify_svg_content(self) -> None: """Verify that the SVG content is not empty. @@ -259,7 +255,8 @@ def _configure_svg_size(self, dimensions: SvgDimensions) -> None: """ if self.driver is None: raise RuntimeError("WebDriver not initialized") - self.driver.execute_script(""" + self.driver.execute_script( + """ var container = document.querySelector('.math-container'); var svg = document.getElementById('math-svg'); @@ -284,7 +281,10 @@ def _configure_svg_size(self, dimensions: SvgDimensions) -> None: elements[i].style.visibility = 'visible'; elements[i].style.opacity = '1'; } - """, dimensions['width'], dimensions['height']) + """, + dimensions["width"], + dimensions["height"], + ) def cleanup(self) -> None: """Clean up WebDriver resources. diff --git a/static/workspace_manager.py b/static/workspace_manager.py index cc533493..4144190c 100644 --- a/static/workspace_manager.py +++ b/static/workspace_manager.py @@ -67,7 +67,7 @@ def _is_safe_workspace_name(self, name: Optional[str]) -> bool: if not name or not isinstance(name, str): return False - if not re.match(r'^[\w-]+\Z$', name): # Only allow alphanumeric characters, underscores, and hyphens + if not re.match(r"^[\w-]+\Z$", name): # Only allow alphanumeric characters, underscores, and hyphens return False return True @@ -163,7 +163,7 @@ def save_workspace( } file_path = self.get_workspace_path(name, test_dir) - with open(file_path, 'w', encoding='utf-8') as f: + with open(file_path, "w", encoding="utf-8") as f: json.dump(workspace_data, f, indent=2) return True @@ -187,7 +187,7 @@ def _get_most_recent_current_workspace(self, test_dir: Optional[str] = None) -> current_workspaces: List[str] = [] for filename in os.listdir(target_dir): - if filename.startswith('current_workspace_') and filename.endswith('.json'): + if filename.startswith("current_workspace_") and filename.endswith(".json"): file_path = os.path.join(target_dir, filename) current_workspaces.append(file_path) @@ -216,7 +216,7 @@ def load_workspace(self, name: Optional[str] = None, test_dir: Optional[str] = N if not os.path.exists(file_path): raise FileNotFoundError(f"No workspace found at {file_path}") - with open(file_path, 'r', encoding='utf-8') as f: + with open(file_path, "r", encoding="utf-8") as f: workspace_data_raw: JsonValue = json.load(f) if not isinstance(workspace_data_raw, dict): @@ -253,7 +253,9 @@ def _normalize_and_migrate_workspace_record( migrated_state = self._migrate_state(state, schema_version) metadata_name_raw = metadata.get("name") - metadata_name = metadata_name_raw if isinstance(metadata_name_raw, str) and metadata_name_raw else workspace_name + metadata_name = ( + metadata_name_raw if isinstance(metadata_name_raw, str) and metadata_name_raw else workspace_name + ) metadata_last_modified_raw = metadata.get("last_modified") metadata_last_modified = ( metadata_last_modified_raw @@ -287,8 +289,7 @@ def _migrate_state(self, state: WorkspaceState, schema_version: int) -> Workspac return state if schema_version > CURRENT_WORKSPACE_SCHEMA_VERSION: raise ValueError( - f"Unsupported workspace schema_version: {schema_version} " - f"(current: {CURRENT_WORKSPACE_SCHEMA_VERSION})" + f"Unsupported workspace schema_version: {schema_version} (current: {CURRENT_WORKSPACE_SCHEMA_VERSION})" ) return state @@ -312,7 +313,7 @@ def list_workspaces(self, test_dir: Optional[str] = None) -> List[str]: file_path = os.path.join(target_dir, filename) try: - with open(file_path, 'r', encoding='utf-8') as f: + with open(file_path, "r", encoding="utf-8") as f: data_raw: JsonValue = json.load(f) if isinstance(data_raw, dict): metadata_candidate = data_raw.get("metadata") From 23fbbaaa82c9b2eb91512f46991cd73061aea1cd Mon Sep 17 00:00:00 2001 From: Vlad Ciobanu <95963142+vl3c@users.noreply.github.com> Date: Tue, 17 Feb 2026 22:54:19 +0200 Subject: [PATCH 2/3] Fix 22 mypy errors in browser type stubs and consuming modules - Add missing DOMNode properties (id, src, outerHTML) and methods (replaceChild, setSelectionRange) to _dom.pyi - Make DOMNode.select() parameter optional for input text selection - Widen Window.setTimeout delay type to accept float - Add explicit type annotation for SvgRenderer._offscreen_surface - Use intermediate typed variables to avoid no-any-return in DrawableManager and TransformationsManager - Add type: ignore comments for intentional monkey-patching in tests - Replace os.getuid() with getattr fallback for Windows compatibility --- server_tests/test_workspace_management.py | 2 +- static/client/client_tests/test_throttle.py | 10 +++++----- static/client/client_tests/test_window_mocks.py | 10 +++++----- static/client/managers/drawable_manager.py | 3 ++- static/client/managers/transformations_manager.py | 3 ++- static/client/rendering/svg_renderer.py | 1 + static/client/typing/browser/__init__.pyi | 2 +- static/client/typing/browser/_dom.pyi | 7 ++++++- 8 files changed, 23 insertions(+), 15 deletions(-) diff --git a/server_tests/test_workspace_management.py b/server_tests/test_workspace_management.py index 77466eee..1ea2dbdf 100644 --- a/server_tests/test_workspace_management.py +++ b/server_tests/test_workspace_management.py @@ -96,7 +96,7 @@ def test_save_workspace_failure(self) -> None: ) self.assertFalse(success, "Save workspace should return False with invalid name") - if os.name != "nt" and os.getuid() != 0: # Skip on Windows and root + if os.name != "nt" and getattr(os, "getuid", lambda: -1)() != 0: # Skip on Windows and root test_dir = os.path.join(WORKSPACES_DIR, TEST_DIR) original_mode = os.stat(test_dir).st_mode try: diff --git a/static/client/client_tests/test_throttle.py b/static/client/client_tests/test_throttle.py index 50f2345f..35675741 100644 --- a/static/client/client_tests/test_throttle.py +++ b/static/client/client_tests/test_throttle.py @@ -32,9 +32,9 @@ def now() -> int: self.original_clearTimeout = browser_window.clearTimeout # Replace the browser window objects - browser_window.performance = self.mock_performance - browser_window.setTimeout = self.mock_window.setTimeout - browser_window.clearTimeout = self.mock_window.clearTimeout + browser_window.performance = self.mock_performance # type: ignore[assignment] + browser_window.setTimeout = self.mock_window.setTimeout # type: ignore[method-assign] + browser_window.clearTimeout = self.mock_window.clearTimeout # type: ignore[method-assign] def set_time(self, new_time: int) -> None: """Helper to update the mock time.""" @@ -43,8 +43,8 @@ def set_time(self, new_time: int) -> None: def tearDown(self) -> None: # Restore original window objects browser_window.performance = self.original_performance - browser_window.setTimeout = self.original_setTimeout - browser_window.clearTimeout = self.original_clearTimeout + browser_window.setTimeout = self.original_setTimeout # type: ignore[method-assign] + browser_window.clearTimeout = self.original_clearTimeout # type: ignore[method-assign] def test_throttle_first_call_executes_immediately(self) -> None: """Test that the first call to a throttled function executes immediately.""" diff --git a/static/client/client_tests/test_window_mocks.py b/static/client/client_tests/test_window_mocks.py index 3b9ed84d..15b6937b 100644 --- a/static/client/client_tests/test_window_mocks.py +++ b/static/client/client_tests/test_window_mocks.py @@ -25,15 +25,15 @@ def setUp(self) -> None: self.original_clearTimeout = browser_window.clearTimeout # Replace the browser window objects - browser_window.performance = self.mock_performance - browser_window.setTimeout = self.mock_window.setTimeout - browser_window.clearTimeout = self.mock_window.clearTimeout + browser_window.performance = self.mock_performance # type: ignore[assignment] + browser_window.setTimeout = self.mock_window.setTimeout # type: ignore[method-assign] + browser_window.clearTimeout = self.mock_window.clearTimeout # type: ignore[method-assign] def tearDown(self) -> None: # Restore original window objects browser_window.performance = self.original_performance - browser_window.setTimeout = self.original_setTimeout - browser_window.clearTimeout = self.original_clearTimeout + browser_window.setTimeout = self.original_setTimeout # type: ignore[method-assign] + browser_window.clearTimeout = self.original_clearTimeout # type: ignore[method-assign] def test_performance_now(self) -> None: """Test that window.performance.now() returns the correct time and updates properly.""" diff --git a/static/client/managers/drawable_manager.py b/static/client/managers/drawable_manager.py index 2f986365..91130425 100644 --- a/static/client/managers/drawable_manager.py +++ b/static/client/managers/drawable_manager.py @@ -1238,9 +1238,10 @@ def create_perpendicular_from_point( color: Optional[str] = None, ) -> Dict[str, Any]: """Drop a perpendicular from a point to a segment.""" - return self.construction_manager.create_perpendicular_from_point( + result: Dict[str, Any] = self.construction_manager.create_perpendicular_from_point( point_name, segment_name, name=name, color=color ) + return result def create_angle_bisector( self, diff --git a/static/client/managers/transformations_manager.py b/static/client/managers/transformations_manager.py index 8cabc1d9..1824b843 100644 --- a/static/client/managers/transformations_manager.py +++ b/static/client/managers/transformations_manager.py @@ -100,7 +100,8 @@ def _find_drawable_by_name( def _get_class_name(self, drawable: Any) -> str: getter = getattr(drawable, "get_class_name", None) - return getter() if callable(getter) else drawable.__class__.__name__ + name: str = getter() if callable(getter) else drawable.__class__.__name__ + return name def _gather_moved_points(self, drawable: Any) -> List[Any]: get_vertices = getattr(drawable, "get_vertices", None) diff --git a/static/client/rendering/svg_renderer.py b/static/client/rendering/svg_renderer.py index 07f39d34..0f43ca15 100644 --- a/static/client/rendering/svg_renderer.py +++ b/static/client/rendering/svg_renderer.py @@ -198,6 +198,7 @@ def __init__(self, style_config: Optional[Dict[str, Any]] = None, surface_id: st self._telemetry = SvgTelemetry() self._plan_cache: Dict[str, Dict[str, Any]] = {} self._cartesian_cache: Optional[Dict[str, Any]] = None + self._offscreen_surface: Optional[Any] = None self._initialize_plan_state() adapter_surface_id = self._configure_surfaces() self._shared_primitives: SvgPrimitiveAdapter = SvgPrimitiveAdapter( diff --git a/static/client/typing/browser/__init__.pyi b/static/client/typing/browser/__init__.pyi index fb4122e9..00554a8a 100644 --- a/static/client/typing/browser/__init__.pyi +++ b/static/client/typing/browser/__init__.pyi @@ -105,7 +105,7 @@ class Window: URL: _URL # Direct methods - def setTimeout(self, callback: Callable[..., Any], delay: int) -> int: ... + def setTimeout(self, callback: Callable[..., Any], delay: int | float) -> int: ... def clearTimeout(self, timer_id: int) -> None: ... def requestAnimationFrame(self, callback: Callable[..., Any]) -> int: ... diff --git a/static/client/typing/browser/_dom.pyi b/static/client/typing/browser/_dom.pyi index d4bf2ee0..c04513ee 100644 --- a/static/client/typing/browser/_dom.pyi +++ b/static/client/typing/browser/_dom.pyi @@ -24,7 +24,10 @@ class DOMNode: """ # --- Properties --- + id: str + src: str innerHTML: str + outerHTML: str text: str value: Any disabled: bool @@ -52,6 +55,7 @@ class DOMNode: def getAttribute(self, name: str) -> str | None: ... def removeAttribute(self, name: str) -> None: ... def insertBefore(self, new_child: DOMNode, ref_child: DOMNode | None) -> DOMNode: ... + def replaceChild(self, new_child: DOMNode, old_child: DOMNode) -> DOMNode: ... def cloneNode(self, deep: bool = ...) -> DOMNode: ... def bind(self, event: str, handler: Callable[..., Any]) -> None: ... def focus(self) -> None: ... @@ -60,8 +64,9 @@ class DOMNode: def clear(self) -> None: ... def remove(self) -> None: ... def getContext(self, context_type: str) -> Any: ... - def select(self, selector: str) -> list[DOMNode]: ... + def select(self, selector: str = ...) -> list[DOMNode]: ... def select_one(self, selector: str) -> DOMNode | None: ... + def setSelectionRange(self, start: int, end: int) -> None: ... # --- Operators --- def __getitem__(self, key: str) -> Any: ... From fe9fc9b1ce1c549e1f0a9505066fa52bc622aa76 Mon Sep 17 00:00:00 2001 From: vl3c <95963142+vl3c@users.noreply.github.com> Date: Wed, 18 Feb 2026 22:41:46 +0200 Subject: [PATCH 3/3] Fix tool-search model params and restore client startup --- mypy.ini | 3 +- server_tests/test_browser_typing_stubs.py | 10 ++-- server_tests/test_tool_search_service.py | 48 +++++++++++++++++++ .../client/managers/colored_area_manager.py | 25 +++++----- .../browser/__init__.pyi | 0 .../{typing => type_stubs}/browser/_dom.pyi | 0 .../{typing => type_stubs}/browser/aio.pyi | 0 .../{typing => type_stubs}/browser/ajax.pyi | 0 static/tool_search_service.py | 25 +++++----- 9 files changed, 80 insertions(+), 31 deletions(-) rename static/client/{typing => type_stubs}/browser/__init__.pyi (100%) rename static/client/{typing => type_stubs}/browser/_dom.pyi (100%) rename static/client/{typing => type_stubs}/browser/aio.pyi (100%) rename static/client/{typing => type_stubs}/browser/ajax.pyi (100%) diff --git a/mypy.ini b/mypy.ini index 993077ec..d4b1d96a 100644 --- a/mypy.ini +++ b/mypy.ini @@ -11,7 +11,6 @@ ignore_missing_imports = True implicit_reexport = False incremental = True explicit_package_bases = True -mypy_path = static/client/typing +mypy_path = static/client/type_stubs files = app.py, static/app_manager.py, static/workspace_manager.py, static/log_manager.py, static/openai_api_base.py, static/openai_completions_api.py, static/openai_responses_api.py, static/routes.py, static/tool_call_processor.py, static/ai_model.py, static/webdriver_manager.py, static/functions_definitions.py, run_server_tests.py, server_tests/test_mocks.py, server_tests/test_routes.py, server_tests/test_workspace_management.py, static/client/constants.py, static/client/expression_validator.py, static/client/markdown_parser.py, static/client/main.py, static/client/ai_interface.py, static/client/canvas.py, static/client/canvas_event_handler.py, static/client/cartesian_system_2axis.py, static/client/coordinate_mapper.py, static/client/expression_evaluator.py, static/client/function_registry.py, static/client/process_function_calls.py, static/client/result_processor.py, static/client/result_validator.py, static/client/workspace_manager.py, static/client/utils/math_utils.py, static/client/utils/computation_utils.py, static/client/utils/geometry_utils.py, static/client/utils/style_utils.py, static/client/utils/linear_algebra_utils.py, static/client/name_generator/base.py, static/client/name_generator/drawable.py, static/client/name_generator/function.py, static/client/name_generator/point.py, static/client/managers/undo_redo_manager.py, static/client/managers/transformations_manager.py, static/client/managers/drawable_manager.py, static/client/managers/drawables_container.py, static/client/managers/drawable_manager_proxy.py, static/client/managers/drawable_dependency_manager.py, static/client/managers/point_manager.py, static/client/managers/segment_manager.py, static/client/managers/vector_manager.py, static/client/managers/circle_manager.py, static/client/managers/ellipse_manager.py, static/client/managers/function_manager.py, static/client/managers/colored_area_manager.py, static/client/managers/angle_manager.py, static/client/drawables/position.py, static/client/drawables/drawable.py, static/client/drawables/point.py, static/client/drawables/segment.py, static/client/drawables/vector.py, static/client/drawables/triangle.py, static/client/drawables/rectangle.py, static/client/drawables/circle.py, static/client/drawables/ellipse.py, static/client/drawables/function.py, static/client/drawables/angle.py, static/client/drawables/colored_area.py, static/client/drawables/functions_bounded_colored_area.py, static/client/drawables/segments_bounded_colored_area.py, static/client/drawables/function_segment_bounded_colored_area.py, static/client/test_runner.py, static/client/rendering/interfaces.py, static/client/rendering/primitives.py, static/client/rendering/renderables/function_renderable.py, static/client/rendering/renderables/functions_area_renderable.py, static/client/rendering/renderables/segments_area_renderable.py, static/client/rendering/renderables/function_segment_area_renderable.py, static/client/rendering/svg_renderer.py, static/client/client_tests/test_angle.py, static/client/client_tests/test_angle_manager.py, static/client/client_tests/test_canvas.py, static/client/client_tests/test_cartesian.py, static/client/client_tests/test_circle.py, static/client/client_tests/test_custom_drawable_names.py, static/client/client_tests/test_drawable_dependency_manager.py, static/client/client_tests/test_drawable_name_generator.py, static/client/client_tests/test_drawables_container.py, static/client/client_tests/test_ellipse.py, static/client/client_tests/test_event_handler.py, static/client/client_tests/test_expression_validator.py, static/client/client_tests/test_function.py, static/client/client_tests/test_function_bounded_colored_area_integration.py, static/client/client_tests/test_function_calling.py, static/client/client_tests/test_linear_algebra_utils.py, static/client/client_tests/test_function_segment_bounded_colored_area.py, static/client/client_tests/test_functions_bounded_colored_area.py, static/client/client_tests/test_math_functions.py, static/client/client_tests/test_point.py, static/client/client_tests/test_rectangle.py, static/client/client_tests/test_segment.py, static/client/client_tests/test_segments_bounded_colored_area.py, static/client/client_tests/test_throttle.py, static/client/client_tests/test_triangle.py, static/client/client_tests/test_vector.py, static/client/client_tests/test_window_mocks.py, static/client/client_tests/ai_result_formatter.py, static/client/client_tests/brython_io.py, static/client/client_tests/simple_mock.py, static/client/client_tests/tests.py, generate_diagrams_launcher.py, scripts/linear_algebra_expected_values.py, diagrams/scripts/utils.py, diagrams/scripts/generate_diagrams.py, diagrams/scripts/generate_arch.py, diagrams/scripts/generate_brython_diagrams.py, diagrams/scripts/setup_diagram_tools.py, static/client/client_tests/__init__.py, static/client/drawables/__init__.py, static/client/managers/__init__.py, static/client/name_generator/__init__.py, static/client/rendering/__init__.py, static/client/utils/__init__.py, server_tests/__init__.py, server_tests/python_path_setup.py, documentation/metrics/project_metrics_analyzer.py, server_tests/test_browser_typing_stubs.py follow_imports = skip - diff --git a/server_tests/test_browser_typing_stubs.py b/server_tests/test_browser_typing_stubs.py index ed791c05..989fec71 100644 --- a/server_tests/test_browser_typing_stubs.py +++ b/server_tests/test_browser_typing_stubs.py @@ -1,6 +1,6 @@ """Acceptance tests for Brython browser module type stubs. -Validates that the .pyi stubs under static/client/typing/browser/ +Validates that the .pyi stubs under static/client/type_stubs/browser/ are syntactically valid, export the expected names, and pass MyPy type-checking for common browser API usage patterns. """ @@ -21,8 +21,8 @@ # --------------------------------------------------------------------------- _ROOT: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -_STUBS_DIR: str = os.path.join(_ROOT, "static", "client", "typing", "browser") -_MYPY_PATH: str = os.path.join(_ROOT, "static", "client", "typing") +_STUBS_DIR: str = os.path.join(_ROOT, "static", "client", "type_stubs", "browser") +_MYPY_PATH: str = os.path.join(_ROOT, "static", "client", "type_stubs") _INIT_PYI: str = os.path.join(_STUBS_DIR, "__init__.pyi") _DOM_PYI: str = os.path.join(_STUBS_DIR, "_dom.pyi") @@ -1391,9 +1391,9 @@ def test_mypy_ini_has_mypy_path(self) -> None: "mypy.ini must contain 'mypy_path' setting", ) self.assertIn( - "static/client/typing", + "static/client/type_stubs", content, - "mypy_path must reference 'static/client/typing'", + "mypy_path must reference 'static/client/type_stubs'", ) def test_mypy_ini_has_test_file(self) -> None: diff --git a/server_tests/test_tool_search_service.py b/server_tests/test_tool_search_service.py index f62ffb01..873cd7a2 100644 --- a/server_tests/test_tool_search_service.py +++ b/server_tests/test_tool_search_service.py @@ -258,6 +258,54 @@ def test_search_uses_correct_model(self, service: ToolSearchService, mock_client call_args = mock_client.chat.completions.create.call_args assert call_args.kwargs.get("model") == "gpt-4.1" + def test_search_openai_reasoning_model_uses_max_completion_tokens( + self, service: ToolSearchService, mock_client: MagicMock + ) -> None: + """OpenAI reasoning models should use max_completion_tokens.""" + self._setup_mock_response(mock_client, '["create_circle"]') + model = AIModel.from_identifier("o4-mini") + + service.search_tools("draw", model=model) + + call_args = mock_client.chat.completions.create.call_args + assert call_args.kwargs.get("model") == "o4-mini" + assert call_args.kwargs.get("max_completion_tokens") == 500 + assert "max_tokens" not in call_args.kwargs + + def test_search_non_openai_reasoning_model_uses_max_tokens( + self, service: ToolSearchService, mock_client: MagicMock + ) -> None: + """Non-OpenAI reasoning models should keep max_tokens.""" + self._setup_mock_response(mock_client, '["create_circle"]') + model = AIModel( + identifier="reasoning-openrouter-test", + has_vision=False, + is_reasoning_model=True, + provider="openrouter", + ) + + service.search_tools("draw", model=model) + + call_args = mock_client.chat.completions.create.call_args + assert call_args.kwargs.get("model") == "reasoning-openrouter-test" + assert call_args.kwargs.get("max_tokens") == 500 + assert "max_completion_tokens" not in call_args.kwargs + + def test_search_uses_reasoning_default_model_when_none(self, mock_client: MagicMock) -> None: + """Default OpenAI reasoning model should be used as configured.""" + self._setup_mock_response(mock_client, '["create_circle"]') + service = ToolSearchService( + client=mock_client, + default_model=AIModel.from_identifier("o3"), + ) + + service.search_tools("draw") + + call_args = mock_client.chat.completions.create.call_args + assert call_args.kwargs.get("model") == "o3" + assert call_args.kwargs.get("max_completion_tokens") == 500 + assert "max_tokens" not in call_args.kwargs + def test_search_uses_default_model_when_none(self, service: ToolSearchService, mock_client: MagicMock) -> None: """search_tools should use gpt-4.1-mini when no model specified.""" self._setup_mock_response(mock_client, '["create_circle"]') diff --git a/static/client/managers/colored_area_manager.py b/static/client/managers/colored_area_manager.py index b961de02..7cd144c0 100644 --- a/static/client/managers/colored_area_manager.py +++ b/static/client/managers/colored_area_manager.py @@ -74,7 +74,6 @@ from name_generator.drawable import DrawableNameGenerator from drawables.rectangle import Rectangle - class ColoredAreaManager: """ Manages colored area drawables for a Canvas. @@ -201,18 +200,16 @@ def get_y_at_x(segment: Segment, x: float) -> float: y = get_y_at_x(drawable1, x2_max) self.drawable_manager.create_point(x2_max, y) - colored_area: Union[ - SegmentsBoundedColoredArea, FunctionSegmentBoundedColoredArea, FunctionsBoundedColoredArea - ] + colored_area: Union[SegmentsBoundedColoredArea, FunctionSegmentBoundedColoredArea, FunctionsBoundedColoredArea] colored_area = SegmentsBoundedColoredArea(drawable1, drawable2, color=color, opacity=opacity) elif isinstance(drawable2, Segment): # Function-segment case (we know drawable1 is not a segment due to the swap above) colored_area = FunctionSegmentBoundedColoredArea(drawable1, drawable2, color=color, opacity=opacity) else: # Function-function case - colored_area = FunctionsBoundedColoredArea( - drawable1, drawable2, left_bound=left_bound, right_bound=right_bound, color=color, opacity=opacity - ) + colored_area = FunctionsBoundedColoredArea(drawable1, drawable2, + left_bound=left_bound, right_bound=right_bound, + color=color, opacity=opacity) # Add to drawables self.drawables.add(colored_area) @@ -687,7 +684,9 @@ def delete_colored_areas_for_segment( def _remove_colored_area_drawable(self, area: "Drawable") -> bool: """Remove a colored-area drawable and clean dependency graph entries.""" - return bool(remove_drawable_with_dependencies(self.drawables, self.dependency_manager, area)) + return bool(remove_drawable_with_dependencies( + self.drawables, self.dependency_manager, area + )) def get_colored_areas_for_drawable(self, drawable: Union[Function, Segment]) -> List["Drawable"]: """ @@ -776,9 +775,7 @@ def update_colored_area( policy = self._get_policy_for_area(area) self._validate_policy(policy, list(pending_fields.keys())) - self._validate_colored_area_payload( - area, pending_fields, new_color, new_opacity, new_left_bound, new_right_bound - ) + self._validate_colored_area_payload(area, pending_fields, new_color, new_opacity, new_left_bound, new_right_bound) self.canvas.undo_redo_manager.archive() self._apply_colored_area_updates(area, pending_fields, new_color, new_opacity, new_left_bound, new_right_bound) @@ -871,7 +868,11 @@ def _validate_colored_area_payload( if "right_bound" in pending_fields and new_right_bound is not None: updated_right = float(new_right_bound) - if updated_left is not None and updated_right is not None and updated_left >= updated_right: + if ( + updated_left is not None + and updated_right is not None + and updated_left >= updated_right + ): raise ValueError("left_bound must be less than right_bound.") def _apply_colored_area_updates( diff --git a/static/client/typing/browser/__init__.pyi b/static/client/type_stubs/browser/__init__.pyi similarity index 100% rename from static/client/typing/browser/__init__.pyi rename to static/client/type_stubs/browser/__init__.pyi diff --git a/static/client/typing/browser/_dom.pyi b/static/client/type_stubs/browser/_dom.pyi similarity index 100% rename from static/client/typing/browser/_dom.pyi rename to static/client/type_stubs/browser/_dom.pyi diff --git a/static/client/typing/browser/aio.pyi b/static/client/type_stubs/browser/aio.pyi similarity index 100% rename from static/client/typing/browser/aio.pyi rename to static/client/type_stubs/browser/aio.pyi diff --git a/static/client/typing/browser/ajax.pyi b/static/client/type_stubs/browser/ajax.pyi similarity index 100% rename from static/client/typing/browser/ajax.pyi rename to static/client/type_stubs/browser/ajax.pyi diff --git a/static/tool_search_service.py b/static/tool_search_service.py index fcca6779..90cfb188 100644 --- a/static/tool_search_service.py +++ b/static/tool_search_service.py @@ -197,14 +197,8 @@ def search_tools( max_results = max(1, min(20, max_results)) # Use provided model, instance default, or fallback to gpt-4.1-mini. - # Tool search is a lightweight classification task — reasoning models - # are overkill and their max_completion_tokens budget includes internal - # reasoning tokens, so 500 tokens may not leave enough room for output. - # For OpenAI reasoning models, downgrade to gpt-4.1-mini automatically. if model is None: model = self.default_model or AIModel.from_identifier("gpt-4.1-mini") - if getattr(model, "is_reasoning_model", False) and getattr(model, "provider", "") == "openai": - model = AIModel.from_identifier("gpt-4.1-mini") # Build the prompt tool_descriptions = self.build_tool_descriptions() @@ -216,12 +210,19 @@ def search_tools( try: # Call the AI model - response = self.client.chat.completions.create( - model=model.id, - messages=[{"role": "user", "content": prompt}], - temperature=0.0, # Deterministic for consistent results - max_tokens=500, # Tool names are short - ) + request_kwargs: Dict[str, Any] = { + "model": model.id, + "messages": [{"role": "user", "content": prompt}], + "temperature": 0.0, # Deterministic for consistent results + } + # OpenAI reasoning models in Chat Completions reject max_tokens + # and require max_completion_tokens. + if model.is_reasoning_model and model.provider == "openai": + request_kwargs["max_completion_tokens"] = 500 + else: + request_kwargs["max_tokens"] = 500 + + response = self.client.chat.completions.create(**request_kwargs) # Extract the response content content = response.choices[0].message.content