diff --git a/src/strands_evals/simulation/__init__.py b/src/strands_evals/simulation/__init__.py index 6a4be0f..3097b0d 100644 --- a/src/strands_evals/simulation/__init__.py +++ b/src/strands_evals/simulation/__init__.py @@ -1,6 +1,11 @@ from .actor_simulator import ActorSimulator +from .tool_simulator import ToolSimulator # Alias for backward compatibility UserSimulator = ActorSimulator -__all__ = ["ActorSimulator", "UserSimulator"] +__all__ = [ + "ActorSimulator", + "UserSimulator", + "ToolSimulator", +] diff --git a/src/strands_evals/simulation/prompt_templates/__init__.py b/src/strands_evals/simulation/prompt_templates/__init__.py index 0d0771d..6d626d3 100644 --- a/src/strands_evals/simulation/prompt_templates/__init__.py +++ b/src/strands_evals/simulation/prompt_templates/__init__.py @@ -1,11 +1,13 @@ -"""Prompt templates for actor simulation.""" +"""Prompt templates for simulation components.""" from .actor_profile_extraction import ACTOR_PROFILE_PROMPT_TEMPLATE from .actor_system_prompt import DEFAULT_USER_SIMULATOR_PROMPT_TEMPLATE from .goal_completion import GOAL_COMPLETION_PROMPT +from .tool_response_generation import FUNCTION_TOOL_RESPONSE_GENERATION_PROMPT __all__ = [ "ACTOR_PROFILE_PROMPT_TEMPLATE", "DEFAULT_USER_SIMULATOR_PROMPT_TEMPLATE", "GOAL_COMPLETION_PROMPT", + "FUNCTION_TOOL_RESPONSE_GENERATION_PROMPT", ] diff --git a/src/strands_evals/simulation/prompt_templates/tool_response_generation.py b/src/strands_evals/simulation/prompt_templates/tool_response_generation.py new file mode 100644 index 0000000..1feb3db --- /dev/null +++ b/src/strands_evals/simulation/prompt_templates/tool_response_generation.py @@ -0,0 +1,67 @@ +""" +Prompt templates for function tool response generation in Strands Evals. + +This module contains prompt templates used to generate realistic tool responses during +agent evaluation scenarios. These templates enable LLM-powered simulation of tool +behavior when actual tools are not available or when consistent, controllable responses +are needed for evaluation purposes. +""" + +from textwrap import dedent + +FUNCTION_TOOL_RESPONSE_GENERATION_PROMPT = dedent( + """ +You are simulating a function tool call for agent evaluation. Generate a realistic response based on the function name, +parameters, and context. + +## Function Tool Information +Tool Name: {tool_name} +Parameters: {parameters} + +## Initial State Context +{initial_state_description} + +## Current State & Previous Tool Responses (for context) +{previous_responses} + +## Instructions +1. Analyze the function name and parameters to understand what this tool should do +2. Use the initial state description to understand the starting context and available data +3. Generate a realistic response that would be returned by such a function +4. Consider the previous responses to maintain consistency in the simulation +5. Ensure responses are consistent with the established state and realistic for the domain +6. Return valid JSON that represents the function's return value + +## Response Format +Return a JSON object that represents what this function would return. Examples: + +For data retrieval functions: +```json +{{ + "status": "success", + "data": {{ + "result": "retrieved data", + "count": 42 + }} +}} +``` + +For action functions: +```json +{{ + "status": "success", + "message": "Action completed successfully", + "transaction_id": "txn_12345" +}} +``` + +For calculation functions: +```json +{{ + "result": 156.78, + "unit": "meters" +}} +``` + +Generate only valid JSON with no markdown code blocks or additional explanation.""" +) diff --git a/src/strands_evals/simulation/tool_simulator.py b/src/strands_evals/simulation/tool_simulator.py new file mode 100644 index 0000000..f2914e2 --- /dev/null +++ b/src/strands_evals/simulation/tool_simulator.py @@ -0,0 +1,420 @@ +import json +import logging +import warnings +from collections import defaultdict, deque +from datetime import datetime +from typing import Any, Callable + +from pydantic import BaseModel +from strands import Agent +from strands.agent import AgentResult +from strands.models.model import Model +from strands.tools.decorator import DecoratedFunctionTool + +from strands_evals.types.simulation.tool import RegisteredTool + +from .prompt_templates.tool_response_generation import FUNCTION_TOOL_RESPONSE_GENERATION_PROMPT + +logger = logging.getLogger(__name__) + + +class StateRegistry: + """ + State registry for managing shared state between tool simulators. + Organized by state_key to isolate state between different tools or shared state groups. + + The registry automatically maintains a bounded cache of tool calls per state key. + The maximum number of tool calls stored is configurable via max_tool_call_cache_size parameter. + + Attributes: + max_tool_call_cache_size: Maximum number of tool calls to store per state key. + """ + + def __init__(self, max_tool_call_cache_size: int = 20): + """ + Initialize state registry. + + Creates an empty state dictionary to track tool calls and responses + across different simulation sessions. Tool call cache is automatically + bounded to prevent excessive memory usage. + + Args: + max_tool_call_cache_size: Maximum number of tool calls to store per state key. + Older calls are automatically evicted when limit is exceeded. + Default is 20. + """ + self.max_tool_call_cache_size = max_tool_call_cache_size + self._states: defaultdict[str, dict[str, Any]] = defaultdict( + lambda: {"previous_calls": deque(maxlen=self.max_tool_call_cache_size)} + ) + + def initialize_state_via_description(self, initial_state_description: str, state_key: str) -> None: + """ + Initialize state based on the provided description. + + This method pre-seeds the state with an initial description that will be + included in all subsequent LLM prompts, allowing the simulator to have + context about pre-existing data or system state. + + Args: + initial_state_description: Description of the initial state (e.g., existing + database records, system configuration, etc.). + state_key: Key for the state in the registry (typically tool_name or share_state_id). + """ + if state_key not in self._states: + self._states[state_key] = { + "initial_state": initial_state_description, + "previous_calls": deque(maxlen=self.max_tool_call_cache_size), + } + else: + warnings.warn( + f"State with key '{state_key}' already initialized. Skipping re-initialization.", stacklevel=2 + ) + + def get_state(self, state_key: str) -> dict[str, Any]: + """ + Get state for a specific tool or shared state group. + + Args: + state_key: Key for the state (tool_name or share_state_id). + + Returns: + State dictionary containing previous_calls. + """ + if state_key is None: + raise ValueError("Value of state_key is required.") + + # Access will create the default state automatically due to defaultdict + state = self._states[state_key] + + # Convert deque to list for JSON serialization compatibility + return {key: list(value) if isinstance(value, deque) else value for key, value in state.items()} + + def cache_tool_call( + self, + tool_name: str, + state_key: str, + response_data: Any, + parameters: dict[str, Any], + ) -> dict[str, Any]: + """ + Cache a tool call in the tool's state key. + + Args: + tool_name: Name of the tool being called. + state_key: Key for the state (tool_name or share_state_id). + response_data: Response from the tool call. + parameters: Function parameters. + + Returns: + Updated state dictionary. + """ + # Access the actual state storage (not converted copy) + state = self._states[state_key] + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + + call_record = { + "tool_name": tool_name, + "response": response_data, + "timestamp": timestamp, + "parameters": parameters, + } + + # Append to deque with automatic FIFO eviction when cache is full + state["previous_calls"].append(call_record) + return self.get_state(state_key) + + def clear_state(self, state_key: str) -> None: + """ + Clear state for a specific tool or shared state group. + + Args: + state_key: Key for the state to clear. + """ + if state_key in self._states: + del self._states[state_key] + + +class ToolSimulator: + """ + Simulates tool behavior with decorator-based registration system for agent evaluation. + + ToolSimulator provides a decorator for tools and maintains a registry of all + registered tools. It can be configured to override tool behavior for simulation purposes, + enabling controlled testing scenarios. + + IMPORTANT: This simulator expects functions to be decorated with Strands' @tool decorator first. + + Example usage: + simulator = ToolSimulator() + + @simulator.tool() + @tool + def my_tool(param: str) -> dict: + '''Tool description''' + pass + + The simulator automatically maintains a bounded cache of tool calls for context. + The maximum number of tool calls stored per state key is configurable via + max_tool_call_cache_size parameter (default: 20). + + Attributes: + state_registry: Registry for maintaining tool state across calls. + function_tool_prompt: Custom prompt template for tool response generation. + model: Provider for running inference or model identifier for Bedrock. + max_tool_call_cache_size: Maximum number of tool calls to store per state key. + """ + + def __init__( + self, + state_registry: StateRegistry | None = None, + function_tool_prompt: str | None = None, + model: Model | str | None = None, + max_tool_call_cache_size: int = 20, + ): + """ + Initialize a ToolSimulator instance. + + Args: + state_registry: Registry for maintaining tool state. If not provided, + a new StateRegistry will be created with max_tool_call_cache_size. + function_tool_prompt: Optional custom prompt for tool response generation + model: Provider for running inference or a string representing the model-id for Bedrock to use + max_tool_call_cache_size: Maximum number of tool calls to store per state key. + Only used when creating a new StateRegistry (ignored if state_registry + is provided). Older calls are automatically evicted when limit is exceeded. + Default is 20. + """ + self.model = model + self.function_tool_prompt = function_tool_prompt or FUNCTION_TOOL_RESPONSE_GENERATION_PROMPT + self.state_registry = state_registry or StateRegistry(max_tool_call_cache_size=max_tool_call_cache_size) + self._registered_tools: dict[str, RegisteredTool] = {} + self._initialize_shared_states() + + def _initialize_shared_states(self): + """Initialize shared states from registered tools' initial descriptions.""" + for tool_name, registered_tool in self._registered_tools.items(): + if registered_tool.initial_state_description: + state_key = registered_tool.share_state_id or registered_tool.name + self.state_registry.initialize_state_via_description( + registered_tool.initial_state_description, state_key + ) + logger.info(f"Initialized state for tool '{tool_name}' with key '{state_key}'") + + def _create_tool_wrapper(self, registered_tool: RegisteredTool): + """ + Create a Strands tool wrapper for simulation. + + Since the registered function is already a DecoratedFunctionTool (from @tool decorator), + we reuse its existing metadata and spec, but replace the tool_func with our simulation wrapper. + """ + original_tool = registered_tool.function + + if not isinstance(original_tool, DecoratedFunctionTool): + raise TypeError( + f"Expected DecoratedFunctionTool, got {type(original_tool).__name__}. " + f"Ensure your function is decorated with @tool first." + ) + + def wrapper(*args, **kwargs): + state_key = registered_tool.share_state_id or registered_tool.name + + parameters_string = ( + json.dumps({"args": args, "kwargs": kwargs}, indent=2) if args else json.dumps(kwargs, indent=2) + ) + + input_data = { + "tool_name": registered_tool.name, + "parameters": parameters_string, + } + + return self._call_tool(registered_tool, input_data, state_key) + + if registered_tool.function: + wrapper.__name__ = registered_tool.function.__name__ + wrapper.__doc__ = registered_tool.function.__doc__ + else: + wrapper.__name__ = registered_tool.name + wrapper.__doc__ = f"Simulated {registered_tool.name} tool" + + tool_spec = original_tool.tool_spec.copy() + tool_spec["name"] = registered_tool.name + + simulated_tool = DecoratedFunctionTool( + tool_name=registered_tool.name, + tool_spec=tool_spec, + tool_func=wrapper, # Use our simulation wrapper instead of original function + metadata=original_tool._metadata, # Reuse existing metadata + ) + + return simulated_tool + + def _simulate_tool_call(self, prompt: str, structured_output_model=None) -> Any: + """Tool simulation agent creation and response generation.""" + agent = Agent( + system_prompt=prompt, + tools=[], + model=self.model, + callback_handler=None, + ) + return agent(prompt, structured_output_model=structured_output_model) + + def _parse_simulated_response(self, result: AgentResult) -> dict[str, Any]: + """Parse tool simulation agent response, trying to extract JSON first, falling back to wrapping in result.""" + response_text = str(result) or "No response" + try: + response_data = json.loads(response_text) + except json.JSONDecodeError: + response_data = {"result": response_text} + return response_data + + def _call_tool(self, registered_tool: RegisteredTool, input_data: dict[str, Any], state_key: str) -> dict[str, Any]: + """Simulate a tool invocation and return the response.""" + parameters = input_data.get("parameters", {}) + + current_state = self.state_registry.get_state(state_key) + + prompt = self.function_tool_prompt.format( + tool_name=registered_tool.name, + parameters=json.dumps(parameters, indent=2), + initial_state_description=current_state.get("initial_state", "No initial state provided."), + previous_responses=json.dumps(current_state.get("previous_calls", []), indent=2), + ) + + result = self._simulate_tool_call(prompt, structured_output_model=registered_tool.output_schema) + + response_data = self._parse_simulated_response(result) + + self.state_registry.cache_tool_call(registered_tool.name, state_key, response_data, parameters=parameters) + + return response_data + + def tool( + self, + name: str | None = None, + output_schema: type[BaseModel] | None = None, + share_state_id: str | None = None, + initial_state_description: str | None = None, + ) -> Callable: + """ + Decorator for registering tools with flexible output schemas. + + IMPORTANT: This decorator expects the function to already be decorated with @tool + from strands.tools.decorator. When output_schema is not provided, the input_model + from the DecoratedFunctionTool's metadata will be automatically used as the output_schema. + + Args: + name: Optional name for the tool. If None, uses DecoratedFunctionTool.tool_name + output_schema: Optional Pydantic BaseModel for output schema. If None, uses the + input_model from the DecoratedFunctionTool's metadata. + share_state_id: Optional shared state ID for sharing state between tools + initial_state_description: Optional initial state description for the tool's context + + Returns: + Decorator function + """ + + def decorator(func: Callable) -> Callable: + try: + if not isinstance(func, DecoratedFunctionTool): + raise TypeError( + f"Expected DecoratedFunctionTool (from @tool decorator), got {type(func).__name__}. " + f"Please ensure your function is decorated with @tool first, then @simulator.tool()." + ) + + tool_name = name or func.tool_name + + final_output_schema = output_schema + if ( + final_output_schema is None + and hasattr(func, "_metadata") + and hasattr(func._metadata, "input_model") + ): + final_output_schema = func._metadata.input_model + logger.info( + f"Using input_model from DecoratedFunctionTool metadata as output_schema for tool '{tool_name}'" + ) + + registered_tool = RegisteredTool( + name=tool_name, + function=func, + output_schema=final_output_schema, + initial_state_description=initial_state_description, + share_state_id=share_state_id, + ) + self._registered_tools[tool_name] = registered_tool + + if initial_state_description: + state_key = share_state_id or tool_name + self.state_registry.initialize_state_via_description(initial_state_description, state_key) + logger.info(f"Initialized state for tool '{tool_name}' with key '{state_key}'") + + logger.info(f"Registered tool: {tool_name}") + + except Exception as e: + raise RuntimeError(f"Error registering tool {name or getattr(func, '__name__', 'unknown')}: {e}") from e + + return func + + return decorator + + def __getattr__(self, name: str) -> Any: + """ + Allow direct access to registered tools as attributes. + + Args: + name: Tool name + + Returns: + Tool callable wrapper + + Raises: + AttributeError: If tool not found + """ + registered_tool = self._registered_tools.get(name) + if registered_tool: + return self._create_tool_wrapper(registered_tool) + + raise AttributeError(f"Tool '{name}' not found in registered tools") + + def get_tool(self, tool_name: str) -> Callable | None: + """ + Get a tool by name and create a simulation wrapper. + + Args: + tool_name: Name of the tool to retrieve + + Returns: + Tool callable wrapper if found, None otherwise + """ + registered_tool = self._registered_tools.get(tool_name) + if not registered_tool: + return None + + return self._create_tool_wrapper(registered_tool) + + def list_tools(self) -> list[str]: + """ + List all registered tool names. + + Returns: + List of tool names + """ + return list(self._registered_tools.keys()) + + def clear_tools(self): + """Clear all registered tools for this simulator instance.""" + self._registered_tools.clear() + logger.info("Cleared tool registry for this simulator instance") + + def get_state(self, state_key: str) -> dict[str, Any]: + """ + Get state for a specific tool or shared state group. + + Args: + state_key: Key for the state (tool_name or share_state_id). + + Returns: + State dictionary containing previous_calls. + """ + return self.state_registry.get_state(state_key) diff --git a/src/strands_evals/types/simulation/tool.py b/src/strands_evals/types/simulation/tool.py new file mode 100644 index 0000000..e816aae --- /dev/null +++ b/src/strands_evals/types/simulation/tool.py @@ -0,0 +1,30 @@ +from typing import Callable + +from pydantic import BaseModel, Field + + +class RegisteredTool(BaseModel): + """ + Represents a registered function tool in the simulator. + + Attributes: + name: Name of the tool for identification and registration. + function: Function callable (excluded from serialization). + output_schema: Pydantic BaseModel for output schema (excluded from serialization). + initial_state_description: Initial state description for the tool's context. + share_state_id: Optional shared state ID for sharing state between tools. + """ + + name: str = Field(..., description="Name of the tool") + function: Callable | None = Field(default=None, description="Function callable", exclude=True) + output_schema: type[BaseModel] | None = Field( + default=None, description="Pydantic BaseModel for output schema", exclude=True + ) + initial_state_description: str | None = Field( + default=None, description="Initial state description for the tool's context" + ) + share_state_id: str | None = Field( + default=None, description="Optional shared state ID for sharing state between tools" + ) + + model_config = {"arbitrary_types_allowed": True} diff --git a/tests/strands_evals/simulation/test_tool_simulator.py b/tests/strands_evals/simulation/test_tool_simulator.py new file mode 100644 index 0000000..7583d10 --- /dev/null +++ b/tests/strands_evals/simulation/test_tool_simulator.py @@ -0,0 +1,419 @@ +"""Tests for ToolSimulator class.""" + +from typing import Any, Dict +from unittest.mock import MagicMock + +import pytest +from strands import tool + +from strands_evals.case import Case +from strands_evals.simulation.tool_simulator import StateRegistry, ToolSimulator + + +@pytest.fixture +def sample_case(): + """Fixture providing a sample test case.""" + return Case( + input="I want to test tool simulation", + metadata={"task_description": "Complete tool simulation test"}, + ) + + +@pytest.fixture +def mock_model(): + """Fixture providing a mock model for testing.""" + mock = MagicMock() + + # Mock the structured_output method + def mock_structured_output(output_type, messages, system_prompt=None): + # Simulate streaming response + yield {"contentBlockDelta": {"text": '{"result": "mocked response"}'}} + + mock.structured_output = mock_structured_output + return mock + + +def test_tool_simulator_init(): + """Test ToolSimulator initialization with all parameters.""" + custom_registry = StateRegistry() + + simulator = ToolSimulator( + state_registry=custom_registry, + model=None, + ) + + assert simulator.state_registry is custom_registry + assert simulator.model is None # model is now used for LLM inference + assert simulator.function_tool_prompt is not None # Check that prompt template is loaded + + +def test_tool_decorator_registration(): + """Test tool decorator registration.""" + simulator = ToolSimulator() + + @simulator.tool() + @tool + def test_function(x: int, y: str) -> dict: + """A sample function for testing.""" + return {"x": x, "y": y} + + assert "test_function" in simulator._registered_tools + registered_tool = simulator._registered_tools["test_function"] + assert registered_tool.name == "test_function" + assert registered_tool.function == test_function + + +def test_tool_decorator_with_name(): + """Test tool decorator with custom name.""" + simulator = ToolSimulator() + + @simulator.tool(name="custom_name") + @tool + def test_function(x: int) -> dict: + """A sample function for testing.""" + return {"x": x} + + assert "custom_name" in simulator._registered_tools + registered_tool = simulator._registered_tools["custom_name"] + assert registered_tool.name == "custom_name" + assert registered_tool.function == test_function + + +def test_tool_simulation(mock_model): + """Test tool simulation.""" + simulator = ToolSimulator(model=mock_model) + + # Register tool + @simulator.tool() + @tool + def test_func(message: str) -> dict: + """Test function that should be simulated.""" + pass + + # Mock the Agent constructor and its result to avoid real LLM calls + mock_agent_instance = MagicMock() + mock_result = MagicMock() + # Mock __str__ method to return expected JSON string + mock_result.__str__ = MagicMock(return_value='{"result": "simulated response"}') + mock_agent_instance.return_value = mock_result + + with pytest.MonkeyPatch().context() as m: + # Mock the Agent class constructor + m.setattr("strands_evals.simulation.tool_simulator.Agent", lambda **kwargs: mock_agent_instance) + + # Execute simulated function + result = simulator.test_func("Hello, world!") + + assert result == {"result": "simulated response"} + assert mock_agent_instance.called + + +def test_list_tools(): + """Test listing registered tools.""" + simulator = ToolSimulator() + + @simulator.tool() + @tool + def func1(): + pass + + @simulator.tool() + @tool + def func2(): + pass + + tools = simulator.list_tools() + + assert set(tools) == {"func1", "func2"} + + +def test_sharedstate_registry(mock_model): + """Test that tools can share the same state registry.""" + shared_state_id = "shared_banking_state" + initial_state = "Initial banking system state with account balances" + + simulator = ToolSimulator(model=mock_model) + + # Register tools that share the same state + @simulator.tool(share_state_id=shared_state_id, initial_state_description=initial_state) + @tool + def check_balance(account_id: str): + """Check account balance.""" + pass + + @simulator.tool(share_state_id=shared_state_id, initial_state_description=initial_state) + @tool + def transfer_funds(from_account: str, to_account: str): + """Transfer funds between accounts.""" + pass + + @simulator.tool(share_state_id=shared_state_id, initial_state_description=initial_state) + @tool + def get_transactions(account_id: str): + """Get transaction history.""" + pass + + # Mock the Agent constructor to avoid real LLM calls + mock_agent_instances = [] + expected_responses = [ + {"balance": 1000, "currency": "USD"}, # Function response + {"status": "success", "message": "Transfer completed"}, # Function response + {"transactions": []}, # Function response + ] + + def create_mock_agent(**kwargs): + mock_agent = MagicMock() + mock_result = MagicMock() + + if len(mock_agent_instances) < len(expected_responses): + response = expected_responses[len(mock_agent_instances)] + import json + + # Mock __str__ method to return JSON string + mock_result.__str__ = MagicMock(return_value=json.dumps(response)) + + mock_agent.return_value = mock_result + mock_agent_instances.append(mock_agent) + return mock_agent + + with pytest.MonkeyPatch().context() as m: + # Mock the Agent class constructor + m.setattr("strands_evals.simulation.tool_simulator.Agent", create_mock_agent) + + # Execute each tool in order + balance_result = simulator.check_balance("12345") + transfer_result = simulator.transfer_funds("12345", "67890") + transactions_result = simulator.get_transactions("12345") + + # Verify results + assert balance_result == {"balance": 1000, "currency": "USD"} + assert transfer_result == {"status": "success", "message": "Transfer completed"} + assert transactions_result == {"transactions": []} + + # Verify all agents were called + assert len(mock_agent_instances) == 3 + for agent in mock_agent_instances: + assert agent.called + + # Verify all tools accessed the same shared state + shared_state = simulator.state_registry.get_state(shared_state_id) + assert "initial_state" in shared_state + assert shared_state["initial_state"] == initial_state + assert "previous_calls" in shared_state + assert len(shared_state["previous_calls"]) == 3 + + # Check that all three tool calls are recorded in the shared state + tool_names = [call["tool_name"] for call in shared_state["previous_calls"]] + assert "check_balance" in tool_names + assert "transfer_funds" in tool_names + assert "get_transactions" in tool_names + + # Verify each tool call recorded its parameters correctly + for call in shared_state["previous_calls"]: + assert "parameters" in call + + +def test_cache_tool_call_function(): + """Test recording function call in state registry.""" + registry = StateRegistry() + + registry.cache_tool_call( + tool_name="test_tool", + state_key="test_state", + response_data={"result": "success"}, + parameters={"param": "value"}, + ) + + state = registry.get_state("test_state") + assert "previous_calls" in state + assert len(state["previous_calls"]) == 1 + call = state["previous_calls"][0] + assert call["tool_name"] == "test_tool" + assert call["parameters"] == {"param": "value"} + assert call["response"] == {"result": "success"} + + +def test_tool_not_found_raises_error(): + """Test that accessing non-existent tools raises ValueError.""" + simulator = ToolSimulator() + + # Test that accessing a non-existent tool via __getattr__ raises AttributeError + with pytest.raises(AttributeError) as excinfo: + _ = simulator.nonexistent_tool + + assert "not found in registered tools" in str(excinfo.value) + + +def test_clear_tools(): + """Test clearing tool registry for a specific simulator instance.""" + simulator = ToolSimulator() + + @simulator.tool() + @tool + def test_func(): + pass + + assert len(simulator._registered_tools) == 1 + + simulator.clear_tools() + + assert len(simulator._registered_tools) == 0 + + +def test_attaching_tool_simulator_to_strands_agent(): + """Test attaching tool simulator to Strands agent.""" + simulator = ToolSimulator() + + # Register a tool simulator + @simulator.tool() + @tool + def test_tool(input_value: str) -> Dict[str, Any]: + """Test tool for agent attachment. + + Args: + input_value: Input parameter for processing + """ + pass + + # Get the tool wrapper + tool_wrapper = simulator.get_tool("test_tool") + + # Create a Strands Agent with the tool simulator + from strands import Agent + + agent = Agent(tools=[tool_wrapper]) + + # Verify the agent has access to the tool + assert "test_tool" in agent.tool_names + assert hasattr(agent.tool, "test_tool") + + +def test_get_state_method(): + """Test the get_state method for direct state access.""" + simulator = ToolSimulator() + + @simulator.tool(share_state_id="test_state", initial_state_description="Test initial state") + @tool + def test_tool(): + pass + + # Test get_state method + state = simulator.get_state("test_state") + assert "initial_state" in state + assert state["initial_state"] == "Test initial state" + assert "previous_calls" in state + + +def test_output_schema_parameter(): + """Test that output_schema parameter is accepted and stored.""" + from pydantic import BaseModel + + class CustomOutput(BaseModel): + result: str + count: int + + simulator = ToolSimulator() + + @simulator.tool(output_schema=CustomOutput) + @tool + def test_tool_with_schema(): + pass + + registered_tool = simulator._registered_tools["test_tool_with_schema"] + assert registered_tool.output_schema == CustomOutput + + +def test_automatic_input_model_as_output_schema(): + """Test that input_model is automatically used as output_schema when no explicit schema is provided.""" + from pydantic import BaseModel + + simulator = ToolSimulator() + + # Create a tool with typed parameters - this should generate an input_model + @simulator.tool() + @tool + def test_tool_with_parameters(name: str, age: int, active: bool = True) -> dict: + """A test tool with typed parameters. + + Args: + name: The person's name + age: The person's age + active: Whether the person is active + """ + pass + + # Check that the tool was registered + assert "test_tool_with_parameters" in simulator._registered_tools + registered_tool = simulator._registered_tools["test_tool_with_parameters"] + + # Verify that output_schema was automatically set from input_model + assert registered_tool.output_schema is not None, "output_schema should be automatically set from input_model" + + # Verify it's a Pydantic BaseModel class + assert issubclass(registered_tool.output_schema, BaseModel), "output_schema should be a Pydantic BaseModel" + + # Check the schema has the expected fields + schema = registered_tool.output_schema.model_json_schema() + properties = schema.get("properties", {}) + assert "name" in properties, "Schema should have 'name' field" + assert "age" in properties, "Schema should have 'age' field" + assert "active" in properties, "Schema should have 'active' field" + + +def test_explicit_output_schema_override(): + """Test that explicit output_schema takes precedence over automatic input_model.""" + from pydantic import BaseModel + + simulator = ToolSimulator() + + class CustomOutput(BaseModel): + result: str + count: int + + # Create a tool with explicit output_schema + @simulator.tool(output_schema=CustomOutput) + @tool + def test_tool_with_explicit_schema(name: str, age: int) -> dict: + """A test tool with explicit output_schema. + + Args: + name: The person's name + age: The person's age + """ + pass + + # Check that the tool was registered + assert "test_tool_with_explicit_schema" in simulator._registered_tools + registered_tool = simulator._registered_tools["test_tool_with_explicit_schema"] + + # Verify that the explicit output_schema is used, not the input_model + assert registered_tool.output_schema is CustomOutput, "Explicit output_schema should be used" + + +def test_no_parameters_tool_input_model(): + """Test tool with no parameters uses empty input_model as output_schema.""" + from pydantic import BaseModel + + simulator = ToolSimulator() + + # Create a tool with no parameters + @simulator.tool() + @tool + def test_tool_no_params() -> dict: + """A test tool with no parameters.""" + pass + + # Check that the tool was registered + assert "test_tool_no_params" in simulator._registered_tools + registered_tool = simulator._registered_tools["test_tool_no_params"] + + # The output_schema should be the input_model (which should be an empty model for no-param tools) + if registered_tool.output_schema is not None: + # If there's an output_schema, it should be a BaseModel + assert issubclass(registered_tool.output_schema, BaseModel), "output_schema should be a Pydantic BaseModel" + + # Check that it has no required properties (since no parameters) + schema = registered_tool.output_schema.model_json_schema() + properties = schema.get("properties", {}) + # No parameters should mean no properties or empty properties + assert len(properties) == 0, "Tool with no parameters should have empty schema properties"