diff --git a/openhands-sdk/openhands/sdk/llm/capabilities.py b/openhands-sdk/openhands/sdk/llm/capabilities.py new file mode 100644 index 0000000000..d4d676a724 --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/capabilities.py @@ -0,0 +1,272 @@ +"""Model capability detection and validation for LLM instances. + +This module extracts capability-related logic from the LLM class to improve +maintainability and testability. It handles: +- Model information lookup from litellm +- Context window validation +- Vision support detection +- Prompt caching support detection +- Responses API support detection +""" + +from __future__ import annotations + +import os +import warnings +from dataclasses import dataclass +from typing import Final + +from litellm.types.utils import ModelInfo +from litellm.utils import supports_vision +from pydantic import SecretStr + +from openhands.sdk.llm.exceptions import LLMContextWindowTooSmallError +from openhands.sdk.llm.utils.model_features import get_features +from openhands.sdk.llm.utils.model_info import get_litellm_model_info +from openhands.sdk.logger import get_logger + + +logger = get_logger(__name__) + +__all__ = ["CapabilitiesConfig", "LLMCapabilities", "ModelInfo"] + + +@dataclass(frozen=True, slots=True) +class CapabilitiesConfig: + """Configuration for LLMCapabilities initialization. + + Groups the parameters needed to detect model capabilities, keeping + the LLMCapabilities constructor signature stable as new fields are added. + """ + + model: str + model_canonical_name: str | None + base_url: str | None + api_key: SecretStr | str | None + disable_vision: bool + caching_prompt: bool + + +# Minimum context window size required for OpenHands to function properly. +# Based on typical usage: system prompt (~2k) + conversation history (~4k) +# + tool definitions (~2k) + working memory (~8k) = ~16k minimum. +MIN_CONTEXT_WINDOW_TOKENS: Final[int] = 16384 + +# Environment variable to override the minimum context window check +ENV_ALLOW_SHORT_CONTEXT_WINDOWS: Final[str] = "ALLOW_SHORT_CONTEXT_WINDOWS" + +# Default max output tokens when model info only provides 'max_tokens' (ambiguous). +# Some providers use 'max_tokens' for the total context window, not output limit. +# This cap prevents requesting output that exceeds the context window. +# 16384 is a safe default that works for most models (GPT-4o: 16k, Claude: 8k). +DEFAULT_MAX_OUTPUT_TOKENS_CAP: Final[int] = 16384 + +# Model-specific output token limits. +# These override litellm's model_info when a substring match is found. +# The limit is applied as an upper cap: if litellm reports a higher value, +# it's clamped down; if the model isn't in model_info at all, this value is used. +MODEL_OUTPUT_TOKEN_LIMITS: Final[dict[str, int]] = { + "claude-3-7-sonnet": 64000, + "claude-sonnet-4": 64000, + "kimi-k2-thinking": 64000, + "o3": 100000, +} + + +class LLMCapabilities: + """Detects and caches model capabilities. + + This class encapsulates capability detection for LLM models, including: + - Vision support + - Prompt caching support + - Responses API support + - Context window validation + - Auto-detection of token limits from model info + + It is initialized with model configuration and caches model info from litellm. + Token limits are auto-detected and exposed as ``detected_max_input_tokens`` + and ``detected_max_output_tokens``. The caller (LLM) owns the resolution of + user overrides vs detected values. + + Example: + >>> config = CapabilitiesConfig( + ... model="claude-sonnet-4-20250514", + ... model_canonical_name=None, + ... base_url=None, + ... api_key=SecretStr("key"), + ... disable_vision=False, + ... caching_prompt=True, + ... ) + >>> caps = LLMCapabilities(config) + >>> caps.vision_is_active() + True + >>> caps.is_caching_prompt_active() + True + """ + + def __init__(self, config: CapabilitiesConfig) -> None: + """Initialize capabilities detection. + + Args: + config: Configuration for capability detection. + """ + self._config = config + + # Auto-detected token limits (never user overrides) + self.detected_max_input_tokens: int | None = None + self.detected_max_output_tokens: int | None = None + + # Internal cache for model info + self._model_info: ModelInfo | None = None + + # Initialize model info and capabilities + self._init_model_info_and_caps() + + @property + def model(self) -> str: + """Return the model name.""" + return self._config.model + + @property + def model_name_for_capabilities(self) -> str: + """Return canonical name for capability lookups (e.g., vision support).""" + return self._config.model_canonical_name or self._config.model + + @property + def model_info(self) -> ModelInfo | None: + """Return the cached model info dictionary.""" + return self._model_info + + def _init_model_info_and_caps(self) -> None: + """Initialize model info and auto-detect token limits.""" + self._model_info = get_litellm_model_info( + secret_api_key=self._config.api_key, + base_url=self._config.base_url, + model=self.model_name_for_capabilities, + ) + + # Context window (max_input_tokens) + if self._model_info is not None and isinstance( + self._model_info.get("max_input_tokens"), int + ): + self.detected_max_input_tokens = self._model_info.get("max_input_tokens") + + # Validate context window size + self._validate_context_window_size() + + # Auto-detect max_output_tokens + self._auto_detect_max_output_tokens() + + def _auto_detect_max_output_tokens(self) -> None: + """Auto-detect max_output_tokens from model info.""" + model = self._config.model + + # 1. Check model-specific overrides (from MODEL_OUTPUT_TOKEN_LIMITS) + for model_prefix, limit in MODEL_OUTPUT_TOKEN_LIMITS.items(): + if model_prefix in model: + self.detected_max_output_tokens = limit + logger.debug( + "Setting max_output_tokens to %s for %s (model-specific limit)", + limit, + model, + ) + return + + # 2. Fall back to model_info detection + if self._model_info is not None: + if isinstance(self._model_info.get("max_output_tokens"), int): + self.detected_max_output_tokens = self._model_info.get( + "max_output_tokens" + ) + elif isinstance( + max_tokens_value := self._model_info.get("max_tokens"), int + ): + # 'max_tokens' is ambiguous: some providers use it for total + # context window, not output limit. Cap it to avoid requesting + # output that exceeds the context window. + self.detected_max_output_tokens = min( + max_tokens_value, DEFAULT_MAX_OUTPUT_TOKENS_CAP + ) + if max_tokens_value > DEFAULT_MAX_OUTPUT_TOKENS_CAP: + logger.debug( + "Capping max_output_tokens from %s to %s for %s " + "(max_tokens may be context window, not output)", + max_tokens_value, + self.detected_max_output_tokens, + model, + ) + + def _validate_context_window_size(self) -> None: + """Validate that the context window is large enough for OpenHands.""" + # Allow override via environment variable + if os.environ.get(ENV_ALLOW_SHORT_CONTEXT_WINDOWS, "").lower() in ( + "true", + "1", + "yes", + ): + return + + # Unknown context window - cannot validate + if self.detected_max_input_tokens is None: + return + + # Check minimum requirement + if self.detected_max_input_tokens < MIN_CONTEXT_WINDOW_TOKENS: + raise LLMContextWindowTooSmallError( + self.detected_max_input_tokens, MIN_CONTEXT_WINDOW_TOKENS + ) + + def vision_is_active(self) -> bool: + """Check if vision is supported and enabled. + + Returns: + True if the model supports vision and it's not disabled. + """ + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return not self._config.disable_vision and self._supports_vision() + + def _supports_vision(self) -> bool: + """Check if the model supports vision capabilities. + + Returns: + True if model is vision capable. Returns False if model not + supported by litellm. + """ + # litellm.supports_vision currently returns False for 'openai/gpt-...' + # or 'anthropic/claude-...' (with prefixes) but model_info will have + # the correct value for some reason. + # Check both the full model name and the name after proxy prefix + model_for_caps = self.model_name_for_capabilities + return bool( + supports_vision(model_for_caps) + or supports_vision(model_for_caps.split("/")[-1]) + or ( + self._model_info is not None + and self._model_info.get("supports_vision", False) + ) + ) + + def is_caching_prompt_active(self) -> bool: + """Check if prompt caching is supported and enabled for current model. + + Returns: + True if prompt caching is supported and enabled for the given model. + """ + if not self._config.caching_prompt: + return False + # We don't need to look-up model_info, because + # only Anthropic models need explicit caching breakpoints + return ( + self._config.caching_prompt + and get_features(self.model_name_for_capabilities).supports_prompt_cache + ) + + def uses_responses_api(self) -> bool: + """Check if this model uses the OpenAI Responses API path. + + Returns: + True if the model should use the Responses API. + """ + # by default, uses = supports + return get_features(self.model_name_for_capabilities).supports_responses_api diff --git a/openhands-sdk/openhands/sdk/llm/llm.py b/openhands-sdk/openhands/sdk/llm/llm.py index d9c1cfc3b7..cb72852d98 100644 --- a/openhands-sdk/openhands/sdk/llm/llm.py +++ b/openhands-sdk/openhands/sdk/llm/llm.py @@ -21,8 +21,12 @@ ) from pydantic.json_schema import SkipJsonSchema +from openhands.sdk.llm.capabilities import ( + CapabilitiesConfig, + LLMCapabilities, + ModelInfo, +) from openhands.sdk.llm.fallback_strategy import FallbackStrategy -from openhands.sdk.llm.utils.model_info import get_litellm_model_info from openhands.sdk.utils.deprecation import warn_deprecated from openhands.sdk.utils.pydantic_secrets import serialize_secret, validate_secret @@ -70,12 +74,10 @@ ) from litellm.utils import ( create_pretrained_tokenizer, - supports_vision, token_counter, ) from openhands.sdk.llm.exceptions import ( - LLMContextWindowTooSmallError, LLMNoResponseError, map_provider_exception, ) @@ -114,20 +116,6 @@ LLMNoResponseError, ) -# Minimum context window size required for OpenHands to function properly. -# Based on typical usage: system prompt (~2k) + conversation history (~4k) -# + tool definitions (~2k) + working memory (~8k) = ~16k minimum. -MIN_CONTEXT_WINDOW_TOKENS: Final[int] = 16384 - -# Environment variable to override the minimum context window check -ENV_ALLOW_SHORT_CONTEXT_WINDOWS: Final[str] = "ALLOW_SHORT_CONTEXT_WINDOWS" - -# Default max output tokens when model info only provides 'max_tokens' (ambiguous). -# Some providers use 'max_tokens' for the total context window, not output limit. -# This cap prevents requesting output that exceeds the context window. -# 16384 is a safe default that works for most models (GPT-4o: 16k, Claude: 8k). -DEFAULT_MAX_OUTPUT_TOKENS_CAP: Final[int] = 16384 - class LLM(BaseModel, RetryMixin, NonNativeToolCallingMixin): """Language model interface for OpenHands agents. @@ -385,7 +373,8 @@ class LLM(BaseModel, RetryMixin, NonNativeToolCallingMixin): ) _metrics: Metrics | None = PrivateAttr(default=None) # Runtime-only private attrs - _model_info: Any = PrivateAttr(default=None) + # set in _set_env_side_effects validator + _capabilities: LLMCapabilities = PrivateAttr(default=None) # type: ignore[assignment] _tokenizer: Any = PrivateAttr(default=None) _telemetry: Telemetry | None = PrivateAttr(default=None) _is_subscription: bool = PrivateAttr(default=False) @@ -482,8 +471,25 @@ def _set_env_side_effects(self): if self.custom_tokenizer: self._tokenizer = create_pretrained_tokenizer(self.custom_tokenizer) - # Capabilities + model info - self._init_model_info_and_caps() + # LLMCapabilities owns capability detection (vision, caching, token limits). + # It auto-detects token limits from model info. LLM owns the final resolved + # values: user override wins, otherwise use auto-detected. + self._capabilities = LLMCapabilities( + CapabilitiesConfig( + model=self.model, + model_canonical_name=self.model_canonical_name, + base_url=self.base_url, + api_key=self.api_key, + disable_vision=self.disable_vision + if self.disable_vision is not None + else False, + caching_prompt=self.caching_prompt, + ) + ) + if self.max_input_tokens is None: + self.max_input_tokens = self._capabilities.detected_max_input_tokens + if self.max_output_tokens is None: + self.max_output_tokens = self._capabilities.detected_max_output_tokens logger.debug( f"LLM ready: model={self.model} base_url={self.base_url} " @@ -1074,151 +1080,40 @@ def _litellm_modify_params_ctx(self, flag: bool): litellm.modify_params = old # ========================================================================= - # Capabilities, formatting, and info + # Capabilities (delegated to LLMCapabilities) # ========================================================================= def _model_name_for_capabilities(self) -> str: """Return canonical name for capability lookups (e.g., vision support).""" - return self.model_canonical_name or self.model - - def _init_model_info_and_caps(self) -> None: - self._model_info = get_litellm_model_info( - secret_api_key=self.api_key, - base_url=self.base_url, - model=self._model_name_for_capabilities(), - ) - - # Context window and max_output_tokens - if ( - self.max_input_tokens is None - and self._model_info is not None - and isinstance(self._model_info.get("max_input_tokens"), int) - ): - self.max_input_tokens = self._model_info.get("max_input_tokens") - - # Validate context window size - self._validate_context_window_size() - - if self.max_output_tokens is None: - if any( - m in self.model - for m in [ - "claude-3-7-sonnet", - "claude-sonnet-4", - "kimi-k2-thinking", - ] - ): - self.max_output_tokens = ( - 64000 # practical cap (litellm may allow 128k with header) - ) - logger.debug( - f"Setting max_output_tokens to {self.max_output_tokens} " - f"for {self.model}" - ) - elif self._model_info is not None: - if isinstance(self._model_info.get("max_output_tokens"), int): - self.max_output_tokens = self._model_info.get("max_output_tokens") - elif isinstance(self._model_info.get("max_tokens"), int): - # 'max_tokens' is ambiguous: some providers use it for total - # context window, not output limit. Cap it to avoid requesting - # output that exceeds the context window. - max_tokens_value = self._model_info.get("max_tokens") - assert isinstance(max_tokens_value, int) # for type checker - self.max_output_tokens = min( - max_tokens_value, DEFAULT_MAX_OUTPUT_TOKENS_CAP - ) - if max_tokens_value > DEFAULT_MAX_OUTPUT_TOKENS_CAP: - logger.debug( - "Capping max_output_tokens from %s to %s for %s " - "(max_tokens may be context window, not output)", - max_tokens_value, - self.max_output_tokens, - self.model, - ) - - if "o3" in self.model: - o3_limit = 100000 - if self.max_output_tokens is None or self.max_output_tokens > o3_limit: - self.max_output_tokens = o3_limit - logger.debug( - "Clamping max_output_tokens to %s for %s", - self.max_output_tokens, - self.model, - ) - - def _validate_context_window_size(self) -> None: - """Validate that the context window is large enough for OpenHands.""" - # Allow override via environment variable - if os.environ.get(ENV_ALLOW_SHORT_CONTEXT_WINDOWS, "").lower() in ( - "true", - "1", - "yes", - ): - return - - # Unknown context window - cannot validate - if self.max_input_tokens is None: - return - - # Check minimum requirement - if self.max_input_tokens < MIN_CONTEXT_WINDOW_TOKENS: - raise LLMContextWindowTooSmallError( - self.max_input_tokens, MIN_CONTEXT_WINDOW_TOKENS - ) + return self._capabilities.model_name_for_capabilities def vision_is_active(self) -> bool: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - return not self.disable_vision and self._supports_vision() - - def _supports_vision(self) -> bool: - """Acquire from litellm if model is vision capable. + """Check if vision is supported and enabled. Returns: - bool: True if model is vision capable. Return False if model not - supported by litellm. + True if the model supports vision and it's not disabled. """ - # litellm.supports_vision currently returns False for 'openai/gpt-...' or 'anthropic/claude-...' (with prefixes) # noqa: E501 - # but model_info will have the correct value for some reason. - # we can go with it, but we will need to keep an eye if model_info is correct for Vertex or other providers # noqa: E501 - # remove when litellm is updated to fix https://github.com/BerriAI/litellm/issues/5608 # noqa: E501 - # Check both the full model name and the name after proxy prefix for vision support # noqa: E501 - model_for_caps = self._model_name_for_capabilities() - return ( - supports_vision(model_for_caps) - or supports_vision(model_for_caps.split("/")[-1]) - or ( - self._model_info is not None - and self._model_info.get("supports_vision", False) - ) - or False # fallback to False if model_info is None - ) + return self._capabilities.vision_is_active() def is_caching_prompt_active(self) -> bool: """Check if prompt caching is supported and enabled for current model. Returns: - boolean: True if prompt caching is supported and enabled for the given - model. + True if prompt caching is supported and enabled for the given model. """ - if not self.caching_prompt: - return False - # We don't need to look-up model_info, because - # only Anthropic models need explicit caching breakpoints - return ( - self.caching_prompt - and get_features(self._model_name_for_capabilities()).supports_prompt_cache - ) + return self._capabilities.is_caching_prompt_active() def uses_responses_api(self) -> bool: - """Whether this model uses the OpenAI Responses API path.""" + """Whether this model uses the OpenAI Responses API path. - # by default, uses = supports - return get_features(self._model_name_for_capabilities()).supports_responses_api + Returns: + True if the model should use the Responses API. + """ + return self._capabilities.uses_responses_api() @property - def model_info(self) -> dict | None: + def model_info(self) -> ModelInfo | None: """Returns the model info dictionary.""" - return self._model_info + return self._capabilities.model_info # ========================================================================= # Utilities preserved from previous class diff --git a/tests/sdk/context/condenser/test_llm_summarizing_condenser.py b/tests/sdk/context/condenser/test_llm_summarizing_condenser.py index 475e403639..0c431fc0d3 100644 --- a/tests/sdk/context/condenser/test_llm_summarizing_condenser.py +++ b/tests/sdk/context/condenser/test_llm_summarizing_condenser.py @@ -3,6 +3,7 @@ import pytest from litellm.types.utils import ModelResponse +from pydantic import SecretStr from openhands.sdk.context.condenser.base import ( CondensationRequirement, @@ -76,6 +77,14 @@ def create_completion_result(content: str) -> LLMResponse: mock_llm.input_cost_per_token = None mock_llm.output_cost_per_token = None + # Attributes required by LLMCapabilities (created in _set_env_side_effects) + mock_llm.model_canonical_name = None + mock_llm.api_key = SecretStr("test-key") + mock_llm.disable_vision = False + mock_llm.caching_prompt = False + mock_llm.max_input_tokens = 128000 + mock_llm.max_output_tokens = 4096 + mock_llm._metrics = None # Helper method to set mock response content diff --git a/tests/sdk/llm/test_capabilities.py b/tests/sdk/llm/test_capabilities.py new file mode 100644 index 0000000000..a9fe4825d4 --- /dev/null +++ b/tests/sdk/llm/test_capabilities.py @@ -0,0 +1,263 @@ +"""Tests for the LLMCapabilities class.""" + +from unittest.mock import patch + +import pytest +from pydantic import SecretStr + +from openhands.sdk.llm.capabilities import ( + DEFAULT_MAX_OUTPUT_TOKENS_CAP, + MIN_CONTEXT_WINDOW_TOKENS, + CapabilitiesConfig, + LLMCapabilities, +) +from openhands.sdk.llm.exceptions import LLMContextWindowTooSmallError + + +@pytest.fixture +def mock_model_info(): + """Default mock model info for testing.""" + return { + "max_input_tokens": 128000, + "max_output_tokens": 16384, + "supports_vision": True, + } + + +@pytest.fixture +def base_config_kwargs(): + """Base kwargs for creating CapabilitiesConfig.""" + return { + "model": "claude-sonnet-4-20250514", + "model_canonical_name": None, + "base_url": None, + "api_key": SecretStr("test-key"), + "disable_vision": False, + "caching_prompt": True, + } + + +def _make_caps(kwargs: dict) -> LLMCapabilities: + """Helper to create LLMCapabilities from a kwargs dict.""" + return LLMCapabilities(CapabilitiesConfig(**kwargs)) + + +@patch("openhands.sdk.llm.capabilities.get_litellm_model_info") +def test_capabilities_initialization( + mock_get_info, mock_model_info, base_config_kwargs +): + """Test basic initialization of LLMCapabilities.""" + mock_get_info.return_value = mock_model_info + # Use a model that doesn't have special output token handling + base_config_kwargs["model"] = "gpt-4o" + + caps = _make_caps(base_config_kwargs) + + assert caps.model == "gpt-4o" + assert caps.model_name_for_capabilities == "gpt-4o" + assert caps.detected_max_input_tokens == 128000 + assert caps.detected_max_output_tokens == 16384 + assert caps.model_info == mock_model_info + + +@patch("openhands.sdk.llm.capabilities.get_litellm_model_info") +def test_model_canonical_name_override( + mock_get_info, mock_model_info, base_config_kwargs +): + """Test that model_canonical_name overrides model for capabilities.""" + mock_get_info.return_value = mock_model_info + base_config_kwargs["model_canonical_name"] = "anthropic/claude-sonnet-4" + + caps = _make_caps(base_config_kwargs) + + assert caps.model == "claude-sonnet-4-20250514" + assert caps.model_name_for_capabilities == "anthropic/claude-sonnet-4" + + +@patch("openhands.sdk.llm.capabilities.get_litellm_model_info") +def test_max_output_tokens_cap_from_max_tokens(mock_get_info, base_config_kwargs): + """Test that max_tokens is capped to DEFAULT_MAX_OUTPUT_TOKENS_CAP.""" + mock_get_info.return_value = { + "max_input_tokens": 200000, + "max_tokens": 200000, # Ambiguous - could be context window + } + # Use a model that doesn't have special output token handling + base_config_kwargs["model"] = "gpt-4o" + + caps = _make_caps(base_config_kwargs) + + # Should be capped to avoid exceeding context window + assert caps.detected_max_output_tokens == DEFAULT_MAX_OUTPUT_TOKENS_CAP + + +@patch("openhands.sdk.llm.capabilities.get_litellm_model_info") +def test_claude_extended_output_tokens(mock_get_info, base_config_kwargs): + """Test that Claude models get extended max_output_tokens.""" + mock_get_info.return_value = {"max_input_tokens": 200000} + base_config_kwargs["model"] = "claude-sonnet-4-20250514" + + caps = _make_caps(base_config_kwargs) + + assert caps.detected_max_output_tokens == 64000 + + +@patch("openhands.sdk.llm.capabilities.get_litellm_model_info") +def test_o3_output_tokens_clamped(mock_get_info, base_config_kwargs): + """Test that o3 models have output tokens clamped to 100k.""" + mock_get_info.return_value = { + "max_input_tokens": 200000, + "max_output_tokens": 200000, + } + base_config_kwargs["model"] = "o3-2025-04-16" + + caps = _make_caps(base_config_kwargs) + + assert caps.detected_max_output_tokens == 100000 + + +@patch("openhands.sdk.llm.capabilities.get_litellm_model_info") +@patch("openhands.sdk.llm.capabilities.supports_vision", return_value=True) +def test_vision_is_active_when_supported(mock_sv, mock_get_info, base_config_kwargs): + """Test vision_is_active returns True when model supports vision.""" + mock_get_info.return_value = {"supports_vision": True} + base_config_kwargs["disable_vision"] = False + + caps = _make_caps(base_config_kwargs) + + assert caps.vision_is_active() is True + + +@patch("openhands.sdk.llm.capabilities.get_litellm_model_info") +@patch("openhands.sdk.llm.capabilities.supports_vision", return_value=True) +def test_vision_is_active_when_disabled(mock_sv, mock_get_info, base_config_kwargs): + """Test vision_is_active returns False when disable_vision=True.""" + mock_get_info.return_value = {"supports_vision": True} + base_config_kwargs["disable_vision"] = True + + caps = _make_caps(base_config_kwargs) + + assert caps.vision_is_active() is False + + +@patch("openhands.sdk.llm.capabilities.get_litellm_model_info") +@patch("openhands.sdk.llm.capabilities.supports_vision", return_value=False) +def test_vision_is_active_when_not_supported( + mock_sv, mock_get_info, base_config_kwargs +): + """Test vision_is_active returns False when model doesn't support vision.""" + mock_get_info.return_value = {"supports_vision": False} + base_config_kwargs["disable_vision"] = False + + caps = _make_caps(base_config_kwargs) + + assert caps.vision_is_active() is False + + +@patch("openhands.sdk.llm.capabilities.get_litellm_model_info") +def test_caching_prompt_active_for_claude(mock_get_info, base_config_kwargs): + """Test that caching is active for Claude models when enabled.""" + mock_get_info.return_value = {} + base_config_kwargs["model"] = "claude-3-5-sonnet" + base_config_kwargs["caching_prompt"] = True + + caps = _make_caps(base_config_kwargs) + + assert caps.is_caching_prompt_active() is True + + +@patch("openhands.sdk.llm.capabilities.get_litellm_model_info") +def test_caching_prompt_inactive_when_disabled(mock_get_info, base_config_kwargs): + """Test that caching is inactive when caching_prompt=False.""" + mock_get_info.return_value = {} + base_config_kwargs["model"] = "claude-3-5-sonnet" + base_config_kwargs["caching_prompt"] = False + + caps = _make_caps(base_config_kwargs) + + assert caps.is_caching_prompt_active() is False + + +@patch("openhands.sdk.llm.capabilities.get_litellm_model_info") +def test_caching_prompt_inactive_for_unsupported_model( + mock_get_info, base_config_kwargs +): + """Test that caching is inactive for models that don't support it.""" + mock_get_info.return_value = {} + base_config_kwargs["model"] = "gpt-4o" + base_config_kwargs["caching_prompt"] = True + + caps = _make_caps(base_config_kwargs) + + assert caps.is_caching_prompt_active() is False + + +@patch("openhands.sdk.llm.capabilities.get_litellm_model_info") +def test_uses_responses_api_for_gpt5(mock_get_info, base_config_kwargs): + """Test that GPT-5 models use the Responses API.""" + mock_get_info.return_value = {} + base_config_kwargs["model"] = "gpt-5.2" + + caps = _make_caps(base_config_kwargs) + + assert caps.uses_responses_api() is True + + +@patch("openhands.sdk.llm.capabilities.get_litellm_model_info") +def test_uses_responses_api_false_for_older_models(mock_get_info, base_config_kwargs): + """Test that older models don't use the Responses API.""" + mock_get_info.return_value = {} + base_config_kwargs["model"] = "gpt-4o" + + caps = _make_caps(base_config_kwargs) + + assert caps.uses_responses_api() is False + + +@patch("openhands.sdk.llm.capabilities.get_litellm_model_info") +def test_context_window_too_small_raises_error(mock_get_info, base_config_kwargs): + """Test that small context windows raise LLMContextWindowTooSmallError.""" + mock_get_info.return_value = {"max_input_tokens": 4096} + + with pytest.raises(LLMContextWindowTooSmallError) as exc_info: + _make_caps(base_config_kwargs) + + # Check the error message contains expected values + assert "4,096" in str(exc_info.value) + assert str(MIN_CONTEXT_WINDOW_TOKENS) in str(exc_info.value).replace(",", "") + + +@patch("openhands.sdk.llm.capabilities.get_litellm_model_info") +@patch.dict("os.environ", {"ALLOW_SHORT_CONTEXT_WINDOWS": "true"}) +def test_context_window_check_can_be_bypassed(mock_get_info, base_config_kwargs): + """Test that context window check can be bypassed with env var.""" + mock_get_info.return_value = {"max_input_tokens": 4096} + + # Should not raise + caps = _make_caps(base_config_kwargs) + + assert caps.detected_max_input_tokens == 4096 + + +@patch("openhands.sdk.llm.capabilities.get_litellm_model_info") +def test_unknown_context_window_passes_validation(mock_get_info, base_config_kwargs): + """Test that unknown context window (None) doesn't fail validation.""" + mock_get_info.return_value = {} # No max_input_tokens + + # Should not raise + caps = _make_caps(base_config_kwargs) + + assert caps.detected_max_input_tokens is None + + +@patch("openhands.sdk.llm.capabilities.get_litellm_model_info") +def test_model_info_returns_cached_info( + mock_get_info, mock_model_info, base_config_kwargs +): + """Test that model_info property returns the cached model info.""" + mock_get_info.return_value = mock_model_info + + caps = _make_caps(base_config_kwargs) + + assert caps.model_info is mock_model_info + # Verify get_litellm_model_info was only called once + mock_get_info.assert_called_once() diff --git a/tests/sdk/llm/test_llm.py b/tests/sdk/llm/test_llm.py index 9711d3e681..15dac07b0a 100644 --- a/tests/sdk/llm/test_llm.py +++ b/tests/sdk/llm/test_llm.py @@ -979,11 +979,11 @@ def test_unmapped_model_with_logging_enabled(mock_transport): # Context Window Validation Tests -@patch("openhands.sdk.llm.llm.get_litellm_model_info") +@patch("openhands.sdk.llm.capabilities.get_litellm_model_info") def test_llm_raises_error_on_small_context_window(mock_get_model_info): """Test that LLM raises error when context window is too small.""" + from openhands.sdk.llm.capabilities import MIN_CONTEXT_WINDOW_TOKENS from openhands.sdk.llm.exceptions import LLMContextWindowTooSmallError - from openhands.sdk.llm.llm import MIN_CONTEXT_WINDOW_TOKENS mock_get_model_info.return_value = {"max_input_tokens": 2048} @@ -999,12 +999,12 @@ def test_llm_raises_error_on_small_context_window(mock_get_model_info): assert "docs.openhands.dev" in str(exc_info.value) -@patch("openhands.sdk.llm.llm.get_litellm_model_info") +@patch("openhands.sdk.llm.capabilities.get_litellm_model_info") def test_llm_respects_allow_short_context_windows_env_var(mock_get_model_info): """Test that ALLOW_SHORT_CONTEXT_WINDOWS env var bypasses validation.""" import os - from openhands.sdk.llm.llm import ENV_ALLOW_SHORT_CONTEXT_WINDOWS + from openhands.sdk.llm.capabilities import ENV_ALLOW_SHORT_CONTEXT_WINDOWS mock_get_model_info.return_value = {"max_input_tokens": 2048} @@ -1073,7 +1073,7 @@ def test_llm_reset_metrics(): # max_output_tokens Capping Tests -@patch("openhands.sdk.llm.llm.get_litellm_model_info") +@patch("openhands.sdk.llm.capabilities.get_litellm_model_info") def test_max_output_tokens_capped_when_using_max_tokens_fallback(mock_get_model_info): """Test that max_output_tokens is capped when falling back to max_tokens. @@ -1083,7 +1083,7 @@ def test_max_output_tokens_capped_when_using_max_tokens_fallback(mock_get_model_ See: https://github.com/OpenHands/software-agent-sdk/issues/XXX """ - from openhands.sdk.llm.llm import DEFAULT_MAX_OUTPUT_TOKENS_CAP + from openhands.sdk.llm.capabilities import DEFAULT_MAX_OUTPUT_TOKENS_CAP # Simulate a model where max_tokens = context window (200k) but # max_output_tokens is not set @@ -1105,7 +1105,7 @@ def test_max_output_tokens_capped_when_using_max_tokens_fallback(mock_get_model_ assert llm.max_output_tokens < 200000 -@patch("openhands.sdk.llm.llm.get_litellm_model_info") +@patch("openhands.sdk.llm.capabilities.get_litellm_model_info") def test_max_output_tokens_uses_actual_value_when_available(mock_get_model_info): """Test that actual max_output_tokens is used when available.""" # Simulate a model with proper max_output_tokens @@ -1125,10 +1125,10 @@ def test_max_output_tokens_uses_actual_value_when_available(mock_get_model_info) assert llm.max_output_tokens == 8192 -@patch("openhands.sdk.llm.llm.get_litellm_model_info") +@patch("openhands.sdk.llm.capabilities.get_litellm_model_info") def test_max_output_tokens_small_max_tokens_not_capped(mock_get_model_info): """Test that small max_tokens fallback is not unnecessarily capped.""" - from openhands.sdk.llm.llm import DEFAULT_MAX_OUTPUT_TOKENS_CAP + from openhands.sdk.llm.capabilities import DEFAULT_MAX_OUTPUT_TOKENS_CAP # Simulate a model where max_tokens is small (actual output limit) mock_get_model_info.return_value = { diff --git a/tests/sdk/llm/test_llm_completion.py b/tests/sdk/llm/test_llm_completion.py index de0f482816..aecd395c1a 100644 --- a/tests/sdk/llm/test_llm_completion.py +++ b/tests/sdk/llm/test_llm_completion.py @@ -408,11 +408,14 @@ def test_llm_token_counting_basic(default_config): def test_llm_model_info_initialization(default_config): - """Test model info initialization.""" + """Test model info initialization. + + Model info is initialized during LLM construction through LLMCapabilities. + """ llm = default_config - # Model info initialization should complete without errors - llm._init_model_info_and_caps() + # Capabilities are initialized during construction + assert llm._capabilities is not None # Model info might be None for unknown models, which is fine assert llm.model_info is None or isinstance(llm.model_info, dict) diff --git a/tests/sdk/llm/test_llm_serialization.py b/tests/sdk/llm/test_llm_serialization.py index eeceef9723..b9961c66cc 100644 --- a/tests/sdk/llm/test_llm_serialization.py +++ b/tests/sdk/llm/test_llm_serialization.py @@ -106,14 +106,13 @@ def test_llm_private_attributes_not_serialized() -> None: llm = LLM(model="test-model", usage_id="test-llm") # Set private attributes (these would normally be set internally) - llm._model_info = {"some": "info"} llm._tokenizer = "mock-tokenizer" # Serialize to dict llm_dict = llm.model_dump() # Private attributes should not be present - assert "_model_info" not in llm_dict + assert "_capabilities" not in llm_dict assert "_tokenizer" not in llm_dict assert "_telemetry" not in llm_dict @@ -121,10 +120,10 @@ def test_llm_private_attributes_not_serialized() -> None: llm_json = llm.model_dump_json() deserialized_llm = LLM.model_validate_json(llm_json) - # Private attributes should have default values - # (LLM creates telemetry automatically) - assert deserialized_llm._model_info is None - assert deserialized_llm._tokenizer is None + # Private attributes should have default values or be re-initialized + # (LLM creates capabilities and telemetry automatically during validation) + assert deserialized_llm._capabilities is not None # Re-initialized on deserialize + assert deserialized_llm._tokenizer is None # Default is None assert deserialized_llm.native_tool_calling is True assert ( deserialized_llm._telemetry is not None diff --git a/tests/sdk/llm/test_model_canonical_name_resolution.py b/tests/sdk/llm/test_model_canonical_name_resolution.py index e3a2a7d05e..3bba277255 100644 --- a/tests/sdk/llm/test_model_canonical_name_resolution.py +++ b/tests/sdk/llm/test_model_canonical_name_resolution.py @@ -37,10 +37,14 @@ def fake_get_features(model: str): return DummyFeatures(model) monkeypatch.setattr( - "openhands.sdk.llm.llm.get_litellm_model_info", fake_get_model_info + "openhands.sdk.llm.capabilities.get_litellm_model_info", fake_get_model_info + ) + monkeypatch.setattr( + "openhands.sdk.llm.capabilities.supports_vision", fake_supports_vision + ) + monkeypatch.setattr( + "openhands.sdk.llm.capabilities.get_features", fake_get_features ) - monkeypatch.setattr("openhands.sdk.llm.llm.supports_vision", fake_supports_vision) - monkeypatch.setattr("openhands.sdk.llm.llm.get_features", fake_get_features) real_llm = LLM(model="openai/gpt-5-mini") proxy_llm = LLM( diff --git a/tests/sdk/llm/test_vision_support.py b/tests/sdk/llm/test_vision_support.py index e56f94dd65..c5370bf28c 100644 --- a/tests/sdk/llm/test_vision_support.py +++ b/tests/sdk/llm/test_vision_support.py @@ -78,10 +78,10 @@ def test_chat_serializes_images_when_vision_supported(model): @patch( - "openhands.sdk.llm.llm.get_litellm_model_info", + "openhands.sdk.llm.capabilities.get_litellm_model_info", return_value={"supports_vision": False}, ) -@patch("openhands.sdk.llm.llm.supports_vision", return_value=False) +@patch("openhands.sdk.llm.capabilities.supports_vision", return_value=False) def test_message_with_image_does_not_enable_vision_for_text_only_model( mock_sv, _mock_model_info ): @@ -147,10 +147,10 @@ def test_disable_vision_overrides_litellm_detection(): @patch( - "openhands.sdk.llm.llm.get_litellm_model_info", + "openhands.sdk.llm.capabilities.get_litellm_model_info", return_value={"supports_vision": False}, ) -@patch("openhands.sdk.llm.llm.supports_vision", return_value=False) +@patch("openhands.sdk.llm.capabilities.supports_vision", return_value=False) def test_message_with_image_in_responses_does_not_include_input_image( mock_sv, _mock_model_info ):