diff --git a/openhands-sdk/openhands/sdk/conversation/compliance/__init__.py b/openhands-sdk/openhands/sdk/conversation/compliance/__init__.py new file mode 100644 index 0000000000..d27fbcdc23 --- /dev/null +++ b/openhands-sdk/openhands/sdk/conversation/compliance/__init__.py @@ -0,0 +1,25 @@ +"""API Compliance monitoring for conversation events. + +This module provides an APIComplianceMonitor that detects and rejects violations +of LLM API requirements in the event stream. Violating events are logged and +rejected (not added to conversation). Future versions may support reconciliation +strategies. + +The monitor enforces valid tool-call sequences: +- When tool calls are pending, only matching observations are allowed +- Messages cannot interleave with pending tool calls +- Tool results must reference known tool_call_ids +""" + +from openhands.sdk.conversation.compliance.base import ( + ComplianceState, + ComplianceViolation, +) +from openhands.sdk.conversation.compliance.monitor import APIComplianceMonitor + + +__all__ = [ + "APIComplianceMonitor", + "ComplianceState", + "ComplianceViolation", +] diff --git a/openhands-sdk/openhands/sdk/conversation/compliance/base.py b/openhands-sdk/openhands/sdk/conversation/compliance/base.py new file mode 100644 index 0000000000..57350f5614 --- /dev/null +++ b/openhands-sdk/openhands/sdk/conversation/compliance/base.py @@ -0,0 +1,41 @@ +"""Base classes for API compliance monitoring.""" + +from dataclasses import dataclass, field + +from openhands.sdk.event.types import EventID, ToolCallID + + +@dataclass +class ComplianceViolation: + """Represents an API compliance violation. + + Attributes: + property_name: Name of the property that was violated. + event_id: ID of the event that caused the violation. + description: Human-readable description of the violation. + context: Optional additional context (e.g., related tool_call_ids). + """ + + property_name: str + event_id: EventID + description: str + context: dict[str, object] | None = None + + +@dataclass +class ComplianceState: + """Shared state for tracking API compliance. + + Tracks the tool call lifecycle to detect violations: + - pending_tool_call_ids: Actions awaiting results + - completed_tool_call_ids: Actions that have received results + + Attributes: + pending_tool_call_ids: Tool calls that have been made but not yet + received results. Maps tool_call_id to the ActionEvent id. + completed_tool_call_ids: Tool calls that have received results. + Used to detect duplicate results. + """ + + pending_tool_call_ids: dict[ToolCallID, EventID] = field(default_factory=dict) + completed_tool_call_ids: set[ToolCallID] = field(default_factory=set) diff --git a/openhands-sdk/openhands/sdk/conversation/compliance/monitor.py b/openhands-sdk/openhands/sdk/conversation/compliance/monitor.py new file mode 100644 index 0000000000..aafa91ecdb --- /dev/null +++ b/openhands-sdk/openhands/sdk/conversation/compliance/monitor.py @@ -0,0 +1,172 @@ +"""API Compliance Monitor that checks events before adding to conversation.""" + +from openhands.sdk.conversation.compliance.base import ( + ComplianceState, + ComplianceViolation, +) +from openhands.sdk.event import ( + ActionEvent, + LLMConvertibleEvent, + MessageEvent, + ObservationBaseEvent, +) +from openhands.sdk.logger import get_logger + + +logger = get_logger(__name__) + + +class APIComplianceMonitor: + """Monitors events for API compliance violations. + + Enforces valid tool-call sequences by checking what events are allowed + given current state. The key invariant: when tool calls are pending, + only matching observations are allowed. + + State machine: + - IDLE (no pending calls): Messages and new actions allowed + - TOOL_CALLING (pending calls): Only matching observations allowed + Currently operates in rejection mode: violating events are logged and rejected + (not added to the conversation). + + Attributes: + state: Compliance state tracking pending/completed tool calls. + """ + + def __init__(self) -> None: + """Initialize the compliance monitor.""" + self.state = ComplianceState() + + def _check_tool_call_sequence( + self, event: LLMConvertibleEvent + ) -> ComplianceViolation | None: + """Check if an event violates the tool-call sequence property. + + The rule is simple: if we have pending tool calls, only matching + observations are allowed. This covers all 8 API compliance patterns: + + - a01 (unmatched_tool_use): Message while calls pending + - a02 (unmatched_tool_result): Result with unknown ID + - a03 (interleaved_user_msg): User message while calls pending + - a04 (interleaved_asst_msg): Assistant message while calls pending + - a05 (duplicate_tool_call_id): Result for already-completed ID + - a06 (wrong_tool_call_id): Result with wrong/unknown ID + - a07 (parallel_missing_result): Message before all parallel results + - a08 (parallel_wrong_order): Result before action (unknown ID) + + Args: + event: The event to check. + + Returns: + A ComplianceViolation if the event violates the property, None otherwise. + """ + # Actions are always allowed - they start or continue a tool-call batch + if isinstance(event, ActionEvent): + return None + + # Messages require no pending tool calls + if isinstance(event, MessageEvent): + if self.state.pending_tool_call_ids: + pending_ids = list(self.state.pending_tool_call_ids.keys()) + return ComplianceViolation( + property_name="interleaved_message", + event_id=event.id, + description=( + f"Message interleaved with {len(pending_ids)} pending " + f"tool call(s)" + ), + context={"pending_tool_call_ids": pending_ids}, + ) + return None + + # Observations must match a known tool_call_id + if isinstance(event, ObservationBaseEvent): + tool_call_id = event.tool_call_id + + # Check for valid match (pending) + if tool_call_id in self.state.pending_tool_call_ids: + return None # Valid - completes a pending call + + # Check for duplicate (already completed) + if tool_call_id in self.state.completed_tool_call_ids: + return ComplianceViolation( + property_name="duplicate_tool_result", + event_id=event.id, + description=( + f"Duplicate tool result for tool_call_id: {tool_call_id}" + ), + context={"tool_call_id": tool_call_id}, + ) + + # Unknown ID - orphan result (covers a02, a06, a08) + return ComplianceViolation( + property_name="unmatched_tool_result", + event_id=event.id, + description=( + f"Tool result references unknown tool_call_id: {tool_call_id}" + ), + context={"tool_call_id": tool_call_id}, + ) + + return None + + def _update_state(self, event: LLMConvertibleEvent) -> None: + """Update compliance state after processing an event. + + Tracks the tool-call lifecycle: + - ActionEvent: Add to pending + - ObservationBaseEvent: Move from pending to completed + """ + if isinstance(event, ActionEvent): + self.state.pending_tool_call_ids[event.tool_call_id] = event.id + elif isinstance(event, ObservationBaseEvent): + # Move from pending to completed (if it was pending) + self.state.pending_tool_call_ids.pop(event.tool_call_id, None) + self.state.completed_tool_call_ids.add(event.tool_call_id) + + def process_event(self, event: LLMConvertibleEvent) -> list[ComplianceViolation]: + """Check an event for violations and update state. + + Fail-closed semantics: if checking crashes, the event is treated as + violating (state is not updated). Only events that pass checking + without violations have their state updated. + + Args: + event: The event to process. + + Returns: + List of violations detected (empty if compliant). + """ + violations: list[ComplianceViolation] = [] + check_failed = False + + try: + violation = self._check_tool_call_sequence(event) + if violation is not None: + violations.append(violation) + logger.warning( + "API compliance violation detected: %s - %s (event_id=%s)", + violation.property_name, + violation.description, + violation.event_id, + ) + except Exception as e: + logger.exception( + "Error checking compliance for event %s: %s", + event.id, + e, + ) + check_failed = True + + # Only update state if check succeeded and no violations + if not check_failed and not violations: + try: + self._update_state(event) + except Exception as e: + logger.exception( + "Error updating compliance state for event %s: %s", + event.id, + e, + ) + + return violations diff --git a/openhands-sdk/openhands/sdk/conversation/events_list_base.py b/openhands-sdk/openhands/sdk/conversation/events_list_base.py index 0f5f77239c..784fe7f086 100644 --- a/openhands-sdk/openhands/sdk/conversation/events_list_base.py +++ b/openhands-sdk/openhands/sdk/conversation/events_list_base.py @@ -2,6 +2,7 @@ from collections.abc import Sequence from openhands.sdk.event import Event +from openhands.sdk.event.types import EventID class EventsListBase(Sequence[Event], ABC): @@ -15,3 +16,8 @@ class EventsListBase(Sequence[Event], ABC): def append(self, event: Event) -> None: """Add a new event to the list.""" ... + + @abstractmethod + def get_index(self, event_id: EventID) -> int: + """Return the integer index for a given event_id.""" + ... diff --git a/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py b/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py index 229c023955..95caf4b0fd 100644 --- a/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py +++ b/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py @@ -178,12 +178,16 @@ def __init__( cipher=cipher, ) - # Default callback: persist every event to state + # Default callback: persist every event to state with compliance checking def _default_callback(e): # This callback runs while holding the conversation state's lock # (see BaseConversation.compose_callbacks usage inside `with self._state:` # regions), so updating state here is thread-safe. - self._state.events.append(e) + # + # Use add_event() to check API compliance before appending. + # Events with violations are logged and rejected (not added to event log). + # Violations are logged but events are still processed. + self._state.add_event(e) # Track user MessageEvent IDs here so hook callbacks (which may # synthesize or alter user messages) are captured in one place. if isinstance(e, MessageEvent) and e.source == "user": diff --git a/openhands-sdk/openhands/sdk/conversation/impl/remote_conversation.py b/openhands-sdk/openhands/sdk/conversation/impl/remote_conversation.py index 32762bb2e4..895a03dfe7 100644 --- a/openhands-sdk/openhands/sdk/conversation/impl/remote_conversation.py +++ b/openhands-sdk/openhands/sdk/conversation/impl/remote_conversation.py @@ -352,6 +352,14 @@ def append(self, event: Event) -> None: """Add a new event to the list (for compatibility with EventLog interface).""" self.add_event(event) + def get_index(self, event_id: str) -> int: + """Return the integer index for a given event_id.""" + with self._lock: + for idx, event in enumerate(self._cached_events): + if event.id == event_id: + return idx + raise KeyError(f"Unknown event_id: {event_id}") + def create_default_callback(self) -> ConversationCallbackType: """Create a default callback that adds events to this list.""" diff --git a/openhands-sdk/openhands/sdk/conversation/state.py b/openhands-sdk/openhands/sdk/conversation/state.py index 8cef2c04b2..62b93fcdcd 100644 --- a/openhands-sdk/openhands/sdk/conversation/state.py +++ b/openhands-sdk/openhands/sdk/conversation/state.py @@ -1,4 +1,6 @@ # state.py +from __future__ import annotations + import json from collections.abc import Sequence from enum import Enum @@ -8,13 +10,20 @@ from pydantic import Field, PrivateAttr, model_validator from openhands.sdk.agent.base import AgentBase +from openhands.sdk.conversation.compliance import APIComplianceMonitor from openhands.sdk.conversation.conversation_stats import ConversationStats from openhands.sdk.conversation.event_store import EventLog +from openhands.sdk.conversation.events_list_base import EventsListBase from openhands.sdk.conversation.fifo_lock import FIFOLock from openhands.sdk.conversation.persistence_const import BASE_STATE, EVENTS_DIR from openhands.sdk.conversation.secret_registry import SecretRegistry from openhands.sdk.conversation.types import ConversationCallbackType, ConversationID -from openhands.sdk.event import ActionEvent, ObservationEvent, UserRejectObservation +from openhands.sdk.event import ( + ActionEvent, + LLMConvertibleEvent, + ObservationEvent, + UserRejectObservation, +) from openhands.sdk.event.base import Event from openhands.sdk.event.types import EventID from openhands.sdk.io import FileStore, InMemoryFileStore, LocalFileStore @@ -177,6 +186,9 @@ class ConversationState(OpenHandsModel): _lock: FIFOLock = PrivateAttr( default_factory=FIFOLock ) # FIFO lock for thread safety + _compliance_monitor: APIComplianceMonitor | None = PrivateAttr( + default=None + ) # API compliance monitor (lazy-initialized) @model_validator(mode="before") @classmethod @@ -202,7 +214,12 @@ def _handle_legacy_fields(cls, data: Any) -> Any: return data @property - def events(self) -> EventLog: + def events(self) -> EventsListBase: + """Read-only view of conversation events. + + Returns events as EventsListBase to discourage direct mutation. + Use add_event() to add new events with compliance monitoring. + """ return self._events @property @@ -244,7 +261,7 @@ def _save_base_state(self, fs: FileStore) -> None: # ===== Factory: open-or-create (no load/save methods needed) ===== @classmethod def create( - cls: type["ConversationState"], + cls: type[ConversationState], id: ConversationID, agent: AgentBase, workspace: BaseWorkspace, @@ -252,7 +269,7 @@ def create( max_iterations: int = 500, stuck_detection: bool = True, cipher: Cipher | None = None, - ) -> "ConversationState": + ) -> ConversationState: """Create a new conversation state or resume from persistence. This factory method handles both new conversation creation and resumption @@ -506,3 +523,49 @@ def owned(self) -> bool: Return True if the lock is currently held by the calling thread. """ return self._lock.owned() + + # ===== API Compliance Monitoring ===== + + @property + def compliance_monitor(self) -> APIComplianceMonitor: + """Get or create the API compliance monitor. + + The monitor is lazily initialized on first access. + """ + if self._compliance_monitor is None: + self._compliance_monitor = APIComplianceMonitor() + return self._compliance_monitor + + def add_event(self, event: Event) -> bool: + """Add an event to the conversation, checking for API compliance. + + This is the only supported way to add events to the conversation. + Do not mutate the events list directly (e.g., via ``state.events.append()``), + as this bypasses compliance monitoring and may cause silent failures. + + For LLMConvertibleEvent instances, the event is checked against API + compliance properties. Events with violations are logged and rejected + (not added to the event log). + + Args: + event: The event to add to the conversation. + + Returns: + True if the event was added, False if rejected due to violations. + """ + # Check for compliance violations only for LLM-convertible events + if isinstance(event, LLMConvertibleEvent): + try: + violations = self.compliance_monitor.process_event(event) + if violations: + # Reject events with violations + return False + except Exception as e: + logger.exception( + "Error checking compliance for event %s: %s", event.id, e + ) + # Fail-closed: reject event if compliance check crashes + return False + + self._events.append(event) + return True diff --git a/tests/integration/api_compliance/patterns.py b/tests/integration/api_compliance/patterns.py new file mode 100644 index 0000000000..d60688b095 --- /dev/null +++ b/tests/integration/api_compliance/patterns.py @@ -0,0 +1,462 @@ +"""API Compliance Pattern Definitions. + +This module defines the 8 malformed message patterns (a01-a08) that test +API compliance. Each pattern is a list of Message objects representing +a malformed conversation sequence. + +These patterns are used by: +1. Integration tests that verify how LLM APIs respond to malformed input +2. Unit tests that verify the APIComplianceMonitor catches these violations +""" + +from dataclasses import dataclass + +from openhands.sdk.llm import Message, MessageToolCall, TextContent + + +@dataclass +class CompliancePattern: + """A compliance test pattern with metadata.""" + + name: str + description: str + messages: list[Message] + expected_violation: str # The violation property_name we expect + + +# ============================================================================= +# Pattern a01: Unmatched tool_use +# ============================================================================= +A01_UNMATCHED_TOOL_USE = CompliancePattern( + name="unmatched_tool_use", + description=( + "Conversation where an assistant message contains a tool_use " + "(tool_calls), but no tool_result follows before the next user message." + ), + messages=[ + Message( + role="system", + content=[TextContent(text="You are a helpful assistant.")], + ), + Message( + role="user", + content=[TextContent(text="List the files in the current directory.")], + ), + # Assistant message with tool_use + Message( + role="assistant", + content=[TextContent(text="I'll list the files for you.")], + tool_calls=[ + MessageToolCall( + id="call_abc123", + name="terminal", + arguments='{"command": "ls -la"}', + origin="completion", + ) + ], + ), + # NOTE: No tool_result follows! Directly another user message. + Message( + role="user", + content=[TextContent(text="What was the result?")], + ), + ], + expected_violation="interleaved_message", +) + + +# ============================================================================= +# Pattern a02: Unmatched tool_result +# ============================================================================= +A02_UNMATCHED_TOOL_RESULT = CompliancePattern( + name="unmatched_tool_result", + description=( + "Conversation where a tool_result message references a tool_call_id " + "that doesn't exist in any prior assistant message's tool_calls." + ), + messages=[ + Message( + role="system", + content=[TextContent(text="You are a helpful assistant.")], + ), + Message( + role="user", + content=[TextContent(text="List the files in the current directory.")], + ), + # Assistant message WITHOUT tool_use + Message( + role="assistant", + content=[TextContent(text="I can help you list files. What directory?")], + ), + # Tool result that references a non-existent tool_call_id + Message( + role="tool", + content=[TextContent(text="file1.txt\nfile2.txt\nfile3.txt")], + tool_call_id="call_nonexistent_xyz", + name="terminal", + ), + ], + expected_violation="unmatched_tool_result", +) + + +# ============================================================================= +# Pattern a03: Interleaved user message +# ============================================================================= +A03_INTERLEAVED_USER_MSG = CompliancePattern( + name="interleaved_user_message", + description=( + "Conversation where a user message appears between a tool_use " + "(in assistant message) and its corresponding tool_result." + ), + messages=[ + Message( + role="system", + content=[TextContent(text="You are a helpful assistant.")], + ), + Message( + role="user", + content=[TextContent(text="List the files in the current directory.")], + ), + # Assistant message with tool_use + Message( + role="assistant", + content=[TextContent(text="I'll list the files for you.")], + tool_calls=[ + MessageToolCall( + id="call_abc123", + name="terminal", + arguments='{"command": "ls -la"}', + origin="completion", + ) + ], + ), + # INTERLEAVED: User message before tool_result + Message( + role="user", + content=[TextContent(text="Actually, can you also show hidden files?")], + ), + # Tool result comes AFTER the interleaved user message + Message( + role="tool", + content=[TextContent(text="file1.txt\nfile2.txt")], + tool_call_id="call_abc123", + name="terminal", + ), + ], + expected_violation="interleaved_message", +) + + +# ============================================================================= +# Pattern a04: Interleaved assistant message +# ============================================================================= +A04_INTERLEAVED_ASST_MSG = CompliancePattern( + name="interleaved_assistant_message", + description=( + "Conversation where an assistant message (without tool_calls) appears " + "between a tool_use and its corresponding tool_result." + ), + messages=[ + Message( + role="system", + content=[TextContent(text="You are a helpful assistant.")], + ), + Message( + role="user", + content=[TextContent(text="List the files in the current directory.")], + ), + # First assistant message with tool_use + Message( + role="assistant", + content=[TextContent(text="I'll list the files for you.")], + tool_calls=[ + MessageToolCall( + id="call_abc123", + name="terminal", + arguments='{"command": "ls -la"}', + origin="completion", + ) + ], + ), + # INTERLEAVED: Another assistant message without tool_calls + Message( + role="assistant", + content=[TextContent(text="The command is running...")], + ), + # Tool result comes AFTER the interleaved assistant message + Message( + role="tool", + content=[TextContent(text="file1.txt\nfile2.txt")], + tool_call_id="call_abc123", + name="terminal", + ), + ], + expected_violation="interleaved_message", +) + + +# ============================================================================= +# Pattern a05: Duplicate tool_call_id +# ============================================================================= +A05_DUPLICATE_TOOL_CALL_ID = CompliancePattern( + name="duplicate_tool_call_id", + description=( + "Conversation where two tool_result messages have the same tool_call_id, " + "meaning multiple results are provided for a single tool_use." + ), + messages=[ + Message( + role="system", + content=[TextContent(text="You are a helpful assistant.")], + ), + Message( + role="user", + content=[TextContent(text="List the files in the current directory.")], + ), + # Assistant message with tool_use + Message( + role="assistant", + content=[TextContent(text="I'll list the files for you.")], + tool_calls=[ + MessageToolCall( + id="call_abc123", + name="terminal", + arguments='{"command": "ls -la"}', + origin="completion", + ) + ], + ), + # First tool result (correct) + Message( + role="tool", + content=[TextContent(text="file1.txt\nfile2.txt")], + tool_call_id="call_abc123", + name="terminal", + ), + # Some intervening messages + Message( + role="user", + content=[TextContent(text="Thanks! Now what?")], + ), + Message( + role="assistant", + content=[ + TextContent(text="You're welcome! Let me know if you need anything.") + ], + ), + Message( + role="user", + content=[TextContent(text="Actually, show me the files again.")], + ), + # DUPLICATE: Second tool result with SAME tool_call_id + Message( + role="tool", + content=[TextContent(text="file1.txt\nfile2.txt\nfile3.txt")], + tool_call_id="call_abc123", # Same ID as before! + name="terminal", + ), + ], + expected_violation="duplicate_tool_result", +) + + +# ============================================================================= +# Pattern a06: Wrong tool_call_id +# ============================================================================= +A06_WRONG_TOOL_CALL_ID = CompliancePattern( + name="wrong_tool_call_id", + description=( + "Conversation where a tool_result references the wrong tool_call_id " + "(one that has already been completed)." + ), + messages=[ + Message( + role="system", + content=[TextContent(text="You are a helpful assistant.")], + ), + Message( + role="user", + content=[TextContent(text="Run two commands: ls and pwd")], + ), + # First assistant message with tool_use (id=A) + Message( + role="assistant", + content=[TextContent(text="I'll run ls first.")], + tool_calls=[ + MessageToolCall( + id="call_A_ls", + name="terminal", + arguments='{"command": "ls"}', + origin="completion", + ) + ], + ), + # First tool result - CORRECT + Message( + role="tool", + content=[TextContent(text="file1.txt\nfile2.txt")], + tool_call_id="call_A_ls", + name="terminal", + ), + # Second assistant message with tool_use (id=B) + Message( + role="assistant", + content=[TextContent(text="Now I'll run pwd.")], + tool_calls=[ + MessageToolCall( + id="call_B_pwd", + name="terminal", + arguments='{"command": "pwd"}', + origin="completion", + ) + ], + ), + # Second tool result - WRONG ID (references first tool_use which is done) + Message( + role="tool", + content=[TextContent(text="/home/user/project")], + tool_call_id="call_A_ls", # Wrong! Should be call_B_pwd + name="terminal", + ), + ], + expected_violation="duplicate_tool_result", # It's a duplicate of completed ID +) + + +# ============================================================================= +# Pattern a07: Parallel missing result +# ============================================================================= +A07_PARALLEL_MISSING_RESULT = CompliancePattern( + name="parallel_missing_result", + description=( + "Conversation where an assistant message contains multiple parallel " + "tool_calls, but only some of them have corresponding tool_results." + ), + messages=[ + Message( + role="system", + content=[TextContent(text="You are a helpful assistant.")], + ), + Message( + role="user", + content=[ + TextContent(text="Get the weather in San Francisco, Tokyo, and Paris.") + ], + ), + # Assistant message with THREE parallel tool_calls + Message( + role="assistant", + content=[TextContent(text="I'll check the weather in all three cities.")], + tool_calls=[ + MessageToolCall( + id="call_sf", + name="terminal", + arguments='{"command": "weather sf"}', + origin="completion", + ), + MessageToolCall( + id="call_tokyo", + name="terminal", + arguments='{"command": "weather tokyo"}', + origin="completion", + ), + MessageToolCall( + id="call_paris", + name="terminal", + arguments='{"command": "weather paris"}', + origin="completion", + ), + ], + ), + # Tool result for SF - provided + Message( + role="tool", + content=[TextContent(text="San Francisco: 65°F, Sunny")], + tool_call_id="call_sf", + name="terminal", + ), + # Tool result for Tokyo - provided + Message( + role="tool", + content=[TextContent(text="Tokyo: 72°F, Cloudy")], + tool_call_id="call_tokyo", + name="terminal", + ), + # NOTE: Tool result for Paris is MISSING! + # Next user message arrives before Paris result + Message( + role="user", + content=[TextContent(text="What about Paris?")], + ), + ], + expected_violation="interleaved_message", +) + + +# ============================================================================= +# Pattern a08: Parallel wrong order +# ============================================================================= +A08_PARALLEL_WRONG_ORDER = CompliancePattern( + name="parallel_wrong_order", + description=( + "Conversation where tool_results appear before the assistant message " + "that contains the corresponding tool_calls." + ), + messages=[ + Message( + role="system", + content=[TextContent(text="You are a helpful assistant.")], + ), + Message( + role="user", + content=[TextContent(text="Check the weather in SF and Tokyo.")], + ), + # Tool results appear FIRST (wrong!) + Message( + role="tool", + content=[TextContent(text="San Francisco: 65°F, Sunny")], + tool_call_id="call_sf", + name="terminal", + ), + Message( + role="tool", + content=[TextContent(text="Tokyo: 72°F, Cloudy")], + tool_call_id="call_tokyo", + name="terminal", + ), + # Assistant message with tool_calls comes AFTER tool_results + Message( + role="assistant", + content=[TextContent(text="I'll check both cities.")], + tool_calls=[ + MessageToolCall( + id="call_sf", + name="terminal", + arguments='{"command": "weather sf"}', + origin="completion", + ), + MessageToolCall( + id="call_tokyo", + name="terminal", + arguments='{"command": "weather tokyo"}', + origin="completion", + ), + ], + ), + ], + expected_violation="unmatched_tool_result", +) + + +# All patterns for iteration +ALL_COMPLIANCE_PATTERNS = [ + A01_UNMATCHED_TOOL_USE, + A02_UNMATCHED_TOOL_RESULT, + A03_INTERLEAVED_USER_MSG, + A04_INTERLEAVED_ASST_MSG, + A05_DUPLICATE_TOOL_CALL_ID, + A06_WRONG_TOOL_CALL_ID, + A07_PARALLEL_MISSING_RESULT, + A08_PARALLEL_WRONG_ORDER, +] diff --git a/tests/integration/tests/a01_unmatched_tool_use.py b/tests/integration/tests/a01_unmatched_tool_use.py index 490b6c20fa..5c7b521cf9 100644 --- a/tests/integration/tests/a01_unmatched_tool_use.py +++ b/tests/integration/tests/a01_unmatched_tool_use.py @@ -4,26 +4,14 @@ Tests how different LLM APIs respond when a tool_use message is sent without a corresponding tool_result. - Pattern: [system] → [user] → [assistant with tool_use] → [user message] → API CALL ↑ No tool_result! """ -from openhands.sdk.llm import Message, MessageToolCall, TextContent +from openhands.sdk.llm import Message from tests.integration.api_compliance.base import BaseAPIComplianceTest - - -PATTERN_NAME = "unmatched_tool_use" -DESCRIPTION = """ -Sends a conversation where an assistant message contains a tool_use (tool_calls), -but no tool_result (tool message) follows before the next user message. - -This pattern can occur when: -- ObservationEvent is delayed or lost -- User message arrives before observation is recorded -- Event sync issues during conversation resume -""" +from tests.integration.api_compliance.patterns import A01_UNMATCHED_TOOL_USE class UnmatchedToolUseTest(BaseAPIComplianceTest): @@ -31,39 +19,12 @@ class UnmatchedToolUseTest(BaseAPIComplianceTest): @property def pattern_name(self) -> str: - return PATTERN_NAME + return A01_UNMATCHED_TOOL_USE.name @property def pattern_description(self) -> str: - return DESCRIPTION + return A01_UNMATCHED_TOOL_USE.description def build_malformed_messages(self) -> list[Message]: """Build message sequence with unmatched tool_use.""" - return [ - Message( - role="system", - content=[TextContent(text="You are a helpful assistant.")], - ), - Message( - role="user", - content=[TextContent(text="List the files in the current directory.")], - ), - # Assistant message with tool_use - Message( - role="assistant", - content=[TextContent(text="I'll list the files for you.")], - tool_calls=[ - MessageToolCall( - id="call_abc123", - name="terminal", - arguments='{"command": "ls -la"}', - origin="completion", - ) - ], - ), - # NOTE: No tool_result follows! Directly another user message. - Message( - role="user", - content=[TextContent(text="What was the result?")], - ), - ] + return A01_UNMATCHED_TOOL_USE.messages diff --git a/tests/integration/tests/a02_unmatched_tool_result.py b/tests/integration/tests/a02_unmatched_tool_result.py index 4a54c9587e..feb656198d 100644 --- a/tests/integration/tests/a02_unmatched_tool_result.py +++ b/tests/integration/tests/a02_unmatched_tool_result.py @@ -9,20 +9,9 @@ ↑ References non-existent ID! """ -from openhands.sdk.llm import Message, TextContent +from openhands.sdk.llm import Message from tests.integration.api_compliance.base import BaseAPIComplianceTest - - -PATTERN_NAME = "unmatched_tool_result" -DESCRIPTION = """ -Sends a conversation where a tool_result message references a tool_call_id -that doesn't exist in any prior assistant message's tool_calls. - -This pattern can occur when: -- tool_call_id is corrupted during serialization -- Tool results are sent for the wrong conversation -- Event ordering issues cause mismatched IDs -""" +from tests.integration.api_compliance.patterns import A02_UNMATCHED_TOOL_RESULT class UnmatchedToolResultTest(BaseAPIComplianceTest): @@ -30,35 +19,12 @@ class UnmatchedToolResultTest(BaseAPIComplianceTest): @property def pattern_name(self) -> str: - return PATTERN_NAME + return A02_UNMATCHED_TOOL_RESULT.name @property def pattern_description(self) -> str: - return DESCRIPTION + return A02_UNMATCHED_TOOL_RESULT.description def build_malformed_messages(self) -> list[Message]: """Build message sequence with unmatched tool_result.""" - return [ - Message( - role="system", - content=[TextContent(text="You are a helpful assistant.")], - ), - Message( - role="user", - content=[TextContent(text="List the files in the current directory.")], - ), - # Assistant message WITHOUT tool_use - Message( - role="assistant", - content=[ - TextContent(text="I can help you list files. What directory?") - ], - ), - # Tool result that references a non-existent tool_call_id - Message( - role="tool", - content=[TextContent(text="file1.txt\nfile2.txt\nfile3.txt")], - tool_call_id="call_nonexistent_xyz", - name="terminal", - ), - ] + return A02_UNMATCHED_TOOL_RESULT.messages diff --git a/tests/integration/tests/a03_interleaved_user_msg.py b/tests/integration/tests/a03_interleaved_user_msg.py index ed75a914cd..50ad9e87df 100644 --- a/tests/integration/tests/a03_interleaved_user_msg.py +++ b/tests/integration/tests/a03_interleaved_user_msg.py @@ -4,26 +4,14 @@ Tests how different LLM APIs respond when a user message appears between tool_use and tool_result. - Pattern: [assistant with tool_use] → [user message] → [tool_result] ↑ Inserted between tool_use and tool_result! """ -from openhands.sdk.llm import Message, MessageToolCall, TextContent +from openhands.sdk.llm import Message from tests.integration.api_compliance.base import BaseAPIComplianceTest - - -PATTERN_NAME = "interleaved_user_message" -DESCRIPTION = """ -Sends a conversation where a user message appears between a tool_use -(in assistant message) and its corresponding tool_result (tool message). - -This pattern can occur when: -- User sends message via send_message() during pending tool execution -- Events are appended to the event list in incorrect order -- Async message delivery causes race conditions -""" +from tests.integration.api_compliance.patterns import A03_INTERLEAVED_USER_MSG class InterleavedUserMessageTest(BaseAPIComplianceTest): @@ -31,46 +19,12 @@ class InterleavedUserMessageTest(BaseAPIComplianceTest): @property def pattern_name(self) -> str: - return PATTERN_NAME + return A03_INTERLEAVED_USER_MSG.name @property def pattern_description(self) -> str: - return DESCRIPTION + return A03_INTERLEAVED_USER_MSG.description def build_malformed_messages(self) -> list[Message]: """Build message sequence with interleaved user message.""" - return [ - Message( - role="system", - content=[TextContent(text="You are a helpful assistant.")], - ), - Message( - role="user", - content=[TextContent(text="List the files in the current directory.")], - ), - # Assistant message with tool_use - Message( - role="assistant", - content=[TextContent(text="I'll list the files for you.")], - tool_calls=[ - MessageToolCall( - id="call_abc123", - name="terminal", - arguments='{"command": "ls -la"}', - origin="completion", - ) - ], - ), - # INTERLEAVED: User message before tool_result - Message( - role="user", - content=[TextContent(text="Actually, can you also show hidden files?")], - ), - # Tool result comes AFTER the interleaved user message - Message( - role="tool", - content=[TextContent(text="file1.txt\nfile2.txt")], - tool_call_id="call_abc123", - name="terminal", - ), - ] + return A03_INTERLEAVED_USER_MSG.messages diff --git a/tests/integration/tests/a04_interleaved_asst_msg.py b/tests/integration/tests/a04_interleaved_asst_msg.py index b59ce40302..49eb248437 100644 --- a/tests/integration/tests/a04_interleaved_asst_msg.py +++ b/tests/integration/tests/a04_interleaved_asst_msg.py @@ -9,20 +9,9 @@ ↑ Another assistant turn before tool_result! """ -from openhands.sdk.llm import Message, MessageToolCall, TextContent +from openhands.sdk.llm import Message from tests.integration.api_compliance.base import BaseAPIComplianceTest - - -PATTERN_NAME = "interleaved_assistant_message" -DESCRIPTION = """ -Sends a conversation where an assistant message (without tool_calls) appears -between a tool_use and its corresponding tool_result. - -This pattern might occur in edge cases with: -- Malformed condensation that inserts summary messages incorrectly -- Manual event manipulation -- Corrupted conversation history -""" +from tests.integration.api_compliance.patterns import A04_INTERLEAVED_ASST_MSG class InterleavedAssistantMessageTest(BaseAPIComplianceTest): @@ -30,46 +19,12 @@ class InterleavedAssistantMessageTest(BaseAPIComplianceTest): @property def pattern_name(self) -> str: - return PATTERN_NAME + return A04_INTERLEAVED_ASST_MSG.name @property def pattern_description(self) -> str: - return DESCRIPTION + return A04_INTERLEAVED_ASST_MSG.description def build_malformed_messages(self) -> list[Message]: """Build message sequence with interleaved assistant message.""" - return [ - Message( - role="system", - content=[TextContent(text="You are a helpful assistant.")], - ), - Message( - role="user", - content=[TextContent(text="List the files in the current directory.")], - ), - # First assistant message with tool_use - Message( - role="assistant", - content=[TextContent(text="I'll list the files for you.")], - tool_calls=[ - MessageToolCall( - id="call_abc123", - name="terminal", - arguments='{"command": "ls -la"}', - origin="completion", - ) - ], - ), - # INTERLEAVED: Another assistant message without tool_calls - Message( - role="assistant", - content=[TextContent(text="The command is running...")], - ), - # Tool result comes AFTER the interleaved assistant message - Message( - role="tool", - content=[TextContent(text="file1.txt\nfile2.txt")], - tool_call_id="call_abc123", - name="terminal", - ), - ] + return A04_INTERLEAVED_ASST_MSG.messages diff --git a/tests/integration/tests/a05_duplicate_tool_call_id.py b/tests/integration/tests/a05_duplicate_tool_call_id.py index 7e8db00410..49ab29589a 100644 --- a/tests/integration/tests/a05_duplicate_tool_call_id.py +++ b/tests/integration/tests/a05_duplicate_tool_call_id.py @@ -4,26 +4,14 @@ Tests how different LLM APIs respond when multiple tool_result messages have the same tool_call_id. - Pattern: [assistant with tool_use id=X] → [tool_result id=X] → ... → [tool_result id=X] ↑ Duplicate! """ -from openhands.sdk.llm import Message, MessageToolCall, TextContent +from openhands.sdk.llm import Message from tests.integration.api_compliance.base import BaseAPIComplianceTest - - -PATTERN_NAME = "duplicate_tool_call_id" -DESCRIPTION = """ -Sends a conversation where two tool_result messages have the same tool_call_id, -meaning multiple results are provided for a single tool_use. - -This pattern can occur when: -- Conversation is resumed and duplicate ObservationEvent is created -- Event sync issues during conversation restore -- get_unmatched_actions() incorrectly identifies action as unmatched -""" +from tests.integration.api_compliance.patterns import A05_DUPLICATE_TOOL_CALL_ID class DuplicateToolCallIdTest(BaseAPIComplianceTest): @@ -31,65 +19,12 @@ class DuplicateToolCallIdTest(BaseAPIComplianceTest): @property def pattern_name(self) -> str: - return PATTERN_NAME + return A05_DUPLICATE_TOOL_CALL_ID.name @property def pattern_description(self) -> str: - return DESCRIPTION + return A05_DUPLICATE_TOOL_CALL_ID.description def build_malformed_messages(self) -> list[Message]: """Build message sequence with duplicate tool_call_id.""" - return [ - Message( - role="system", - content=[TextContent(text="You are a helpful assistant.")], - ), - Message( - role="user", - content=[TextContent(text="List the files in the current directory.")], - ), - # Assistant message with tool_use - Message( - role="assistant", - content=[TextContent(text="I'll list the files for you.")], - tool_calls=[ - MessageToolCall( - id="call_abc123", - name="terminal", - arguments='{"command": "ls -la"}', - origin="completion", - ) - ], - ), - # First tool result (correct) - Message( - role="tool", - content=[TextContent(text="file1.txt\nfile2.txt")], - tool_call_id="call_abc123", - name="terminal", - ), - # Some intervening messages (simulating conversation continuation) - Message( - role="user", - content=[TextContent(text="Thanks! Now what?")], - ), - Message( - role="assistant", - content=[ - TextContent( - text="You're welcome! Let me know if you need anything else." - ) - ], - ), - Message( - role="user", - content=[TextContent(text="Actually, show me the files again.")], - ), - # DUPLICATE: Second tool result with SAME tool_call_id - Message( - role="tool", - content=[TextContent(text="file1.txt\nfile2.txt\nfile3.txt")], - tool_call_id="call_abc123", # Same ID as before! - name="terminal", - ), - ] + return A05_DUPLICATE_TOOL_CALL_ID.messages diff --git a/tests/integration/tests/a06_wrong_tool_call_id.py b/tests/integration/tests/a06_wrong_tool_call_id.py index 68263a2c9b..db6cdd3849 100644 --- a/tests/integration/tests/a06_wrong_tool_call_id.py +++ b/tests/integration/tests/a06_wrong_tool_call_id.py @@ -2,27 +2,16 @@ API Compliance Test: Wrong tool_call_id Tests how different LLM APIs respond when a tool_result references the wrong -tool_call_id (swapped with another tool_use's ID). +tool_call_id (one that has already been completed). Pattern: - [assistant with tool_use id=A] → [assistant with tool_use id=B] → - [tool_result id=B] → [tool_result id=A] ← IDs swapped! + [assistant with tool_use id=A] → [tool_result id=A] → + [assistant with tool_use id=B] → [tool_result id=A] ← References completed ID! """ -from openhands.sdk.llm import Message, MessageToolCall, TextContent +from openhands.sdk.llm import Message from tests.integration.api_compliance.base import BaseAPIComplianceTest - - -PATTERN_NAME = "wrong_tool_call_id" -DESCRIPTION = """ -Sends a conversation where tool_results are provided but with swapped IDs, -so each tool_result references the wrong tool_use. - -This pattern might occur with: -- ID corruption during serialization -- Race conditions in parallel tool execution -- Manual event manipulation errors -""" +from tests.integration.api_compliance.patterns import A06_WRONG_TOOL_CALL_ID class WrongToolCallIdTest(BaseAPIComplianceTest): @@ -30,61 +19,12 @@ class WrongToolCallIdTest(BaseAPIComplianceTest): @property def pattern_name(self) -> str: - return PATTERN_NAME + return A06_WRONG_TOOL_CALL_ID.name @property def pattern_description(self) -> str: - return DESCRIPTION + return A06_WRONG_TOOL_CALL_ID.description def build_malformed_messages(self) -> list[Message]: """Build message sequence with swapped tool_call_ids.""" - return [ - Message( - role="system", - content=[TextContent(text="You are a helpful assistant.")], - ), - Message( - role="user", - content=[TextContent(text="Run two commands: ls and pwd")], - ), - # First assistant message with tool_use (id=A) - Message( - role="assistant", - content=[TextContent(text="I'll run ls first.")], - tool_calls=[ - MessageToolCall( - id="call_A_ls", - name="terminal", - arguments='{"command": "ls"}', - origin="completion", - ) - ], - ), - # First tool result - CORRECT - Message( - role="tool", - content=[TextContent(text="file1.txt\nfile2.txt")], - tool_call_id="call_A_ls", - name="terminal", - ), - # Second assistant message with tool_use (id=B) - Message( - role="assistant", - content=[TextContent(text="Now I'll run pwd.")], - tool_calls=[ - MessageToolCall( - id="call_B_pwd", - name="terminal", - arguments='{"command": "pwd"}', - origin="completion", - ) - ], - ), - # Second tool result - WRONG ID (references first tool_use) - Message( - role="tool", - content=[TextContent(text="/home/user/project")], - tool_call_id="call_A_ls", # Wrong! Should be call_B_pwd - name="terminal", - ), - ] + return A06_WRONG_TOOL_CALL_ID.messages diff --git a/tests/integration/tests/a07_parallel_missing_result.py b/tests/integration/tests/a07_parallel_missing_result.py index 1e5e4ef72a..23b40fb356 100644 --- a/tests/integration/tests/a07_parallel_missing_result.py +++ b/tests/integration/tests/a07_parallel_missing_result.py @@ -9,20 +9,9 @@ ↑ Missing result for C! """ -from openhands.sdk.llm import Message, MessageToolCall, TextContent +from openhands.sdk.llm import Message from tests.integration.api_compliance.base import BaseAPIComplianceTest - - -PATTERN_NAME = "parallel_missing_result" -DESCRIPTION = """ -Sends a conversation where an assistant message contains multiple parallel -tool_calls, but only some of them have corresponding tool_results. - -This pattern can occur when: -- Partial tool execution failure -- Event loss for some observations -- Timeout causes some results to be missing -""" +from tests.integration.api_compliance.patterns import A07_PARALLEL_MISSING_RESULT class ParallelMissingResultTest(BaseAPIComplianceTest): @@ -30,72 +19,12 @@ class ParallelMissingResultTest(BaseAPIComplianceTest): @property def pattern_name(self) -> str: - return PATTERN_NAME + return A07_PARALLEL_MISSING_RESULT.name @property def pattern_description(self) -> str: - return DESCRIPTION + return A07_PARALLEL_MISSING_RESULT.description def build_malformed_messages(self) -> list[Message]: """Build message sequence with parallel tool calls missing a result.""" - return [ - Message( - role="system", - content=[TextContent(text="You are a helpful assistant.")], - ), - Message( - role="user", - content=[ - TextContent( - text="Get the weather in San Francisco, Tokyo, and Paris." - ) - ], - ), - # Assistant message with THREE parallel tool_calls - Message( - role="assistant", - content=[ - TextContent(text="I'll check the weather in all three cities.") - ], - tool_calls=[ - MessageToolCall( - id="call_sf", - name="terminal", - arguments='{"command": "weather sf"}', - origin="completion", - ), - MessageToolCall( - id="call_tokyo", - name="terminal", - arguments='{"command": "weather tokyo"}', - origin="completion", - ), - MessageToolCall( - id="call_paris", - name="terminal", - arguments='{"command": "weather paris"}', - origin="completion", - ), - ], - ), - # Tool result for SF - provided - Message( - role="tool", - content=[TextContent(text="San Francisco: 65°F, Sunny")], - tool_call_id="call_sf", - name="terminal", - ), - # Tool result for Tokyo - provided - Message( - role="tool", - content=[TextContent(text="Tokyo: 72°F, Cloudy")], - tool_call_id="call_tokyo", - name="terminal", - ), - # NOTE: Tool result for Paris is MISSING! - # Next user message arrives before Paris result - Message( - role="user", - content=[TextContent(text="What about Paris?")], - ), - ] + return A07_PARALLEL_MISSING_RESULT.messages diff --git a/tests/integration/tests/a08_parallel_wrong_order.py b/tests/integration/tests/a08_parallel_wrong_order.py index f5257618e4..45ba982e0a 100644 --- a/tests/integration/tests/a08_parallel_wrong_order.py +++ b/tests/integration/tests/a08_parallel_wrong_order.py @@ -9,20 +9,9 @@ ↑ Results before the tool_calls! """ -from openhands.sdk.llm import Message, MessageToolCall, TextContent +from openhands.sdk.llm import Message from tests.integration.api_compliance.base import BaseAPIComplianceTest - - -PATTERN_NAME = "parallel_wrong_order" -DESCRIPTION = """ -Sends a conversation where tool_results appear before the assistant message -that contains the corresponding tool_calls. This is a severe ordering violation. - -This pattern might occur with: -- Severe event ordering bugs -- Manual conversation manipulation -- Corrupted event stream -""" +from tests.integration.api_compliance.patterns import A08_PARALLEL_WRONG_ORDER class ParallelWrongOrderTest(BaseAPIComplianceTest): @@ -30,53 +19,12 @@ class ParallelWrongOrderTest(BaseAPIComplianceTest): @property def pattern_name(self) -> str: - return PATTERN_NAME + return A08_PARALLEL_WRONG_ORDER.name @property def pattern_description(self) -> str: - return DESCRIPTION + return A08_PARALLEL_WRONG_ORDER.description def build_malformed_messages(self) -> list[Message]: """Build message sequence with tool results before tool calls.""" - return [ - Message( - role="system", - content=[TextContent(text="You are a helpful assistant.")], - ), - Message( - role="user", - content=[TextContent(text="Check the weather in SF and Tokyo.")], - ), - # Tool results appear FIRST (wrong!) - Message( - role="tool", - content=[TextContent(text="San Francisco: 65°F, Sunny")], - tool_call_id="call_sf", - name="terminal", - ), - Message( - role="tool", - content=[TextContent(text="Tokyo: 72°F, Cloudy")], - tool_call_id="call_tokyo", - name="terminal", - ), - # Assistant message with tool_calls comes AFTER tool_results - Message( - role="assistant", - content=[TextContent(text="I'll check both cities.")], - tool_calls=[ - MessageToolCall( - id="call_sf", - name="terminal", - arguments='{"command": "weather sf"}', - origin="completion", - ), - MessageToolCall( - id="call_tokyo", - name="terminal", - arguments='{"command": "weather tokyo"}', - origin="completion", - ), - ], - ), - ] + return A08_PARALLEL_WRONG_ORDER.messages diff --git a/tests/sdk/conversation/compliance/__init__.py b/tests/sdk/conversation/compliance/__init__.py new file mode 100644 index 0000000000..4c378bbe9f --- /dev/null +++ b/tests/sdk/conversation/compliance/__init__.py @@ -0,0 +1 @@ +"""Tests for API compliance monitoring.""" diff --git a/tests/sdk/conversation/compliance/conftest.py b/tests/sdk/conversation/compliance/conftest.py new file mode 100644 index 0000000000..28ab335ca2 --- /dev/null +++ b/tests/sdk/conversation/compliance/conftest.py @@ -0,0 +1,116 @@ +"""Fixtures for API compliance tests.""" + +import uuid + +import pytest + +from openhands.sdk.event import ActionEvent, MessageEvent, ObservationEvent +from openhands.sdk.event.types import ToolCallID +from openhands.sdk.llm import Message, MessageToolCall, TextContent +from openhands.sdk.tool import Observation + + +class SimpleObservation(Observation): + """Simple observation for testing.""" + + result: str + + @property + def to_llm_content(self) -> list[TextContent]: + return [TextContent(text=self.result)] + + +def make_action_event( + tool_call_id: ToolCallID | None = None, + tool_name: str = "terminal", +) -> ActionEvent: + """Create an ActionEvent for testing.""" + call_id = tool_call_id or f"call_{uuid.uuid4().hex[:8]}" + event_id = str(uuid.uuid4()) + return ActionEvent( + id=event_id, + source="agent", + thought=[TextContent(text="Let me do this")], + tool_name=tool_name, + tool_call_id=call_id, + tool_call=MessageToolCall( + id=call_id, + name=tool_name, + arguments='{"command": "ls"}', + origin="completion", + ), + llm_response_id=str(uuid.uuid4()), + ) + + +def make_observation_event( + action_event: ActionEvent, + result: str = "output", +) -> ObservationEvent: + """Create an ObservationEvent matching an ActionEvent.""" + return ObservationEvent( + id=str(uuid.uuid4()), + source="environment", + tool_name=action_event.tool_name, + tool_call_id=action_event.tool_call_id, + action_id=action_event.id, + observation=SimpleObservation(result=result), + ) + + +def make_orphan_observation_event( + tool_call_id: ToolCallID, + tool_name: str = "terminal", + action_id: str | None = None, +) -> ObservationEvent: + """Create an ObservationEvent with no matching action.""" + return ObservationEvent( + id=str(uuid.uuid4()), + source="environment", + tool_name=tool_name, + tool_call_id=tool_call_id, + action_id=action_id or str(uuid.uuid4()), + observation=SimpleObservation(result="orphan result"), + ) + + +def make_user_message_event(text: str = "Hello") -> MessageEvent: + """Create a user MessageEvent.""" + return MessageEvent( + id=str(uuid.uuid4()), + source="user", + llm_message=Message( + role="user", + content=[TextContent(text=text)], + ), + ) + + +def make_assistant_message_event(text: str = "I'll help") -> MessageEvent: + """Create an assistant MessageEvent.""" + return MessageEvent( + id=str(uuid.uuid4()), + source="agent", + llm_message=Message( + role="assistant", + content=[TextContent(text=text)], + ), + ) + + +@pytest.fixture +def action_event(): + """A single action event.""" + return make_action_event() + + +@pytest.fixture +def user_message_event(): + """A user message event.""" + return make_user_message_event() + + +@pytest.fixture +def assistant_message_event(): + """An assistant message event.""" + return make_assistant_message_event() diff --git a/tests/sdk/conversation/compliance/test_monitor.py b/tests/sdk/conversation/compliance/test_monitor.py new file mode 100644 index 0000000000..fe4ba1b483 --- /dev/null +++ b/tests/sdk/conversation/compliance/test_monitor.py @@ -0,0 +1,174 @@ +"""Tests for the APIComplianceMonitor.""" + +from unittest.mock import patch + +from openhands.sdk.conversation.compliance import APIComplianceMonitor +from tests.sdk.conversation.compliance.conftest import ( + make_action_event, + make_observation_event, + make_orphan_observation_event, + make_user_message_event, +) + + +def test_monitor_no_violations_normal_flow(): + """Normal conversation flow should have no violations.""" + monitor = APIComplianceMonitor() + all_violations: list = [] + + # Normal flow: action -> observation -> message + action = make_action_event(tool_call_id="call_1") + violations = monitor.process_event(action) + all_violations.extend(violations) + assert len(violations) == 0 + + obs = make_observation_event(action) + violations = monitor.process_event(obs) + all_violations.extend(violations) + assert len(violations) == 0 + + user_msg = make_user_message_event() + violations = monitor.process_event(user_msg) + all_violations.extend(violations) + assert len(violations) == 0 + + assert len(all_violations) == 0 + + +def test_monitor_detects_interleaved_message(): + """Monitor should detect interleaved message violation.""" + monitor = APIComplianceMonitor() + + action = make_action_event(tool_call_id="call_1") + monitor.process_event(action) + + # User message before observation - violation + user_msg = make_user_message_event() + violations = monitor.process_event(user_msg) + + assert len(violations) == 1 + assert violations[0].property_name == "interleaved_message" + + +def test_monitor_detects_orphan_observation(): + """Monitor should detect orphan observation as single violation.""" + monitor = APIComplianceMonitor() + + # Orphan observation (unknown tool_call_id) + orphan = make_orphan_observation_event(tool_call_id="call_unknown") + violations = monitor.process_event(orphan) + + # Should detect exactly one violation (unmatched_tool_result) + assert len(violations) == 1 + assert violations[0].property_name == "unmatched_tool_result" + + +def test_monitor_returns_violations_per_call(): + """Monitor returns violations for each call, caller can accumulate.""" + monitor = APIComplianceMonitor() + all_violations: list = [] + + # First violation + action = make_action_event(tool_call_id="call_1") + all_violations.extend(monitor.process_event(action)) + violations = monitor.process_event(make_user_message_event()) # interleaved + all_violations.extend(violations) + + initial_count = len(all_violations) + assert initial_count > 0 + + # Second violation + violations = monitor.process_event( + make_orphan_observation_event(tool_call_id="unknown") + ) + all_violations.extend(violations) + + assert len(all_violations) > initial_count + + +def test_monitor_state_persists_across_events(): + """Monitor state should persist correctly across events.""" + monitor = APIComplianceMonitor() + + # Add action + action1 = make_action_event(tool_call_id="call_1") + monitor.process_event(action1) + + assert "call_1" in monitor.state.pending_tool_call_ids + + # Add observation + obs1 = make_observation_event(action1) + monitor.process_event(obs1) + + assert "call_1" not in monitor.state.pending_tool_call_ids + assert "call_1" in monitor.state.completed_tool_call_ids + + +def test_monitor_parallel_tool_calls(): + """Monitor should handle parallel tool calls correctly.""" + monitor = APIComplianceMonitor() + + # Three parallel actions + action1 = make_action_event(tool_call_id="call_sf") + action2 = make_action_event(tool_call_id="call_tokyo") + action3 = make_action_event(tool_call_id="call_paris") + + for action in [action1, action2, action3]: + monitor.process_event(action) + + assert len(monitor.state.pending_tool_call_ids) == 3 + + # Two results arrive + monitor.process_event(make_observation_event(action1)) + monitor.process_event(make_observation_event(action2)) + + assert len(monitor.state.pending_tool_call_ids) == 1 + assert "call_paris" in monitor.state.pending_tool_call_ids + + # User message with one pending - violation + violations = monitor.process_event(make_user_message_event()) + assert len(violations) == 1 + assert "call_paris" in str(violations[0].context) + + +def test_monitor_handles_check_exception_gracefully(): + """Monitor should handle exceptions in check gracefully (fail-closed). + + If _check_tool_call_sequence raises an exception, it should be caught + and logged, not crash the monitor. State should NOT be updated (fail-closed). + """ + monitor = APIComplianceMonitor() + + action = make_action_event(tool_call_id="call_check_err") + monitor.process_event(action) + initial_pending = len(monitor.state.pending_tool_call_ids) + + # Simulate another action where check fails + with patch.object( + monitor, "_check_tool_call_sequence", side_effect=ValueError("Oops!") + ): + # Should not raise - the exception should be caught and logged + violations = monitor.process_event(make_action_event(tool_call_id="call_2")) + + # The monitor should continue working despite the error + assert violations == [] + + # Fail-closed: state should NOT be updated when check fails + assert len(monitor.state.pending_tool_call_ids) == initial_pending + assert "call_2" not in monitor.state.pending_tool_call_ids + + +def test_monitor_handles_update_exception_gracefully(): + """Monitor should handle exceptions in state update gracefully. + + If _update_state raises an exception, it should be caught and logged, + not crash the monitor. + """ + monitor = APIComplianceMonitor() + + with patch.object(monitor, "_update_state", side_effect=ValueError("Oops!")): + # Should not raise - the exception should be caught and logged + violations = monitor.process_event(make_action_event()) + + # The monitor should continue working + assert violations == [] diff --git a/tests/sdk/conversation/compliance/test_pattern_detection.py b/tests/sdk/conversation/compliance/test_pattern_detection.py new file mode 100644 index 0000000000..260dcb06c3 --- /dev/null +++ b/tests/sdk/conversation/compliance/test_pattern_detection.py @@ -0,0 +1,326 @@ +"""Tests that verify APIComplianceMonitor catches all 8 API compliance patterns. + +These tests convert the Message-based patterns from the API compliance tests +into Event sequences and verify the monitor detects the expected violations. + +This provides fast unit test coverage without requiring LLM API calls. +""" + +import uuid + +import pytest + +from openhands.sdk.conversation.compliance import APIComplianceMonitor +from openhands.sdk.event import ActionEvent, MessageEvent, ObservationEvent +from openhands.sdk.llm import Message, MessageToolCall, TextContent +from openhands.sdk.tool import Observation +from tests.integration.api_compliance.patterns import ( + A01_UNMATCHED_TOOL_USE, + A02_UNMATCHED_TOOL_RESULT, + A03_INTERLEAVED_USER_MSG, + A04_INTERLEAVED_ASST_MSG, + A05_DUPLICATE_TOOL_CALL_ID, + A06_WRONG_TOOL_CALL_ID, + A07_PARALLEL_MISSING_RESULT, + A08_PARALLEL_WRONG_ORDER, + ALL_COMPLIANCE_PATTERNS, + CompliancePattern, +) + + +class SimpleObservation(Observation): + """Simple observation for testing.""" + + result: str + + @property + def to_llm_content(self) -> list[TextContent]: + return [TextContent(text=self.result)] + + +def message_to_event( + msg: Message, + pending_tool_calls: dict[str, MessageToolCall], +) -> MessageEvent | ActionEvent | ObservationEvent | None: + """Convert a Message to the appropriate Event type. + + Args: + msg: The Message to convert. + pending_tool_calls: Dict tracking tool_call_id -> MessageToolCall for + pending actions. Updated in-place when we see tool_calls. + + Returns: + The corresponding Event, or None for system messages. + """ + if msg.role == "system": + # Skip system messages - they don't become events + return None + + if msg.role == "user": + return MessageEvent( + id=str(uuid.uuid4()), + source="user", + llm_message=msg, + ) + + if msg.role == "assistant": + # Check if this assistant message has tool_calls + if msg.tool_calls: + # Convert to ActionEvent(s) - we'll just return the first one + # and track all tool_calls + for tc in msg.tool_calls: + pending_tool_calls[tc.id] = tc + + # Return an ActionEvent for the first tool_call + first_tc = msg.tool_calls[0] + thought_text = "" + if msg.content and isinstance(msg.content[0], TextContent): + thought_text = msg.content[0].text + return ActionEvent( + id=str(uuid.uuid4()), + source="agent", + thought=[TextContent(text=thought_text)], + tool_name=first_tc.name, + tool_call_id=first_tc.id, + tool_call=first_tc, + llm_response_id=str(uuid.uuid4()), + ) + else: + # Regular assistant message + return MessageEvent( + id=str(uuid.uuid4()), + source="agent", + llm_message=msg, + ) + + if msg.role == "tool": + # Tool result -> ObservationEvent + tool_call_id = msg.tool_call_id + assert tool_call_id is not None, "Tool message must have tool_call_id" + + # Get result text + result_text = "" + if msg.content: + for content in msg.content: + if isinstance(content, TextContent): + result_text = content.text + break + + return ObservationEvent( + id=str(uuid.uuid4()), + source="environment", + tool_name=msg.name or "terminal", + tool_call_id=tool_call_id, + action_id=str(uuid.uuid4()), # We don't track this precisely + observation=SimpleObservation(result=result_text), + ) + + return None + + +def convert_pattern_to_events( + pattern: CompliancePattern, +) -> list[MessageEvent | ActionEvent | ObservationEvent]: + """Convert a compliance pattern's messages to events. + + For patterns with multiple parallel tool_calls, we generate one ActionEvent + per tool_call (not just one per assistant message). + + Returns: + List of events representing the pattern. + """ + events: list[MessageEvent | ActionEvent | ObservationEvent] = [] + pending_tool_calls: dict[str, MessageToolCall] = {} + + for msg in pattern.messages: + if msg.role == "assistant" and msg.tool_calls and len(msg.tool_calls) > 1: + # Handle parallel tool calls - create ActionEvent for each + thought_text = "" + if msg.content and isinstance(msg.content[0], TextContent): + thought_text = msg.content[0].text + for tc in msg.tool_calls: + pending_tool_calls[tc.id] = tc + event = ActionEvent( + id=str(uuid.uuid4()), + source="agent", + thought=[TextContent(text=thought_text)], + tool_name=tc.name, + tool_call_id=tc.id, + tool_call=tc, + llm_response_id=str(uuid.uuid4()), + ) + events.append(event) + else: + event = message_to_event(msg, pending_tool_calls) + if event is not None: + events.append(event) + + return events + + +def run_pattern_through_monitor( + pattern: CompliancePattern, +) -> tuple[list[str], int]: + """Run a pattern through the monitor and collect violations. + + Returns: + Tuple of (list of violation property_names, index of first violation) + """ + monitor = APIComplianceMonitor() + events = convert_pattern_to_events(pattern) + + all_violations: list[str] = [] + first_violation_idx = -1 + + for i, event in enumerate(events): + violations = monitor.process_event(event) + for v in violations: + if first_violation_idx == -1: + first_violation_idx = i + all_violations.append(v.property_name) + + return all_violations, first_violation_idx + + +# ============================================================================= +# Parametrized test for all patterns +# ============================================================================= + + +@pytest.mark.parametrize( + "pattern", + ALL_COMPLIANCE_PATTERNS, + ids=[p.name for p in ALL_COMPLIANCE_PATTERNS], +) +def test_monitor_detects_pattern(pattern: CompliancePattern): + """Verify the monitor detects the expected violation for each pattern.""" + violations, _ = run_pattern_through_monitor(pattern) + + assert len(violations) > 0, f"Pattern '{pattern.name}' should trigger a violation" + assert pattern.expected_violation in violations, ( + f"Pattern '{pattern.name}' should trigger '{pattern.expected_violation}', " + f"but got: {violations}" + ) + + +# ============================================================================= +# Individual pattern tests with detailed assertions +# ============================================================================= + + +def test_a01_unmatched_tool_use(): + """Pattern a01: User message while tool call is pending.""" + violations, idx = run_pattern_through_monitor(A01_UNMATCHED_TOOL_USE) + + assert "interleaved_message" in violations + # Violation should occur when user message arrives (4th event: after action) + assert idx == 2 # 0=user, 1=action, 2=user (violation) + + +def test_a02_unmatched_tool_result(): + """Pattern a02: Tool result with unknown tool_call_id.""" + violations, idx = run_pattern_through_monitor(A02_UNMATCHED_TOOL_RESULT) + + assert "unmatched_tool_result" in violations + # Violation should occur when orphan tool result arrives + assert idx == 2 # 0=user, 1=assistant, 2=tool (violation) + + +def test_a03_interleaved_user_msg(): + """Pattern a03: User message between tool_use and tool_result.""" + violations, idx = run_pattern_through_monitor(A03_INTERLEAVED_USER_MSG) + + assert "interleaved_message" in violations + # Violation on interleaved user message + assert idx == 2 # 0=user, 1=action, 2=user (violation), 3=tool + + +def test_a04_interleaved_asst_msg(): + """Pattern a04: Assistant message between tool_use and tool_result.""" + violations, idx = run_pattern_through_monitor(A04_INTERLEAVED_ASST_MSG) + + assert "interleaved_message" in violations + # Violation on interleaved assistant message + assert idx == 2 # 0=user, 1=action, 2=assistant (violation), 3=tool + + +def test_a05_duplicate_tool_call_id(): + """Pattern a05: Second tool_result with same tool_call_id.""" + violations, idx = run_pattern_through_monitor(A05_DUPLICATE_TOOL_CALL_ID) + + assert "duplicate_tool_result" in violations + + +def test_a06_wrong_tool_call_id(): + """Pattern a06: Tool result references wrong (already completed) ID.""" + violations, idx = run_pattern_through_monitor(A06_WRONG_TOOL_CALL_ID) + + # This results in a duplicate because call_A_ls was already completed + assert "duplicate_tool_result" in violations + + +def test_a07_parallel_missing_result(): + """Pattern a07: User message while parallel tool calls are pending.""" + violations, idx = run_pattern_through_monitor(A07_PARALLEL_MISSING_RESULT) + + assert "interleaved_message" in violations + + +def test_a08_parallel_wrong_order(): + """Pattern a08: Tool results arrive before tool_calls.""" + violations, idx = run_pattern_through_monitor(A08_PARALLEL_WRONG_ORDER) + + assert "unmatched_tool_result" in violations + # First violation should be on first tool result (unknown ID at that point) + assert idx == 1 # 0=user, 1=tool (violation) + + +# ============================================================================= +# Sanity check: valid sequences should have no violations +# ============================================================================= + + +def test_valid_sequence_no_violations(): + """A valid tool-call sequence should have no violations.""" + monitor = APIComplianceMonitor() + + # Create a valid sequence: user -> action -> observation -> user + user1 = MessageEvent( + id=str(uuid.uuid4()), + source="user", + llm_message=Message(role="user", content=[TextContent(text="List files")]), + ) + tool_call = MessageToolCall( + id="call_valid", + name="terminal", + arguments='{"command": "ls"}', + origin="completion", + ) + action = ActionEvent( + id=str(uuid.uuid4()), + source="agent", + thought=[TextContent(text="I'll list files")], + tool_name="terminal", + tool_call_id="call_valid", + tool_call=tool_call, + llm_response_id=str(uuid.uuid4()), + ) + observation = ObservationEvent( + id=str(uuid.uuid4()), + source="environment", + tool_name="terminal", + tool_call_id="call_valid", + action_id=action.id, + observation=SimpleObservation(result="file1.txt\nfile2.txt"), + ) + user2 = MessageEvent( + id=str(uuid.uuid4()), + source="user", + llm_message=Message(role="user", content=[TextContent(text="Thanks!")]), + ) + + all_violations = [] + for event in [user1, action, observation, user2]: + all_violations.extend(monitor.process_event(event)) + + assert len(all_violations) == 0, f"Valid sequence had violations: {all_violations}" diff --git a/tests/sdk/conversation/compliance/test_properties.py b/tests/sdk/conversation/compliance/test_properties.py new file mode 100644 index 0000000000..15513f1a65 --- /dev/null +++ b/tests/sdk/conversation/compliance/test_properties.py @@ -0,0 +1,236 @@ +"""Tests for API compliance monitoring. + +These tests verify the monitor detects all 8 API compliance patterns from +tests/integration/tests/a*.py: + +- a01: Unmatched tool_use (message while calls pending) +- a02: Unmatched tool_result (result with unknown ID) +- a03: Interleaved user message (user message while calls pending) +- a04: Interleaved assistant message (assistant message while calls pending) +- a05: Duplicate tool_call_id (result for already-completed ID) +- a06: Wrong tool_call_id (result with wrong/unknown ID) +- a07: Parallel missing result (message before all parallel results) +- a08: Parallel wrong order (result before action) +""" + +from openhands.sdk.conversation.compliance import APIComplianceMonitor +from tests.sdk.conversation.compliance.conftest import ( + make_action_event, + make_assistant_message_event, + make_observation_event, + make_orphan_observation_event, + make_user_message_event, +) + + +# ============================================================================= +# Interleaved Message Violations (a01, a03, a04, a07) +# ============================================================================= + + +def test_no_violation_message_when_no_pending_actions(): + """User message is fine when no tool calls are pending.""" + monitor = APIComplianceMonitor() + + user_msg = make_user_message_event() + violations = monitor.process_event(user_msg) + + assert len(violations) == 0 + + +def test_violation_user_message_with_pending_action(): + """User message while action is pending violates the property (a01/a03).""" + monitor = APIComplianceMonitor() + + # Add an action (now pending) + action = make_action_event(tool_call_id="call_123") + monitor.process_event(action) + + # User message before observation - violation + user_msg = make_user_message_event() + violations = monitor.process_event(user_msg) + + assert len(violations) == 1 + assert violations[0].property_name == "interleaved_message" + assert "pending" in violations[0].description.lower() + assert "call_123" in str(violations[0].context) + + +def test_violation_assistant_message_with_pending_action(): + """Assistant message while action is pending violates the property (a04).""" + monitor = APIComplianceMonitor() + + action = make_action_event(tool_call_id="call_456") + monitor.process_event(action) + + assistant_msg = make_assistant_message_event() + violations = monitor.process_event(assistant_msg) + + assert len(violations) == 1 + assert violations[0].property_name == "interleaved_message" + + +def test_violation_parallel_missing_result(): + """User message with partial parallel results violates property (a07).""" + monitor = APIComplianceMonitor() + + # Three parallel actions + action1 = make_action_event(tool_call_id="call_sf") + action2 = make_action_event(tool_call_id="call_tokyo") + action3 = make_action_event(tool_call_id="call_paris") + for action in [action1, action2, action3]: + monitor.process_event(action) + + # Two results arrive + monitor.process_event(make_observation_event(action1)) + monitor.process_event(make_observation_event(action2)) + # call_paris is still pending + + user_msg = make_user_message_event("What about Paris?") + violations = monitor.process_event(user_msg) + + assert len(violations) == 1 + assert "call_paris" in str(violations[0].context) + + +def test_no_violation_for_action_events(): + """ActionEvent itself doesn't trigger violation (always allowed).""" + monitor = APIComplianceMonitor() + + # Even with pending actions, a new action is fine + action1 = make_action_event(tool_call_id="call_existing") + monitor.process_event(action1) + + action2 = make_action_event(tool_call_id="call_new") + violations = monitor.process_event(action2) + + assert len(violations) == 0 + + +def test_no_violation_for_matching_observation(): + """ObservationEvent matching a pending action is allowed.""" + monitor = APIComplianceMonitor() + + action = make_action_event(tool_call_id="call_789") + monitor.process_event(action) + + obs = make_observation_event(action) + violations = monitor.process_event(obs) + + assert len(violations) == 0 + + +# ============================================================================= +# Unmatched Tool Result Violations (a02, a06, a08) +# ============================================================================= + + +def test_no_violation_when_action_exists(): + """Tool result is fine when its tool_call_id is pending.""" + monitor = APIComplianceMonitor() + + action = make_action_event(tool_call_id="call_valid") + monitor.process_event(action) + + obs = make_observation_event(action) + violations = monitor.process_event(obs) + + assert len(violations) == 0 + + +def test_violation_unknown_tool_call_id(): + """Tool result with unknown tool_call_id violates property (a02/a06/a08).""" + monitor = APIComplianceMonitor() + + # No actions have been seen + orphan_obs = make_orphan_observation_event(tool_call_id="call_unknown") + violations = monitor.process_event(orphan_obs) + + assert len(violations) == 1 + assert violations[0].property_name == "unmatched_tool_result" + assert "call_unknown" in violations[0].description + + +def test_violation_wrong_tool_call_id(): + """Tool result referencing wrong tool_call_id violates property (a06).""" + monitor = APIComplianceMonitor() + + # We have action with call_correct, but result references call_wrong + action = make_action_event(tool_call_id="call_correct") + monitor.process_event(action) + + orphan_obs = make_orphan_observation_event(tool_call_id="call_wrong") + violations = monitor.process_event(orphan_obs) + + assert len(violations) == 1 + assert violations[0].property_name == "unmatched_tool_result" + assert "call_wrong" in violations[0].description + + +# ============================================================================= +# Duplicate Tool Result Violations (a05) +# ============================================================================= + + +def test_no_violation_first_result(): + """First tool result for a tool_call_id is fine.""" + monitor = APIComplianceMonitor() + + action = make_action_event(tool_call_id="call_first") + monitor.process_event(action) + + obs = make_observation_event(action) + violations = monitor.process_event(obs) + + assert len(violations) == 0 + + +def test_violation_duplicate_result(): + """Second tool result for same tool_call_id violates property (a05).""" + monitor = APIComplianceMonitor() + + action = make_action_event(tool_call_id="call_duplicate") + monitor.process_event(action) + + # First result - fine + obs1 = make_observation_event(action) + violations = monitor.process_event(obs1) + assert len(violations) == 0 + + # Second result for same ID - violation + obs2 = make_orphan_observation_event(tool_call_id="call_duplicate") + violations = monitor.process_event(obs2) + + assert len(violations) == 1 + assert violations[0].property_name == "duplicate_tool_result" + assert "call_duplicate" in violations[0].description + + +# ============================================================================= +# State Update Tests +# ============================================================================= + + +def test_state_update_action_adds_pending(): + """Adding an action should update pending_tool_call_ids.""" + monitor = APIComplianceMonitor() + + action = make_action_event(tool_call_id="call_new") + monitor.process_event(action) + + assert "call_new" in monitor.state.pending_tool_call_ids + + +def test_state_update_observation_resolves_pending(): + """Adding an observation should move from pending to completed.""" + monitor = APIComplianceMonitor() + + action = make_action_event(tool_call_id="call_resolve") + monitor.process_event(action) + assert "call_resolve" in monitor.state.pending_tool_call_ids + + obs = make_observation_event(action) + monitor.process_event(obs) + + assert "call_resolve" not in monitor.state.pending_tool_call_ids + assert "call_resolve" in monitor.state.completed_tool_call_ids diff --git a/tests/sdk/conversation/compliance/test_state_integration.py b/tests/sdk/conversation/compliance/test_state_integration.py new file mode 100644 index 0000000000..d373d47794 --- /dev/null +++ b/tests/sdk/conversation/compliance/test_state_integration.py @@ -0,0 +1,154 @@ +"""Tests for ConversationState.add_event() integration with compliance monitoring.""" + +import uuid + +import pytest +from pydantic import SecretStr + +from openhands.sdk import LLM, Agent +from openhands.sdk.conversation.state import ConversationState +from openhands.sdk.workspace import LocalWorkspace +from tests.sdk.conversation.compliance.conftest import ( + make_action_event, + make_observation_event, + make_user_message_event, +) + + +@pytest.fixture +def temp_workspace(tmp_path): + """Create a temporary workspace for testing.""" + workspace_dir = tmp_path / "workspace" + workspace_dir.mkdir() + return LocalWorkspace(working_dir=workspace_dir) + + +@pytest.fixture +def mock_agent(): + """Create a mock agent for testing.""" + llm = LLM(model="mock-model", api_key=SecretStr("fake-key")) + return Agent(llm=llm) + + +@pytest.fixture +def conversation_state(temp_workspace, mock_agent): + """Create a ConversationState for testing.""" + return ConversationState.create( + id=uuid.uuid4(), + agent=mock_agent, + workspace=temp_workspace, + persistence_dir=None, # In-memory for testing + ) + + +def test_add_event_appends_to_event_log(conversation_state): + """add_event() should append events to the event log and return True.""" + initial_count = len(conversation_state.events) + + user_msg = make_user_message_event("Hello") + result = conversation_state.add_event(user_msg) + + assert result is True + assert len(conversation_state.events) == initial_count + 1 + assert conversation_state.events[-1].id == user_msg.id + + +def test_add_event_lazy_creates_monitor(conversation_state): + """compliance_monitor should be lazily initialized.""" + # Initially None + assert conversation_state._compliance_monitor is None + + # Access via property triggers creation + monitor = conversation_state.compliance_monitor + + assert monitor is not None + assert conversation_state._compliance_monitor is monitor + + +def test_add_event_checks_compliance(conversation_state, caplog): + """add_event() should check compliance and log violations.""" + import logging + + # Add an action + action = make_action_event(tool_call_id="call_1") + conversation_state.add_event(action) + + # User message while action pending should create violation + user_msg = make_user_message_event() + + with caplog.at_level(logging.WARNING): + conversation_state.add_event(user_msg) + + # Should have logged violation + assert "interleaved_message" in caplog.text + assert "API compliance violation detected" in caplog.text + + +def test_add_event_normal_flow_no_violations(conversation_state, caplog): + """Normal conversation flow should have no violations.""" + import logging + + # Normal flow: action -> observation -> user message + action = make_action_event(tool_call_id="call_1") + conversation_state.add_event(action) + + obs = make_observation_event(action) + conversation_state.add_event(obs) + + user_msg = make_user_message_event() + + with caplog.at_level(logging.WARNING): + conversation_state.add_event(user_msg) + + # No violations logged + assert "API compliance violation detected" not in caplog.text + + +def test_add_event_rejects_on_violation(conversation_state, caplog): + """Events with violations should be rejected (not added) and return False.""" + import logging + + action = make_action_event(tool_call_id="call_1") + result = conversation_state.add_event(action) + assert result is True + + initial_count = len(conversation_state.events) + + # User message while action pending - violation + user_msg = make_user_message_event() + + with caplog.at_level(logging.WARNING): + result = conversation_state.add_event(user_msg) + + # Should return False for rejected event + assert result is False + + # Event should NOT be in the log + assert len(conversation_state.events) == initial_count + assert conversation_state.events[-1].id == action.id + + # Violation should be logged + assert "API compliance violation detected" in caplog.text + + +def test_add_event_tracks_state_correctly(conversation_state): + """add_event() should correctly update compliance state.""" + action = make_action_event(tool_call_id="call_track") + conversation_state.add_event(action) + + monitor = conversation_state.compliance_monitor + assert "call_track" in monitor.state.pending_tool_call_ids + + obs = make_observation_event(action) + conversation_state.add_event(obs) + + assert "call_track" not in monitor.state.pending_tool_call_ids + assert "call_track" in monitor.state.completed_tool_call_ids + + +def test_compliance_monitor_property_returns_same_instance(conversation_state): + """compliance_monitor property should return the same instance each time.""" + monitor1 = conversation_state.compliance_monitor + monitor2 = conversation_state.compliance_monitor + + assert monitor1 is monitor2