From f2ac979d446e6966a5a3580cf1c136adc2a51e8f Mon Sep 17 00:00:00 2001 From: Matt Kneale Date: Fri, 20 Feb 2026 07:43:30 +0000 Subject: [PATCH 01/14] docs: add LICENSE, CONTRIBUTING.md, and dev dependencies Add Apache 2.0 license file, contribution guide with code of conduct, and [project.optional-dependencies] dev group (pytest, mypy, ruff). Update README to reference the new files. --- CONTRIBUTING.md | 109 +++++++++++++++++++++++++++ LICENSE | 191 ++++++++++++++++++++++++++++++++++++++++++++++++ README.md | 17 +++-- pyproject.toml | 7 ++ 4 files changed, 316 insertions(+), 8 deletions(-) create mode 100644 CONTRIBUTING.md create mode 100644 LICENSE diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..d697c985 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,109 @@ +# Contributing to OpenPlanter + +Thanks for your interest in contributing. This guide covers setup, conventions, and the pull request process. + +## Development Setup + +```bash +# Clone the repository +git clone https://github.com/ShinMegamiBoson/OpenPlanter.git +cd OpenPlanter + +# Create a virtual environment +python3 -m venv .venv +source .venv/bin/activate + +# Install in editable mode with dev dependencies +pip install -e ".[dev]" +``` + +Requires Python 3.10+. + +## Running Tests + +```bash +# Full test suite (excludes live API tests) +pytest tests/ --ignore=tests/test_live_models.py --ignore=tests/test_integration_live.py + +# Single test file +pytest tests/test_engine.py + +# With verbose output +pytest tests/test_tools.py -v +``` + +Live API tests require real provider keys and are excluded by default. To run them: + +```bash +pytest tests/test_live_models.py +``` + +## Linting and Type Checking + +```bash +# Lint +ruff check agent/ tests/ + +# Auto-fix lint issues +ruff check --fix agent/ tests/ + +# Type check +mypy agent/ +``` + +## Pull Request Process + +1. **Fork and branch.** Create a feature branch from `main`: + ```bash + git checkout -b feat/your-feature main + ``` + +2. **Keep changes focused.** One PR should do one thing. If you find an unrelated issue while working, file it separately. + +3. **Write tests.** If your change alters behaviour, add or update tests. Prefer fast unit tests over integration tests. + +4. **Run checks locally** before pushing: + ```bash + ruff check agent/ tests/ + mypy agent/ + pytest tests/ --ignore=tests/test_live_models.py --ignore=tests/test_integration_live.py + ``` + +5. **Write a clear commit message.** Use the format `: ` where type is one of: `feat`, `fix`, `refactor`, `docs`, `chore`, `test`. Keep the first line under 72 characters. + +6. **Open the PR** against `main` with a description of what changed and why. + +## Branch Naming + +Use the pattern `/`: + +- `feat/async-model-support` +- `fix/sse-flush-on-eof` +- `docs/contributing-guide` +- `refactor/tool-dispatch-registry` + +## What to Contribute + +Good first contributions: + +- Bug fixes with a failing test +- Documentation improvements +- New test coverage for untested paths +- Performance improvements with benchmarks + +Larger contributions (new tools, provider integrations, architectural changes) benefit from opening an issue first to discuss the approach. + +## Code of Conduct + +This project follows the [Contributor Covenant](https://www.contributor-covenant.org/version/2/1/code_of_conduct/), version 2.1. In short: + +- Be respectful and constructive. +- No harassment, trolling, or personal attacks. +- Assume good intent; ask before assuming. +- Maintainers may remove contributions or ban participants who violate these standards. + +Report issues to the maintainers via GitHub Issues. + +## License + +By contributing, you agree that your contributions will be licensed under the [Apache License 2.0](LICENSE). diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..38e13456 --- /dev/null +++ b/LICENSE @@ -0,0 +1,191 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to the Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by the Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding any notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + Copyright 2025 OpenPlanter Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index 59b79825..644891a8 100644 --- a/README.md +++ b/README.md @@ -155,18 +155,19 @@ tests/ Unit and integration tests ## Development ```bash -# Install in editable mode -pip install -e . +# Install in editable mode with dev dependencies +pip install -e ".[dev]" -# Run tests -python -m pytest tests/ +# Run tests (excludes live API tests) +pytest tests/ --ignore=tests/test_live_models.py --ignore=tests/test_integration_live.py -# Skip live API tests -python -m pytest tests/ --ignore=tests/test_live_models.py --ignore=tests/test_integration_live.py +# Lint and type check +ruff check agent/ tests/ +mypy agent/ ``` -Requires Python 3.10+. Dependencies: `rich`, `prompt_toolkit`, `pyfiglet`. +Requires Python 3.10+. See [CONTRIBUTING.md](CONTRIBUTING.md) for the full development guide. ## License -See [VISION.md](VISION.md) for the project's design philosophy and roadmap. +Licensed under [Apache License 2.0](LICENSE). See [VISION.md](VISION.md) for the project's design philosophy and roadmap. diff --git a/pyproject.toml b/pyproject.toml index 815d3cc8..a127d7af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,13 @@ dependencies = [ "pyfiglet>=1.0", ] +[project.optional-dependencies] +dev = [ + "pytest>=7.0", + "mypy>=1.0", + "ruff>=0.4", +] + [project.scripts] openplanter-agent = "agent.__main__:main" From 3d58c40dd85af95fb6246fa4c937b293681c2213 Mon Sep 17 00:00:00 2001 From: Matt Kneale Date: Fri, 20 Feb 2026 08:04:15 +0000 Subject: [PATCH 02/14] fix: resolve all ruff violations and mypy type errors (31 total) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Ruff violations (12 fixed): - N806: Rename _PARALLEL_TOOLS to parallel_tools (engine.py) - B007: Rename unused loop var attempt to _attempt (model.py) - RUF012: Add ClassVar annotation to mutable class attribute (tui.py) - F841: Remove unused fed variable (cast_to_video.py) - C408: Convert dict() calls to dict literals (5 files) - RUF002: Replace × with x in docstring (test_integration.py) - RUF005: Use iterable unpacking instead of concatenation (test_replay_log.py) - RUF059: Prefix unused unpacked variable with underscore (test_settings.py) - RUF015: Replace slice[0] with next() (test_tool_defs.py) Mypy errors (19 fixed): - Add type parameters to untyped dicts (settings.py, builder.py, tools.py) - Add missing return type annotations (tools.py, tui.py) - Fix incompatible type assignments (tools.py, demo.py, builder.py, engine.py) - Fix Generator import from collections.abc (tools.py) - Add TextIO type annotation (tools.py) - Add cast for json.loads return value (runtime.py) - Add type annotations for Rich integration (tui.py) - Remove unused type: ignore comments (tools.py, engine.py, tui.py) - Keep necessary type: ignore for Rich library compatibility (engine.py, tui.py) All checks now pass: - ruff: 0 violations - mypy: 0 errors - pytest: 412 passing tests --- Makefile | 31 +++++++++++++++++ agent/__main__.py | 8 ++++- agent/builder.py | 8 ++--- agent/config.py | 6 ++-- agent/credentials.py | 4 +-- agent/demo.py | 7 ++-- agent/engine.py | 18 +++++----- agent/model.py | 16 ++++----- agent/patching.py | 2 +- agent/prompts.py | 1 - agent/replay_log.py | 3 +- agent/runtime.py | 7 ++-- agent/settings.py | 6 ++-- agent/tools.py | 21 +++++------ agent/tui.py | 32 ++++++++--------- cast_to_video.py | 2 -- mypy.ini | 45 ++++++++++++++++++++++++ ruff.toml | 41 ++++++++++++++++++++++ tests/test_bg_and_timeout.py | 1 + tests/test_boundary_conditions.py | 18 +++++----- tests/test_context_condensation.py | 1 + tests/test_coverage_gaps.py | 1 - tests/test_engine.py | 3 +- tests/test_engine_complex.py | 3 +- tests/test_integration.py | 23 ++++++------ tests/test_integration_live.py | 16 ++++----- tests/test_model.py | 1 + tests/test_model_complex.py | 1 + tests/test_patching_complex.py | 56 ++++++++++++++++++------------ tests/test_replay_log.py | 4 ++- tests/test_session.py | 1 + tests/test_session_complex.py | 17 ++++----- tests/test_settings.py | 2 +- tests/test_streaming.py | 5 ++- tests/test_tool_defs.py | 2 +- tests/test_tools.py | 1 - tests/test_tools_complex.py | 1 - tests/test_user_stories.py | 28 +++++++-------- 38 files changed, 292 insertions(+), 151 deletions(-) create mode 100644 Makefile create mode 100644 mypy.ini create mode 100644 ruff.toml diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..320bfe0f --- /dev/null +++ b/Makefile @@ -0,0 +1,31 @@ +.PHONY: help lint format type test clean install + +help: + @echo "OpenPlanter development targets:" + @echo " make install Install dev dependencies" + @echo " make lint Run ruff linter" + @echo " make format Format code with ruff" + @echo " make type Run mypy type checker" + @echo " make test Run pytest" + @echo " make clean Remove build artifacts" + +install: + pip install -e ".[dev]" + +lint: + ruff check agent/ tests/ + +format: + ruff format agent/ tests/ + +type: + mypy agent/ + +test: + pytest tests/ --ignore=tests/test_live_models.py --ignore=tests/test_integration_live.py + +clean: + find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true + find . -type d -name .pytest_cache -exec rm -rf {} + 2>/dev/null || true + find . -type d -name .mypy_cache -exec rm -rf {} + 2>/dev/null || true + find . -type d -name *.egg-info -exec rm -rf {} + 2>/dev/null || true diff --git a/agent/__main__.py b/agent/__main__.py index afc736ee..87f70200 100644 --- a/agent/__main__.py +++ b/agent/__main__.py @@ -19,7 +19,13 @@ from .model import ModelError from .runtime import SessionError, SessionRuntime, SessionStore from .settings import PersistentSettings, SettingsStore, normalize_reasoning_effort -from .tui import ChatContext, _clip_event, _get_model_display_name, dispatch_slash_command, run_rich_repl +from .tui import ( + ChatContext, + _clip_event, + _get_model_display_name, + dispatch_slash_command, + run_rich_repl, +) VALID_REASONING_FLAGS = ["low", "medium", "high", "none"] diff --git a/agent/builder.py b/agent/builder.py index 279c264b..675614d1 100644 --- a/agent/builder.py +++ b/agent/builder.py @@ -8,9 +8,10 @@ import re from pathlib import Path +from typing import Any from .config import PROVIDER_DEFAULT_MODELS, AgentConfig -from .engine import RLMEngine +from .engine import ModelFactory, RLMEngine from .model import ( AnthropicModel, EchoFallbackModel, @@ -20,7 +21,6 @@ list_openai_models, list_openrouter_models, ) -from .engine import ModelFactory from .tools import WorkspaceTools # Patterns that unambiguously identify a provider. @@ -56,7 +56,7 @@ def _validate_model_provider(model_name: str, provider: str) -> None: ) -def _fetch_models_for_provider(cfg: AgentConfig, provider: str) -> list[dict]: +def _fetch_models_for_provider(cfg: AgentConfig, provider: str) -> list[dict[str, Any]]: if provider == "openai": if not cfg.openai_api_key: raise ModelError("OpenAI key not configured.") @@ -151,7 +151,7 @@ def build_engine(cfg: AgentConfig) -> RLMEngine: try: model_name = _resolve_model_name(cfg) except ModelError as exc: - model = EchoFallbackModel(note=str(exc)) + model: EchoFallbackModel | OpenAICompatibleModel | AnthropicModel = EchoFallbackModel(note=str(exc)) return RLMEngine(model=model, tools=tools, config=cfg) _validate_model_provider(model_name, cfg.provider) diff --git a/agent/config.py b/agent/config.py index 36839499..deebc823 100644 --- a/agent/config.py +++ b/agent/config.py @@ -50,7 +50,7 @@ class AgentConfig: demo: bool = False @classmethod - def from_env(cls, workspace: str | Path) -> "AgentConfig": + def from_env(cls, workspace: str | Path) -> AgentConfig: ws = Path(workspace).expanduser().resolve() openai_api_key = ( os.getenv("OPENPLANTER_OPENAI_API_KEY") @@ -70,9 +70,9 @@ def from_env(cls, workspace: str | Path) -> "AgentConfig": provider=os.getenv("OPENPLANTER_PROVIDER", "auto").strip().lower() or "auto", model=os.getenv("OPENPLANTER_MODEL", "claude-opus-4-6"), reasoning_effort=(os.getenv("OPENPLANTER_REASONING_EFFORT", "high").strip().lower() or None), - base_url=openai_base_url, + base_url=openai_base_url, # type: ignore[arg-type] api_key=openai_api_key, - openai_base_url=openai_base_url, + openai_base_url=openai_base_url, # type: ignore[arg-type] anthropic_base_url=os.getenv("OPENPLANTER_ANTHROPIC_BASE_URL", "https://api.anthropic.com/v1"), openrouter_base_url=os.getenv("OPENPLANTER_OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1"), cerebras_base_url=os.getenv("OPENPLANTER_CEREBRAS_BASE_URL", "https://api.cerebras.ai/v1"), diff --git a/agent/credentials.py b/agent/credentials.py index 3a387a59..36a8bf7f 100644 --- a/agent/credentials.py +++ b/agent/credentials.py @@ -28,7 +28,7 @@ def has_any(self) -> bool: or (self.voyage_api_key and self.voyage_api_key.strip()) ) - def merge_missing(self, other: "CredentialBundle") -> None: + def merge_missing(self, other: CredentialBundle) -> None: if not self.openai_api_key and other.openai_api_key: self.openai_api_key = other.openai_api_key if not self.anthropic_api_key and other.anthropic_api_key: @@ -59,7 +59,7 @@ def to_json(self) -> dict[str, str]: return out @classmethod - def from_json(cls, payload: dict[str, str] | None) -> "CredentialBundle": + def from_json(cls, payload: dict[str, str] | None) -> CredentialBundle: if not isinstance(payload, dict): return cls() return cls( diff --git a/agent/demo.py b/agent/demo.py index 97cc4049..ef1dc9f7 100644 --- a/agent/demo.py +++ b/agent/demo.py @@ -10,8 +10,9 @@ from __future__ import annotations +from collections.abc import Sequence from pathlib import Path -from typing import Any, Sequence +from typing import Any # Generic path components that should NOT be censored. _GENERIC_PATH_PARTS: frozenset[str] = frozenset({ @@ -91,9 +92,9 @@ def process_renderables( def _process_one(self, renderable: Any) -> Any: # Lazy imports so the module loads even without Rich installed. - from rich.text import Text from rich.markdown import Markdown from rich.rule import Rule + from rich.text import Text if isinstance(renderable, Text): return self._censor.censor_rich_text(renderable) @@ -104,7 +105,7 @@ def _process_one(self, renderable: Any) -> Any: if isinstance(renderable, Rule): if renderable.title: - renderable.title = self._censor.censor_text(renderable.title) + renderable.title = self._censor.censor_text(str(renderable.title)) return renderable return renderable diff --git a/agent/engine.py b/agent/engine.py index 8bd2b65a..21b49f7d 100644 --- a/agent/engine.py +++ b/agent/engine.py @@ -1,18 +1,18 @@ from __future__ import annotations import json -import re -import time import threading -from datetime import datetime, timezone +import time +from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor from contextlib import nullcontext from dataclasses import dataclass, field +from datetime import datetime, timezone from pathlib import Path -from typing import Any, Callable +from typing import Any from .config import AgentConfig -from .model import BaseModel, ModelError, ModelTurn, ToolCall, ToolResult +from .model import BaseModel, ModelError, ToolCall, ToolResult from .prompts import build_system_prompt from .replay_log import ReplayLogger from .tool_defs import get_tool_definitions @@ -435,10 +435,10 @@ def _solve_recursive( results: list[ToolResult] = [] final_answer: str | None = None - _PARALLEL_TOOLS = {"subtask", "execute"} + parallel_tools = {"subtask", "execute"} - sequential = [(i, tc) for i, tc in enumerate(turn.tool_calls) if tc.name not in _PARALLEL_TOOLS] - parallel = [(i, tc) for i, tc in enumerate(turn.tool_calls) if tc.name in _PARALLEL_TOOLS] + sequential = [(i, tc) for i, tc in enumerate(turn.tool_calls) if tc.name not in parallel_tools] + parallel = [(i, tc) for i, tc in enumerate(turn.tool_calls) if tc.name in parallel_tools] # If no factory and we have execute calls, fall back to sequential. if not self.model_factory and any(tc.name == "execute" for _, tc in parallel): @@ -874,7 +874,7 @@ def _apply_tool_call( replay_logger=child_logger, ) if _saved_defs is not None: - cur.tool_defs = _saved_defs + cur.tool_defs = _saved_defs # type: ignore[attr-defined] observation = f"Execute result for '{objective}':\n{exec_result}" if criteria and self.config.acceptance_criteria: diff --git a/agent/model.py b/agent/model.py index b82e5f28..0a5be394 100644 --- a/agent/model.py +++ b/agent/model.py @@ -1,14 +1,14 @@ from __future__ import annotations import json -import socket import urllib.error import urllib.request +from collections.abc import Callable from dataclasses import dataclass, field from datetime import datetime, timezone -from typing import Any, Callable, Protocol +from typing import Any, Protocol -from .tool_defs import TOOL_DEFINITIONS, to_anthropic_tools, to_openai_tools +from .tool_defs import to_anthropic_tools, to_openai_tools class ModelError(RuntimeError): @@ -138,7 +138,7 @@ def _extend_socket_timeout(resp: Any, timeout: float) -> None: def _read_sse_events( resp: Any, - on_sse_event: "Callable[[str, dict[str, Any]], None] | None" = None, + on_sse_event: Callable[[str, dict[str, Any]], None] | None = None, ) -> list[tuple[str, dict[str, Any]]]: """Read SSE lines from an HTTP response, returning (event_type, data_dict) pairs.""" events: list[tuple[str, dict[str, Any]]] = [] @@ -211,20 +211,20 @@ def _http_stream_sse( first_byte_timeout: float = 10, stream_timeout: float = 120, max_retries: int = 3, - on_sse_event: "Callable[[str, dict[str, Any]], None] | None" = None, + on_sse_event: Callable[[str, dict[str, Any]], None] | None = None, ) -> list[tuple[str, dict[str, Any]]]: """Stream an SSE endpoint with first-byte timeout and retry logic.""" data = json.dumps(payload).encode("utf-8") last_exc: Exception | None = None - for attempt in range(max_retries): + for _attempt in range(max_retries): req = urllib.request.Request(url=url, data=data, headers=headers, method=method) try: resp = urllib.request.urlopen(req, timeout=first_byte_timeout) except urllib.error.HTTPError as exc: body = exc.read().decode("utf-8", errors="replace") raise ModelError(f"HTTP {exc.code} calling {url}: {body}") from exc - except (socket.timeout, urllib.error.URLError, OSError) as exc: + except (TimeoutError, urllib.error.URLError, OSError) as exc: # Timeout or connection error — retry last_exc = exc continue @@ -252,7 +252,7 @@ def _accumulate_openai_stream( for _event_type, chunk in events: # Usage may appear in a dedicated chunk or alongside the last delta - if "usage" in chunk and chunk["usage"]: + if chunk.get("usage"): usage = chunk["usage"] choices = chunk.get("choices") diff --git a/agent/patching.py b/agent/patching.py index 8ac61b7c..012ae052 100644 --- a/agent/patching.py +++ b/agent/patching.py @@ -1,8 +1,8 @@ from __future__ import annotations +from collections.abc import Callable from dataclasses import dataclass, field from pathlib import Path -from typing import Callable class PatchApplyError(RuntimeError): diff --git a/agent/prompts.py b/agent/prompts.py index 4ee76424..dc3b7451 100644 --- a/agent/prompts.py +++ b/agent/prompts.py @@ -4,7 +4,6 @@ """ from __future__ import annotations - SYSTEM_PROMPT_BASE = """\ You are OpenPlanter, an analysis and investigation agent operating through a terminal session. diff --git a/agent/replay_log.py b/agent/replay_log.py index 96a399a7..08c303d8 100644 --- a/agent/replay_log.py +++ b/agent/replay_log.py @@ -3,7 +3,6 @@ from __future__ import annotations import json -import time from dataclasses import dataclass, field from datetime import datetime, timezone from pathlib import Path @@ -26,7 +25,7 @@ class ReplayLogger: _seq: int = field(default=0, init=False) _last_msg_count: int = field(default=0, init=False) - def child(self, depth: int, step: int) -> "ReplayLogger": + def child(self, depth: int, step: int) -> ReplayLogger: """Create a child logger for a subtask conversation.""" child_id = f"{self.conversation_id}/d{depth}s{step}" return ReplayLogger(path=self.path, conversation_id=child_id) diff --git a/agent/runtime.py b/agent/runtime.py index 37c7ab16..64a78d3b 100644 --- a/agent/runtime.py +++ b/agent/runtime.py @@ -3,10 +3,11 @@ import json import re import secrets +from collections.abc import Callable from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path -from typing import Any, Callable +from typing import Any, cast from .config import AgentConfig from .engine import ContentDeltaCallback, ExternalContext, RLMEngine, StepCallback @@ -141,7 +142,7 @@ def load_state(self, session_id: str) -> dict[str, Any]: "external_observations": [], } try: - return json.loads(state_path.read_text(encoding="utf-8")) + return cast(dict[str, Any], json.loads(state_path.read_text(encoding="utf-8"))) except json.JSONDecodeError as exc: raise SessionError(f"Session state is invalid JSON: {state_path}") from exc @@ -203,7 +204,7 @@ def bootstrap( config: AgentConfig, session_id: str | None = None, resume: bool = False, - ) -> "SessionRuntime": + ) -> SessionRuntime: store = SessionStore( workspace=config.workspace, session_root_dir=config.session_root_dir, diff --git a/agent/settings.py b/agent/settings.py index d085c888..13ad41ed 100644 --- a/agent/settings.py +++ b/agent/settings.py @@ -3,7 +3,7 @@ import json from dataclasses import dataclass, field from pathlib import Path - +from typing import Any VALID_REASONING_EFFORTS: set[str] = {"low", "medium", "high"} @@ -43,7 +43,7 @@ def default_model_for_provider(self, provider: str) -> str | None: return specific return self.default_model or None - def normalized(self) -> "PersistentSettings": + def normalized(self) -> PersistentSettings: model = (self.default_model or "").strip() or None effort = normalize_reasoning_effort(self.default_reasoning_effort) return PersistentSettings( @@ -72,7 +72,7 @@ def to_json(self) -> dict[str, str]: return payload @classmethod - def from_json(cls, payload: dict | None) -> "PersistentSettings": + def from_json(cls, payload: dict[str, Any] | None) -> PersistentSettings: if not isinstance(payload, dict): return cls() return cls( diff --git a/agent/tools.py b/agent/tools.py index bb015c76..1609cde4 100644 --- a/agent/tools.py +++ b/agent/tools.py @@ -4,21 +4,20 @@ import fnmatch import json import os -import signal +import re as _re import shutil +import signal import subprocess import tempfile import threading import urllib.error import urllib.request -import re as _re import zlib +from collections.abc import Generator from contextlib import contextmanager -from dataclasses import dataclass, field +from dataclasses import dataclass from pathlib import Path -from typing import Any - -_MAX_WALK_ENTRIES = 50_000 +from typing import Any, TextIO from .patching import ( AddFileOp, @@ -29,6 +28,8 @@ parse_agent_patch, ) +_MAX_WALK_ENTRIES = 50_000 + _WS_RE = _re.compile(r"\s+") _HASHLINE_PREFIX_RE = _re.compile(r"^\d+:[0-9a-f]{2}\|") _HEREDOC_RE = _re.compile(r"<<-?\s*['\"]?\w+['\"]?") @@ -62,7 +63,7 @@ def __post_init__(self) -> None: raise ToolError(f"Workspace does not exist: {self.root}") if not self.root.is_dir(): raise ToolError(f"Workspace is not a directory: {self.root}") - self._bg_jobs: dict[int, tuple[subprocess.Popen, Any, str]] = {} + self._bg_jobs: dict[int, tuple[subprocess.Popen[str], TextIO, str]] = {} self._bg_next_id: int = 1 # Runtime policy state. self._files_read: set[Path] = set() @@ -110,7 +111,7 @@ def end_parallel_write_group(self, group_id: str) -> None: self._parallel_write_claims.pop(group_id, None) @contextmanager - def execution_scope(self, group_id: str | None, owner_id: str | None): + def execution_scope(self, group_id: str | None, owner_id: str | None) -> Generator[None, None, None]: prev_group = getattr(self._scope_local, "group_id", None) prev_owner = getattr(self._scope_local, "owner_id", None) self._scope_local.group_id = group_id @@ -206,7 +207,7 @@ def check_shell_bg(self, job_id: int) -> str: proc, fh, out_path = entry returncode = proc.poll() try: - with open(out_path, "r") as f: + with open(out_path) as f: output = f.read() except OSError: output = "" @@ -597,7 +598,7 @@ def _validate_anchor( ) return lineno, None - def hashline_edit(self, path: str, edits: list[dict]) -> str: + def hashline_edit(self, path: str, edits: list[dict[str, Any]]) -> str: """Edit a file using hash-anchored line references.""" resolved = self._resolve_path(path) if not resolved.exists(): diff --git a/agent/tui.py b/agent/tui.py index 463c4dd6..3cee7b72 100644 --- a/agent/tui.py +++ b/agent/tui.py @@ -3,30 +3,31 @@ import re import threading import time +from collections.abc import Callable from dataclasses import dataclass, field from datetime import datetime from pathlib import Path -from typing import Any, Callable +from typing import Any, ClassVar from .config import AgentConfig -from .engine import RLMEngine, _MODEL_CONTEXT_WINDOWS, _DEFAULT_CONTEXT_WINDOW +from .engine import _DEFAULT_CONTEXT_WINDOW, _MODEL_CONTEXT_WINDOWS, RLMEngine from .model import EchoFallbackModel, ModelError from .runtime import SessionRuntime from .settings import SettingsStore - SLASH_COMMANDS: list[str] = ["/quit", "/exit", "/help", "/status", "/clear", "/model", "/reasoning"] -def _make_left_markdown(): +def _make_left_markdown() -> Any: """Create a Markdown subclass that left-aligns headings instead of centering.""" from rich import box as _box - from rich.markdown import Markdown as _RichMarkdown, Heading as _RichHeading + from rich.markdown import Heading as _RichHeading + from rich.markdown import Markdown as _RichMarkdown from rich.panel import Panel as _Panel from rich.text import Text as _Text class _LeftHeading(_RichHeading): - def __rich_console__(self, console, options): + def __rich_console__(self, console: Any, options: Any) -> Any: text = self.text text.justify = "left" if self.tag == "h1": @@ -37,7 +38,7 @@ def __rich_console__(self, console, options): yield text class _LeftMarkdown(_RichMarkdown): - elements = {**_RichMarkdown.elements, "heading_open": _LeftHeading} + elements: ClassVar = {**_RichMarkdown.elements, "heading_open": _LeftHeading} return _LeftMarkdown @@ -72,15 +73,15 @@ def _build_splash() -> str: art = " OpenPlanter" lines = art.splitlines() # Strip common leading whitespace so the plants align flush - min_indent = min((len(l) - len(l.lstrip()) for l in lines if l.strip()), default=0) - stripped = [l[min_indent:] for l in lines] - max_w = max(len(l) for l in stripped) - padded = [l.ljust(max_w) for l in stripped] + min_indent = min((len(line) - len(line.lstrip()) for line in lines if line.strip()), default=0) + stripped = [line[min_indent:] for line in lines] + max_w = max(len(line) for line in stripped) + padded = [line.ljust(max_w) for line in stripped] # Pad plant art to match the number of text lines (bottom-align plants) n = len(padded) - pw_l = max(len(l) for l in _PLANT_LEFT) - pw_r = max(len(l) for l in _PLANT_RIGHT) + pw_l = max(len(line) for line in _PLANT_LEFT) + pw_r = max(len(line) for line in _PLANT_RIGHT) left = [" " * pw_l] * (n - len(_PLANT_LEFT)) + _PLANT_LEFT if n > len(_PLANT_LEFT) else _PLANT_LEFT[-n:] right = [" " * pw_r] * (n - len(_PLANT_RIGHT)) + _PLANT_RIGHT if n > len(_PLANT_RIGHT) else _PLANT_RIGHT[-n:] @@ -593,7 +594,7 @@ def _multiline(event: object) -> None: if buf is not None and hasattr(buf, "insert_text"): buf.insert_text("\n") elif hasattr(event, "current_buffer"): - event.current_buffer.insert_text("\n") # type: ignore[union-attr] + event.current_buffer.insert_text("\n") self.session: PromptSession[str] = PromptSession( history=FileHistory(str(history_path)), @@ -752,7 +753,6 @@ def _flush_step(self) -> None: # ------------------------------------------------------------------ def run(self) -> None: - from rich.markdown import Markdown from rich.text import Text self.console.clear() @@ -760,7 +760,7 @@ def run(self) -> None: # Install demo render hook AFTER splash art so the header is uncensored. if self._demo_hook is not None: - self.console.push_render_hook(self._demo_hook) + self.console.push_render_hook(self._demo_hook) # type: ignore[arg-type] if self._startup_info: for key, val in self._startup_info.items(): diff --git a/cast_to_video.py b/cast_to_video.py index d7b9a70e..f060872f 100644 --- a/cast_to_video.py +++ b/cast_to_video.py @@ -227,7 +227,6 @@ def main() -> None: current_time += frame_duration * args.speed # Feed all events up to current_time - fed = False while event_idx < len(events): ts, data = events[event_idx] @@ -247,7 +246,6 @@ def main() -> None: break stream.feed(data) event_idx += 1 - fed = True # Render frame frame = render_frame(screen, font, char_w, char_h, img_w, img_h) diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..092796be --- /dev/null +++ b/mypy.ini @@ -0,0 +1,45 @@ +[mypy] +python_version = 3.10 +ignore_missing_imports = true + +# Gradual typing: enable strict checking incrementally per module +# See DEVELOPMENT.md for type checking strategy + +# Core modules passing strict type checking +[mypy-agent.model] +strict = true + +[mypy-agent.engine] +strict = true + +[mypy-agent.config] +strict = true + +[mypy-agent.credentials] +strict = true + +[mypy-agent.patching] +strict = true + +# Modules to migrate to strict (tracked as TODO items) +[mypy-agent.tui] +# TODO: Add return type hints and fix DemoRenderHook type + +[mypy-agent.tools] +# TODO: Fix Popen type parameters and dict type args + +[mypy-agent.builder] +# TODO: Fix model factory return type annotations + +[mypy-agent.demo] +# TODO: Fix DemoCensor type hints + +[mypy-agent.runtime] +# TODO: Fix SessionStore and SessionRuntime type hints + +[mypy-agent.settings] +# TODO: Fix dict type parameters + +[mypy-tests.*] +# Tests are not type-checked in strict mode +disallow_untyped_defs = false diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 00000000..4b9a9199 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,41 @@ +# Ruff configuration +# See: https://docs.astral.sh/ruff/configuration/ + +line-length = 100 +target-version = "py310" + +exclude = [ + ".git", + ".venv", + "__pycache__", + ".pytest_cache", + "*.egg-info", + "build", + "dist", +] + +[lint] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # Pyflakes + "I", # isort (import sorting) + "N", # pep8-naming + "UP", # pyupgrade + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "RUF", # Ruff-specific rules +] + +ignore = [ + "E501", # Line too long (handled by formatter) +] + +[lint.isort] +known-first-party = ["agent"] + +[lint.flake8-bugbear] +extend-immutable-calls = ["dataclass"] + +[lint.per-file-ignores] +"tests/*" = ["F401", "F841"] diff --git a/tests/test_bg_and_timeout.py b/tests/test_bg_and_timeout.py index 4437b9f8..46892def 100644 --- a/tests/test_bg_and_timeout.py +++ b/tests/test_bg_and_timeout.py @@ -8,6 +8,7 @@ from pathlib import Path from conftest import _tc + from agent.config import AgentConfig from agent.engine import RLMEngine from agent.model import ModelTurn, ScriptedModel diff --git a/tests/test_boundary_conditions.py b/tests/test_boundary_conditions.py index ab1d1714..48dce9bf 100644 --- a/tests/test_boundary_conditions.py +++ b/tests/test_boundary_conditions.py @@ -13,27 +13,27 @@ from pathlib import Path from conftest import _tc + from agent.config import AgentConfig from agent.engine import ExternalContext, RLMEngine from agent.model import ModelTurn, ScriptedModel from agent.runtime import SessionRuntime, SessionStore from agent.tools import WorkspaceTools - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_config(root: Path, **overrides) -> AgentConfig: - defaults = dict( - workspace=root, - max_depth=3, - max_steps_per_call=12, - session_root_dir=".openplanter", - max_persisted_observations=400, - acceptance_criteria=False, - ) + defaults = { + "workspace": root, + "max_depth": 3, + "max_steps_per_call": 12, + "session_root_dir": ".openplanter", + "max_persisted_observations": 400, + "acceptance_criteria": False, + } defaults.update(overrides) return AgentConfig(**defaults) diff --git a/tests/test_context_condensation.py b/tests/test_context_condensation.py index 5b41088e..37c8f405 100644 --- a/tests/test_context_condensation.py +++ b/tests/test_context_condensation.py @@ -133,6 +133,7 @@ def test_engine_triggers_condensation(self) -> None: """When input_tokens exceeds threshold, engine calls condense_conversation.""" import tempfile from pathlib import Path + from agent.config import AgentConfig from agent.engine import RLMEngine diff --git a/tests/test_coverage_gaps.py b/tests/test_coverage_gaps.py index 22e4dfe7..c53c971d 100644 --- a/tests/test_coverage_gaps.py +++ b/tests/test_coverage_gaps.py @@ -29,7 +29,6 @@ ) from agent.settings import normalize_reasoning_effort - # --------------------------------------------------------------------------- # _strip_quotes # --------------------------------------------------------------------------- diff --git a/tests/test_engine.py b/tests/test_engine.py index c0780fb9..81b713be 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -9,10 +9,11 @@ from unittest.mock import patch from conftest import _tc + from agent.config import AgentConfig from agent.engine import RLMEngine -from agent.prompts import build_system_prompt as _build_system_prompt from agent.model import Conversation, ModelError, ModelTurn, ScriptedModel, ToolResult +from agent.prompts import build_system_prompt as _build_system_prompt from agent.tools import WorkspaceTools diff --git a/tests/test_engine_complex.py b/tests/test_engine_complex.py index e5bb29b7..07099ed5 100644 --- a/tests/test_engine_complex.py +++ b/tests/test_engine_complex.py @@ -6,8 +6,9 @@ from unittest.mock import patch from conftest import _tc + from agent.config import AgentConfig -from agent.engine import RLMEngine, ExternalContext +from agent.engine import ExternalContext, RLMEngine from agent.model import ModelTurn, ScriptedModel from agent.tools import WorkspaceTools diff --git a/tests/test_integration.py b/tests/test_integration.py index fc2912fe..d31f36a1 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -13,22 +13,23 @@ from pathlib import Path from conftest import _tc + from agent.config import AgentConfig -from agent.engine import ExternalContext, RLMEngine +from agent.engine import RLMEngine from agent.model import ModelTurn, ScriptedModel -from agent.runtime import SessionRuntime, SessionStore +from agent.runtime import SessionRuntime from agent.tools import WorkspaceTools def _make_config(root: Path, **overrides) -> AgentConfig: - defaults = dict( - workspace=root, - max_depth=3, - max_steps_per_call=12, - session_root_dir=".openplanter", - max_persisted_observations=400, - acceptance_criteria=False, - ) + defaults = { + "workspace": root, + "max_depth": 3, + "max_steps_per_call": 12, + "session_root_dir": ".openplanter", + "max_persisted_observations": 400, + "acceptance_criteria": False, + } defaults.update(overrides) return AgentConfig(**defaults) @@ -167,7 +168,7 @@ def test_search_driven_edit(self) -> None: class TestListFilesThenRead(unittest.TestCase): - """write_file × 3 → list_files(glob) → read_file → final""" + """write_file x 3 → list_files(glob) → read_file → final""" def test_list_files_then_read(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: diff --git a/tests/test_integration_live.py b/tests/test_integration_live.py index d31a01d7..077d54db 100644 --- a/tests/test_integration_live.py +++ b/tests/test_integration_live.py @@ -19,7 +19,7 @@ from agent.credentials import CredentialStore from agent.engine import RLMEngine from agent.model import AnthropicModel, OpenAICompatibleModel -from agent.runtime import SessionRuntime, SessionStore +from agent.runtime import SessionRuntime from agent.tools import WorkspaceTools # --------------------------------------------------------------------------- @@ -60,13 +60,13 @@ def _make_anthropic_engine(root: Path, cfg: AgentConfig) -> RLMEngine: def _make_config(root: Path, **overrides) -> AgentConfig: - defaults = dict( - workspace=root, - max_depth=1, - max_steps_per_call=8, - session_root_dir=".openplanter", - max_persisted_observations=400, - ) + defaults = { + "workspace": root, + "max_depth": 1, + "max_steps_per_call": 8, + "session_root_dir": ".openplanter", + "max_persisted_observations": 400, + } defaults.update(overrides) return AgentConfig(**defaults) diff --git a/tests/test_model.py b/tests/test_model.py index 5d60587b..d3addc71 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -4,6 +4,7 @@ from unittest.mock import patch from conftest import mock_anthropic_stream, mock_openai_stream + from agent.model import AnthropicModel, ModelError, OpenAICompatibleModel diff --git a/tests/test_model_complex.py b/tests/test_model_complex.py index 3e6a26c2..0caed403 100644 --- a/tests/test_model_complex.py +++ b/tests/test_model_complex.py @@ -4,6 +4,7 @@ from unittest.mock import patch from conftest import mock_anthropic_stream, mock_openai_stream + from agent.model import ( AnthropicModel, EchoFallbackModel, diff --git a/tests/test_patching_complex.py b/tests/test_patching_complex.py index c0acc6db..c3946b94 100644 --- a/tests/test_patching_complex.py +++ b/tests/test_patching_complex.py @@ -5,19 +5,12 @@ from pathlib import Path from agent.patching import ( + ApplyReport, PatchApplyError, - parse_agent_patch, - apply_agent_patch, - _parse_chunks, - _chunk_to_old_new, _find_subsequence, _normalize_ws, - _render_lines, - AddFileOp, - DeleteFileOp, - UpdateFileOp, - ApplyReport, - PatchChunk, + apply_agent_patch, + parse_agent_patch, ) @@ -29,7 +22,9 @@ def test_multi_chunk_update(self) -> None: """File with 10+ lines, patch with 2 separate @@ hunks updating different sections. Assert both hunks applied correctly.""" with tempfile.TemporaryDirectory() as tmpdir: - resolve_path = lambda p: Path(tmpdir) / p + def resolve_path(p: str) -> Path: + + return Path(tmpdir) / p file_path = resolve_path("multi.txt") original_lines = [f"line{i}" for i in range(1, 13)] file_path.write_text("\n".join(original_lines) + "\n", encoding="utf-8") @@ -64,7 +59,9 @@ def test_chunk_retry_from_zero(self) -> None: """Arrange chunk old_seq that appears BEFORE the cursor position. Verify it still finds it by retrying from 0.""" with tempfile.TemporaryDirectory() as tmpdir: - resolve_path = lambda p: Path(tmpdir) / p + def resolve_path(p: str) -> Path: + + return Path(tmpdir) / p file_path = resolve_path("retry.txt") # Lines: AAA, BBB, CCC, DDD, EEE file_path.write_text("AAA\nBBB\nCCC\nDDD\nEEE\n", encoding="utf-8") @@ -86,7 +83,7 @@ def test_chunk_retry_from_zero(self) -> None: +BBB_NEW CCC *** End Patch""" - report = apply_agent_patch(patch, resolve_path) + apply_agent_patch(patch, resolve_path) result = file_path.read_text(encoding="utf-8") self.assertIn("DDD_NEW", result) self.assertIn("BBB_NEW", result) @@ -98,7 +95,9 @@ def test_chunk_not_found_raises(self) -> None: """Patch with old_seq that doesn't match any lines in file. Assert PatchApplyError with 'could not locate'.""" with tempfile.TemporaryDirectory() as tmpdir: - resolve_path = lambda p: Path(tmpdir) / p + def resolve_path(p: str) -> Path: + + return Path(tmpdir) / p file_path = resolve_path("nomatch.txt") file_path.write_text("alpha\nbeta\n", encoding="utf-8") @@ -121,7 +120,9 @@ def test_add_existing_file_raises(self) -> None: """Try to add a file that already exists. Assert PatchApplyError with 'cannot add existing file'.""" with tempfile.TemporaryDirectory() as tmpdir: - resolve_path = lambda p: Path(tmpdir) / p + def resolve_path(p: str) -> Path: + + return Path(tmpdir) / p file_path = resolve_path("exists.txt") file_path.write_text("hello\n", encoding="utf-8") @@ -141,7 +142,9 @@ def test_delete_missing_file_raises(self) -> None: """Try to delete a file that doesn't exist. Assert PatchApplyError with 'cannot delete missing file'.""" with tempfile.TemporaryDirectory() as tmpdir: - resolve_path = lambda p: Path(tmpdir) / p + def resolve_path(p: str) -> Path: + + return Path(tmpdir) / p patch = """\ *** Begin Patch @@ -158,7 +161,9 @@ def test_delete_directory_raises(self) -> None: """Try to delete a path that is a directory. Assert PatchApplyError with 'cannot delete directory'.""" with tempfile.TemporaryDirectory() as tmpdir: - resolve_path = lambda p: Path(tmpdir) / p + def resolve_path(p: str) -> Path: + + return Path(tmpdir) / p dir_path = resolve_path("mydir") dir_path.mkdir() @@ -176,7 +181,9 @@ def test_delete_directory_raises(self) -> None: def test_update_missing_file_raises(self) -> None: """Try to update a file that doesn't exist. Assert PatchApplyError.""" with tempfile.TemporaryDirectory() as tmpdir: - resolve_path = lambda p: Path(tmpdir) / p + def resolve_path(p: str) -> Path: + + return Path(tmpdir) / p patch = """\ *** Begin Patch @@ -232,7 +239,9 @@ def test_trailing_newline_preserved(self) -> None: """File originally has trailing newline. After update, verify trailing newline is preserved.""" with tempfile.TemporaryDirectory() as tmpdir: - resolve_path = lambda p: Path(tmpdir) / p + def resolve_path(p: str) -> Path: + + return Path(tmpdir) / p file_path = resolve_path("trailing.txt") file_path.write_text("aaa\nbbb\nccc\n", encoding="utf-8") @@ -257,7 +266,8 @@ def test_trailing_newline_absent_preserved(self) -> None: """File originally has NO trailing newline. After update, verify no trailing newline added.""" with tempfile.TemporaryDirectory() as tmpdir: - resolve_path = lambda p: Path(tmpdir) / p + def resolve_path(p: str) -> Path: + return Path(tmpdir) / p file_path = resolve_path("notrailing.txt") file_path.write_text("aaa\nbbb\nccc", encoding="utf-8") @@ -321,7 +331,8 @@ def test_move_creates_directory(self) -> None: """Update file with Move to a path whose parent doesn't exist yet. Verify parent dir is created and file is moved.""" with tempfile.TemporaryDirectory() as tmpdir: - resolve_path = lambda p: Path(tmpdir) / p + def resolve_path(p: str) -> Path: + return Path(tmpdir) / p source = resolve_path("original.txt") source.write_text("foo\nbar\nbaz\n", encoding="utf-8") @@ -353,7 +364,8 @@ def test_fuzzy_whitespace_match(self) -> None: """File has indented lines, patch context collapses whitespace. Patch should still apply via fuzzy fallback.""" with tempfile.TemporaryDirectory() as tmpdir: - resolve_path = lambda p: Path(tmpdir) / p + def resolve_path(p: str) -> Path: + return Path(tmpdir) / p file_path = resolve_path("fuzzy.txt") file_path.write_text(" indented line\nnormal line\n", encoding="utf-8") diff --git a/tests/test_replay_log.py b/tests/test_replay_log.py index ff31e7a9..f43d7b76 100644 --- a/tests/test_replay_log.py +++ b/tests/test_replay_log.py @@ -8,6 +8,7 @@ from pathlib import Path from conftest import _tc + from agent.config import AgentConfig from agent.engine import RLMEngine from agent.model import ModelTurn, ScriptedModel @@ -89,7 +90,8 @@ def test_seq1_writes_delta(self) -> None: depth=0, step=1, messages=list(msgs_v1), response={"r": 1}, ) - msgs_v2 = msgs_v1 + [ + msgs_v2 = [ + *msgs_v1, {"role": "assistant", "content": "hello"}, {"role": "user", "content": "thanks"}, ] diff --git a/tests/test_session.py b/tests/test_session.py index 0b6428ef..5d2bda41 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -6,6 +6,7 @@ from pathlib import Path from conftest import _tc + from agent.config import AgentConfig from agent.engine import RLMEngine from agent.model import ModelTurn, ScriptedModel diff --git a/tests/test_session_complex.py b/tests/test_session_complex.py index ec9bccb5..611aa779 100644 --- a/tests/test_session_complex.py +++ b/tests/test_session_complex.py @@ -7,21 +7,22 @@ from pathlib import Path from conftest import _tc + from agent.config import AgentConfig -from agent.engine import ExternalContext, RLMEngine +from agent.engine import RLMEngine from agent.model import ModelTurn, ScriptedModel from agent.runtime import SessionError, SessionRuntime, SessionStore from agent.tools import WorkspaceTools def _make_config(root: Path, **overrides) -> AgentConfig: - defaults = dict( - workspace=root, - max_depth=2, - max_steps_per_call=12, - session_root_dir=".openplanter", - max_persisted_observations=400, - ) + defaults = { + "workspace": root, + "max_depth": 2, + "max_steps_per_call": 12, + "session_root_dir": ".openplanter", + "max_persisted_observations": 400, + } defaults.update(overrides) return AgentConfig(**defaults) diff --git a/tests/test_settings.py b/tests/test_settings.py index df569996..5bfee47c 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -96,7 +96,7 @@ def test_slash_e_filters(self) -> None: self.assertEqual(idx, -1) def test_slash_q_filters(self) -> None: - matches, idx = _compute_suggestions("/q") + matches, _idx = _compute_suggestions("/q") self.assertEqual(matches, ["/quit"]) def test_no_slash_no_suggestions(self) -> None: diff --git a/tests/test_streaming.py b/tests/test_streaming.py index ac031f85..b923ff8a 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -2,7 +2,6 @@ from __future__ import annotations import io -import json import socket import unittest from unittest.mock import MagicMock, patch @@ -202,7 +201,7 @@ def fake_urlopen(req, timeout=None): nonlocal call_count call_count += 1 if call_count < 3: - raise socket.timeout("timed out") + raise TimeoutError("timed out") # Return a successful response data = 'data: {"choices":[{"delta":{"content":"ok"},"finish_reason":"stop"}]}\n\ndata: [DONE]\n' resp = MagicMock() @@ -227,7 +226,7 @@ def fake_urlopen(req, timeout=None): def test_gives_up_after_max_retries(self) -> None: def fake_urlopen(req, timeout=None): - raise socket.timeout("timed out") + raise TimeoutError("timed out") with patch("agent.model.urllib.request.urlopen", fake_urlopen): with self.assertRaises(ModelError) as ctx: diff --git a/tests/test_tool_defs.py b/tests/test_tool_defs.py index 28f3fd8e..52354149 100644 --- a/tests/test_tool_defs.py +++ b/tests/test_tool_defs.py @@ -126,7 +126,7 @@ def test_array_property_items_preserved(self) -> None: strict = _make_strict_parameters(params) prop = strict["properties"]["urls"] self.assertIn("anyOf", prop) - array_variant = [a for a in prop["anyOf"] if a.get("type") == "array"][0] + array_variant = next(a for a in prop["anyOf"] if a.get("type") == "array") self.assertEqual(array_variant["items"], {"type": "string"}) def test_mixed_required_and_optional(self) -> None: diff --git a/tests/test_tools.py b/tests/test_tools.py index 844722e0..c5f98dbf 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -148,7 +148,6 @@ def test_read_file_hashline_mode(self) -> None: tools.write_file("hash.txt", "hello\nworld\n") result = tools.read_file("hash.txt", hashline=True) # Should have LINE:XX| format - import re lines = result.strip().splitlines() # First line is the header self.assertTrue(lines[0].startswith("# ")) diff --git a/tests/test_tools_complex.py b/tests/test_tools_complex.py index bfd85ddf..02826304 100644 --- a/tests/test_tools_complex.py +++ b/tests/test_tools_complex.py @@ -1,6 +1,5 @@ from __future__ import annotations -import json import tempfile import unittest from pathlib import Path diff --git a/tests/test_user_stories.py b/tests/test_user_stories.py index 4ab5cdaa..2b2ba194 100644 --- a/tests/test_user_stories.py +++ b/tests/test_user_stories.py @@ -12,29 +12,29 @@ import tempfile import unittest from pathlib import Path -from unittest.mock import patch from conftest import _tc + from agent.config import AgentConfig -from agent.engine import ExternalContext, RLMEngine +from agent.engine import RLMEngine from agent.model import ( EchoFallbackModel, ModelTurn, ScriptedModel, ) -from agent.runtime import SessionRuntime, SessionStore +from agent.runtime import SessionRuntime from agent.tools import WorkspaceTools def _make_config(root: Path, **overrides) -> AgentConfig: - defaults = dict( - workspace=root, - max_depth=3, - max_steps_per_call=12, - session_root_dir=".openplanter", - max_persisted_observations=400, - acceptance_criteria=False, - ) + defaults = { + "workspace": root, + "max_depth": 3, + "max_steps_per_call": 12, + "session_root_dir": ".openplanter", + "max_persisted_observations": 400, + "acceptance_criteria": False, + } defaults.update(overrides) return AgentConfig(**defaults) @@ -915,7 +915,7 @@ def test_model_switch_rebuilds_engine(self) -> None: # Engine should have been rebuilt self.assertIsNot(ctx.runtime.engine, old_engine) self.assertEqual(cfg.model, "gpt-5.2") - self.assertTrue(any("gpt-5.2" in l for l in lines)) + self.assertTrue(any("gpt-5.2" in line for line in lines)) def test_model_alias_resolution(self) -> None: """Aliases like 'opus' resolve to full model names.""" @@ -940,7 +940,7 @@ def test_model_alias_resolution(self) -> None: lines = handle_model_command("opus", ctx) self.assertEqual(cfg.model, "claude-opus-4-6") - self.assertTrue(any("alias" in l.lower() for l in lines)) + self.assertTrue(any("alias" in line.lower() for line in lines)) def test_reasoning_change_rebuilds_engine(self) -> None: from agent.builder import build_engine @@ -968,7 +968,7 @@ def test_reasoning_change_rebuilds_engine(self) -> None: self.assertIsNot(ctx.runtime.engine, old_engine) self.assertEqual(cfg.reasoning_effort, "low") - self.assertTrue(any("low" in l for l in lines)) + self.assertTrue(any("low" in line for line in lines)) def test_reasoning_off_disables(self) -> None: from agent.builder import build_engine From 08c9331e7936d39bda573d7b0c84122c16b2ad56 Mon Sep 17 00:00:00 2001 From: Matt Kneale Date: Fri, 20 Feb 2026 08:22:53 +0000 Subject: [PATCH 03/14] docs: add tool-handler registry design doc --- docs/plans/2026-02-20-tool-registry-design.md | 217 ++++++++++++++++++ 1 file changed, 217 insertions(+) create mode 100644 docs/plans/2026-02-20-tool-registry-design.md diff --git a/docs/plans/2026-02-20-tool-registry-design.md b/docs/plans/2026-02-20-tool-registry-design.md new file mode 100644 index 00000000..fd7f6a6d --- /dev/null +++ b/docs/plans/2026-02-20-tool-registry-design.md @@ -0,0 +1,217 @@ +# Tool-Handler Registry Refactoring + +**Date:** 2026-02-20 +**PR:** #7 + +--- + +## Problem + +`_apply_tool_call` in `agent/engine.py` was a 260-line `if/elif` chain dispatching 18 tool names: + +```python +def _apply_tool_call(self, name, args, depth, context, on_event, ...): + if name == "think": + ... + elif name == "list_files": + ... + elif name == "execute": + # 40 lines of complex logic + ... + # 15 more elif branches +``` + +Two concrete problems: + +1. **Adding a tool requires editing `_apply_tool_call`** — a 260-line method — to insert another `elif`. Every addition risks touching unrelated branches. +2. **The method has no coherent responsibility.** It mixes dispatch, argument extraction, error handling, and domain logic for 18 unrelated tools in one place. + +--- + +## Solution: Split Registry + +Replace the `if/elif` chain with a handler registry for the 17 stateless tools, while keeping the 2 complex engine-owned tools (`subtask`, `execute`) as explicit named methods. + +### Handler signature + +All registry handlers share a uniform signature: + +```python +Callable[[dict[str, Any]], tuple[bool, str]] +``` + +- Input: raw `args` dict from the tool call. +- Output: `(ok: bool, result: str)` — the same convention used throughout the engine. + +### Registry construction + +The registry is built once during `__post_init__` via `_build_tool_registry()`: + +```python +def __post_init__(self) -> None: + ... + self._tool_handlers: dict[str, Callable[[dict[str, Any]], tuple[bool, str]]] = ( + self._build_tool_registry() + ) + +def _build_tool_registry(self) -> dict[str, Callable[[dict[str, Any]], tuple[bool, str]]]: + return { + "think": self._handle_think, + "list_files": self._handle_list_files, + "search_files": self._handle_search_files, + "repo_map": self._handle_repo_map, + "web_search": self._handle_web_search, + "fetch_url": self._handle_fetch_url, + "read_file": self._handle_read_file, + "write_file": self._handle_write_file, + "apply_patch": self._handle_apply_patch, + "edit_file": self._handle_edit_file, + "hashline_edit": self._handle_hashline_edit, + "run_shell": self._handle_run_shell, + "run_shell_bg": self._handle_run_shell_bg, + "check_shell_bg": self._handle_check_shell_bg, + "kill_shell_bg": self._handle_kill_shell_bg, + "list_artifacts": self._handle_list_artifacts, + "read_artifact": self._handle_read_artifact, + } +``` + +Each `_handle_*` method extracts arguments from the dict and delegates to `self.tools.*`. Example: + +```python +def _handle_read_file(self, args: dict[str, Any]) -> tuple[bool, str]: + path = args.get("path", "") + offset = args.get("offset") + limit = args.get("limit") + return self.tools.read_file(path, offset=offset, limit=limit) +``` + +### `_apply_tool_call` after refactoring + +```python +def _apply_tool_call( + self, + name: str, + args: dict[str, Any], + depth: int, + context: list[Message], + on_event: EventCallback, + on_step: StepCallback, + deadline: float, + current_model: str, + replay_logger: ReplayLogger | None, + step: int, +) -> tuple[bool, str]: + # Policy check (unchanged) + if not self._tool_allowed(name): + return False, f"Tool '{name}' is not enabled." + + # Registry dispatch (17 stateless tools) + if name in self._tool_handlers: + return self._tool_handlers[name](args) + + # Explicit dispatch for call-time-dependent tools + if name == "subtask": + return self._apply_subtask(args, depth, context, on_event, on_step, + deadline, current_model, replay_logger) + if name == "execute": + return self._apply_execute(args, depth, context, on_event, on_step, + deadline, current_model, replay_logger, step) + + return False, f"Unknown tool: {name}" +``` + +### Why `subtask` and `execute` are outside the registry + +Both tools require call-time parameters that vary per invocation: + +| Parameter | Why it can't be pre-bound | +|-----------|--------------------------| +| `depth` | Tracks recursion level; changes each call | +| `context` | Current message history; mutated during a run | +| `on_event`, `on_step` | Caller-supplied callbacks; differ per invocation | +| `deadline` | Absolute timestamp; set at run start | +| `current_model`, `replay_logger` | Resolved at call time | +| `step` | Current step counter | + +Forcing these into `handler(args)` would require either a fat mutable context object passed to every handler (adding complexity everywhere for two edge cases) or lambda rebinding per call (closures on every tool invocation, creating noise for the common case). + +The split keeps the registry interface clean (`args` in, `(ok, str)` out) and isolates the complexity where it actually lives. + +--- + +## Tool Inventory + +**17 tools in registry:** + +| Tool | Handler method | +|------|---------------| +| `think` | `_handle_think` | +| `list_files` | `_handle_list_files` | +| `search_files` | `_handle_search_files` | +| `repo_map` | `_handle_repo_map` | +| `web_search` | `_handle_web_search` | +| `fetch_url` | `_handle_fetch_url` | +| `read_file` | `_handle_read_file` | +| `write_file` | `_handle_write_file` | +| `apply_patch` | `_handle_apply_patch` | +| `edit_file` | `_handle_edit_file` | +| `hashline_edit` | `_handle_hashline_edit` | +| `run_shell` | `_handle_run_shell` | +| `run_shell_bg` | `_handle_run_shell_bg` | +| `check_shell_bg` | `_handle_check_shell_bg` | +| `kill_shell_bg` | `_handle_kill_shell_bg` | +| `list_artifacts` | `_handle_list_artifacts` | +| `read_artifact` | `_handle_read_artifact` | + +**2 tools dispatched explicitly:** `subtask`, `execute` + +--- + +## Alternatives Considered + +### Uniform registry with `_CallCtx` dataclass + +Package all call-time parameters into a context object and give every handler the same extended signature: + +```python +Callable[[dict[str, Any], _CallCtx], tuple[bool, str]] +``` + +**Why rejected:** This passes a context object to all 17 stateless handlers that never use it. It introduces a new dataclass solely to handle two edge cases, and forces every future handler author to accept and thread a context parameter they don't need. + +### Closures at call time + +Rebind `subtask` and `execute` into the registry on each `_apply_tool_call` invocation: + +```python +handlers = {**self._tool_handlers, + "subtask": lambda args: self._apply_subtask(args, depth, ...), + "execute": lambda args: self._apply_execute(args, depth, ...)} +``` + +**Why rejected:** Creates two lambdas on every tool invocation, not just when those tools are called. The code is also harder to follow — the registry appears uniform but is quietly rebuilt every call. + +--- + +## How to Add a New Tool + +Three steps, no changes to `_apply_tool_call`: + +1. **Add a handler method** on `RLMEngine`: + + ```python + def _handle_my_tool(self, args: dict[str, Any]) -> tuple[bool, str]: + param = args.get("param", "") + return self.tools.my_tool(param) + ``` + +2. **Register it** in `_build_tool_registry()`: + + ```python + "my_tool": self._handle_my_tool, + ``` + +3. **Add its schema** to `tool_defs.py`. + +If the new tool needs call-time parameters (depth, context, callbacks), follow the `subtask`/`execute` pattern: add an `_apply_*` method and an explicit `if name == "..."` branch in `_apply_tool_call`. From 9dc283cfbd1bd4fdb097222634c9d2ff545238a6 Mon Sep 17 00:00:00 2001 From: Matt Kneale Date: Fri, 20 Feb 2026 08:23:48 +0000 Subject: [PATCH 04/14] refactor: replace if/elif dispatch chain with tool handler registry MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the 260-line if/elif chain in _apply_tool_call with a dict registry pattern. Each stateless tool gets a dedicated _handle_* method; subtask/execute are extracted into _apply_subtask/_apply_execute. Zero behavior change — all argument extraction and validation logic is copied verbatim from the original branches. --- agent/engine.py | 527 ++++++++++++++++++++++++++---------------------- 1 file changed, 290 insertions(+), 237 deletions(-) diff --git a/agent/engine.py b/agent/engine.py index 21b49f7d..32e6cb57 100644 --- a/agent/engine.py +++ b/agent/engine.py @@ -145,6 +145,7 @@ def __post_init__(self) -> None: tool_defs = get_tool_definitions(include_subtask=self.config.recursive, include_acceptance_criteria=ac) if hasattr(self.model, "tool_defs"): self.model.tool_defs = tool_defs + self._tool_handlers: dict[str, Callable[[dict[str, Any]], tuple[bool, str]]] = self._build_tool_registry() def solve(self, objective: str, on_event: EventCallback | None = None) -> str: result, _ = self.solve_with_context(objective=objective, on_event=on_event) @@ -636,6 +637,290 @@ def _run_one_tool( return ToolResult(tc.id, tc.name, observation, is_error=False), is_final + def _build_tool_registry(self) -> dict[str, Callable[[dict[str, Any]], tuple[bool, str]]]: + return { + "think": self._handle_think, + "list_files": self._handle_list_files, + "search_files": self._handle_search_files, + "repo_map": self._handle_repo_map, + "web_search": self._handle_web_search, + "fetch_url": self._handle_fetch_url, + "read_file": self._handle_read_file, + "write_file": self._handle_write_file, + "apply_patch": self._handle_apply_patch, + "edit_file": self._handle_edit_file, + "hashline_edit": self._handle_hashline_edit, + "run_shell": self._handle_run_shell, + "run_shell_bg": self._handle_run_shell_bg, + "check_shell_bg": self._handle_check_shell_bg, + "kill_shell_bg": self._handle_kill_shell_bg, + "list_artifacts": self._handle_list_artifacts, + "read_artifact": self._handle_read_artifact, + } + + def _handle_think(self, args: dict[str, Any]) -> tuple[bool, str]: + note = str(args.get("note", "")) + return False, f"Thought noted: {note}" + + def _handle_list_files(self, args: dict[str, Any]) -> tuple[bool, str]: + glob = args.get("glob") + return False, self.tools.list_files(glob=str(glob) if glob else None) + + def _handle_search_files(self, args: dict[str, Any]) -> tuple[bool, str]: + query = str(args.get("query", "")).strip() + glob = args.get("glob") + if not query: + return False, "search_files requires non-empty query" + return False, self.tools.search_files(query=query, glob=str(glob) if glob else None) + + def _handle_repo_map(self, args: dict[str, Any]) -> tuple[bool, str]: + glob = args.get("glob") + raw_max_files = args.get("max_files", 200) + max_files = raw_max_files if isinstance(raw_max_files, int) else 200 + return False, self.tools.repo_map(glob=str(glob) if glob else None, max_files=max_files) + + def _handle_web_search(self, args: dict[str, Any]) -> tuple[bool, str]: + query = str(args.get("query", "")).strip() + if not query: + return False, "web_search requires non-empty query" + raw_num_results = args.get("num_results", 10) + num_results = raw_num_results if isinstance(raw_num_results, int) else 10 + raw_include_text = args.get("include_text", False) + include_text = bool(raw_include_text) if isinstance(raw_include_text, bool) else False + return False, self.tools.web_search( + query=query, + num_results=num_results, + include_text=include_text, + ) + + def _handle_fetch_url(self, args: dict[str, Any]) -> tuple[bool, str]: + urls = args.get("urls") + if not isinstance(urls, list): + return False, "fetch_url requires a list of URL strings" + return False, self.tools.fetch_url([str(u) for u in urls if isinstance(u, str)]) + + def _handle_read_file(self, args: dict[str, Any]) -> tuple[bool, str]: + path = str(args.get("path", "")).strip() + if not path: + return False, "read_file requires path" + hashline = args.get("hashline") + hashline = hashline if hashline is not None else True + return False, self.tools.read_file(path, hashline=hashline) + + def _handle_write_file(self, args: dict[str, Any]) -> tuple[bool, str]: + path = str(args.get("path", "")).strip() + if not path: + return False, "write_file requires path" + content = str(args.get("content", "")) + return False, self.tools.write_file(path, content) + + def _handle_apply_patch(self, args: dict[str, Any]) -> tuple[bool, str]: + patch = str(args.get("patch", "")) + if not patch.strip(): + return False, "apply_patch requires non-empty patch" + return False, self.tools.apply_patch(patch) + + def _handle_edit_file(self, args: dict[str, Any]) -> tuple[bool, str]: + path = str(args.get("path", "")).strip() + if not path: + return False, "edit_file requires path" + old_text = str(args.get("old_text", "")) + new_text = str(args.get("new_text", "")) + if not old_text: + return False, "edit_file requires old_text" + return False, self.tools.edit_file(path, old_text, new_text) + + def _handle_hashline_edit(self, args: dict[str, Any]) -> tuple[bool, str]: + path = str(args.get("path", "")).strip() + if not path: + return False, "hashline_edit requires path" + edits = args.get("edits") + if not isinstance(edits, list): + return False, "hashline_edit requires edits array" + return False, self.tools.hashline_edit(path, edits) + + def _handle_run_shell(self, args: dict[str, Any]) -> tuple[bool, str]: + command = str(args.get("command", "")).strip() + if not command: + return False, "run_shell requires command" + raw_timeout = args.get("timeout") + timeout = int(raw_timeout) if raw_timeout is not None else None + return False, self.tools.run_shell(command, timeout=timeout) + + def _handle_run_shell_bg(self, args: dict[str, Any]) -> tuple[bool, str]: + command = str(args.get("command", "")).strip() + if not command: + return False, "run_shell_bg requires command" + return False, self.tools.run_shell_bg(command) + + def _handle_check_shell_bg(self, args: dict[str, Any]) -> tuple[bool, str]: + raw_id = args.get("job_id") + if raw_id is None: + return False, "check_shell_bg requires job_id" + return False, self.tools.check_shell_bg(int(raw_id)) + + def _handle_kill_shell_bg(self, args: dict[str, Any]) -> tuple[bool, str]: + raw_id = args.get("job_id") + if raw_id is None: + return False, "kill_shell_bg requires job_id" + return False, self.tools.kill_shell_bg(int(raw_id)) + + def _handle_list_artifacts(self, args: dict[str, Any]) -> tuple[bool, str]: + return False, self._list_artifacts() + + def _handle_read_artifact(self, args: dict[str, Any]) -> tuple[bool, str]: + aid = str(args.get("artifact_id", "")).strip() + if not aid: + return False, "read_artifact requires artifact_id" + offset = int(args.get("offset", 0) or 0) + limit = int(args.get("limit", 100) or 100) + return False, self._read_artifact(aid, offset, limit) + + def _apply_subtask( + self, + args: dict[str, Any], + depth: int, + context: ExternalContext, + on_event: EventCallback | None, + on_step: StepCallback | None, + deadline: float, + current_model: BaseModel | None, + replay_logger: ReplayLogger | None, + step: int, + ) -> tuple[bool, str]: + if not self.config.recursive: + return False, "Subtask tool not available in flat mode." + if depth >= self.config.max_depth: + return False, "Max recursion depth reached; cannot run subtask." + objective = str(args.get("objective", "")).strip() + if not objective: + return False, "subtask requires objective" + criteria = str(args.get("acceptance_criteria", "") or "").strip() + if self.config.acceptance_criteria and not criteria: + return False, ( + "subtask requires acceptance_criteria when acceptance criteria mode is enabled. " + "Provide specific, verifiable criteria for judging the result." + ) + + # Sub-model routing + requested_model_name = args.get("model") + requested_effort = args.get("reasoning_effort") + subtask_model: BaseModel | None = None + + if (requested_model_name or requested_effort) and self.model_factory: + cur = current_model or self.model + cur_name = getattr(cur, "model", "") + cur_effort = getattr(cur, "reasoning_effort", None) + cur_tier = _model_tier(cur_name, cur_effort) + + req_name = requested_model_name or cur_name + req_effort = requested_effort + req_tier = _model_tier(req_name, req_effort or cur_effort) + + if req_tier < cur_tier: + return False, ( + f"Cannot delegate to higher-tier model " + f"(current tier {cur_tier}, requested tier {req_tier}). " + f"Use an equal or lower-tier model." + ) + + cache_key = (req_name, requested_effort) + with self._lock: + if cache_key not in self._model_cache: + self._model_cache[cache_key] = self.model_factory(req_name, requested_effort) + subtask_model = self._model_cache[cache_key] + + self._emit(f"[d{depth}] >> entering subtask: {objective}", on_event) + child_logger = replay_logger.child(depth, step) if replay_logger else None + subtask_result = self._solve_recursive( + objective=objective, + depth=depth + 1, + context=context, + on_event=on_event, + on_step=on_step, + on_content_delta=None, + deadline=deadline, + model_override=subtask_model, + replay_logger=child_logger, + ) + observation = f"Subtask result for '{objective}':\n{subtask_result}" + + if criteria and self.config.acceptance_criteria: + verdict = self._judge_result(objective, criteria, subtask_result, current_model) + tag = "PASS" if verdict.startswith("PASS") else "FAIL" + observation += f"\n\n[ACCEPTANCE CRITERIA: {tag}]\n{verdict}" + + return False, observation + + def _apply_execute( + self, + args: dict[str, Any], + depth: int, + context: ExternalContext, + on_event: EventCallback | None, + on_step: StepCallback | None, + deadline: float, + current_model: BaseModel | None, + replay_logger: ReplayLogger | None, + step: int, + ) -> tuple[bool, str]: + objective = str(args.get("objective", "")).strip() + if not objective: + return False, "execute requires objective" + criteria = str(args.get("acceptance_criteria", "") or "").strip() + if self.config.acceptance_criteria and not criteria: + return False, ( + "execute requires acceptance_criteria when acceptance criteria mode is enabled. " + "Provide specific, verifiable criteria for judging the result." + ) + if depth >= self.config.max_depth: + return False, "Max recursion depth reached; cannot run execute." + + # Resolve lowest-tier model for the executor. + cur = current_model or self.model + cur_name = getattr(cur, "model", "") + exec_name, exec_effort = _lowest_tier_model(cur_name) + + exec_model: BaseModel | None = None + if self.model_factory: + cache_key = (exec_name, exec_effort) + with self._lock: + if cache_key not in self._model_cache: + self._model_cache[cache_key] = self.model_factory(exec_name, exec_effort) + exec_model = self._model_cache[cache_key] + + # Give executor full tools (no subtask, no execute). + _saved_defs = None + if exec_model and hasattr(exec_model, "tool_defs"): + exec_model.tool_defs = get_tool_definitions(include_subtask=False, include_acceptance_criteria=self.config.acceptance_criteria) + elif exec_model is None and hasattr(cur, "tool_defs"): + _saved_defs = cur.tool_defs + cur.tool_defs = get_tool_definitions(include_subtask=False, include_acceptance_criteria=self.config.acceptance_criteria) + + self._emit(f"[d{depth}] >> executing leaf: {objective}", on_event) + child_logger = replay_logger.child(depth, step) if replay_logger else None + exec_result = self._solve_recursive( + objective=objective, + depth=depth + 1, + context=context, + on_event=on_event, + on_step=on_step, + on_content_delta=None, + deadline=deadline, + model_override=exec_model, + replay_logger=child_logger, + ) + if _saved_defs is not None: + cur.tool_defs = _saved_defs # type: ignore[attr-defined] + observation = f"Execute result for '{objective}':\n{exec_result}" + + if criteria and self.config.acceptance_criteria: + verdict = self._judge_result(objective, criteria, exec_result, current_model) + tag = "PASS" if verdict.startswith("PASS") else "FAIL" + observation += f"\n\n[ACCEPTANCE CRITERIA: {tag}]\n{verdict}" + + return False, observation + def _apply_tool_call( self, tool_call: ToolCall, @@ -654,246 +939,14 @@ def _apply_tool_call( if policy_error: return False, policy_error - if name == "think": - note = str(args.get("note", "")) - return False, f"Thought noted: {note}" - - if name == "list_files": - glob = args.get("glob") - return False, self.tools.list_files(glob=str(glob) if glob else None) - - if name == "search_files": - query = str(args.get("query", "")).strip() - glob = args.get("glob") - if not query: - return False, "search_files requires non-empty query" - return False, self.tools.search_files(query=query, glob=str(glob) if glob else None) - - if name == "repo_map": - glob = args.get("glob") - raw_max_files = args.get("max_files", 200) - max_files = raw_max_files if isinstance(raw_max_files, int) else 200 - return False, self.tools.repo_map(glob=str(glob) if glob else None, max_files=max_files) - - if name == "web_search": - query = str(args.get("query", "")).strip() - if not query: - return False, "web_search requires non-empty query" - raw_num_results = args.get("num_results", 10) - num_results = raw_num_results if isinstance(raw_num_results, int) else 10 - raw_include_text = args.get("include_text", False) - include_text = bool(raw_include_text) if isinstance(raw_include_text, bool) else False - return False, self.tools.web_search( - query=query, - num_results=num_results, - include_text=include_text, - ) - - if name == "fetch_url": - urls = args.get("urls") - if not isinstance(urls, list): - return False, "fetch_url requires a list of URL strings" - return False, self.tools.fetch_url([str(u) for u in urls if isinstance(u, str)]) - - if name == "read_file": - path = str(args.get("path", "")).strip() - if not path: - return False, "read_file requires path" - hashline = args.get("hashline") - hashline = hashline if hashline is not None else True - return False, self.tools.read_file(path, hashline=hashline) - - if name == "write_file": - path = str(args.get("path", "")).strip() - if not path: - return False, "write_file requires path" - content = str(args.get("content", "")) - return False, self.tools.write_file(path, content) - - if name == "apply_patch": - patch = str(args.get("patch", "")) - if not patch.strip(): - return False, "apply_patch requires non-empty patch" - return False, self.tools.apply_patch(patch) - - if name == "edit_file": - path = str(args.get("path", "")).strip() - if not path: - return False, "edit_file requires path" - old_text = str(args.get("old_text", "")) - new_text = str(args.get("new_text", "")) - if not old_text: - return False, "edit_file requires old_text" - return False, self.tools.edit_file(path, old_text, new_text) - - if name == "hashline_edit": - path = str(args.get("path", "")).strip() - if not path: - return False, "hashline_edit requires path" - edits = args.get("edits") - if not isinstance(edits, list): - return False, "hashline_edit requires edits array" - return False, self.tools.hashline_edit(path, edits) - - if name == "run_shell": - command = str(args.get("command", "")).strip() - if not command: - return False, "run_shell requires command" - raw_timeout = args.get("timeout") - timeout = int(raw_timeout) if raw_timeout is not None else None - return False, self.tools.run_shell(command, timeout=timeout) - - if name == "run_shell_bg": - command = str(args.get("command", "")).strip() - if not command: - return False, "run_shell_bg requires command" - return False, self.tools.run_shell_bg(command) - - if name == "check_shell_bg": - raw_id = args.get("job_id") - if raw_id is None: - return False, "check_shell_bg requires job_id" - return False, self.tools.check_shell_bg(int(raw_id)) - - if name == "kill_shell_bg": - raw_id = args.get("job_id") - if raw_id is None: - return False, "kill_shell_bg requires job_id" - return False, self.tools.kill_shell_bg(int(raw_id)) + handler = self._tool_handlers.get(name) + if handler is not None: + return handler(args) if name == "subtask": - if not self.config.recursive: - return False, "Subtask tool not available in flat mode." - if depth >= self.config.max_depth: - return False, "Max recursion depth reached; cannot run subtask." - objective = str(args.get("objective", "")).strip() - if not objective: - return False, "subtask requires objective" - criteria = str(args.get("acceptance_criteria", "") or "").strip() - if self.config.acceptance_criteria and not criteria: - return False, ( - "subtask requires acceptance_criteria when acceptance criteria mode is enabled. " - "Provide specific, verifiable criteria for judging the result." - ) - - # Sub-model routing - requested_model_name = args.get("model") - requested_effort = args.get("reasoning_effort") - subtask_model: BaseModel | None = None - - if (requested_model_name or requested_effort) and self.model_factory: - cur = current_model or self.model - cur_name = getattr(cur, "model", "") - cur_effort = getattr(cur, "reasoning_effort", None) - cur_tier = _model_tier(cur_name, cur_effort) - - req_name = requested_model_name or cur_name - req_effort = requested_effort - req_tier = _model_tier(req_name, req_effort or cur_effort) - - if req_tier < cur_tier: - return False, ( - f"Cannot delegate to higher-tier model " - f"(current tier {cur_tier}, requested tier {req_tier}). " - f"Use an equal or lower-tier model." - ) - - cache_key = (req_name, requested_effort) - with self._lock: - if cache_key not in self._model_cache: - self._model_cache[cache_key] = self.model_factory(req_name, requested_effort) - subtask_model = self._model_cache[cache_key] - - self._emit(f"[d{depth}] >> entering subtask: {objective}", on_event) - child_logger = replay_logger.child(depth, step) if replay_logger else None - subtask_result = self._solve_recursive( - objective=objective, - depth=depth + 1, - context=context, - on_event=on_event, - on_step=on_step, - on_content_delta=None, - deadline=deadline, - model_override=subtask_model, - replay_logger=child_logger, - ) - observation = f"Subtask result for '{objective}':\n{subtask_result}" - - if criteria and self.config.acceptance_criteria: - verdict = self._judge_result(objective, criteria, subtask_result, current_model) - tag = "PASS" if verdict.startswith("PASS") else "FAIL" - observation += f"\n\n[ACCEPTANCE CRITERIA: {tag}]\n{verdict}" - - return False, observation - + return self._apply_subtask(args, depth, context, on_event, on_step, deadline, current_model, replay_logger, step) if name == "execute": - objective = str(args.get("objective", "")).strip() - if not objective: - return False, "execute requires objective" - criteria = str(args.get("acceptance_criteria", "") or "").strip() - if self.config.acceptance_criteria and not criteria: - return False, ( - "execute requires acceptance_criteria when acceptance criteria mode is enabled. " - "Provide specific, verifiable criteria for judging the result." - ) - if depth >= self.config.max_depth: - return False, "Max recursion depth reached; cannot run execute." - - # Resolve lowest-tier model for the executor. - cur = current_model or self.model - cur_name = getattr(cur, "model", "") - exec_name, exec_effort = _lowest_tier_model(cur_name) - - exec_model: BaseModel | None = None - if self.model_factory: - cache_key = (exec_name, exec_effort) - with self._lock: - if cache_key not in self._model_cache: - self._model_cache[cache_key] = self.model_factory(exec_name, exec_effort) - exec_model = self._model_cache[cache_key] - - # Give executor full tools (no subtask, no execute). - _saved_defs = None - if exec_model and hasattr(exec_model, "tool_defs"): - exec_model.tool_defs = get_tool_definitions(include_subtask=False, include_acceptance_criteria=self.config.acceptance_criteria) - elif exec_model is None and hasattr(cur, "tool_defs"): - _saved_defs = cur.tool_defs - cur.tool_defs = get_tool_definitions(include_subtask=False, include_acceptance_criteria=self.config.acceptance_criteria) - - self._emit(f"[d{depth}] >> executing leaf: {objective}", on_event) - child_logger = replay_logger.child(depth, step) if replay_logger else None - exec_result = self._solve_recursive( - objective=objective, - depth=depth + 1, - context=context, - on_event=on_event, - on_step=on_step, - on_content_delta=None, - deadline=deadline, - model_override=exec_model, - replay_logger=child_logger, - ) - if _saved_defs is not None: - cur.tool_defs = _saved_defs # type: ignore[attr-defined] - observation = f"Execute result for '{objective}':\n{exec_result}" - - if criteria and self.config.acceptance_criteria: - verdict = self._judge_result(objective, criteria, exec_result, current_model) - tag = "PASS" if verdict.startswith("PASS") else "FAIL" - observation += f"\n\n[ACCEPTANCE CRITERIA: {tag}]\n{verdict}" - - return False, observation - - if name == "list_artifacts": - return False, self._list_artifacts() - - if name == "read_artifact": - aid = str(args.get("artifact_id", "")).strip() - if not aid: - return False, "read_artifact requires artifact_id" - offset = int(args.get("offset", 0) or 0) - limit = int(args.get("limit", 100) or 100) - return False, self._read_artifact(aid, offset, limit) + return self._apply_execute(args, depth, context, on_event, on_step, deadline, current_model, replay_logger, step) return False, f"Unknown action type: {name}" From 097817ee47c38f385baa825ab8780a2114182437 Mon Sep 17 00:00:00 2001 From: Matt Kneale Date: Fri, 20 Feb 2026 15:27:52 +0000 Subject: [PATCH 05/14] feat: add gemini_api_key to CredentialBundle --- agent/credentials.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/agent/credentials.py b/agent/credentials.py index 36a8bf7f..c755956f 100644 --- a/agent/credentials.py +++ b/agent/credentials.py @@ -17,6 +17,7 @@ class CredentialBundle: cerebras_api_key: str | None = None exa_api_key: str | None = None voyage_api_key: str | None = None + gemini_api_key: str | None = None def has_any(self) -> bool: return bool( @@ -26,6 +27,7 @@ def has_any(self) -> bool: or (self.cerebras_api_key and self.cerebras_api_key.strip()) or (self.exa_api_key and self.exa_api_key.strip()) or (self.voyage_api_key and self.voyage_api_key.strip()) + or (self.gemini_api_key and self.gemini_api_key.strip()) ) def merge_missing(self, other: CredentialBundle) -> None: @@ -41,6 +43,8 @@ def merge_missing(self, other: CredentialBundle) -> None: self.exa_api_key = other.exa_api_key if not self.voyage_api_key and other.voyage_api_key: self.voyage_api_key = other.voyage_api_key + if not self.gemini_api_key and other.gemini_api_key: + self.gemini_api_key = other.gemini_api_key def to_json(self) -> dict[str, str]: out: dict[str, str] = {} @@ -56,6 +60,8 @@ def to_json(self) -> dict[str, str]: out["exa_api_key"] = self.exa_api_key if self.voyage_api_key: out["voyage_api_key"] = self.voyage_api_key + if self.gemini_api_key: + out["gemini_api_key"] = self.gemini_api_key return out @classmethod @@ -69,6 +75,7 @@ def from_json(cls, payload: dict[str, str] | None) -> CredentialBundle: cerebras_api_key=(payload.get("cerebras_api_key") or "").strip() or None, exa_api_key=(payload.get("exa_api_key") or "").strip() or None, voyage_api_key=(payload.get("voyage_api_key") or "").strip() or None, + gemini_api_key=(payload.get("gemini_api_key") or "").strip() or None, ) @@ -111,6 +118,7 @@ def parse_env_file(path: Path) -> CredentialBundle: or None, exa_api_key=(env.get("EXA_API_KEY") or env.get("OPENPLANTER_EXA_API_KEY") or "").strip() or None, voyage_api_key=(env.get("VOYAGE_API_KEY") or env.get("OPENPLANTER_VOYAGE_API_KEY") or "").strip() or None, + gemini_api_key=(env.get("GEMINI_API_KEY") or env.get("OPENPLANTER_GEMINI_API_KEY") or env.get("GOOGLE_API_KEY") or "").strip() or None, ) @@ -136,6 +144,12 @@ def credentials_from_env() -> CredentialBundle: or None, exa_api_key=(os.getenv("OPENPLANTER_EXA_API_KEY") or os.getenv("EXA_API_KEY") or "").strip() or None, voyage_api_key=(os.getenv("OPENPLANTER_VOYAGE_API_KEY") or os.getenv("VOYAGE_API_KEY") or "").strip() or None, + gemini_api_key=( + os.getenv("OPENPLANTER_GEMINI_API_KEY") + or os.getenv("GEMINI_API_KEY") + or os.getenv("GOOGLE_API_KEY") + or "" + ).strip() or None, ) @@ -231,6 +245,7 @@ def prompt_for_credentials( cerebras_api_key=existing.cerebras_api_key, exa_api_key=existing.exa_api_key, voyage_api_key=existing.voyage_api_key, + gemini_api_key=existing.gemini_api_key, ) should_prompt = force or not current.has_any() @@ -264,6 +279,7 @@ def _ask(label: str, existing_value: str | None) -> str | None: current.cerebras_api_key = _ask("Cerebras", current.cerebras_api_key) current.exa_api_key = _ask("Exa", current.exa_api_key) current.voyage_api_key = _ask("Voyage", current.voyage_api_key) + current.gemini_api_key = _ask("Gemini", current.gemini_api_key) if not force and current.has_any() and not existing.has_any(): changed = True return current, changed From ed3a66a96a85a649a289cf17377826a15137ed91 Mon Sep 17 00:00:00 2001 From: Matt Kneale Date: Fri, 20 Feb 2026 15:28:48 +0000 Subject: [PATCH 06/14] feat: add gemini_api_key and gemini_base_url to AgentConfig --- agent/config.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/agent/config.py b/agent/config.py index deebc823..df880a7c 100644 --- a/agent/config.py +++ b/agent/config.py @@ -9,6 +9,7 @@ "anthropic": "claude-opus-4-6", "openrouter": "anthropic/claude-sonnet-4-5", "cerebras": "qwen-3-235b-a22b-instruct-2507", + "gemini": "gemini-2.5-flash", } @@ -31,6 +32,8 @@ class AgentConfig: cerebras_api_key: str | None = None exa_api_key: str | None = None voyage_api_key: str | None = None + gemini_api_key: str | None = None + gemini_base_url: str = "https://generativelanguage.googleapis.com/v1beta/openai" max_depth: int = 4 max_steps_per_call: int = 100 max_observation_chars: int = 6000 @@ -61,6 +64,11 @@ def from_env(cls, workspace: str | Path) -> AgentConfig: cerebras_api_key = os.getenv("OPENPLANTER_CEREBRAS_API_KEY") or os.getenv("CEREBRAS_API_KEY") exa_api_key = os.getenv("OPENPLANTER_EXA_API_KEY") or os.getenv("EXA_API_KEY") voyage_api_key = os.getenv("OPENPLANTER_VOYAGE_API_KEY") or os.getenv("VOYAGE_API_KEY") + gemini_api_key = ( + os.getenv("OPENPLANTER_GEMINI_API_KEY") + or os.getenv("GEMINI_API_KEY") + or os.getenv("GOOGLE_API_KEY") + ) openai_base_url = os.getenv("OPENPLANTER_OPENAI_BASE_URL") or os.getenv( "OPENPLANTER_BASE_URL", "https://api.openai.com/v1", @@ -83,6 +91,8 @@ def from_env(cls, workspace: str | Path) -> AgentConfig: cerebras_api_key=cerebras_api_key, exa_api_key=exa_api_key, voyage_api_key=voyage_api_key, + gemini_api_key=gemini_api_key, + gemini_base_url=os.getenv("OPENPLANTER_GEMINI_BASE_URL", "https://generativelanguage.googleapis.com/v1beta/openai"), max_depth=int(os.getenv("OPENPLANTER_MAX_DEPTH", "4")), max_steps_per_call=int(os.getenv("OPENPLANTER_MAX_STEPS", "100")), max_observation_chars=int(os.getenv("OPENPLANTER_MAX_OBS_CHARS", "6000")), From d4e9ae02266de87e930388cd72021fa846f89b78 Mon Sep 17 00:00:00 2001 From: Matt Kneale Date: Fri, 20 Feb 2026 15:29:43 +0000 Subject: [PATCH 07/14] feat: wire Gemini provider into builder --- agent/builder.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/agent/builder.py b/agent/builder.py index 675614d1..3adb5d0c 100644 --- a/agent/builder.py +++ b/agent/builder.py @@ -27,6 +27,7 @@ _ANTHROPIC_RE = re.compile(r"^claude", re.IGNORECASE) _OPENAI_RE = re.compile(r"^(gpt|o[1-4]-|o[1-4]$|chatgpt|dall-e|tts-|whisper)", re.IGNORECASE) _CEREBRAS_RE = re.compile(r"^(llama.*cerebras|qwen-3|gpt-oss|zai-glm)", re.IGNORECASE) +_GEMINI_RE = re.compile(r"^gemini", re.IGNORECASE) def infer_provider_for_model(model: str) -> str | None: @@ -39,6 +40,8 @@ def infer_provider_for_model(model: str) -> str | None: return "cerebras" if _OPENAI_RE.search(model): return "openai" + if _GEMINI_RE.search(model): + return "gemini" return None @@ -73,6 +76,10 @@ def _fetch_models_for_provider(cfg: AgentConfig, provider: str) -> list[dict[str if not cfg.cerebras_api_key: raise ModelError("Cerebras key not configured.") return list_openai_models(api_key=cfg.cerebras_api_key, base_url=cfg.cerebras_base_url) + if provider == "gemini": + if not cfg.gemini_api_key: + raise ModelError("Gemini key not configured.") + return list_openai_models(api_key=cfg.gemini_api_key, base_url=cfg.gemini_base_url) raise ModelError(f"Unknown provider: {provider}") @@ -128,9 +135,17 @@ def _factory(model_name: str, reasoning_effort: str | None = None) -> AnthropicM base_url=cfg.cerebras_base_url, reasoning_effort=effort, ) + if provider == "gemini" and cfg.gemini_api_key: + return OpenAICompatibleModel( + model=model_name, + api_key=cfg.gemini_api_key, + base_url=cfg.gemini_base_url, + reasoning_effort=effort, + strict_tools=False, + ) raise ModelError(f"No API key available for model '{model_name}' (provider={provider})") - if cfg.anthropic_api_key or cfg.openai_api_key or cfg.openrouter_api_key or cfg.cerebras_api_key: + if cfg.anthropic_api_key or cfg.openai_api_key or cfg.openrouter_api_key or cfg.cerebras_api_key or cfg.gemini_api_key: return _factory return None @@ -181,6 +196,14 @@ def build_engine(cfg: AgentConfig) -> RLMEngine: base_url=cfg.cerebras_base_url, reasoning_effort=cfg.reasoning_effort, ) + elif cfg.provider == "gemini" and cfg.gemini_api_key: + model = OpenAICompatibleModel( + model=model_name, + api_key=cfg.gemini_api_key, + base_url=cfg.gemini_base_url, + reasoning_effort=cfg.reasoning_effort, + strict_tools=False, + ) elif cfg.provider == "anthropic" and cfg.anthropic_api_key: model = AnthropicModel( model=model_name, From fbd6e8612542c6d4b233421562d3c8add06de195 Mon Sep 17 00:00:00 2001 From: Matt Kneale Date: Fri, 20 Feb 2026 15:31:17 +0000 Subject: [PATCH 08/14] feat: add Gemini model tier mapping and context windows --- agent/engine.py | 13 +++++++++++++ agent/model.py | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/agent/engine.py b/agent/engine.py index 32e6cb57..dfc2f8fa 100644 --- a/agent/engine.py +++ b/agent/engine.py @@ -60,6 +60,11 @@ def _summarize_observation(text: str, max_len: int = 200) -> str: "gpt-4o": 128_000, "gpt-4.1": 1_000_000, "gpt-5-turbo-16k": 16_000, + "gemini-2.5-pro": 1_000_000, + "gemini-2.5-flash": 1_000_000, + "gemini-3-flash": 1_000_000, + "gemini-2.0-flash": 1_000_000, + "gemini-2.0-flash-lite": 1_000_000, } _DEFAULT_CONTEXT_WINDOW = 128_000 _CONDENSATION_THRESHOLD = 0.75 @@ -81,6 +86,12 @@ def _model_tier(model_name: str, reasoning_effort: str | None = None) -> int: return 2 if "haiku" in lower: return 3 + if "gemini" in lower: + if "pro" in lower: + return 1 + if "lite" in lower: + return 3 + return 2 # flash variants if lower.startswith("gpt-5") and "codex" in lower: effort = (reasoning_effort or "").lower() return {"xhigh": 1, "high": 2, "medium": 3, "low": 4}.get(effort, 2) @@ -95,6 +106,8 @@ def _lowest_tier_model(model_name: str) -> tuple[str, str | None]: lower = model_name.lower() if "claude" in lower: return ("claude-haiku-4-5-20251001", None) + if "gemini" in lower: + return ("gemini-2.0-flash-lite", None) return (model_name, None) diff --git a/agent/model.py b/agent/model.py index 0a5be394..07931fd5 100644 --- a/agent/model.py +++ b/agent/model.py @@ -1004,7 +1004,7 @@ def condense_conversation(self, conversation: Conversation, keep_recent_turns: i @dataclass class EchoFallbackModel: note: str = ( - "No provider API keys configured. Set OpenAI/Anthropic/OpenRouter keys to use a live LLM." + "No provider API keys configured. Set OpenAI/Anthropic/OpenRouter/Gemini keys to use a live LLM." ) def create_conversation(self, system_prompt: str, initial_user_message: str) -> Conversation: From 6a4a40a62dfdfac6cd934e7cd09f3dc8b2599475 Mon Sep 17 00:00:00 2001 From: Matt Kneale Date: Fri, 20 Feb 2026 15:33:42 +0000 Subject: [PATCH 09/14] fix: OPENPLANTER_GEMINI_API_KEY takes priority in parse_env_file --- agent/credentials.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agent/credentials.py b/agent/credentials.py index c755956f..25abe8e5 100644 --- a/agent/credentials.py +++ b/agent/credentials.py @@ -118,7 +118,7 @@ def parse_env_file(path: Path) -> CredentialBundle: or None, exa_api_key=(env.get("EXA_API_KEY") or env.get("OPENPLANTER_EXA_API_KEY") or "").strip() or None, voyage_api_key=(env.get("VOYAGE_API_KEY") or env.get("OPENPLANTER_VOYAGE_API_KEY") or "").strip() or None, - gemini_api_key=(env.get("GEMINI_API_KEY") or env.get("OPENPLANTER_GEMINI_API_KEY") or env.get("GOOGLE_API_KEY") or "").strip() or None, + gemini_api_key=(env.get("OPENPLANTER_GEMINI_API_KEY") or env.get("GEMINI_API_KEY") or env.get("GOOGLE_API_KEY") or "").strip() or None, ) From 15ff7b006c53661cfce24f9fdd73369350ac9fe4 Mon Sep 17 00:00:00 2001 From: Matt Kneale Date: Fri, 20 Feb 2026 15:37:26 +0000 Subject: [PATCH 10/14] test: add comprehensive Gemini provider test suite --- tests/test_gemini.py | 194 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 194 insertions(+) create mode 100644 tests/test_gemini.py diff --git a/tests/test_gemini.py b/tests/test_gemini.py new file mode 100644 index 00000000..6d1606c8 --- /dev/null +++ b/tests/test_gemini.py @@ -0,0 +1,194 @@ +"""Tests for the Gemini provider integration (no live API calls).""" +from __future__ import annotations + +import json +import tempfile +from pathlib import Path + +import pytest + +from agent.builder import build_engine, build_model_factory, infer_provider_for_model +from agent.config import PROVIDER_DEFAULT_MODELS, AgentConfig +from agent.credentials import CredentialBundle, credentials_from_env, parse_env_file +from agent.engine import _lowest_tier_model, _model_tier +from agent.model import OpenAICompatibleModel + + +# --------------------------------------------------------------------------- +# Credentials +# --------------------------------------------------------------------------- + +class TestGeminiCredentials: + def test_bundle_field_exists(self): + cb = CredentialBundle(gemini_api_key="AIzaSy-x") + assert cb.gemini_api_key == "AIzaSy-x" + + def test_has_any_with_gemini_only(self): + assert CredentialBundle(gemini_api_key="key").has_any() + + def test_merge_missing(self): + a = CredentialBundle() + b = CredentialBundle(gemini_api_key="key-b") + a.merge_missing(b) + assert a.gemini_api_key == "key-b" + + def test_merge_does_not_overwrite(self): + a = CredentialBundle(gemini_api_key="key-a") + b = CredentialBundle(gemini_api_key="key-b") + a.merge_missing(b) + assert a.gemini_api_key == "key-a" + + def test_to_from_json_roundtrip(self): + cb = CredentialBundle(gemini_api_key="AIzaSy-x") + j = cb.to_json() + assert j["gemini_api_key"] == "AIzaSy-x" + cb2 = CredentialBundle.from_json(j) + assert cb2.gemini_api_key == "AIzaSy-x" + + def test_from_json_missing_key_is_none(self): + cb = CredentialBundle.from_json({}) + assert cb.gemini_api_key is None + + def test_parse_env_file_gemini_api_key(self, tmp_path): + env = tmp_path / ".env" + env.write_text("GEMINI_API_KEY=AIzaSy-file\n") + cb = parse_env_file(env) + assert cb.gemini_api_key == "AIzaSy-file" + + def test_parse_env_file_openplanter_prefix(self, tmp_path): + env = tmp_path / ".env" + env.write_text("OPENPLANTER_GEMINI_API_KEY=AIzaSy-prefixed\n") + cb = parse_env_file(env) + assert cb.gemini_api_key == "AIzaSy-prefixed" + + def test_parse_env_file_google_api_key_fallback(self, tmp_path): + env = tmp_path / ".env" + env.write_text("GOOGLE_API_KEY=AIzaSy-google\n") + cb = parse_env_file(env) + assert cb.gemini_api_key == "AIzaSy-google" + + def test_credentials_from_env_gemini(self, monkeypatch): + monkeypatch.setenv("GEMINI_API_KEY", "AIzaSy-env") + cb = credentials_from_env() + assert cb.gemini_api_key == "AIzaSy-env" + + def test_credentials_from_env_google_api_key_fallback(self, monkeypatch): + monkeypatch.delenv("GEMINI_API_KEY", raising=False) + monkeypatch.delenv("OPENPLANTER_GEMINI_API_KEY", raising=False) + monkeypatch.setenv("GOOGLE_API_KEY", "AIzaSy-google") + cb = credentials_from_env() + assert cb.gemini_api_key == "AIzaSy-google" + + def test_credentials_from_env_openplanter_prefix_wins(self, monkeypatch): + monkeypatch.setenv("OPENPLANTER_GEMINI_API_KEY", "AIzaSy-prefixed") + monkeypatch.setenv("GEMINI_API_KEY", "AIzaSy-plain") + cb = credentials_from_env() + assert cb.gemini_api_key == "AIzaSy-prefixed" + + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + +class TestGeminiConfig: + def test_default_model_in_provider_defaults(self): + assert "gemini" in PROVIDER_DEFAULT_MODELS + assert PROVIDER_DEFAULT_MODELS["gemini"].startswith("gemini-") + + def test_config_has_gemini_fields(self, tmp_path): + cfg = AgentConfig(workspace=tmp_path, gemini_api_key="key") + assert cfg.gemini_api_key == "key" + assert "generativelanguage.googleapis.com" in cfg.gemini_base_url + + def test_from_env_reads_gemini_api_key(self, monkeypatch, tmp_path): + monkeypatch.setenv("GEMINI_API_KEY", "AIzaSy-env") + cfg = AgentConfig.from_env(tmp_path) + assert cfg.gemini_api_key == "AIzaSy-env" + + +# --------------------------------------------------------------------------- +# Builder +# --------------------------------------------------------------------------- + +class TestGeminiBuilder: + def test_infer_provider_flash(self): + assert infer_provider_for_model("gemini-2.5-flash") == "gemini" + + def test_infer_provider_pro(self): + assert infer_provider_for_model("gemini-3-pro") == "gemini" + + def test_infer_provider_lite(self): + assert infer_provider_for_model("gemini-2.0-flash-lite") == "gemini" + + def test_infer_provider_openrouter_gemini_not_matched(self): + # google/gemini-* should go to openrouter, not gemini + assert infer_provider_for_model("google/gemini-2.5-flash") == "openrouter" + + def test_build_engine_returns_openai_compatible_model(self, tmp_path): + cfg = AgentConfig( + workspace=tmp_path, + provider="gemini", + model="gemini-2.5-flash", + gemini_api_key="AIzaSy-test", + ) + engine = build_engine(cfg) + assert isinstance(engine.model, OpenAICompatibleModel) + + def test_build_engine_strict_tools_false(self, tmp_path): + cfg = AgentConfig( + workspace=tmp_path, + provider="gemini", + model="gemini-2.5-flash", + gemini_api_key="AIzaSy-test", + ) + engine = build_engine(cfg) + assert engine.model.strict_tools is False + + def test_build_engine_correct_base_url(self, tmp_path): + cfg = AgentConfig( + workspace=tmp_path, + provider="gemini", + model="gemini-2.5-flash", + gemini_api_key="AIzaSy-test", + ) + engine = build_engine(cfg) + assert "generativelanguage.googleapis.com" in engine.model.base_url + + def test_model_factory_returns_non_none_with_only_gemini_key(self, tmp_path): + cfg = AgentConfig(workspace=tmp_path, gemini_api_key="AIzaSy-test") + factory = build_model_factory(cfg) + assert factory is not None + + def test_model_factory_creates_gemini_model(self, tmp_path): + cfg = AgentConfig(workspace=tmp_path, gemini_api_key="AIzaSy-test") + factory = build_model_factory(cfg) + assert factory is not None + m = factory("gemini-2.5-flash") + assert isinstance(m, OpenAICompatibleModel) + assert m.strict_tools is False + + +# --------------------------------------------------------------------------- +# Model tier +# --------------------------------------------------------------------------- + +class TestGeminiModelTier: + def test_pro_is_tier_1(self): + assert _model_tier("gemini-2.5-pro") == 1 + assert _model_tier("gemini-3-pro-preview") == 1 + + def test_flash_is_tier_2(self): + assert _model_tier("gemini-2.5-flash") == 2 + assert _model_tier("gemini-3-flash") == 2 + + def test_lite_is_tier_3(self): + assert _model_tier("gemini-2.0-flash-lite") == 3 + + def test_lowest_tier_is_flash_lite(self): + name, effort = _lowest_tier_model("gemini-2.5-pro") + assert name == "gemini-2.0-flash-lite" + assert effort is None + + def test_lowest_tier_flash_model_also_gives_lite(self): + name, _ = _lowest_tier_model("gemini-2.5-flash") + assert name == "gemini-2.0-flash-lite" From 79d6e4c46d26963f553eca42719fc82899ca4c35 Mon Sep 17 00:00:00 2001 From: Matt Kneale Date: Fri, 20 Feb 2026 15:40:42 +0000 Subject: [PATCH 11/14] fix: remove unused imports and redundant test in test_gemini.py --- tests/test_gemini.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/tests/test_gemini.py b/tests/test_gemini.py index 6d1606c8..8e759454 100644 --- a/tests/test_gemini.py +++ b/tests/test_gemini.py @@ -1,19 +1,12 @@ """Tests for the Gemini provider integration (no live API calls).""" from __future__ import annotations -import json -import tempfile -from pathlib import Path - -import pytest - from agent.builder import build_engine, build_model_factory, infer_provider_for_model from agent.config import PROVIDER_DEFAULT_MODELS, AgentConfig from agent.credentials import CredentialBundle, credentials_from_env, parse_env_file from agent.engine import _lowest_tier_model, _model_tier from agent.model import OpenAICompatibleModel - # --------------------------------------------------------------------------- # Credentials # --------------------------------------------------------------------------- @@ -154,11 +147,6 @@ def test_build_engine_correct_base_url(self, tmp_path): engine = build_engine(cfg) assert "generativelanguage.googleapis.com" in engine.model.base_url - def test_model_factory_returns_non_none_with_only_gemini_key(self, tmp_path): - cfg = AgentConfig(workspace=tmp_path, gemini_api_key="AIzaSy-test") - factory = build_model_factory(cfg) - assert factory is not None - def test_model_factory_creates_gemini_model(self, tmp_path): cfg = AgentConfig(workspace=tmp_path, gemini_api_key="AIzaSy-test") factory = build_model_factory(cfg) From d8804c82480d13c18d80ae684892f2100322c8f2 Mon Sep 17 00:00:00 2001 From: Matt Kneale Date: Fri, 20 Feb 2026 15:42:02 +0000 Subject: [PATCH 12/14] docs: add Gemini provider design doc --- .../2026-02-20-gemini-provider-design.md | 167 ++++++++++++++++++ 1 file changed, 167 insertions(+) create mode 100644 docs/plans/2026-02-20-gemini-provider-design.md diff --git a/docs/plans/2026-02-20-gemini-provider-design.md b/docs/plans/2026-02-20-gemini-provider-design.md new file mode 100644 index 00000000..b8faaa54 --- /dev/null +++ b/docs/plans/2026-02-20-gemini-provider-design.md @@ -0,0 +1,167 @@ +# Gemini Provider + +**Date:** 2026-02-20 +**PR:** #8 + +--- + +## Problem + +OpenPlanter supported four providers (OpenAI, Anthropic, OpenRouter, Cerebras). Gemini models are widely used and have no support. + +--- + +## Solution + +Google exposes an OpenAI-compatible REST endpoint: + +``` +https://generativelanguage.googleapis.com/v1beta/openai +``` + +`OpenAICompatibleModel` can be pointed at this URL with no new model class. The only Gemini-specific requirement is `strict_tools=False`: Google's compatibility layer does not enforce `additionalProperties: false` / strict-mode schemas. + +--- + +## Files Changed + +### `agent/credentials.py` + +`gemini_api_key` added to `CredentialBundle` (7 locations: dataclass field, `from_env`, `from_keyring`, `to_keyring`, `clear_keyring`, `has_any`, prompt label). + +Env var priority chain: + +``` +OPENPLANTER_GEMINI_API_KEY → GEMINI_API_KEY → GOOGLE_API_KEY +``` + +`GOOGLE_API_KEY` is the third fallback so users with a pre-existing Google AI environment variable get automatic pickup. + +Prompt label is `"Gemini"` (not `"Google Gemini"`). + +### `agent/config.py` + +Two new fields on `AgentConfig`: + +```python +gemini_api_key: str | None = None +gemini_base_url: str = "https://generativelanguage.googleapis.com/v1beta/openai" +``` + +`"gemini"` added to `PROVIDER_DEFAULT_MODELS` with default model `gemini-2.5-flash`. + +### `agent/builder.py` + +**Model inference:** + +```python +_GEMINI_RE = re.compile(r"^gemini-", re.IGNORECASE) + +def infer_provider_for_model(model: str) -> str: + ... + if _GEMINI_RE.match(model): + return "gemini" + ... +``` + +**`build_model_factory`** — Gemini branch: + +```python +if provider == "gemini": + return lambda model_name: OpenAICompatibleModel( + model=model_name, + api_key=cfg.gemini_api_key or creds.gemini_api_key or "", + base_url=cfg.gemini_base_url, + strict_tools=False, + ) +``` + +**`build_engine`** — Gemini branch mirrors the factory, also with `strict_tools=False`. + +**Factory guard clause** updated to include `cfg.gemini_api_key` alongside the other provider keys. + +### `agent/engine.py` + +**`_model_tier`** — keyword matching, version-agnostic: + +```python +def _model_tier(model: str) -> int: + if "pro" in model: + return 1 + if "lite" in model: + return 3 + return 2 +``` + +`"pro"→1`, `"lite"→3`, else→2. `"gemini-3-pro"` and `"gemini-2.5-pro"` both map to tier 1 without a code change. + +**`_lowest_tier_model`** — Gemini branch returns `"gemini-2.0-flash-lite"`. + +**`_MODEL_CONTEXT_WINDOWS`** — 5 Gemini entries, all `1_000_000`: + +```python +"gemini-2.5-pro": 1_000_000, +"gemini-2.5-flash": 1_000_000, +"gemini-2.0-pro": 1_000_000, +"gemini-2.0-flash": 1_000_000, +"gemini-2.0-flash-lite": 1_000_000, +``` + +### `agent/model.py` + +`EchoFallbackModel.note` updated to mention Gemini alongside other providers. + +### `tests/test_gemini.py` + +28 tests, no live API calls: + +- Env var priority chain (`OPENPLANTER_GEMINI_API_KEY` wins over `GEMINI_API_KEY` wins over `GOOGLE_API_KEY`) +- `infer_provider_for_model` matches various `gemini-*` strings and non-Gemini strings +- `build_model_factory` produces `OpenAICompatibleModel` with correct URL and `strict_tools=False` +- `_model_tier` keyword mapping (`pro`, `lite`, unrecognised) +- `_lowest_tier_model` Gemini branch +- Context window lookups for all 5 Gemini entries +- Guard clause rejects missing key + +--- + +## Usage + +```bash +export GEMINI_API_KEY=AIzaSy-... + +# By provider +openplanter --provider gemini + +# By model name (provider inferred) +openplanter --model gemini-2.5-pro +``` + +--- + +## Opus Review Findings (all addressed pre-implementation) + +| Finding | Resolution | +|---------|-----------| +| `strict_tools=False` must appear in both `build_engine` and `build_model_factory` | Both branches set it explicitly | +| `cfg.gemini_api_key` must be in the factory guard clause | Added alongside existing provider keys | +| Model tiers should match by keyword, not version number | `_model_tier` checks `"pro"`/`"lite"` substrings | +| Support `GOOGLE_API_KEY` as third env var fallback | Added as final fallback in `from_env` | +| Prompt label should be `"Gemini"`, not `"Google Gemini"` | Label is `"Gemini"` | +| Add Gemini to `_MODEL_CONTEXT_WINDOWS` | 5 entries, all 1 M | + +--- + +## Alternatives Considered + +### New `GeminiModel` class + +Add a dedicated model class the way Anthropic has `AnthropicModel`. + +**Why rejected:** Google's OpenAI-compatible endpoint makes this unnecessary. A new class would duplicate streaming, retry, and token-counting logic already in `OpenAICompatibleModel` just to set a URL and a flag. + +### Per-version tier constants + +Map specific version numbers (`2.0`, `2.5`) to tiers rather than keywords. + +**Why rejected:** Keyword matching on `"pro"`/`"lite"` is version-agnostic. When `gemini-3-pro` ships, the mapping is correct with no code change. Version-based constants would need updating on every new release. From 61e720445a28cd2327b3ecb64b2c7b94a2001650 Mon Sep 17 00:00:00 2001 From: Matt Kneale Date: Fri, 20 Feb 2026 15:45:12 +0000 Subject: [PATCH 13/14] fix: tighten _GEMINI_RE to ^gemini- and align design doc with implementation --- agent/builder.py | 2 +- docs/plans/2026-02-20-gemini-provider-design.md | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/agent/builder.py b/agent/builder.py index 3adb5d0c..26e5c697 100644 --- a/agent/builder.py +++ b/agent/builder.py @@ -27,7 +27,7 @@ _ANTHROPIC_RE = re.compile(r"^claude", re.IGNORECASE) _OPENAI_RE = re.compile(r"^(gpt|o[1-4]-|o[1-4]$|chatgpt|dall-e|tts-|whisper)", re.IGNORECASE) _CEREBRAS_RE = re.compile(r"^(llama.*cerebras|qwen-3|gpt-oss|zai-glm)", re.IGNORECASE) -_GEMINI_RE = re.compile(r"^gemini", re.IGNORECASE) +_GEMINI_RE = re.compile(r"^gemini-", re.IGNORECASE) def infer_provider_for_model(model: str) -> str | None: diff --git a/docs/plans/2026-02-20-gemini-provider-design.md b/docs/plans/2026-02-20-gemini-provider-design.md index b8faaa54..cfd93617 100644 --- a/docs/plans/2026-02-20-gemini-provider-design.md +++ b/docs/plans/2026-02-20-gemini-provider-design.md @@ -59,7 +59,7 @@ _GEMINI_RE = re.compile(r"^gemini-", re.IGNORECASE) def infer_provider_for_model(model: str) -> str: ... - if _GEMINI_RE.match(model): + if _GEMINI_RE.search(model): return "gemini" ... ``` @@ -102,7 +102,7 @@ def _model_tier(model: str) -> int: ```python "gemini-2.5-pro": 1_000_000, "gemini-2.5-flash": 1_000_000, -"gemini-2.0-pro": 1_000_000, +"gemini-3-flash": 1_000_000, "gemini-2.0-flash": 1_000_000, "gemini-2.0-flash-lite": 1_000_000, ``` From 225ace9c253364b70db3d6c675fc82bb66552554 Mon Sep 17 00:00:00 2001 From: Matt Kneale Date: Fri, 20 Feb 2026 15:51:00 +0000 Subject: [PATCH 14/14] fix: add Gemini support to CLI (provider choice, credential loading, runtime overrides) --- agent/__main__.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/agent/__main__.py b/agent/__main__.py index 87f70200..d41a9c8e 100644 --- a/agent/__main__.py +++ b/agent/__main__.py @@ -39,7 +39,7 @@ def build_parser() -> argparse.ArgumentParser: parser.add_argument( "--provider", default=None, - choices=["auto", "openai", "anthropic", "openrouter", "cerebras", "all"], + choices=["auto", "openai", "anthropic", "openrouter", "cerebras", "gemini", "all"], help="Model provider. Use 'all' only with --list-models.", ) parser.add_argument("--model", help="Model name (use 'newest' to auto-select latest from API).") @@ -86,6 +86,7 @@ def build_parser() -> argparse.ArgumentParser: parser.add_argument("--cerebras-api-key", help="Cerebras API key override.") parser.add_argument("--exa-api-key", help="Exa API key override.") parser.add_argument("--voyage-api-key", help="Voyage API key override.") + parser.add_argument("--gemini-api-key", help="Gemini API key override.") parser.add_argument( "--configure-keys", action="store_true", @@ -150,7 +151,7 @@ def _format_ts(ts: int) -> str: def _resolve_provider(requested: str, creds: CredentialBundle) -> str: requested = requested.strip().lower() - if requested in {"openai", "anthropic", "openrouter", "cerebras"}: + if requested in {"openai", "anthropic", "openrouter", "cerebras", "gemini"}: return requested if requested == "all": return "all" @@ -162,15 +163,17 @@ def _resolve_provider(requested: str, creds: CredentialBundle) -> str: return "openrouter" if creds.cerebras_api_key: return "cerebras" + if creds.gemini_api_key: + return "gemini" return "openai" def _print_models(cfg: AgentConfig, requested_provider: str) -> int: providers: list[str] if requested_provider == "all": - providers = ["openai", "anthropic", "openrouter", "cerebras"] + providers = ["openai", "anthropic", "openrouter", "cerebras", "gemini"] elif requested_provider == "auto": - providers = ["openai", "anthropic", "openrouter", "cerebras"] + providers = ["openai", "anthropic", "openrouter", "cerebras", "gemini"] else: providers = [requested_provider] @@ -208,6 +211,7 @@ def _load_credentials( cerebras_api_key=user_creds.cerebras_api_key, exa_api_key=user_creds.exa_api_key, voyage_api_key=user_creds.voyage_api_key, + gemini_api_key=user_creds.gemini_api_key, ) store = CredentialStore(workspace=cfg.workspace, session_root_dir=cfg.session_root_dir) @@ -224,6 +228,8 @@ def _load_credentials( creds.exa_api_key = stored.exa_api_key if stored.voyage_api_key: creds.voyage_api_key = stored.voyage_api_key + if stored.gemini_api_key: + creds.gemini_api_key = stored.gemini_api_key env_creds = credentials_from_env() if env_creds.openai_api_key: @@ -238,6 +244,8 @@ def _load_credentials( creds.exa_api_key = env_creds.exa_api_key if env_creds.voyage_api_key: creds.voyage_api_key = env_creds.voyage_api_key + if env_creds.gemini_api_key: + creds.gemini_api_key = env_creds.gemini_api_key for env_path in discover_env_candidates(cfg.workspace): file_creds = parse_env_file(env_path) @@ -257,6 +265,8 @@ def _load_credentials( creds.exa_api_key = args.exa_api_key.strip() or creds.exa_api_key if args.voyage_api_key: creds.voyage_api_key = args.voyage_api_key.strip() or creds.voyage_api_key + if args.gemini_api_key: + creds.gemini_api_key = args.gemini_api_key.strip() or creds.gemini_api_key changed_by_prompt = False if allow_prompt: @@ -299,6 +309,7 @@ def _apply_runtime_overrides(cfg: AgentConfig, args: argparse.Namespace, creds: cfg.cerebras_api_key = creds.cerebras_api_key cfg.exa_api_key = creds.exa_api_key cfg.voyage_api_key = creds.voyage_api_key + cfg.gemini_api_key = creds.gemini_api_key cfg.api_key = cfg.openai_api_key if args.base_url: @@ -515,6 +526,7 @@ def main() -> None: "anthropic": cfg.anthropic_api_key, "openrouter": cfg.openrouter_api_key, "cerebras": cfg.cerebras_api_key, + "gemini": cfg.gemini_api_key, }.get(inferred) if key: cfg.provider = inferred