diff --git a/examples/01_standalone_sdk/19_llm_routing.py b/examples/01_standalone_sdk/19_llm_routing.py index 00d304a559..ba4b3fc056 100644 --- a/examples/01_standalone_sdk/19_llm_routing.py +++ b/examples/01_standalone_sdk/19_llm_routing.py @@ -1,4 +1,5 @@ import os +import uuid from pydantic import SecretStr @@ -56,8 +57,14 @@ def conversation_callback(event: Event): llm_messages.append(event.to_llm_message()) +conversation_id = uuid.uuid4() + conversation = Conversation( - agent=agent, callbacks=[conversation_callback], workspace=os.getcwd() + agent=agent, + callbacks=[conversation_callback], + conversation_id=conversation_id, + workspace=os.getcwd(), + persistence_dir="./.conversations", ) conversation.send_message( @@ -81,6 +88,24 @@ def conversation_callback(event: Event): ) conversation.run() +# Test conversation serialization +print("Conversation finished. Got the following LLM messages:") +for i, message in enumerate(llm_messages): + print(f"Message {i}: {str(message)[:200]}") + +print("Serializing conversation...") + +del conversation + +print("Deserializing conversation...") + +conversation = Conversation( + agent=agent, + callbacks=[conversation_callback], + persistence_dir="./.conversations", + conversation_id=conversation_id, +) + conversation.send_message( message=Message( role="user", diff --git a/examples/01_standalone_sdk/25_llm_manual_switch.py b/examples/01_standalone_sdk/25_llm_manual_switch.py new file mode 100644 index 0000000000..c01937f244 --- /dev/null +++ b/examples/01_standalone_sdk/25_llm_manual_switch.py @@ -0,0 +1,158 @@ +import os +import uuid + +from pydantic import SecretStr + +from openhands.sdk import ( + LLM, + Agent, + Conversation, + Event, + LLMConvertibleEvent, + Message, + TextContent, + get_logger, +) +from openhands.sdk.llm.router.impl.dynamic import DynamicRouter +from openhands.tools.preset.default import get_default_tools + + +logger = get_logger(__name__) + +# Configure initial LLM +api_key = os.getenv("LLM_API_KEY") +assert api_key is not None, "LLM_API_KEY environment variable is not set." + +# Create DynamicRouter with 2 initial LLMs +claude_llm = LLM( + service_id="agent-initial", + model="litellm_proxy/anthropic/claude-sonnet-4-5-20250929", + base_url="https://llm-proxy.eval.all-hands.dev", + api_key=SecretStr(api_key), +) + +gpt_5_llm = LLM( + service_id="gpt-5", + model="litellm_proxy/openai/gpt-5-2025-08-07", + base_url="https://llm-proxy.eval.all-hands.dev", + api_key=SecretStr(api_key), +) + +dynamic_router = DynamicRouter( + service_id="dynamic-router", + llms_for_routing={ + "primary": claude_llm, + "gpt-5": gpt_5_llm, + }, # primary is the default +) + +# Tools +cwd = os.getcwd() +tools = get_default_tools() + +# Agent with dynamic router +agent = Agent(llm=dynamic_router, tools=tools) + +llm_messages = [] # collect raw LLM messages + + +def conversation_callback(event: Event): + if isinstance(event, LLMConvertibleEvent): + llm_messages.append(event.to_llm_message()) + + +# Set up conversation with persistence for serialization demo +conversation_id = uuid.uuid4() + +conversation = Conversation( + agent=agent, + callbacks=[conversation_callback], + conversation_id=conversation_id, + workspace=os.getcwd(), + persistence_dir="./.conversations", +) + +print(f"Starting with LLM: {dynamic_router.active_llm_identifier}") +print(f"Available LLMs: {list(dynamic_router.llms_for_routing.keys())}") + +# First interaction with Claude - primary LLM +conversation.send_message( + message=Message( + role="user", + content=[TextContent(text="Hi there!")], + ) +) +conversation.run() + +print("=" * 50) +print("Switching to GPT-5...") + +# Manually switch to GPT-5 +success = dynamic_router.switch_to_llm("gpt-5") +print(f"GPT-5 switched successfully: {success}") +print(f"Current LLM: {dynamic_router.active_llm_identifier}") + +# Interaction with GPT-5 +conversation.send_message( + message=Message( + role="user", + content=[TextContent(text="Who trained you as an LLM?")], + ) +) +conversation.run() + + +# Show current state before serialization +print(f"Before serialization - Current LLM: {dynamic_router.active_llm_identifier}") +print(f"Available LLMs: {list(dynamic_router.llms_for_routing.keys())}") + +# Delete conversation to simulate restart +del conversation + +# Recreate conversation from persistence +print("Recreating conversation from persistence...") +conversation = Conversation( + agent=agent, + callbacks=[conversation_callback], + conversation_id=conversation_id, + persistence_dir="./.conversations", +) + +print(f"After deserialization - Current LLM: {dynamic_router.active_llm_identifier}") +assert dynamic_router.active_llm_identifier == "gpt-5" +print(f"Available LLMs: {list(dynamic_router.llms_for_routing.keys())}") + +# Continue conversation after persistence +conversation.send_message( + message=Message( + role="user", + content=[TextContent(text="What did we talk about earlier?")], + ) +) +conversation.run() + +# Switch back to primary model for complex task +print("Switching back to claude for complex reasoning...") + +dynamic_router.switch_to_llm("primary") +print(f"Switched to LLM: {dynamic_router.active_llm_identifier}") + +conversation.send_message( + message=Message( + role="user", + content=[ + TextContent( + text="Explain the concept of dynamic programming in one sentence." + ) + ], + ) +) +conversation.run() + +print("Demonstrating persistence with LLM switching...") + + +print("=" * 100) +print("Conversation finished. Got the following LLM messages:") +for i, message in enumerate(llm_messages): + print(f"Message {i}: {str(message)[:200]}") diff --git a/openhands-agent-server/openhands/agent_server/conversation_service.py b/openhands-agent-server/openhands/agent_server/conversation_service.py index 736dbd1de6..5276280058 100644 --- a/openhands-agent-server/openhands/agent_server/conversation_service.py +++ b/openhands-agent-server/openhands/agent_server/conversation_service.py @@ -20,7 +20,7 @@ from openhands.agent_server.pub_sub import Subscriber from openhands.agent_server.server_details_router import update_last_execution_time from openhands.agent_server.utils import utc_now -from openhands.sdk import LLM, Event, Message +from openhands.sdk import Event, LLMBase, Message from openhands.sdk.conversation.state import AgentExecutionStatus, ConversationState @@ -276,7 +276,7 @@ async def get_event_service(self, conversation_id: UUID) -> EventService | None: return self._event_services.get(conversation_id) async def generate_conversation_title( - self, conversation_id: UUID, max_length: int = 50, llm: LLM | None = None + self, conversation_id: UUID, max_length: int = 50, llm: LLMBase | None = None ) -> str | None: """Generate a title for the conversation using LLM.""" if self._event_services is None: diff --git a/openhands-agent-server/openhands/agent_server/event_service.py b/openhands-agent-server/openhands/agent_server/event_service.py index 60cc92610e..a413afd9f7 100644 --- a/openhands-agent-server/openhands/agent_server/event_service.py +++ b/openhands-agent-server/openhands/agent_server/event_service.py @@ -11,7 +11,7 @@ ) from openhands.agent_server.pub_sub import PubSub, Subscriber from openhands.agent_server.utils import utc_now -from openhands.sdk import LLM, Agent, Event, Message, get_logger +from openhands.sdk import Agent, Event, LLMBase, Message, get_logger from openhands.sdk.conversation.impl.local_conversation import LocalConversation from openhands.sdk.conversation.secrets_manager import SecretValue from openhands.sdk.conversation.state import AgentExecutionStatus, ConversationState @@ -255,7 +255,7 @@ async def close(self): loop.run_in_executor(None, self._conversation.close) async def generate_title( - self, llm: "LLM | None" = None, max_length: int = 50 + self, llm: "LLMBase | None" = None, max_length: int = 50 ) -> str: """Generate a title for the conversation. diff --git a/openhands-agent-server/openhands/agent_server/models.py b/openhands-agent-server/openhands/agent_server/models.py index c7ffeccc89..3653c1a372 100644 --- a/openhands-agent-server/openhands/agent_server/models.py +++ b/openhands-agent-server/openhands/agent_server/models.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, Field from openhands.agent_server.utils import utc_now -from openhands.sdk import LLM, AgentBase, Event, ImageContent, Message, TextContent +from openhands.sdk import AgentBase, Event, ImageContent, LLMBase, Message, TextContent from openhands.sdk.conversation.secret_source import SecretSource from openhands.sdk.conversation.state import AgentExecutionStatus, ConversationState from openhands.sdk.llm.utils.metrics import MetricsSnapshot @@ -176,7 +176,7 @@ class GenerateTitleRequest(BaseModel): max_length: int = Field( default=50, ge=1, le=200, description="Maximum length of the generated title" ) - llm: LLM | None = Field( + llm: LLMBase | None = Field( default=None, description="Optional LLM to use for title generation" ) diff --git a/openhands-sdk/openhands/sdk/__init__.py b/openhands-sdk/openhands/sdk/__init__.py index 07ec5f48a1..b846361afc 100644 --- a/openhands-sdk/openhands/sdk/__init__.py +++ b/openhands-sdk/openhands/sdk/__init__.py @@ -19,6 +19,7 @@ from openhands.sdk.llm import ( LLM, ImageContent, + LLMBase, LLMRegistry, Message, RedactedThinkingBlock, @@ -56,6 +57,7 @@ __version__ = "0.0.0" # fallback for editable/unbuilt environments __all__ = [ + "LLMBase", "LLM", "LLMRegistry", "ConversationStats", diff --git a/openhands-sdk/openhands/sdk/agent/base.py b/openhands-sdk/openhands/sdk/agent/base.py index c17bf01bdc..d81e55cecf 100644 --- a/openhands-sdk/openhands/sdk/agent/base.py +++ b/openhands-sdk/openhands/sdk/agent/base.py @@ -3,6 +3,7 @@ import sys from abc import ABC, abstractmethod from collections.abc import Generator, Iterable +from types import MappingProxyType from typing import TYPE_CHECKING, Any from pydantic import BaseModel, ConfigDict, Field, PrivateAttr @@ -11,7 +12,7 @@ from openhands.sdk.context.agent_context import AgentContext from openhands.sdk.context.condenser import CondenserBase, LLMSummarizingCondenser from openhands.sdk.context.prompts.prompt import render_template -from openhands.sdk.llm import LLM +from openhands.sdk.llm import LLM, LLMBase from openhands.sdk.logger import get_logger from openhands.sdk.mcp import create_mcp_tools from openhands.sdk.security.llm_analyzer import LLMSecurityAnalyzer @@ -37,7 +38,7 @@ class AgentBase(DiscriminatedUnionMixin, ABC): arbitrary_types_allowed=True, ) - llm: LLM = Field( + llm: LLMBase = Field( ..., description="LLM configuration for the agent.", examples=[ @@ -377,7 +378,7 @@ def _walk(obj: object) -> Iterable[LLM]: return model_out # Built-in containers - if isinstance(obj, dict): + if isinstance(obj, dict) or isinstance(obj, MappingProxyType): dict_out: list[LLM] = [] for k, v in obj.items(): dict_out.extend(_walk(k)) diff --git a/openhands-sdk/openhands/sdk/context/condenser/llm_summarizing_condenser.py b/openhands-sdk/openhands/sdk/context/condenser/llm_summarizing_condenser.py index a5d518494f..c08e7be30e 100644 --- a/openhands-sdk/openhands/sdk/context/condenser/llm_summarizing_condenser.py +++ b/openhands-sdk/openhands/sdk/context/condenser/llm_summarizing_condenser.py @@ -7,11 +7,11 @@ from openhands.sdk.context.view import View from openhands.sdk.event.condenser import Condensation from openhands.sdk.event.llm_convertible import MessageEvent -from openhands.sdk.llm import LLM, Message, TextContent +from openhands.sdk.llm import LLMBase, Message, TextContent class LLMSummarizingCondenser(RollingCondenser): - llm: LLM + llm: LLMBase max_size: int = Field(default=120, gt=0) keep_first: int = Field(default=4, ge=0) diff --git a/openhands-sdk/openhands/sdk/conversation/base.py b/openhands-sdk/openhands/sdk/conversation/base.py index 59384dbfc9..a2bd079048 100644 --- a/openhands-sdk/openhands/sdk/conversation/base.py +++ b/openhands-sdk/openhands/sdk/conversation/base.py @@ -7,7 +7,7 @@ from openhands.sdk.conversation.events_list_base import EventsListBase from openhands.sdk.conversation.secrets_manager import SecretValue from openhands.sdk.conversation.types import ConversationCallbackType, ConversationID -from openhands.sdk.llm.llm import LLM +from openhands.sdk.llm.llm import LLMBase from openhands.sdk.llm.message import Message, content_to_str from openhands.sdk.security.confirmation_policy import ( ConfirmationPolicyBase, @@ -123,7 +123,7 @@ def update_secrets(self, secrets: Mapping[str, SecretValue]) -> None: ... def close(self) -> None: ... @abstractmethod - def generate_title(self, llm: LLM | None = None, max_length: int = 50) -> str: + def generate_title(self, llm: LLMBase | None = None, max_length: int = 50) -> str: """Generate a title for the conversation based on the first user message. Args: diff --git a/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py b/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py index 976b9d254a..9dbf6a55ca 100644 --- a/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py +++ b/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py @@ -19,7 +19,7 @@ PauseEvent, UserRejectObservation, ) -from openhands.sdk.llm import LLM, Message, TextContent +from openhands.sdk.llm import LLMBase, Message, TextContent from openhands.sdk.llm.llm_registry import LLMRegistry from openhands.sdk.logger import get_logger from openhands.sdk.security.confirmation_policy import ( @@ -366,7 +366,7 @@ def close(self) -> None: except Exception as e: logger.warning(f"Error closing executor for tool '{tool.name}': {e}") - def generate_title(self, llm: LLM | None = None, max_length: int = 50) -> str: + def generate_title(self, llm: LLMBase | None = None, max_length: int = 50) -> str: """Generate a title for the conversation based on the first user message. Args: diff --git a/openhands-sdk/openhands/sdk/conversation/impl/remote_conversation.py b/openhands-sdk/openhands/sdk/conversation/impl/remote_conversation.py index 10426ab73b..9e41cb316f 100644 --- a/openhands-sdk/openhands/sdk/conversation/impl/remote_conversation.py +++ b/openhands-sdk/openhands/sdk/conversation/impl/remote_conversation.py @@ -26,7 +26,7 @@ FULL_STATE_KEY, ConversationStateUpdateEvent, ) -from openhands.sdk.llm import LLM, Message, TextContent +from openhands.sdk.llm import LLMBase, Message, TextContent from openhands.sdk.logger import get_logger from openhands.sdk.security.confirmation_policy import ( ConfirmationPolicyBase, @@ -601,7 +601,7 @@ def update_secrets(self, secrets: Mapping[str, SecretValue]) -> None: self._client, "POST", f"/api/conversations/{self._id}/secrets", json=payload ) - def generate_title(self, llm: LLM | None = None, max_length: int = 50) -> str: + def generate_title(self, llm: LLMBase | None = None, max_length: int = 50) -> str: """Generate a title for the conversation based on the first user message. Args: diff --git a/openhands-sdk/openhands/sdk/conversation/title_utils.py b/openhands-sdk/openhands/sdk/conversation/title_utils.py index a971b1be1a..5f3b86558b 100644 --- a/openhands-sdk/openhands/sdk/conversation/title_utils.py +++ b/openhands-sdk/openhands/sdk/conversation/title_utils.py @@ -4,7 +4,7 @@ from openhands.sdk.event import MessageEvent from openhands.sdk.event.base import Event -from openhands.sdk.llm import LLM, Message, TextContent +from openhands.sdk.llm import LLMBase, Message, TextContent from openhands.sdk.logger import get_logger @@ -56,7 +56,9 @@ def extract_first_user_message(events: Sequence[Event]) -> str | None: return None -def generate_title_with_llm(message: str, llm: LLM, max_length: int = 50) -> str | None: +def generate_title_with_llm( + message: str, llm: LLMBase, max_length: int = 50 +) -> str | None: """Generate a conversation title using LLM. Args: @@ -155,7 +157,7 @@ def generate_fallback_title(message: str, max_length: int = 50) -> str: def generate_conversation_title( - events: Sequence[Event], llm: LLM | None = None, max_length: int = 50 + events: Sequence[Event], llm: LLMBase | None = None, max_length: int = 50 ) -> str: """Generate a title for a conversation based on the first user message. diff --git a/openhands-sdk/openhands/sdk/llm/__init__.py b/openhands-sdk/openhands/sdk/llm/__init__.py index fabed357d1..5c3124c95e 100644 --- a/openhands-sdk/openhands/sdk/llm/__init__.py +++ b/openhands-sdk/openhands/sdk/llm/__init__.py @@ -1,4 +1,4 @@ -from openhands.sdk.llm.llm import LLM +from openhands.sdk.llm.llm import LLM, LLMBase from openhands.sdk.llm.llm_registry import LLMRegistry, RegistryEvent from openhands.sdk.llm.llm_response import LLMResponse from openhands.sdk.llm.message import ( @@ -11,7 +11,12 @@ ThinkingBlock, content_to_str, ) -from openhands.sdk.llm.router import RouterLLM +from openhands.sdk.llm.router import ( + DynamicRouter, + MultimodalRouter, + RandomRouter, + RouterLLM, +) from openhands.sdk.llm.utils.metrics import Metrics, MetricsSnapshot from openhands.sdk.llm.utils.unverified_models import ( UNVERIFIED_MODELS_EXCLUDING_BEDROCK, @@ -22,9 +27,13 @@ __all__ = [ "LLMResponse", + "LLMBase", "LLM", "LLMRegistry", "RouterLLM", + "RandomRouter", + "DynamicRouter", + "MultimodalRouter", "RegistryEvent", "Message", "MessageToolCall", diff --git a/openhands-sdk/openhands/sdk/llm/llm.py b/openhands-sdk/openhands/sdk/llm/llm.py index 68d7394f2c..0b9f9f0260 100644 --- a/openhands-sdk/openhands/sdk/llm/llm.py +++ b/openhands-sdk/openhands/sdk/llm/llm.py @@ -4,6 +4,7 @@ import json import os import warnings +from abc import abstractmethod from collections.abc import Callable, Sequence from contextlib import contextmanager from typing import TYPE_CHECKING, Any, ClassVar, Literal, get_args, get_origin @@ -26,6 +27,7 @@ if TYPE_CHECKING: # type hints only, avoid runtime import cycle from openhands.sdk.tool.tool import ToolBase +from openhands.sdk.utils.models import DiscriminatedUnionMixin from openhands.sdk.utils.pydantic_diff import pretty_pydantic_diff @@ -79,7 +81,7 @@ logger = get_logger(__name__) -__all__ = ["LLM"] +__all__ = ["LLM", "LLMBase"] # Exceptions we retry on @@ -98,7 +100,7 @@ ) -class LLM(BaseModel, RetryMixin, NonNativeToolCallingMixin): +class LLMBase(DiscriminatedUnionMixin, RetryMixin, NonNativeToolCallingMixin): """Refactored LLM: simple `completion()`, centralized Telemetry, tiny helpers.""" # ========================================================================= @@ -420,6 +422,7 @@ def restore_metrics(self, metrics: Metrics) -> None: # Only used by ConversationStats to seed metrics self._metrics = metrics + @abstractmethod def completion( self, messages: list[Message], @@ -432,111 +435,7 @@ def completion( Normalize → (maybe) mock tools → transport → postprocess. """ - # Check if streaming is requested - if kwargs.get("stream", False): - raise ValueError("Streaming is not supported") - - # 1) serialize messages - formatted_messages = self.format_messages_for_llm(messages) - - # 2) choose function-calling strategy - use_native_fc = self.is_function_calling_active() - original_fncall_msgs = copy.deepcopy(formatted_messages) - - # Convert Tool objects to ChatCompletionToolParam once here - cc_tools: list[ChatCompletionToolParam] = [] - if tools: - cc_tools = [ - t.to_openai_tool( - add_security_risk_prediction=add_security_risk_prediction - ) - for t in tools - ] - - use_mock_tools = self.should_mock_tool_calls(cc_tools) - if use_mock_tools: - logger.debug( - "LLM.completion: mocking function-calling via prompt " - f"for model {self.model}" - ) - formatted_messages, kwargs = self.pre_request_prompt_mock( - formatted_messages, cc_tools or [], kwargs - ) - - # 3) normalize provider params - # Only pass tools when native FC is active - kwargs["tools"] = cc_tools if (bool(cc_tools) and use_native_fc) else None - has_tools_flag = bool(cc_tools) and use_native_fc - # Behavior-preserving: delegate to select_chat_options - call_kwargs = select_chat_options(self, kwargs, has_tools=has_tools_flag) - - # 4) optional request logging context (kept small) - assert self._telemetry is not None - log_ctx = None - if self._telemetry.log_enabled: - log_ctx = { - "messages": formatted_messages[:], # already simple dicts - "tools": tools, - "kwargs": {k: v for k, v in call_kwargs.items()}, - "context_window": self.max_input_tokens, - } - if tools and not use_native_fc: - log_ctx["raw_messages"] = original_fncall_msgs - self._telemetry.on_request(log_ctx=log_ctx) - - # 5) do the call with retries - @self.retry_decorator( - num_retries=self.num_retries, - retry_exceptions=LLM_RETRY_EXCEPTIONS, - retry_min_wait=self.retry_min_wait, - retry_max_wait=self.retry_max_wait, - retry_multiplier=self.retry_multiplier, - retry_listener=self.retry_listener, - ) - def _one_attempt(**retry_kwargs) -> ModelResponse: - assert self._telemetry is not None - # Merge retry-modified kwargs (like temperature) with call_kwargs - final_kwargs = {**call_kwargs, **retry_kwargs} - resp = self._transport_call(messages=formatted_messages, **final_kwargs) - raw_resp: ModelResponse | None = None - if use_mock_tools: - raw_resp = copy.deepcopy(resp) - resp = self.post_response_prompt_mock( - resp, nonfncall_msgs=formatted_messages, tools=cc_tools - ) - # 6) telemetry - self._telemetry.on_response(resp, raw_resp=raw_resp) - - # Ensure at least one choice - if not resp.get("choices") or len(resp["choices"]) < 1: - raise LLMNoResponseError( - "Response choices is less than 1. Response: " + str(resp) - ) - - return resp - - try: - resp = _one_attempt() - - # Convert the first choice to an OpenHands Message - first_choice = resp["choices"][0] - message = Message.from_llm_chat_message(first_choice["message"]) - - # Get current metrics snapshot - metrics_snapshot = MetricsSnapshot( - model_name=self.metrics.model_name, - accumulated_cost=self.metrics.accumulated_cost, - max_budget_per_task=self.metrics.max_budget_per_task, - accumulated_token_usage=self.metrics.accumulated_token_usage, - ) - - # Create and return LLMResponse - return LLMResponse( - message=message, metrics=metrics_snapshot, raw_response=resp - ) - except Exception as e: - self._telemetry.on_error(e) - raise + pass # ========================================================================= # Responses API (non-stream, v1) @@ -886,6 +785,9 @@ def format_messages_for_llm(self, messages: list[Message]) -> list[dict]: message.cache_enabled = self.is_caching_prompt_active() message.vision_enabled = self.vision_is_active() message.function_calling_enabled = self.is_function_calling_active() + message.extended_thinking_enabled = get_features( + self.model + ).supports_extended_thinking if "deepseek" in self.model or ( "kimi-k2-instruct" in self.model and "groq" in self.model ): @@ -958,13 +860,13 @@ def get_token_count(self, messages: list[Message]) -> int: # Serialization helpers # ========================================================================= @classmethod - def load_from_json(cls, json_path: str) -> LLM: + def load_from_json(cls, json_path: str) -> LLMBase: with open(json_path) as f: data = json.load(f) return cls(**data) @classmethod - def load_from_env(cls, prefix: str = "LLM_") -> LLM: + def load_from_env(cls, prefix: str = "LLM_") -> LLMBase: TRUTHY = {"true", "1", "yes", "on"} def _unwrap_type(t: Any) -> Any: @@ -1018,7 +920,8 @@ def _cast_value(raw: str, t: Any) -> Any: data[field_name] = v return cls(**data) - def resolve_diff_from_deserialized(self, persisted: LLM) -> LLM: + @abstractmethod + def resolve_diff_from_deserialized(self, persisted: LLMBase) -> LLMBase: """Resolve differences between a deserialized LLM and the current instance. This is due to fields like api_key being serialized to "****" in dumps, @@ -1028,31 +931,7 @@ def resolve_diff_from_deserialized(self, persisted: LLM) -> LLM: Return a new LLM instance equivalent to `persisted` but with explicitly whitelisted fields (e.g. api_key) taken from `self`. """ - if persisted.__class__ is not self.__class__: - raise ValueError( - f"Cannot resolve_diff_from_deserialized between {self.__class__} " - f"and {persisted.__class__}" - ) - - # Copy allowed fields from runtime llm into the persisted llm - llm_updates = {} - persisted_dump = persisted.model_dump(exclude_none=True) - for field in self.OVERRIDE_ON_SERIALIZE: - if field in persisted_dump.keys(): - llm_updates[field] = getattr(self, field) - if llm_updates: - reconciled = persisted.model_copy(update=llm_updates) - else: - reconciled = persisted - - if self.model_dump(exclude_none=True) != reconciled.model_dump( - exclude_none=True - ): - raise ValueError( - "The LLM provided is different from the one in persisted state.\n" - f"Diff: {pretty_pydantic_diff(self, reconciled)}" - ) - return reconciled + pass @staticmethod def is_context_window_exceeded_exception(exception: Exception) -> bool: @@ -1101,3 +980,146 @@ def is_context_window_exceeded_exception(exception: Exception) -> bool: # window exceeded error, we'll have to assume it's not and rely on the call-site # context to handle it appropriately. return False + + +class LLM(LLMBase): + def completion( + self, + messages: list[Message], + tools: Sequence[ToolBase] | None = None, + _return_metrics: bool = False, + add_security_risk_prediction: bool = False, + **kwargs, + ) -> LLMResponse: + # Check if streaming is requested + if kwargs.get("stream", False): + raise ValueError("Streaming is not supported") + + # 1) serialize messages + formatted_messages = self.format_messages_for_llm(messages) + + # 2) choose function-calling strategy + use_native_fc = self.is_function_calling_active() + original_fncall_msgs = copy.deepcopy(formatted_messages) + + # Convert Tool objects to ChatCompletionToolParam once here + cc_tools: list[ChatCompletionToolParam] = [] + if tools: + cc_tools = [ + t.to_openai_tool( + add_security_risk_prediction=add_security_risk_prediction + ) + for t in tools + ] + + use_mock_tools = self.should_mock_tool_calls(cc_tools) + if use_mock_tools: + logger.debug( + "LLM.completion: mocking function-calling via prompt " + f"for model {self.model}" + ) + formatted_messages, kwargs = self.pre_request_prompt_mock( + formatted_messages, cc_tools or [], kwargs + ) + + # 3) normalize provider params + # Only pass tools when native FC is active + kwargs["tools"] = cc_tools if (bool(cc_tools) and use_native_fc) else None + has_tools_flag = bool(cc_tools) and use_native_fc + # Behavior-preserving: delegate to select_chat_options + call_kwargs = select_chat_options(self, kwargs, has_tools=has_tools_flag) + + # 4) optional request logging context (kept small) + assert self._telemetry is not None + log_ctx = None + if self._telemetry.log_enabled: + log_ctx = { + "messages": formatted_messages[:], # already simple dicts + "tools": tools, + "kwargs": {k: v for k, v in call_kwargs.items()}, + "context_window": self.max_input_tokens, + } + if tools and not use_native_fc: + log_ctx["raw_messages"] = original_fncall_msgs + self._telemetry.on_request(log_ctx=log_ctx) + + # 5) do the call with retries + @self.retry_decorator( + num_retries=self.num_retries, + retry_exceptions=LLM_RETRY_EXCEPTIONS, + retry_min_wait=self.retry_min_wait, + retry_max_wait=self.retry_max_wait, + retry_multiplier=self.retry_multiplier, + retry_listener=self.retry_listener, + ) + def _one_attempt(**retry_kwargs) -> ModelResponse: + assert self._telemetry is not None + # Merge retry-modified kwargs (like temperature) with call_kwargs + final_kwargs = {**call_kwargs, **retry_kwargs} + resp = self._transport_call(messages=formatted_messages, **final_kwargs) + raw_resp: ModelResponse | None = None + if use_mock_tools: + raw_resp = copy.deepcopy(resp) + resp = self.post_response_prompt_mock( + resp, nonfncall_msgs=formatted_messages, tools=cc_tools + ) + # 6) telemetry + self._telemetry.on_response(resp, raw_resp=raw_resp) + + # Ensure at least one choice + if not resp.get("choices") or len(resp["choices"]) < 1: + raise LLMNoResponseError( + "Response choices is less than 1. Response: " + str(resp) + ) + + return resp + + try: + resp = _one_attempt() + + # Convert the first choice to an OpenHands Message + first_choice = resp["choices"][0] + message = Message.from_llm_chat_message(first_choice["message"]) + + # Get current metrics snapshot + metrics_snapshot = MetricsSnapshot( + model_name=self.metrics.model_name, + accumulated_cost=self.metrics.accumulated_cost, + max_budget_per_task=self.metrics.max_budget_per_task, + accumulated_token_usage=self.metrics.accumulated_token_usage, + ) + + # Create and return LLMResponse + return LLMResponse( + message=message, metrics=metrics_snapshot, raw_response=resp + ) + except Exception as e: + self._telemetry.on_error(e) + raise + + def resolve_diff_from_deserialized(self, persisted: LLMBase) -> LLMBase: + if persisted.__class__ is not self.__class__: + raise ValueError( + f"Cannot resolve_diff_from_deserialized between {self.__class__} " + f"and {persisted.__class__}" + ) + + # Copy allowed fields from runtime llm into the persisted llm + llm_updates = {} + persisted_dump = persisted.model_dump(exclude_none=True) + for field in self.OVERRIDE_ON_SERIALIZE: + if field in persisted_dump.keys(): + llm_updates[field] = getattr(self, field) + if llm_updates: + reconciled = persisted.model_copy(update=llm_updates) + else: + reconciled = persisted + + if self.model_dump(exclude_none=True) != reconciled.model_dump( + exclude_none=True + ): + raise ValueError( + "The LLM provided is different from the one in persisted state.\n" + f"Diff: {pretty_pydantic_diff(self, reconciled)}" + ) + return reconciled diff --git a/openhands-sdk/openhands/sdk/llm/message.py b/openhands-sdk/openhands/sdk/llm/message.py index 8ba235a557..b353686f8b 100644 --- a/openhands-sdk/openhands/sdk/llm/message.py +++ b/openhands-sdk/openhands/sdk/llm/message.py @@ -225,6 +225,7 @@ class Message(BaseModel): description="Intermediate reasoning/thinking content from reasoning models", ) # Anthropic-specific thinking blocks (not normalized by LiteLLM) + extended_thinking_enabled: bool = False thinking_blocks: Sequence[ThinkingBlock | RedactedThinkingBlock] = Field( default_factory=list, description="Raw Anthropic thinking blocks for extended thinking feature", @@ -296,7 +297,7 @@ def _list_serializer(self) -> dict[str, Any]: # Add thinking blocks first (for Anthropic extended thinking) # Only add thinking blocks for assistant messages - if self.role == "assistant": + if self.role == "assistant" and self.extended_thinking_enabled: thinking_blocks = list( self.thinking_blocks ) # Copy to avoid modifying original diff --git a/openhands-sdk/openhands/sdk/llm/router/__init__.py b/openhands-sdk/openhands/sdk/llm/router/__init__.py index 37e7baca4a..171020f694 100644 --- a/openhands-sdk/openhands/sdk/llm/router/__init__.py +++ b/openhands-sdk/openhands/sdk/llm/router/__init__.py @@ -1,4 +1,5 @@ from openhands.sdk.llm.router.base import RouterLLM +from openhands.sdk.llm.router.impl.dynamic import DynamicRouter from openhands.sdk.llm.router.impl.multimodal import MultimodalRouter from openhands.sdk.llm.router.impl.random import RandomRouter @@ -7,4 +8,5 @@ "RouterLLM", "RandomRouter", "MultimodalRouter", + "DynamicRouter", ] diff --git a/openhands-sdk/openhands/sdk/llm/router/base.py b/openhands-sdk/openhands/sdk/llm/router/base.py index 7a1286ce13..917d0dc459 100644 --- a/openhands-sdk/openhands/sdk/llm/router/base.py +++ b/openhands-sdk/openhands/sdk/llm/router/base.py @@ -1,17 +1,20 @@ from abc import abstractmethod from collections.abc import Sequence +from types import MappingProxyType from pydantic import ( Field, + field_serializer, field_validator, model_validator, ) -from openhands.sdk.llm.llm import LLM +from openhands.sdk.llm.llm import LLM, LLMBase from openhands.sdk.llm.llm_response import LLMResponse from openhands.sdk.llm.message import Message from openhands.sdk.logger import get_logger from openhands.sdk.tool.tool import ToolBase +from openhands.sdk.utils.pydantic_diff import pretty_pydantic_diff logger = get_logger(__name__) @@ -30,11 +33,12 @@ class RouterLLM(LLM): """ router_name: str = Field(default="base_router", description="Name of the router") - llms_for_routing: dict[str, LLM] = Field( + llms_for_routing: dict[str, LLMBase] = Field( default_factory=dict ) # Mapping of LLM name to LLM instance for routing - active_llm: LLM | None = Field( - default=None, description="Currently selected LLM instance" + + active_llm_identifier: str | None = Field( + default=None, description="Currently selected LLM's identifier" ) @field_validator("llms_for_routing") @@ -46,11 +50,23 @@ def validate_llms_not_empty(cls, v): ) return v + @model_validator(mode="after") + def make_immutable(self): + object.__setattr__( + self, "llms_for_routing", MappingProxyType(self.llms_for_routing) + ) + return self + + @field_serializer("llms_for_routing") + def serialize_llms_for_routing(self, v): + # Convert MappingProxyType back to a serializable dict + return dict(v) + def completion( self, messages: list[Message], tools: Sequence[ToolBase] | None = None, - return_metrics: bool = False, + _return_metrics: bool = False, add_security_risk_prediction: bool = False, **kwargs, ) -> LLMResponse: @@ -59,16 +75,16 @@ def completion( underlying LLM based on the routing logic implemented in select_llm(). """ # Select appropriate LLM - selected_model = self.select_llm(messages) - self.active_llm = self.llms_for_routing[selected_model] + self.active_llm_identifier = self.select_llm(messages) + active_llm = self.llms_for_routing[self.active_llm_identifier] - logger.info(f"RouterLLM routing to {selected_model}...") + logger.info(f"RouterLLM routing to {self.active_llm_identifier}...") # Delegate to selected LLM - return self.active_llm.completion( + return active_llm.completion( messages=messages, tools=tools, - return_metrics=return_metrics, + _return_metrics=_return_metrics, add_security_risk_prediction=add_security_risk_prediction, **kwargs, ) @@ -112,3 +128,76 @@ def set_placeholder_model(cls, data): d["model"] = d.get("router_name", "router") return d + + def resolve_diff_from_deserialized(self, persisted: "LLMBase") -> "LLMBase": + """Resolve differences between a deserialized RouterLLM and the current + instance. + + This method handles the reconciliation of nested LLMs in llms_for_routing, + ensuring that secret fields (like api_key) are properly restored from the + runtime instance to the deserialized instance. + + Args: + persisted: The deserialized RouterLLM instance from persistence + + Returns: + A new RouterLLM instance equivalent to `persisted` but with secrets + from the runtime instance properly restored in all nested LLMs + + Raises: + ValueError: If the classes don't match or if reconciliation fails + """ + # If persisted is not a RouterLLM at all, this is an incompatible state + if not isinstance(persisted, RouterLLM): + # Check if the persisted data even has the router fields + persisted_dict = persisted.model_dump() + if "llms_for_routing" not in persisted_dict: + raise ValueError( + f"Cannot resolve_diff_from_deserialized: persisted LLM is not a " + "RouterLLM and doesn't contain router data. Got " + f"{persisted.__class__}" + ) + # Try to reconstruct as the correct RouterLLM subclass + persisted = self.__class__.model_validate(persisted_dict) + + # Check classes match exactly + if type(persisted) is not type(self): + raise ValueError( + f"Cannot resolve_diff_from_deserialized between {type(self)} " + f"and {type(persisted)}" + ) + + # Reconcile each nested LLM in llms_for_routing + reconciled_llms = {} + for name, persisted_llm in persisted.llms_for_routing.items(): + if name not in self.llms_for_routing: + raise ValueError( + f"LLM '{name}' found in persisted state but not in runtime router" + ) + runtime_llm = self.llms_for_routing[name] + reconciled_llms[name] = runtime_llm.resolve_diff_from_deserialized( + persisted_llm + ) + + # Check for LLMs in runtime that aren't in persisted state + for name in self.llms_for_routing: + if name not in persisted.llms_for_routing: + raise ValueError( + f"LLM '{name}' found in runtime router but not in persisted state" + ) + + # Create reconciled router with updated nested LLMs + # Note: active_llm is runtime state and should not be persisted/restored + reconciled = persisted.model_copy(update={"llms_for_routing": reconciled_llms}) + + # Validate that the reconciled router matches the runtime router + runtime_dump = self.model_dump(exclude_none=True) + reconciled_dump = reconciled.model_dump(exclude_none=True) + + if runtime_dump != reconciled_dump: + raise ValueError( + "The RouterLLM provided is different from the one in persisted state.\n" + f"Diff: {pretty_pydantic_diff(self, reconciled)}" + ) + + return reconciled diff --git a/openhands-sdk/openhands/sdk/llm/router/impl/dynamic.py b/openhands-sdk/openhands/sdk/llm/router/impl/dynamic.py new file mode 100644 index 0000000000..acc74f6351 --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/router/impl/dynamic.py @@ -0,0 +1,74 @@ +""" +Dynamic Router implementation for OpenHands SDK. + +This router allows users to switch to entirely new LLMs without pre-configuring them, +with full serialization/deserialization support. +""" + +from pydantic import model_validator + +from openhands.sdk.llm.message import Message +from openhands.sdk.llm.router.base import RouterLLM +from openhands.sdk.logger import get_logger + + +logger = get_logger(__name__) + + +class DynamicRouter(RouterLLM): + """ + A RouterLLM that supports manual LLM switching. + Users need to provide all LLMs they want to switch to at initialization. + """ + + PRIMARY_MODEL_KEY: str = "primary" + + router_name: str = "dynamic_router" + manual_selection: str | None = None + + def select_llm(self, messages: list[Message]) -> str: # noqa: ARG002 + """ + Select LLM based on manual selection or fallback to first available. + + Args: + messages: List of messages (not used in manual selection) + + Returns: + Name of the selected LLM + """ + if self.manual_selection: + return self.manual_selection + + # Use the primary LLM if no manual selection + return self.PRIMARY_MODEL_KEY + + def switch_to_llm( + self, + identifier: str, + ) -> bool: + """ + Switch to an LLM by identifier. + + Args: + identifier: Name to discriminate the LLM instance + Returns: + True if switch was successful, False otherwise + """ + if identifier not in self.llms_for_routing: + logger.warning(f"Failed to switch to LLM {identifier}: not found") + return False + + self.manual_selection = identifier + self.active_llm_identifier: str = self.manual_selection + logger.info(f"Switched to existing LLM: {identifier}") + return True + + @model_validator(mode="after") + def _validate_llms_for_routing(self) -> "DynamicRouter": + """Ensure required models are present in llms_for_routing.""" + if self.PRIMARY_MODEL_KEY not in self.llms_for_routing: + raise ValueError( + f"Primary LLM key '{self.PRIMARY_MODEL_KEY}' not found" + " in llms_for_routing." + ) + return self diff --git a/openhands-tools/openhands/tools/preset/default.py b/openhands-tools/openhands/tools/preset/default.py index 0bffe02071..a1e5bf7e06 100644 --- a/openhands-tools/openhands/tools/preset/default.py +++ b/openhands-tools/openhands/tools/preset/default.py @@ -5,7 +5,7 @@ LLMSummarizingCondenser, ) from openhands.sdk.context.condenser.base import CondenserBase -from openhands.sdk.llm.llm import LLM +from openhands.sdk.llm.llm import LLMBase from openhands.sdk.logger import get_logger from openhands.sdk.security.llm_analyzer import LLMSecurityAnalyzer from openhands.sdk.tool import Tool, register_tool @@ -54,7 +54,7 @@ def get_default_tools( return tools -def get_default_condenser(llm: LLM) -> CondenserBase: +def get_default_condenser(llm: LLMBase) -> CondenserBase: # Create a condenser to manage the context. The condenser will automatically # truncate conversation history when it exceeds max_size, and replaces the dropped # events with an LLM-generated summary. @@ -64,7 +64,7 @@ def get_default_condenser(llm: LLM) -> CondenserBase: def get_default_agent( - llm: LLM, + llm: LLMBase, cli_mode: bool = False, add_security_analyzer: bool = False, ) -> Agent: diff --git a/openhands-tools/openhands/tools/preset/planning.py b/openhands-tools/openhands/tools/preset/planning.py index 59a411dc12..86eba49a97 100644 --- a/openhands-tools/openhands/tools/preset/planning.py +++ b/openhands-tools/openhands/tools/preset/planning.py @@ -2,7 +2,7 @@ from openhands.sdk import Agent from openhands.sdk.context.condenser import LLMSummarizingCondenser -from openhands.sdk.llm.llm import LLM +from openhands.sdk.llm.llm import LLMBase from openhands.sdk.logger import get_logger from openhands.sdk.tool import Tool, register_tool @@ -119,7 +119,7 @@ def get_planning_tools() -> list[Tool]: ] -def get_planning_condenser(llm: LLM) -> LLMSummarizingCondenser: +def get_planning_condenser(llm: LLMBase) -> LLMSummarizingCondenser: """Get a condenser optimized for planning workflows. Args: @@ -138,7 +138,7 @@ def get_planning_condenser(llm: LLM) -> LLMSummarizingCondenser: def get_planning_agent( - llm: LLM, + llm: LLMBase, ) -> Agent: """Get a configured planning agent. diff --git a/tests/sdk/llm/test_dynamic_router.py b/tests/sdk/llm/test_dynamic_router.py new file mode 100644 index 0000000000..d74be48672 --- /dev/null +++ b/tests/sdk/llm/test_dynamic_router.py @@ -0,0 +1,184 @@ +""" +Tests for the DynamicRouter implementation. +""" + +import json + +import pytest +from pydantic import SecretStr + +from openhands.sdk.llm import LLM +from openhands.sdk.llm.router.impl.dynamic import DynamicRouter + + +class TestDynamicRouter: + """Test suite for DynamicRouter functionality.""" + + def test_initialization(self): + """Test basic router initialization.""" + initial_llm = LLM( + model="gpt-4o-mini", api_key=SecretStr("test-key"), service_id="agent" + ) + + router = DynamicRouter( + service_id="test_router", llms_for_routing={"primary": initial_llm} + ) + + assert router.router_name == "dynamic_router" + assert router.manual_selection is None + assert len(router.llms_for_routing) == 1 + assert "primary" in router.llms_for_routing + + def test_default_selection(self): + """Test default LLM selection when no manual selection is set.""" + initial_llm = LLM( + model="gpt-4o-mini", api_key=SecretStr("test-key"), service_id="agent" + ) + second_llm = LLM( + model="gpt-4o", api_key=SecretStr("test-key2"), service_id="agent2" + ) + + router = DynamicRouter( + service_id="test_router", + llms_for_routing={"primary": initial_llm, "secondary": second_llm}, + ) + + selected = router.select_llm([]) + assert selected == "primary" + + def test_manual_selection(self): + """Test manual LLM selection.""" + llm1 = LLM( + model="gpt-4o-mini", api_key=SecretStr("test-key1"), service_id="llm1" + ) + llm2 = LLM(model="gpt-4o", api_key=SecretStr("test-key2"), service_id="llm2") + + router = DynamicRouter( + service_id="test_router", llms_for_routing={"primary": llm1, "llm2": llm2} + ) + + # Test switching to existing LLM + success = router.switch_to_llm("llm2") + assert success is True + assert router.manual_selection == "llm2" + assert router.select_llm([]) == "llm2" + + def test_switch_to_non_existent_model(self): + """Test that switching to a non-existent model fails.""" + # Create with a minimal dummy LLM to satisfy base class validation + dummy_llm = LLM( + model="gpt-4o-mini", api_key=SecretStr("dummy"), service_id="dummy" + ) + router = DynamicRouter( + service_id="test_router", llms_for_routing={"primary": dummy_llm} + ) + + success = router.switch_to_llm("invalid") + assert success is False + assert router.manual_selection is None + assert "invalid" not in router.llms_for_routing + + def test_get_available_llms(self): + """Test getting list of available LLMs.""" + initial_llm = LLM( + model="gpt-4o-mini", api_key=SecretStr("test-key"), service_id="agent" + ) + claude_llm = LLM( + service_id="claude", + model="claude-3-5-sonnet-20241022", + api_key=SecretStr("claude-key"), + ) + + router = DynamicRouter( + service_id="test_router", + llms_for_routing={"primary": initial_llm, "claude": claude_llm}, + ) + + # Initially only pre-configured LLM + available = router.llms_for_routing + assert available["primary"].model == "gpt-4o-mini" + + # Add dynamic LLM + router.switch_to_llm( + "claude", + ) + + available = router.llms_for_routing + assert len(available) == 2 + assert available["primary"].model == "gpt-4o-mini" + assert available["claude"].model == "claude-3-5-sonnet-20241022" + + def test_get_current_llm_name(self): + """Test getting current LLM name.""" + dummy_llm = LLM( + model="gpt-4o-mini", api_key=SecretStr("dummy"), service_id="dummy" + ) + claude = LLM( + service_id="claude", + model="claude-3-5-sonnet-20241022", + api_key=SecretStr("claude-key"), + ) + + router = DynamicRouter( + service_id="test_router", + llms_for_routing={"primary": dummy_llm, "claude": claude}, + ) + + assert router.active_llm_identifier is None + + router.switch_to_llm( + "claude", + ) + assert router.active_llm_identifier == "claude" + + def test_serialization_with_dynamic_llms(self): + """Test serialization of router with dynamic LLMs.""" + initial_llm = LLM( + model="gpt-4o-mini", api_key=SecretStr("test-key"), service_id="agent" + ) + claude_llm = LLM( + service_id="claude", + model="claude-3-5-sonnet-20241022", + api_key=SecretStr("claude-key"), + ) + gemini_llm = LLM( + service_id="gemini", + model="gemini-1.5-pro", + api_key=SecretStr("gemini-key"), + ) + + router = DynamicRouter( + service_id="test_router", + llms_for_routing={ + "primary": initial_llm, + "claude": claude_llm, + "gemini": gemini_llm, + }, + ) + + # Add dynamic LLMs + router.switch_to_llm( + "claude", + ) + router.switch_to_llm( + "gemini", + ) + + # Serialize + serialized = router.model_dump_json(exclude_none=True) + data = json.loads(serialized) + + assert data["manual_selection"] == "gemini" # Last selected + + def test_manually_modify_llms_for_routing_raise_error(self): + """Test that manually modifying llms_for_routing is not allowed.""" + initial_llm = LLM( + model="gpt-4o-mini", api_key=SecretStr("test-key"), service_id="agent" + ) + router = DynamicRouter( + service_id="test_router", llms_for_routing={"primary": initial_llm} + ) + with pytest.raises(TypeError): + router.llms_for_routing["new_llm"] = LLM( + model="gpt-4o", api_key=SecretStr("test-key2"), service_id="agent2" + ) diff --git a/tests/sdk/llm/test_thinking_blocks.py b/tests/sdk/llm/test_thinking_blocks.py index 023f3d96b9..8a50a08326 100644 --- a/tests/sdk/llm/test_thinking_blocks.py +++ b/tests/sdk/llm/test_thinking_blocks.py @@ -173,6 +173,7 @@ def test_message_list_serializer_with_thinking_blocks(): role="assistant", content=[TextContent(text="The answer is 42.")], thinking_blocks=[thinking_block], + extended_thinking_enabled=True, ) serialized = message._list_serializer() @@ -242,6 +243,7 @@ def test_multiple_thinking_blocks(): role="assistant", content=[TextContent(text="Conclusion")], thinking_blocks=thinking_blocks, + extended_thinking_enabled=True, ) assert len(message.thinking_blocks) == 2