diff --git a/examples/01_standalone_sdk/43_interrupt_example.py b/examples/01_standalone_sdk/43_interrupt_example.py new file mode 100644 index 0000000000..933dacfc32 --- /dev/null +++ b/examples/01_standalone_sdk/43_interrupt_example.py @@ -0,0 +1,173 @@ +"""Example: Interrupting agent execution with Ctrl+C. + +This example demonstrates how to use conversation.interrupt() to immediately +cancel an in-flight LLM call when the user presses Ctrl+C. + +Unlike pause(), which waits for the current LLM call to complete, +interrupt() cancels the call immediately by: +- Cancelling the async task running the LLM call +- Closing the HTTP connection +- Raising LLMCancelledError + +This is useful for: +- Long-running reasoning tasks that you want to stop immediately +- Expensive API calls you want to cancel to save costs +- Interactive applications where responsiveness is important + +Usage: + LLM_API_KEY=your_key python 43_interrupt_example.py + +Press Ctrl+C at any time to interrupt the agent. +""" + +import os +import signal +import sys +import threading +import time + +from openhands.sdk import LLM, Agent, Conversation, Tool +from openhands.tools.terminal import TerminalTool + + +PROMPT = """ +I need you to solve this complex logic puzzle step by step, showing your reasoning: + +There are 5 houses in a row, each a different color (Red, Green, Blue, Yellow, White). +Each house is occupied by a person of different nationality. +Each person has a different pet, drink, and cigarette brand. + +Clues: +1. The British person lives in the red house. +2. The Swedish person keeps dogs as pets. +3. The Danish person drinks tea. +4. The green house is on the left of the white house. +5. The green house's owner drinks coffee. +6. The person who smokes Pall Mall rears birds. +7. The owner of the yellow house smokes Dunhill. +8. The person living in the center house drinks milk. +9. The Norwegian lives in the first house. +10. The person who smokes Blend lives next to the one who keeps cats. +11. The person who keeps horses lives next to the one who smokes Dunhill. +12. The person who smokes Blue Master drinks beer. +13. The German smokes Prince. +14. The Norwegian lives next to the blue house. +15. The person who smokes Blend has a neighbor who drinks water. + +Question: Who owns the fish? + +Please solve this completely, showing your full reasoning process with all deductions. +After solving, create a file called 'puzzle_solution.txt' with your complete solution. +""" + + +def main(): + # Track timing + start_time: float | None = None + interrupt_time: float | None = None + + # Configure LLM - use gpt-5.2 for long reasoning tasks + # Falls back to environment variable model if gpt-5.2 not available + api_key = os.getenv("LLM_API_KEY") + if not api_key: + print("Error: LLM_API_KEY environment variable is not set.") + sys.exit(1) + + model = os.getenv("LLM_MODEL", "openai/gpt-5.2") + base_url = os.getenv("LLM_BASE_URL") + + print("=" * 70) + print("Interrupt Example - Press Ctrl+C to immediately stop the agent") + print("=" * 70) + print() + + llm = LLM( + usage_id="reasoning-agent", + model=model, + base_url=base_url, + api_key=api_key, + ) + + print(f"Using model: {model}") + print() + + # Create agent with minimal tools + agent = Agent( + llm=llm, + tools=[Tool(name=TerminalTool.name)], + ) + + conversation = Conversation(agent=agent, workspace=os.getcwd()) + + # Set up Ctrl+C handler + def signal_handler(_signum, _frame): + nonlocal interrupt_time + interrupt_time = time.time() + print("\n") + print("=" * 70) + print("Ctrl+C detected! Interrupting agent...") + print("=" * 70) + + # Call interrupt() - this immediately cancels any in-flight LLM call + conversation.interrupt() + + signal.signal(signal.SIGINT, signal_handler) + + # Send a task that requires long reasoning + print("Sending a complex reasoning task to the agent...") + print("(This task is designed to take a while - press Ctrl+C to interrupt)") + print() + + conversation.send_message(PROMPT) + print(f"Agent status: {conversation.state.execution_status}") + print() + + # Run in background thread so we can handle signals + def run_agent(): + conversation.run() + + start_time = time.time() + thread = threading.Thread(target=run_agent) + thread.start() + + print("Agent is working... (press Ctrl+C to interrupt)") + print() + + # Wait for thread to complete (either normally or via interrupt) + thread.join() + + end_time = time.time() + + # Report timing + print() + print("=" * 70) + print("Results") + print("=" * 70) + print() + print(f"Final status: {conversation.state.execution_status}") + print() + + if interrupt_time: + interrupt_latency = end_time - interrupt_time + total_time = end_time - start_time + print(f"Total time from start to stop: {total_time:.2f} seconds") + print(f"Time from Ctrl+C to full stop: {interrupt_latency:.3f} seconds") + print() + print("The agent was interrupted immediately!") + print("Without interrupt(), you would have had to wait for the full") + print("LLM response to complete before the agent would stop.") + else: + total_time = end_time - start_time + print(f"Total time: {total_time:.2f} seconds") + print("Agent completed normally (was not interrupted)") + + print() + + # Report cost + cost = llm.metrics.accumulated_cost + print(f"Accumulated cost: ${cost:.6f}") + print(f"EXAMPLE_COST: {cost}") + + +if __name__ == "__main__": + main() diff --git a/openhands-sdk/openhands/sdk/conversation/base.py b/openhands-sdk/openhands/sdk/conversation/base.py index bfb3aa366a..2217213f7a 100644 --- a/openhands-sdk/openhands/sdk/conversation/base.py +++ b/openhands-sdk/openhands/sdk/conversation/base.py @@ -193,6 +193,21 @@ def reject_pending_actions( @abstractmethod def pause(self) -> None: ... + @abstractmethod + def interrupt(self) -> None: + """Interrupt the agent immediately, cancelling any in-flight LLM calls. + + Unlike pause(), which waits for the current step to complete, + interrupt() attempts to cancel ongoing LLM calls immediately: + + - Streaming calls: Cancelled at the next chunk boundary (immediate) + - Non-streaming calls: The async task is cancelled, closing the HTTP connection + + This method is thread-safe and can be called from any thread. + After interruption, the conversation status is set to PAUSED. + """ + ... + @abstractmethod def update_secrets(self, secrets: Mapping[str, SecretValue]) -> None: ... diff --git a/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py b/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py index 29c003b875..3203381c19 100644 --- a/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py +++ b/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py @@ -37,6 +37,7 @@ from openhands.sdk.hooks import HookConfig, HookEventProcessor, create_hook_callback from openhands.sdk.io import LocalFileStore from openhands.sdk.llm import LLM, Message, TextContent +from openhands.sdk.llm.exceptions import LLMCancelledError from openhands.sdk.llm.llm_profile_store import LLMProfileStore from openhands.sdk.llm.llm_registry import LLMRegistry from openhands.sdk.logger import get_logger @@ -680,6 +681,10 @@ def run(self) -> None: ) ) break + except LLMCancelledError: + # LLM call was cancelled via interrupt() - this is not an error + # Status is already set to PAUSED by interrupt() + logger.info("Agent step cancelled by interrupt") except Exception as e: self._state.execution_status = ConversationExecutionStatus.ERROR @@ -759,6 +764,36 @@ def pause(self) -> None: self._on_event(pause_event) logger.info("Agent execution pause requested") + def interrupt(self) -> None: + """Interrupt the agent immediately, cancelling any in-flight LLM calls. + + Unlike pause(), which waits for the current step to complete, + interrupt() cancels ongoing LLM calls immediately: + + - Streaming calls: Cancelled at the next chunk boundary (immediate) + - Non-streaming calls: The async task is cancelled, closing the HTTP connection + + This method is thread-safe and can be called from any thread. + After interruption, the conversation status is set to PAUSED. + """ + from openhands.sdk.event.user_action import InterruptEvent + + # Cancel all LLMs first (main agent LLM + any in registry) + self.agent.llm.cancel() + for llm in self.llm_registry.usage_to_llm.values(): + llm.cancel() + + # Set paused status + with self._state: + if self._state.execution_status in [ + ConversationExecutionStatus.IDLE, + ConversationExecutionStatus.RUNNING, + ]: + self._state.execution_status = ConversationExecutionStatus.PAUSED + interrupt_event = InterruptEvent() + self._on_event(interrupt_event) + logger.info("Agent execution interrupted") + def update_secrets(self, secrets: Mapping[str, SecretValue]) -> None: """Add secrets to the conversation's secret registry. diff --git a/openhands-sdk/openhands/sdk/conversation/impl/remote_conversation.py b/openhands-sdk/openhands/sdk/conversation/impl/remote_conversation.py index f0f813432e..98fefe584a 100644 --- a/openhands-sdk/openhands/sdk/conversation/impl/remote_conversation.py +++ b/openhands-sdk/openhands/sdk/conversation/impl/remote_conversation.py @@ -1119,6 +1119,14 @@ def reject_pending_actions(self, reason: str = "User rejected the action") -> No def pause(self) -> None: _send_request(self._client, "POST", f"/api/conversations/{self._id}/pause") + def interrupt(self) -> None: + """Interrupt the conversation immediately, cancelling any in-flight LLM calls. + + For remote conversations, this sends an interrupt request to the server. + The server will cancel the current operation and set the status to paused. + """ + _send_request(self._client, "POST", f"/api/conversations/{self._id}/interrupt") + def update_secrets(self, secrets: Mapping[str, SecretValue]) -> None: # Convert SecretValue to strings for JSON serialization # SecretValue can be str or callable, we need to handle both diff --git a/openhands-sdk/openhands/sdk/event/__init__.py b/openhands-sdk/openhands/sdk/event/__init__.py index 27da310db4..bd27ad4767 100644 --- a/openhands-sdk/openhands/sdk/event/__init__.py +++ b/openhands-sdk/openhands/sdk/event/__init__.py @@ -19,7 +19,7 @@ ) from openhands.sdk.event.token import TokenEvent from openhands.sdk.event.types import EventID, ToolCallID -from openhands.sdk.event.user_action import PauseEvent +from openhands.sdk.event.user_action import InterruptEvent, PauseEvent __all__ = [ @@ -36,6 +36,7 @@ "UserRejectObservation", "RejectionSource", "PauseEvent", + "InterruptEvent", "Condensation", "CondensationRequest", "CondensationSummaryEvent", diff --git a/openhands-sdk/openhands/sdk/event/user_action.py b/openhands-sdk/openhands/sdk/event/user_action.py index 949ad4b711..9efa1b7d36 100644 --- a/openhands-sdk/openhands/sdk/event/user_action.py +++ b/openhands-sdk/openhands/sdk/event/user_action.py @@ -19,3 +19,28 @@ def visualize(self) -> Text: def __str__(self) -> str: """Plain text string representation for PauseEvent.""" return f"{self.__class__.__name__} ({self.source}): Agent execution paused" + + +class InterruptEvent(Event): + """Event indicating that the agent execution was interrupted. + + Unlike PauseEvent, InterruptEvent indicates that an in-flight LLM call + was cancelled. This provides immediate interruption rather than waiting + for the current step to complete. + """ + + source: SourceType = "user" + reason: str = "User requested interrupt" + + @property + def visualize(self) -> Text: + """Return Rich Text representation of this interrupt event.""" + content = Text() + content.append("Conversation Interrupted", style="bold red") + if self.reason != "User requested interrupt": + content.append(f" - {self.reason}", style="dim") + return content + + def __str__(self) -> str: + """Plain text string representation for InterruptEvent.""" + return f"{self.__class__.__name__} ({self.source}): {self.reason}" diff --git a/openhands-sdk/openhands/sdk/llm/exceptions/__init__.py b/openhands-sdk/openhands/sdk/llm/exceptions/__init__.py index 5659fc19fb..8b02ccef16 100644 --- a/openhands-sdk/openhands/sdk/llm/exceptions/__init__.py +++ b/openhands-sdk/openhands/sdk/llm/exceptions/__init__.py @@ -6,6 +6,7 @@ FunctionCallValidationError, LLMAuthenticationError, LLMBadRequestError, + LLMCancelledError, LLMContextWindowExceedError, LLMContextWindowTooSmallError, LLMError, @@ -38,6 +39,7 @@ "LLMTimeoutError", "LLMServiceUnavailableError", "LLMBadRequestError", + "LLMCancelledError", "UserCancelledError", "OperationCancelled", # Helpers diff --git a/openhands-sdk/openhands/sdk/llm/exceptions/types.py b/openhands-sdk/openhands/sdk/llm/exceptions/types.py index 9fd5373fcd..48f877ab53 100644 --- a/openhands-sdk/openhands/sdk/llm/exceptions/types.py +++ b/openhands-sdk/openhands/sdk/llm/exceptions/types.py @@ -128,3 +128,17 @@ def __init__(self, message: str = "User cancelled the request") -> None: class OperationCancelled(Exception): def __init__(self, message: str = "Operation was cancelled") -> None: super().__init__(message) + + +class LLMCancelledError(Exception): + """Raised when an LLM call is cancelled by user interrupt. + + This exception is raised when `LLM.cancel()` is called during an in-flight + LLM request. For streaming calls, cancellation happens immediately at the + next chunk boundary. For non-streaming calls, cancellation stops the async + task and closes the HTTP connection. + """ + + def __init__(self, message: str = "LLM call was cancelled") -> None: + super().__init__(message) + self.message = message diff --git a/openhands-sdk/openhands/sdk/llm/llm.py b/openhands-sdk/openhands/sdk/llm/llm.py index 90c80b9815..399bac71a3 100644 --- a/openhands-sdk/openhands/sdk/llm/llm.py +++ b/openhands-sdk/openhands/sdk/llm/llm.py @@ -44,7 +44,7 @@ ChatCompletionToolParam, CustomStreamWrapper, ResponseInputParam, - completion as litellm_completion, + acompletion as litellm_acompletion, ) from litellm.exceptions import ( APIConnectionError, @@ -53,8 +53,12 @@ ServiceUnavailableError, Timeout as LiteLLMTimeout, ) -from litellm.responses.main import responses as litellm_responses -from litellm.responses.streaming_iterator import SyncResponsesAPIStreamingIterator +from litellm.responses.main import ( + aresponses as litellm_aresponses, +) +from litellm.responses.streaming_iterator import ( + ResponsesAPIStreamingIterator, +) from litellm.types.llms.openai import ( OutputTextDeltaEvent, ReasoningSummaryTextDeltaEvent, @@ -85,6 +89,7 @@ from openhands.sdk.llm.message import ( Message, ) +from openhands.sdk.llm.mixins.async_cancellation import AsyncRunner from openhands.sdk.llm.mixins.non_native_fc import NonNativeToolCallingMixin from openhands.sdk.llm.options.chat_options import select_chat_options from openhands.sdk.llm.options.responses_options import select_responses_options @@ -134,8 +139,8 @@ class LLM(BaseModel, RetryMixin, NonNativeToolCallingMixin): The LLM class provides a unified interface for interacting with various language models through the litellm library. It handles model configuration, - API authentication, - retry logic, and tool calling capabilities. + API authentication, retry logic, tool calling capabilities, and async + cancellation support. Example: >>> from openhands.sdk import LLM @@ -392,6 +397,9 @@ class LLM(BaseModel, RetryMixin, NonNativeToolCallingMixin): _is_subscription: bool = PrivateAttr(default=False) _litellm_provider: str | None = PrivateAttr(default=None) + # Async runner for interruptible LLM calls + _async_runner: AsyncRunner | None = PrivateAttr(default=None) + model_config: ClassVar[ConfigDict] = ConfigDict( extra="ignore", arbitrary_types_allowed=True ) @@ -479,6 +487,10 @@ def _set_env_side_effects(self): metrics=self._metrics, ) + # Async runner for cancellable LLM calls + if self._async_runner is None: + self._async_runner = AsyncRunner(owner_id=self.usage_id) + # Tokenizer if self.custom_tokenizer: self._tokenizer = create_pretrained_tokenizer(self.custom_tokenizer) @@ -585,6 +597,95 @@ def reset_metrics(self) -> None: self._metrics = None self._telemetry = None + def __deepcopy__(self, memo: dict[int, Any] | None = None) -> LLM: + """Custom deepcopy that handles unpicklable thread-local state. + + This method is required because the AsyncRunner contains threading + primitives that cannot be pickled or deepcopied. + + This method is invoked in two scenarios: + 1. When `copy.deepcopy(llm)` is called explicitly + 2. When `llm.model_copy(deep=True)` is called (Pydantic calls __deepcopy__ + internally for deep copies) + + The AsyncRunner is not deepcopyable. Instead, we create a new copy + with a fresh AsyncRunner that will be lazily initialized. + """ + # Create a shallow copy using pydantic's model_copy + new_llm = self.model_copy(deep=False) + + # Deep copy the copyable private attributes + if self._metrics is not None: + new_llm._metrics = copy.deepcopy(self._metrics, memo) + else: + new_llm._metrics = None + + if self._tokenizer is not None: + # Tokenizers may not be deepcopyable, create fresh + new_llm._tokenizer = None + + if self._telemetry is not None: + new_llm._telemetry = copy.deepcopy(self._telemetry, memo) + else: + new_llm._telemetry = None + + new_llm._is_subscription = self._is_subscription + new_llm._litellm_provider = self._litellm_provider + + # Create fresh AsyncRunner for the copy + new_llm._async_runner = AsyncRunner(owner_id=new_llm.usage_id) + + return new_llm + + # ========================================================================= + # Cancellation support (delegates to AsyncRunner) + # ========================================================================= + def cancel(self) -> None: + """Cancel any in-flight LLM call (best effort). + + This method cancels the current LLM call immediately. The cancellation + takes effect at the next await point. + + Thread-safe: can be called from any thread. After cancellation, + the LLM can be used for new calls. + + Example: + >>> # In another thread: + >>> llm.cancel() # Cancels the current LLM call + """ + if self._async_runner is not None: + self._async_runner.cancel() + + def is_cancelled(self) -> bool: + """Check if the current call has been cancelled. + + Returns: + True if there's a current call and it has been cancelled. + """ + if self._async_runner is not None: + return self._async_runner.is_cancelled() + return False + + def close(self) -> None: + """Stop the background event loop and cleanup resources. + + This method should be called when the LLM instance is no longer needed, + especially in long-running applications that create/destroy many LLM + instances to prevent thread leaks. + + After calling close(), the LLM can still be used - the event loop + will be lazily recreated on the next LLM call. + + Example: + >>> llm = LLM(model="gpt-4o") + >>> try: + ... response = llm.completion(messages=[...]) + ... finally: + ... llm.close() # Clean up background thread + """ + if self._async_runner is not None: + self._async_runner.close() + def _handle_error( self, error: Exception, @@ -862,88 +963,20 @@ def _one_attempt(**retry_kwargs) -> ResponsesAPIResponse: assert self._telemetry is not None self._telemetry.on_request(telemetry_ctx=telemetry_ctx) final_kwargs = {**call_kwargs, **retry_kwargs} - with self._litellm_modify_params_ctx(self.modify_params): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=DeprecationWarning) - typed_input: ResponseInputParam | str = ( - cast(ResponseInputParam, input_items) if input_items else "" - ) - api_key_value = self._get_litellm_api_key_value() - - ret = litellm_responses( - model=self.model, - input=typed_input, - instructions=instructions, - tools=resp_tools, - api_key=api_key_value, - api_base=self.base_url, - api_version=self.api_version, - timeout=self.timeout, - drop_params=self.drop_params, - seed=self.seed, - **final_kwargs, - ) - if isinstance(ret, ResponsesAPIResponse): - if user_enable_streaming: - logger.warning( - "Responses streaming was requested, but the provider " - "returned a non-streaming response; no on_token deltas " - "will be emitted." - ) - self._telemetry.on_response(ret) - return ret - - # When stream=True, LiteLLM returns a streaming iterator rather than - # a single ResponsesAPIResponse. Drain the iterator and use the - # completed response. - if final_kwargs.get("stream", False): - if not isinstance(ret, SyncResponsesAPIStreamingIterator): - raise AssertionError( - f"Expected Responses stream iterator, got {type(ret)}" - ) - - stream_callback = on_token if user_enable_streaming else None - for event in ret: - if stream_callback is None: - continue - if isinstance( - event, - ( - OutputTextDeltaEvent, - RefusalDeltaEvent, - ReasoningSummaryTextDeltaEvent, - ), - ): - delta = event.delta - if delta: - stream_callback( - ModelResponseStream( - choices=[ - StreamingChoices( - delta=Delta(content=delta) - ) - ] - ) - ) - - completed_event = ret.completed_response - if completed_event is None: - raise LLMNoResponseError( - "Responses stream finished without a completed response" - ) - if not isinstance(completed_event, ResponseCompletedEvent): - raise LLMNoResponseError( - f"Unexpected completed event: {type(completed_event)}" - ) - - completed_resp = completed_event.response - self._telemetry.on_response(completed_resp) - return completed_resp - - raise AssertionError( - f"Expected ResponsesAPIResponse, got {type(ret)}" - ) + # Run async responses call in background event loop for cancellation + assert self._async_runner is not None + coro = self._async_responses_call( + typed_input=cast(ResponseInputParam, input_items) + if input_items + else "", + instructions=instructions, + resp_tools=resp_tools, + user_enable_streaming=user_enable_streaming, + on_token=on_token, + final_kwargs=final_kwargs, + ) + return self._async_runner.run(coro, "LLM responses call was cancelled") try: resp: ResponsesAPIResponse = _one_attempt() @@ -1013,6 +1046,30 @@ def _transport_call( on_token: TokenCallbackType | None = None, **kwargs, ) -> ModelResponse: + """Execute LLM completion call with cancellation support. + + This method runs the LLM call in a background async event loop, + allowing it to be cancelled via LLM.cancel(). The main thread + blocks waiting for the result. + """ + assert self._async_runner is not None + coro = self._async_transport_call( + messages=messages, + enable_streaming=enable_streaming, + on_token=on_token, + **kwargs, + ) + return self._async_runner.run(coro, "LLM completion call was cancelled") + + async def _async_transport_call( + self, + *, + messages: list[dict[str, Any]], + enable_streaming: bool = False, + on_token: TokenCallbackType | None = None, + **kwargs, + ) -> ModelResponse: + """Async implementation of transport call.""" # litellm.modify_params is GLOBAL; guard it for thread-safety with self._litellm_modify_params_ctx(self.modify_params): with warnings.catch_warnings(): @@ -1040,8 +1097,8 @@ def _transport_call( ) api_key_value = self._get_litellm_api_key_value() - # Some providers need renames handled in _normalize_call_kwargs. - ret = litellm_completion( + # Use async completion for cancellation support + ret = await litellm_acompletion( model=self.model, api_key=api_key_value, api_base=self.base_url, @@ -1052,12 +1109,18 @@ def _transport_call( messages=messages, **kwargs, ) + if enable_streaming and on_token is not None: + # For streaming, iterate async and check for cancellation + # ret is CustomStreamWrapper when streaming assert isinstance(ret, CustomStreamWrapper) - chunks = [] - for chunk in ret: - on_token(chunk) - chunks.append(chunk) + chunks: list[ModelResponseStream] = [] + async for chunk in ret: + # on_token callback is sync, call it directly + # CustomStreamWrapper yields ModelResponseStream chunks + stream_chunk = cast(ModelResponseStream, chunk) + on_token(stream_chunk) + chunks.append(stream_chunk) ret = litellm.stream_chunk_builder(chunks, messages=messages) assert isinstance(ret, ModelResponse), ( @@ -1065,6 +1128,94 @@ def _transport_call( ) return ret + async def _async_responses_call( + self, + *, + typed_input: ResponseInputParam | str, + instructions: str | None, + resp_tools: list[Any] | None, + user_enable_streaming: bool, + on_token: TokenCallbackType | None, + final_kwargs: dict[str, Any], + ) -> ResponsesAPIResponse: + """Async implementation of responses call.""" + with self._litellm_modify_params_ctx(self.modify_params): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + api_key_value = self._get_litellm_api_key_value() + + ret = await litellm_aresponses( + model=self.model, + input=typed_input, + instructions=instructions, + tools=resp_tools, + api_key=api_key_value, + api_base=self.base_url, + api_version=self.api_version, + timeout=self.timeout, + drop_params=self.drop_params, + seed=self.seed, + **final_kwargs, + ) + + if isinstance(ret, ResponsesAPIResponse): + if user_enable_streaming: + logger.warning( + "Responses streaming was requested, but the provider " + "returned a non-streaming response; no on_token deltas " + "will be emitted." + ) + assert self._telemetry is not None + self._telemetry.on_response(ret) + return ret + + # When stream=True, LiteLLM returns an async streaming iterator + if final_kwargs.get("stream", False): + if not isinstance(ret, ResponsesAPIStreamingIterator): + raise AssertionError( + f"Expected async Responses stream iterator, got {type(ret)}" + ) + + stream_callback = on_token if user_enable_streaming else None + async for event in ret: + if stream_callback is None: + continue + if isinstance( + event, + ( + OutputTextDeltaEvent, + RefusalDeltaEvent, + ReasoningSummaryTextDeltaEvent, + ), + ): + delta = event.delta + if delta: + stream_callback( + ModelResponseStream( + choices=[ + StreamingChoices(delta=Delta(content=delta)) + ] + ) + ) + + completed_event = ret.completed_response + if completed_event is None: + raise LLMNoResponseError( + "Responses stream finished without a completed response" + ) + if not isinstance(completed_event, ResponseCompletedEvent): + raise LLMNoResponseError( + f"Unexpected completed event: {type(completed_event)}" + ) + + completed_resp = completed_event.response + + assert self._telemetry is not None + self._telemetry.on_response(completed_resp) + return completed_resp + + raise AssertionError(f"Expected ResponsesAPIResponse, got {type(ret)}") + @contextmanager def _litellm_modify_params_ctx(self, flag: bool): old = getattr(litellm, "modify_params", None) diff --git a/openhands-sdk/openhands/sdk/llm/mixins/async_cancellation.py b/openhands-sdk/openhands/sdk/llm/mixins/async_cancellation.py new file mode 100644 index 0000000000..73f4ac9b9f --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/mixins/async_cancellation.py @@ -0,0 +1,165 @@ +"""Async runner for interruptible LLM calls. + +This module provides the AsyncRunner class which manages a background +event loop for running async LLM calls that can be cancelled from any thread. +""" + +from __future__ import annotations + +import asyncio +import threading +from collections.abc import Coroutine +from concurrent.futures import CancelledError as FutureCancelledError, Future +from typing import Any, TypeVar + +from openhands.sdk.llm.exceptions import LLMCancelledError +from openhands.sdk.logger import get_logger + + +logger = get_logger(__name__) + +T = TypeVar("T") + + +class AsyncRunner: + """Manages async execution with cancellation support. + + This class manages a background event loop in a daemon thread, allowing + synchronous callers to run async coroutines while supporting immediate + cancellation via cancel(). + + The event loop is created lazily on first use and can be cleaned up + via close(). After close(), the runner can still be used - the event + loop will be lazily recreated. + + Example: + >>> runner = AsyncRunner(owner_id="my-llm") + >>> result = runner.run(some_async_func(), "Call cancelled") + >>> # From another thread: + >>> runner.cancel() + """ + + def __init__(self, owner_id: str) -> None: + """Initialize the async runner. + + Args: + owner_id: Identifier for logging/debugging (e.g., LLM usage_id). + """ + self._owner_id = owner_id + self._loop: asyncio.AbstractEventLoop | None = None + self._thread: threading.Thread | None = None + self._future: Future[Any] | None = None + self._lock = threading.Lock() + + def _ensure_loop(self) -> asyncio.AbstractEventLoop: + """Lazily create background event loop thread. + + The event loop runs in a daemon thread and is used to execute async + coroutines. This allows synchronous callers to use async internally + while supporting immediate cancellation. + """ + if self._loop is None: + self._loop = asyncio.new_event_loop() + self._thread = threading.Thread( + target=self._loop.run_forever, + daemon=True, + name=f"async-runner-{self._owner_id}", + ) + self._thread.start() + logger.debug(f"Started async runner thread for {self._owner_id}") + return self._loop + + def run(self, coro: Coroutine[Any, Any, T], cancel_message: str) -> T: + """Run an async coroutine with cancellation support. + + This method submits the coroutine to the background event loop and + blocks until completion. The call can be cancelled via cancel(). + + Args: + coro: The coroutine to execute. + cancel_message: Message for LLMCancelledError if cancelled. + + Returns: + The result of the coroutine. + + Raises: + LLMCancelledError: If cancelled via cancel(). + """ + loop = self._ensure_loop() + future: Future[T] = asyncio.run_coroutine_threadsafe(coro, loop) + + with self._lock: + self._future = future + + try: + return future.result() + except (asyncio.CancelledError, FutureCancelledError): + raise LLMCancelledError(cancel_message) + finally: + with self._lock: + self._future = None + + def cancel(self) -> None: + """Cancel any in-flight call (best effort). + + This method cancels the current call immediately. The cancellation + takes effect at the next await point. + + Thread-safe: can be called from any thread. After cancellation, + the runner can be used for new calls. + """ + with self._lock: + if self._future is not None: + logger.info(f"Cancelling call for {self._owner_id}") + self._future.cancel() + + def is_cancelled(self) -> bool: + """Check if the current call has been cancelled. + + Returns: + True if there's a current call and it has been cancelled. + """ + with self._lock: + if self._future is not None: + return self._future.cancelled() + return False + + def close(self) -> None: + """Stop the background event loop and cleanup resources. + + This method should be called when the runner is no longer needed, + especially in long-running applications to prevent thread leaks. + + After close(), the runner can still be used - the event loop + will be lazily recreated on the next run() call. + """ + # First, cancel any in-flight call (outside the lock to avoid + # holding the lock during join) + future_to_cancel: Future[Any] | None = None + loop_to_stop: asyncio.AbstractEventLoop | None = None + thread_to_join: threading.Thread | None = None + + with self._lock: + future_to_cancel = self._future + self._future = None + loop_to_stop = self._loop + self._loop = None + thread_to_join = self._thread + self._thread = None + + # Perform cleanup outside the lock + if future_to_cancel is not None: + future_to_cancel.cancel() + + if loop_to_stop is not None: + loop_to_stop.call_soon_threadsafe(loop_to_stop.stop) + + if thread_to_join is not None: + thread_to_join.join(timeout=2.0) + if thread_to_join.is_alive(): + logger.warning( + f"Async runner thread for {self._owner_id} did not stop " + "within timeout" + ) + else: + logger.debug(f"Stopped async runner thread for {self._owner_id}") diff --git a/tests/cross/test_agent_loading.py b/tests/cross/test_agent_loading.py index 07fa7e0f85..138d9aa6c7 100644 --- a/tests/cross/test_agent_loading.py +++ b/tests/cross/test_agent_loading.py @@ -422,7 +422,7 @@ def test_conversation_fails_when_agent_type_changes(): ) -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_conversation_persistence_lifecycle(mock_completion): """Test full conversation persistence lifecycle similar to examples/10_persistence.py.""" # noqa: E501 from tests.conftest import create_mock_litellm_response diff --git a/tests/cross/test_conversation_restore_behavior.py b/tests/cross/test_conversation_restore_behavior.py index 05abb85fc4..bf7b946a78 100644 --- a/tests/cross/test_conversation_restore_behavior.py +++ b/tests/cross/test_conversation_restore_behavior.py @@ -133,7 +133,7 @@ def _agent( return agent_type(**agent_kwargs) -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_conversation_restore_lifecycle_happy_path(mock_completion): """Baseline: restore should load prior events and allow further execution.""" @@ -195,7 +195,7 @@ def capture_completion(*_args: Any, **kwargs: Any): restored.close() -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_conversation_restore_fails_when_removing_tools(mock_completion): """Restore must fail when runtime tools remove a persisted tool.""" @@ -239,7 +239,7 @@ def test_conversation_restore_fails_when_removing_tools(mock_completion): assert "FileEditorTool" in str(exc.value) -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_conversation_restore_fails_when_adding_tools(mock_completion): """Restore must fail when runtime tools add a new tool.""" @@ -283,7 +283,7 @@ def test_conversation_restore_fails_when_adding_tools(mock_completion): assert "FileEditorTool" in str(exc.value) -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_conversation_restore_fails_when_agent_class_changes(mock_completion): """Restore must fail when persisted and runtime agent types differ.""" @@ -326,7 +326,7 @@ def test_conversation_restore_fails_when_agent_class_changes(mock_completion): assert "self is of type" in str(exc.value) -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_conversation_restore_fails_when_default_tools_removed(mock_completion): """Restore must fail if include_default_tools removes a built-in tool.""" @@ -372,7 +372,7 @@ def test_conversation_restore_fails_when_default_tools_removed(mock_completion): assert "think" in str(exc.value) -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_conversation_restore_fails_when_default_tools_added(mock_completion): """Restore must fail if include_default_tools adds a built-in tool.""" @@ -418,7 +418,7 @@ def test_conversation_restore_fails_when_default_tools_added(mock_completion): assert "think" in str(exc.value) -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_conversation_restore_succeeds_when_llm_condenser_and_skills_change( mock_completion, ): @@ -477,7 +477,7 @@ def test_conversation_restore_succeeds_when_llm_condenser_and_skills_change( restored.close() -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_restore_reasoning_effort_none_strips_temperature(mock_completion): """Reasoning models should accept reasoning_effort and ignore temperature/top_p.""" diff --git a/tests/cross/test_hello_world.py b/tests/cross/test_hello_world.py index 19de6ba005..5f04b8211e 100644 --- a/tests/cross/test_hello_world.py +++ b/tests/cross/test_hello_world.py @@ -145,7 +145,7 @@ def create_mock_llm_responses(self): return [first_response, second_response] - @patch("openhands.sdk.llm.llm.litellm_completion") + @patch("openhands.sdk.llm.llm.litellm_acompletion") def test_hello_world_with_real_llm_data(self, mock_completion, fncall_raw_logs): """Test the complete hello world flow with real LLM completion data.""" # Setup real LLM responses from fixtures @@ -269,7 +269,7 @@ def test_hello_world_with_real_llm_data(self, mock_completion, fncall_raw_logs): "Real responses should have content" ) - @patch("openhands.sdk.llm.llm.litellm_completion") + @patch("openhands.sdk.llm.llm.litellm_acompletion") def test_llm_completion_logging_fidelity(self, mock_completion, fncall_raw_logs): """Test mocked LLM completion logging produces same output.""" # Use mock responses for consistent behavior instead of real fixture data @@ -434,7 +434,7 @@ def capture_completion_fidelity(*args, **kwargs): # Mock the completion method with patch( - "openhands.sdk.llm.llm.litellm_completion", + "openhands.sdk.llm.llm.litellm_acompletion", side_effect=capture_completion_fidelity, ): # Create conversation and send a message @@ -527,7 +527,7 @@ def capture_completion_non_func(*args, **kwargs): # Mock the completion method with patch( - "openhands.sdk.llm.llm.litellm_completion", + "openhands.sdk.llm.llm.litellm_acompletion", side_effect=capture_completion_non_func, ): # Create conversation and send a message diff --git a/tests/sdk/agent/test_message_while_finishing.py b/tests/sdk/agent/test_message_while_finishing.py index 45fa9c80bf..48462d603a 100644 --- a/tests/sdk/agent/test_message_while_finishing.py +++ b/tests/sdk/agent/test_message_while_finishing.py @@ -335,7 +335,7 @@ def elapsed_time(): print(f"{elapsed_time()} Test started") with patch( - "openhands.sdk.llm.llm.litellm_completion", + "openhands.sdk.llm.llm.litellm_acompletion", side_effect=self._mock_llm_response, ): # Start the conversation with a natural request diff --git a/tests/sdk/agent/test_non_executable_action_emission.py b/tests/sdk/agent/test_non_executable_action_emission.py index 392009f506..4e7c9ba292 100644 --- a/tests/sdk/agent/test_non_executable_action_emission.py +++ b/tests/sdk/agent/test_non_executable_action_emission.py @@ -68,7 +68,7 @@ def cb(e): conv = Conversation(agent=agent, callbacks=[cb]) with patch( - "openhands.sdk.llm.llm.litellm_completion", side_effect=mock_llm_response + "openhands.sdk.llm.llm.litellm_acompletion", side_effect=mock_llm_response ): conv.send_message(Message(role="user", content=[TextContent(text="go")])) agent.step(conv, on_event=cb) diff --git a/tests/sdk/agent/test_nonexistent_tool_handling.py b/tests/sdk/agent/test_nonexistent_tool_handling.py index f111792168..32b3ca500f 100644 --- a/tests/sdk/agent/test_nonexistent_tool_handling.py +++ b/tests/sdk/agent/test_nonexistent_tool_handling.py @@ -70,7 +70,7 @@ def event_callback(event): conversation = Conversation(agent=agent, callbacks=[event_callback]) with patch( - "openhands.sdk.llm.llm.litellm_completion", side_effect=mock_llm_response + "openhands.sdk.llm.llm.litellm_acompletion", side_effect=mock_llm_response ): # Send a message to start the conversation conversation.send_message( @@ -160,7 +160,7 @@ def event_callback(event): conversation = Conversation(agent=agent, callbacks=[event_callback]) with patch( - "openhands.sdk.llm.llm.litellm_completion", side_effect=mock_llm_response + "openhands.sdk.llm.llm.litellm_acompletion", side_effect=mock_llm_response ): conversation.send_message( Message( @@ -271,7 +271,7 @@ def event_callback(event): conversation = Conversation(agent=agent, callbacks=[event_callback]) with patch( - "openhands.sdk.llm.llm.litellm_completion", side_effect=mock_llm_response + "openhands.sdk.llm.llm.litellm_acompletion", side_effect=mock_llm_response ): conversation.send_message( Message( diff --git a/tests/sdk/agent/test_security_policy_integration.py b/tests/sdk/agent/test_security_policy_integration.py index 1624d210f5..f0694484b5 100644 --- a/tests/sdk/agent/test_security_policy_integration.py +++ b/tests/sdk/agent/test_security_policy_integration.py @@ -306,7 +306,7 @@ def test_security_risk_param_ignored_when_no_analyzer(): # Mock LLM response that includes security_risk=HIGH even though # llm_security_analyzer=False (the LLM might do this if it's well-trained) with patch( - "openhands.sdk.llm.llm.litellm_completion", + "openhands.sdk.llm.llm.litellm_acompletion", return_value=_tool_response( "think", '{"thought": "This is a test thought", "security_risk": "HIGH"}', diff --git a/tests/sdk/agent/test_tool_execution_error_handling.py b/tests/sdk/agent/test_tool_execution_error_handling.py index ce22e13966..1508889180 100644 --- a/tests/sdk/agent/test_tool_execution_error_handling.py +++ b/tests/sdk/agent/test_tool_execution_error_handling.py @@ -113,7 +113,7 @@ def event_callback(event): conversation = Conversation(agent=agent, callbacks=[event_callback]) with patch( - "openhands.sdk.llm.llm.litellm_completion", side_effect=mock_llm_response + "openhands.sdk.llm.llm.litellm_acompletion", side_effect=mock_llm_response ): conversation.send_message( Message( @@ -229,7 +229,7 @@ def event_callback(event): conversation = Conversation(agent=agent, callbacks=[event_callback]) with patch( - "openhands.sdk.llm.llm.litellm_completion", side_effect=mock_llm_response + "openhands.sdk.llm.llm.litellm_acompletion", side_effect=mock_llm_response ): conversation.send_message( Message( diff --git a/tests/sdk/context/condenser/test_llm_summarizing_condenser.py b/tests/sdk/context/condenser/test_llm_summarizing_condenser.py index 475e403639..a3ce624dcc 100644 --- a/tests/sdk/context/condenser/test_llm_summarizing_condenser.py +++ b/tests/sdk/context/condenser/test_llm_summarizing_condenser.py @@ -77,6 +77,8 @@ def create_completion_result(content: str) -> LLMResponse: mock_llm.output_cost_per_token = None mock_llm._metrics = None + mock_llm._async_runner = None + mock_llm.usage_id = "test-llm" # Helper method to set mock response content def set_mock_response_content(content: str): diff --git a/tests/sdk/conversation/local/test_confirmation_mode.py b/tests/sdk/conversation/local/test_confirmation_mode.py index 0a8533baca..e38cdeb450 100644 --- a/tests/sdk/conversation/local/test_confirmation_mode.py +++ b/tests/sdk/conversation/local/test_confirmation_mode.py @@ -159,7 +159,7 @@ def _make_pending_action(self) -> None: self.conversation.set_confirmation_policy(AlwaysConfirm()) mock_completion = self._mock_action_once() with patch( - "openhands.sdk.llm.llm.litellm_completion", + "openhands.sdk.llm.llm.litellm_acompletion", return_value=mock_completion.return_value, ): self.conversation.send_message( @@ -434,7 +434,7 @@ def test_message_only_in_confirmation_mode_does_not_wait(self): self.conversation.set_confirmation_policy(AlwaysConfirm()) mock_completion = self._mock_message_only("Hello, how can I help you?") with patch( - "openhands.sdk.llm.llm.litellm_completion", + "openhands.sdk.llm.llm.litellm_acompletion", return_value=mock_completion.return_value, ): self.conversation.send_message( @@ -470,7 +470,7 @@ def test_action_then_confirm_or_reject(self, should_reject: bool): # Confirm path per your instruction: call run() to execute pending action mock_completion = self._mock_message_only("Task completed successfully!") with patch( - "openhands.sdk.llm.llm.litellm_completion", + "openhands.sdk.llm.llm.litellm_acompletion", return_value=mock_completion.return_value, ): self.conversation.run() @@ -520,7 +520,7 @@ def test_single_finish_action_skips_confirmation_entirely(self): # Send a message that should trigger the finish action with patch( - "openhands.sdk.llm.llm.litellm_completion", + "openhands.sdk.llm.llm.litellm_acompletion", return_value=mock_completion.return_value, ): self.conversation.send_message( @@ -565,7 +565,7 @@ def test_think_and_finish_action_skips_confirmation_entirely(self): mock_finish = self._mock_finish_action("Analysis complete") with patch( - "openhands.sdk.llm.llm.litellm_completion", + "openhands.sdk.llm.llm.litellm_acompletion", side_effect=[mock_think.return_value, mock_finish.return_value], ): # Kick things off (LLM returns ThinkAction; should execute immediately) diff --git a/tests/sdk/conversation/local/test_conversation_pause_functionality.py b/tests/sdk/conversation/local/test_conversation_pause_functionality.py index 0890a4add7..82d1f7d460 100644 --- a/tests/sdk/conversation/local/test_conversation_pause_functionality.py +++ b/tests/sdk/conversation/local/test_conversation_pause_functionality.py @@ -181,7 +181,7 @@ def test_pause_basic_functionality(self): assert len(pause_events) == 1 assert pause_events[0].source == "user" - @patch("openhands.sdk.llm.llm.litellm_completion") + @patch("openhands.sdk.llm.llm.litellm_acompletion") def test_pause_during_normal_execution(self, mock_completion): """Test pausing before run() starts - pause is reset and agent runs normally.""" # Mock LLM to return a message that finishes execution @@ -228,7 +228,7 @@ def test_pause_during_normal_execution(self, mock_completion): ] assert len(pause_events) == 1 - @patch("openhands.sdk.llm.llm.litellm_completion") + @patch("openhands.sdk.llm.llm.litellm_acompletion") def test_resume_paused_agent(self, mock_completion): """Test pausing before run() - pause is reset and agent runs normally.""" # Mock LLM to return a message that finishes execution @@ -273,7 +273,7 @@ def test_resume_paused_agent(self, mock_completion): ] assert len(agent_messages) == 1 # Agent ran and completed - @patch("openhands.sdk.llm.llm.litellm_completion") + @patch("openhands.sdk.llm.llm.litellm_acompletion") def test_pause_with_confirmation_mode(self, mock_completion): """Test that pause before run() with confirmation mode - pause is reset and agent waits for confirmation.""" # noqa: E501 # Enable confirmation mode @@ -365,7 +365,7 @@ def test_multiple_pause_calls_create_one_event(self): ) @pytest.mark.timeout(3) - @patch("openhands.sdk.llm.llm.litellm_completion") + @patch("openhands.sdk.llm.llm.litellm_acompletion") def test_pause_while_running_continuous_actions(self, mock_completion): step_entered = threading.Event() diff --git a/tests/sdk/conversation/test_base_span_management.py b/tests/sdk/conversation/test_base_span_management.py index 01c1ddb536..ad9ca0debb 100644 --- a/tests/sdk/conversation/test_base_span_management.py +++ b/tests/sdk/conversation/test_base_span_management.py @@ -35,6 +35,9 @@ def id(self) -> UUID: def pause(self) -> None: pass + def interrupt(self) -> None: + pass + def reject_pending_actions(self, reason: str = "User rejected the action") -> None: pass diff --git a/tests/sdk/conversation/test_conversation_interrupt.py b/tests/sdk/conversation/test_conversation_interrupt.py new file mode 100644 index 0000000000..4eebea9130 --- /dev/null +++ b/tests/sdk/conversation/test_conversation_interrupt.py @@ -0,0 +1,214 @@ +"""Tests for conversation interrupt functionality.""" + +import threading +from unittest.mock import patch + +import pytest +from pydantic import SecretStr + +from openhands.sdk import Agent, LocalConversation +from openhands.sdk.conversation.state import ConversationExecutionStatus +from openhands.sdk.event.user_action import InterruptEvent, PauseEvent +from openhands.sdk.llm import LLM + + +@pytest.fixture +def llm(): + """Create a test LLM instance.""" + return LLM( + model="gpt-4o", + api_key=SecretStr("test_key"), + usage_id="test-conversation-llm", + num_retries=0, + ) + + +@pytest.fixture +def agent(llm: LLM): + """Create a test agent.""" + return Agent(llm=llm) + + +def test_interrupt_event_exists(): + """Test that InterruptEvent can be instantiated.""" + event = InterruptEvent() + assert event.source == "user" + assert event.reason == "User requested interrupt" + + +def test_interrupt_event_visualize(): + """Test InterruptEvent visualization.""" + event = InterruptEvent() + viz = event.visualize + + assert "Interrupted" in viz.plain + + +def test_interrupt_event_str(): + """Test InterruptEvent string representation.""" + event = InterruptEvent() + s = str(event) + assert "InterruptEvent" in s + assert "user" in s + + +def test_interrupt_event_custom_reason(): + """Test InterruptEvent with custom reason.""" + event = InterruptEvent(reason="Custom stop reason") + assert event.reason == "Custom stop reason" + + viz = event.visualize + assert "Custom stop reason" in viz.plain + + +def test_pause_event_vs_interrupt_event(): + """Test that PauseEvent and InterruptEvent are distinct.""" + pause = PauseEvent() + interrupt = InterruptEvent() + + assert type(pause).__name__ == "PauseEvent" + assert type(interrupt).__name__ == "InterruptEvent" + + # Different visualization + assert "Paused" in pause.visualize.plain + assert "Interrupted" in interrupt.visualize.plain + + +def test_conversation_has_interrupt_method(agent: Agent, tmp_path): + """Test that LocalConversation has interrupt method.""" + conv = LocalConversation(agent=agent, workspace=str(tmp_path)) + assert hasattr(conv, "interrupt") + assert callable(conv.interrupt) + + +def test_conversation_interrupt_cancels_llm(agent: Agent, tmp_path): + """Test that interrupt() calls llm.cancel().""" + # Create conversation + conv = LocalConversation(agent=agent, workspace=str(tmp_path)) + + # Mock the LLM's cancel method at class level + with patch("openhands.sdk.llm.LLM.cancel") as mock_cancel: + # Call interrupt + conv.interrupt() + + # Verify cancel was called on the LLM + mock_cancel.assert_called() + + +def test_conversation_interrupt_sets_paused_status(agent: Agent, tmp_path): + """Test that interrupt() sets status to PAUSED.""" + conv = LocalConversation(agent=agent, workspace=str(tmp_path)) + + # Initially IDLE + assert conv.state.execution_status == ConversationExecutionStatus.IDLE + + # Call interrupt + conv.interrupt() + + # Should be PAUSED + assert conv.state.execution_status == ConversationExecutionStatus.PAUSED + + +def test_conversation_interrupt_when_running(agent: Agent, tmp_path): + """Test interrupt when conversation is in RUNNING status.""" + conv = LocalConversation(agent=agent, workspace=str(tmp_path)) + + # Manually set to running + conv._state.execution_status = ConversationExecutionStatus.RUNNING + + # Call interrupt + conv.interrupt() + + # Should be PAUSED + assert conv.state.execution_status == ConversationExecutionStatus.PAUSED + + +def test_conversation_interrupt_idempotent(agent: Agent, tmp_path): + """Test that multiple interrupt calls don't cause issues.""" + conv = LocalConversation(agent=agent, workspace=str(tmp_path)) + + # Call interrupt multiple times + conv.interrupt() + conv.interrupt() + conv.interrupt() + + # Should remain PAUSED + assert conv.state.execution_status == ConversationExecutionStatus.PAUSED + + +def test_conversation_interrupt_cancels_all_llms_in_registry(agent: Agent, tmp_path): + """Test that interrupt cancels LLMs in the registry too.""" + conv = LocalConversation(agent=agent, workspace=str(tmp_path)) + + # Add an LLM to the registry using the proper API + extra_llm = LLM( + model="gpt-4o", + api_key=SecretStr("test_key"), + usage_id="extra-llm", + num_retries=0, + ) + conv.llm_registry.add(extra_llm) + + # Mock cancel at class level - both calls go through the same mock + with patch("openhands.sdk.llm.LLM.cancel") as mock_cancel: + # Call interrupt + conv.interrupt() + + # cancel should be called >= 2 times (agent.llm + extra_llm) + assert mock_cancel.call_count >= 2 + + +def test_conversation_interrupt_when_already_paused(agent: Agent, tmp_path): + """Test interrupt when already paused still cancels LLM.""" + conv = LocalConversation(agent=agent, workspace=str(tmp_path)) + + # Set to PAUSED + conv._state.execution_status = ConversationExecutionStatus.PAUSED + + # Mock cancel method at class level + with patch("openhands.sdk.llm.LLM.cancel") as mock_cancel: + # Call interrupt - should still cancel LLM but not change status + conv.interrupt() + + # LLM cancel should still be called + mock_cancel.assert_called() + + # Status should remain PAUSED + assert conv.state.execution_status == ConversationExecutionStatus.PAUSED + + +def test_conversation_interrupt_when_finished(agent: Agent, tmp_path): + """Test interrupt when conversation is finished (status doesn't change).""" + conv = LocalConversation(agent=agent, workspace=str(tmp_path)) + + # Set to FINISHED + conv._state.execution_status = ConversationExecutionStatus.FINISHED + + # Mock cancel method at class level + with patch("openhands.sdk.llm.LLM.cancel") as mock_cancel: + # Call interrupt + conv.interrupt() + + # LLM cancel should still be called (in case something is running) + mock_cancel.assert_called() + + # Status should remain FINISHED + assert conv.state.execution_status == ConversationExecutionStatus.FINISHED + + +def test_conversation_interrupt_is_thread_safe(agent: Agent, tmp_path): + """Test that interrupt can be called from multiple threads safely.""" + conv = LocalConversation(agent=agent, workspace=str(tmp_path)) + + # Call interrupt from multiple threads + threads = [] + for _ in range(10): + t = threading.Thread(target=conv.interrupt) + threads.append(t) + t.start() + + for t in threads: + t.join(timeout=2) + + # Should not raise any errors and status should be PAUSED + assert conv.state.execution_status == ConversationExecutionStatus.PAUSED diff --git a/tests/sdk/conversation/test_conversation_stats.py b/tests/sdk/conversation/test_conversation_stats.py index 98d37808ef..be0c774201 100644 --- a/tests/sdk/conversation/test_conversation_stats.py +++ b/tests/sdk/conversation/test_conversation_stats.py @@ -109,7 +109,7 @@ def test_get_metrics_for_usage(conversation_stats): def test_register_llm_with_new_usage(conversation_stats): """Test registering a new LLM usage.""" # Patch the LLM class to avoid actual API calls - with patch("openhands.sdk.llm.llm.litellm_completion"): + with patch("openhands.sdk.llm.llm.litellm_acompletion"): llm = LLM( usage_id="new-service", model="gpt-4o", @@ -140,7 +140,7 @@ def test_register_llm_with_restored_metrics(conversation_stats): conversation_stats.usage_to_metrics = {usage_id: restored_metrics} # Patch the LLM class to avoid actual API calls - with patch("openhands.sdk.llm.llm.litellm_completion"): + with patch("openhands.sdk.llm.llm.litellm_acompletion"): llm = LLM( usage_id=usage_id, model="gpt-4o", @@ -319,7 +319,7 @@ def test_register_llm_with_multiple_restored_usage_ids(conversation_stats): } # Patch the LLM class to avoid actual API calls - with patch("openhands.sdk.llm.llm.litellm_completion"): + with patch("openhands.sdk.llm.llm.litellm_acompletion"): # Register first LLM llm_1 = LLM( usage_id=usage_id_1, diff --git a/tests/sdk/llm/test_api_connection_error_retry.py b/tests/sdk/llm/test_api_connection_error_retry.py index 8bb644d899..155aa1797c 100644 --- a/tests/sdk/llm/test_api_connection_error_retry.py +++ b/tests/sdk/llm/test_api_connection_error_retry.py @@ -40,16 +40,16 @@ def default_config(): ) -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_completion_retries_api_connection_error( - mock_litellm_completion, default_config + mock_litellm_acompletion, default_config ): """Test that APIConnectionError is properly retried.""" mock_response = create_mock_response("Retry successful") - # Mock the litellm_completion to first raise an APIConnectionError, + # Mock the litellm_acompletion to first raise an APIConnectionError, # then return a successful response - mock_litellm_completion.side_effect = [ + mock_litellm_acompletion.side_effect = [ APIConnectionError( message="API connection error", llm_provider="test_provider", @@ -74,16 +74,16 @@ def test_completion_retries_api_connection_error( # Verify that the retry was successful assert isinstance(response, LLMResponse) assert response.raw_response == mock_response - assert mock_litellm_completion.call_count == 2 # Initial call + 1 retry + assert mock_litellm_acompletion.call_count == 2 # Initial call + 1 retry -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_completion_max_retries_api_connection_error( - mock_litellm_completion, default_config + mock_litellm_acompletion, default_config ): """Test that APIConnectionError respects max retries and is mapped to SDK error.""" - # Mock the litellm_completion to raise APIConnectionError multiple times - mock_litellm_completion.side_effect = [ + # Mock the litellm_acompletion to raise APIConnectionError multiple times + mock_litellm_acompletion.side_effect = [ APIConnectionError( message="API connection error 1", llm_provider="test_provider", @@ -120,7 +120,7 @@ def test_completion_max_retries_api_connection_error( # Verify that the correct number of retries were attempted # The actual behavior is that it tries num_retries times total - assert mock_litellm_completion.call_count == default_config.num_retries + assert mock_litellm_acompletion.call_count == default_config.num_retries # The exception should contain connection error information assert "API connection error" in str(excinfo.value) @@ -129,11 +129,11 @@ def test_completion_max_retries_api_connection_error( assert isinstance(excinfo.value.__cause__, APIConnectionError) -@patch("openhands.sdk.llm.llm.litellm_completion") -def test_completion_no_retry_on_success(mock_litellm_completion, default_config): +@patch("openhands.sdk.llm.llm.litellm_acompletion") +def test_completion_no_retry_on_success(mock_litellm_acompletion, default_config): """Test that successful calls don't trigger retries.""" mock_response = create_mock_response("Success on first try") - mock_litellm_completion.return_value = mock_response + mock_litellm_acompletion.return_value = mock_response # Create an LLM instance and call completion llm = LLM( @@ -151,16 +151,16 @@ def test_completion_no_retry_on_success(mock_litellm_completion, default_config) # Verify that no retries were needed assert isinstance(response, LLMResponse) assert response.raw_response == mock_response - assert mock_litellm_completion.call_count == 1 # Only the initial call + assert mock_litellm_acompletion.call_count == 1 # Only the initial call -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_completion_no_retry_on_non_retryable_error( - mock_litellm_completion, default_config + mock_litellm_acompletion, default_config ): """Test that non-retryable errors don't trigger retries.""" # Mock a non-retryable error (e.g., ValueError) - mock_litellm_completion.side_effect = ValueError("Invalid input") + mock_litellm_acompletion.side_effect = ValueError("Invalid input") # Create an LLM instance and call completion llm = LLM( @@ -179,7 +179,7 @@ def test_completion_no_retry_on_non_retryable_error( ) # Verify that no retries were attempted - assert mock_litellm_completion.call_count == 1 # Only the initial call + assert mock_litellm_acompletion.call_count == 1 # Only the initial call assert "Invalid input" in str(excinfo.value) @@ -210,8 +210,8 @@ def test_retry_configuration_validation(): assert llm_custom.retry_multiplier == 2.0 -@patch("openhands.sdk.llm.llm.litellm_completion") -def test_retry_listener_callback(mock_litellm_completion, default_config): +@patch("openhands.sdk.llm.llm.litellm_acompletion") +def test_retry_listener_callback(mock_litellm_acompletion, default_config): """Test that retry listener callback is called during retries.""" retry_calls = [] @@ -220,7 +220,7 @@ def retry_listener(attempt: int, max_attempts: int, _err: BaseException | None): mock_response = create_mock_response("Success after retry") - mock_litellm_completion.side_effect = [ + mock_litellm_acompletion.side_effect = [ APIConnectionError( message="Connection failed", llm_provider="test_provider", diff --git a/tests/sdk/llm/test_llm.py b/tests/sdk/llm/test_llm.py index 9711d3e681..088b788275 100644 --- a/tests/sdk/llm/test_llm.py +++ b/tests/sdk/llm/test_llm.py @@ -230,7 +230,7 @@ def test_metrics_diff(): assert accumulated_diff["cache_write_tokens"] == 2 -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_llm_completion_with_mock(mock_completion): """Test LLM completion with mocked litellm.""" mock_response = create_mock_litellm_response("Test response") @@ -255,7 +255,7 @@ def test_llm_completion_with_mock(mock_completion): mock_completion.assert_called_once() -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_llm_retry_on_rate_limit(mock_completion): """Test that LLM retries on rate limit errors.""" mock_response = create_mock_litellm_response("Success after retry") @@ -318,7 +318,7 @@ def test_llm_token_counting(default_llm): assert token_count >= 0 -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_llm_forwards_extra_headers_to_litellm(mock_completion): mock_response = create_mock_litellm_response("ok") mock_completion.return_value = mock_response @@ -341,7 +341,7 @@ def test_llm_forwards_extra_headers_to_litellm(mock_completion): assert kwargs.get("extra_headers") == headers -@patch("openhands.sdk.llm.llm.litellm_responses") +@patch("openhands.sdk.llm.llm.litellm_aresponses") def test_llm_responses_forwards_extra_headers_to_litellm(mock_responses): # Build a minimal, but valid, ResponsesAPIResponse instance per litellm types # Build typed message output using OpenAI types to satisfy litellm schema @@ -388,7 +388,7 @@ def test_llm_responses_forwards_extra_headers_to_litellm(mock_responses): assert kwargs.get("extra_headers") == headers -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_completion_merges_llm_extra_headers_with_extended_thinking_default( mock_completion, ): @@ -419,7 +419,7 @@ def test_completion_merges_llm_extra_headers_with_extended_thinking_default( assert headers.get("X-Trace") == "1" -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_completion_call_time_extra_headers_override_config_and_defaults( mock_completion, ): @@ -454,7 +454,7 @@ def test_completion_call_time_extra_headers_override_config_and_defaults( assert "X-Trace" not in headers -@patch("openhands.sdk.llm.llm.litellm_responses") +@patch("openhands.sdk.llm.llm.litellm_aresponses") def test_responses_call_time_extra_headers_override_config(mock_responses): # Build a minimal valid Responses response msg = ResponseOutputMessage.model_construct( @@ -777,7 +777,7 @@ def test_llm_config_validation(): assert full_llm.max_output_tokens == 1000 -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_llm_no_response_error(mock_completion): """Test handling of LLMNoResponseError.""" from litellm.types.utils import ModelResponse, Usage diff --git a/tests/sdk/llm/test_llm_completion.py b/tests/sdk/llm/test_llm_completion.py index de0f482816..d6f9ad9aa1 100644 --- a/tests/sdk/llm/test_llm_completion.py +++ b/tests/sdk/llm/test_llm_completion.py @@ -74,7 +74,7 @@ def default_config(): ) -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_llm_completion_basic(mock_completion): """Test basic LLM completion functionality.""" mock_response = create_mock_response("Test response") @@ -119,7 +119,7 @@ def test_llm_streaming_not_supported(default_config): llm.completion(messages=messages, stream=True) -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") @patch("openhands.sdk.llm.llm.litellm.stream_chunk_builder") def test_llm_completion_streaming_with_callback(mock_stream_builder, mock_completion): """Test that streaming with on_token callback works correctly.""" @@ -167,9 +167,14 @@ def test_llm_completion_streaming_with_callback(mock_stream_builder, mock_comple object="chat.completion.chunk", ) - # Create a mock stream wrapper + # Create an async iterator for streaming (acompletion returns async iterator) + async def async_stream(): + for chunk in [chunk1, chunk2, chunk3]: + yield chunk + + # Create a mock stream wrapper that supports async iteration mock_stream = MagicMock(spec=CustomStreamWrapper) - mock_stream.__iter__.return_value = iter([chunk1, chunk2, chunk3]) + mock_stream.__aiter__ = lambda self: async_stream().__aiter__() mock_completion.return_value = mock_stream # Mock the stream builder to return a complete response @@ -210,7 +215,7 @@ def on_token(chunk): assert response.message.content[0].text == "Hello world!" -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") @patch("openhands.sdk.llm.llm.litellm.stream_chunk_builder") def test_llm_completion_streaming_with_tools(mock_stream_builder, mock_completion): """Test streaming completion with tool calls.""" @@ -277,9 +282,14 @@ def test_llm_completion_streaming_with_tools(mock_stream_builder, mock_completio object="chat.completion.chunk", ) - # Create mock stream + # Create an async iterator for streaming (acompletion returns async iterator) + async def async_stream(): + for chunk in [chunk1, chunk2, chunk3]: + yield chunk + + # Create mock stream that supports async iteration mock_stream = MagicMock(spec=CustomStreamWrapper) - mock_stream.__iter__.return_value = iter([chunk1, chunk2, chunk3]) + mock_stream.__aiter__ = lambda self: async_stream().__aiter__() mock_completion.return_value = mock_stream # Mock final response with tool call @@ -323,7 +333,7 @@ def on_token(chunk): assert response.message.tool_calls[0].name == "test_tool" -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_llm_completion_with_tools(mock_completion): """Test LLM completion with tools.""" mock_response = create_mock_response("I'll use the tool") @@ -368,7 +378,7 @@ def test_llm_completion_with_tools(mock_completion): mock_completion.assert_called_once() -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_llm_completion_error_handling(mock_completion): """Test LLM completion error handling.""" # Mock an exception @@ -478,7 +488,7 @@ def test_llm_token_usage_tracking(default_config): assert accumulated.completion_tokens >= 5 -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_llm_completion_with_custom_params(mock_completion, default_config): """Test LLM completion with custom parameters.""" mock_response = create_mock_response("Custom response") @@ -515,7 +525,7 @@ def test_llm_completion_with_custom_params(mock_completion, default_config): assert call_kwargs.get("top_p") == 0.9 -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_llm_completion_non_function_call_mode(mock_completion): """Test LLM completion with non-function call mode (prompt-based tool calling).""" # Create a mock response that looks like a non-function call response @@ -594,7 +604,7 @@ def test_llm_completion_non_function_call_mode(mock_completion): assert len(call_messages) >= len(messages) -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_llm_completion_function_call_vs_non_function_call_mode(mock_completion): """Test the difference between function call mode and non-function call mode.""" mock_response = create_mock_response("Test response") diff --git a/tests/sdk/llm/test_llm_fallback.py b/tests/sdk/llm/test_llm_fallback.py index 9758d65a02..cda58e6d1c 100644 --- a/tests/sdk/llm/test_llm_fallback.py +++ b/tests/sdk/llm/test_llm_fallback.py @@ -55,7 +55,7 @@ def _patch_resolve(primary: LLM, fallback_instances: list[LLM]): primary.fallback_strategy._resolved = fallback_instances -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_primary_succeeds_fallback_not_tried(mock_comp): mock_comp.return_value = _get_mock_response("primary ok") @@ -72,7 +72,7 @@ def test_primary_succeeds_fallback_not_tried(mock_comp): assert mock_comp.call_count == 1 -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_fallback_succeeds_after_primary_transient_failure(mock_comp): primary_error = APIConnectionError( message="connection reset", llm_provider="openai", model="gpt-4o" @@ -96,7 +96,7 @@ def side_effect(**kwargs): assert content.text == "fallback ok" -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_all_fallbacks_fail_raises_primary_error(mock_comp): mock_comp.side_effect = APIConnectionError( message="down", llm_provider="openai", model="gpt-4o" @@ -114,7 +114,7 @@ def test_all_fallbacks_fail_raises_primary_error(mock_comp): _ = primary.completion(_MSGS) -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_non_transient_error_skips_fallback(mock_comp): """A plain Exception is NOT in LLM_FALLBACK_EXCEPTIONS, so fallback should be skipped.""" @@ -132,7 +132,7 @@ def test_non_transient_error_skips_fallback(mock_comp): assert mock_comp.call_count == 1 -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_no_fallbacks_configured_normal_error(mock_comp): mock_comp.side_effect = APIConnectionError( message="down", llm_provider="openai", model="gpt-4o" @@ -145,7 +145,7 @@ def test_no_fallbacks_configured_normal_error(mock_comp): _ = primary.completion(_MSGS) -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_metrics_merged_from_fallback(mock_comp): primary_error = RateLimitError( message="rate limited", llm_provider="openai", model="gpt-4o" @@ -183,7 +183,7 @@ def side_effect(**kwargs): ) -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_second_fallback_succeeds(mock_comp): # Second fallback succeeds after first fallback fails call_count = {"n": 0} @@ -211,7 +211,7 @@ def side_effect(**kwargs): assert call_count["n"] == 3 -@patch("openhands.sdk.llm.llm.litellm_responses") +@patch("openhands.sdk.llm.llm.litellm_aresponses") def test_responses_fallback_succeeds(mock_resp): """Ensure fallback works through the responses() code path too.""" from litellm.types.llms.openai import ResponsesAPIResponse @@ -260,7 +260,7 @@ def side_effect(**kwargs): assert content.text == "fb ok" -@patch("openhands.sdk.llm.llm.litellm_responses") +@patch("openhands.sdk.llm.llm.litellm_aresponses") def test_responses_non_transient_skips_fallback(mock_resp): mock_resp.side_effect = Exception("not transient") @@ -275,7 +275,7 @@ def test_responses_non_transient_skips_fallback(mock_resp): assert mock_resp.call_count == 1 -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_fallback_profiles_resolved_via_store(mock_comp, tmp_path): """Verify that fallback profile names are resolved through LLMProfileStore.""" from openhands.sdk.llm.llm_profile_store import LLMProfileStore diff --git a/tests/sdk/llm/test_llm_interrupt.py b/tests/sdk/llm/test_llm_interrupt.py new file mode 100644 index 0000000000..4caf3914e3 --- /dev/null +++ b/tests/sdk/llm/test_llm_interrupt.py @@ -0,0 +1,451 @@ +"""Tests for LLM cancellation and interrupt functionality.""" + +import asyncio +import threading +import time +from typing import Any +from unittest.mock import patch + +import pytest +from litellm.types.utils import ( + Choices, + Message as LiteLLMMessage, + ModelResponse, + Usage, +) +from pydantic import SecretStr + +from openhands.sdk.llm import LLM, Message, TextContent +from openhands.sdk.llm.exceptions import LLMCancelledError + + +def create_mock_response(content: str = "Test response"): + """Create a properly structured mock ModelResponse.""" + return ModelResponse( + id="test-id", + choices=[ + Choices( + finish_reason="stop", + index=0, + message=LiteLLMMessage(content=content, role="assistant"), + ) + ], + created=1234567890, + model="gpt-4o", + object="chat.completion", + usage=Usage(prompt_tokens=10, completion_tokens=5, total_tokens=15), + ) + + +@pytest.fixture +def llm(): + """Create a test LLM instance.""" + return LLM( + model="gpt-4o", + api_key=SecretStr("test_key"), + usage_id="test-interrupt-llm", + num_retries=0, # Disable retries for predictable tests + ) + + +@pytest.fixture +def messages(): + """Create test messages.""" + return [ + Message( + role="system", content=[TextContent(text="You are a helpful assistant")] + ), + Message(role="user", content=[TextContent(text="Hello")]), + ] + + +def test_llm_has_cancel_method(llm: LLM): + """Test that LLM has cancel method.""" + assert hasattr(llm, "cancel") + assert callable(llm.cancel) + + +def test_llm_has_is_cancelled_method(llm: LLM): + """Test that LLM has is_cancelled method.""" + assert hasattr(llm, "is_cancelled") + assert callable(llm.is_cancelled) + + +def test_llm_is_cancelled_returns_false_when_no_task(llm: LLM): + """Test is_cancelled returns False when there's no current task.""" + assert llm.is_cancelled() is False + + +def test_llm_cancel_does_not_raise_when_no_task(llm: LLM): + """Test that cancel doesn't raise when there's no current task.""" + # Should not raise - calling cancel when nothing is running is OK + llm.cancel() + assert llm.is_cancelled() is False + + +def test_llm_has_async_runner(llm: LLM): + """Test that LLM has an AsyncRunner instance.""" + assert llm._async_runner is not None + + +def test_llm_async_runner_loop_created_lazily(llm: LLM): + """Test that async runner's loop is not created until needed.""" + # The runner exists but its loop is not created until first use + runner = llm._async_runner + assert runner is not None + assert runner._loop is None + assert runner._thread is None + + +def test_llm_async_runner_creates_thread_on_use(llm: LLM): + """Test that async runner creates and starts background thread when used.""" + runner = llm._async_runner + assert runner is not None + + # Force the runner to create its loop + loop = runner._ensure_loop() + + assert loop is not None + assert runner._loop is loop + assert runner._thread is not None + assert runner._thread.is_alive() + assert runner._thread.daemon is True + + # Clean up + llm.close() + + +def test_llm_async_runner_reuses_existing_loop(llm: LLM): + """Test that async runner reuses existing loop.""" + runner = llm._async_runner + assert runner is not None + + loop1 = runner._ensure_loop() + loop2 = runner._ensure_loop() + + assert loop1 is loop2 + + # Clean up + llm.close() + + +@patch("openhands.sdk.llm.llm.litellm_acompletion") +def test_llm_completion_uses_async_internally(mock_acompletion, llm: LLM, messages): + """Test that completion uses async completion internally.""" + mock_response = create_mock_response() + mock_acompletion.return_value = mock_response + + result = llm.completion(messages) + + assert result is not None + mock_acompletion.assert_called_once() + + +@patch("openhands.sdk.llm.llm.litellm_acompletion") +def test_llm_cancel_during_completion(mock_acompletion, llm: LLM, messages): + """Test that cancel() works during a completion call.""" + # Create an event to coordinate between threads + call_started = threading.Event() + can_finish = threading.Event() + + async def slow_completion(*args, **kwargs): + call_started.set() + # Wait up to 5 seconds for signal or cancellation + for _ in range(50): + if can_finish.is_set(): + return create_mock_response() + await asyncio.sleep(0.1) + return create_mock_response() + + mock_acompletion.side_effect = slow_completion + + result_container: dict[str, Any] = {"result": None, "error": None} + + def run_completion(): + try: + result_container["result"] = llm.completion(messages) + except Exception as e: + result_container["error"] = e + + # Start completion in background thread + thread = threading.Thread(target=run_completion) + thread.start() + + # Wait for the call to start + call_started.wait(timeout=2) + time.sleep(0.1) # Small delay to ensure task is tracked + + # Cancel the call + llm.cancel() + + # Wait for thread to finish + thread.join(timeout=3) + + # Should have raised LLMCancelledError + assert result_container["error"] is not None + assert isinstance(result_container["error"], LLMCancelledError) + + +@patch("openhands.sdk.llm.llm.litellm_acompletion") +def test_llm_cancel_is_thread_safe(mock_acompletion, messages): + """Test that cancel() can be called from multiple threads safely.""" + llm = LLM( + model="gpt-4o", + api_key=SecretStr("test_key"), + usage_id="test-thread-safe", + num_retries=0, + ) + + mock_acompletion.return_value = create_mock_response() + + # Call cancel from multiple threads concurrently + threads = [] + for i in range(10): + t = threading.Thread(target=llm.cancel) + threads.append(t) + t.start() + + for t in threads: + t.join(timeout=2) + + # Should not raise any errors + + +@patch("openhands.sdk.llm.llm.litellm_acompletion") +def test_llm_can_be_reused_after_cancel(mock_acompletion, llm: LLM, messages): + """Test that LLM can be used for new calls after cancellation.""" + call_count = 0 + call_started = threading.Event() + + async def slow_then_fast(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + call_started.set() + # First call is slow + await asyncio.sleep(10) + # Second call returns immediately + return create_mock_response(f"Response {call_count}") + + mock_acompletion.side_effect = slow_then_fast + + # First call - will be cancelled + result_container: dict[str, Any] = {"error": None} + + def first_call(): + try: + llm.completion(messages) + except LLMCancelledError as e: + result_container["error"] = e + + thread = threading.Thread(target=first_call) + thread.start() + call_started.wait(timeout=2) + time.sleep(0.1) + llm.cancel() + thread.join(timeout=3) + + assert result_container["error"] is not None + assert isinstance(result_container["error"], LLMCancelledError) + + # Reset mock for second call + mock_acompletion.side_effect = None + mock_acompletion.return_value = create_mock_response("Second response") + + # Second call - should work normally + result = llm.completion(messages) + assert result is not None + # Check the content via the message + assert result.message.content is not None + assert len(result.message.content) > 0 + first_content = result.message.content[0] + # Verify it's a TextContent and contains expected text + assert isinstance(first_content, TextContent) + assert "Second response" in first_content.text + + +def test_llm_cancelled_error_exception(): + """Test LLMCancelledError exception properties.""" + error = LLMCancelledError() + assert str(error) == "LLM call was cancelled" + assert error.message == "LLM call was cancelled" + + custom_error = LLMCancelledError("Custom cancellation message") + assert str(custom_error) == "Custom cancellation message" + assert custom_error.message == "Custom cancellation message" + + +def test_llm_cancelled_error_can_be_caught(): + """Test that LLMCancelledError can be caught as Exception.""" + with pytest.raises(LLMCancelledError): + raise LLMCancelledError("test") + + # Should also be catchable as generic Exception + try: + raise LLMCancelledError("test") + except Exception as e: + assert isinstance(e, LLMCancelledError) + + +# ========================================================================= +# Tests for close() method - Resource Cleanup +# ========================================================================= + + +def test_llm_has_close_method(llm: LLM): + """Test that LLM has close method.""" + assert hasattr(llm, "close") + assert callable(llm.close) + + +def test_llm_close_does_not_raise_when_no_loop(llm: LLM): + """Test that close doesn't raise when there's no background loop.""" + runner = llm._async_runner + assert runner is not None + + # Should not raise - calling close when nothing is started is OK + llm.close() + # Runner's internal loop should be None + assert runner._loop is None + assert runner._thread is None + + +def test_llm_close_stops_event_loop_thread(llm: LLM): + """Test that close() stops the background event loop thread.""" + runner = llm._async_runner + assert runner is not None + + # First, start the event loop via the runner + loop = runner._ensure_loop() + thread = runner._thread + + assert loop is not None + assert thread is not None + assert thread.is_alive() + + # Now close it + llm.close() + + # Thread should be stopped + assert runner._loop is None + assert runner._thread is None + # Give thread a moment to finish + time.sleep(0.1) + assert not thread.is_alive() + + +def test_llm_close_can_be_called_multiple_times(llm: LLM): + """Test that close() can be called multiple times safely.""" + runner = llm._async_runner + assert runner is not None + + # Start the event loop via the runner + runner._ensure_loop() + + # Close multiple times - should not raise + llm.close() + llm.close() + llm.close() + + assert runner._loop is None + assert runner._thread is None + + +@patch("openhands.sdk.llm.llm.litellm_acompletion") +def test_llm_can_be_reused_after_close(mock_acompletion, llm: LLM, messages): + """Test that LLM can be used for new calls after close().""" + runner = llm._async_runner + assert runner is not None + + mock_acompletion.return_value = create_mock_response() + + # Make a call to start the event loop + result1 = llm.completion(messages) + assert result1 is not None + + # Close the LLM + llm.close() + assert runner._loop is None + assert runner._thread is None + + # Make another call - should work (loop recreated lazily) + result2 = llm.completion(messages) + assert result2 is not None + assert runner._loop is not None # Loop was recreated + + # Clean up + llm.close() + + +def test_llm_close_is_thread_safe(messages): + """Test that close() can be called from multiple threads safely.""" + llm = LLM( + model="gpt-4o", + api_key=SecretStr("test_key"), + usage_id="test-close-thread-safe", + num_retries=0, + ) + + runner = llm._async_runner + assert runner is not None + + # Start the event loop via the runner + runner._ensure_loop() + + # Call close from multiple threads concurrently + threads = [] + for _ in range(10): + t = threading.Thread(target=llm.close) + threads.append(t) + t.start() + + for t in threads: + t.join(timeout=2) + + # Should not raise any errors and should be cleaned up + assert runner._loop is None + assert runner._thread is None + + +@patch("openhands.sdk.llm.llm.litellm_acompletion") +def test_llm_close_cancels_in_flight_task(mock_acompletion, llm: LLM, messages): + """Test that close() cancels any in-flight task before stopping the loop.""" + call_started = threading.Event() + call_finished = threading.Event() + + async def slow_completion(*args, **kwargs): + call_started.set() + # Wait up to 10 seconds + for _ in range(100): + await asyncio.sleep(0.1) + call_finished.set() + return create_mock_response() + + mock_acompletion.side_effect = slow_completion + + result_container: dict[str, Any] = {"result": None, "error": None} + + def run_completion(): + try: + result_container["result"] = llm.completion(messages) + except Exception as e: + result_container["error"] = e + + # Start completion in background thread + thread = threading.Thread(target=run_completion) + thread.start() + + # Wait for the call to start + call_started.wait(timeout=2) + time.sleep(0.1) # Small delay to ensure task is tracked + + # Close the LLM (should cancel the task) + llm.close() + + # Wait for thread to finish + thread.join(timeout=3) + + # Should have raised LLMCancelledError + assert result_container["error"] is not None + assert isinstance(result_container["error"], LLMCancelledError) + assert not call_finished.is_set() # Call should not have completed normally diff --git a/tests/sdk/llm/test_llm_litellm_extra_body.py b/tests/sdk/llm/test_llm_litellm_extra_body.py index 5bf9af6c1b..c5b608ae4f 100644 --- a/tests/sdk/llm/test_llm_litellm_extra_body.py +++ b/tests/sdk/llm/test_llm_litellm_extra_body.py @@ -23,7 +23,7 @@ def test_completion_forwards_extra_body_for_proxy_models(): ) messages = [Message(role="user", content=[TextContent(text="Hello")])] - with patch("openhands.sdk.llm.llm.litellm_completion") as mock_completion: + with patch("openhands.sdk.llm.llm.litellm_acompletion") as mock_completion: mock_response = ModelResponse( id="test-id", choices=[ @@ -66,7 +66,7 @@ def test_responses_forwards_extra_body_for_all_models(): ) messages = [Message(role="user", content=[TextContent(text="Hello")])] - with patch("openhands.sdk.llm.llm.litellm_responses") as mock_responses: + with patch("openhands.sdk.llm.llm.litellm_aresponses") as mock_responses: mock_response = MagicMock(spec=ResponsesAPIResponse) mock_response.id = "test-id" mock_response.created_at = 1234567890 diff --git a/tests/sdk/llm/test_llm_log_completions_integration.py b/tests/sdk/llm/test_llm_log_completions_integration.py index 7a68ce2b46..ef129b598c 100644 --- a/tests/sdk/llm/test_llm_log_completions_integration.py +++ b/tests/sdk/llm/test_llm_log_completions_integration.py @@ -46,7 +46,7 @@ def test_llm_log_completions_integration_no_warnings(): ) # Mock the litellm completion call - with patch("openhands.sdk.llm.llm.litellm_completion") as mock_completion: + with patch("openhands.sdk.llm.llm.litellm_acompletion") as mock_completion: mock_completion.return_value = mock_response # Capture any warnings @@ -147,7 +147,7 @@ def test_llm_log_completions_with_tool_calls(): ) # Mock the litellm completion call - with patch("openhands.sdk.llm.llm.litellm_completion") as mock_completion: + with patch("openhands.sdk.llm.llm.litellm_acompletion") as mock_completion: mock_completion.return_value = mock_response # Capture any warnings diff --git a/tests/sdk/llm/test_llm_no_response_retry.py b/tests/sdk/llm/test_llm_no_response_retry.py index c8f5809554..5fba75fd4a 100644 --- a/tests/sdk/llm/test_llm_no_response_retry.py +++ b/tests/sdk/llm/test_llm_no_response_retry.py @@ -52,7 +52,7 @@ def base_llm() -> LLM: ) -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_no_response_retries_then_succeeds(mock_completion, base_llm: LLM) -> None: mock_completion.side_effect = [ create_empty_choices_response("empty-1"), @@ -68,7 +68,7 @@ def test_no_response_retries_then_succeeds(mock_completion, base_llm: LLM) -> No assert mock_completion.call_count == 2 # initial + 1 retry -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_no_response_exhausts_retries_bubbles_llm_no_response( mock_completion, base_llm: LLM ) -> None: @@ -87,7 +87,7 @@ def test_no_response_exhausts_retries_bubbles_llm_no_response( assert mock_completion.call_count == base_llm.num_retries -@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion") def test_no_response_retry_bumps_temperature(mock_completion, base_llm: LLM) -> None: # Ensure we start at 0.0 to trigger bump to 1.0 on retry assert base_llm.temperature == 0.0 diff --git a/tests/sdk/llm/test_llm_pricing_passthrough.py b/tests/sdk/llm/test_llm_pricing_passthrough.py index 4836c5e8bc..9eee2c3449 100644 --- a/tests/sdk/llm/test_llm_pricing_passthrough.py +++ b/tests/sdk/llm/test_llm_pricing_passthrough.py @@ -14,7 +14,7 @@ def test_llm_pricing_passthrough_custom_rates(): litellm.cost_calculator.completion_cost. """ with ( - patch("openhands.sdk.llm.llm.litellm_completion") as mock_completion, + patch("openhands.sdk.llm.llm.litellm_acompletion") as mock_completion, patch("openhands.sdk.llm.utils.telemetry.litellm_completion_cost") as mock_cost, ): mock_completion.return_value = create_mock_litellm_response("ok") diff --git a/tests/sdk/llm/test_llm_retry_telemetry.py b/tests/sdk/llm/test_llm_retry_telemetry.py index bf1ab0a774..d545bc8319 100644 --- a/tests/sdk/llm/test_llm_retry_telemetry.py +++ b/tests/sdk/llm/test_llm_retry_telemetry.py @@ -44,8 +44,8 @@ def create_mock_response( ) -@patch("openhands.sdk.llm.llm.litellm_completion") -def test_telemetry_records_only_successful_attempt_latency(mock_litellm_completion): +@patch("openhands.sdk.llm.llm.litellm_acompletion") +def test_telemetry_records_only_successful_attempt_latency(mock_litellm_acompletion): """ Test that when LLM calls are retried, telemetry only records the latency of the successful attempt, not the cumulative time of all attempts. @@ -59,7 +59,7 @@ def test_telemetry_records_only_successful_attempt_latency(mock_litellm_completi mock_response = create_mock_response("Success after retry") # Simulate 2 failures followed by success - mock_litellm_completion.side_effect = [ + mock_litellm_acompletion.side_effect = [ APIConnectionError( message="Connection failed 1", llm_provider="test_provider", @@ -96,7 +96,7 @@ def test_telemetry_records_only_successful_attempt_latency(mock_litellm_completi # Verify the call succeeded assert response.raw_response == mock_response - assert mock_litellm_completion.call_count == 3 + assert mock_litellm_acompletion.call_count == 3 # Get the metrics to check recorded latency metrics = llm.metrics @@ -129,8 +129,8 @@ def test_telemetry_records_only_successful_attempt_latency(mock_litellm_completi ) -@patch("openhands.sdk.llm.llm.litellm_completion") -def test_telemetry_on_request_called_per_retry(mock_litellm_completion): +@patch("openhands.sdk.llm.llm.litellm_acompletion") +def test_telemetry_on_request_called_per_retry(mock_litellm_acompletion): """ Test that telemetry.on_request() is called for each retry attempt. @@ -161,7 +161,7 @@ def mock_transport_call_side_effect(*args, **kwargs): ) return mock_response - mock_litellm_completion.side_effect = mock_transport_call_side_effect + mock_litellm_acompletion.side_effect = mock_transport_call_side_effect # Create LLM instance llm = LLM( @@ -195,8 +195,8 @@ def mock_transport_call_side_effect(*args, **kwargs): ) -@patch("openhands.sdk.llm.llm.litellm_completion") -def test_telemetry_metrics_accurate_with_retries(mock_litellm_completion): +@patch("openhands.sdk.llm.llm.litellm_acompletion") +def test_telemetry_metrics_accurate_with_retries(mock_litellm_acompletion): """ Test that all telemetry metrics (tokens, cost, latency) are accurate when retries occur. @@ -207,7 +207,7 @@ def test_telemetry_metrics_accurate_with_retries(mock_litellm_completion): ) # Simulate one failure then success - mock_litellm_completion.side_effect = [ + mock_litellm_acompletion.side_effect = [ APIConnectionError( message="Connection failed", llm_provider="test_provider", @@ -254,8 +254,8 @@ def test_telemetry_metrics_accurate_with_retries(mock_litellm_completion): assert metrics.response_latencies[0].latency < 0.5 -@patch("openhands.sdk.llm.llm.litellm_completion") -def test_telemetry_no_multiple_records_on_retry(mock_litellm_completion): +@patch("openhands.sdk.llm.llm.litellm_acompletion") +def test_telemetry_no_multiple_records_on_retry(mock_litellm_acompletion): """ Test that telemetry doesn't create multiple records for failed attempts. @@ -264,7 +264,7 @@ def test_telemetry_no_multiple_records_on_retry(mock_litellm_completion): mock_response = create_mock_response("Success") # Simulate multiple failures then success - mock_litellm_completion.side_effect = [ + mock_litellm_acompletion.side_effect = [ APIConnectionError( message="Fail 1", llm_provider="test_provider", model="test_model" ), diff --git a/tests/sdk/llm/test_llm_timeout.py b/tests/sdk/llm/test_llm_timeout.py index 6b6e44c16d..b63e1f7739 100644 --- a/tests/sdk/llm/test_llm_timeout.py +++ b/tests/sdk/llm/test_llm_timeout.py @@ -65,7 +65,7 @@ def test_timeout_accepts_zero(self): class TestLLMTimeoutPassthrough: """Tests that timeout is correctly passed to litellm.""" - @patch("openhands.sdk.llm.llm.litellm_completion") + @patch("openhands.sdk.llm.llm.litellm_acompletion") def test_default_timeout_passed_to_litellm(self, mock_completion): """Test that the default timeout is passed to litellm completion calls.""" from litellm.types.utils import ( @@ -111,7 +111,7 @@ def test_default_timeout_passed_to_litellm(self, mock_completion): f"to litellm, but got {call_kwargs['timeout']}" ) - @patch("openhands.sdk.llm.llm.litellm_completion") + @patch("openhands.sdk.llm.llm.litellm_acompletion") def test_custom_timeout_passed_to_litellm(self, mock_completion): """Test that a custom timeout is passed to litellm completion calls.""" from litellm.types.utils import ( @@ -153,7 +153,7 @@ def test_custom_timeout_passed_to_litellm(self, mock_completion): assert call_kwargs["timeout"] == custom_timeout - @patch("openhands.sdk.llm.llm.litellm_completion") + @patch("openhands.sdk.llm.llm.litellm_acompletion") def test_none_timeout_passed_to_litellm(self, mock_completion): """Test that None timeout is passed to litellm (no timeout).""" from litellm.types.utils import ( diff --git a/tests/sdk/llm/test_responses_parsing_and_kwargs.py b/tests/sdk/llm/test_responses_parsing_and_kwargs.py index 6c3cc110d8..a4f6995f50 100644 --- a/tests/sdk/llm/test_responses_parsing_and_kwargs.py +++ b/tests/sdk/llm/test_responses_parsing_and_kwargs.py @@ -116,7 +116,7 @@ def test_normalize_responses_kwargs_encrypted_reasoning_disabled(): assert "text.output_text" in out["include"] -@patch("openhands.sdk.llm.llm.litellm_responses") +@patch("openhands.sdk.llm.llm.litellm_aresponses") def test_llm_responses_end_to_end(mock_responses_call): # Configure LLM llm = LLM(model="gpt-5-mini") diff --git a/tests/sdk/llm/test_telemetry_policy.py b/tests/sdk/llm/test_telemetry_policy.py index f056302ad6..dbfc3516e3 100644 --- a/tests/sdk/llm/test_telemetry_policy.py +++ b/tests/sdk/llm/test_telemetry_policy.py @@ -14,7 +14,7 @@ def test_chat_forwards_extra_body_for_all_models(): model="cerebras/llama-3.3-70b", usage_id="u1", litellm_extra_body={"k": "v"} ) messages = [Message(role="user", content=[TextContent(text="Hi")])] - with patch("openhands.sdk.llm.llm.litellm_completion") as mock_call: + with patch("openhands.sdk.llm.llm.litellm_acompletion") as mock_call: mock_call.return_value = ModelResponse( id="x", choices=[ @@ -39,7 +39,7 @@ def test_chat_proxy_forwards_extra_body(): eb = {"cluster": "c1", "route": "r1"} llm = LLM(model="litellm_proxy/gpt-4o", usage_id="u1", litellm_extra_body=eb) messages = [Message(role="user", content=[TextContent(text="Hi")])] - with patch("openhands.sdk.llm.llm.litellm_completion") as mock_call: + with patch("openhands.sdk.llm.llm.litellm_acompletion") as mock_call: mock_call.return_value = ModelResponse( id="x", choices=[ @@ -61,7 +61,7 @@ def test_chat_proxy_forwards_extra_body(): # Responses path: same policy -@patch("openhands.sdk.llm.llm.litellm_responses") +@patch("openhands.sdk.llm.llm.litellm_aresponses") def test_responses_forwards_extra_body_for_all_models(mock_responses): llm = LLM( model="cerebras/llama-3.3-70b", usage_id="u1", litellm_extra_body={"k": "v"} @@ -90,7 +90,7 @@ def test_responses_forwards_extra_body_for_all_models(mock_responses): assert kwargs.get("extra_body") == {"k": "v"} -@patch("openhands.sdk.llm.llm.litellm_responses") +@patch("openhands.sdk.llm.llm.litellm_aresponses") def test_responses_proxy_forwards_extra_body(mock_responses): eb = {"cluster": "c1", "route": "r1"} llm = LLM(model="litellm_proxy/gpt-4o", usage_id="u1", litellm_extra_body=eb)