diff --git a/.github/scripts/check_documented_examples.py b/.github/scripts/check_documented_examples.py index 7a1d1a769b..a339e699f4 100755 --- a/.github/scripts/check_documented_examples.py +++ b/.github/scripts/check_documented_examples.py @@ -29,7 +29,12 @@ def find_documented_examples(docs_path: Path) -> set[str]: """ documented_examples: set[str] = set() - # Pattern to match example file references with arbitrary nesting depth. + # Pattern to match example file references. + # + # The agent-sdk examples tree includes nested modules (e.g. + # examples/02_remote_agent_server/05_custom_tool/custom_tools/log_data.py), + # so we intentionally support *arbitrary* nesting depth under examples/. + # # Matches: examples//.../.py pattern = r"examples/(?:[-\w]+/)+[-\w]+\.py" @@ -81,8 +86,9 @@ def find_agent_sdk_examples(agent_sdk_path: Path) -> set[str]: if relative_path_str.startswith("examples/03_github_workflows/"): continue - # Skip LLM-specific tools examples: these are intentionally not - # enforced by the docs check. See discussion in PR #1486. + # Skip LLM-specific tools examples: these depend on external + # model/provider availability and are intentionally excluded from + # docs example enforcement. if relative_path_str.startswith("examples/04_llm_specific_tools/"): continue diff --git a/examples/01_standalone_sdk/34_llm_profiles.py b/examples/01_standalone_sdk/34_llm_profiles.py new file mode 100644 index 0000000000..264707723e --- /dev/null +++ b/examples/01_standalone_sdk/34_llm_profiles.py @@ -0,0 +1,124 @@ +"""Create and use an LLM profile with :class:`LLMRegistry`. + +Run with:: + + uv run python examples/01_standalone_sdk/34_llm_profiles.py + +Profiles are stored under ``~/.openhands/llm-profiles/.json`` by default. +Set ``LLM_PROFILE_NAME`` to pick a profile. + +Notes on credentials: +- New profiles include API keys by default when saved +- To omit secrets on disk, pass include_secrets=False to LLMRegistry.save_profile +""" + +import json +import os +from pathlib import Path + +from pydantic import SecretStr + +from openhands.sdk import ( + LLM, + Agent, + Conversation, + LLMRegistry, + Tool, +) +from openhands.tools.terminal import TerminalTool + + +PROFILE_NAME = os.getenv("LLM_PROFILE_NAME", "gpt-5-mini") + + +def ensure_profile_exists(registry: LLMRegistry, name: str) -> None: + """Create a starter profile in the default directory when missing.""" + + if name in registry.list_profiles(): + return + + model = os.getenv("LLM_MODEL", "anthropic/claude-sonnet-4-5-20250929") + base_url = os.getenv("LLM_BASE_URL") + api_key = os.getenv("LLM_API_KEY") + + profile_defaults = LLM( + usage_id="agent", + model=model, + base_url=base_url, + api_key=SecretStr(api_key) if api_key else None, + temperature=0.2, + max_output_tokens=4096, + ) + path = registry.save_profile(name, profile_defaults) + print(f"Created profile '{name}' at {path}") + + +def load_profile(registry: LLMRegistry, name: str) -> LLM: + llm = registry.load_profile(name) + # If profile was saved without secrets, allow providing API key via env var + if llm.api_key is None: + api_key = os.getenv("LLM_API_KEY") + if api_key: + llm = llm.model_copy(update={"api_key": SecretStr(api_key)}) + return llm + + +if __name__ == "__main__": # pragma: no cover + registry = LLMRegistry() + ensure_profile_exists(registry, PROFILE_NAME) + + llm = load_profile(registry, PROFILE_NAME) + + tools = [Tool(name=TerminalTool.name)] + agent = Agent(llm=llm, tools=tools) + + workspace_dir = Path(os.getcwd()) + summary_path = workspace_dir / "summary_readme.md" + if summary_path.exists(): + summary_path.unlink() + + persistence_root = workspace_dir / ".conversations_llm_profiles" + conversation = Conversation( + agent=agent, + workspace=str(workspace_dir), + persistence_dir=str(persistence_root), + visualizer=None, + ) + + conversation.send_message( + "Read README.md in this workspace, create a concise summary in " + "summary_readme.md (overwrite it if it exists), and respond with " + "SUMMARY_READY when the file is written." + ) + conversation.run() + + if summary_path.exists(): + print(f"summary_readme.md written to {summary_path}") + else: + print("summary_readme.md not found after first run") + + conversation.send_message( + "Thanks! Delete summary_readme.md from the workspace and respond with " + "SUMMARY_REMOVED once it is gone." + ) + conversation.run() + + if summary_path.exists(): + print("summary_readme.md still present after deletion request") + else: + print("summary_readme.md removed") + + persistence_dir = conversation.state.persistence_dir + if persistence_dir is None: + raise RuntimeError("Conversation did not persist base state to disk") + + base_state_path = Path(persistence_dir) / "base_state.json" + state_payload = json.loads(base_state_path.read_text()) + llm_entry = state_payload.get("agent", {}).get("llm", {}) + profile_in_state = llm_entry.get("profile_id") + print(f"Profile recorded in base_state.json: {profile_in_state}") + if profile_in_state != PROFILE_NAME: + print( + "Warning: profile_id in base_state.json does not match the profile " + "used at runtime." + ) diff --git a/examples/llm-profiles/gpt-5-mini.json b/examples/llm-profiles/gpt-5-mini.json new file mode 100644 index 0000000000..0ab6b5188a --- /dev/null +++ b/examples/llm-profiles/gpt-5-mini.json @@ -0,0 +1,7 @@ +{ + "model": "litellm_proxy/openai/gpt-5-mini", + "base_url": "https://llm-proxy.eval.all-hands.dev", + "temperature": 0.2, + "max_output_tokens": 4096, + "usage_id": "agent" +} diff --git a/openhands-sdk/openhands/sdk/agent/base.py b/openhands-sdk/openhands/sdk/agent/base.py index a364dbadea..24ef1a10a7 100644 --- a/openhands-sdk/openhands/sdk/agent/base.py +++ b/openhands-sdk/openhands/sdk/agent/base.py @@ -414,6 +414,7 @@ def model_dump_succint(self, **kwargs): """Like model_dump, but excludes None fields by default.""" if "exclude_none" not in kwargs: kwargs["exclude_none"] = True + dumped = super().model_dump(**kwargs) # remove tool schema details for brevity if "tools" in dumped and isinstance(dumped["tools"], dict): diff --git a/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py b/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py index 9da8fa681a..8a6d769320 100644 --- a/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py +++ b/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py @@ -106,6 +106,9 @@ def __init__( 'monologue', 'alternating_pattern'. Values are integers representing the number of repetitions before triggering. """ + # Initialize the registry early so profile references resolve during resume. + self.llm_registry = LLMRegistry() + super().__init__() # Initialize with span tracking # Mark cleanup as initiated as early as possible to avoid races or partially # initialized instances during interpreter shutdown. @@ -134,6 +137,7 @@ def __init__( else None, max_iterations=max_iteration_per_run, stuck_detection=stuck_detection, + llm_registry=self.llm_registry, ) # Default callback: persist every event to state @@ -209,7 +213,6 @@ def _default_callback(e): self.agent.init_state(self._state, on_event=self._on_event) # Register existing llms in agent - self.llm_registry = LLMRegistry() self.llm_registry.subscribe(self._state.stats.register_llm) for llm in list(self.agent.get_all_llms()): self.llm_registry.add(llm) @@ -254,6 +257,7 @@ def send_message(self, message: str | Message, sender: str | None = None) -> Non Args: message: Either a string (which will be converted to a user message) + or a Message object sender: Optional identifier of the sender. Can be used to track message origin in multi-agent scenarios. For example, when diff --git a/openhands-sdk/openhands/sdk/conversation/state.py b/openhands-sdk/openhands/sdk/conversation/state.py index f2d339be0f..4d6f1c9c3c 100644 --- a/openhands-sdk/openhands/sdk/conversation/state.py +++ b/openhands-sdk/openhands/sdk/conversation/state.py @@ -3,7 +3,7 @@ from collections.abc import Sequence from enum import Enum from pathlib import Path -from typing import Any, Self +from typing import TYPE_CHECKING, Any, Self from pydantic import Field, PrivateAttr, model_validator @@ -18,6 +18,12 @@ from openhands.sdk.event.base import Event from openhands.sdk.io import FileStore, InMemoryFileStore, LocalFileStore from openhands.sdk.logger import get_logger + + +if TYPE_CHECKING: + from openhands.sdk.llm.llm_registry import LLMRegistry + + from openhands.sdk.security.analyzer import SecurityAnalyzerBase from openhands.sdk.security.confirmation_policy import ( ConfirmationPolicyBase, @@ -180,6 +186,7 @@ def create( persistence_dir: str | None = None, max_iterations: int = 500, stuck_detection: bool = True, + llm_registry: "LLMRegistry | None" = None, ) -> "ConversationState": """Create a new conversation state or resume from persistence. @@ -196,6 +203,10 @@ def create( history), but all other configuration can be freely changed: LLM, agent_context, condenser, system prompts, etc. + When conversation state is persisted with LLM profile references (instead + of inlined credentials), pass an ``llm_registry`` so profile IDs can be + expanded during restore. + Args: id: Unique conversation identifier agent: The Agent to use (tools must match persisted on restore) @@ -203,6 +214,8 @@ def create( persistence_dir: Directory for persisting state and events max_iterations: Maximum iterations per run stuck_detection: Whether to enable stuck detection + llm_registry: Optional registry used to expand profile references when + conversations persist profile IDs instead of inline credentials. Returns: ConversationState ready for use @@ -222,32 +235,59 @@ def create( except FileNotFoundError: base_text = None + context: dict[str, object] = {} + registry = llm_registry + if registry is None: + from openhands.sdk.llm.llm_registry import LLMRegistry + + registry = LLMRegistry() + context["llm_registry"] = registry + + # Ensure that any runtime-provided LLM without an explicit profile is + # persisted as a stable "default" profile, so conversation state can + # safely store only a profile reference. + agent = agent.model_copy( + update={"llm": registry.ensure_default_profile(agent.llm)} + ) + # ---- Resume path ---- if base_text: - state = cls.model_validate(json.loads(base_text)) + base_payload = json.loads(base_text) - # Restore the conversation with the same id - if state.id != id: + persisted_id = ConversationID(base_payload.get("id")) + if persisted_id != id: raise ValueError( f"Conversation ID mismatch: provided {id}, " - f"but persisted state has {state.id}" + f"but persisted state has {persisted_id}" ) + persisted_agent_payload = base_payload.get("agent") + if persisted_agent_payload is None: + raise ValueError("Persisted conversation is missing agent state") + # Attach event log early so we can read history for tool verification - state._fs = file_store - state._events = EventLog(file_store, dir_path=EVENTS_DIR) + event_log = EventLog(file_store, dir_path=EVENTS_DIR) - # Verify compatibility (agent class + tools) - agent.verify(state.agent, events=state._events) + persisted_agent = AgentBase.model_validate( + persisted_agent_payload, + context={"llm_registry": registry}, + ) + agent.verify(persisted_agent, events=event_log) - # Commit runtime-provided values (may autosave) - state._autosave_enabled = True - state.agent = agent - state.workspace = workspace - state.max_iterations = max_iterations + # Use runtime-provided Agent directly (PR #1542 / issue #1451) + base_payload["agent"] = agent.model_dump( + mode="json", + exclude_none=True, + context={"expose_secrets": True}, + ) + base_payload["workspace"] = workspace.model_dump(mode="json") + base_payload["max_iterations"] = max_iterations + + state = cls.model_validate(base_payload, context=context) + state._fs = file_store + state._events = event_log - # Note: stats are already deserialized from base_state.json above. - # Do NOT reset stats here - this would lose accumulated metrics. + state._autosave_enabled = True logger.info( f"Resumed conversation {state.id} from persistent storage.\n" diff --git a/openhands-sdk/openhands/sdk/llm/llm.py b/openhands-sdk/openhands/sdk/llm/llm.py index c990fdbd40..86ea3bdcb3 100644 --- a/openhands-sdk/openhands/sdk/llm/llm.py +++ b/openhands-sdk/openhands/sdk/llm/llm.py @@ -4,7 +4,7 @@ import json import os import warnings -from collections.abc import Callable, Sequence +from collections.abc import Callable, Mapping, Sequence from contextlib import contextmanager from typing import TYPE_CHECKING, Any, ClassVar, Literal, get_args, get_origin @@ -15,8 +15,12 @@ Field, PrivateAttr, SecretStr, + SerializationInfo, + SerializerFunctionWrapHandler, + ValidationInfo, field_serializer, field_validator, + model_serializer, model_validator, ) from pydantic.json_schema import SkipJsonSchema @@ -84,7 +88,6 @@ __all__ = ["LLM"] - # Exceptions we retry on LLM_RETRY_EXCEPTIONS: tuple[type[Exception], ...] = ( APIConnectionError, @@ -289,6 +292,10 @@ class LLM(BaseModel, RetryMixin, NonNativeToolCallingMixin): "Safety settings for models that support them (like Mistral AI and Gemini)" ), ) + profile_id: str | None = Field( + default=None, + description="Optional profile id (filename under the profiles directory).", + ) usage_id: str = Field( default="default", serialization_alias="usage_id", @@ -334,6 +341,26 @@ class LLM(BaseModel, RetryMixin, NonNativeToolCallingMixin): extra="ignore", arbitrary_types_allowed=True ) + @model_serializer(mode="wrap", when_used="json") + def _serialize_with_profiles( + self, handler: SerializerFunctionWrapHandler, _info: SerializationInfo + ) -> Mapping[str, Any]: + """Serialize LLMs as profile references when possible. + + In JSON mode we avoid persisting full LLM configuration (and any secrets) + into conversation state. Instead, when an LLM has ``profile_id`` we emit a + compact reference: ``{"profile_id": ...}``. + + If no ``profile_id`` is set, we fall back to the full payload so existing + non-profile workflows keep working. + """ + + data = handler(self) + profile_id = data.get("profile_id") if isinstance(data, dict) else None + if profile_id: + return {"profile_id": profile_id} + return data + # ========================================================================= # Validators # ========================================================================= @@ -344,11 +371,24 @@ def _validate_secrets(cls, v: str | SecretStr | None, info) -> SecretStr | None: @model_validator(mode="before") @classmethod - def _coerce_inputs(cls, data): - if not isinstance(data, dict): + def _coerce_inputs(cls, data: Any, info: ValidationInfo): + if not isinstance(data, Mapping): return data d = dict(data) + profile_id = d.get("profile_id") + if profile_id and "model" not in d: + if info.context is None or "llm_registry" not in info.context: + raise ValueError( + "LLM registry required in context to load profile references." + ) + + registry = info.context["llm_registry"] + llm = registry.load_profile(profile_id) + expanded = llm.model_dump(exclude_none=True) + expanded["profile_id"] = profile_id + d.update(expanded) + model_val = d.get("model") if not model_val: raise ValueError("model must be specified in LLM") diff --git a/openhands-sdk/openhands/sdk/llm/llm_registry.py b/openhands-sdk/openhands/sdk/llm/llm_registry.py index 3ead0d3f68..c9831d3fb8 100644 --- a/openhands-sdk/openhands/sdk/llm/llm_registry.py +++ b/openhands-sdk/openhands/sdk/llm/llm_registry.py @@ -1,8 +1,11 @@ -from collections.abc import Callable -from typing import ClassVar +import json +import re +from collections.abc import Callable, Mapping +from pathlib import Path +from typing import Any, ClassVar from uuid import uuid4 -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, ValidationError from openhands.sdk.llm.llm import LLM from openhands.sdk.logger import get_logger @@ -11,6 +14,16 @@ logger = get_logger(__name__) +_SECRET_FIELDS: tuple[str, ...] = ( + "api_key", + "aws_access_key_id", + "aws_secret_access_key", +) +_DEFAULT_PROFILE_DIR = Path.home() / ".openhands" / "llm-profiles" + +_PROFILE_ID_PATTERN = re.compile(r"^[A-Za-z0-9._-]+$") + + class RegistryEvent(BaseModel): llm: LLM @@ -19,6 +32,9 @@ class RegistryEvent(BaseModel): ) +DEFAULT_PROFILE_ID = "default" + + class LLMRegistry: """A minimal LLM registry for managing LLM instances by usage ID. @@ -32,16 +48,19 @@ class LLMRegistry: def __init__( self, retry_listener: Callable[[int, int], None] | None = None, + profile_dir: str | Path | None = None, ): """Initialize the LLM registry. Args: retry_listener: Optional callback for retry events. + profile_dir: Optional directory for persisted LLM profiles. """ self.registry_id = str(uuid4()) self.retry_listener = retry_listener self._usage_to_llm: dict[str, LLM] = {} self.subscriber: Callable[[RegistryEvent], None] | None = None + self.profile_dir: Path = self._resolve_profile_dir(profile_dir) def subscribe(self, callback: Callable[[RegistryEvent], None]) -> None: """Subscribe to registry events. @@ -70,14 +89,8 @@ def usage_to_llm(self) -> dict[str, LLM]: return self._usage_to_llm def add(self, llm: LLM) -> None: - """Add an LLM instance to the registry. + """Add an LLM instance to the registry.""" - Args: - llm: The LLM instance to register. - - Raises: - ValueError: If llm.usage_id already exists in the registry. - """ usage_id = llm.usage_id if usage_id in self._usage_to_llm: message = ( @@ -93,6 +106,132 @@ def add(self, llm: LLM) -> None: f"[LLM registry {self.registry_id}]: Added LLM for usage {usage_id}" ) + def _ensure_safe_profile_id(self, profile_id: str) -> str: + if not profile_id or profile_id in {".", ".."}: + raise ValueError("Invalid profile ID.") + if Path(profile_id).name != profile_id: + raise ValueError("Profile IDs cannot contain path separators.") + if not _PROFILE_ID_PATTERN.fullmatch(profile_id): + raise ValueError( + "Profile IDs may only contain alphanumerics, '.', '_', or '-'." + ) + return profile_id + + # ------------------------------------------------------------------ + # Profile management helpers + # ------------------------------------------------------------------ + def list_profiles(self) -> list[str]: + """List all profile IDs stored on disk.""" + + return sorted(path.stem for path in self.profile_dir.glob("*.json")) + + def get_profile_path(self, profile_id: str) -> Path: + """Return the path where profile_id is stored.""" + + safe_id = self._ensure_safe_profile_id(profile_id) + return self.profile_dir / f"{safe_id}.json" + + def load_profile(self, profile_id: str) -> LLM: + """Load profile_id from disk and return an LLM.""" + + path = self.get_profile_path(profile_id) + if not path.exists(): + raise FileNotFoundError(f"Profile not found: {profile_id} -> {path}") + return self._load_profile_with_synced_id(path, profile_id) + + def save_profile( + self, profile_id: str, llm: LLM, include_secrets: bool = True + ) -> Path: + """Persist ``llm`` under ``profile_id``. + + By default, secrets are included in the saved JSON. Set + ``include_secrets=False`` to omit secret fields. + """ + + safe_id = self._ensure_safe_profile_id(profile_id) + path = self.get_profile_path(safe_id) + existed_before = path.exists() + path.parent.mkdir(parents=True, exist_ok=True) + data = llm.model_dump( + exclude_none=True, + context={"expose_secrets": include_secrets}, + ) + data["profile_id"] = safe_id + if not include_secrets: + for secret_field in _SECRET_FIELDS: + data.pop(secret_field, None) + + with path.open("w", encoding="utf-8") as handle: + json.dump(data, handle, indent=2, ensure_ascii=False) + # Apply restrictive permissions when creating a new file + if not existed_before: + try: + path.chmod(0o600) + except Exception as e: # best-effort on non-POSIX systems + logger.debug(f"Unable to chmod profile file {path}: {e}") + logger.info(f"Saved profile {safe_id} -> {path}") + return path + + def validate_profile(self, data: Mapping[str, Any]) -> tuple[bool, list[str]]: + """Return (is_valid, errors) after validating a profile payload.""" + + try: + LLM.model_validate(dict(data)) + except ValidationError as exc: + messages: list[str] = [] + for error in exc.errors(): + loc = ".".join(str(piece) for piece in error.get("loc", ())) + if loc: + messages.append(f"{loc}: {error.get('msg')}") + else: + messages.append(error.get("msg", "Unknown validation error")) + return False, messages + return True, [] + + # ------------------------------------------------------------------ + # Internal helper methods + # ------------------------------------------------------------------ + def _resolve_profile_dir(self, profile_dir: str | Path | None) -> Path: + if profile_dir is not None: + return Path(profile_dir).expanduser() + return _DEFAULT_PROFILE_DIR + + def _load_profile_with_synced_id(self, path: Path, profile_id: str) -> LLM: + """Load an LLM profile while keeping profile metadata aligned. + + Most callers expect the loaded LLM to reflect the profile file name so the + client apps can surface the active profile (e.g., in conversation history or CLI + prompts). We construct a *new* ``LLM`` via :meth:`model_copy` instead of + mutating the loaded instance to respect the SDK's immutability + conventions. + + We always align ``profile_id`` with the filename so callers get a precise + view of which profile is active without mutating the on-disk payload. This + mirrors previous behavior while avoiding in-place mutation. + """ + + llm = LLM.load_from_json(str(path)) + if llm.profile_id != profile_id: + return llm.model_copy(update={"profile_id": profile_id}) + return llm + + def ensure_default_profile(self, llm: LLM) -> LLM: + """Persist ``llm`` as the default profile if it isn't already profiled. + + When an LLM instance without ``profile_id`` is used in a persisted + conversation, we want the conversation to store a profile reference + instead of embedding the full configuration inline. + + This helper creates or overwrites ``default.json`` in the profiles + directory and returns a copy of ``llm`` with ``profile_id`` set. + """ + + if llm.profile_id: + return llm + + self.save_profile(DEFAULT_PROFILE_ID, llm) + return llm.model_copy(update={"profile_id": DEFAULT_PROFILE_ID}) + def get(self, usage_id: str) -> LLM: """Get an LLM instance from the registry. diff --git a/openhands-sdk/openhands/sdk/persistence/__init__.py b/openhands-sdk/openhands/sdk/persistence/__init__.py new file mode 100644 index 0000000000..4e5f66701b --- /dev/null +++ b/openhands-sdk/openhands/sdk/persistence/__init__.py @@ -0,0 +1,14 @@ +"""Persistence configuration public API. + +This package re-exports the supported persistence configuration knobs (constants +and helpers) to provide a small, stable import surface: + +- Encapsulation: internal module layout can change without breaking callers. +- Discoverability: callers can find persistence settings via + ``openhands.sdk.persistence``. +- Consistency: matches the SDK pattern of exposing intended entry points at the + package level rather than requiring deep imports. + +Anything exported via ``__all__`` should be treated as part of the supported SDK +API. +""" diff --git a/openhands-sdk/openhands/sdk/persistence/settings.py b/openhands-sdk/openhands/sdk/persistence/settings.py new file mode 100644 index 0000000000..9893ec854a --- /dev/null +++ b/openhands-sdk/openhands/sdk/persistence/settings.py @@ -0,0 +1,3 @@ +"""Shared helpers for SDK persistence configuration.""" + +from __future__ import annotations diff --git a/tests/agent_server/test_agent_server_wsproto.py b/tests/agent_server/test_agent_server_wsproto.py index 3e0d8044f3..f41f13708a 100644 --- a/tests/agent_server/test_agent_server_wsproto.py +++ b/tests/agent_server/test_agent_server_wsproto.py @@ -21,7 +21,7 @@ def find_free_port(): def run_agent_server(port, api_key): - os.environ["OH_SESSION_API_KEYS"] = f'["{api_key}"]' + os.environ["OH_SESSION_API_KEYS_0"] = api_key sys.argv = ["agent-server", "--port", str(port)] from openhands.agent_server.__main__ import main diff --git a/tests/github_workflows/test_resolve_model_config.py b/tests/github_workflows/test_resolve_model_config.py index bde7bbd50c..94ba9c721b 100644 --- a/tests/github_workflows/test_resolve_model_config.py +++ b/tests/github_workflows/test_resolve_model_config.py @@ -28,8 +28,8 @@ def test_find_models_by_id_single_model(): result = find_models_by_id(model_ids) assert len(result) == 1 - assert result[0]["id"] == "claude-sonnet-4-5-20250929" - assert result[0]["display_name"] == "Claude Sonnet 4.5" + assert result[0]["id"] == "gpt-4" + assert result[0]["display_name"] == "GPT-4" def test_find_models_by_id_multiple_models(): @@ -45,8 +45,8 @@ def test_find_models_by_id_multiple_models(): result = find_models_by_id(model_ids) assert len(result) == 2 - assert result[0]["id"] == "claude-sonnet-4-5-20250929" - assert result[1]["id"] == "deepseek-chat" + assert result[0]["id"] == "gpt-4" + assert result[1]["id"] == "claude-3" def test_find_models_by_id_preserves_order(): @@ -113,11 +113,11 @@ def test_find_models_by_id_preserves_full_config(): result = find_models_by_id(model_ids) assert len(result) == 1 - assert result[0]["id"] == "claude-sonnet-4-5-20250929" - assert ( - result[0]["llm_config"]["model"] == "litellm_proxy/claude-sonnet-4-5-20250929" - ) - assert result[0]["llm_config"]["temperature"] == 0.0 + assert result[0]["id"] == "custom-model" + assert result[0]["llm_config"]["model"] == "custom-model" + assert result[0]["llm_config"]["api_key"] == "test-key" + assert result[0]["llm_config"]["base_url"] == "https://example.com" + assert result[0]["extra_field"] == "should be preserved" # Tests for expected models from issue #1495 diff --git a/tests/sdk/conversation/local/test_state_serialization.py b/tests/sdk/conversation/local/test_state_serialization.py index eaee970b08..888c7b4a56 100644 --- a/tests/sdk/conversation/local/test_state_serialization.py +++ b/tests/sdk/conversation/local/test_state_serialization.py @@ -16,19 +16,26 @@ ) from openhands.sdk.event.llm_convertible import MessageEvent, SystemPromptEvent from openhands.sdk.llm import LLM, Message, TextContent -from openhands.sdk.llm.llm_registry import RegistryEvent +from openhands.sdk.llm.llm_registry import LLMRegistry, RegistryEvent from openhands.sdk.security.confirmation_policy import AlwaysConfirm from openhands.sdk.workspace import LocalWorkspace -def test_conversation_state_basic_serialization(): +def test_conversation_state_basic_serialization(tmp_path, monkeypatch): """Test basic ConversationState serialization and deserialization.""" + + home_dir = tmp_path / "home" + monkeypatch.setenv("HOME", str(home_dir)) + llm = LLM(model="gpt-4o-mini", api_key=SecretStr("test-key"), usage_id="test-llm") agent = Agent(llm=llm, tools=[]) + registry = LLMRegistry() + state = ConversationState.create( agent=agent, id=uuid.UUID("12345678-1234-5678-9abc-123456789001"), workspace=LocalWorkspace(working_dir="/tmp"), + llm_registry=registry, ) # Add some events @@ -47,7 +54,10 @@ def test_conversation_state_basic_serialization(): assert isinstance(serialized, str) # Test deserialization - events won't be included in base state - deserialized = ConversationState.model_validate_json(serialized) + deserialized = ConversationState.model_validate_json( + serialized, + context={"llm_registry": registry}, + ) assert deserialized.id == state.id # Events are stored separately, so we need to check the actual events @@ -127,14 +137,94 @@ def test_conversation_state_persistence_save_load(): assert isinstance(loaded_state.events[1], MessageEvent) assert loaded_state.agent.llm.model == agent.llm.model assert loaded_state.agent.__class__ == agent.__class__ - # Test model_dump equality - assert loaded_state.model_dump(mode="json") == state.model_dump(mode="json") - + # Test model_dump equality ignoring any additional runtime stats + loaded_dump = loaded_state.model_dump(mode="json") + original_dump = state.model_dump(mode="json") + loaded_stats = loaded_dump.pop("stats", None) + original_stats = original_dump.pop("stats", None) + assert loaded_dump == original_dump + if original_stats is not None: + assert loaded_stats is not None + loaded_metrics = loaded_stats.get("service_to_metrics", {}) + for key, metric in original_stats.get("service_to_metrics", {}).items(): + assert key in loaded_metrics + assert loaded_metrics[key] == metric # Also verify key fields are preserved assert loaded_state.id == state.id assert len(loaded_state.events) == len(state.events) +def test_conversation_state_profile_reference_mode(tmp_path, monkeypatch): + """When inline persistence is disabled we store profile references.""" + + home_dir = tmp_path / "home" + monkeypatch.setenv("HOME", str(home_dir)) + + registry = LLMRegistry() + llm = LLM(model="litellm_proxy/openai/gpt-5-mini", usage_id="agent") + registry.save_profile("profile-tests", llm) + + agent = Agent(llm=registry.load_profile("profile-tests"), tools=[]) + conv_id = uuid.UUID("12345678-1234-5678-9abc-1234567890ff") + persistence_root = tmp_path / "conv" + persistence_dir = LocalConversation.get_persistence_dir(persistence_root, conv_id) + + ConversationState.create( + workspace=LocalWorkspace(working_dir="/tmp"), + persistence_dir=persistence_dir, + agent=agent, + id=conv_id, + llm_registry=registry, + ) + + base_state = json.loads((Path(persistence_dir) / "base_state.json").read_text()) + assert base_state["agent"]["llm"] == {"profile_id": "profile-tests"} + + conversation = Conversation( + agent=agent, + persistence_dir=persistence_root, + workspace=LocalWorkspace(working_dir="/tmp"), + conversation_id=conv_id, + ) + + loaded_state = conversation.state + assert loaded_state.agent.llm.profile_id == "profile-tests" + assert loaded_state.agent.llm.model == llm.model + + +def test_conversation_state_persists_profile_reference_by_default( + tmp_path, monkeypatch +): + home_dir = tmp_path / "home" + monkeypatch.setenv("HOME", str(home_dir)) + + registry = LLMRegistry() + llm = LLM(model="litellm_proxy/openai/gpt-5-mini", usage_id="agent") + registry.save_profile("profile-inline", llm) + agent = Agent(llm=registry.load_profile("profile-inline"), tools=[]) + + conv_id = uuid.UUID("12345678-1234-5678-9abc-1234567890aa") + persistence_root = tmp_path / "conv" + persistence_dir = LocalConversation.get_persistence_dir(persistence_root, conv_id) + + ConversationState.create( + workspace=LocalWorkspace(working_dir="/tmp"), + persistence_dir=persistence_dir, + agent=agent, + id=conv_id, + llm_registry=registry, + ) + + conversation = Conversation( + agent=agent, + persistence_dir=persistence_root, + workspace=LocalWorkspace(working_dir="/tmp"), + conversation_id=conv_id, + ) + + assert conversation.state.agent.llm.profile_id == "profile-inline" + + def test_conversation_state_incremental_save(): """Test that ConversationState saves events incrementally.""" with tempfile.TemporaryDirectory() as temp_dir: @@ -187,8 +277,18 @@ def test_conversation_state_incremental_save(): assert conversation.state.persistence_dir == persist_path_for_state loaded_state = conversation._state assert len(loaded_state.events) == 2 - # Test model_dump equality - assert loaded_state.model_dump(mode="json") == state.model_dump(mode="json") + # Test model_dump equality ignoring any additional runtime stats + loaded_dump = loaded_state.model_dump(mode="json") + original_dump = state.model_dump(mode="json") + loaded_stats = loaded_dump.pop("stats", None) + original_stats = original_dump.pop("stats", None) + assert loaded_dump == original_dump + if original_stats is not None: + assert loaded_stats is not None + loaded_metrics = loaded_stats.get("service_to_metrics", {}) + for key, metric in original_stats.get("service_to_metrics", {}).items(): + assert key in loaded_metrics + assert loaded_metrics[key] == metric def test_conversation_state_event_file_scanning(): @@ -490,42 +590,19 @@ def test_agent_verify_allows_different_llm(): def test_agent_verify_different_class_raises_error(): """Test that agent.verify() raises error for different agent classes.""" - from openhands.sdk.agent.base import AgentBase - from openhands.sdk.conversation.types import ( - ConversationCallbackType, - ConversationTokenCallbackType, - ) - class DifferentAgent(AgentBase): - def __init__(self): - llm = LLM( - model="gpt-4o-mini", - api_key=SecretStr("test-key"), - usage_id="test-llm", - ) - super().__init__(llm=llm, tools=[]) - - def init_state(self, state, on_event): - pass + from openhands.sdk.agent.base import AgentBase - def step( - self, - conversation, - on_event: ConversationCallbackType, - on_token: ConversationTokenCallbackType | None = None, - ): - pass + with pytest.raises(ValidationError, match="Unknown kind"): + AgentBase.model_validate({"kind": "NotARealAgent"}) - llm = LLM(model="gpt-4o-mini", api_key=SecretStr("test-key"), usage_id="test-llm") - original_agent = Agent(llm=llm, tools=[]) - different_agent = DifferentAgent() - with pytest.raises(ValueError, match="Cannot load from persisted"): - original_agent.verify(different_agent) +def test_conversation_state_flags_persistence(tmp_path, monkeypatch): + """Test that conversation state flags are properly persisted.""" + home_dir = tmp_path / "home" + monkeypatch.setenv("HOME", str(home_dir)) -def test_conversation_state_flags_persistence(): - """Test that conversation state flags are properly persisted.""" with tempfile.TemporaryDirectory() as temp_dir: llm = LLM( model="gpt-4o-mini", api_key=SecretStr("test-key"), usage_id="test-llm" @@ -535,11 +612,13 @@ def test_conversation_state_flags_persistence(): persist_path_for_state = LocalConversation.get_persistence_dir( temp_dir, conv_id ) + registry = LLMRegistry() state = ConversationState.create( workspace=LocalWorkspace(working_dir="/tmp"), persistence_dir=persist_path_for_state, agent=agent, id=conv_id, + llm_registry=registry, ) state.stats.register_llm(RegistryEvent(llm=llm)) @@ -555,6 +634,7 @@ def test_conversation_state_flags_persistence(): persistence_dir=persist_path_for_state, agent=agent, id=conv_id, + llm_registry=registry, ) # Verify key fields are preserved @@ -568,8 +648,12 @@ def test_conversation_state_flags_persistence(): assert loaded_state.model_dump(mode="json") == state.model_dump(mode="json") -def test_conversation_with_agent_different_llm_config(): +def test_conversation_with_agent_different_llm_config(tmp_path, monkeypatch): """Test conversation with agent having different LLM configuration.""" + + home_dir = tmp_path / "home" + monkeypatch.setenv("HOME", str(home_dir)) + with tempfile.TemporaryDirectory() as temp_dir: # Create conversation with original LLM config original_llm = LLM( @@ -622,8 +706,12 @@ def test_conversation_with_agent_different_llm_config(): assert new_dump == original_state_dump -def test_resume_uses_runtime_workspace_and_max_iterations(): +def test_resume_uses_runtime_workspace_and_max_iterations(tmp_path, monkeypatch): """Test that resume uses runtime-provided workspace and max_iterations.""" + + home_dir = tmp_path / "home" + monkeypatch.setenv("HOME", str(home_dir)) + with tempfile.TemporaryDirectory() as temp_dir: llm = LLM( model="gpt-4o-mini", api_key=SecretStr("test-key"), usage_id="test-llm" @@ -655,8 +743,14 @@ def test_resume_uses_runtime_workspace_and_max_iterations(): assert resumed_state.max_iterations == 200 -def test_resume_preserves_persisted_execution_status_and_stuck_detection(): +def test_resume_preserves_persisted_execution_status_and_stuck_detection( + tmp_path, monkeypatch +): """Test that resume preserves execution_status and stuck_detection.""" + + home_dir = tmp_path / "home" + monkeypatch.setenv("HOME", str(home_dir)) + with tempfile.TemporaryDirectory() as temp_dir: llm = LLM( model="gpt-4o-mini", api_key=SecretStr("test-key"), usage_id="test-llm" @@ -686,8 +780,12 @@ def test_resume_preserves_persisted_execution_status_and_stuck_detection(): assert resumed_state.stuck_detection is False -def test_resume_preserves_blocked_actions_and_messages(): +def test_resume_preserves_blocked_actions_and_messages(tmp_path, monkeypatch): """Test that resume preserves blocked_actions and blocked_messages.""" + + home_dir = tmp_path / "home" + monkeypatch.setenv("HOME", str(home_dir)) + with tempfile.TemporaryDirectory() as temp_dir: llm = LLM( model="gpt-4o-mini", api_key=SecretStr("test-key"), usage_id="test-llm" @@ -716,8 +814,12 @@ def test_resume_preserves_blocked_actions_and_messages(): assert resumed_state.blocked_messages["msg-1"] == "inappropriate content" -def test_conversation_state_stats_preserved_on_resume(): +def test_conversation_state_stats_preserved_on_resume(tmp_path, monkeypatch): """Regression: stats should not be reset when resuming a conversation.""" + + home_dir = tmp_path / "home" + monkeypatch.setenv("HOME", str(home_dir)) + with tempfile.TemporaryDirectory() as temp_dir: llm = LLM( model="gpt-4o-mini", api_key=SecretStr("test-key"), usage_id="test-llm" diff --git a/tests/sdk/llm/test_llm_registry_profiles.py b/tests/sdk/llm/test_llm_registry_profiles.py new file mode 100644 index 0000000000..b26bc33cae --- /dev/null +++ b/tests/sdk/llm/test_llm_registry_profiles.py @@ -0,0 +1,356 @@ +import json + +import pytest +from pydantic import SecretStr + +from openhands.sdk.llm.llm import LLM +from openhands.sdk.llm.llm_registry import LLMRegistry + + +def test_list_profiles_returns_sorted_names(tmp_path): + registry = LLMRegistry(profile_dir=tmp_path) + (tmp_path / "b.json").write_text("{}", encoding="utf-8") + (tmp_path / "a.json").write_text("{}", encoding="utf-8") + + assert registry.list_profiles() == ["a", "b"] + + +def test_save_profile_includes_secret_fields_by_default(tmp_path): + registry = LLMRegistry(profile_dir=tmp_path) + llm = LLM( + model="gpt-4o-mini", + usage_id="service", + api_key=SecretStr("secret"), + aws_access_key_id=SecretStr("id"), + aws_secret_access_key=SecretStr("value"), + ) + + path = registry.save_profile("sample", llm) + data = json.loads(path.read_text(encoding="utf-8")) + + assert data["profile_id"] == "sample" + assert data["usage_id"] == "service" + assert data["api_key"] == "secret" + assert data["aws_access_key_id"] == "id" + assert data["aws_secret_access_key"] == "value" + + +def test_save_profile_can_exclude_secret_fields(tmp_path): + registry = LLMRegistry(profile_dir=tmp_path) + llm = LLM( + model="gpt-4o-mini", + usage_id="service", + api_key=SecretStr("secret"), + aws_access_key_id=SecretStr("id"), + aws_secret_access_key=SecretStr("value"), + ) + + path = registry.save_profile("sample", llm, include_secrets=False) + data = json.loads(path.read_text(encoding="utf-8")) + + assert "api_key" not in data + assert "aws_access_key_id" not in data + assert "aws_secret_access_key" not in data + + +def test_load_profile_assigns_profile_id_when_missing(tmp_path): + registry = LLMRegistry(profile_dir=tmp_path) + profile_path = tmp_path / "foo.json" + profile_path.write_text( + json.dumps({"model": "gpt-4o-mini", "usage_id": "svc"}), + encoding="utf-8", + ) + + llm = registry.load_profile("foo") + + assert llm.profile_id == "foo" + assert llm.usage_id == "svc" + + +def test_load_profile_ignores_unknown_fields(tmp_path): + registry = LLMRegistry(profile_dir=tmp_path) + profile_path = tmp_path / "legacy.json" + profile_path.write_text( + json.dumps( + { + "model": "gpt-4o-mini", + "usage_id": "svc", + "metadata": {"profile_description": "Legacy profile payload"}, + "unknown_field": 123, + } + ), + encoding="utf-8", + ) + + llm = registry.load_profile("legacy") + assert llm.usage_id == "svc" + + +def test_llm_serializer_emits_profile_reference_when_profile_id_present(): + llm = LLM(model="gpt-4o-mini", usage_id="service", profile_id="sample") + + payload = llm.model_dump(mode="json") + assert payload == {"profile_id": "sample"} + + +def test_llm_validator_loads_profile_reference(tmp_path): + registry = LLMRegistry(profile_dir=tmp_path) + source_llm = LLM(model="gpt-4o-mini", usage_id="service") + registry.save_profile("profile-tests", source_llm) + + parsed = LLM.model_validate( + {"profile_id": "profile-tests"}, + context={"llm_registry": registry}, + ) + + assert parsed.model == source_llm.model + assert parsed.profile_id == "profile-tests" + assert parsed.usage_id == source_llm.usage_id + + +def test_validate_profile_reports_errors(tmp_path): + registry = LLMRegistry(profile_dir=tmp_path) + + ok, errors = registry.validate_profile({"model": "gpt-4o-mini", "usage_id": "svc"}) + assert ok + assert errors == [] + + ok, errors = registry.validate_profile({"usage_id": "svc"}) + assert not ok + assert any("model" in message for message in errors) + + +def test_get_profile_path_rejects_traversal(tmp_path): + registry = LLMRegistry(profile_dir=tmp_path) + with pytest.raises(ValueError): + registry.get_profile_path("../secret") + + +def test_load_profile_syncs_mismatched_profile_id(tmp_path): + """Test that load_profile syncs profile_id when file name differs from stored id.""" + registry = LLMRegistry(profile_dir=tmp_path) + profile_path = tmp_path / "correct-name.json" + profile_path.write_text( + json.dumps( + { + "model": "gpt-4o-mini", + "usage_id": "svc", + "profile_id": "wrong-name", # Mismatched with filename + } + ), + encoding="utf-8", + ) + + llm = registry.load_profile("correct-name") + + # Should use filename as authoritative profile_id + assert llm.profile_id == "correct-name" + assert llm.usage_id == "svc" + + +def test_profile_id_validation_rejects_invalid_characters(tmp_path): + """Test that profile IDs with invalid characters are rejected.""" + registry = LLMRegistry(profile_dir=tmp_path) + llm = LLM(model="gpt-4o-mini", usage_id="svc") + + # Test various invalid profile IDs + invalid_ids = [ + "", # Empty string + ".", # Single dot + "..", # Double dot + "profile/with/slashes", # Path separators + "profile\\with\\backslashes", # Windows path separators + "profile with spaces", # Spaces (valid per pattern but let's test) + "profile@special!", # Special characters + ] + + for invalid_id in invalid_ids: + with pytest.raises(ValueError): + registry.save_profile(invalid_id, llm) + + +def test_ensure_default_profile_creates_and_overwrites_default(tmp_path): + registry = LLMRegistry(profile_dir=tmp_path) + + llm_a = LLM(model="gpt-4o-mini", usage_id="svc", api_key=SecretStr("k1")) + profiled_a = registry.ensure_default_profile(llm_a) + assert profiled_a.profile_id == "default" + + path = registry.get_profile_path("default") + assert path.exists() + + llm_b = LLM(model="gpt-4o", usage_id="svc", api_key=SecretStr("k2")) + profiled_b = registry.ensure_default_profile(llm_b) + assert profiled_b.profile_id == "default" + + loaded = registry.load_profile("default") + assert loaded.model == "gpt-4o" + + +def test_profile_id_validation_accepts_valid_characters(tmp_path): + """Test that profile IDs with valid characters are accepted.""" + registry = LLMRegistry(profile_dir=tmp_path) + llm = LLM(model="gpt-4o-mini", usage_id="svc") + + # Test various valid profile IDs + valid_ids = [ + "simple", + "with-dashes", + "with_underscores", + "with.dots", + "Mixed123Case", + "all-valid_chars.123", + ] + + for valid_id in valid_ids: + path = registry.save_profile(valid_id, llm) + assert path.exists() + assert path.stem == valid_id + + +def test_llm_model_copy_updates_profile_id(): + """Test that LLM.model_copy can update profile_id.""" + original = LLM(model="gpt-4o-mini", usage_id="svc", profile_id="original") + + updated = original.model_copy(update={"profile_id": "updated"}) + + assert original.profile_id == "original" + assert updated.profile_id == "updated" + assert updated.model == original.model + assert updated.usage_id == original.usage_id + + +def test_load_profile_without_registry_context_requires_registry(tmp_path): + """Profile stubs always need a registry in validation context.""" + + registry = LLMRegistry(profile_dir=tmp_path) + llm = LLM(model="gpt-4o-mini", usage_id="svc") + registry.save_profile("test-profile", llm) + + with pytest.raises(ValueError, match="LLM registry required"): + LLM.model_validate({"profile_id": "test-profile"}) + + +def test_save_profile_sets_restrictive_permissions_on_create(tmp_path): + registry = LLMRegistry(profile_dir=tmp_path) + llm = LLM(model="gpt-4o-mini", usage_id="svc", api_key=SecretStr("k")) + + path = registry.save_profile("perm-test", llm) + + # On POSIX, expect 0o600. On platforms without chmod, at least ensure + # the file is not executable and is owner-readable/writable. + mode = path.stat().st_mode + # Mask to permission bits + perm_bits = mode & 0o777 + assert (perm_bits & 0o111) == 0 # no execute bits + # Expect owner read/write + assert (perm_bits & 0o600) == 0o600 + + +def test_profile_directory_created_on_save_profile(tmp_path): + """Profile directory is created when saving profiles (not on init).""" + + profile_dir = tmp_path / "new" / "nested" / "dir" + assert not profile_dir.exists() + + registry = LLMRegistry(profile_dir=profile_dir) + assert registry.profile_dir == profile_dir + assert registry.list_profiles() == [] + assert not profile_dir.exists() + + llm = LLM(model="gpt-4o-mini", usage_id="svc") + registry.save_profile("dir-create-test", llm) + + assert profile_dir.exists() + assert profile_dir.is_dir() + + +def test_profile_id_preserved_through_serialization_roundtrip(tmp_path): + """Test that profile_id is preserved through save/load cycle.""" + llm = LLM(model="gpt-4o-mini", usage_id="svc", profile_id="test-profile") + + # Serialize + inline_data = llm.model_dump(mode="json") + assert inline_data == {"profile_id": "test-profile"} + + # Deserialize requires a registry (to expand profile) + registry = LLMRegistry(profile_dir=tmp_path) + registry.save_profile("test-profile", llm) + restored = LLM.model_validate(inline_data, context={"llm_registry": registry}) + assert restored.profile_id == "test-profile" + assert restored.model == "gpt-4o-mini" + + +def test_registry_list_usage_ids_after_multiple_adds(tmp_path): + """Test that list_usage_ids correctly tracks multiple LLM instances.""" + registry = LLMRegistry(profile_dir=tmp_path) + + llm1 = LLM(model="gpt-4o-mini", usage_id="service-1") + llm2 = LLM(model="gpt-4o", usage_id="service-2") + llm3 = LLM(model="claude-3-sonnet", usage_id="service-3") + + registry.add(llm1) + registry.add(llm2) + registry.add(llm3) + + usage_ids = registry.list_usage_ids() + assert len(usage_ids) == 3 + assert "service-1" in usage_ids + assert "service-2" in usage_ids + assert "service-3" in usage_ids + + +def test_save_profile_overwrites_existing_file(tmp_path): + """Test that saving a profile overwrites existing file with same name.""" + registry = LLMRegistry(profile_dir=tmp_path) + + # Save initial profile + llm1 = LLM(model="gpt-4o-mini", usage_id="original") + registry.save_profile("test", llm1) + + # Save updated profile with same name + llm2 = LLM(model="gpt-4o", usage_id="updated") + registry.save_profile("test", llm2) + + # Load and verify it's the updated version + loaded = registry.load_profile("test") + assert loaded.model == "gpt-4o" + assert loaded.usage_id == "updated" + + +def test_load_profile_not_found_raises_file_not_found_error(tmp_path): + """Test that loading non-existent profile raises FileNotFoundError.""" + registry = LLMRegistry(profile_dir=tmp_path) + + with pytest.raises(FileNotFoundError, match="Profile not found"): + registry.load_profile("nonexistent") + + +def test_registry_subscriber_notification_on_add(tmp_path): + """Test that registry notifies subscriber when LLM is added.""" + registry = LLMRegistry(profile_dir=tmp_path) + notifications = [] + + def subscriber(event): + notifications.append(event) + + registry.subscribe(subscriber) + + llm = LLM(model="gpt-4o-mini", usage_id="test") + registry.add(llm) + + assert len(notifications) == 1 + assert notifications[0].llm.model == "gpt-4o-mini" + assert notifications[0].llm.usage_id == "test" + + +def test_profile_serialization_mode_reference_only(tmp_path): + """Test that non-inline mode returns only profile_id reference.""" + llm = LLM(model="gpt-4o-mini", usage_id="svc", profile_id="ref-test") + + ref_data = llm.model_dump(mode="json") + + # Should only contain profile_id + assert ref_data == {"profile_id": "ref-test"} + assert "model" not in ref_data + assert "usage_id" not in ref_data diff --git a/tests/sdk/utils/test_discriminated_union.py b/tests/sdk/utils/test_discriminated_union.py index 07da9dcb83..f4608b824b 100644 --- a/tests/sdk/utils/test_discriminated_union.py +++ b/tests/sdk/utils/test_discriminated_union.py @@ -1,3 +1,4 @@ +import gc from abc import ABC, abstractmethod from typing import ClassVar @@ -228,7 +229,7 @@ def test_model_containing_polymorphic_field(): def test_duplicate_kind(): - # nAn error should be raised when a duplicate class name is detected + # An error should be raised when a duplicate class definition is detected. with pytest.raises(ValueError) as exc_info: @@ -246,6 +247,9 @@ class SomeImpl(SomeBase): ) assert expected in error_message + # Ensure the failed subclass definition does not leak into subsequent tests. + gc.collect() + def test_enhanced_error_message_with_validation(): """Test that the enhanced error message appears during model validation."""