From 86159c2d097eac92a79e4f33b14c4c4d4a498226 Mon Sep 17 00:00:00 2001 From: Darren Wang Date: Thu, 22 Jan 2026 00:31:28 +0000 Subject: [PATCH 01/15] init tool simulator pr --- src/strands_evals/simulation/__init__.py | 7 +- .../simulation/prompt_templates/__init__.py | 14 +- .../tool_override_generation.py | 131 +++ .../tool_response_generation.py | 219 +++++ .../simulation/tool_simulator.py | 765 ++++++++++++++++++ src/strands_evals/types/simulation/tool.py | 183 +++++ .../simulation/test_tool_simulator.py | 629 ++++++++++++++ 7 files changed, 1945 insertions(+), 3 deletions(-) create mode 100644 src/strands_evals/simulation/prompt_templates/tool_override_generation.py create mode 100644 src/strands_evals/simulation/prompt_templates/tool_response_generation.py create mode 100644 src/strands_evals/simulation/tool_simulator.py create mode 100644 src/strands_evals/types/simulation/tool.py create mode 100644 tests/strands_evals/simulation/test_tool_simulator.py diff --git a/src/strands_evals/simulation/__init__.py b/src/strands_evals/simulation/__init__.py index 6a4be0f..98c7593 100644 --- a/src/strands_evals/simulation/__init__.py +++ b/src/strands_evals/simulation/__init__.py @@ -1,6 +1,11 @@ from .actor_simulator import ActorSimulator +from .tool_simulator import ToolSimulator # Alias for backward compatibility UserSimulator = ActorSimulator -__all__ = ["ActorSimulator", "UserSimulator"] +__all__ = [ + "ActorSimulator", + "UserSimulator", + "ToolSimulator", +] diff --git a/src/strands_evals/simulation/prompt_templates/__init__.py b/src/strands_evals/simulation/prompt_templates/__init__.py index 0d0771d..ab39053 100644 --- a/src/strands_evals/simulation/prompt_templates/__init__.py +++ b/src/strands_evals/simulation/prompt_templates/__init__.py @@ -1,11 +1,21 @@ -"""Prompt templates for actor simulation.""" +"""Prompt templates for simulation components.""" from .actor_profile_extraction import ACTOR_PROFILE_PROMPT_TEMPLATE from .actor_system_prompt import DEFAULT_USER_SIMULATOR_PROMPT_TEMPLATE from .goal_completion import GOAL_COMPLETION_PROMPT +from .tool_override_generation import TOOL_OVERRIDE_GENERATION_PROMPT +from .tool_response_generation import ( + API_TOOL_RESPONSE_GENERATION_PROMPT, + FUNCTION_TOOL_RESPONSE_GENERATION_PROMPT, + MCP_TOOL_RESPONSE_GENERATION_PROMPT, +) __all__ = [ "ACTOR_PROFILE_PROMPT_TEMPLATE", - "DEFAULT_USER_SIMULATOR_PROMPT_TEMPLATE", + "DEFAULT_USER_SIMULATOR_PROMPT_TEMPLATE", "GOAL_COMPLETION_PROMPT", + "TOOL_OVERRIDE_GENERATION_PROMPT", + "API_TOOL_RESPONSE_GENERATION_PROMPT", + "FUNCTION_TOOL_RESPONSE_GENERATION_PROMPT", + "MCP_TOOL_RESPONSE_GENERATION_PROMPT", ] diff --git a/src/strands_evals/simulation/prompt_templates/tool_override_generation.py b/src/strands_evals/simulation/prompt_templates/tool_override_generation.py new file mode 100644 index 0000000..5583e64 --- /dev/null +++ b/src/strands_evals/simulation/prompt_templates/tool_override_generation.py @@ -0,0 +1,131 @@ +""" +Prompt template for tool override generation in Strands Evals. + +This module contains the prompt template used to analyze test scenarios and determine +optimal tool simulation strategies for agent evaluation workflows. It applies scientific +tool categorization to ensure consistent and appropriate simulation decisions across +different tool types and usage contexts. +""" + +from textwrap import dedent + +TOOL_OVERRIDE_GENERATION_PROMPT = dedent( + """You are an expert at analyzing test scenarios and determining optimal tool simulation strategies for agent evaluation workflows. + +Your primary objective is to apply SCIENTIFIC TOOL CATEGORIZATION and ensure CONSISTENCY in tool simulation decisions. + +## Scenario +{scenario} + +## Available Tools +{tools_json} + +## Scientific Tool Categorization Framework + +Based on comprehensive analysis of MCP servers and tool libraries, tools fall into four primary categories: + +### CATEGORY 1: COMPUTE TOOLS (Default: REAL) +**Characteristics**: Pure computational functions, no side effects, no state changes +**Examples**: Mathematical operations, FFT, string manipulation, date formatting, validation +**Simulation Strategy**: Connect directly via MCP - these are safe and deterministic +**Rationale**: No external dependencies, consistent results, low security risk + +### CATEGORY 2: DATABASE/PERSISTENT STATE TOOLS (Default: SIMULATE) +**Characteristics**: CRUD operations, booking systems, inventory management, resource allocation +**Examples**: create_booking(), update_inventory(), delete_user(), query_orders() +**Simulation Strategy**: MUST use synthetic/dummy databases with relevant test data +**Rationale**: Cannot connect to production DBs; subsequent operations depend on consistent state +**Critical Rule**: If ANY tool modifies a resource, ALL tools operating on that resource MUST be simulated + +### CATEGORY 3: ML MODEL TOOLS (Default: CONTEXT-DEPENDENT) +**Characteristics**: Calls to other ML models, AI services, content generation +**Examples**: image_generator(), text_classifier(), sentiment_analyzer(), llm_call() +**Simulation Strategy**: Evaluate based on scenario requirements and cost considerations +**Rationale**: May need human supervision; consider latency and cost implications + +### CATEGORY 4: SPECIALIZED TOOLS (Default: SIMULATE) +**Characteristics**: External integrations, infrastructure operations, specialized hardware +**Examples**: 3D renderers, CAD functions, game engines, deployment tools, notification services +**Simulation Strategy**: Require specialized support; simulate unless explicitly needed +**Rationale**: Complex dependencies, potential side effects, specialized environments + +## Consistency Rules (CRITICAL) + +**RULE 1 - Resource State Consistency**: +If tool A modifies resource R, then ALL tools B, C, D that operate on resource R MUST have the same simulation decision. +Example: cancel_flight(booking_id) simulated → get_flight_status(booking_id) must also be simulated + +**RULE 2 - Workflow Integrity**: +Tools in the same logical workflow should maintain consistent simulation decisions to preserve end-to-end test validity. + +**RULE 3 - External Service Consistency**: +If one tool calls external service S, related tools calling service S should have consistent simulation decisions. + +## Instructions + +For EACH tool, analyze: + +1. **Category Classification**: Determine which of the 4 categories (1-4) this tool belongs to +2. **Resource Dependencies**: Identify what resources/services this tool operates on +3. **Consistency Impact**: List other tools that must have matching simulation decisions +4. **Simulation Decision**: Apply category defaults, then adjust for consistency rules + +## Failure Conditions Specification + +Configure failure simulation with these parameters: + +{{ + "enabled": true, // Whether failure simulation is enabled (boolean) + "error_rate": 0.15, // Error rate between 0.0 and 1.0 (float) + "error_type": "timeout", // Error type (see allowed values below) + "error_message": "Custom error message" // Optional custom error message (string) +}} + +### Examples of Error Types: +- `"timeout"` - Request timeout errors +- `"execution_error"` - General execution failures +- `"network_error"` - Network connectivity issues +- `"authentication_error"` - Authentication failures +- `"authorization_error"` - Permission denied errors +- `"rate_limit_error"` - Rate limiting errors +- `"internal_error"` - Internal system errors + +### Failure Rate Guidelines: +- **0.0** - No failures (disabled) +- **0.01-0.05** - Low failure rate (1-5%) - production-like +- **0.1-0.2** - Medium failure rate (10-20%) - stress testing +- **0.3+** - High failure rate (30%+) - chaos engineering + +## Response Format + +{{ + "scenario_summary": "Brief summary of the scenario and testing objectives", + "resource_groups": {{ + "group_name": {{ + "description": "What this resource group represents", + "tools": ["tool1", "tool2", "tool3"], + "simulation_decision": true, + "rationale": "Why all tools in this group have the same decision" + }} + }}, + "tool_overrides": [ + {{ + "tool_name": "name_of_tool", + "category": 1, + "category_rationale": "Category 1: Compute tool - pure mathematical operation with no side effects", + "resource_dependencies": ["resource_name"], + "consistency_requirements": ["related_tool1", "related_tool2"], + "should_simulate": false, + "failure_conditions": {{ + "enabled": false, + "error_rate": 0.0, + "error_type": "execution_error", + "error_message": "Error message if failure occurs" + }}, + "rationale": "Final decision rationale considering category and consistency rules" + }} + ] +}} + +Generate only valid JSON with no markdown code blocks or additional explanation.""" +) diff --git a/src/strands_evals/simulation/prompt_templates/tool_response_generation.py b/src/strands_evals/simulation/prompt_templates/tool_response_generation.py new file mode 100644 index 0000000..7420edb --- /dev/null +++ b/src/strands_evals/simulation/prompt_templates/tool_response_generation.py @@ -0,0 +1,219 @@ +""" +Prompt templates for tool response generation in Strands Evals. + +This module contains prompt templates used to generate realistic tool responses during +agent evaluation scenarios. These templates enable LLM-powered simulation of tool +behavior when actual tools are not available or when consistent, controllable responses +are needed for evaluation purposes. + +The module provides specialized templates for different tool types: +1. Function tools - Traditional Python function calls with parameters and return values +2. MCP tools - Model Context Protocol tools with structured input/output formats +3. API tools - REST API endpoints with HTTP request/response patterns + +Each template guides an LLM to: +- Analyze the tool name, parameters, and context to understand expected behavior +- Generate realistic responses that maintain consistency across the conversation +- Follow appropriate response formats for each tool type +- Consider previous tool responses to maintain state consistency in simulations + +Key Components: +- FUNCTION_TOOL_RESPONSE_GENERATION_PROMPT: Template for Python function tool simulation +- MCP_TOOL_RESPONSE_GENERATION_PROMPT: Template for MCP tool response generation +- API_TOOL_RESPONSE_GENERATION_PROMPT: Template for REST API endpoint simulation + +These templates ensure that simulated tool responses are contextually appropriate, +format-compliant, and maintain the illusion of real tool execution for effective +agent evaluation and testing scenarios. +""" + +from textwrap import dedent + +FUNCTION_TOOL_RESPONSE_GENERATION_PROMPT = dedent( + """ +You are simulating a function tool call for agent evaluation. Generate a realistic response based on the function name, parameters, and context. + +## Function Tool Information +Tool Name: {tool_name} +Parameters: {parameters} + +## Previous Tool Responses (for context) +{previous_responses} + +## Instructions +1. Analyze the function name and parameters to understand what this tool should do +2. Generate a realistic response that would be returned by such a function +3. Consider the previous responses to maintain consistency in the simulation +4. Return valid JSON that represents the function's return value + +## Response Format +Return a JSON object that represents what this function would return. Examples: + +For data retrieval functions: +```json +{{ + "status": "success", + "data": {{ + "result": "retrieved data", + "count": 42 + }} +}} +``` + +For action functions: +```json +{{ + "status": "success", + "message": "Action completed successfully", + "transaction_id": "txn_12345" +}} +``` + +For calculation functions: +```json +{{ + "result": 156.78, + "unit": "meters" +}} +``` + +Generate only valid JSON with no markdown code blocks or additional explanation.""" +) + +MCP_TOOL_RESPONSE_GENERATION_PROMPT = dedent( + """You are simulating an MCP (Model Context Protocol) tool call for agent evaluation. Generate a realistic response based on the tool name, input payload, and context. + +## MCP Tool Information +Tool Name: {tool_name} +Input Payload: {mcp_payload} + +## Previous Tool Responses (for context) +{previous_responses} + +## Instructions +1. Analyze the tool name and input payload to understand what this MCP tool should do +2. Generate a realistic response following MCP response format +3. Consider the previous responses to maintain consistency in the simulation +4. Return valid JSON in MCP response format + +## MCP Response Format +MCP tools return responses in this format: + +For successful operations: +```json +{{ + "content": [ + {{ + "type": "text", + "text": "Operation completed successfully. Retrieved 5 items." + }} + ] +}} +``` + +For data operations: +```json +{{ + "content": [ + {{ + "type": "text", + "text": "Found user profile for john.doe" + }}, + {{ + "type": "resource", + "resource": {{ + "uri": "user://john.doe", + "name": "John Doe Profile", + "mimeType": "application/json" + }} + }} + ] +}} +``` + +For errors: +```json +{{ + "isError": true, + "content": [ + {{ + "type": "text", + "text": "Error: User not found" + }} + ] +}} +``` + +Generate only valid JSON with no markdown code blocks or additional explanation.""" +) + +API_TOOL_RESPONSE_GENERATION_PROMPT = dedent( + """You are simulating an API tool call for agent evaluation. Generate a realistic HTTP response based on the API endpoint, method, payload, and context. + +## API Tool Information +Tool Name: {tool_name} +Path: {path} +Method: {method} +Request Payload: {api_payload} + +## Previous Tool Responses (for context) +{previous_responses} + +## Instructions +1. Analyze the API path, method, and payload to understand what this endpoint should do +2. Generate a realistic HTTP response with appropriate status code and data +3. Consider the previous responses to maintain consistency in the simulation +4. Return valid JSON in HTTP response format + +## HTTP Response Format +API responses should include status codes and appropriate data: + +For successful GET requests: +```json +{{ + "status": 200, + "data": {{ + "id": 123, + "name": "Example Item", + "created_at": "2024-01-15T10:30:00Z" + }} +}} +``` + +For successful POST requests: +```json +{{ + "status": 201, + "data": {{ + "id": 456, + "message": "Resource created successfully" + }} +}} +``` + +For errors: +```json +{{ + "status": 404, + "error": {{ + "type": "not_found", + "title": "Not Found", + "detail": "The requested resource could not be found" + }} +}} +``` + +For validation errors: +```json +{{ + "status": 400, + "error": {{ + "type": "validation_error", + "title": "Bad Request", + "detail": "Missing required field: email" + }} +}} +``` + +Generate only valid JSON with no markdown code blocks or additional explanation.""" +) diff --git a/src/strands_evals/simulation/tool_simulator.py b/src/strands_evals/simulation/tool_simulator.py new file mode 100644 index 0000000..27b0963 --- /dev/null +++ b/src/strands_evals/simulation/tool_simulator.py @@ -0,0 +1,765 @@ +import inspect +import json +import logging +import random +from typing import Any, Callable, Dict, List, Optional + +from strands import Agent +from strands.models.bedrock import BedrockModel +from strands.models.model import Model + +from strands_evals.case import Case +from strands_evals.types.simulation.tool import ( + FailureConditions, + RegisteredTool, + StateRegistry, + ToolOverrideConfig, + ToolType, +) + +logger = logging.getLogger(__name__) + + +class ToolSimulator: + """ + Simulates tool behavior with decorator-based registration system for agent evaluation. + + ToolSimulator provides decorator functions for different tool types and maintains + a registry of all registered tools. It can be configured to override tool + behavior for simulation purposes, enabling controlled testing scenarios. + + Attributes: + tool_overrides: Dictionary mapping tool names to override configurations. + system_prompt_template: Template string for system prompts. + model: Provider for running inference or model identifier for Bedrock. + _active_simulators: Dictionary of active tool simulators. + _registered_tools: Class-level registry for all registered tools. + _state_registry: Registry for maintaining tool state across calls. + """ + + # Class-level registry for all registered tools + _registered_tools: Dict[str, RegisteredTool] = {} + _state_registry: Optional[StateRegistry] = None + + def __init__( + self, + tool_overrides: Optional[Dict[str, ToolOverrideConfig]] = None, + state_registry: Optional[StateRegistry] = None, + system_prompt_template: Optional[str] = None, + model: Model | str | None = None, + ): + """ + Initialize a ToolSimulator instance. + + Args: + tool_overrides: Dictionary mapping tool names to ToolOverrideConfig instances + state_registry: Registry for maintaining tool state + system_prompt_template: Template for system prompts + model: Provider for running inference or a string representing the model-id for Bedrock to use + """ + self.tool_overrides = tool_overrides or {} + self.system_prompt_template = system_prompt_template + + # Initialize model following Agent pattern + self.model = BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model + + # Set up state registry + if state_registry: + self._state_registry = state_registry + elif self._state_registry is None: + self._state_registry = StateRegistry() + + # Initialize tool simulators for registered tools + self._active_simulators: Dict[str, Any] = {} + self._initialize_simulators() + + def _function_has_implementation(self, func: Callable) -> bool: + """Check if a function has actual implementation or is just an empty stub.""" + try: + import dis + # Get function bytecode + bytecode = list(dis.get_instructions(func)) + + # Check if function only contains simple return patterns + if len(bytecode) <= 3: + load_const_none_count = sum( + 1 for instr in bytecode if instr.opname == "LOAD_CONST" and instr.argval is None + ) + return_count = sum(1 for instr in bytecode if instr.opname == "RETURN_VALUE") + + if load_const_none_count >= 1 and return_count == 1 and len(bytecode) <= 3: + return False + + return True + except Exception: + # If we can't analyze bytecode, assume it's implemented + return True + + def _initialize_simulators(self): + """Initialize simulators for all registered tools with overrides.""" + for tool_name, registered_tool in self._registered_tools.items(): + # Check if we have override config for this tool + override_config = self.tool_overrides.get(tool_name) + + if override_config: + # Create simulator with override configuration + simulator = self._create_tool_simulator(registered_tool, override_config) + self._active_simulators[tool_name] = simulator + else: + # Check if function is an empty stub or has real implementation + if registered_tool.function and self._function_has_implementation(registered_tool.function): + # Use real function for implemented functions + self._active_simulators[tool_name] = registered_tool.function + else: + # Create default simulator for empty stubs + default_config = ToolOverrideConfig() + simulator = self._create_tool_simulator(registered_tool, default_config) + self._active_simulators[tool_name] = simulator + + def _create_tool_simulator(self, registered_tool: RegisteredTool, config: ToolOverrideConfig) -> Any: + """Create a tool simulator instance based on the registered tool type.""" + # Determine state key from tool name or simulator kwargs + state_key = ( + registered_tool.simulator_kwargs.get("share_state_id", registered_tool.name) + if registered_tool.simulator_kwargs + else registered_tool.name + ) + + # Create wrapper function that handles the simulation + if registered_tool.tool_type == ToolType.FUNCTION: + return self._create_function_simulator_wrapper(registered_tool, registered_tool.tool_type, state_key) + elif registered_tool.tool_type == ToolType.MCP: + return self._create_mcp_simulator_wrapper(registered_tool, registered_tool.tool_type, state_key) + elif registered_tool.tool_type == ToolType.API: + return self._create_api_simulator_wrapper(registered_tool, registered_tool.tool_type, state_key) + else: + raise ValueError(f"Unsupported tool type: {registered_tool.tool_type}") + + def _create_function_simulator_wrapper(self, registered_tool: RegisteredTool, tool_type: ToolType, state_key: str) -> Callable: + """Create a wrapper function for function tool simulation.""" + def wrapper(*args, **kwargs): + try: + # Build parameters as expected by simulation + parameters_string = ( + json.dumps({"args": args, "kwargs": kwargs}, indent=2) + if args + else json.dumps(kwargs, indent=2) + ) + + # Get tool behavior configuration from tool overrides + tool_override_config = {} + if registered_tool.name in self.tool_overrides: + override_config = self.tool_overrides[registered_tool.name] + if override_config.failure_conditions: + tool_override_config["failure_conditions"] = override_config.failure_conditions.model_dump() + else: + tool_override_config["failure_conditions"] = {"enabled": False} + + input_data = { + "tool_name": registered_tool.name, + "parameters": parameters_string, + "tool_override": tool_override_config, + } + + return self._simulate_tool_call(tool_type, state_key, input_data) + except Exception as e: + logger.error(f"Error in function simulation for {registered_tool.name}: {e}") + raise + + # Copy function metadata + if registered_tool.function: + wrapper.__name__ = registered_tool.function.__name__ + try: + wrapper.__signature__ = inspect.signature(registered_tool.function) # type: ignore + except (ValueError, TypeError): + pass + wrapper.__doc__ = registered_tool.function.__doc__ + else: + wrapper.__name__ = registered_tool.name + + return wrapper + + def _create_mcp_simulator_wrapper(self, registered_tool: RegisteredTool, tool_type: ToolType, state_key: str) -> Callable: + """Create a wrapper function for MCP tool simulation.""" + def wrapper(**params): + try: + # Get tool behavior configuration from tool overrides + tool_override_config = {} + if registered_tool.name in self.tool_overrides: + override_config = self.tool_overrides[registered_tool.name] + if override_config.failure_conditions: + tool_override_config["failure_conditions"] = override_config.failure_conditions.model_dump() + else: + tool_override_config["failure_conditions"] = {"enabled": False} + + input_data = { + "tool_name": registered_tool.name, + "input_mcp_payload": params, + "tool_override": tool_override_config, + } + + return self._simulate_tool_call(tool_type, state_key, input_data) + except Exception as e: + logger.error(f"Error in MCP simulation for {registered_tool.name}: {e}") + raise + + wrapper.__name__ = registered_tool.name + return wrapper + + def _create_api_simulator_wrapper(self, registered_tool: RegisteredTool, tool_type: ToolType, state_key: str) -> Callable: + """Create a wrapper function for API tool simulation.""" + def wrapper(**kwargs): + try: + # Get tool behavior configuration from tool overrides + tool_override_config = {} + if registered_tool.name in self.tool_overrides: + override_config = self.tool_overrides[registered_tool.name] + if override_config.failure_conditions: + tool_override_config["failure_conditions"] = override_config.failure_conditions.model_dump() + else: + tool_override_config["failure_conditions"] = {"enabled": False} + + input_data = { + "tool_name": registered_tool.name, + "user_input_api_payload": kwargs, + "path": registered_tool.api_path or "", + "method": registered_tool.api_method or "GET", + "tool_override": tool_override_config, + } + + return self._simulate_tool_call(tool_type, state_key, input_data) + except Exception as e: + logger.error(f"Error in API simulation for {registered_tool.name}: {e}") + raise + + wrapper.__name__ = registered_tool.name + return wrapper + + def _simulate_tool_call(self, tool_type: ToolType, state_key: str, input_data: Dict[str, Any]) -> Any: + """Simulate a tool invocation and return the response.""" + # Handle tool behavior configuration + tool_override = input_data.get("tool_override", {}) + + # Check for failure conditions + failure_conditions = tool_override.get("failure_conditions", {}) + if failure_conditions and failure_conditions.get("enabled", False): + error_rate = failure_conditions.get("error_rate", 0.0) + if random.random() < error_rate: + error_type = failure_conditions.get("error_type", "execution_error") + error_message = failure_conditions.get("error_message", "An error occurred") + + if tool_type == ToolType.API: + return self._create_error_response(error_type, error_message) + elif tool_type in [ToolType.FUNCTION, ToolType.MCP]: + return { + "status": "error", + "error_type": error_type, + "message": error_message + } + + # Route to appropriate handler based on tool type + if tool_type == ToolType.FUNCTION: + return self._handle_function_tool(input_data, state_key) + elif tool_type == ToolType.MCP: + return self._handle_mcp_tool(input_data, state_key) + elif tool_type == ToolType.API: + return self._handle_api_tool(input_data, state_key) + else: + return self._create_error_response("unsupported_tool_type", f"Tool type '{tool_type}' not supported") + + def _handle_function_tool(self, input_data: Dict[str, Any], state_key: str) -> Dict[str, Any]: + """Handle function tool simulation.""" + tool_name = input_data.get("tool_name", "") + parameters = input_data.get("parameters", {}) + + if not tool_name: + return {"status": "error", "error_type": "missing_tool_name", "message": "Tool name is required"} + + # Generate response using LLM + try: + from .prompt_templates.tool_response_generation import FUNCTION_TOOL_RESPONSE_GENERATION_PROMPT + + prompt = FUNCTION_TOOL_RESPONSE_GENERATION_PROMPT.format( + tool_name=tool_name, + parameters=json.dumps(parameters, indent=2) if parameters else "{}", + previous_responses=json.dumps(self._state_registry.get_state(state_key), indent=2) or "{}" + ) + + llm_response = self._generate_llm_response(prompt) + response_data = self._parse_llm_response(llm_response) + + # Record the call + self._state_registry.record_function_call(tool_name, state_key, parameters, response_data) + + return response_data + + except Exception as e: + logger.error(f"Error generating function response: {e}") + return {"status": "error", "error_type": "generation_error", "message": str(e)} + + def _handle_mcp_tool(self, input_data: Dict[str, Any], state_key: str) -> Dict[str, Any]: + """Handle MCP tool simulation.""" + tool_name = input_data.get("tool_name", "") + input_mcp_payload = input_data.get("input_mcp_payload", {}) + + if not tool_name: + return { + "isError": True, + "content": [{"type": "text", "text": "Tool name is required"}] + } + + try: + from .prompt_templates.tool_response_generation import MCP_TOOL_RESPONSE_GENERATION_PROMPT + + prompt = MCP_TOOL_RESPONSE_GENERATION_PROMPT.format( + tool_name=tool_name, + mcp_payload=json.dumps(input_mcp_payload, indent=2) if input_mcp_payload else "{}", + previous_responses=json.dumps(self._state_registry.get_state(state_key), indent=2) or "{}" + ) + + llm_response = self._generate_llm_response(prompt) + response_data = self._parse_llm_response(llm_response) + + # Record the call + self._state_registry.record_mcp_tool_call(tool_name, state_key, input_mcp_payload, response_data) + + return response_data + + except Exception as e: + logger.error(f"Error generating MCP response: {e}") + return { + "isError": True, + "content": [{"type": "text", "text": f"Error generating response: {str(e)}"}] + } + + def _handle_api_tool(self, input_data: Dict[str, Any], state_key: str) -> Dict[str, Any]: + """Handle API tool simulation.""" + tool_name = input_data.get("tool_name", "") + user_input_api_payload = input_data.get("user_input_api_payload", {}) + path = input_data.get("path", "") + method = input_data.get("method", "GET") + + if not tool_name: + return self._create_error_response("missing_tool_name", "Tool name is required", 400) + + try: + from .prompt_templates.tool_response_generation import API_TOOL_RESPONSE_GENERATION_PROMPT + + prompt = API_TOOL_RESPONSE_GENERATION_PROMPT.format( + tool_name=tool_name, + path=path, + method=method, + api_payload=json.dumps(user_input_api_payload, indent=2) if user_input_api_payload else "{}", + previous_responses=json.dumps(self._state_registry.get_state(state_key), indent=2) or "{}" + ) + + llm_response = self._generate_llm_response(prompt) + response_data = self._parse_llm_response(llm_response) + + # Ensure proper API response format + if "status" not in response_data: + response_data = {"status": 200, "data": response_data} + + # Record the call + self._state_registry.record_api_call(tool_name, state_key, path, method, user_input_api_payload, response_data) + + return response_data + + except Exception as e: + logger.error(f"Error generating API response: {e}") + return self._create_error_response("generation_error", str(e), 500) + + def _generate_llm_response(self, prompt: str) -> str: + """ + Generate LLM response using the model for a given prompt. + + Args: + prompt: The prompt string to send to the LLM + + Returns: + Raw LLM response text + + Raises: + Exception: If LLM generation fails + """ + try: + # Create message for model inference + messages = [{"role": "user", "content": [{"text": prompt}]}] + + # Generate response + llm_response = "" + for event in self.model.structured_output(str, messages, system_prompt=self.system_prompt_template): + if hasattr(event, 'get') and event.get("contentBlockDelta"): + delta = event["contentBlockDelta"] + if "text" in delta: + llm_response += delta["text"] + elif hasattr(event, 'get') and event.get("message"): + # Handle final message + content = event["message"].get("content", []) + for block in content: + if "text" in block: + llm_response += block["text"] + elif hasattr(event, 'get') and event.get("output"): + # Handle structured output result + return str(event["output"]) + + return llm_response + + except Exception as e: + logger.error(f"Error generating LLM response: {e}") + raise + + def _parse_llm_response(self, llm_response: str) -> Dict[str, Any]: + """Parse LLM response with fallback handling.""" + try: + return json.loads(llm_response) + except json.JSONDecodeError: + # Try to extract JSON from code blocks + import re + json_matches = re.findall(r'```(?:json)?\s*([\s\S]*?)\s*```', llm_response) + + for json_str in json_matches: + try: + return json.loads(json_str) + except json.JSONDecodeError: + continue + + # Fallback to simple text response + return {"result": llm_response} + + def _create_error_response(self, error_type: str, error_message: str, status_code: int = 400) -> Dict[str, Any]: + """Create standardized error response.""" + return { + "status": status_code, + "error": { + "type": error_type, + "title": self._get_error_title(status_code), + "detail": error_message + } + } + + def _get_error_title(self, status_code: int) -> str: + """Get error title based on status code.""" + error_titles = { + 400: 'Bad Request', + 401: 'Unauthorized', + 403: 'Forbidden', + 404: 'Not Found', + 429: 'Too Many Requests', + 500: 'Internal Server Error', + 503: 'Service Unavailable' + } + return error_titles.get(status_code, 'Error') + + @classmethod + def function_tool(cls, name: Optional[str] = None, **simulator_kwargs) -> Callable: + """ + Decorator for registering Python function tools. + + Args: + name: Optional name for the tool. If None, uses function.__name__ + **simulator_kwargs: Additional simulator configuration + + Returns: + Decorator function + """ + def decorator(func: Callable) -> Callable: + try: + tool_name = name or func.__name__ + + # Register tool + registered_tool = RegisteredTool( + name=tool_name, + tool_type=ToolType.FUNCTION, + function=func, + simulator_kwargs=simulator_kwargs, + ) + cls._registered_tools[tool_name] = registered_tool + + logger.info(f"Registered function tool: {tool_name}") + + except Exception as e: + logger.error(f"Error registering function tool {name or func.__name__}: {e}") + raise + + return func + + return decorator + + @classmethod + def mcp_tool(cls, name: Optional[str] = None, schema: Optional[Dict[str, Any]] = None, **simulator_kwargs) -> Callable: + """ + Decorator for registering MCP (Model Context Protocol) tools. + + Args: + name: Optional name for the tool. If None, uses function.__name__ + schema: MCP tool schema dictionary + **simulator_kwargs: Additional simulator configuration + + Returns: + Decorator function + """ + def decorator(func: Callable) -> Callable: + tool_name = name or func.__name__ + + if schema is None: + raise ValueError(f"MCP schema is required for tool {tool_name}") + + # Register tool + registered_tool = RegisteredTool( + name=tool_name, + tool_type=ToolType.MCP, + function=func, + mcp_schema=schema, + simulator_kwargs=simulator_kwargs, + ) + cls._registered_tools[tool_name] = registered_tool + + logger.info(f"Registered MCP tool: {tool_name}") + return func + + return decorator + + @classmethod + def api_tool( + cls, + name: Optional[str] = None, + path: Optional[str] = None, + method: Optional[str] = None, + schema: Optional[Dict[str, Any]] = None, + **simulator_kwargs, + ) -> Callable: + """ + Decorator for registering API tools. + + Args: + name: Optional name for the tool. If None, uses function.__name__ + path: API endpoint path + method: HTTP method (GET, POST, etc.) + schema: API tool schema dictionary + **simulator_kwargs: Additional simulator configuration + + Returns: + Decorator function + """ + def decorator(func: Callable) -> Callable: + tool_name = name or func.__name__ + + if path is None: + raise ValueError("API path is required") + if method is None: + raise ValueError("HTTP method is required") + + # Register tool + registered_tool = RegisteredTool( + name=tool_name, + tool_type=ToolType.API, + function=func, + api_path=path, + api_method=method, + simulator_kwargs=simulator_kwargs, + ) + cls._registered_tools[tool_name] = registered_tool + + logger.info(f"Registered API tool: {tool_name}") + return func + + return decorator + + @classmethod + def from_case_for_tool_simulator( + cls, + case: Case, + system_prompt_template: Optional[str] = None, + model: Optional[str] = None, + **kwargs, + ) -> "ToolSimulator": + """ + Create a ToolSimulator instance configured for a specific case. + + Args: + case: Case object containing test case information and metadata + system_prompt_template: Template for system prompts + model: Model identifier for LLM-based simulation + **kwargs: Additional configuration options + + Returns: + Configured ToolSimulator instance + """ + tool_overrides = cls._generate_override_from_case(case) + return cls( + tool_overrides=tool_overrides, + system_prompt_template=system_prompt_template, + model=model, + **kwargs, + ) + + @staticmethod + def _generate_override_from_case(case: Case) -> Dict[str, ToolOverrideConfig]: + """Generate tool override configuration from a case using LLM.""" + # Extract scenario description from case + scenario_description = f"Test case: {case.name or 'unnamed'}. Input: {case.input}" + if case.metadata: + scenario_description += f". Metadata: {case.metadata}" + + # Create tools list from registered tools + tools_list = [] + for tool_name, registered_tool in ToolSimulator._registered_tools.items(): + tool_info = { + "name": tool_name, + "type": registered_tool.tool_type.value, + "description": ( + getattr(registered_tool.function, "__doc__", "") + if registered_tool.function + else "" + ), + } + + # Add schema information based on tool type + if registered_tool.tool_type == ToolType.FUNCTION and registered_tool.function: + sig = inspect.signature(registered_tool.function) + parameters = {} + for param_name, param in sig.parameters.items(): + param_type = "string" + if param.annotation != inspect.Parameter.empty: + type_map = { + int: "integer", + float: "number", + bool: "boolean", + list: "array", + dict: "object", + str: "string", + } + param_type = type_map.get(param.annotation, "string") + + parameters[param_name] = { + "type": param_type, + "required": param.default == inspect.Parameter.empty, + } + + tool_info["parameters"] = parameters + + elif registered_tool.tool_type == ToolType.MCP and registered_tool.mcp_schema: + tool_info["schema"] = registered_tool.mcp_schema + + elif registered_tool.tool_type == ToolType.API: + tool_info["path"] = registered_tool.api_path + tool_info["method"] = registered_tool.api_method + + tools_list.append(tool_info) + + # If no registered tools, return empty override + if not tools_list: + logger.warning("No registered tools found for override generation") + return {} + + # Generate overrides using LLM prompt + try: + tools_json = json.dumps(tools_list, indent=2) + + # Use the tool override generation prompt + from .prompt_templates.tool_override_generation import TOOL_OVERRIDE_GENERATION_PROMPT + + prompt = TOOL_OVERRIDE_GENERATION_PROMPT.format( + scenario=scenario_description, + tools_json=tools_json, + ) + + # Generate response + agent = Agent(callback_handler=None) + result = agent(prompt) + llm_response = str(result) + + # Parse LLM response + try: + response_data = json.loads(llm_response.strip()) + except json.JSONDecodeError as e: + logger.error(f"Failed to parse LLM response as JSON: {e}") + logger.debug(f"Raw LLM response: {llm_response}") + return {} + + # Convert LLM response to ToolOverrideConfig instances + tool_configs: Dict[str, ToolOverrideConfig] = {} + tool_overrides = response_data.get("tool_overrides", []) + + for override in tool_overrides: + tool_name = override.get("tool_name") + should_simulate = override.get("should_simulate", True) + + if not tool_name or not should_simulate: + continue + + # Add failure conditions using new schema format + failure_conditions = override.get("failure_conditions", {}) + failure_conditions = { + "enabled": failure_conditions.get("enabled", False), + "error_rate": failure_conditions.get("error_rate", 0.0), + "error_type": failure_conditions.get("error_type", "execution_error"), + "error_message": failure_conditions.get("error_message"), + } + + try: + # Create FailureConditions instance + failure_conditions_instance = FailureConditions(**failure_conditions) + + # Create ToolOverrideConfig instance + tool_configs[tool_name] = ToolOverrideConfig( + failure_conditions=failure_conditions_instance, + ) + except Exception as e: + logger.warning(f"Failed to create ToolOverrideConfig for {tool_name}: {e}") + continue + + logger.info(f"Generated overrides for {len(tool_configs)} tools using LLM") + return tool_configs + + except Exception as e: + logger.error(f"Error generating overrides using LLM: {e}") + logger.warning("Falling back to empty override configuration") + return {} + + def get_tool(self, tool_name: str) -> Optional[Callable]: + """ + Get a tool by name from the active simulators. + + Args: + tool_name: Name of the tool to retrieve + + Returns: + Tool callable if found, None otherwise + """ + return self._active_simulators.get(tool_name) + + def list_tools(self) -> List[str]: + """ + List all registered tool names. + + Returns: + List of tool names + """ + return list(self._registered_tools.keys()) + + @classmethod + def clear_registry(cls): + """Clear all registered tools. Useful for testing.""" + cls._registered_tools.clear() + cls._state_registry = None + logger.info("Cleared tool registry") + + def __getattr__(self, name: str) -> Any: + """ + Allow direct access to registered tools as attributes. + + Args: + name: Tool name + + Returns: + Tool callable + + Raises: + AttributeError: If tool not found + """ + if name in self._active_simulators: + return self._active_simulators[name] + + raise AttributeError(f"Tool '{name}' not found in active simulators") diff --git a/src/strands_evals/types/simulation/tool.py b/src/strands_evals/types/simulation/tool.py new file mode 100644 index 0000000..f7bacd7 --- /dev/null +++ b/src/strands_evals/types/simulation/tool.py @@ -0,0 +1,183 @@ +from datetime import datetime +from enum import Enum +from typing import Any, Callable, Dict, Optional + +from pydantic import BaseModel, Field, field_validator, model_validator + + +class ToolType(Enum): + """ + Enumeration of supported tool types for simulation. + + Attributes: + FUNCTION: Python function tools that can be called directly. + MCP: Model Context Protocol tools with structured schemas. + API: REST API endpoints with HTTP methods and paths. + """ + FUNCTION = "function" + MCP = "mcp" + API = "api" + + +class FailureConditions(BaseModel): + """ + Configuration for failure simulation conditions. + + Attributes: + enabled: Whether failure simulation is enabled for the tool. + error_rate: Error rate between 0.0 and 1.0 for random failure injection. + error_type: Type of error to simulate when failures occur. + error_message: Optional custom error message for simulated failures. + """ + + enabled: bool = Field(default=False, description="Whether failure simulation is enabled") + error_rate: float = Field(default=0.0, ge=0.0, le=1.0, description="Error rate between 0.0 and 1.0") + error_type: str = Field(default="execution_error", description="Type of error to simulate") + error_message: Optional[str] = None + + @field_validator("error_rate") + @classmethod + def validate_error_rate(cls, v: float) -> float: + """Validate error rate is between 0 and 1.""" + if not 0.0 <= v <= 1.0: + raise ValueError("Error rate must be between 0.0 and 1.0") + return v + + @model_validator(mode='after') + def validate_enabled_state(self) -> 'FailureConditions': + """Validate that if enabled is True, error_rate is > 0.""" + if self.enabled and self.error_rate == 0.0: + raise ValueError("If failure conditions are enabled, error_rate must be greater than 0") + return self + + +class ToolOverrideConfig(BaseModel): + """ + Configuration for tool override behavior. + + Attributes: + failure_conditions: Configuration for failure simulation conditions. + """ + failure_conditions: FailureConditions = Field(default_factory=FailureConditions, description="Configuration for failure simulation") + + +class RegisteredTool(BaseModel): + """ + Represents a registered tool in the simulator. + + Attributes: + name: Name of the tool for identification and registration. + tool_type: Type of the tool (FUNCTION, MCP, or API). + function: Function callable for FUNCTION type tools (excluded from serialization). + mcp_schema: MCP tool schema dictionary for MCP type tools. + api_path: API endpoint path for API type tools. + api_method: HTTP method for API type tools (GET, POST, etc.). + simulator_kwargs: Additional simulator configuration parameters. + """ + name: str = Field(..., description="Name of the tool") + tool_type: ToolType = Field(..., description="Type of the tool") + function: Optional[Callable] = Field(default=None, description="Function callable", exclude=True) + mcp_schema: Optional[Dict[str, Any]] = Field(default=None, description="MCP tool schema") + api_path: Optional[str] = Field(default=None, description="API endpoint path") + api_method: Optional[str] = Field(default=None, description="HTTP method") + simulator_kwargs: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Additional simulator configuration") + + class Config: + arbitrary_types_allowed = True + + +class StateRegistry: + """ + Simple state registry for maintaining tool state across calls. + + Attributes: + _states: Internal dictionary mapping state keys to recorded call history. + """ + + def __init__(self): + """ + Initialize state registry. + + Creates an empty state dictionary to track tool calls and responses + across different simulation sessions. + """ + self._states: Dict[str, Dict[str, Any]] = {} + + def get_state(self, key: str) -> Dict[str, Any]: + """ + Get state for a given key. + + Args: + key: State key to retrieve recorded calls for. + + Returns: + Dictionary containing recorded call history for the key, empty if not found. + """ + return self._states.get(key, {}) + + def record_function_call(self, tool_name: str, state_key: str, parameters: Dict[str, Any], response_data: Any): + """ + Record a function call in state. + + Args: + tool_name: Name of the function tool that was called. + state_key: State key to record the call under. + parameters: Parameters passed to the function. + response_data: Response data returned from the function. + """ + if state_key not in self._states: + self._states[state_key] = {"function_calls": []} + + call_record = { + "tool_name": tool_name, + "parameters": parameters, + "response": response_data, + "timestamp": datetime.now().isoformat() + } + self._states[state_key]["function_calls"].append(call_record) + + def record_mcp_tool_call(self, tool_name: str, state_key: str, input_mcp_payload: Dict[str, Any], response_data: Any): + """ + Record an MCP tool call in state. + + Args: + tool_name: Name of the MCP tool that was called. + state_key: State key to record the call under. + input_mcp_payload: Input payload sent to the MCP tool. + response_data: Response data returned from the MCP tool. + """ + if state_key not in self._states: + self._states[state_key] = {"mcp_calls": []} + + call_record = { + "tool_name": tool_name, + "input": input_mcp_payload, + "response": response_data, + "timestamp": datetime.now().isoformat() + } + self._states[state_key]["mcp_calls"].append(call_record) + + def record_api_call(self, tool_name: str, state_key: str, path: str, method: str, input_data: Dict[str, Any], response: Dict[str, Any]): + """ + Record an API call in state. + + Args: + tool_name: Name of the API tool that was called. + state_key: State key to record the call under. + path: API endpoint path that was called. + method: HTTP method used for the API call. + input_data: Input data sent to the API endpoint. + response: Response data returned from the API endpoint. + """ + if state_key not in self._states: + self._states[state_key] = {"api_calls": []} + + call_record = { + "tool_name": tool_name, + "path": path, + "method": method, + "input": input_data, + "response": response, + "timestamp": datetime.now().isoformat() + } + self._states[state_key]["api_calls"].append(call_record) diff --git a/tests/strands_evals/simulation/test_tool_simulator.py b/tests/strands_evals/simulation/test_tool_simulator.py new file mode 100644 index 0000000..9e9c0c4 --- /dev/null +++ b/tests/strands_evals/simulation/test_tool_simulator.py @@ -0,0 +1,629 @@ +"""Tests for ToolSimulator class.""" + +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from strands_evals.simulation.tool_simulator import ToolSimulator +from strands_evals.types.simulation.tool import ( + FailureConditions, + RegisteredTool, + StateRegistry, + ToolOverrideConfig, + ToolType, +) + + +@pytest.fixture +def sample_failure_conditions(): + """Fixture providing sample failure conditions.""" + return FailureConditions( + enabled=True, + error_rate=0.5, + error_type="timeout_error", + error_message="Operation timed out", + ) + + +@pytest.fixture +def sample_tool_override_config(sample_failure_conditions): + """Fixture providing sample tool override configuration.""" + return ToolOverrideConfig( + failure_conditions=sample_failure_conditions, + scenario_config={"test_key": "test_value"}, + ) + + +@pytest.fixture +def sample_scenario(): + """Fixture providing a sample scenario dictionary.""" + return { + "name": "Banking Simulation", + "description": "Test scenario for banking operations with account balance checks", + "metadata": {"domain": "finance", "complexity": "medium"}, + } + + +@pytest.fixture +def mock_model(): + """Fixture providing a mock model for testing.""" + mock = MagicMock() + + # Mock the async generator for model.generate() + async def mock_generate(messages, system_prompt=None): + # Simulate streaming response + yield { + "contentBlockDelta": { + "text": '{"result": "mocked response"}' + } + } + + mock.generate = mock_generate + return mock + + +@pytest.fixture(autouse=True) +def clear_registry(): + """Clear tool registry before each test.""" + ToolSimulator.clear_registry() + yield + ToolSimulator.clear_registry() + + +class TestToolSimulatorInitialization: + """Test cases for ToolSimulator initialization.""" + + def test_init_with_defaults(self): + """Test ToolSimulator initialization with default parameters.""" + simulator = ToolSimulator() + + assert simulator.tool_overrides == {} + assert simulator.simulator_config == {} + assert simulator.system_prompt_template is None + assert simulator.model is not None + assert simulator._state_registry is not None + assert simulator._active_simulators == {} + + def test_init_with_model_string(self): + """Test ToolSimulator initialization with model string.""" + model_id = "anthropic.claude-3-haiku-20240307-v1:0" + simulator = ToolSimulator(model=model_id) + + assert simulator.model is not None + # The model should be configured with the provided model_id + + def test_init_with_model_object(self, mock_model): + """Test ToolSimulator initialization with model object.""" + simulator = ToolSimulator(model=mock_model) + + assert simulator.model is mock_model + + def test_init_with_tool_overrides(self, sample_tool_override_config): + """Test ToolSimulator initialization with tool overrides.""" + tool_overrides = {"test_tool": sample_tool_override_config} + simulator = ToolSimulator(tool_overrides=tool_overrides) + + assert simulator.tool_overrides == tool_overrides + + def test_init_with_custom_state_registry(self): + """Test ToolSimulator initialization with custom state registry.""" + custom_registry = StateRegistry() + simulator = ToolSimulator(state_registry=custom_registry) + + assert simulator._state_registry is custom_registry + + def test_init_with_system_prompt_template(self): + """Test ToolSimulator initialization with system prompt template.""" + template = "You are a helpful assistant simulating tools." + simulator = ToolSimulator(system_prompt_template=template) + + assert simulator.system_prompt_template == template + + +class TestToolDecorators: + """Test cases for tool decorator registration.""" + + def test_function_tool_decorator(self): + """Test function tool decorator registration.""" + @ToolSimulator.function_tool("test_function") + def sample_function(x: int, y: str) -> dict: + """A sample function for testing.""" + return {"x": x, "y": y} + + assert "test_function" in ToolSimulator._registered_tools + registered_tool = ToolSimulator._registered_tools["test_function"] + assert registered_tool.name == "test_function" + assert registered_tool.tool_type == ToolType.FUNCTION + assert registered_tool.function == sample_function + + def test_function_tool_decorator_without_name(self): + """Test function tool decorator uses function name when no name provided.""" + @ToolSimulator.function_tool() + def my_test_function(): + """Test function.""" + pass + + assert "my_test_function" in ToolSimulator._registered_tools + registered_tool = ToolSimulator._registered_tools["my_test_function"] + assert registered_tool.name == "my_test_function" + + def test_mcp_tool_decorator(self): + """Test MCP tool decorator registration.""" + schema = { + "type": "object", + "properties": {"param": {"type": "string"}}, + "required": ["param"] + } + + @ToolSimulator.mcp_tool("test_mcp", schema=schema) + def sample_mcp_tool(**params): + """A sample MCP tool for testing.""" + return {"content": [{"type": "text", "text": f"Result: {params}"}]} + + assert "test_mcp" in ToolSimulator._registered_tools + registered_tool = ToolSimulator._registered_tools["test_mcp"] + assert registered_tool.name == "test_mcp" + assert registered_tool.tool_type == ToolType.MCP + assert registered_tool.mcp_schema == schema + + def test_mcp_tool_decorator_requires_schema(self): + """Test MCP tool decorator requires schema parameter.""" + with pytest.raises(ValueError, match="MCP schema is required"): + @ToolSimulator.mcp_tool("test_mcp") + def sample_mcp_tool(**params): + pass + + def test_api_tool_decorator(self): + """Test API tool decorator registration.""" + @ToolSimulator.api_tool("test_api", path="/test", method="POST") + def sample_api_tool(**kwargs): + """A sample API tool for testing.""" + return {"status": 200, "data": kwargs} + + assert "test_api" in ToolSimulator._registered_tools + registered_tool = ToolSimulator._registered_tools["test_api"] + assert registered_tool.name == "test_api" + assert registered_tool.tool_type == ToolType.API + assert registered_tool.api_path == "/test" + assert registered_tool.api_method == "POST" + + def test_api_tool_decorator_requires_path(self): + """Test API tool decorator requires path parameter.""" + with pytest.raises(ValueError, match="API path is required"): + @ToolSimulator.api_tool("test_api", method="GET") + def sample_api_tool(**kwargs): + pass + + def test_api_tool_decorator_requires_method(self): + """Test API tool decorator requires method parameter.""" + with pytest.raises(ValueError, match="HTTP method is required"): + @ToolSimulator.api_tool("test_api", path="/test") + def sample_api_tool(**kwargs): + pass + + def test_function_tool_with_simulator_kwargs(self): + """Test function tool decorator with simulator kwargs.""" + @ToolSimulator.function_tool("test_function", share_state_id="shared_state") + def sample_function(): + pass + + registered_tool = ToolSimulator._registered_tools["test_function"] + assert registered_tool.simulator_kwargs == {"share_state_id": "shared_state"} + + +class TestToolSimulation: + """Test cases for tool simulation functionality.""" + + @patch('strands._async.run_async') + def test_function_tool_simulation(self, mock_run_async): + """Test function tool simulation.""" + # Setup mock response + mock_run_async.return_value = '{"result": "simulated response"}' + + # Register and create simulator + @ToolSimulator.function_tool("test_function") + def test_func(message: str) -> dict: + """Test function that should be simulated.""" + pass + + simulator = ToolSimulator() + + # Execute simulated function + result = simulator.test_function("Hello, world!") + + assert result == {"result": "simulated response"} + mock_run_async.assert_called_once() + + @patch('strands._async.run_async') + def test_mcp_tool_simulation(self, mock_run_async): + """Test MCP tool simulation.""" + # Setup mock response + mock_run_async.return_value = '{"content": [{"type": "text", "text": "MCP response"}]}' + + # Register and create simulator + schema = {"type": "object", "properties": {"param": {"type": "string"}}} + @ToolSimulator.mcp_tool("test_mcp", schema=schema) + def test_mcp(**params): + """Test MCP tool that should be simulated.""" + pass + + simulator = ToolSimulator() + + # Execute simulated MCP tool + result = simulator.test_mcp(param="test_value") + + assert result == {"content": [{"type": "text", "text": "MCP response"}]} + mock_run_async.assert_called_once() + + @patch('strands._async.run_async') + def test_api_tool_simulation(self, mock_run_async): + """Test API tool simulation.""" + # Setup mock response + mock_run_async.return_value = '{"status": 200, "data": {"key": "value"}}' + + # Register and create simulator + @ToolSimulator.api_tool("test_api", path="/test", method="POST") + def test_api(**kwargs): + """Test API tool that should be simulated.""" + pass + + simulator = ToolSimulator() + + # Execute simulated API tool + result = simulator.test_api(key="value") + + assert result == {"status": 200, "data": {"key": "value"}} + mock_run_async.assert_called_once() + + def test_implemented_function_uses_real_implementation(self): + """Test that functions with real implementations are not simulated.""" + @ToolSimulator.function_tool("implemented_function") + def real_function(x: int) -> dict: + """A function with real implementation.""" + return {"doubled": x * 2} + + simulator = ToolSimulator() + result = simulator.implemented_function(5) + + assert result == {"doubled": 10} + + def test_failure_conditions_trigger_error(self): + """Test that failure conditions trigger errors as expected.""" + # Register tool + @ToolSimulator.function_tool("failing_function") + def test_func(): + pass + + # Create failure conditions with 100% error rate + failure_conditions = FailureConditions( + enabled=True, + error_rate=1.0, + error_type="timeout_error", + error_message="Simulated timeout" + ) + tool_overrides = { + "failing_function": ToolOverrideConfig(failure_conditions=failure_conditions) + } + + simulator = ToolSimulator(tool_overrides=tool_overrides) + + # Function should return error due to failure conditions + result = simulator.failing_function() + + assert result["status"] == "error" + assert result["error_type"] == "timeout_error" + assert result["message"] == "Simulated timeout" + + +class TestToolRetrieval: + """Test cases for tool retrieval and listing.""" + + def test_list_tools(self): + """Test listing registered tools.""" + @ToolSimulator.function_tool("func1") + def func1(): + pass + + @ToolSimulator.function_tool("func2") + def func2(): + pass + + simulator = ToolSimulator() + tools = simulator.list_tools() + + assert set(tools) == {"func1", "func2"} + + def test_get_tool(self): + """Test getting tool by name.""" + @ToolSimulator.function_tool("test_function") + def test_func(): + return {"test": "result"} + + simulator = ToolSimulator() + tool = simulator.get_tool("test_function") + + assert tool is not None + assert callable(tool) + + def test_get_nonexistent_tool(self): + """Test getting non-existent tool returns None.""" + simulator = ToolSimulator() + tool = simulator.get_tool("nonexistent_tool") + + assert tool is None + + def test_tool_attribute_access(self): + """Test accessing tools as attributes.""" + @ToolSimulator.function_tool("test_function") + def test_func(): + return {"test": "result"} + + simulator = ToolSimulator() + + # Should be able to access as attribute + assert hasattr(simulator, "test_function") + tool = simulator.test_function + assert callable(tool) + + def test_nonexistent_tool_attribute_raises_error(self): + """Test accessing non-existent tool as attribute raises AttributeError.""" + simulator = ToolSimulator() + + with pytest.raises(AttributeError, match="Tool 'nonexistent' not found"): + _ = simulator.nonexistent + + +class TestFactoryMethods: + """Test cases for factory methods.""" + + @patch('strands._async.run_async') + def test_from_scenario_for_tool_simulator(self, mock_run_async, sample_scenario): + """Test factory method creates simulator from scenario.""" + # Register a test tool first + @ToolSimulator.function_tool("account_balance_check") + def check_balance(account_id: str) -> dict: + """Check account balance.""" + pass + + # Mock LLM response for override generation + mock_override_response = { + "tool_overrides": [{ + "tool_name": "account_balance_check", + "should_simulate": True, + "failure_conditions": { + "enabled": False, + "error_rate": 0.0 + } + }] + } + mock_run_async.return_value = json.dumps(mock_override_response) + + simulator = ToolSimulator.from_scenario_for_tool_simulator( + scenario_dict=sample_scenario, + system_prompt_template="Test template", + model="test-model" + ) + + assert simulator is not None + assert simulator.system_prompt_template == "Test template" + mock_run_async.assert_called_once() + + @patch('strands._async.run_async') + def test_generate_override_from_scenario(self, mock_run_async, sample_scenario): + """Test override generation from scenario.""" + # Register test tools + @ToolSimulator.function_tool("test_function") + def test_func(param: str) -> dict: + """Test function.""" + pass + + # Mock LLM response + mock_response = { + "tool_overrides": [{ + "tool_name": "test_function", + "should_simulate": True, + "failure_conditions": { + "enabled": True, + "error_rate": 0.1, + "error_type": "network_error", + "error_message": "Network timeout" + } + }] + } + mock_run_async.return_value = json.dumps(mock_response) + + overrides = ToolSimulator._generate_override_from_scenario(sample_scenario) + + assert "test_function" in overrides + override = overrides["test_function"] + assert override.failure_conditions.enabled is True + assert override.failure_conditions.error_rate == 0.1 + assert override.failure_conditions.error_type == "network_error" + mock_run_async.assert_called_once() + + def test_generate_override_with_no_tools(self, sample_scenario): + """Test override generation with no registered tools.""" + # Clear registry to ensure no tools + ToolSimulator.clear_registry() + + overrides = ToolSimulator._generate_override_from_scenario(sample_scenario) + + assert overrides == {} + + @patch('strands._async.run_async') + def test_generate_override_handles_llm_error(self, mock_run_async, sample_scenario): + """Test override generation handles LLM errors gracefully.""" + # Register a test tool + @ToolSimulator.function_tool("test_function") + def test_func(): + pass + + # Mock LLM to return invalid JSON + mock_run_async.return_value = "invalid json response" + + overrides = ToolSimulator._generate_override_from_scenario(sample_scenario) + + # Should return empty dict on error + assert overrides == {} + + +class TestStateRegistry: + """Test cases for StateRegistry functionality.""" + + def test_state_registry_creation(self): + """Test StateRegistry is created properly.""" + registry = StateRegistry() + + assert registry is not None + assert registry._states == {} + + def test_record_function_call(self): + """Test recording function call in state registry.""" + registry = StateRegistry() + + registry.record_function_call( + tool_name="test_tool", + state_key="test_state", + parameters={"param": "value"}, + response_data={"result": "success"} + ) + + state = registry.get_state("test_state") + assert "function_calls" in state + assert len(state["function_calls"]) == 1 + call = state["function_calls"][0] + assert call["tool_name"] == "test_tool" + assert call["parameters"] == {"param": "value"} + assert call["response"] == {"result": "success"} + + def test_record_mcp_tool_call(self): + """Test recording MCP tool call in state registry.""" + registry = StateRegistry() + + registry.record_mcp_tool_call( + tool_name="mcp_tool", + state_key="mcp_state", + input_mcp_payload={"input": "data"}, + response_data={"content": [{"type": "text", "text": "result"}]} + ) + + state = registry.get_state("mcp_state") + assert "mcp_calls" in state + assert len(state["mcp_calls"]) == 1 + call = state["mcp_calls"][0] + assert call["tool_name"] == "mcp_tool" + assert call["input"] == {"input": "data"} + + def test_record_api_call(self): + """Test recording API call in state registry.""" + registry = StateRegistry() + + registry.record_api_call( + tool_name="api_tool", + state_key="api_state", + path="/test", + method="POST", + input_data={"data": "test"}, + response={"status": 200} + ) + + state = registry.get_state("api_state") + assert "api_calls" in state + assert len(state["api_calls"]) == 1 + call = state["api_calls"][0] + assert call["tool_name"] == "api_tool" + assert call["path"] == "/test" + assert call["method"] == "POST" + + +class TestErrorHandling: + """Test cases for error handling.""" + + def test_parse_llm_response_valid_json(self): + """Test parsing valid JSON response.""" + simulator = ToolSimulator() + + response = simulator._parse_llm_response('{"key": "value"}') + + assert response == {"key": "value"} + + def test_parse_llm_response_json_in_code_block(self): + """Test parsing JSON from code blocks.""" + simulator = ToolSimulator() + + llm_text = '```json\n{"key": "value"}\n```' + response = simulator._parse_llm_response(llm_text) + + assert response == {"key": "value"} + + def test_parse_llm_response_invalid_json_fallback(self): + """Test fallback for invalid JSON.""" + simulator = ToolSimulator() + + response = simulator._parse_llm_response("This is not JSON") + + assert response == {"result": "This is not JSON"} + + def test_create_error_response(self): + """Test error response creation.""" + simulator = ToolSimulator() + + error = simulator._create_error_response("test_error", "Test message", 400) + + assert error["status"] == 400 + assert error["error"]["type"] == "test_error" + assert error["error"]["detail"] == "Test message" + assert error["error"]["title"] == "Bad Request" + + def test_get_error_title(self): + """Test error title mapping.""" + simulator = ToolSimulator() + + assert simulator._get_error_title(400) == "Bad Request" + assert simulator._get_error_title(404) == "Not Found" + assert simulator._get_error_title(500) == "Internal Server Error" + assert simulator._get_error_title(999) == "Error" # Unknown status code + + +class TestRegistryManagement: + """Test cases for registry management.""" + + def test_clear_registry(self): + """Test clearing tool registry.""" + @ToolSimulator.function_tool("test_function") + def test_func(): + pass + + assert len(ToolSimulator._registered_tools) == 1 + + ToolSimulator.clear_registry() + + assert len(ToolSimulator._registered_tools) == 0 + assert ToolSimulator._state_registry is None + + def test_function_has_implementation_detection(self): + """Test detection of function implementation.""" + simulator = ToolSimulator() + + # Empty function should be detected as not implemented + def empty_func(): + pass + + def implemented_func(): + return {"result": "value"} + + assert not simulator._function_has_implementation(empty_func) + assert simulator._function_has_implementation(implemented_func) + + def test_function_has_implementation_error_handling(self): + """Test function implementation detection handles errors.""" + simulator = ToolSimulator() + + # Create a mock function that will cause dis.get_instructions to fail + mock_func = MagicMock() + mock_func.__code__ = None + + # Should assume implemented on error + result = simulator._function_has_implementation(mock_func) + assert result is True From be24b1148a984bd687f7b1b656b3389bac5d1475 Mon Sep 17 00:00:00 2001 From: Darren Wang Date: Fri, 23 Jan 2026 22:38:21 +0000 Subject: [PATCH 02/15] fix tool init state registry and prompt --- .../tool_response_generation.py | 39 ++- .../simulation/tool_simulator.py | 51 +++- src/strands_evals/types/simulation/tool.py | 237 +++++++++++++----- 3 files changed, 247 insertions(+), 80 deletions(-) diff --git a/src/strands_evals/simulation/prompt_templates/tool_response_generation.py b/src/strands_evals/simulation/prompt_templates/tool_response_generation.py index 7420edb..55b7a52 100644 --- a/src/strands_evals/simulation/prompt_templates/tool_response_generation.py +++ b/src/strands_evals/simulation/prompt_templates/tool_response_generation.py @@ -37,14 +37,19 @@ Tool Name: {tool_name} Parameters: {parameters} -## Previous Tool Responses (for context) +## Initial State Context +{initial_state_description} + +## Current State & Previous Tool Responses (for context) {previous_responses} ## Instructions 1. Analyze the function name and parameters to understand what this tool should do -2. Generate a realistic response that would be returned by such a function -3. Consider the previous responses to maintain consistency in the simulation -4. Return valid JSON that represents the function's return value +2. Use the initial state description to understand the starting context and available data +3. Generate a realistic response that would be returned by such a function +4. Consider the previous responses to maintain consistency in the simulation +5. Ensure responses are consistent with the established state and realistic for the domain +6. Return valid JSON that represents the function's return value ## Response Format Return a JSON object that represents what this function would return. Examples: @@ -87,14 +92,19 @@ Tool Name: {tool_name} Input Payload: {mcp_payload} -## Previous Tool Responses (for context) +## Initial State Context +{initial_state_description} + +## Current State & Previous Tool Responses (for context) {previous_responses} ## Instructions 1. Analyze the tool name and input payload to understand what this MCP tool should do -2. Generate a realistic response following MCP response format -3. Consider the previous responses to maintain consistency in the simulation -4. Return valid JSON in MCP response format +2. Use the initial state description to understand the starting context and available data +3. Generate a realistic response following MCP response format +4. Consider the previous responses to maintain consistency in the simulation +5. Ensure responses are consistent with the established state and realistic for the domain +6. Return valid JSON in MCP response format ## MCP Response Format MCP tools return responses in this format: @@ -156,14 +166,19 @@ Method: {method} Request Payload: {api_payload} -## Previous Tool Responses (for context) +## Initial State Context +{initial_state_description} + +## Current State & Previous Tool Responses (for context) {previous_responses} ## Instructions 1. Analyze the API path, method, and payload to understand what this endpoint should do -2. Generate a realistic HTTP response with appropriate status code and data -3. Consider the previous responses to maintain consistency in the simulation -4. Return valid JSON in HTTP response format +2. Use the initial state description to understand the starting context and available data +3. Generate a realistic HTTP response with appropriate status code and data +4. Consider the previous responses to maintain consistency in the simulation +5. Ensure responses are consistent with the established state and realistic for the domain +6. Return valid JSON in HTTP response format ## HTTP Response Format API responses should include status codes and appropriate data: diff --git a/src/strands_evals/simulation/tool_simulator.py b/src/strands_evals/simulation/tool_simulator.py index 27b0963..aa5acfa 100644 --- a/src/strands_evals/simulation/tool_simulator.py +++ b/src/strands_evals/simulation/tool_simulator.py @@ -71,6 +71,7 @@ def __init__( # Initialize tool simulators for registered tools self._active_simulators: Dict[str, Any] = {} + self._initialize_shared_states() self._initialize_simulators() def _function_has_implementation(self, func: Callable) -> bool: @@ -95,6 +96,24 @@ def _function_has_implementation(self, func: Callable) -> bool: # If we can't analyze bytecode, assume it's implemented return True + def _initialize_shared_states(self): + """Initialize shared states from registered tools' initial descriptions.""" + for tool_name, registered_tool in self._registered_tools.items(): + if registered_tool.initial_state_description: + # Determine state key from tool name or simulator kwargs + state_key = ( + registered_tool.simulator_kwargs.get("share_state_id", registered_tool.name) + if registered_tool.simulator_kwargs + else registered_tool.name + ) + + # Initialize state with description + self._state_registry.initialize_state_via_description( + registered_tool.initial_state_description, + state_key + ) + logger.info(f"Initialized state for tool '{tool_name}' with key '{state_key}'") + def _initialize_simulators(self): """Initialize simulators for all registered tools with overrides.""" for tool_name, registered_tool in self._registered_tools.items(): @@ -279,10 +298,15 @@ def _handle_function_tool(self, input_data: Dict[str, Any], state_key: str) -> D try: from .prompt_templates.tool_response_generation import FUNCTION_TOOL_RESPONSE_GENERATION_PROMPT + # Get initial state description from state registry + current_state = self._state_registry.get_state(state_key) + initial_state_description = current_state.get("initial_state", "No initial state provided.") + prompt = FUNCTION_TOOL_RESPONSE_GENERATION_PROMPT.format( tool_name=tool_name, parameters=json.dumps(parameters, indent=2) if parameters else "{}", - previous_responses=json.dumps(self._state_registry.get_state(state_key), indent=2) or "{}" + initial_state_description=initial_state_description, + previous_responses=json.dumps(current_state, indent=2) or "{}" ) llm_response = self._generate_llm_response(prompt) @@ -311,10 +335,15 @@ def _handle_mcp_tool(self, input_data: Dict[str, Any], state_key: str) -> Dict[s try: from .prompt_templates.tool_response_generation import MCP_TOOL_RESPONSE_GENERATION_PROMPT + # Get initial state description from state registry + current_state = self._state_registry.get_state(state_key) + initial_state_description = current_state.get("initial_state", "No initial state provided.") + prompt = MCP_TOOL_RESPONSE_GENERATION_PROMPT.format( tool_name=tool_name, mcp_payload=json.dumps(input_mcp_payload, indent=2) if input_mcp_payload else "{}", - previous_responses=json.dumps(self._state_registry.get_state(state_key), indent=2) or "{}" + initial_state_description=initial_state_description, + previous_responses=json.dumps(current_state, indent=2) or "{}" ) llm_response = self._generate_llm_response(prompt) @@ -345,12 +374,17 @@ def _handle_api_tool(self, input_data: Dict[str, Any], state_key: str) -> Dict[s try: from .prompt_templates.tool_response_generation import API_TOOL_RESPONSE_GENERATION_PROMPT + # Get initial state description from state registry + current_state = self._state_registry.get_state(state_key) + initial_state_description = current_state.get("initial_state", "No initial state provided.") + prompt = API_TOOL_RESPONSE_GENERATION_PROMPT.format( tool_name=tool_name, path=path, method=method, api_payload=json.dumps(user_input_api_payload, indent=2) if user_input_api_payload else "{}", - previous_responses=json.dumps(self._state_registry.get_state(state_key), indent=2) or "{}" + initial_state_description=initial_state_description, + previous_responses=json.dumps(current_state, indent=2) or "{}" ) llm_response = self._generate_llm_response(prompt) @@ -452,12 +486,13 @@ def _get_error_title(self, status_code: int) -> str: return error_titles.get(status_code, 'Error') @classmethod - def function_tool(cls, name: Optional[str] = None, **simulator_kwargs) -> Callable: + def function_tool(cls, name: Optional[str] = None, initial_state_description: Optional[str] = None, **simulator_kwargs) -> Callable: """ Decorator for registering Python function tools. Args: name: Optional name for the tool. If None, uses function.__name__ + initial_state_description: Optional initial state description for the tool's context **simulator_kwargs: Additional simulator configuration Returns: @@ -472,6 +507,7 @@ def decorator(func: Callable) -> Callable: name=tool_name, tool_type=ToolType.FUNCTION, function=func, + initial_state_description=initial_state_description, simulator_kwargs=simulator_kwargs, ) cls._registered_tools[tool_name] = registered_tool @@ -487,13 +523,14 @@ def decorator(func: Callable) -> Callable: return decorator @classmethod - def mcp_tool(cls, name: Optional[str] = None, schema: Optional[Dict[str, Any]] = None, **simulator_kwargs) -> Callable: + def mcp_tool(cls, name: Optional[str] = None, schema: Optional[Dict[str, Any]] = None, initial_state_description: Optional[str] = None, **simulator_kwargs) -> Callable: """ Decorator for registering MCP (Model Context Protocol) tools. Args: name: Optional name for the tool. If None, uses function.__name__ schema: MCP tool schema dictionary + initial_state_description: Optional initial state description for the tool's context **simulator_kwargs: Additional simulator configuration Returns: @@ -511,6 +548,7 @@ def decorator(func: Callable) -> Callable: tool_type=ToolType.MCP, function=func, mcp_schema=schema, + initial_state_description=initial_state_description, simulator_kwargs=simulator_kwargs, ) cls._registered_tools[tool_name] = registered_tool @@ -527,6 +565,7 @@ def api_tool( path: Optional[str] = None, method: Optional[str] = None, schema: Optional[Dict[str, Any]] = None, + initial_state_description: Optional[str] = None, **simulator_kwargs, ) -> Callable: """ @@ -537,6 +576,7 @@ def api_tool( path: API endpoint path method: HTTP method (GET, POST, etc.) schema: API tool schema dictionary + initial_state_description: Optional initial state description for the tool's context **simulator_kwargs: Additional simulator configuration Returns: @@ -557,6 +597,7 @@ def decorator(func: Callable) -> Callable: function=func, api_path=path, api_method=method, + initial_state_description=initial_state_description, simulator_kwargs=simulator_kwargs, ) cls._registered_tools[tool_name] = registered_tool diff --git a/src/strands_evals/types/simulation/tool.py b/src/strands_evals/types/simulation/tool.py index f7bacd7..f008d35 100644 --- a/src/strands_evals/types/simulation/tool.py +++ b/src/strands_evals/types/simulation/tool.py @@ -72,6 +72,7 @@ class RegisteredTool(BaseModel): mcp_schema: MCP tool schema dictionary for MCP type tools. api_path: API endpoint path for API type tools. api_method: HTTP method for API type tools (GET, POST, etc.). + initial_state_description: Initial state description for the tool's context. simulator_kwargs: Additional simulator configuration parameters. """ name: str = Field(..., description="Name of the tool") @@ -80,6 +81,7 @@ class RegisteredTool(BaseModel): mcp_schema: Optional[Dict[str, Any]] = Field(default=None, description="MCP tool schema") api_path: Optional[str] = Field(default=None, description="API endpoint path") api_method: Optional[str] = Field(default=None, description="HTTP method") + initial_state_description: Optional[str] = Field(default=None, description="Initial state description for the tool's context") simulator_kwargs: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Additional simulator configuration") class Config: @@ -88,10 +90,8 @@ class Config: class StateRegistry: """ - Simple state registry for maintaining tool state across calls. - - Attributes: - _states: Internal dictionary mapping state keys to recorded call history. + State registry for managing shared state between tool simulators. + Organized by state_key to isolate state between different tools or shared state groups. """ def __init__(self): @@ -103,81 +103,192 @@ def __init__(self): """ self._states: Dict[str, Dict[str, Any]] = {} - def get_state(self, key: str) -> Dict[str, Any]: + def initialize_state_via_description(self, initial_state_description: str, state_key: str) -> None: """ - Get state for a given key. + Initialize state based on the provided description. + + This method pre-seeds the state with an initial description that will be + included in all subsequent LLM prompts, allowing the simulator to have + context about pre-existing data or system state. Args: - key: State key to retrieve recorded calls for. + initial_state_description: Description of the initial state (e.g., existing + database records, system configuration, etc.). + state_key: Key for the state in the registry (typically tool_name or share_state_id). + """ + if state_key not in self._states: + self._states[state_key] = { + "initial_state": initial_state_description, + "previous_calls": [], + "user_context": {}, + } + + def get_state(self, state_key: str) -> Dict[str, Any]: + """ + Get state for a specific tool or shared state group. + + Args: + state_key: Key for the state (tool_name or share_state_id). Returns: - Dictionary containing recorded call history for the key, empty if not found. + State dictionary containing previous_calls and user_context. """ - return self._states.get(key, {}) - - def record_function_call(self, tool_name: str, state_key: str, parameters: Dict[str, Any], response_data: Any): + if state_key is None: + raise ValueError("Value of state_key is required.") + + if state_key not in self._states: + self._states[state_key] = { + "previous_calls": [], + "user_context": {}, + } + + return dict(self._states[state_key]) + + def record_function_call( + self, + tool_name: str, + state_key: str, + parameters: Dict[str, Any], + response_data: Any, + ) -> Dict[str, Any]: """ - Record a function call in state. + Record a function call in the tool's state history. Args: - tool_name: Name of the function tool that was called. - state_key: State key to record the call under. + tool_name: Name of the function being called. + state_key: Key for the state (tool_name or share_state_id). parameters: Parameters passed to the function. - response_data: Response data returned from the function. + response_data: Response from the function call. + + Returns: + Updated state dictionary. """ - if state_key not in self._states: - self._states[state_key] = {"function_calls": []} - - call_record = { - "tool_name": tool_name, - "parameters": parameters, - "response": response_data, - "timestamp": datetime.now().isoformat() - } - self._states[state_key]["function_calls"].append(call_record) - - def record_mcp_tool_call(self, tool_name: str, state_key: str, input_mcp_payload: Dict[str, Any], response_data: Any): + state = self.get_state(state_key) + date_timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + + state["previous_calls"].append({ + 'tool_name': tool_name, + 'tool_type': 'function', + 'parameters': parameters, + 'response': response_data, + 'timestamp': date_timestamp + }) + + # Keep history manageable + if len(state["previous_calls"]) > 20: + state["previous_calls"] = state["previous_calls"][-20:] + + # Update the stored state + self._states[state_key] = state + + return state + + def record_mcp_tool_call( + self, + tool_name: str, + state_key: str, + input_mcp_payload: Dict[str, Any], + response_data: Any, + ) -> Dict[str, Any]: """ - Record an MCP tool call in state. + Record an MCP tool call in the tool's state history. Args: - tool_name: Name of the MCP tool that was called. - state_key: State key to record the call under. - input_mcp_payload: Input payload sent to the MCP tool. - response_data: Response data returned from the MCP tool. + tool_name: Name of the MCP tool being called. + state_key: Key for the state (tool_name or share_state_id). + input_mcp_payload: Input payload for the MCP tool call. + response_data: Response from the MCP tool call. + + Returns: + Updated state dictionary. """ - if state_key not in self._states: - self._states[state_key] = {"mcp_calls": []} - - call_record = { - "tool_name": tool_name, - "input": input_mcp_payload, - "response": response_data, - "timestamp": datetime.now().isoformat() - } - self._states[state_key]["mcp_calls"].append(call_record) - - def record_api_call(self, tool_name: str, state_key: str, path: str, method: str, input_data: Dict[str, Any], response: Dict[str, Any]): + state = self.get_state(state_key) + date_timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + + state["previous_calls"].append({ + 'tool_name': tool_name, + 'tool_type': 'mcp', + 'input_mcp_payload': input_mcp_payload, + 'response': response_data, + 'timestamp': date_timestamp + }) + + # Keep history manageable + if len(state["previous_calls"]) > 20: + state["previous_calls"] = state["previous_calls"][-20:] + + # Update the stored state + self._states[state_key] = state + + return state + + def record_api_call( + self, + tool_name: str, + state_key: str, + path: str, + method: str, + input_data: Any, + response: Any, + ) -> Dict[str, Any]: """ - Record an API call in state. + Record an API call in the tool's state history. Args: - tool_name: Name of the API tool that was called. - state_key: State key to record the call under. - path: API endpoint path that was called. - method: HTTP method used for the API call. - input_data: Input data sent to the API endpoint. - response: Response data returned from the API endpoint. + tool_name: Name of the API tool being called. + state_key: Key for the state (tool_name or share_state_id). + path: API endpoint path. + method: HTTP method. + input_data: Input data for the API call. + response: Response from the API call. + + Returns: + Updated state dictionary. """ - if state_key not in self._states: - self._states[state_key] = {"api_calls": []} - - call_record = { - "tool_name": tool_name, - "path": path, - "method": method, - "input": input_data, - "response": response, - "timestamp": datetime.now().isoformat() - } - self._states[state_key]["api_calls"].append(call_record) + state = self.get_state(state_key) + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + + state["previous_calls"].append({ + 'tool_name': tool_name, + 'tool_type': 'api', + 'path': path, + 'method': method, + 'input': input_data, + 'response': response, + 'timestamp': timestamp + }) + + # Keep history manageable + if len(state["previous_calls"]) > 20: + state["previous_calls"] = state["previous_calls"][-20:] + + # Update the stored state + self._states[state_key] = state + + return state + + def set_user_context(self, state_key: str, user_context: Dict[str, Any]) -> Dict[str, Any]: + """ + Set user context for a state. + + Args: + state_key: Key for the state (tool_name or share_state_id). + user_context: User context dictionary to store. + + Returns: + Updated state dictionary. + """ + state = self.get_state(state_key) + state["user_context"] = user_context + self._states[state_key] = state + return state + + def clear_state(self, state_key: str) -> None: + """ + Clear state for a specific tool or shared state group. + + Args: + state_key: Key for the state to clear. + """ + if state_key in self._states: + del self._states[state_key] From 1d507f0db57f3cd828dd159f4034ffc3349a1ab5 Mon Sep 17 00:00:00 2001 From: Darren Wang Date: Fri, 23 Jan 2026 23:39:19 +0000 Subject: [PATCH 03/15] support static and mock modes --- .../tool_override_generation.py | 3 +- .../simulation/tool_simulator.py | 392 +++++++++++------- src/strands_evals/types/simulation/tool.py | 3 + 3 files changed, 235 insertions(+), 163 deletions(-) diff --git a/src/strands_evals/simulation/prompt_templates/tool_override_generation.py b/src/strands_evals/simulation/prompt_templates/tool_override_generation.py index 5583e64..f61782a 100644 --- a/src/strands_evals/simulation/prompt_templates/tool_override_generation.py +++ b/src/strands_evals/simulation/prompt_templates/tool_override_generation.py @@ -115,14 +115,13 @@ "category_rationale": "Category 1: Compute tool - pure mathematical operation with no side effects", "resource_dependencies": ["resource_name"], "consistency_requirements": ["related_tool1", "related_tool2"], - "should_simulate": false, "failure_conditions": {{ "enabled": false, "error_rate": 0.0, "error_type": "execution_error", "error_message": "Error message if failure occurs" }}, - "rationale": "Final decision rationale considering category and consistency rules" + "rationale": "Simulation configuration rationale focusing on failure conditions and error patterns" }} ] }} diff --git a/src/strands_evals/simulation/tool_simulator.py b/src/strands_evals/simulation/tool_simulator.py index aa5acfa..287f426 100644 --- a/src/strands_evals/simulation/tool_simulator.py +++ b/src/strands_evals/simulation/tool_simulator.py @@ -32,7 +32,6 @@ class ToolSimulator: tool_overrides: Dictionary mapping tool names to override configurations. system_prompt_template: Template string for system prompts. model: Provider for running inference or model identifier for Bedrock. - _active_simulators: Dictionary of active tool simulators. _registered_tools: Class-level registry for all registered tools. _state_registry: Registry for maintaining tool state across calls. """ @@ -69,10 +68,8 @@ def __init__( elif self._state_registry is None: self._state_registry = StateRegistry() - # Initialize tool simulators for registered tools - self._active_simulators: Dict[str, Any] = {} + # Initialize shared states from registered tools self._initialize_shared_states() - self._initialize_simulators() def _function_has_implementation(self, func: Callable) -> bool: """Check if a function has actual implementation or is just an empty stub.""" @@ -114,148 +111,15 @@ def _initialize_shared_states(self): ) logger.info(f"Initialized state for tool '{tool_name}' with key '{state_key}'") - def _initialize_simulators(self): - """Initialize simulators for all registered tools with overrides.""" - for tool_name, registered_tool in self._registered_tools.items(): - # Check if we have override config for this tool - override_config = self.tool_overrides.get(tool_name) - - if override_config: - # Create simulator with override configuration - simulator = self._create_tool_simulator(registered_tool, override_config) - self._active_simulators[tool_name] = simulator - else: - # Check if function is an empty stub or has real implementation - if registered_tool.function and self._function_has_implementation(registered_tool.function): - # Use real function for implemented functions - self._active_simulators[tool_name] = registered_tool.function - else: - # Create default simulator for empty stubs - default_config = ToolOverrideConfig() - simulator = self._create_tool_simulator(registered_tool, default_config) - self._active_simulators[tool_name] = simulator - - def _create_tool_simulator(self, registered_tool: RegisteredTool, config: ToolOverrideConfig) -> Any: - """Create a tool simulator instance based on the registered tool type.""" - # Determine state key from tool name or simulator kwargs - state_key = ( - registered_tool.simulator_kwargs.get("share_state_id", registered_tool.name) - if registered_tool.simulator_kwargs - else registered_tool.name - ) - - # Create wrapper function that handles the simulation - if registered_tool.tool_type == ToolType.FUNCTION: - return self._create_function_simulator_wrapper(registered_tool, registered_tool.tool_type, state_key) - elif registered_tool.tool_type == ToolType.MCP: - return self._create_mcp_simulator_wrapper(registered_tool, registered_tool.tool_type, state_key) - elif registered_tool.tool_type == ToolType.API: - return self._create_api_simulator_wrapper(registered_tool, registered_tool.tool_type, state_key) - else: - raise ValueError(f"Unsupported tool type: {registered_tool.tool_type}") - - def _create_function_simulator_wrapper(self, registered_tool: RegisteredTool, tool_type: ToolType, state_key: str) -> Callable: - """Create a wrapper function for function tool simulation.""" - def wrapper(*args, **kwargs): - try: - # Build parameters as expected by simulation - parameters_string = ( - json.dumps({"args": args, "kwargs": kwargs}, indent=2) - if args - else json.dumps(kwargs, indent=2) - ) - - # Get tool behavior configuration from tool overrides - tool_override_config = {} - if registered_tool.name in self.tool_overrides: - override_config = self.tool_overrides[registered_tool.name] - if override_config.failure_conditions: - tool_override_config["failure_conditions"] = override_config.failure_conditions.model_dump() - else: - tool_override_config["failure_conditions"] = {"enabled": False} - - input_data = { - "tool_name": registered_tool.name, - "parameters": parameters_string, - "tool_override": tool_override_config, - } - - return self._simulate_tool_call(tool_type, state_key, input_data) - except Exception as e: - logger.error(f"Error in function simulation for {registered_tool.name}: {e}") - raise - - # Copy function metadata - if registered_tool.function: - wrapper.__name__ = registered_tool.function.__name__ - try: - wrapper.__signature__ = inspect.signature(registered_tool.function) # type: ignore - except (ValueError, TypeError): - pass - wrapper.__doc__ = registered_tool.function.__doc__ - else: - wrapper.__name__ = registered_tool.name - - return wrapper - - def _create_mcp_simulator_wrapper(self, registered_tool: RegisteredTool, tool_type: ToolType, state_key: str) -> Callable: - """Create a wrapper function for MCP tool simulation.""" - def wrapper(**params): - try: - # Get tool behavior configuration from tool overrides - tool_override_config = {} - if registered_tool.name in self.tool_overrides: - override_config = self.tool_overrides[registered_tool.name] - if override_config.failure_conditions: - tool_override_config["failure_conditions"] = override_config.failure_conditions.model_dump() - else: - tool_override_config["failure_conditions"] = {"enabled": False} - - input_data = { - "tool_name": registered_tool.name, - "input_mcp_payload": params, - "tool_override": tool_override_config, - } - - return self._simulate_tool_call(tool_type, state_key, input_data) - except Exception as e: - logger.error(f"Error in MCP simulation for {registered_tool.name}: {e}") - raise - - wrapper.__name__ = registered_tool.name - return wrapper - - def _create_api_simulator_wrapper(self, registered_tool: RegisteredTool, tool_type: ToolType, state_key: str) -> Callable: - """Create a wrapper function for API tool simulation.""" - def wrapper(**kwargs): - try: - # Get tool behavior configuration from tool overrides - tool_override_config = {} - if registered_tool.name in self.tool_overrides: - override_config = self.tool_overrides[registered_tool.name] - if override_config.failure_conditions: - tool_override_config["failure_conditions"] = override_config.failure_conditions.model_dump() - else: - tool_override_config["failure_conditions"] = {"enabled": False} - - input_data = { - "tool_name": registered_tool.name, - "user_input_api_payload": kwargs, - "path": registered_tool.api_path or "", - "method": registered_tool.api_method or "GET", - "tool_override": tool_override_config, - } - - return self._simulate_tool_call(tool_type, state_key, input_data) - except Exception as e: - logger.error(f"Error in API simulation for {registered_tool.name}: {e}") - raise - - wrapper.__name__ = registered_tool.name - return wrapper def _simulate_tool_call(self, tool_type: ToolType, state_key: str, input_data: Dict[str, Any]) -> Any: """Simulate a tool invocation and return the response.""" + tool_name = input_data.get("tool_name", "") + registered_tool = self._registered_tools.get(tool_name) + + if not registered_tool: + return self._create_error_response("tool_not_found", f"Tool '{tool_name}' not found") + # Handle tool behavior configuration tool_override = input_data.get("tool_override", {}) @@ -276,15 +140,23 @@ def _simulate_tool_call(self, tool_type: ToolType, state_key: str, input_data: D "message": error_message } - # Route to appropriate handler based on tool type - if tool_type == ToolType.FUNCTION: - return self._handle_function_tool(input_data, state_key) - elif tool_type == ToolType.MCP: - return self._handle_mcp_tool(input_data, state_key) - elif tool_type == ToolType.API: - return self._handle_api_tool(input_data, state_key) + # Handle different simulation modes + if registered_tool.mode == "static": + return self._handle_static_mode(registered_tool, tool_type) + elif registered_tool.mode == "mock": + return self._handle_mock_mode(registered_tool, input_data, state_key, tool_type) + elif registered_tool.mode == "dynamic": + # Route to appropriate handler based on tool type + if tool_type == ToolType.FUNCTION: + return self._handle_function_tool(input_data, state_key) + elif tool_type == ToolType.MCP: + return self._handle_mcp_tool(input_data, state_key) + elif tool_type == ToolType.API: + return self._handle_api_tool(input_data, state_key) + else: + return self._create_error_response("unsupported_tool_type", f"Tool type '{tool_type}' not supported") else: - return self._create_error_response("unsupported_tool_type", f"Tool type '{tool_type}' not supported") + return self._create_error_response("unsupported_mode", f"Simulation mode '{registered_tool.mode}' not supported") def _handle_function_tool(self, input_data: Dict[str, Any], state_key: str) -> Dict[str, Any]: """Handle function tool simulation.""" @@ -485,14 +357,96 @@ def _get_error_title(self, status_code: int) -> str: } return error_titles.get(status_code, 'Error') + def _handle_static_mode(self, registered_tool: RegisteredTool, tool_type: ToolType) -> Dict[str, Any]: + """Handle static mode simulation - returns predefined static response.""" + if registered_tool.static_response is not None: + return registered_tool.static_response + + # Default static responses for different tool types + if tool_type == ToolType.FUNCTION: + return {"status": "success", "result": f"Static response from {registered_tool.name}"} + elif tool_type == ToolType.MCP: + return { + "isError": False, + "content": [{"type": "text", "text": f"Static response from {registered_tool.name}"}] + } + elif tool_type == ToolType.API: + return {"status": 200, "data": {"message": f"Static response from {registered_tool.name}"}} + else: + return {"status": "error", "message": "Unsupported tool type for static mode"} + + def _handle_mock_mode(self, registered_tool: RegisteredTool, input_data: Dict[str, Any], state_key: str, tool_type: ToolType) -> Dict[str, Any]: + """Handle mock mode simulation - calls custom mock function.""" + if registered_tool.mock_function is not None: + try: + # Extract parameters based on tool type + if tool_type == ToolType.FUNCTION: + parameters = input_data.get("parameters", {}) + if isinstance(parameters, str): + parameters = json.loads(parameters) + + # Call mock function with extracted parameters + if "kwargs" in parameters: + result = registered_tool.mock_function(**parameters["kwargs"]) + elif "args" in parameters: + result = registered_tool.mock_function(*parameters["args"]) + else: + result = registered_tool.mock_function(**parameters) + + elif tool_type == ToolType.MCP: + input_mcp_payload = input_data.get("input_mcp_payload", {}) + result = registered_tool.mock_function(**input_mcp_payload) + + elif tool_type == ToolType.API: + user_input_api_payload = input_data.get("user_input_api_payload", {}) + result = registered_tool.mock_function(**user_input_api_payload) + + else: + return {"status": "error", "message": "Unsupported tool type for mock mode"} + + # Record the call in state registry + tool_name = registered_tool.name + if tool_type == ToolType.FUNCTION: + self._state_registry.record_function_call(tool_name, state_key, parameters, result) + elif tool_type == ToolType.MCP: + self._state_registry.record_mcp_tool_call(tool_name, state_key, input_mcp_payload, result) + elif tool_type == ToolType.API: + path = input_data.get("path", "") + method = input_data.get("method", "GET") + self._state_registry.record_api_call(tool_name, state_key, path, method, user_input_api_payload, result) + + return result + + except Exception as e: + logger.error(f"Error calling mock function for {registered_tool.name}: {e}") + if tool_type == ToolType.API: + return self._create_error_response("mock_error", str(e), 500) + else: + return {"status": "error", "error_type": "mock_error", "message": str(e)} + + # Fallback to static mode if no mock function provided + logger.warning(f"No mock function provided for {registered_tool.name}, falling back to static mode") + return self._handle_static_mode(registered_tool, tool_type) + @classmethod - def function_tool(cls, name: Optional[str] = None, initial_state_description: Optional[str] = None, **simulator_kwargs) -> Callable: + def function_tool( + cls, + name: Optional[str] = None, + initial_state_description: Optional[str] = None, + mode: str = "dynamic", + static_response: Optional[Dict[str, Any]] = None, + mock_function: Optional[Callable] = None, + **simulator_kwargs + ) -> Callable: """ Decorator for registering Python function tools. Args: name: Optional name for the tool. If None, uses function.__name__ initial_state_description: Optional initial state description for the tool's context + mode: Simulation mode - "dynamic", "static", or "mock" + static_response: Static response dict for static mode + mock_function: Custom callable for mock mode **simulator_kwargs: Additional simulator configuration Returns: @@ -509,6 +463,9 @@ def decorator(func: Callable) -> Callable: function=func, initial_state_description=initial_state_description, simulator_kwargs=simulator_kwargs, + mode=mode, + static_response=static_response, + mock_function=mock_function, ) cls._registered_tools[tool_name] = registered_tool @@ -523,7 +480,16 @@ def decorator(func: Callable) -> Callable: return decorator @classmethod - def mcp_tool(cls, name: Optional[str] = None, schema: Optional[Dict[str, Any]] = None, initial_state_description: Optional[str] = None, **simulator_kwargs) -> Callable: + def mcp_tool( + cls, + name: Optional[str] = None, + schema: Optional[Dict[str, Any]] = None, + initial_state_description: Optional[str] = None, + mode: str = "dynamic", + static_response: Optional[Dict[str, Any]] = None, + mock_function: Optional[Callable] = None, + **simulator_kwargs + ) -> Callable: """ Decorator for registering MCP (Model Context Protocol) tools. @@ -531,6 +497,9 @@ def mcp_tool(cls, name: Optional[str] = None, schema: Optional[Dict[str, Any]] = name: Optional name for the tool. If None, uses function.__name__ schema: MCP tool schema dictionary initial_state_description: Optional initial state description for the tool's context + mode: Simulation mode - "dynamic", "static", or "mock" + static_response: Static response dict for static mode + mock_function: Custom callable for mock mode **simulator_kwargs: Additional simulator configuration Returns: @@ -550,6 +519,9 @@ def decorator(func: Callable) -> Callable: mcp_schema=schema, initial_state_description=initial_state_description, simulator_kwargs=simulator_kwargs, + mode=mode, + static_response=static_response, + mock_function=mock_function, ) cls._registered_tools[tool_name] = registered_tool @@ -566,6 +538,9 @@ def api_tool( method: Optional[str] = None, schema: Optional[Dict[str, Any]] = None, initial_state_description: Optional[str] = None, + mode: str = "dynamic", + static_response: Optional[Dict[str, Any]] = None, + mock_function: Optional[Callable] = None, **simulator_kwargs, ) -> Callable: """ @@ -577,6 +552,9 @@ def api_tool( method: HTTP method (GET, POST, etc.) schema: API tool schema dictionary initial_state_description: Optional initial state description for the tool's context + mode: Simulation mode - "dynamic", "static", or "mock" + static_response: Static response dict for static mode + mock_function: Custom callable for mock mode **simulator_kwargs: Additional simulator configuration Returns: @@ -599,6 +577,9 @@ def decorator(func: Callable) -> Callable: api_method=method, initial_state_description=initial_state_description, simulator_kwargs=simulator_kwargs, + mode=mode, + static_response=static_response, + mock_function=mock_function, ) cls._registered_tools[tool_name] = registered_tool @@ -725,9 +706,8 @@ def _generate_override_from_case(case: Case) -> Dict[str, ToolOverrideConfig]: for override in tool_overrides: tool_name = override.get("tool_name") - should_simulate = override.get("should_simulate", True) - if not tool_name or not should_simulate: + if not tool_name: continue # Add failure conditions using new schema format @@ -761,15 +741,104 @@ def _generate_override_from_case(case: Case) -> Dict[str, ToolOverrideConfig]: def get_tool(self, tool_name: str) -> Optional[Callable]: """ - Get a tool by name from the active simulators. + Get a tool by name and create a simulation wrapper. Args: tool_name: Name of the tool to retrieve Returns: - Tool callable if found, None otherwise + Tool callable wrapper if found, None otherwise """ - return self._active_simulators.get(tool_name) + registered_tool = self._registered_tools.get(tool_name) + if not registered_tool: + return None + + return self._create_tool_wrapper(registered_tool) + + def _create_tool_wrapper(self, registered_tool: RegisteredTool) -> Callable: + """Create a wrapper function for direct tool access.""" + def wrapper(*args, **kwargs): + # Determine state key + state_key = ( + registered_tool.simulator_kwargs.get("share_state_id", registered_tool.name) + if registered_tool.simulator_kwargs + else registered_tool.name + ) + + # Build input data based on tool type + if registered_tool.tool_type == ToolType.FUNCTION: + parameters_string = ( + json.dumps({"args": args, "kwargs": kwargs}, indent=2) + if args + else json.dumps(kwargs, indent=2) + ) + + # Get tool override configuration + tool_override_config = {} + if registered_tool.name in self.tool_overrides: + override_config = self.tool_overrides[registered_tool.name] + if override_config.failure_conditions: + tool_override_config["failure_conditions"] = override_config.failure_conditions.model_dump() + else: + tool_override_config["failure_conditions"] = {"enabled": False} + + input_data = { + "tool_name": registered_tool.name, + "parameters": parameters_string, + "tool_override": tool_override_config, + } + + elif registered_tool.tool_type == ToolType.MCP: + # Get tool override configuration + tool_override_config = {} + if registered_tool.name in self.tool_overrides: + override_config = self.tool_overrides[registered_tool.name] + if override_config.failure_conditions: + tool_override_config["failure_conditions"] = override_config.failure_conditions.model_dump() + else: + tool_override_config["failure_conditions"] = {"enabled": False} + + input_data = { + "tool_name": registered_tool.name, + "input_mcp_payload": kwargs, + "tool_override": tool_override_config, + } + + elif registered_tool.tool_type == ToolType.API: + # Get tool override configuration + tool_override_config = {} + if registered_tool.name in self.tool_overrides: + override_config = self.tool_overrides[registered_tool.name] + if override_config.failure_conditions: + tool_override_config["failure_conditions"] = override_config.failure_conditions.model_dump() + else: + tool_override_config["failure_conditions"] = {"enabled": False} + + input_data = { + "tool_name": registered_tool.name, + "user_input_api_payload": kwargs, + "path": registered_tool.api_path or "", + "method": registered_tool.api_method or "GET", + "tool_override": tool_override_config, + } + + else: + raise ValueError(f"Unsupported tool type: {registered_tool.tool_type}") + + return self._simulate_tool_call(registered_tool.tool_type, state_key, input_data) + + # Copy function metadata + if registered_tool.function: + wrapper.__name__ = registered_tool.function.__name__ + try: + wrapper.__signature__ = inspect.signature(registered_tool.function) # type: ignore + except (ValueError, TypeError): + pass + wrapper.__doc__ = registered_tool.function.__doc__ + else: + wrapper.__name__ = registered_tool.name + + return wrapper def list_tools(self) -> List[str]: """ @@ -800,7 +869,8 @@ def __getattr__(self, name: str) -> Any: Raises: AttributeError: If tool not found """ - if name in self._active_simulators: - return self._active_simulators[name] + registered_tool = self._registered_tools.get(name) + if registered_tool: + return self._create_tool_wrapper(registered_tool) - raise AttributeError(f"Tool '{name}' not found in active simulators") + raise AttributeError(f"Tool '{name}' not found in registered tools") diff --git a/src/strands_evals/types/simulation/tool.py b/src/strands_evals/types/simulation/tool.py index f008d35..ee3e0ec 100644 --- a/src/strands_evals/types/simulation/tool.py +++ b/src/strands_evals/types/simulation/tool.py @@ -83,6 +83,9 @@ class RegisteredTool(BaseModel): api_method: Optional[str] = Field(default=None, description="HTTP method") initial_state_description: Optional[str] = Field(default=None, description="Initial state description for the tool's context") simulator_kwargs: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Additional simulator configuration") + mode: str = Field(default="dynamic", description="Simulation mode: dynamic, static, mock") + static_response: Optional[Dict[str, Any]] = Field(default=None, description="Static response for static mode") + mock_function: Optional[Callable] = Field(default=None, description="Mock function for mock mode", exclude=True) class Config: arbitrary_types_allowed = True From 7b82fc80438ca3d493b881557e3dbd5390751a72 Mon Sep 17 00:00:00 2001 From: Darren Wang Date: Tue, 27 Jan 2026 00:11:50 +0000 Subject: [PATCH 04/15] unit test tool simulator --- .../simulation/test_tool_simulator.py | 1180 +++++++++-------- 1 file changed, 645 insertions(+), 535 deletions(-) diff --git a/tests/strands_evals/simulation/test_tool_simulator.py b/tests/strands_evals/simulation/test_tool_simulator.py index 9e9c0c4..1846dcc 100644 --- a/tests/strands_evals/simulation/test_tool_simulator.py +++ b/tests/strands_evals/simulation/test_tool_simulator.py @@ -1,10 +1,13 @@ """Tests for ToolSimulator class.""" import json -from unittest.mock import AsyncMock, MagicMock, patch +from typing import Any, Dict +from unittest.mock import MagicMock, patch import pytest +from strands import Agent, tool +from strands_evals.case import Case from strands_evals.simulation.tool_simulator import ToolSimulator from strands_evals.types.simulation.tool import ( FailureConditions, @@ -36,13 +39,12 @@ def sample_tool_override_config(sample_failure_conditions): @pytest.fixture -def sample_scenario(): - """Fixture providing a sample scenario dictionary.""" - return { - "name": "Banking Simulation", - "description": "Test scenario for banking operations with account balance checks", - "metadata": {"domain": "finance", "complexity": "medium"}, - } +def sample_case(): + """Fixture providing a sample test case.""" + return Case( + input="I want to test tool simulation", + metadata={"task_description": "Complete tool simulation test"}, + ) @pytest.fixture @@ -50,8 +52,8 @@ def mock_model(): """Fixture providing a mock model for testing.""" mock = MagicMock() - # Mock the async generator for model.generate() - async def mock_generate(messages, system_prompt=None): + # Mock the structured_output method + def mock_structured_output(output_type, messages, system_prompt=None): # Simulate streaming response yield { "contentBlockDelta": { @@ -59,7 +61,7 @@ async def mock_generate(messages, system_prompt=None): } } - mock.generate = mock_generate + mock.structured_output = mock_structured_output return mock @@ -71,559 +73,667 @@ def clear_registry(): ToolSimulator.clear_registry() -class TestToolSimulatorInitialization: - """Test cases for ToolSimulator initialization.""" +def test_tool_simulator_init(sample_tool_override_config): + """Test ToolSimulator initialization with all parameters.""" + custom_registry = StateRegistry() + tool_overrides = {"test_tool": sample_tool_override_config} + template = "You are a helpful assistant simulating tools." + + simulator = ToolSimulator( + tool_overrides=tool_overrides, + state_registry=custom_registry, + system_prompt_template=template, + model=None, + ) + + assert simulator.tool_overrides == tool_overrides + assert simulator._state_registry is custom_registry + assert simulator.system_prompt_template == template + assert simulator.model is not None + + +def test_function_tool_decorator_registration(): + """Test function tool decorator registration.""" + @ToolSimulator.function_tool() + def test_function(x: int, y: str) -> dict: + """A sample function for testing.""" + return {"x": x, "y": y} + + assert "test_function" in ToolSimulator._registered_tools + registered_tool = ToolSimulator._registered_tools["test_function"] + assert registered_tool.name == "test_function" + assert registered_tool.tool_type == ToolType.FUNCTION + assert registered_tool.function == test_function + + +def test_mcp_tool_decorator_registration(): + """Test MCP tool decorator registration.""" + schema = { + "type": "object", + "properties": {"param": {"type": "string"}}, + "required": ["param"] + } - def test_init_with_defaults(self): - """Test ToolSimulator initialization with default parameters.""" - simulator = ToolSimulator() - - assert simulator.tool_overrides == {} - assert simulator.simulator_config == {} - assert simulator.system_prompt_template is None - assert simulator.model is not None - assert simulator._state_registry is not None - assert simulator._active_simulators == {} - - def test_init_with_model_string(self): - """Test ToolSimulator initialization with model string.""" - model_id = "anthropic.claude-3-haiku-20240307-v1:0" - simulator = ToolSimulator(model=model_id) - - assert simulator.model is not None - # The model should be configured with the provided model_id + @ToolSimulator.mcp_tool("test_mcp", schema=schema) + def sample_mcp_tool(**params): + """A sample MCP tool for testing.""" + return {"content": [{"type": "text", "text": f"Result: {params}"}]} + + assert "test_mcp" in ToolSimulator._registered_tools + registered_tool = ToolSimulator._registered_tools["test_mcp"] + assert registered_tool.name == "test_mcp" + assert registered_tool.tool_type == ToolType.MCP + assert registered_tool.mcp_schema == schema + + +def test_api_tool_decorator_registration(): + """Test API tool decorator registration.""" + @ToolSimulator.api_tool("test_api", path="/test", method="POST") + def sample_api_tool(**kwargs): + """A sample API tool for testing.""" + return {"status": 200, "data": kwargs} + + assert "test_api" in ToolSimulator._registered_tools + registered_tool = ToolSimulator._registered_tools["test_api"] + assert registered_tool.name == "test_api" + assert registered_tool.tool_type == ToolType.API + assert registered_tool.api_path == "/test" + assert registered_tool.api_method == "POST" + + +def test_function_tool_simulation(mock_model): + """Test function tool simulation.""" + # Register and create simulator with mock model + @ToolSimulator.function_tool("test_function") + def test_func(message: str) -> dict: + """Test function that should be simulated.""" + pass + + # Mock the structured_output method to return expected JSON + def mock_structured_output(output_type, messages, system_prompt=None): + yield {"output": '{"result": "simulated response"}'} + + mock_model.structured_output = mock_structured_output + simulator = ToolSimulator(model=mock_model) + + # Execute simulated function + result = simulator.test_function("Hello, world!") + + assert result == {"result": "simulated response"} - def test_init_with_model_object(self, mock_model): - """Test ToolSimulator initialization with model object.""" - simulator = ToolSimulator(model=mock_model) - - assert simulator.model is mock_model - def test_init_with_tool_overrides(self, sample_tool_override_config): - """Test ToolSimulator initialization with tool overrides.""" - tool_overrides = {"test_tool": sample_tool_override_config} - simulator = ToolSimulator(tool_overrides=tool_overrides) - - assert simulator.tool_overrides == tool_overrides +def test_mcp_tool_simulation(mock_model): + """Test MCP tool simulation.""" + # Register and create simulator with mock model + schema = {"type": "object", "properties": {"param": {"type": "string"}}} + @ToolSimulator.mcp_tool("test_mcp", schema=schema) + def test_mcp(**params): + """Test MCP tool that should be simulated.""" + pass + + # Mock the structured_output method to return expected JSON + def mock_structured_output(output_type, messages, system_prompt=None): + yield {"output": '{"content": [{"type": "text", "text": "MCP response"}]}'} + + mock_model.structured_output = mock_structured_output + simulator = ToolSimulator(model=mock_model) + + # Execute simulated MCP tool + result = simulator.test_mcp(param="test_value") + + assert result == {"content": [{"type": "text", "text": "MCP response"}]} - def test_init_with_custom_state_registry(self): - """Test ToolSimulator initialization with custom state registry.""" - custom_registry = StateRegistry() - simulator = ToolSimulator(state_registry=custom_registry) - - assert simulator._state_registry is custom_registry - def test_init_with_system_prompt_template(self): - """Test ToolSimulator initialization with system prompt template.""" - template = "You are a helpful assistant simulating tools." - simulator = ToolSimulator(system_prompt_template=template) - - assert simulator.system_prompt_template == template - - -class TestToolDecorators: - """Test cases for tool decorator registration.""" - - def test_function_tool_decorator(self): - """Test function tool decorator registration.""" - @ToolSimulator.function_tool("test_function") - def sample_function(x: int, y: str) -> dict: - """A sample function for testing.""" - return {"x": x, "y": y} - - assert "test_function" in ToolSimulator._registered_tools - registered_tool = ToolSimulator._registered_tools["test_function"] - assert registered_tool.name == "test_function" - assert registered_tool.tool_type == ToolType.FUNCTION - assert registered_tool.function == sample_function - - def test_function_tool_decorator_without_name(self): - """Test function tool decorator uses function name when no name provided.""" - @ToolSimulator.function_tool() - def my_test_function(): - """Test function.""" - pass - - assert "my_test_function" in ToolSimulator._registered_tools - registered_tool = ToolSimulator._registered_tools["my_test_function"] - assert registered_tool.name == "my_test_function" - - def test_mcp_tool_decorator(self): - """Test MCP tool decorator registration.""" - schema = { - "type": "object", - "properties": {"param": {"type": "string"}}, - "required": ["param"] - } +def test_api_tool_simulation(mock_model): + """Test API tool simulation.""" + # Register and create simulator with mock model + @ToolSimulator.api_tool("test_api", path="/test", method="POST") + def test_api(**kwargs): + """Test API tool that should be simulated.""" + pass + + # Mock the structured_output method to return expected JSON + def mock_structured_output(output_type, messages, system_prompt=None): + yield {"output": '{"status": 200, "data": {"key": "value"}}'} + + mock_model.structured_output = mock_structured_output + simulator = ToolSimulator(model=mock_model) + + # Execute simulated API tool + result = simulator.test_api(key="value") + + assert result == {"status": 200, "data": {"key": "value"}} - @ToolSimulator.mcp_tool("test_mcp", schema=schema) - def sample_mcp_tool(**params): - """A sample MCP tool for testing.""" - return {"content": [{"type": "text", "text": f"Result: {params}"}]} - - assert "test_mcp" in ToolSimulator._registered_tools - registered_tool = ToolSimulator._registered_tools["test_mcp"] - assert registered_tool.name == "test_mcp" - assert registered_tool.tool_type == ToolType.MCP - assert registered_tool.mcp_schema == schema - - def test_mcp_tool_decorator_requires_schema(self): - """Test MCP tool decorator requires schema parameter.""" - with pytest.raises(ValueError, match="MCP schema is required"): - @ToolSimulator.mcp_tool("test_mcp") - def sample_mcp_tool(**params): - pass - - def test_api_tool_decorator(self): - """Test API tool decorator registration.""" - @ToolSimulator.api_tool("test_api", path="/test", method="POST") - def sample_api_tool(**kwargs): - """A sample API tool for testing.""" - return {"status": 200, "data": kwargs} - - assert "test_api" in ToolSimulator._registered_tools - registered_tool = ToolSimulator._registered_tools["test_api"] - assert registered_tool.name == "test_api" - assert registered_tool.tool_type == ToolType.API - assert registered_tool.api_path == "/test" - assert registered_tool.api_method == "POST" - - def test_api_tool_decorator_requires_path(self): - """Test API tool decorator requires path parameter.""" - with pytest.raises(ValueError, match="API path is required"): - @ToolSimulator.api_tool("test_api", method="GET") - def sample_api_tool(**kwargs): - pass - - def test_api_tool_decorator_requires_method(self): - """Test API tool decorator requires method parameter.""" - with pytest.raises(ValueError, match="HTTP method is required"): - @ToolSimulator.api_tool("test_api", path="/test") - def sample_api_tool(**kwargs): - pass - - def test_function_tool_with_simulator_kwargs(self): - """Test function tool decorator with simulator kwargs.""" - @ToolSimulator.function_tool("test_function", share_state_id="shared_state") - def sample_function(): - pass - - registered_tool = ToolSimulator._registered_tools["test_function"] - assert registered_tool.simulator_kwargs == {"share_state_id": "shared_state"} - - -class TestToolSimulation: - """Test cases for tool simulation functionality.""" - - @patch('strands._async.run_async') - def test_function_tool_simulation(self, mock_run_async): - """Test function tool simulation.""" - # Setup mock response - mock_run_async.return_value = '{"result": "simulated response"}' - - # Register and create simulator - @ToolSimulator.function_tool("test_function") - def test_func(message: str) -> dict: - """Test function that should be simulated.""" - pass - - simulator = ToolSimulator() - - # Execute simulated function - result = simulator.test_function("Hello, world!") - - assert result == {"result": "simulated response"} - mock_run_async.assert_called_once() - - @patch('strands._async.run_async') - def test_mcp_tool_simulation(self, mock_run_async): - """Test MCP tool simulation.""" - # Setup mock response - mock_run_async.return_value = '{"content": [{"type": "text", "text": "MCP response"}]}' - - # Register and create simulator - schema = {"type": "object", "properties": {"param": {"type": "string"}}} - @ToolSimulator.mcp_tool("test_mcp", schema=schema) - def test_mcp(**params): - """Test MCP tool that should be simulated.""" - pass - - simulator = ToolSimulator() - - # Execute simulated MCP tool - result = simulator.test_mcp(param="test_value") - - assert result == {"content": [{"type": "text", "text": "MCP response"}]} - mock_run_async.assert_called_once() - - @patch('strands._async.run_async') - def test_api_tool_simulation(self, mock_run_async): - """Test API tool simulation.""" - # Setup mock response - mock_run_async.return_value = '{"status": 200, "data": {"key": "value"}}' - - # Register and create simulator - @ToolSimulator.api_tool("test_api", path="/test", method="POST") - def test_api(**kwargs): - """Test API tool that should be simulated.""" - pass - - simulator = ToolSimulator() - - # Execute simulated API tool - result = simulator.test_api(key="value") - - assert result == {"status": 200, "data": {"key": "value"}} - mock_run_async.assert_called_once() - - def test_implemented_function_uses_real_implementation(self): - """Test that functions with real implementations are not simulated.""" - @ToolSimulator.function_tool("implemented_function") - def real_function(x: int) -> dict: - """A function with real implementation.""" - return {"doubled": x * 2} - - simulator = ToolSimulator() - result = simulator.implemented_function(5) - - assert result == {"doubled": 10} - - def test_failure_conditions_trigger_error(self): - """Test that failure conditions trigger errors as expected.""" - # Register tool - @ToolSimulator.function_tool("failing_function") - def test_func(): - pass - - # Create failure conditions with 100% error rate - failure_conditions = FailureConditions( - enabled=True, - error_rate=1.0, - error_type="timeout_error", - error_message="Simulated timeout" - ) - tool_overrides = { - "failing_function": ToolOverrideConfig(failure_conditions=failure_conditions) - } - - simulator = ToolSimulator(tool_overrides=tool_overrides) - - # Function should return error due to failure conditions - result = simulator.failing_function() - - assert result["status"] == "error" - assert result["error_type"] == "timeout_error" - assert result["message"] == "Simulated timeout" +def test_failure_conditions_trigger_error(): + """Test that failure conditions trigger errors as expected.""" + # Register tool + @ToolSimulator.function_tool("failing_function") + def test_func(): + pass + + # Create failure conditions with 100% error rate + failure_conditions = FailureConditions( + enabled=True, + error_rate=1.0, + error_type="timeout_error", + error_message="Simulated timeout" + ) + tool_overrides = { + "failing_function": ToolOverrideConfig(failure_conditions=failure_conditions) + } + + simulator = ToolSimulator(tool_overrides=tool_overrides) + + # Function should return error due to failure conditions + result = simulator.failing_function() + + assert result["status"] == "error" + assert result["error_type"] == "timeout_error" + assert result["message"] == "Simulated timeout" -class TestToolRetrieval: - """Test cases for tool retrieval and listing.""" - def test_list_tools(self): - """Test listing registered tools.""" - @ToolSimulator.function_tool("func1") - def func1(): - pass +def test_list_tools(): + """Test listing registered tools.""" + @ToolSimulator.function_tool("func1") + def func1(): + pass - @ToolSimulator.function_tool("func2") - def func2(): - pass + @ToolSimulator.function_tool("func2") + def func2(): + pass - simulator = ToolSimulator() - tools = simulator.list_tools() - - assert set(tools) == {"func1", "func2"} + simulator = ToolSimulator() + tools = simulator.list_tools() + + assert set(tools) == {"func1", "func2"} - def test_get_tool(self): - """Test getting tool by name.""" - @ToolSimulator.function_tool("test_function") - def test_func(): - return {"test": "result"} - - simulator = ToolSimulator() - tool = simulator.get_tool("test_function") - - assert tool is not None - assert callable(tool) - def test_get_nonexistent_tool(self): - """Test getting non-existent tool returns None.""" - simulator = ToolSimulator() - tool = simulator.get_tool("nonexistent_tool") - - assert tool is None +@patch("strands_evals.simulation.tool_simulator.Agent") +def test_from_case_for_tool_simulator(mock_agent_class, sample_case): + """Test factory method creates simulator from case.""" + # Register a test tool first + @ToolSimulator.function_tool("account_balance_check") + def check_balance(account_id: str) -> dict: + """Check account balance.""" + pass + + # Mock agent response for override generation + mock_agent = MagicMock() + mock_override_response = { + "tool_overrides": [{ + "tool_name": "account_balance_check", + "should_simulate": True, + "failure_conditions": { + "enabled": False, + "error_rate": 0.0 + } + }] + } + mock_agent.return_value = json.dumps(mock_override_response) + mock_agent_class.return_value = mock_agent + + simulator = ToolSimulator.from_case_for_tool_simulator( + case=sample_case, + system_prompt_template="Test template", + model="test-model" + ) + + assert simulator is not None + assert simulator.system_prompt_template == "Test template" + + +@patch("strands_evals.simulation.tool_simulator.Agent") +def test_generate_override_from_case(mock_agent_class, sample_case): + """Test override generation from case.""" + # Register test tools + @ToolSimulator.function_tool("test_function") + def test_func(param: str) -> dict: + """Test function.""" + pass + + # Mock agent response + mock_agent = MagicMock() + mock_response = { + "tool_overrides": [{ + "tool_name": "test_function", + "should_simulate": True, + "failure_conditions": { + "enabled": True, + "error_rate": 0.1, + "error_type": "network_error", + "error_message": "Network timeout" + } + }] + } + mock_agent.return_value = json.dumps(mock_response) + mock_agent_class.return_value = mock_agent + + overrides = ToolSimulator._generate_override_from_case(sample_case) + + assert "test_function" in overrides + override = overrides["test_function"] + assert override.failure_conditions.enabled is True + assert override.failure_conditions.error_rate == 0.1 + assert override.failure_conditions.error_type == "network_error" - def test_tool_attribute_access(self): - """Test accessing tools as attributes.""" - @ToolSimulator.function_tool("test_function") - def test_func(): - return {"test": "result"} - - simulator = ToolSimulator() - - # Should be able to access as attribute - assert hasattr(simulator, "test_function") - tool = simulator.test_function - assert callable(tool) - - def test_nonexistent_tool_attribute_raises_error(self): - """Test accessing non-existent tool as attribute raises AttributeError.""" - simulator = ToolSimulator() - - with pytest.raises(AttributeError, match="Tool 'nonexistent' not found"): - _ = simulator.nonexistent +def test_shared_state_registry(mock_model): + """Test that function, MCP, and API tools can share the same state registry.""" + shared_state_id = "shared_banking_state" + initial_state = "Initial banking system state with account balances" + + # Register three different tools that share the same state + @ToolSimulator.function_tool( + "check_balance", + initial_state_description=initial_state, + share_state_id=shared_state_id + ) + def check_balance(account_id: str): + """Check account balance.""" + pass + + @ToolSimulator.mcp_tool( + "transfer_funds", + schema={"type": "object", "properties": {"from_account": {"type": "string"}, "to_account": {"type": "string"}}}, + initial_state_description=initial_state, + share_state_id=shared_state_id + ) + def transfer_funds(**params): + """Transfer funds between accounts.""" + pass + + @ToolSimulator.api_tool( + "get_transactions", + path="/transactions", + method="GET", + initial_state_description=initial_state, + share_state_id=shared_state_id + ) + def get_transactions(**kwargs): + """Get transaction history.""" + pass + + # Mock responses for each tool type based on call count + call_count = 0 + def mock_structured_output(output_type, messages, system_prompt=None): + nonlocal call_count + call_count += 1 + if call_count == 1: # First call (check_balance) + yield {"output": '{"balance": 1000, "currency": "USD"}'} + elif call_count == 2: # Second call (transfer_funds) + yield {"output": '{"content": [{"type": "text", "text": "Transfer completed"}]}'} + elif call_count == 3: # Third call (get_transactions) + yield {"output": '{"status": 200, "data": {"transactions": []}}'} + + mock_model.structured_output = mock_structured_output + simulator = ToolSimulator(model=mock_model) + + # Execute each tool in order + balance_result = simulator.check_balance("12345") + transfer_result = simulator.transfer_funds(from_account="12345", to_account="67890") + transactions_result = simulator.get_transactions(account_id="12345") + + # Verify results + assert balance_result == {"balance": 1000, "currency": "USD"} + assert transfer_result == {"content": [{"type": "text", "text": "Transfer completed"}]} + assert transactions_result == {"status": 200, "data": {"transactions": []}} + + # Verify all tools accessed the same shared state + shared_state = simulator._state_registry.get_state(shared_state_id) + assert "initial_state" in shared_state + assert shared_state["initial_state"] == initial_state + assert "previous_calls" in shared_state + assert len(shared_state["previous_calls"]) == 3 + + # Check that all three tool calls are recorded in the shared state + tool_names = [call["tool_name"] for call in shared_state["previous_calls"]] + assert "check_balance" in tool_names + assert "transfer_funds" in tool_names + assert "get_transactions" in tool_names + + # Verify each tool type recorded its specific data correctly + function_call = next(call for call in shared_state["previous_calls"] if call["tool_name"] == "check_balance") + assert "parameters" in function_call + + mcp_call = next(call for call in shared_state["previous_calls"] if call["tool_name"] == "transfer_funds") + assert "input_mcp_payload" in mcp_call + + api_call = next(call for call in shared_state["previous_calls"] if call["tool_name"] == "get_transactions") + assert "path" in api_call + assert "method" in api_call -class TestFactoryMethods: - """Test cases for factory methods.""" - @patch('strands._async.run_async') - def test_from_scenario_for_tool_simulator(self, mock_run_async, sample_scenario): - """Test factory method creates simulator from scenario.""" - # Register a test tool first - @ToolSimulator.function_tool("account_balance_check") - def check_balance(account_id: str) -> dict: - """Check account balance.""" - pass - - # Mock LLM response for override generation - mock_override_response = { - "tool_overrides": [{ - "tool_name": "account_balance_check", - "should_simulate": True, - "failure_conditions": { - "enabled": False, - "error_rate": 0.0 - } - }] - } - mock_run_async.return_value = json.dumps(mock_override_response) - - simulator = ToolSimulator.from_scenario_for_tool_simulator( - scenario_dict=sample_scenario, - system_prompt_template="Test template", - model="test-model" - ) - - assert simulator is not None - assert simulator.system_prompt_template == "Test template" - mock_run_async.assert_called_once() - - @patch('strands._async.run_async') - def test_generate_override_from_scenario(self, mock_run_async, sample_scenario): - """Test override generation from scenario.""" - # Register test tools - @ToolSimulator.function_tool("test_function") - def test_func(param: str) -> dict: - """Test function.""" - pass - - # Mock LLM response - mock_response = { - "tool_overrides": [{ - "tool_name": "test_function", - "should_simulate": True, - "failure_conditions": { - "enabled": True, - "error_rate": 0.1, - "error_type": "network_error", - "error_message": "Network timeout" - } - }] - } - mock_run_async.return_value = json.dumps(mock_response) - - overrides = ToolSimulator._generate_override_from_scenario(sample_scenario) - - assert "test_function" in overrides - override = overrides["test_function"] - assert override.failure_conditions.enabled is True - assert override.failure_conditions.error_rate == 0.1 - assert override.failure_conditions.error_type == "network_error" - mock_run_async.assert_called_once() - - def test_generate_override_with_no_tools(self, sample_scenario): - """Test override generation with no registered tools.""" - # Clear registry to ensure no tools - ToolSimulator.clear_registry() - - overrides = ToolSimulator._generate_override_from_scenario(sample_scenario) - - assert overrides == {} - - @patch('strands._async.run_async') - def test_generate_override_handles_llm_error(self, mock_run_async, sample_scenario): - """Test override generation handles LLM errors gracefully.""" - # Register a test tool - @ToolSimulator.function_tool("test_function") - def test_func(): - pass - - # Mock LLM to return invalid JSON - mock_run_async.return_value = "invalid json response" - - overrides = ToolSimulator._generate_override_from_scenario(sample_scenario) - - # Should return empty dict on error - assert overrides == {} +def test_record_function_call(): + """Test recording function call in state registry.""" + registry = StateRegistry() + + registry.record_function_call( + tool_name="test_tool", + state_key="test_state", + parameters={"param": "value"}, + response_data={"result": "success"} + ) + + state = registry.get_state("test_state") + assert "previous_calls" in state + assert len(state["previous_calls"]) == 1 + call = state["previous_calls"][0] + assert call["tool_name"] == "test_tool" + assert call["parameters"] == {"param": "value"} + assert call["response"] == {"result": "success"} + + +def test_record_mcp_tool_call(): + """Test recording MCP tool call in state registry.""" + registry = StateRegistry() + + registry.record_mcp_tool_call( + tool_name="mcp_tool", + state_key="mcp_state", + input_mcp_payload={"input": "data"}, + response_data={"content": [{"type": "text", "text": "result"}]} + ) + + state = registry.get_state("mcp_state") + assert "previous_calls" in state + assert len(state["previous_calls"]) == 1 + call = state["previous_calls"][0] + assert call["tool_name"] == "mcp_tool" + assert call["input_mcp_payload"] == {"input": "data"} -class TestStateRegistry: - """Test cases for StateRegistry functionality.""" +def test_record_api_call(): + """Test recording API call in state registry.""" + registry = StateRegistry() + + registry.record_api_call( + tool_name="api_tool", + state_key="api_state", + path="/test", + method="POST", + input_data={"data": "test"}, + response={"status": 200} + ) + + state = registry.get_state("api_state") + assert "previous_calls" in state + assert len(state["previous_calls"]) == 1 + call = state["previous_calls"][0] + assert call["tool_name"] == "api_tool" + assert call["path"] == "/test" + assert call["method"] == "POST" + assert call["input"] == {"data": "test"} + + +def test_parse_llm_response_valid_json(): + """Test parsing valid JSON response.""" + simulator = ToolSimulator() + + response = simulator._parse_llm_response('{"key": "value"}') + + assert response == {"key": "value"} - def test_state_registry_creation(self): - """Test StateRegistry is created properly.""" - registry = StateRegistry() - - assert registry is not None - assert registry._states == {} - def test_record_function_call(self): - """Test recording function call in state registry.""" - registry = StateRegistry() - - registry.record_function_call( - tool_name="test_tool", - state_key="test_state", - parameters={"param": "value"}, - response_data={"result": "success"} - ) - - state = registry.get_state("test_state") - assert "function_calls" in state - assert len(state["function_calls"]) == 1 - call = state["function_calls"][0] - assert call["tool_name"] == "test_tool" - assert call["parameters"] == {"param": "value"} - assert call["response"] == {"result": "success"} - - def test_record_mcp_tool_call(self): - """Test recording MCP tool call in state registry.""" - registry = StateRegistry() - - registry.record_mcp_tool_call( - tool_name="mcp_tool", - state_key="mcp_state", - input_mcp_payload={"input": "data"}, - response_data={"content": [{"type": "text", "text": "result"}]} - ) - - state = registry.get_state("mcp_state") - assert "mcp_calls" in state - assert len(state["mcp_calls"]) == 1 - call = state["mcp_calls"][0] - assert call["tool_name"] == "mcp_tool" - assert call["input"] == {"input": "data"} - - def test_record_api_call(self): - """Test recording API call in state registry.""" - registry = StateRegistry() - - registry.record_api_call( - tool_name="api_tool", - state_key="api_state", - path="/test", - method="POST", - input_data={"data": "test"}, - response={"status": 200} - ) - - state = registry.get_state("api_state") - assert "api_calls" in state - assert len(state["api_calls"]) == 1 - call = state["api_calls"][0] - assert call["tool_name"] == "api_tool" - assert call["path"] == "/test" - assert call["method"] == "POST" +def test_parse_llm_response_json_in_code_block(): + """Test parsing JSON from code blocks.""" + simulator = ToolSimulator() + + llm_text = '```json\n{"key": "value"}\n```' + response = simulator._parse_llm_response(llm_text) + + assert response == {"key": "value"} -class TestErrorHandling: - """Test cases for error handling.""" +def test_parse_llm_response_invalid_json_fallback(): + """Test fallback for invalid JSON.""" + simulator = ToolSimulator() + + response = simulator._parse_llm_response("This is not JSON") + + assert response == {"result": "This is not JSON"} - def test_parse_llm_response_valid_json(self): - """Test parsing valid JSON response.""" - simulator = ToolSimulator() - - response = simulator._parse_llm_response('{"key": "value"}') - - assert response == {"key": "value"} - def test_parse_llm_response_json_in_code_block(self): - """Test parsing JSON from code blocks.""" - simulator = ToolSimulator() - - llm_text = '```json\n{"key": "value"}\n```' - response = simulator._parse_llm_response(llm_text) - - assert response == {"key": "value"} +def test_create_error_response(): + """Test error response creation.""" + simulator = ToolSimulator() + + error = simulator._create_error_response("test_error", "Test message", 400) + + assert error["status"] == 400 + assert error["error"]["type"] == "test_error" + assert error["error"]["detail"] == "Test message" + assert error["error"]["title"] == "Bad Request" - def test_parse_llm_response_invalid_json_fallback(self): - """Test fallback for invalid JSON.""" - simulator = ToolSimulator() - - response = simulator._parse_llm_response("This is not JSON") - - assert response == {"result": "This is not JSON"} - def test_create_error_response(self): - """Test error response creation.""" - simulator = ToolSimulator() - - error = simulator._create_error_response("test_error", "Test message", 400) - - assert error["status"] == 400 - assert error["error"]["type"] == "test_error" - assert error["error"]["detail"] == "Test message" - assert error["error"]["title"] == "Bad Request" - - def test_get_error_title(self): - """Test error title mapping.""" - simulator = ToolSimulator() - - assert simulator._get_error_title(400) == "Bad Request" - assert simulator._get_error_title(404) == "Not Found" - assert simulator._get_error_title(500) == "Internal Server Error" - assert simulator._get_error_title(999) == "Error" # Unknown status code +def test_get_error_title(): + """Test error title mapping.""" + simulator = ToolSimulator() + + assert simulator._get_error_title(400) == "Bad Request" + assert simulator._get_error_title(404) == "Not Found" + assert simulator._get_error_title(500) == "Internal Server Error" + assert simulator._get_error_title(999) == "Error" # Unknown status code -class TestRegistryManagement: - """Test cases for registry management.""" +def test_clear_registry(): + """Test clearing tool registry.""" + @ToolSimulator.function_tool("test_function") + def test_func(): + pass + + assert len(ToolSimulator._registered_tools) == 1 + + ToolSimulator.clear_registry() + + assert len(ToolSimulator._registered_tools) == 0 + assert ToolSimulator._state_registry is None - def test_clear_registry(self): - """Test clearing tool registry.""" - @ToolSimulator.function_tool("test_function") - def test_func(): - pass - - assert len(ToolSimulator._registered_tools) == 1 - - ToolSimulator.clear_registry() - - assert len(ToolSimulator._registered_tools) == 0 - assert ToolSimulator._state_registry is None - def test_function_has_implementation_detection(self): - """Test detection of function implementation.""" - simulator = ToolSimulator() - - # Empty function should be detected as not implemented - def empty_func(): - pass - - def implemented_func(): - return {"result": "value"} - - assert not simulator._function_has_implementation(empty_func) - assert simulator._function_has_implementation(implemented_func) +def test_function_has_implementation_detection(): + """Test detection of function implementation.""" + simulator = ToolSimulator() + + # Empty function should be detected as not implemented + def empty_func(): + pass + + def implemented_func(): + return {"result": "value"} + + assert not simulator._function_has_implementation(empty_func) + assert simulator._function_has_implementation(implemented_func) - def test_function_has_implementation_error_handling(self): - """Test function implementation detection handles errors.""" - simulator = ToolSimulator() - - # Create a mock function that will cause dis.get_instructions to fail - mock_func = MagicMock() - mock_func.__code__ = None - - # Should assume implemented on error - result = simulator._function_has_implementation(mock_func) - assert result is True + +def test_function_tool_decorator_stacking_with_strands_tool(): + """Test function tool decorator stacking with Strands @tool decorator.""" + # Mock function that handles parameters with **kwargs + def mock_function(**kwargs): + input_value = kwargs.get("input_value", "") + return {"result": f"processed {input_value}"} + + # Define tool with stacked decorators + @tool + @ToolSimulator.function_tool("stacked_function_tool", mode="mock", + mock_function=mock_function) + def stacked_function_tool(input_value: str) -> Dict[str, Any]: + """Test function tool with stacked decorators. + + Args: + input_value: Input parameter for processing + """ + pass + + # Create simulator + simulator = ToolSimulator() + + # Test that the tool is callable and returns expected result + result = simulator.stacked_function_tool(input_value="test_input") + assert result == {"result": "processed test_input"} + + # Verify the tool is registered in ToolSimulator + assert "stacked_function_tool" in ToolSimulator._registered_tools + registered_tool = ToolSimulator._registered_tools["stacked_function_tool"] + assert registered_tool.tool_type == ToolType.FUNCTION + assert registered_tool.mode == "mock" + assert registered_tool.mock_function == mock_function + + # Validate Strands tool creation + assert stacked_function_tool.tool_spec is not None + spec = stacked_function_tool.tool_spec + + # Check basic spec properties + assert spec["name"] == "stacked_function_tool" + assert spec["description"] == "Test function tool with stacked decorators." + + # Check input schema + schema = spec["inputSchema"]["json"] + assert schema["type"] == "object" + assert set(schema["required"]) == {"input_value"} + + # Check parameter properties + assert schema["properties"]["input_value"]["type"] == "string" + assert schema["properties"]["input_value"]["description"] == "Input parameter for processing" + + # Make sure these are set properly + assert stacked_function_tool.__wrapped__ is not None + assert stacked_function_tool.__doc__ == stacked_function_tool._tool_func.__doc__ + + +def test_mcp_tool_decorator_stacking_with_strands_tool(): + """Test MCP tool decorator stacking with Strands @tool decorator.""" + # Mock function for MCP tool + def mock_mcp_processor(param1, param2=42): + return { + "content": [ + {"type": "text", "text": f"MCP processed: {param1} with value {param2}"} + ], + "isError": False + } + + schema = { + "type": "object", + "properties": { + "param1": {"type": "string"}, + "param2": {"type": "integer", "default": 42} + }, + "required": ["param1"] + } + + # Define tool with stacked decorators + @tool + @ToolSimulator.mcp_tool("stacked_mcp_tool", schema=schema, mode="mock", + mock_function=mock_mcp_processor) + def stacked_mcp_tool(param1: str, param2: int = 42) -> Dict[str, Any]: + """Test MCP tool with stacked decorators. + + Args: + param1: First parameter for MCP processing + param2: Second parameter with default value + """ + pass + + # Create simulator + simulator = ToolSimulator() + + # Test that the tool is callable and returns expected result + result = simulator.stacked_mcp_tool(param1="test", param2=100) + expected = { + "content": [{"type": "text", "text": "MCP processed: test with value 100"}], + "isError": False + } + assert result == expected + + # Verify the tool is registered in ToolSimulator + assert "stacked_mcp_tool" in ToolSimulator._registered_tools + registered_tool = ToolSimulator._registered_tools["stacked_mcp_tool"] + assert registered_tool.tool_type == ToolType.MCP + assert registered_tool.mode == "mock" + assert registered_tool.mock_function == mock_mcp_processor + + # Validate Strands tool creation + assert stacked_mcp_tool.tool_spec is not None + spec = stacked_mcp_tool.tool_spec + + # Check basic spec properties + assert spec["name"] == "stacked_mcp_tool" + assert spec["description"] == "Test MCP tool with stacked decorators." + + # Check input schema + schema = spec["inputSchema"]["json"] + assert schema["type"] == "object" + assert set(schema["required"]) == {"param1"} + + # Check parameter properties + assert schema["properties"]["param1"]["type"] == "string" + assert schema["properties"]["param2"]["type"] == "integer" + assert schema["properties"]["param1"]["description"] == "First parameter for MCP processing" + assert schema["properties"]["param2"]["description"] == "Second parameter with default value" + + # Make sure these are set properly + assert stacked_mcp_tool.__wrapped__ is not None + assert stacked_mcp_tool.__doc__ == stacked_mcp_tool._tool_func.__doc__ + + +def test_api_tool_decorator_stacking_with_strands_tool(): + """Test API tool decorator stacking with Strands @tool decorator.""" + # Static response for API tool + static_response = { + "status": 200, + "data": { + "message": "API tool working", + "timestamp": "2024-01-01T12:00:00Z", + "endpoint": "/test/api" + } + } + + # Define tool with stacked decorators + @tool + @ToolSimulator.api_tool("stacked_api_tool", path="/test/api", method="GET", + mode="static", static_response=static_response) + def stacked_api_tool(query: str = "") -> Dict[str, Any]: + """Test API tool with stacked decorators. + + Args: + query: Query parameter for API call + """ + pass + + # Create simulator + simulator = ToolSimulator() + + # Test that the tool is callable and returns expected result + result = simulator.stacked_api_tool(query="test_query") + assert result == static_response + + # Verify the tool is registered in ToolSimulator + assert "stacked_api_tool" in ToolSimulator._registered_tools + registered_tool = ToolSimulator._registered_tools["stacked_api_tool"] + assert registered_tool.tool_type == ToolType.API + assert registered_tool.mode == "static" + assert registered_tool.api_path == "/test/api" + assert registered_tool.api_method == "GET" + assert registered_tool.static_response == static_response + + # Validate Strands tool creation + assert stacked_api_tool.tool_spec is not None + spec = stacked_api_tool.tool_spec + + # Check basic spec properties + assert spec["name"] == "stacked_api_tool" + assert spec["description"] == "Test API tool with stacked decorators." + + # Check input schema + schema = spec["inputSchema"]["json"] + assert schema["type"] == "object" + # query parameter is optional, so required list may be empty or missing + required_fields = set(schema.get("required", [])) + assert required_fields == set() + + # Check parameter properties + assert schema["properties"]["query"]["type"] == "string" + assert schema["properties"]["query"]["description"] == "Query parameter for API call" + + # Make sure these are set properly + assert stacked_api_tool.__wrapped__ is not None + assert stacked_api_tool.__doc__ == stacked_api_tool._tool_func.__doc__ From cd0365a34120fa730948a4ec81666f068b08b63a Mon Sep 17 00:00:00 2001 From: Darren Wang Date: Tue, 27 Jan 2026 20:15:38 +0000 Subject: [PATCH 05/15] remove override and simplify pr --- .../simulation/prompt_templates/__init__.py | 4 +- .../tool_override_generation.py | 130 ----- .../tool_response_generation.py | 20 - .../simulation/tool_simulator.py | 477 +++++++++--------- src/strands_evals/types/simulation/tool.py | 251 +-------- .../simulation/test_tool_simulator.py | 152 +----- 6 files changed, 245 insertions(+), 789 deletions(-) delete mode 100644 src/strands_evals/simulation/prompt_templates/tool_override_generation.py diff --git a/src/strands_evals/simulation/prompt_templates/__init__.py b/src/strands_evals/simulation/prompt_templates/__init__.py index ab39053..e65aa51 100644 --- a/src/strands_evals/simulation/prompt_templates/__init__.py +++ b/src/strands_evals/simulation/prompt_templates/__init__.py @@ -3,7 +3,6 @@ from .actor_profile_extraction import ACTOR_PROFILE_PROMPT_TEMPLATE from .actor_system_prompt import DEFAULT_USER_SIMULATOR_PROMPT_TEMPLATE from .goal_completion import GOAL_COMPLETION_PROMPT -from .tool_override_generation import TOOL_OVERRIDE_GENERATION_PROMPT from .tool_response_generation import ( API_TOOL_RESPONSE_GENERATION_PROMPT, FUNCTION_TOOL_RESPONSE_GENERATION_PROMPT, @@ -12,9 +11,8 @@ __all__ = [ "ACTOR_PROFILE_PROMPT_TEMPLATE", - "DEFAULT_USER_SIMULATOR_PROMPT_TEMPLATE", + "DEFAULT_USER_SIMULATOR_PROMPT_TEMPLATE", "GOAL_COMPLETION_PROMPT", - "TOOL_OVERRIDE_GENERATION_PROMPT", "API_TOOL_RESPONSE_GENERATION_PROMPT", "FUNCTION_TOOL_RESPONSE_GENERATION_PROMPT", "MCP_TOOL_RESPONSE_GENERATION_PROMPT", diff --git a/src/strands_evals/simulation/prompt_templates/tool_override_generation.py b/src/strands_evals/simulation/prompt_templates/tool_override_generation.py deleted file mode 100644 index f61782a..0000000 --- a/src/strands_evals/simulation/prompt_templates/tool_override_generation.py +++ /dev/null @@ -1,130 +0,0 @@ -""" -Prompt template for tool override generation in Strands Evals. - -This module contains the prompt template used to analyze test scenarios and determine -optimal tool simulation strategies for agent evaluation workflows. It applies scientific -tool categorization to ensure consistent and appropriate simulation decisions across -different tool types and usage contexts. -""" - -from textwrap import dedent - -TOOL_OVERRIDE_GENERATION_PROMPT = dedent( - """You are an expert at analyzing test scenarios and determining optimal tool simulation strategies for agent evaluation workflows. - -Your primary objective is to apply SCIENTIFIC TOOL CATEGORIZATION and ensure CONSISTENCY in tool simulation decisions. - -## Scenario -{scenario} - -## Available Tools -{tools_json} - -## Scientific Tool Categorization Framework - -Based on comprehensive analysis of MCP servers and tool libraries, tools fall into four primary categories: - -### CATEGORY 1: COMPUTE TOOLS (Default: REAL) -**Characteristics**: Pure computational functions, no side effects, no state changes -**Examples**: Mathematical operations, FFT, string manipulation, date formatting, validation -**Simulation Strategy**: Connect directly via MCP - these are safe and deterministic -**Rationale**: No external dependencies, consistent results, low security risk - -### CATEGORY 2: DATABASE/PERSISTENT STATE TOOLS (Default: SIMULATE) -**Characteristics**: CRUD operations, booking systems, inventory management, resource allocation -**Examples**: create_booking(), update_inventory(), delete_user(), query_orders() -**Simulation Strategy**: MUST use synthetic/dummy databases with relevant test data -**Rationale**: Cannot connect to production DBs; subsequent operations depend on consistent state -**Critical Rule**: If ANY tool modifies a resource, ALL tools operating on that resource MUST be simulated - -### CATEGORY 3: ML MODEL TOOLS (Default: CONTEXT-DEPENDENT) -**Characteristics**: Calls to other ML models, AI services, content generation -**Examples**: image_generator(), text_classifier(), sentiment_analyzer(), llm_call() -**Simulation Strategy**: Evaluate based on scenario requirements and cost considerations -**Rationale**: May need human supervision; consider latency and cost implications - -### CATEGORY 4: SPECIALIZED TOOLS (Default: SIMULATE) -**Characteristics**: External integrations, infrastructure operations, specialized hardware -**Examples**: 3D renderers, CAD functions, game engines, deployment tools, notification services -**Simulation Strategy**: Require specialized support; simulate unless explicitly needed -**Rationale**: Complex dependencies, potential side effects, specialized environments - -## Consistency Rules (CRITICAL) - -**RULE 1 - Resource State Consistency**: -If tool A modifies resource R, then ALL tools B, C, D that operate on resource R MUST have the same simulation decision. -Example: cancel_flight(booking_id) simulated → get_flight_status(booking_id) must also be simulated - -**RULE 2 - Workflow Integrity**: -Tools in the same logical workflow should maintain consistent simulation decisions to preserve end-to-end test validity. - -**RULE 3 - External Service Consistency**: -If one tool calls external service S, related tools calling service S should have consistent simulation decisions. - -## Instructions - -For EACH tool, analyze: - -1. **Category Classification**: Determine which of the 4 categories (1-4) this tool belongs to -2. **Resource Dependencies**: Identify what resources/services this tool operates on -3. **Consistency Impact**: List other tools that must have matching simulation decisions -4. **Simulation Decision**: Apply category defaults, then adjust for consistency rules - -## Failure Conditions Specification - -Configure failure simulation with these parameters: - -{{ - "enabled": true, // Whether failure simulation is enabled (boolean) - "error_rate": 0.15, // Error rate between 0.0 and 1.0 (float) - "error_type": "timeout", // Error type (see allowed values below) - "error_message": "Custom error message" // Optional custom error message (string) -}} - -### Examples of Error Types: -- `"timeout"` - Request timeout errors -- `"execution_error"` - General execution failures -- `"network_error"` - Network connectivity issues -- `"authentication_error"` - Authentication failures -- `"authorization_error"` - Permission denied errors -- `"rate_limit_error"` - Rate limiting errors -- `"internal_error"` - Internal system errors - -### Failure Rate Guidelines: -- **0.0** - No failures (disabled) -- **0.01-0.05** - Low failure rate (1-5%) - production-like -- **0.1-0.2** - Medium failure rate (10-20%) - stress testing -- **0.3+** - High failure rate (30%+) - chaos engineering - -## Response Format - -{{ - "scenario_summary": "Brief summary of the scenario and testing objectives", - "resource_groups": {{ - "group_name": {{ - "description": "What this resource group represents", - "tools": ["tool1", "tool2", "tool3"], - "simulation_decision": true, - "rationale": "Why all tools in this group have the same decision" - }} - }}, - "tool_overrides": [ - {{ - "tool_name": "name_of_tool", - "category": 1, - "category_rationale": "Category 1: Compute tool - pure mathematical operation with no side effects", - "resource_dependencies": ["resource_name"], - "consistency_requirements": ["related_tool1", "related_tool2"], - "failure_conditions": {{ - "enabled": false, - "error_rate": 0.0, - "error_type": "execution_error", - "error_message": "Error message if failure occurs" - }}, - "rationale": "Simulation configuration rationale focusing on failure conditions and error patterns" - }} - ] -}} - -Generate only valid JSON with no markdown code blocks or additional explanation.""" -) diff --git a/src/strands_evals/simulation/prompt_templates/tool_response_generation.py b/src/strands_evals/simulation/prompt_templates/tool_response_generation.py index 55b7a52..363ef64 100644 --- a/src/strands_evals/simulation/prompt_templates/tool_response_generation.py +++ b/src/strands_evals/simulation/prompt_templates/tool_response_generation.py @@ -5,26 +5,6 @@ agent evaluation scenarios. These templates enable LLM-powered simulation of tool behavior when actual tools are not available or when consistent, controllable responses are needed for evaluation purposes. - -The module provides specialized templates for different tool types: -1. Function tools - Traditional Python function calls with parameters and return values -2. MCP tools - Model Context Protocol tools with structured input/output formats -3. API tools - REST API endpoints with HTTP request/response patterns - -Each template guides an LLM to: -- Analyze the tool name, parameters, and context to understand expected behavior -- Generate realistic responses that maintain consistency across the conversation -- Follow appropriate response formats for each tool type -- Consider previous tool responses to maintain state consistency in simulations - -Key Components: -- FUNCTION_TOOL_RESPONSE_GENERATION_PROMPT: Template for Python function tool simulation -- MCP_TOOL_RESPONSE_GENERATION_PROMPT: Template for MCP tool response generation -- API_TOOL_RESPONSE_GENERATION_PROMPT: Template for REST API endpoint simulation - -These templates ensure that simulated tool responses are contextually appropriate, -format-compliant, and maintain the illusion of real tool execution for effective -agent evaluation and testing scenarios. """ from textwrap import dedent diff --git a/src/strands_evals/simulation/tool_simulator.py b/src/strands_evals/simulation/tool_simulator.py index 287f426..b651299 100644 --- a/src/strands_evals/simulation/tool_simulator.py +++ b/src/strands_evals/simulation/tool_simulator.py @@ -1,25 +1,229 @@ import inspect import json import logging -import random +import warnings +from datetime import datetime from typing import Any, Callable, Dict, List, Optional -from strands import Agent from strands.models.bedrock import BedrockModel from strands.models.model import Model -from strands_evals.case import Case from strands_evals.types.simulation.tool import ( - FailureConditions, RegisteredTool, - StateRegistry, - ToolOverrideConfig, ToolType, ) logger = logging.getLogger(__name__) +class StateRegistry: + """ + State registry for managing shared state between tool simulators. + Organized by state_key to isolate state between different tools or shared state groups. + """ + + def __init__(self): + """ + Initialize state registry. + + Creates an empty state dictionary to track tool calls and responses + across different simulation sessions. + """ + self._states: Dict[str, Dict[str, Any]] = {} + + def initialize_state_via_description(self, initial_state_description: str, state_key: str) -> None: + """ + Initialize state based on the provided description. + + This method pre-seeds the state with an initial description that will be + included in all subsequent LLM prompts, allowing the simulator to have + context about pre-existing data or system state. + + Args: + initial_state_description: Description of the initial state (e.g., existing + database records, system configuration, etc.). + state_key: Key for the state in the registry (typically tool_name or share_state_id). + """ + if state_key not in self._states: + self._states[state_key] = { + "initial_state": initial_state_description, + "previous_calls": [], + "user_context": {}, + } + else: + warnings.warn(f"State with key '{state_key}' already initialized. Skipping re-initialization.") + + def get_state(self, state_key: str) -> Dict[str, Any]: + """ + Get state for a specific tool or shared state group. + + Args: + state_key: Key for the state (tool_name or share_state_id). + + Returns: + State dictionary containing previous_calls and user_context. + """ + if state_key is None: + raise ValueError("Value of state_key is required.") + + if state_key not in self._states: + self._states[state_key] = { + "previous_calls": [], + "user_context": {}, + } + + return dict(self._states[state_key]) + + def record_function_call( + self, + tool_name: str, + state_key: str, + parameters: Dict[str, Any], + response_data: Any, + ) -> Dict[str, Any]: + """ + Record a function call in the tool's state history. + + Args: + tool_name: Name of the function being called. + state_key: Key for the state (tool_name or share_state_id). + parameters: Parameters passed to the function. + response_data: Response from the function call. + + Returns: + Updated state dictionary. + """ + state = self.get_state(state_key) + date_timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + + state["previous_calls"].append({ + 'tool_name': tool_name, + 'tool_type': 'function', + 'parameters': parameters, + 'response': response_data, + 'timestamp': date_timestamp + }) + + # Keep history manageable + if len(state["previous_calls"]) > 20: + state["previous_calls"] = state["previous_calls"][-20:] + + # Update the stored state + self._states[state_key] = state + + return state + + def record_mcp_tool_call( + self, + tool_name: str, + state_key: str, + input_mcp_payload: Dict[str, Any], + response_data: Any, + ) -> Dict[str, Any]: + """ + Record an MCP tool call in the tool's state history. + + Args: + tool_name: Name of the MCP tool being called. + state_key: Key for the state (tool_name or share_state_id). + input_mcp_payload: Input payload for the MCP tool call. + response_data: Response from the MCP tool call. + + Returns: + Updated state dictionary. + """ + state = self.get_state(state_key) + date_timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + + state["previous_calls"].append({ + 'tool_name': tool_name, + 'tool_type': 'mcp', + 'input_mcp_payload': input_mcp_payload, + 'response': response_data, + 'timestamp': date_timestamp + }) + + # Keep history manageable + if len(state["previous_calls"]) > 20: + state["previous_calls"] = state["previous_calls"][-20:] + + # Update the stored state + self._states[state_key] = state + + return state + + def record_api_call( + self, + tool_name: str, + state_key: str, + path: str, + method: str, + input_data: Any, + response: Any, + ) -> Dict[str, Any]: + """ + Record an API call in the tool's state history. + + Args: + tool_name: Name of the API tool being called. + state_key: Key for the state (tool_name or share_state_id). + path: API endpoint path. + method: HTTP method. + input_data: Input data for the API call. + response: Response from the API call. + + Returns: + Updated state dictionary. + """ + state = self.get_state(state_key) + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + + state["previous_calls"].append({ + 'tool_name': tool_name, + 'tool_type': 'api', + 'path': path, + 'method': method, + 'input': input_data, + 'response': response, + 'timestamp': timestamp + }) + + # Keep history manageable + if len(state["previous_calls"]) > 20: + state["previous_calls"] = state["previous_calls"][-20:] + + # Update the stored state + self._states[state_key] = state + + return state + + def set_user_context(self, state_key: str, user_context: Dict[str, Any]) -> Dict[str, Any]: + """ + Set user context for a state. + + Args: + state_key: Key for the state (tool_name or share_state_id). + user_context: User context dictionary to store. + + Returns: + Updated state dictionary. + """ + state = self.get_state(state_key) + state["user_context"] = user_context + self._states[state_key] = state + return state + + def clear_state(self, state_key: str) -> None: + """ + Clear state for a specific tool or shared state group. + + Args: + state_key: Key for the state to clear. + """ + if state_key in self._states: + del self._states[state_key] + + class ToolSimulator: """ Simulates tool behavior with decorator-based registration system for agent evaluation. @@ -42,23 +246,45 @@ class ToolSimulator: def __init__( self, - tool_overrides: Optional[Dict[str, ToolOverrideConfig]] = None, state_registry: Optional[StateRegistry] = None, system_prompt_template: Optional[str] = None, + function_tool_prompt: Optional[str] = None, + mcp_tool_prompt: Optional[str] = None, + api_tool_prompt: Optional[str] = None, model: Model | str | None = None, ): """ Initialize a ToolSimulator instance. Args: - tool_overrides: Dictionary mapping tool names to ToolOverrideConfig instances state_registry: Registry for maintaining tool state system_prompt_template: Template for system prompts + function_tool_prompt: Optional custom prompt for function tool response generation + mcp_tool_prompt: Optional custom prompt for MCP tool response generation + api_tool_prompt: Optional custom prompt for API tool response generation model: Provider for running inference or a string representing the model-id for Bedrock to use """ - self.tool_overrides = tool_overrides or {} self.system_prompt_template = system_prompt_template + # Set custom prompts or use defaults + if function_tool_prompt is None: + from .prompt_templates.tool_response_generation import FUNCTION_TOOL_RESPONSE_GENERATION_PROMPT + self.function_tool_prompt = FUNCTION_TOOL_RESPONSE_GENERATION_PROMPT + else: + self.function_tool_prompt = function_tool_prompt + + if mcp_tool_prompt is None: + from .prompt_templates.tool_response_generation import MCP_TOOL_RESPONSE_GENERATION_PROMPT + self.mcp_tool_prompt = MCP_TOOL_RESPONSE_GENERATION_PROMPT + else: + self.mcp_tool_prompt = mcp_tool_prompt + + if api_tool_prompt is None: + from .prompt_templates.tool_response_generation import API_TOOL_RESPONSE_GENERATION_PROMPT + self.api_tool_prompt = API_TOOL_RESPONSE_GENERATION_PROMPT + else: + self.api_tool_prompt = api_tool_prompt + # Initialize model following Agent pattern self.model = BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model @@ -71,27 +297,6 @@ def __init__( # Initialize shared states from registered tools self._initialize_shared_states() - def _function_has_implementation(self, func: Callable) -> bool: - """Check if a function has actual implementation or is just an empty stub.""" - try: - import dis - # Get function bytecode - bytecode = list(dis.get_instructions(func)) - - # Check if function only contains simple return patterns - if len(bytecode) <= 3: - load_const_none_count = sum( - 1 for instr in bytecode if instr.opname == "LOAD_CONST" and instr.argval is None - ) - return_count = sum(1 for instr in bytecode if instr.opname == "RETURN_VALUE") - - if load_const_none_count >= 1 and return_count == 1 and len(bytecode) <= 3: - return False - - return True - except Exception: - # If we can't analyze bytecode, assume it's implemented - return True def _initialize_shared_states(self): """Initialize shared states from registered tools' initial descriptions.""" @@ -120,26 +325,6 @@ def _simulate_tool_call(self, tool_type: ToolType, state_key: str, input_data: D if not registered_tool: return self._create_error_response("tool_not_found", f"Tool '{tool_name}' not found") - # Handle tool behavior configuration - tool_override = input_data.get("tool_override", {}) - - # Check for failure conditions - failure_conditions = tool_override.get("failure_conditions", {}) - if failure_conditions and failure_conditions.get("enabled", False): - error_rate = failure_conditions.get("error_rate", 0.0) - if random.random() < error_rate: - error_type = failure_conditions.get("error_type", "execution_error") - error_message = failure_conditions.get("error_message", "An error occurred") - - if tool_type == ToolType.API: - return self._create_error_response(error_type, error_message) - elif tool_type in [ToolType.FUNCTION, ToolType.MCP]: - return { - "status": "error", - "error_type": error_type, - "message": error_message - } - # Handle different simulation modes if registered_tool.mode == "static": return self._handle_static_mode(registered_tool, tool_type) @@ -168,13 +353,11 @@ def _handle_function_tool(self, input_data: Dict[str, Any], state_key: str) -> D # Generate response using LLM try: - from .prompt_templates.tool_response_generation import FUNCTION_TOOL_RESPONSE_GENERATION_PROMPT - # Get initial state description from state registry current_state = self._state_registry.get_state(state_key) initial_state_description = current_state.get("initial_state", "No initial state provided.") - prompt = FUNCTION_TOOL_RESPONSE_GENERATION_PROMPT.format( + prompt = self.function_tool_prompt.format( tool_name=tool_name, parameters=json.dumps(parameters, indent=2) if parameters else "{}", initial_state_description=initial_state_description, @@ -205,13 +388,11 @@ def _handle_mcp_tool(self, input_data: Dict[str, Any], state_key: str) -> Dict[s } try: - from .prompt_templates.tool_response_generation import MCP_TOOL_RESPONSE_GENERATION_PROMPT - # Get initial state description from state registry current_state = self._state_registry.get_state(state_key) initial_state_description = current_state.get("initial_state", "No initial state provided.") - prompt = MCP_TOOL_RESPONSE_GENERATION_PROMPT.format( + prompt = self.mcp_tool_prompt.format( tool_name=tool_name, mcp_payload=json.dumps(input_mcp_payload, indent=2) if input_mcp_payload else "{}", initial_state_description=initial_state_description, @@ -244,13 +425,11 @@ def _handle_api_tool(self, input_data: Dict[str, Any], state_key: str) -> Dict[s return self._create_error_response("missing_tool_name", "Tool name is required", 400) try: - from .prompt_templates.tool_response_generation import API_TOOL_RESPONSE_GENERATION_PROMPT - # Get initial state description from state registry current_state = self._state_registry.get_state(state_key) initial_state_description = current_state.get("initial_state", "No initial state provided.") - prompt = API_TOOL_RESPONSE_GENERATION_PROMPT.format( + prompt = self.api_tool_prompt.format( tool_name=tool_name, path=path, method=method, @@ -588,156 +767,6 @@ def decorator(func: Callable) -> Callable: return decorator - @classmethod - def from_case_for_tool_simulator( - cls, - case: Case, - system_prompt_template: Optional[str] = None, - model: Optional[str] = None, - **kwargs, - ) -> "ToolSimulator": - """ - Create a ToolSimulator instance configured for a specific case. - - Args: - case: Case object containing test case information and metadata - system_prompt_template: Template for system prompts - model: Model identifier for LLM-based simulation - **kwargs: Additional configuration options - - Returns: - Configured ToolSimulator instance - """ - tool_overrides = cls._generate_override_from_case(case) - return cls( - tool_overrides=tool_overrides, - system_prompt_template=system_prompt_template, - model=model, - **kwargs, - ) - - @staticmethod - def _generate_override_from_case(case: Case) -> Dict[str, ToolOverrideConfig]: - """Generate tool override configuration from a case using LLM.""" - # Extract scenario description from case - scenario_description = f"Test case: {case.name or 'unnamed'}. Input: {case.input}" - if case.metadata: - scenario_description += f". Metadata: {case.metadata}" - - # Create tools list from registered tools - tools_list = [] - for tool_name, registered_tool in ToolSimulator._registered_tools.items(): - tool_info = { - "name": tool_name, - "type": registered_tool.tool_type.value, - "description": ( - getattr(registered_tool.function, "__doc__", "") - if registered_tool.function - else "" - ), - } - - # Add schema information based on tool type - if registered_tool.tool_type == ToolType.FUNCTION and registered_tool.function: - sig = inspect.signature(registered_tool.function) - parameters = {} - for param_name, param in sig.parameters.items(): - param_type = "string" - if param.annotation != inspect.Parameter.empty: - type_map = { - int: "integer", - float: "number", - bool: "boolean", - list: "array", - dict: "object", - str: "string", - } - param_type = type_map.get(param.annotation, "string") - - parameters[param_name] = { - "type": param_type, - "required": param.default == inspect.Parameter.empty, - } - - tool_info["parameters"] = parameters - - elif registered_tool.tool_type == ToolType.MCP and registered_tool.mcp_schema: - tool_info["schema"] = registered_tool.mcp_schema - - elif registered_tool.tool_type == ToolType.API: - tool_info["path"] = registered_tool.api_path - tool_info["method"] = registered_tool.api_method - - tools_list.append(tool_info) - - # If no registered tools, return empty override - if not tools_list: - logger.warning("No registered tools found for override generation") - return {} - - # Generate overrides using LLM prompt - try: - tools_json = json.dumps(tools_list, indent=2) - - # Use the tool override generation prompt - from .prompt_templates.tool_override_generation import TOOL_OVERRIDE_GENERATION_PROMPT - - prompt = TOOL_OVERRIDE_GENERATION_PROMPT.format( - scenario=scenario_description, - tools_json=tools_json, - ) - - # Generate response - agent = Agent(callback_handler=None) - result = agent(prompt) - llm_response = str(result) - - # Parse LLM response - try: - response_data = json.loads(llm_response.strip()) - except json.JSONDecodeError as e: - logger.error(f"Failed to parse LLM response as JSON: {e}") - logger.debug(f"Raw LLM response: {llm_response}") - return {} - - # Convert LLM response to ToolOverrideConfig instances - tool_configs: Dict[str, ToolOverrideConfig] = {} - tool_overrides = response_data.get("tool_overrides", []) - - for override in tool_overrides: - tool_name = override.get("tool_name") - - if not tool_name: - continue - - # Add failure conditions using new schema format - failure_conditions = override.get("failure_conditions", {}) - failure_conditions = { - "enabled": failure_conditions.get("enabled", False), - "error_rate": failure_conditions.get("error_rate", 0.0), - "error_type": failure_conditions.get("error_type", "execution_error"), - "error_message": failure_conditions.get("error_message"), - } - - try: - # Create FailureConditions instance - failure_conditions_instance = FailureConditions(**failure_conditions) - - # Create ToolOverrideConfig instance - tool_configs[tool_name] = ToolOverrideConfig( - failure_conditions=failure_conditions_instance, - ) - except Exception as e: - logger.warning(f"Failed to create ToolOverrideConfig for {tool_name}: {e}") - continue - - logger.info(f"Generated overrides for {len(tool_configs)} tools using LLM") - return tool_configs - - except Exception as e: - logger.error(f"Error generating overrides using LLM: {e}") - logger.warning("Falling back to empty override configuration") - return {} def get_tool(self, tool_name: str) -> Optional[Callable]: """ @@ -773,53 +802,23 @@ def wrapper(*args, **kwargs): else json.dumps(kwargs, indent=2) ) - # Get tool override configuration - tool_override_config = {} - if registered_tool.name in self.tool_overrides: - override_config = self.tool_overrides[registered_tool.name] - if override_config.failure_conditions: - tool_override_config["failure_conditions"] = override_config.failure_conditions.model_dump() - else: - tool_override_config["failure_conditions"] = {"enabled": False} - input_data = { "tool_name": registered_tool.name, "parameters": parameters_string, - "tool_override": tool_override_config, } elif registered_tool.tool_type == ToolType.MCP: - # Get tool override configuration - tool_override_config = {} - if registered_tool.name in self.tool_overrides: - override_config = self.tool_overrides[registered_tool.name] - if override_config.failure_conditions: - tool_override_config["failure_conditions"] = override_config.failure_conditions.model_dump() - else: - tool_override_config["failure_conditions"] = {"enabled": False} - input_data = { "tool_name": registered_tool.name, "input_mcp_payload": kwargs, - "tool_override": tool_override_config, } elif registered_tool.tool_type == ToolType.API: - # Get tool override configuration - tool_override_config = {} - if registered_tool.name in self.tool_overrides: - override_config = self.tool_overrides[registered_tool.name] - if override_config.failure_conditions: - tool_override_config["failure_conditions"] = override_config.failure_conditions.model_dump() - else: - tool_override_config["failure_conditions"] = {"enabled": False} - input_data = { "tool_name": registered_tool.name, "user_input_api_payload": kwargs, "path": registered_tool.api_path or "", "method": registered_tool.api_method or "GET", - "tool_override": tool_override_config, } else: diff --git a/src/strands_evals/types/simulation/tool.py b/src/strands_evals/types/simulation/tool.py index ee3e0ec..000b042 100644 --- a/src/strands_evals/types/simulation/tool.py +++ b/src/strands_evals/types/simulation/tool.py @@ -1,8 +1,7 @@ -from datetime import datetime from enum import Enum from typing import Any, Callable, Dict, Optional -from pydantic import BaseModel, Field, field_validator, model_validator +from pydantic import BaseModel, Field class ToolType(Enum): @@ -19,48 +18,6 @@ class ToolType(Enum): API = "api" -class FailureConditions(BaseModel): - """ - Configuration for failure simulation conditions. - - Attributes: - enabled: Whether failure simulation is enabled for the tool. - error_rate: Error rate between 0.0 and 1.0 for random failure injection. - error_type: Type of error to simulate when failures occur. - error_message: Optional custom error message for simulated failures. - """ - - enabled: bool = Field(default=False, description="Whether failure simulation is enabled") - error_rate: float = Field(default=0.0, ge=0.0, le=1.0, description="Error rate between 0.0 and 1.0") - error_type: str = Field(default="execution_error", description="Type of error to simulate") - error_message: Optional[str] = None - - @field_validator("error_rate") - @classmethod - def validate_error_rate(cls, v: float) -> float: - """Validate error rate is between 0 and 1.""" - if not 0.0 <= v <= 1.0: - raise ValueError("Error rate must be between 0.0 and 1.0") - return v - - @model_validator(mode='after') - def validate_enabled_state(self) -> 'FailureConditions': - """Validate that if enabled is True, error_rate is > 0.""" - if self.enabled and self.error_rate == 0.0: - raise ValueError("If failure conditions are enabled, error_rate must be greater than 0") - return self - - -class ToolOverrideConfig(BaseModel): - """ - Configuration for tool override behavior. - - Attributes: - failure_conditions: Configuration for failure simulation conditions. - """ - failure_conditions: FailureConditions = Field(default_factory=FailureConditions, description="Configuration for failure simulation") - - class RegisteredTool(BaseModel): """ Represents a registered tool in the simulator. @@ -89,209 +46,3 @@ class RegisteredTool(BaseModel): class Config: arbitrary_types_allowed = True - - -class StateRegistry: - """ - State registry for managing shared state between tool simulators. - Organized by state_key to isolate state between different tools or shared state groups. - """ - - def __init__(self): - """ - Initialize state registry. - - Creates an empty state dictionary to track tool calls and responses - across different simulation sessions. - """ - self._states: Dict[str, Dict[str, Any]] = {} - - def initialize_state_via_description(self, initial_state_description: str, state_key: str) -> None: - """ - Initialize state based on the provided description. - - This method pre-seeds the state with an initial description that will be - included in all subsequent LLM prompts, allowing the simulator to have - context about pre-existing data or system state. - - Args: - initial_state_description: Description of the initial state (e.g., existing - database records, system configuration, etc.). - state_key: Key for the state in the registry (typically tool_name or share_state_id). - """ - if state_key not in self._states: - self._states[state_key] = { - "initial_state": initial_state_description, - "previous_calls": [], - "user_context": {}, - } - - def get_state(self, state_key: str) -> Dict[str, Any]: - """ - Get state for a specific tool or shared state group. - - Args: - state_key: Key for the state (tool_name or share_state_id). - - Returns: - State dictionary containing previous_calls and user_context. - """ - if state_key is None: - raise ValueError("Value of state_key is required.") - - if state_key not in self._states: - self._states[state_key] = { - "previous_calls": [], - "user_context": {}, - } - - return dict(self._states[state_key]) - - def record_function_call( - self, - tool_name: str, - state_key: str, - parameters: Dict[str, Any], - response_data: Any, - ) -> Dict[str, Any]: - """ - Record a function call in the tool's state history. - - Args: - tool_name: Name of the function being called. - state_key: Key for the state (tool_name or share_state_id). - parameters: Parameters passed to the function. - response_data: Response from the function call. - - Returns: - Updated state dictionary. - """ - state = self.get_state(state_key) - date_timestamp = datetime.now().strftime("%Y%m%d%H%M%S") - - state["previous_calls"].append({ - 'tool_name': tool_name, - 'tool_type': 'function', - 'parameters': parameters, - 'response': response_data, - 'timestamp': date_timestamp - }) - - # Keep history manageable - if len(state["previous_calls"]) > 20: - state["previous_calls"] = state["previous_calls"][-20:] - - # Update the stored state - self._states[state_key] = state - - return state - - def record_mcp_tool_call( - self, - tool_name: str, - state_key: str, - input_mcp_payload: Dict[str, Any], - response_data: Any, - ) -> Dict[str, Any]: - """ - Record an MCP tool call in the tool's state history. - - Args: - tool_name: Name of the MCP tool being called. - state_key: Key for the state (tool_name or share_state_id). - input_mcp_payload: Input payload for the MCP tool call. - response_data: Response from the MCP tool call. - - Returns: - Updated state dictionary. - """ - state = self.get_state(state_key) - date_timestamp = datetime.now().strftime("%Y%m%d%H%M%S") - - state["previous_calls"].append({ - 'tool_name': tool_name, - 'tool_type': 'mcp', - 'input_mcp_payload': input_mcp_payload, - 'response': response_data, - 'timestamp': date_timestamp - }) - - # Keep history manageable - if len(state["previous_calls"]) > 20: - state["previous_calls"] = state["previous_calls"][-20:] - - # Update the stored state - self._states[state_key] = state - - return state - - def record_api_call( - self, - tool_name: str, - state_key: str, - path: str, - method: str, - input_data: Any, - response: Any, - ) -> Dict[str, Any]: - """ - Record an API call in the tool's state history. - - Args: - tool_name: Name of the API tool being called. - state_key: Key for the state (tool_name or share_state_id). - path: API endpoint path. - method: HTTP method. - input_data: Input data for the API call. - response: Response from the API call. - - Returns: - Updated state dictionary. - """ - state = self.get_state(state_key) - timestamp = datetime.now().strftime("%Y%m%d%H%M%S") - - state["previous_calls"].append({ - 'tool_name': tool_name, - 'tool_type': 'api', - 'path': path, - 'method': method, - 'input': input_data, - 'response': response, - 'timestamp': timestamp - }) - - # Keep history manageable - if len(state["previous_calls"]) > 20: - state["previous_calls"] = state["previous_calls"][-20:] - - # Update the stored state - self._states[state_key] = state - - return state - - def set_user_context(self, state_key: str, user_context: Dict[str, Any]) -> Dict[str, Any]: - """ - Set user context for a state. - - Args: - state_key: Key for the state (tool_name or share_state_id). - user_context: User context dictionary to store. - - Returns: - Updated state dictionary. - """ - state = self.get_state(state_key) - state["user_context"] = user_context - self._states[state_key] = state - return state - - def clear_state(self, state_key: str) -> None: - """ - Clear state for a specific tool or shared state group. - - Args: - state_key: Key for the state to clear. - """ - if state_key in self._states: - del self._states[state_key] diff --git a/tests/strands_evals/simulation/test_tool_simulator.py b/tests/strands_evals/simulation/test_tool_simulator.py index 1846dcc..847341e 100644 --- a/tests/strands_evals/simulation/test_tool_simulator.py +++ b/tests/strands_evals/simulation/test_tool_simulator.py @@ -1,41 +1,14 @@ """Tests for ToolSimulator class.""" -import json from typing import Any, Dict -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest -from strands import Agent, tool +from strands import tool from strands_evals.case import Case -from strands_evals.simulation.tool_simulator import ToolSimulator -from strands_evals.types.simulation.tool import ( - FailureConditions, - RegisteredTool, - StateRegistry, - ToolOverrideConfig, - ToolType, -) - - -@pytest.fixture -def sample_failure_conditions(): - """Fixture providing sample failure conditions.""" - return FailureConditions( - enabled=True, - error_rate=0.5, - error_type="timeout_error", - error_message="Operation timed out", - ) - - -@pytest.fixture -def sample_tool_override_config(sample_failure_conditions): - """Fixture providing sample tool override configuration.""" - return ToolOverrideConfig( - failure_conditions=sample_failure_conditions, - scenario_config={"test_key": "test_value"}, - ) +from strands_evals.simulation.tool_simulator import ToolSimulator, StateRegistry +from strands_evals.types.simulation.tool import ToolType @pytest.fixture @@ -73,20 +46,17 @@ def clear_registry(): ToolSimulator.clear_registry() -def test_tool_simulator_init(sample_tool_override_config): +def test_tool_simulator_init(): """Test ToolSimulator initialization with all parameters.""" custom_registry = StateRegistry() - tool_overrides = {"test_tool": sample_tool_override_config} template = "You are a helpful assistant simulating tools." simulator = ToolSimulator( - tool_overrides=tool_overrides, state_registry=custom_registry, system_prompt_template=template, model=None, ) - assert simulator.tool_overrides == tool_overrides assert simulator._state_registry is custom_registry assert simulator.system_prompt_template == template assert simulator.model is not None @@ -205,34 +175,6 @@ def mock_structured_output(output_type, messages, system_prompt=None): assert result == {"status": 200, "data": {"key": "value"}} -def test_failure_conditions_trigger_error(): - """Test that failure conditions trigger errors as expected.""" - # Register tool - @ToolSimulator.function_tool("failing_function") - def test_func(): - pass - - # Create failure conditions with 100% error rate - failure_conditions = FailureConditions( - enabled=True, - error_rate=1.0, - error_type="timeout_error", - error_message="Simulated timeout" - ) - tool_overrides = { - "failing_function": ToolOverrideConfig(failure_conditions=failure_conditions) - } - - simulator = ToolSimulator(tool_overrides=tool_overrides) - - # Function should return error due to failure conditions - result = simulator.failing_function() - - assert result["status"] == "error" - assert result["error_type"] == "timeout_error" - assert result["message"] == "Simulated timeout" - - def test_list_tools(): """Test listing registered tools.""" @ToolSimulator.function_tool("func1") @@ -249,75 +191,6 @@ def func2(): assert set(tools) == {"func1", "func2"} -@patch("strands_evals.simulation.tool_simulator.Agent") -def test_from_case_for_tool_simulator(mock_agent_class, sample_case): - """Test factory method creates simulator from case.""" - # Register a test tool first - @ToolSimulator.function_tool("account_balance_check") - def check_balance(account_id: str) -> dict: - """Check account balance.""" - pass - - # Mock agent response for override generation - mock_agent = MagicMock() - mock_override_response = { - "tool_overrides": [{ - "tool_name": "account_balance_check", - "should_simulate": True, - "failure_conditions": { - "enabled": False, - "error_rate": 0.0 - } - }] - } - mock_agent.return_value = json.dumps(mock_override_response) - mock_agent_class.return_value = mock_agent - - simulator = ToolSimulator.from_case_for_tool_simulator( - case=sample_case, - system_prompt_template="Test template", - model="test-model" - ) - - assert simulator is not None - assert simulator.system_prompt_template == "Test template" - - -@patch("strands_evals.simulation.tool_simulator.Agent") -def test_generate_override_from_case(mock_agent_class, sample_case): - """Test override generation from case.""" - # Register test tools - @ToolSimulator.function_tool("test_function") - def test_func(param: str) -> dict: - """Test function.""" - pass - - # Mock agent response - mock_agent = MagicMock() - mock_response = { - "tool_overrides": [{ - "tool_name": "test_function", - "should_simulate": True, - "failure_conditions": { - "enabled": True, - "error_rate": 0.1, - "error_type": "network_error", - "error_message": "Network timeout" - } - }] - } - mock_agent.return_value = json.dumps(mock_response) - mock_agent_class.return_value = mock_agent - - overrides = ToolSimulator._generate_override_from_case(sample_case) - - assert "test_function" in overrides - override = overrides["test_function"] - assert override.failure_conditions.enabled is True - assert override.failure_conditions.error_rate == 0.1 - assert override.failure_conditions.error_type == "network_error" - - def test_shared_state_registry(mock_model): """Test that function, MCP, and API tools can share the same state registry.""" shared_state_id = "shared_banking_state" @@ -530,21 +403,6 @@ def test_func(): assert ToolSimulator._state_registry is None -def test_function_has_implementation_detection(): - """Test detection of function implementation.""" - simulator = ToolSimulator() - - # Empty function should be detected as not implemented - def empty_func(): - pass - - def implemented_func(): - return {"result": "value"} - - assert not simulator._function_has_implementation(empty_func) - assert simulator._function_has_implementation(implemented_func) - - def test_function_tool_decorator_stacking_with_strands_tool(): """Test function tool decorator stacking with Strands @tool decorator.""" # Mock function that handles parameters with **kwargs From 4d57a5393b73ecd4d788ae0338a65ad2b8981e6b Mon Sep 17 00:00:00 2001 From: Darren Wang Date: Tue, 27 Jan 2026 22:03:01 +0000 Subject: [PATCH 06/15] replace llm call with agent; simplify error raise --- .../simulation/tool_simulator.py | 554 +++++++----------- src/strands_evals/types/simulation/tool.py | 45 +- .../simulation/test_tool_simulator.py | 253 ++++---- 3 files changed, 400 insertions(+), 452 deletions(-) diff --git a/src/strands_evals/simulation/tool_simulator.py b/src/strands_evals/simulation/tool_simulator.py index b651299..0454cd2 100644 --- a/src/strands_evals/simulation/tool_simulator.py +++ b/src/strands_evals/simulation/tool_simulator.py @@ -5,12 +5,14 @@ from datetime import datetime from typing import Any, Callable, Dict, List, Optional -from strands.models.bedrock import BedrockModel +from strands import Agent from strands.models.model import Model from strands_evals.types.simulation.tool import ( RegisteredTool, ToolType, + MCPToolResponse, + APIToolResponse, ) logger = logging.getLogger(__name__) @@ -74,103 +76,23 @@ def get_state(self, state_key: str) -> Dict[str, Any]: return dict(self._states[state_key]) - def record_function_call( + def record_tool_call( self, tool_name: str, state_key: str, - parameters: Dict[str, Any], + tool_type: ToolType, response_data: Any, + **call_data: Any, ) -> Dict[str, Any]: """ - Record a function call in the tool's state history. + Record a tool call in the tool's state history. Args: - tool_name: Name of the function being called. + tool_name: Name of the tool being called. state_key: Key for the state (tool_name or share_state_id). - parameters: Parameters passed to the function. - response_data: Response from the function call. - - Returns: - Updated state dictionary. - """ - state = self.get_state(state_key) - date_timestamp = datetime.now().strftime("%Y%m%d%H%M%S") - - state["previous_calls"].append({ - 'tool_name': tool_name, - 'tool_type': 'function', - 'parameters': parameters, - 'response': response_data, - 'timestamp': date_timestamp - }) - - # Keep history manageable - if len(state["previous_calls"]) > 20: - state["previous_calls"] = state["previous_calls"][-20:] - - # Update the stored state - self._states[state_key] = state - - return state - - def record_mcp_tool_call( - self, - tool_name: str, - state_key: str, - input_mcp_payload: Dict[str, Any], - response_data: Any, - ) -> Dict[str, Any]: - """ - Record an MCP tool call in the tool's state history. - - Args: - tool_name: Name of the MCP tool being called. - state_key: Key for the state (tool_name or share_state_id). - input_mcp_payload: Input payload for the MCP tool call. - response_data: Response from the MCP tool call. - - Returns: - Updated state dictionary. - """ - state = self.get_state(state_key) - date_timestamp = datetime.now().strftime("%Y%m%d%H%M%S") - - state["previous_calls"].append({ - 'tool_name': tool_name, - 'tool_type': 'mcp', - 'input_mcp_payload': input_mcp_payload, - 'response': response_data, - 'timestamp': date_timestamp - }) - - # Keep history manageable - if len(state["previous_calls"]) > 20: - state["previous_calls"] = state["previous_calls"][-20:] - - # Update the stored state - self._states[state_key] = state - - return state - - def record_api_call( - self, - tool_name: str, - state_key: str, - path: str, - method: str, - input_data: Any, - response: Any, - ) -> Dict[str, Any]: - """ - Record an API call in the tool's state history. - - Args: - tool_name: Name of the API tool being called. - state_key: Key for the state (tool_name or share_state_id). - path: API endpoint path. - method: HTTP method. - input_data: Input data for the API call. - response: Response from the API call. + tool_type: Type of the tool (FUNCTION, MCP, or API). + response_data: Response from the tool call. + **call_data: Tool-specific call data (parameters, input_mcp_payload, path, method, input_data, etc.). Returns: Updated state dictionary. @@ -178,15 +100,27 @@ def record_api_call( state = self.get_state(state_key) timestamp = datetime.now().strftime("%Y%m%d%H%M%S") - state["previous_calls"].append({ + # Build call record based on tool type + call_record = { 'tool_name': tool_name, - 'tool_type': 'api', - 'path': path, - 'method': method, - 'input': input_data, - 'response': response, + 'tool_type': tool_type.value, + 'response': response_data, 'timestamp': timestamp - }) + } + + # Add tool-specific fields + if tool_type == ToolType.FUNCTION: + call_record['parameters'] = call_data.get('parameters', {}) + elif tool_type == ToolType.MCP: + call_record['input_mcp_payload'] = call_data.get('input_mcp_payload', {}) + elif tool_type == ToolType.API: + call_record.update({ + 'path': call_data.get('path', ''), + 'method': call_data.get('method', 'GET'), + 'input': call_data.get('input_data', {}) + }) + + state["previous_calls"].append(call_record) # Keep history manageable if len(state["previous_calls"]) > 20: @@ -197,22 +131,6 @@ def record_api_call( return state - def set_user_context(self, state_key: str, user_context: Dict[str, Any]) -> Dict[str, Any]: - """ - Set user context for a state. - - Args: - state_key: Key for the state (tool_name or share_state_id). - user_context: User context dictionary to store. - - Returns: - Updated state dictionary. - """ - state = self.get_state(state_key) - state["user_context"] = user_context - self._states[state_key] = state - return state - def clear_state(self, state_key: str) -> None: """ Clear state for a specific tool or shared state group. @@ -233,8 +151,6 @@ class ToolSimulator: behavior for simulation purposes, enabling controlled testing scenarios. Attributes: - tool_overrides: Dictionary mapping tool names to override configurations. - system_prompt_template: Template string for system prompts. model: Provider for running inference or model identifier for Bedrock. _registered_tools: Class-level registry for all registered tools. _state_registry: Registry for maintaining tool state across calls. @@ -247,7 +163,6 @@ class ToolSimulator: def __init__( self, state_registry: Optional[StateRegistry] = None, - system_prompt_template: Optional[str] = None, function_tool_prompt: Optional[str] = None, mcp_tool_prompt: Optional[str] = None, api_tool_prompt: Optional[str] = None, @@ -258,13 +173,13 @@ def __init__( Args: state_registry: Registry for maintaining tool state - system_prompt_template: Template for system prompts function_tool_prompt: Optional custom prompt for function tool response generation mcp_tool_prompt: Optional custom prompt for MCP tool response generation api_tool_prompt: Optional custom prompt for API tool response generation model: Provider for running inference or a string representing the model-id for Bedrock to use """ - self.system_prompt_template = system_prompt_template + # Store model configuration for creating internal agents + self.model_id = model # Set custom prompts or use defaults if function_tool_prompt is None: @@ -284,9 +199,6 @@ def __init__( self.api_tool_prompt = API_TOOL_RESPONSE_GENERATION_PROMPT else: self.api_tool_prompt = api_tool_prompt - - # Initialize model following Agent pattern - self.model = BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model # Set up state registry if state_registry: @@ -323,7 +235,7 @@ def _simulate_tool_call(self, tool_type: ToolType, state_key: str, input_data: D registered_tool = self._registered_tools.get(tool_name) if not registered_tool: - return self._create_error_response("tool_not_found", f"Tool '{tool_name}' not found") + raise ValueError(f"Tool '{tool_name}' not registered") # Handle different simulation modes if registered_tool.mode == "static": @@ -339,9 +251,9 @@ def _simulate_tool_call(self, tool_type: ToolType, state_key: str, input_data: D elif tool_type == ToolType.API: return self._handle_api_tool(input_data, state_key) else: - return self._create_error_response("unsupported_tool_type", f"Tool type '{tool_type}' not supported") + raise ValueError(f"Tool type '{tool_type}' not supported") else: - return self._create_error_response("unsupported_mode", f"Simulation mode '{registered_tool.mode}' not supported") + raise ValueError(f"Tool simulation mode '{registered_tool.mode}' not supported") def _handle_function_tool(self, input_data: Dict[str, Any], state_key: str) -> Dict[str, Any]: """Handle function tool simulation.""" @@ -351,7 +263,7 @@ def _handle_function_tool(self, input_data: Dict[str, Any], state_key: str) -> D if not tool_name: return {"status": "error", "error_type": "missing_tool_name", "message": "Tool name is required"} - # Generate response using LLM + # Generate response using structured output try: # Get initial state description from state registry current_state = self._state_registry.get_state(state_key) @@ -364,11 +276,21 @@ def _handle_function_tool(self, input_data: Dict[str, Any], state_key: str) -> D previous_responses=json.dumps(current_state, indent=2) or "{}" ) - llm_response = self._generate_llm_response(prompt) - response_data = self._parse_llm_response(llm_response) + # Create agent and generate response with structured output + # Use dict for function responses since they vary based on function signature + agent = Agent( + system_prompt=self.function_tool_prompt, + tools=[], + model=self.model_id, + callback_handler=None, + ) + result = agent(prompt, structured_output_model=dict) + response_data = result.structured_output # Record the call - self._state_registry.record_function_call(tool_name, state_key, parameters, response_data) + self._state_registry.record_tool_call( + tool_name, state_key, ToolType.FUNCTION, response_data, parameters=parameters + ) return response_data @@ -399,11 +321,20 @@ def _handle_mcp_tool(self, input_data: Dict[str, Any], state_key: str) -> Dict[s previous_responses=json.dumps(current_state, indent=2) or "{}" ) - llm_response = self._generate_llm_response(prompt) - response_data = self._parse_llm_response(llm_response) + # Create agent and generate response with structured output + agent = Agent( + system_prompt=self.mcp_tool_prompt, + tools=[], + model=self.model_id, + callback_handler=None, + ) + result = agent(prompt, structured_output_model=MCPToolResponse) + response_data = result.structured_output.model_dump() # Record the call - self._state_registry.record_mcp_tool_call(tool_name, state_key, input_mcp_payload, response_data) + self._state_registry.record_tool_call( + tool_name, state_key, ToolType.MCP, response_data, input_mcp_payload=input_mcp_payload + ) return response_data @@ -422,7 +353,7 @@ def _handle_api_tool(self, input_data: Dict[str, Any], state_key: str) -> Dict[s method = input_data.get("method", "GET") if not tool_name: - return self._create_error_response("missing_tool_name", "Tool name is required", 400) + raise ValueError("tool_name is required for API tool simulation") try: # Get initial state description from state registry @@ -438,103 +369,26 @@ def _handle_api_tool(self, input_data: Dict[str, Any], state_key: str) -> Dict[s previous_responses=json.dumps(current_state, indent=2) or "{}" ) - llm_response = self._generate_llm_response(prompt) - response_data = self._parse_llm_response(llm_response) - - # Ensure proper API response format - if "status" not in response_data: - response_data = {"status": 200, "data": response_data} + # Create agent and generate response with structured output + agent = Agent( + system_prompt=self.api_tool_prompt, + tools=[], + model=self.model_id, + callback_handler=None, + ) + result = agent(prompt, structured_output_model=APIToolResponse) + response_data = result.structured_output.model_dump() # Record the call - self._state_registry.record_api_call(tool_name, state_key, path, method, user_input_api_payload, response_data) + self._state_registry.record_tool_call( + tool_name, state_key, ToolType.API, response_data, + path=path, method=method, input_data=user_input_api_payload + ) return response_data except Exception as e: - logger.error(f"Error generating API response: {e}") - return self._create_error_response("generation_error", str(e), 500) - - def _generate_llm_response(self, prompt: str) -> str: - """ - Generate LLM response using the model for a given prompt. - - Args: - prompt: The prompt string to send to the LLM - - Returns: - Raw LLM response text - - Raises: - Exception: If LLM generation fails - """ - try: - # Create message for model inference - messages = [{"role": "user", "content": [{"text": prompt}]}] - - # Generate response - llm_response = "" - for event in self.model.structured_output(str, messages, system_prompt=self.system_prompt_template): - if hasattr(event, 'get') and event.get("contentBlockDelta"): - delta = event["contentBlockDelta"] - if "text" in delta: - llm_response += delta["text"] - elif hasattr(event, 'get') and event.get("message"): - # Handle final message - content = event["message"].get("content", []) - for block in content: - if "text" in block: - llm_response += block["text"] - elif hasattr(event, 'get') and event.get("output"): - # Handle structured output result - return str(event["output"]) - - return llm_response - - except Exception as e: - logger.error(f"Error generating LLM response: {e}") - raise - - def _parse_llm_response(self, llm_response: str) -> Dict[str, Any]: - """Parse LLM response with fallback handling.""" - try: - return json.loads(llm_response) - except json.JSONDecodeError: - # Try to extract JSON from code blocks - import re - json_matches = re.findall(r'```(?:json)?\s*([\s\S]*?)\s*```', llm_response) - - for json_str in json_matches: - try: - return json.loads(json_str) - except json.JSONDecodeError: - continue - - # Fallback to simple text response - return {"result": llm_response} - - def _create_error_response(self, error_type: str, error_message: str, status_code: int = 400) -> Dict[str, Any]: - """Create standardized error response.""" - return { - "status": status_code, - "error": { - "type": error_type, - "title": self._get_error_title(status_code), - "detail": error_message - } - } - - def _get_error_title(self, status_code: int) -> str: - """Get error title based on status code.""" - error_titles = { - 400: 'Bad Request', - 401: 'Unauthorized', - 403: 'Forbidden', - 404: 'Not Found', - 429: 'Too Many Requests', - 500: 'Internal Server Error', - 503: 'Service Unavailable' - } - return error_titles.get(status_code, 'Error') + raise RuntimeError(f"Error generating simulated API response: {e}") def _handle_static_mode(self, registered_tool: RegisteredTool, tool_type: ToolType) -> Dict[str, Any]: """Handle static mode simulation - returns predefined static response.""" @@ -556,56 +410,138 @@ def _handle_static_mode(self, registered_tool: RegisteredTool, tool_type: ToolTy def _handle_mock_mode(self, registered_tool: RegisteredTool, input_data: Dict[str, Any], state_key: str, tool_type: ToolType) -> Dict[str, Any]: """Handle mock mode simulation - calls custom mock function.""" - if registered_tool.mock_function is not None: - try: - # Extract parameters based on tool type - if tool_type == ToolType.FUNCTION: - parameters = input_data.get("parameters", {}) - if isinstance(parameters, str): - parameters = json.loads(parameters) - - # Call mock function with extracted parameters - if "kwargs" in parameters: - result = registered_tool.mock_function(**parameters["kwargs"]) - elif "args" in parameters: - result = registered_tool.mock_function(*parameters["args"]) - else: - result = registered_tool.mock_function(**parameters) - - elif tool_type == ToolType.MCP: - input_mcp_payload = input_data.get("input_mcp_payload", {}) - result = registered_tool.mock_function(**input_mcp_payload) - - elif tool_type == ToolType.API: - user_input_api_payload = input_data.get("user_input_api_payload", {}) - result = registered_tool.mock_function(**user_input_api_payload) - + if registered_tool.mock_function is None: + raise ValueError("mock_function is required for tool simulator mock mode") + + try: + # Extract parameters based on tool type + if tool_type == ToolType.FUNCTION: + parameters = input_data.get("parameters", {}) + if isinstance(parameters, str): + parameters = json.loads(parameters) + + # Call mock function with extracted parameters + if "kwargs" in parameters: + result = registered_tool.mock_function(**parameters["kwargs"]) + elif "args" in parameters: + result = registered_tool.mock_function(*parameters["args"]) else: - return {"status": "error", "message": "Unsupported tool type for mock mode"} + result = registered_tool.mock_function(**parameters) - # Record the call in state registry - tool_name = registered_tool.name - if tool_type == ToolType.FUNCTION: - self._state_registry.record_function_call(tool_name, state_key, parameters, result) - elif tool_type == ToolType.MCP: - self._state_registry.record_mcp_tool_call(tool_name, state_key, input_mcp_payload, result) - elif tool_type == ToolType.API: - path = input_data.get("path", "") - method = input_data.get("method", "GET") - self._state_registry.record_api_call(tool_name, state_key, path, method, user_input_api_payload, result) + elif tool_type == ToolType.MCP: + input_mcp_payload = input_data.get("input_mcp_payload", {}) + result = registered_tool.mock_function(**input_mcp_payload) - return result + elif tool_type == ToolType.API: + user_input_api_payload = input_data.get("user_input_api_payload", {}) + result = registered_tool.mock_function(**user_input_api_payload) - except Exception as e: - logger.error(f"Error calling mock function for {registered_tool.name}: {e}") - if tool_type == ToolType.API: - return self._create_error_response("mock_error", str(e), 500) - else: - return {"status": "error", "error_type": "mock_error", "message": str(e)} + else: + return {"status": "error", "message": "Unsupported tool type for mock mode"} + + # Record the call in state registry + tool_name = registered_tool.name + if tool_type == ToolType.FUNCTION: + self._state_registry.record_tool_call( + tool_name, state_key, ToolType.FUNCTION, result, parameters=parameters + ) + elif tool_type == ToolType.MCP: + self._state_registry.record_tool_call( + tool_name, state_key, ToolType.MCP, result, input_mcp_payload=input_mcp_payload + ) + elif tool_type == ToolType.API: + path = input_data.get("path", "") + method = input_data.get("method", "GET") + self._state_registry.record_tool_call( + tool_name, state_key, ToolType.API, result, + path=path, method=method, input_data=user_input_api_payload + ) + + return result + + except Exception as e: + raise RuntimeError(f"Tool simulator mock mode error for {tool_type} tool {registered_tool.name}: {e}") + + def _create_tool_wrapper(self, registered_tool: RegisteredTool) -> Callable: + """Create a wrapper function for direct tool access.""" + def wrapper(*args, **kwargs): + # Determine state key + state_key = ( + registered_tool.simulator_kwargs.get("share_state_id", registered_tool.name) + if registered_tool.simulator_kwargs + else registered_tool.name + ) + + # Build input data based on tool type + if registered_tool.tool_type == ToolType.FUNCTION: + parameters_string = ( + json.dumps({"args": args, "kwargs": kwargs}, indent=2) + if args + else json.dumps(kwargs, indent=2) + ) + + input_data = { + "tool_name": registered_tool.name, + "parameters": parameters_string, + } + + elif registered_tool.tool_type == ToolType.MCP: + input_data = { + "tool_name": registered_tool.name, + "input_mcp_payload": kwargs, + } + + elif registered_tool.tool_type == ToolType.API: + input_data = { + "tool_name": registered_tool.name, + "user_input_api_payload": kwargs, + "path": registered_tool.api_path or "", + "method": registered_tool.api_method or "GET", + } + + else: + raise ValueError(f"Unsupported tool type: {registered_tool.tool_type}") + + return self._simulate_tool_call(registered_tool.tool_type, state_key, input_data) - # Fallback to static mode if no mock function provided - logger.warning(f"No mock function provided for {registered_tool.name}, falling back to static mode") - return self._handle_static_mode(registered_tool, tool_type) + # Copy function metadata + if registered_tool.function: + wrapper.__name__ = registered_tool.function.__name__ + try: + wrapper.__signature__ = inspect.signature(registered_tool.function) # type: ignore + except (ValueError, TypeError): + pass + wrapper.__doc__ = registered_tool.function.__doc__ + else: + wrapper.__name__ = registered_tool.name + + return wrapper + + def __getattr__(self, name: str) -> Any: + """ + Allow direct access to registered tools as attributes. + + Args: + name: Tool name + + Returns: + Tool callable + + Raises: + AttributeError: If tool not found + """ + registered_tool = self._registered_tools.get(name) + if registered_tool: + return self._create_tool_wrapper(registered_tool) + + raise AttributeError(f"Tool '{name}' not found in registered tools") + + @classmethod + def clear_registry(cls): + """Clear all registered tools. Useful for testing.""" + cls._registered_tools.clear() + cls._state_registry = None + logger.info("Cleared tool registry") @classmethod def function_tool( @@ -651,8 +587,7 @@ def decorator(func: Callable) -> Callable: logger.info(f"Registered function tool: {tool_name}") except Exception as e: - logger.error(f"Error registering function tool {name or func.__name__}: {e}") - raise + raise RuntimeError(f"Error registering function tool {name or func.__name__}: {e}") return func @@ -784,61 +719,6 @@ def get_tool(self, tool_name: str) -> Optional[Callable]: return self._create_tool_wrapper(registered_tool) - def _create_tool_wrapper(self, registered_tool: RegisteredTool) -> Callable: - """Create a wrapper function for direct tool access.""" - def wrapper(*args, **kwargs): - # Determine state key - state_key = ( - registered_tool.simulator_kwargs.get("share_state_id", registered_tool.name) - if registered_tool.simulator_kwargs - else registered_tool.name - ) - - # Build input data based on tool type - if registered_tool.tool_type == ToolType.FUNCTION: - parameters_string = ( - json.dumps({"args": args, "kwargs": kwargs}, indent=2) - if args - else json.dumps(kwargs, indent=2) - ) - - input_data = { - "tool_name": registered_tool.name, - "parameters": parameters_string, - } - - elif registered_tool.tool_type == ToolType.MCP: - input_data = { - "tool_name": registered_tool.name, - "input_mcp_payload": kwargs, - } - - elif registered_tool.tool_type == ToolType.API: - input_data = { - "tool_name": registered_tool.name, - "user_input_api_payload": kwargs, - "path": registered_tool.api_path or "", - "method": registered_tool.api_method or "GET", - } - - else: - raise ValueError(f"Unsupported tool type: {registered_tool.tool_type}") - - return self._simulate_tool_call(registered_tool.tool_type, state_key, input_data) - - # Copy function metadata - if registered_tool.function: - wrapper.__name__ = registered_tool.function.__name__ - try: - wrapper.__signature__ = inspect.signature(registered_tool.function) # type: ignore - except (ValueError, TypeError): - pass - wrapper.__doc__ = registered_tool.function.__doc__ - else: - wrapper.__name__ = registered_tool.name - - return wrapper - def list_tools(self) -> List[str]: """ List all registered tool names. @@ -847,29 +727,3 @@ def list_tools(self) -> List[str]: List of tool names """ return list(self._registered_tools.keys()) - - @classmethod - def clear_registry(cls): - """Clear all registered tools. Useful for testing.""" - cls._registered_tools.clear() - cls._state_registry = None - logger.info("Cleared tool registry") - - def __getattr__(self, name: str) -> Any: - """ - Allow direct access to registered tools as attributes. - - Args: - name: Tool name - - Returns: - Tool callable - - Raises: - AttributeError: If tool not found - """ - registered_tool = self._registered_tools.get(name) - if registered_tool: - return self._create_tool_wrapper(registered_tool) - - raise AttributeError(f"Tool '{name}' not found in registered tools") diff --git a/src/strands_evals/types/simulation/tool.py b/src/strands_evals/types/simulation/tool.py index 000b042..dceee92 100644 --- a/src/strands_evals/types/simulation/tool.py +++ b/src/strands_evals/types/simulation/tool.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, List, Optional, Union from pydantic import BaseModel, Field @@ -44,5 +44,44 @@ class RegisteredTool(BaseModel): static_response: Optional[Dict[str, Any]] = Field(default=None, description="Static response for static mode") mock_function: Optional[Callable] = Field(default=None, description="Mock function for mock mode", exclude=True) - class Config: - arbitrary_types_allowed = True + model_config = {"arbitrary_types_allowed": True} + + +# Tool Response Models for Structured Output + +class MCPContentItem(BaseModel): + """Individual content item in MCP response.""" + type: str = Field(..., description="Type of content (text, resource, etc.)") + text: Optional[str] = Field(default=None, description="Text content") + resource: Optional[Dict[str, Any]] = Field(default=None, description="Resource information") + + +class MCPToolResponse(BaseModel): + """ + Response model for MCP tool simulation using structured output. + + Follows the MCP response format with content array and optional error flag. + """ + content: List[MCPContentItem] = Field(..., description="Array of content items") + isError: Optional[bool] = Field(default=False, description="Whether this response represents an error") + + +class APIErrorDetail(BaseModel): + """Error detail structure for API responses.""" + type: str = Field(..., description="Error type identifier") + title: str = Field(..., description="Human-readable error title") + detail: str = Field(..., description="Detailed error description") + + +class APIToolResponse(BaseModel): + """ + Response model for API tool simulation using structured output. + + Follows HTTP response format with status code and optional data or error. + """ + status: int = Field(..., description="HTTP status code") + data: Optional[Any] = Field(default=None, description="Response data for successful requests") + error: Optional[APIErrorDetail] = Field(default=None, description="Error details for failed requests") + + # Allow additional fields for flexibility + model_config = {"extra": "allow"} diff --git a/tests/strands_evals/simulation/test_tool_simulator.py b/tests/strands_evals/simulation/test_tool_simulator.py index 847341e..b035855 100644 --- a/tests/strands_evals/simulation/test_tool_simulator.py +++ b/tests/strands_evals/simulation/test_tool_simulator.py @@ -49,17 +49,17 @@ def clear_registry(): def test_tool_simulator_init(): """Test ToolSimulator initialization with all parameters.""" custom_registry = StateRegistry() - template = "You are a helpful assistant simulating tools." simulator = ToolSimulator( state_registry=custom_registry, - system_prompt_template=template, model=None, ) assert simulator._state_registry is custom_registry - assert simulator.system_prompt_template == template - assert simulator.model is not None + assert simulator.model_id is None # model_id is now used instead of system_prompt_template + assert simulator.function_tool_prompt is not None # Check that prompt templates are loaded + assert simulator.mcp_tool_prompt is not None + assert simulator.api_tool_prompt is not None def test_function_tool_decorator_registration(): @@ -119,17 +119,24 @@ def test_func(message: str) -> dict: """Test function that should be simulated.""" pass - # Mock the structured_output method to return expected JSON - def mock_structured_output(output_type, messages, system_prompt=None): - yield {"output": '{"result": "simulated response"}'} - - mock_model.structured_output = mock_structured_output simulator = ToolSimulator(model=mock_model) - # Execute simulated function - result = simulator.test_function("Hello, world!") + # Mock the Agent constructor and its result to avoid real LLM calls + mock_agent_instance = MagicMock() + mock_result = MagicMock() + mock_result.structured_output = {"result": "simulated response"} + mock_agent_instance.return_value = mock_result - assert result == {"result": "simulated response"} + with pytest.MonkeyPatch().context() as m: + # Mock the Agent class constructor + from strands_evals.simulation.tool_simulator import Agent + m.setattr('strands_evals.simulation.tool_simulator.Agent', lambda **kwargs: mock_agent_instance) + + # Execute simulated function + result = simulator.test_function("Hello, world!") + + assert result == {"result": "simulated response"} + assert mock_agent_instance.called def test_mcp_tool_simulation(mock_model): @@ -141,17 +148,24 @@ def test_mcp(**params): """Test MCP tool that should be simulated.""" pass - # Mock the structured_output method to return expected JSON - def mock_structured_output(output_type, messages, system_prompt=None): - yield {"output": '{"content": [{"type": "text", "text": "MCP response"}]}'} - - mock_model.structured_output = mock_structured_output simulator = ToolSimulator(model=mock_model) - # Execute simulated MCP tool - result = simulator.test_mcp(param="test_value") + # Mock the Agent constructor and its result to avoid real LLM calls + mock_agent_instance = MagicMock() + mock_result = MagicMock() + mock_result.structured_output.model_dump.return_value = {"content": [{"type": "text", "text": "MCP response"}]} + mock_agent_instance.return_value = mock_result - assert result == {"content": [{"type": "text", "text": "MCP response"}]} + with pytest.MonkeyPatch().context() as m: + # Mock the Agent class constructor + from strands_evals.simulation.tool_simulator import Agent + m.setattr('strands_evals.simulation.tool_simulator.Agent', lambda **kwargs: mock_agent_instance) + + # Execute simulated MCP tool + result = simulator.test_mcp(param="test_value") + + assert result == {"content": [{"type": "text", "text": "MCP response"}]} + assert mock_agent_instance.called def test_api_tool_simulation(mock_model): @@ -162,17 +176,24 @@ def test_api(**kwargs): """Test API tool that should be simulated.""" pass - # Mock the structured_output method to return expected JSON - def mock_structured_output(output_type, messages, system_prompt=None): - yield {"output": '{"status": 200, "data": {"key": "value"}}'} - - mock_model.structured_output = mock_structured_output simulator = ToolSimulator(model=mock_model) - # Execute simulated API tool - result = simulator.test_api(key="value") + # Mock the Agent constructor and its result to avoid real LLM calls + mock_agent_instance = MagicMock() + mock_result = MagicMock() + mock_result.structured_output.model_dump.return_value = {"status": 200, "data": {"key": "value"}} + mock_agent_instance.return_value = mock_result - assert result == {"status": 200, "data": {"key": "value"}} + with pytest.MonkeyPatch().context() as m: + # Mock the Agent class constructor + from strands_evals.simulation.tool_simulator import Agent + m.setattr('strands_evals.simulation.tool_simulator.Agent', lambda **kwargs: mock_agent_instance) + + # Execute simulated API tool + result = simulator.test_api(key="value") + + assert result == {"status": 200, "data": {"key": "value"}} + assert mock_agent_instance.called def test_list_tools(): @@ -227,30 +248,57 @@ def get_transactions(**kwargs): """Get transaction history.""" pass - # Mock responses for each tool type based on call count - call_count = 0 - def mock_structured_output(output_type, messages, system_prompt=None): - nonlocal call_count - call_count += 1 - if call_count == 1: # First call (check_balance) - yield {"output": '{"balance": 1000, "currency": "USD"}'} - elif call_count == 2: # Second call (transfer_funds) - yield {"output": '{"content": [{"type": "text", "text": "Transfer completed"}]}'} - elif call_count == 3: # Third call (get_transactions) - yield {"output": '{"status": 200, "data": {"transactions": []}}'} - - mock_model.structured_output = mock_structured_output simulator = ToolSimulator(model=mock_model) - # Execute each tool in order - balance_result = simulator.check_balance("12345") - transfer_result = simulator.transfer_funds(from_account="12345", to_account="67890") - transactions_result = simulator.get_transactions(account_id="12345") - - # Verify results - assert balance_result == {"balance": 1000, "currency": "USD"} - assert transfer_result == {"content": [{"type": "text", "text": "Transfer completed"}]} - assert transactions_result == {"status": 200, "data": {"transactions": []}} + # Mock the Agent constructor to avoid real LLM calls + mock_agent_instances = [] + expected_responses = [ + {"balance": 1000, "currency": "USD"}, # Function response + {"content": [{"type": "text", "text": "Transfer completed"}]}, # MCP response + {"status": 200, "data": {"transactions": []}} # API response + ] + + def create_mock_agent(**kwargs): + mock_agent = MagicMock() + mock_result = MagicMock() + + if len(mock_agent_instances) < len(expected_responses): + response = expected_responses[len(mock_agent_instances)] + if 'content' in response: + # MCP response needs .model_dump() + mock_result.structured_output.model_dump.return_value = response + else: + # Function and API responses - function uses dict directly, API uses .model_dump() + if 'balance' in response: + # Function response - use direct structured_output + mock_result.structured_output = response + else: + # API response - use .model_dump() + mock_result.structured_output.model_dump.return_value = response + + mock_agent.return_value = mock_result + mock_agent_instances.append(mock_agent) + return mock_agent + + with pytest.MonkeyPatch().context() as m: + # Mock the Agent class constructor + from strands_evals.simulation.tool_simulator import Agent + m.setattr('strands_evals.simulation.tool_simulator.Agent', create_mock_agent) + + # Execute each tool in order + balance_result = simulator.check_balance("12345") + transfer_result = simulator.transfer_funds(from_account="12345", to_account="67890") + transactions_result = simulator.get_transactions(account_id="12345") + + # Verify results + assert balance_result == {"balance": 1000, "currency": "USD"} + assert transfer_result == {"content": [{"type": "text", "text": "Transfer completed"}]} + assert transactions_result == {"status": 200, "data": {"transactions": []}} + + # Verify all agents were called + assert len(mock_agent_instances) == 3 + for agent in mock_agent_instances: + assert agent.called # Verify all tools accessed the same shared state shared_state = simulator._state_registry.get_state(shared_state_id) @@ -277,15 +325,16 @@ def mock_structured_output(output_type, messages, system_prompt=None): assert "method" in api_call -def test_record_function_call(): - """Test recording function call in state registry.""" +def test_record_tool_call_function(): + """Test recording function call in state registry using unified method.""" registry = StateRegistry() - registry.record_function_call( + registry.record_tool_call( tool_name="test_tool", state_key="test_state", - parameters={"param": "value"}, - response_data={"result": "success"} + tool_type=ToolType.FUNCTION, + response_data={"result": "success"}, + parameters={"param": "value"} ) state = registry.get_state("test_state") @@ -293,19 +342,21 @@ def test_record_function_call(): assert len(state["previous_calls"]) == 1 call = state["previous_calls"][0] assert call["tool_name"] == "test_tool" + assert call["tool_type"] == "function" assert call["parameters"] == {"param": "value"} assert call["response"] == {"result": "success"} -def test_record_mcp_tool_call(): - """Test recording MCP tool call in state registry.""" +def test_record_tool_call_mcp(): + """Test recording MCP tool call in state registry using unified method.""" registry = StateRegistry() - registry.record_mcp_tool_call( + registry.record_tool_call( tool_name="mcp_tool", state_key="mcp_state", - input_mcp_payload={"input": "data"}, - response_data={"content": [{"type": "text", "text": "result"}]} + tool_type=ToolType.MCP, + response_data={"content": [{"type": "text", "text": "result"}]}, + input_mcp_payload={"input": "data"} ) state = registry.get_state("mcp_state") @@ -313,20 +364,23 @@ def test_record_mcp_tool_call(): assert len(state["previous_calls"]) == 1 call = state["previous_calls"][0] assert call["tool_name"] == "mcp_tool" + assert call["tool_type"] == "mcp" assert call["input_mcp_payload"] == {"input": "data"} + assert call["response"] == {"content": [{"type": "text", "text": "result"}]} -def test_record_api_call(): - """Test recording API call in state registry.""" +def test_record_tool_call_api(): + """Test recording API call in state registry using unified method.""" registry = StateRegistry() - registry.record_api_call( + registry.record_tool_call( tool_name="api_tool", state_key="api_state", + tool_type=ToolType.API, + response_data={"status": 200}, path="/test", method="POST", - input_data={"data": "test"}, - response={"status": 200} + input_data={"data": "test"} ) state = registry.get_state("api_state") @@ -334,59 +388,60 @@ def test_record_api_call(): assert len(state["previous_calls"]) == 1 call = state["previous_calls"][0] assert call["tool_name"] == "api_tool" + assert call["tool_type"] == "api" assert call["path"] == "/test" assert call["method"] == "POST" assert call["input"] == {"data": "test"} + assert call["response"] == {"status": 200} -def test_parse_llm_response_valid_json(): - """Test parsing valid JSON response.""" +def test_tool_not_found_raises_error(): + """Test that accessing non-existent tools raises ValueError.""" simulator = ToolSimulator() - response = simulator._parse_llm_response('{"key": "value"}') + # Test that accessing a non-existent tool via _simulate_tool_call raises ValueError + with pytest.raises(ValueError) as excinfo: + simulator._simulate_tool_call( + tool_type=ToolType.FUNCTION, + state_key="test", + input_data={"tool_name": "nonexistent_tool"} + ) - assert response == {"key": "value"} + assert "not registered" in str(excinfo.value) -def test_parse_llm_response_json_in_code_block(): - """Test parsing JSON from code blocks.""" +def test_api_tool_missing_name_raises_error(): + """Test that API tool simulation raises ValueError when tool_name is missing.""" simulator = ToolSimulator() - llm_text = '```json\n{"key": "value"}\n```' - response = simulator._parse_llm_response(llm_text) + with pytest.raises(ValueError) as excinfo: + simulator._handle_api_tool( + input_data={"tool_name": ""}, # Empty tool name + state_key="test" + ) - assert response == {"key": "value"} + assert "tool_name is required for API tool simulation" in str(excinfo.value) -def test_parse_llm_response_invalid_json_fallback(): - """Test fallback for invalid JSON.""" - simulator = ToolSimulator() - - response = simulator._parse_llm_response("This is not JSON") +def test_mock_mode_missing_function_raises_error(): + """Test that mock mode raises ValueError when mock_function is missing.""" + # Register a tool without mock_function but with mock mode + @ToolSimulator.function_tool("test_mock_tool", mode="mock") + def test_mock_tool(): + pass - assert response == {"result": "This is not JSON"} - - -def test_create_error_response(): - """Test error response creation.""" simulator = ToolSimulator() + registered_tool = ToolSimulator._registered_tools["test_mock_tool"] - error = simulator._create_error_response("test_error", "Test message", 400) - - assert error["status"] == 400 - assert error["error"]["type"] == "test_error" - assert error["error"]["detail"] == "Test message" - assert error["error"]["title"] == "Bad Request" - - -def test_get_error_title(): - """Test error title mapping.""" - simulator = ToolSimulator() + with pytest.raises(ValueError) as excinfo: + simulator._handle_mock_mode( + registered_tool=registered_tool, + input_data={"tool_name": "test_mock_tool", "parameters": {}}, + state_key="test", + tool_type=ToolType.FUNCTION + ) - assert simulator._get_error_title(400) == "Bad Request" - assert simulator._get_error_title(404) == "Not Found" - assert simulator._get_error_title(500) == "Internal Server Error" - assert simulator._get_error_title(999) == "Error" # Unknown status code + assert "mock_function is required for tool simulator mock mode" in str(excinfo.value) def test_clear_registry(): From 15d3fcd81a11dcede7b4660190938c3363e17ba4 Mon Sep 17 00:00:00 2001 From: Darren Wang Date: Wed, 28 Jan 2026 01:35:36 +0000 Subject: [PATCH 07/15] refactor and address mypy errors --- src/strands_evals/simulation/__init__.py | 2 +- .../tool_response_generation.py | 9 +- .../simulation/tool_simulator.py | 406 ++++++++++-------- src/strands_evals/types/simulation/tool.py | 26 +- .../simulation/test_tool_simulator.py | 276 ++++++------ 5 files changed, 372 insertions(+), 347 deletions(-) diff --git a/src/strands_evals/simulation/__init__.py b/src/strands_evals/simulation/__init__.py index 98c7593..3097b0d 100644 --- a/src/strands_evals/simulation/__init__.py +++ b/src/strands_evals/simulation/__init__.py @@ -5,7 +5,7 @@ UserSimulator = ActorSimulator __all__ = [ - "ActorSimulator", + "ActorSimulator", "UserSimulator", "ToolSimulator", ] diff --git a/src/strands_evals/simulation/prompt_templates/tool_response_generation.py b/src/strands_evals/simulation/prompt_templates/tool_response_generation.py index 363ef64..8b5aa5c 100644 --- a/src/strands_evals/simulation/prompt_templates/tool_response_generation.py +++ b/src/strands_evals/simulation/prompt_templates/tool_response_generation.py @@ -11,7 +11,8 @@ FUNCTION_TOOL_RESPONSE_GENERATION_PROMPT = dedent( """ -You are simulating a function tool call for agent evaluation. Generate a realistic response based on the function name, parameters, and context. +You are simulating a function tool call for agent evaluation. Generate a realistic response based on the function name, +parameters, and context. ## Function Tool Information Tool Name: {tool_name} @@ -66,7 +67,8 @@ ) MCP_TOOL_RESPONSE_GENERATION_PROMPT = dedent( - """You are simulating an MCP (Model Context Protocol) tool call for agent evaluation. Generate a realistic response based on the tool name, input payload, and context. + """You are simulating an MCP (Model Context Protocol) tool call for agent evaluation. Generate a realistic response +based on the tool name, input payload, and context. ## MCP Tool Information Tool Name: {tool_name} @@ -138,7 +140,8 @@ ) API_TOOL_RESPONSE_GENERATION_PROMPT = dedent( - """You are simulating an API tool call for agent evaluation. Generate a realistic HTTP response based on the API endpoint, method, payload, and context. + """You are simulating an API tool call for agent evaluation. Generate a realistic HTTP response based on the API +endpoint, method, payload, and context. ## API Tool Information Tool Name: {tool_name} diff --git a/src/strands_evals/simulation/tool_simulator.py b/src/strands_evals/simulation/tool_simulator.py index 0454cd2..f1bde51 100644 --- a/src/strands_evals/simulation/tool_simulator.py +++ b/src/strands_evals/simulation/tool_simulator.py @@ -9,10 +9,10 @@ from strands.models.model import Model from strands_evals.types.simulation.tool import ( - RegisteredTool, - ToolType, - MCPToolResponse, APIToolResponse, + MCPToolResponse, + RegisteredTool, + ToolType, ) logger = logging.getLogger(__name__) @@ -23,7 +23,7 @@ class StateRegistry: State registry for managing shared state between tool simulators. Organized by state_key to isolate state between different tools or shared state groups. """ - + def __init__(self): """ Initialize state registry. @@ -32,7 +32,7 @@ def __init__(self): across different simulation sessions. """ self._states: Dict[str, Dict[str, Any]] = {} - + def initialize_state_via_description(self, initial_state_description: str, state_key: str) -> None: """ Initialize state based on the provided description. @@ -53,7 +53,9 @@ def initialize_state_via_description(self, initial_state_description: str, state "user_context": {}, } else: - warnings.warn(f"State with key '{state_key}' already initialized. Skipping re-initialization.") + warnings.warn( + f"State with key '{state_key}' already initialized. Skipping re-initialization.", stacklevel=2 + ) def get_state(self, state_key: str) -> Dict[str, Any]: """ @@ -102,23 +104,25 @@ def record_tool_call( # Build call record based on tool type call_record = { - 'tool_name': tool_name, - 'tool_type': tool_type.value, - 'response': response_data, - 'timestamp': timestamp + "tool_name": tool_name, + "tool_type": tool_type.value, + "response": response_data, + "timestamp": timestamp, } - + # Add tool-specific fields if tool_type == ToolType.FUNCTION: - call_record['parameters'] = call_data.get('parameters', {}) + call_record["parameters"] = call_data.get("parameters", {}) elif tool_type == ToolType.MCP: - call_record['input_mcp_payload'] = call_data.get('input_mcp_payload', {}) + call_record["input_mcp_payload"] = call_data.get("input_mcp_payload", {}) elif tool_type == ToolType.API: - call_record.update({ - 'path': call_data.get('path', ''), - 'method': call_data.get('method', 'GET'), - 'input': call_data.get('input_data', {}) - }) + call_record.update( + { + "path": call_data.get("path", ""), + "method": call_data.get("method", "GET"), + "input": call_data.get("input_data", {}), + } + ) state["previous_calls"].append(call_record) @@ -180,22 +184,25 @@ def __init__( """ # Store model configuration for creating internal agents self.model_id = model - + # Set custom prompts or use defaults if function_tool_prompt is None: from .prompt_templates.tool_response_generation import FUNCTION_TOOL_RESPONSE_GENERATION_PROMPT + self.function_tool_prompt = FUNCTION_TOOL_RESPONSE_GENERATION_PROMPT else: self.function_tool_prompt = function_tool_prompt - + if mcp_tool_prompt is None: from .prompt_templates.tool_response_generation import MCP_TOOL_RESPONSE_GENERATION_PROMPT + self.mcp_tool_prompt = MCP_TOOL_RESPONSE_GENERATION_PROMPT else: self.mcp_tool_prompt = mcp_tool_prompt - + if api_tool_prompt is None: from .prompt_templates.tool_response_generation import API_TOOL_RESPONSE_GENERATION_PROMPT + self.api_tool_prompt = API_TOOL_RESPONSE_GENERATION_PROMPT else: self.api_tool_prompt = api_tool_prompt @@ -209,7 +216,6 @@ def __init__( # Initialize shared states from registered tools self._initialize_shared_states() - def _initialize_shared_states(self): """Initialize shared states from registered tools' initial descriptions.""" for tool_name, registered_tool in self._registered_tools.items(): @@ -220,23 +226,94 @@ def _initialize_shared_states(self): if registered_tool.simulator_kwargs else registered_tool.name ) - + # Initialize state with description self._state_registry.initialize_state_via_description( - registered_tool.initial_state_description, - state_key + registered_tool.initial_state_description, state_key ) logger.info(f"Initialized state for tool '{tool_name}' with key '{state_key}'") + def __getattr__(self, name: str) -> Any: + """ + Allow direct access to registered tools as attributes. + + Args: + name: Tool name + + Returns: + Tool callable + + Raises: + AttributeError: If tool not found + """ + registered_tool = self._registered_tools.get(name) + if registered_tool: + return self._create_tool_wrapper(registered_tool) + + raise AttributeError(f"Tool '{name}' not found in registered tools") + + def _create_tool_wrapper(self, registered_tool: RegisteredTool) -> Callable: + """Create a wrapper function for direct tool access.""" + + def wrapper(*args, **kwargs): + # Determine state key + state_key = ( + registered_tool.simulator_kwargs.get("share_state_id", registered_tool.name) + if registered_tool.simulator_kwargs + else registered_tool.name + ) + + # Build input data based on tool type + if registered_tool.tool_type == ToolType.FUNCTION: + parameters_string = ( + json.dumps({"args": args, "kwargs": kwargs}, indent=2) if args else json.dumps(kwargs, indent=2) + ) + + input_data = { + "tool_name": registered_tool.name, + "parameters": parameters_string, + } + + elif registered_tool.tool_type == ToolType.MCP: + input_data = { + "tool_name": registered_tool.name, + "input_mcp_payload": kwargs, + } + + elif registered_tool.tool_type == ToolType.API: + input_data = { + "tool_name": registered_tool.name, + "user_input_api_payload": kwargs, + "path": registered_tool.api_path or "", + "method": registered_tool.api_method or "GET", + } + + else: + raise ValueError(f"Unsupported tool type: {registered_tool.tool_type}") + + return self._simulate_tool_call(registered_tool.tool_type, state_key, input_data) + + # Copy function metadata + if registered_tool.function: + wrapper.__name__ = registered_tool.function.__name__ + try: + wrapper.__signature__ = inspect.signature(registered_tool.function) # type: ignore + except (ValueError, TypeError): + pass + wrapper.__doc__ = registered_tool.function.__doc__ + else: + wrapper.__name__ = registered_tool.name + + return wrapper def _simulate_tool_call(self, tool_type: ToolType, state_key: str, input_data: Dict[str, Any]) -> Any: """Simulate a tool invocation and return the response.""" tool_name = input_data.get("tool_name", "") registered_tool = self._registered_tools.get(tool_name) - + if not registered_tool: raise ValueError(f"Tool '{tool_name}' not registered") - + # Handle different simulation modes if registered_tool.mode == "static": return self._handle_static_mode(registered_tool, tool_type) @@ -254,73 +331,82 @@ def _simulate_tool_call(self, tool_type: ToolType, state_key: str, input_data: D raise ValueError(f"Tool type '{tool_type}' not supported") else: raise ValueError(f"Tool simulation mode '{registered_tool.mode}' not supported") - + def _handle_function_tool(self, input_data: Dict[str, Any], state_key: str) -> Dict[str, Any]: """Handle function tool simulation.""" tool_name = input_data.get("tool_name", "") parameters = input_data.get("parameters", {}) - + if not tool_name: - return {"status": "error", "error_type": "missing_tool_name", "message": "Tool name is required"} - - # Generate response using structured output + raise ValueError("Tool name is required") + + if not self._state_registry: + raise RuntimeError("State registry is not initialized") + try: # Get initial state description from state registry current_state = self._state_registry.get_state(state_key) initial_state_description = current_state.get("initial_state", "No initial state provided.") - + prompt = self.function_tool_prompt.format( tool_name=tool_name, parameters=json.dumps(parameters, indent=2) if parameters else "{}", initial_state_description=initial_state_description, - previous_responses=json.dumps(current_state, indent=2) or "{}" + previous_responses=json.dumps(current_state, indent=2) or "{}", ) - + # Create agent and generate response with structured output - # Use dict for function responses since they vary based on function signature agent = Agent( system_prompt=self.function_tool_prompt, tools=[], model=self.model_id, callback_handler=None, ) - result = agent(prompt, structured_output_model=dict) - response_data = result.structured_output - + result = agent(prompt, structured_output_model=None) + + # Parse JSON response for function tools since they vary based on function signature + if result.response and isinstance(result.response, str): + try: + response_data = json.loads(result.response) + except json.JSONDecodeError: + response_data = {"result": result.response} + else: + response_data = {"result": str(result.response) if result.response else "No response"} + # Record the call self._state_registry.record_tool_call( tool_name, state_key, ToolType.FUNCTION, response_data, parameters=parameters ) - + return response_data - + except Exception as e: logger.error(f"Error generating function response: {e}") - return {"status": "error", "error_type": "generation_error", "message": str(e)} - + raise RuntimeError(f"Error generating function response: {e}") from e + def _handle_mcp_tool(self, input_data: Dict[str, Any], state_key: str) -> Dict[str, Any]: """Handle MCP tool simulation.""" tool_name = input_data.get("tool_name", "") input_mcp_payload = input_data.get("input_mcp_payload", {}) - + if not tool_name: - return { - "isError": True, - "content": [{"type": "text", "text": "Tool name is required"}] - } - + return {"isError": True, "content": [{"type": "text", "text": "Tool name is required"}]} + + if not self._state_registry: + raise RuntimeError("State registry is not initialized") + try: # Get initial state description from state registry current_state = self._state_registry.get_state(state_key) initial_state_description = current_state.get("initial_state", "No initial state provided.") - + prompt = self.mcp_tool_prompt.format( tool_name=tool_name, mcp_payload=json.dumps(input_mcp_payload, indent=2) if input_mcp_payload else "{}", initial_state_description=initial_state_description, - previous_responses=json.dumps(current_state, indent=2) or "{}" + previous_responses=json.dumps(current_state, indent=2) or "{}", ) - + # Create agent and generate response with structured output agent = Agent( system_prompt=self.mcp_tool_prompt, @@ -329,46 +415,52 @@ def _handle_mcp_tool(self, input_data: Dict[str, Any], state_key: str) -> Dict[s callback_handler=None, ) result = agent(prompt, structured_output_model=MCPToolResponse) - response_data = result.structured_output.model_dump() - + if result.structured_output: + response_data = result.structured_output.model_dump() + else: + response_data = { + "isError": True, + "content": [{"type": "text", "text": "No structured output received"}], + } + # Record the call self._state_registry.record_tool_call( tool_name, state_key, ToolType.MCP, response_data, input_mcp_payload=input_mcp_payload ) - + return response_data - + except Exception as e: logger.error(f"Error generating MCP response: {e}") - return { - "isError": True, - "content": [{"type": "text", "text": f"Error generating response: {str(e)}"}] - } - + return {"isError": True, "content": [{"type": "text", "text": f"Error generating response: {str(e)}"}]} + def _handle_api_tool(self, input_data: Dict[str, Any], state_key: str) -> Dict[str, Any]: """Handle API tool simulation.""" tool_name = input_data.get("tool_name", "") user_input_api_payload = input_data.get("user_input_api_payload", {}) path = input_data.get("path", "") method = input_data.get("method", "GET") - + if not tool_name: raise ValueError("tool_name is required for API tool simulation") - + + if not self._state_registry: + raise RuntimeError("State registry is not initialized") + try: # Get initial state description from state registry current_state = self._state_registry.get_state(state_key) initial_state_description = current_state.get("initial_state", "No initial state provided.") - + prompt = self.api_tool_prompt.format( tool_name=tool_name, path=path, method=method, api_payload=json.dumps(user_input_api_payload, indent=2) if user_input_api_payload else "{}", initial_state_description=initial_state_description, - previous_responses=json.dumps(current_state, indent=2) or "{}" + previous_responses=json.dumps(current_state, indent=2) or "{}", ) - + # Create agent and generate response with structured output agent = Agent( system_prompt=self.api_tool_prompt, @@ -377,38 +469,44 @@ def _handle_api_tool(self, input_data: Dict[str, Any], state_key: str) -> Dict[s callback_handler=None, ) result = agent(prompt, structured_output_model=APIToolResponse) - response_data = result.structured_output.model_dump() - + if result.structured_output: + response_data = result.structured_output.model_dump() + else: + response_data = { + "status": 500, + "error": { + "type": "internal_error", + "title": "Internal Error", + "detail": "No structured output received", + }, + } + # Record the call self._state_registry.record_tool_call( - tool_name, state_key, ToolType.API, response_data, - path=path, method=method, input_data=user_input_api_payload + tool_name, + state_key, + ToolType.API, + response_data, + path=path, + method=method, + input_data=user_input_api_payload, ) - + return response_data - + except Exception as e: - raise RuntimeError(f"Error generating simulated API response: {e}") + raise RuntimeError(f"Error generating simulated API response: {e}") from e def _handle_static_mode(self, registered_tool: RegisteredTool, tool_type: ToolType) -> Dict[str, Any]: """Handle static mode simulation - returns predefined static response.""" - if registered_tool.static_response is not None: - return registered_tool.static_response - - # Default static responses for different tool types - if tool_type == ToolType.FUNCTION: - return {"status": "success", "result": f"Static response from {registered_tool.name}"} - elif tool_type == ToolType.MCP: - return { - "isError": False, - "content": [{"type": "text", "text": f"Static response from {registered_tool.name}"}] - } - elif tool_type == ToolType.API: - return {"status": 200, "data": {"message": f"Static response from {registered_tool.name}"}} - else: - return {"status": "error", "message": "Unsupported tool type for static mode"} + if registered_tool.static_response is None: + raise ValueError(f"Static response is required for tool '{registered_tool.name}' in static mode") - def _handle_mock_mode(self, registered_tool: RegisteredTool, input_data: Dict[str, Any], state_key: str, tool_type: ToolType) -> Dict[str, Any]: + return registered_tool.static_response + + def _handle_mock_mode( + self, registered_tool: RegisteredTool, input_data: Dict[str, Any], state_key: str, tool_type: ToolType + ) -> Dict[str, Any]: """Handle mock mode simulation - calls custom mock function.""" if registered_tool.mock_function is None: raise ValueError("mock_function is required for tool simulator mock mode") @@ -419,7 +517,7 @@ def _handle_mock_mode(self, registered_tool: RegisteredTool, input_data: Dict[st parameters = input_data.get("parameters", {}) if isinstance(parameters, str): parameters = json.loads(parameters) - + # Call mock function with extracted parameters if "kwargs" in parameters: result = registered_tool.mock_function(**parameters["kwargs"]) @@ -427,19 +525,22 @@ def _handle_mock_mode(self, registered_tool: RegisteredTool, input_data: Dict[st result = registered_tool.mock_function(*parameters["args"]) else: result = registered_tool.mock_function(**parameters) - + elif tool_type == ToolType.MCP: input_mcp_payload = input_data.get("input_mcp_payload", {}) result = registered_tool.mock_function(**input_mcp_payload) - + elif tool_type == ToolType.API: user_input_api_payload = input_data.get("user_input_api_payload", {}) result = registered_tool.mock_function(**user_input_api_payload) - + else: - return {"status": "error", "message": "Unsupported tool type for mock mode"} - + raise ValueError(f"Unsupported tool type '{tool_type}' for mock mode") + # Record the call in state registry + if not self._state_registry: + raise RuntimeError("State registry is not initialized") + tool_name = registered_tool.name if tool_type == ToolType.FUNCTION: self._state_registry.record_tool_call( @@ -453,88 +554,21 @@ def _handle_mock_mode(self, registered_tool: RegisteredTool, input_data: Dict[st path = input_data.get("path", "") method = input_data.get("method", "GET") self._state_registry.record_tool_call( - tool_name, state_key, ToolType.API, result, - path=path, method=method, input_data=user_input_api_payload + tool_name, + state_key, + ToolType.API, + result, + path=path, + method=method, + input_data=user_input_api_payload, ) - - return result - - except Exception as e: - raise RuntimeError(f"Tool simulator mock mode error for {tool_type} tool {registered_tool.name}: {e}") - def _create_tool_wrapper(self, registered_tool: RegisteredTool) -> Callable: - """Create a wrapper function for direct tool access.""" - def wrapper(*args, **kwargs): - # Determine state key - state_key = ( - registered_tool.simulator_kwargs.get("share_state_id", registered_tool.name) - if registered_tool.simulator_kwargs - else registered_tool.name - ) - - # Build input data based on tool type - if registered_tool.tool_type == ToolType.FUNCTION: - parameters_string = ( - json.dumps({"args": args, "kwargs": kwargs}, indent=2) - if args - else json.dumps(kwargs, indent=2) - ) - - input_data = { - "tool_name": registered_tool.name, - "parameters": parameters_string, - } - - elif registered_tool.tool_type == ToolType.MCP: - input_data = { - "tool_name": registered_tool.name, - "input_mcp_payload": kwargs, - } - - elif registered_tool.tool_type == ToolType.API: - input_data = { - "tool_name": registered_tool.name, - "user_input_api_payload": kwargs, - "path": registered_tool.api_path or "", - "method": registered_tool.api_method or "GET", - } - - else: - raise ValueError(f"Unsupported tool type: {registered_tool.tool_type}") - - return self._simulate_tool_call(registered_tool.tool_type, state_key, input_data) - - # Copy function metadata - if registered_tool.function: - wrapper.__name__ = registered_tool.function.__name__ - try: - wrapper.__signature__ = inspect.signature(registered_tool.function) # type: ignore - except (ValueError, TypeError): - pass - wrapper.__doc__ = registered_tool.function.__doc__ - else: - wrapper.__name__ = registered_tool.name - - return wrapper - - def __getattr__(self, name: str) -> Any: - """ - Allow direct access to registered tools as attributes. - - Args: - name: Tool name - - Returns: - Tool callable - - Raises: - AttributeError: If tool not found - """ - registered_tool = self._registered_tools.get(name) - if registered_tool: - return self._create_tool_wrapper(registered_tool) + return result - raise AttributeError(f"Tool '{name}' not found in registered tools") + except Exception as e: + raise RuntimeError( + f"Tool simulator mock mode error for {tool_type} tool {registered_tool.name}: {e}" + ) from e @classmethod def clear_registry(cls): @@ -545,13 +579,13 @@ def clear_registry(cls): @classmethod def function_tool( - cls, - name: Optional[str] = None, - initial_state_description: Optional[str] = None, + cls, + name: Optional[str] = None, + initial_state_description: Optional[str] = None, mode: str = "dynamic", static_response: Optional[Dict[str, Any]] = None, mock_function: Optional[Callable] = None, - **simulator_kwargs + **simulator_kwargs, ) -> Callable: """ Decorator for registering Python function tools. @@ -567,6 +601,7 @@ def function_tool( Returns: Decorator function """ + def decorator(func: Callable) -> Callable: try: tool_name = name or func.__name__ @@ -587,7 +622,7 @@ def decorator(func: Callable) -> Callable: logger.info(f"Registered function tool: {tool_name}") except Exception as e: - raise RuntimeError(f"Error registering function tool {name or func.__name__}: {e}") + raise RuntimeError(f"Error registering function tool {name or func.__name__}: {e}") from e return func @@ -595,14 +630,14 @@ def decorator(func: Callable) -> Callable: @classmethod def mcp_tool( - cls, - name: Optional[str] = None, - schema: Optional[Dict[str, Any]] = None, + cls, + name: Optional[str] = None, + schema: Optional[Dict[str, Any]] = None, initial_state_description: Optional[str] = None, mode: str = "dynamic", static_response: Optional[Dict[str, Any]] = None, mock_function: Optional[Callable] = None, - **simulator_kwargs + **simulator_kwargs, ) -> Callable: """ Decorator for registering MCP (Model Context Protocol) tools. @@ -619,6 +654,7 @@ def mcp_tool( Returns: Decorator function """ + def decorator(func: Callable) -> Callable: tool_name = name or func.__name__ @@ -674,6 +710,7 @@ def api_tool( Returns: Decorator function """ + def decorator(func: Callable) -> Callable: tool_name = name or func.__name__ @@ -702,7 +739,6 @@ def decorator(func: Callable) -> Callable: return decorator - def get_tool(self, tool_name: str) -> Optional[Callable]: """ Get a tool by name and create a simulation wrapper. @@ -716,7 +752,7 @@ def get_tool(self, tool_name: str) -> Optional[Callable]: registered_tool = self._registered_tools.get(tool_name) if not registered_tool: return None - + return self._create_tool_wrapper(registered_tool) def list_tools(self) -> List[str]: diff --git a/src/strands_evals/types/simulation/tool.py b/src/strands_evals/types/simulation/tool.py index dceee92..d1e7449 100644 --- a/src/strands_evals/types/simulation/tool.py +++ b/src/strands_evals/types/simulation/tool.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional from pydantic import BaseModel, Field @@ -7,12 +7,13 @@ class ToolType(Enum): """ Enumeration of supported tool types for simulation. - + Attributes: FUNCTION: Python function tools that can be called directly. MCP: Model Context Protocol tools with structured schemas. API: REST API endpoints with HTTP methods and paths. """ + FUNCTION = "function" MCP = "mcp" API = "api" @@ -32,14 +33,19 @@ class RegisteredTool(BaseModel): initial_state_description: Initial state description for the tool's context. simulator_kwargs: Additional simulator configuration parameters. """ + name: str = Field(..., description="Name of the tool") tool_type: ToolType = Field(..., description="Type of the tool") function: Optional[Callable] = Field(default=None, description="Function callable", exclude=True) mcp_schema: Optional[Dict[str, Any]] = Field(default=None, description="MCP tool schema") api_path: Optional[str] = Field(default=None, description="API endpoint path") api_method: Optional[str] = Field(default=None, description="HTTP method") - initial_state_description: Optional[str] = Field(default=None, description="Initial state description for the tool's context") - simulator_kwargs: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Additional simulator configuration") + initial_state_description: Optional[str] = Field( + default=None, description="Initial state description for the tool's context" + ) + simulator_kwargs: Optional[Dict[str, Any]] = Field( + default_factory=dict, description="Additional simulator configuration" + ) mode: str = Field(default="dynamic", description="Simulation mode: dynamic, static, mock") static_response: Optional[Dict[str, Any]] = Field(default=None, description="Static response for static mode") mock_function: Optional[Callable] = Field(default=None, description="Mock function for mock mode", exclude=True) @@ -47,10 +53,9 @@ class RegisteredTool(BaseModel): model_config = {"arbitrary_types_allowed": True} -# Tool Response Models for Structured Output - class MCPContentItem(BaseModel): """Individual content item in MCP response.""" + type: str = Field(..., description="Type of content (text, resource, etc.)") text: Optional[str] = Field(default=None, description="Text content") resource: Optional[Dict[str, Any]] = Field(default=None, description="Resource information") @@ -59,15 +64,17 @@ class MCPContentItem(BaseModel): class MCPToolResponse(BaseModel): """ Response model for MCP tool simulation using structured output. - + Follows the MCP response format with content array and optional error flag. """ + content: List[MCPContentItem] = Field(..., description="Array of content items") isError: Optional[bool] = Field(default=False, description="Whether this response represents an error") class APIErrorDetail(BaseModel): """Error detail structure for API responses.""" + type: str = Field(..., description="Error type identifier") title: str = Field(..., description="Human-readable error title") detail: str = Field(..., description="Detailed error description") @@ -76,12 +83,13 @@ class APIErrorDetail(BaseModel): class APIToolResponse(BaseModel): """ Response model for API tool simulation using structured output. - + Follows HTTP response format with status code and optional data or error. """ + status: int = Field(..., description="HTTP status code") data: Optional[Any] = Field(default=None, description="Response data for successful requests") error: Optional[APIErrorDetail] = Field(default=None, description="Error details for failed requests") - + # Allow additional fields for flexibility model_config = {"extra": "allow"} diff --git a/tests/strands_evals/simulation/test_tool_simulator.py b/tests/strands_evals/simulation/test_tool_simulator.py index b035855..1c79665 100644 --- a/tests/strands_evals/simulation/test_tool_simulator.py +++ b/tests/strands_evals/simulation/test_tool_simulator.py @@ -7,7 +7,7 @@ from strands import tool from strands_evals.case import Case -from strands_evals.simulation.tool_simulator import ToolSimulator, StateRegistry +from strands_evals.simulation.tool_simulator import StateRegistry, ToolSimulator from strands_evals.types.simulation.tool import ToolType @@ -24,16 +24,12 @@ def sample_case(): def mock_model(): """Fixture providing a mock model for testing.""" mock = MagicMock() - + # Mock the structured_output method def mock_structured_output(output_type, messages, system_prompt=None): # Simulate streaming response - yield { - "contentBlockDelta": { - "text": '{"result": "mocked response"}' - } - } - + yield {"contentBlockDelta": {"text": '{"result": "mocked response"}'}} + mock.structured_output = mock_structured_output return mock @@ -49,12 +45,12 @@ def clear_registry(): def test_tool_simulator_init(): """Test ToolSimulator initialization with all parameters.""" custom_registry = StateRegistry() - + simulator = ToolSimulator( state_registry=custom_registry, model=None, ) - + assert simulator._state_registry is custom_registry assert simulator.model_id is None # model_id is now used instead of system_prompt_template assert simulator.function_tool_prompt is not None # Check that prompt templates are loaded @@ -64,6 +60,7 @@ def test_tool_simulator_init(): def test_function_tool_decorator_registration(): """Test function tool decorator registration.""" + @ToolSimulator.function_tool() def test_function(x: int, y: str) -> dict: """A sample function for testing.""" @@ -78,11 +75,7 @@ def test_function(x: int, y: str) -> dict: def test_mcp_tool_decorator_registration(): """Test MCP tool decorator registration.""" - schema = { - "type": "object", - "properties": {"param": {"type": "string"}}, - "required": ["param"] - } + schema = {"type": "object", "properties": {"param": {"type": "string"}}, "required": ["param"]} @ToolSimulator.mcp_tool("test_mcp", schema=schema) def sample_mcp_tool(**params): @@ -98,6 +91,7 @@ def sample_mcp_tool(**params): def test_api_tool_decorator_registration(): """Test API tool decorator registration.""" + @ToolSimulator.api_tool("test_api", path="/test", method="POST") def sample_api_tool(**kwargs): """A sample API tool for testing.""" @@ -113,28 +107,28 @@ def sample_api_tool(**kwargs): def test_function_tool_simulation(mock_model): """Test function tool simulation.""" + # Register and create simulator with mock model @ToolSimulator.function_tool("test_function") def test_func(message: str) -> dict: """Test function that should be simulated.""" pass - + simulator = ToolSimulator(model=mock_model) - + # Mock the Agent constructor and its result to avoid real LLM calls mock_agent_instance = MagicMock() mock_result = MagicMock() mock_result.structured_output = {"result": "simulated response"} mock_agent_instance.return_value = mock_result - + with pytest.MonkeyPatch().context() as m: # Mock the Agent class constructor - from strands_evals.simulation.tool_simulator import Agent - m.setattr('strands_evals.simulation.tool_simulator.Agent', lambda **kwargs: mock_agent_instance) - + m.setattr("strands_evals.simulation.tool_simulator.Agent", lambda **kwargs: mock_agent_instance) + # Execute simulated function result = simulator.test_function("Hello, world!") - + assert result == {"result": "simulated response"} assert mock_agent_instance.called @@ -143,72 +137,73 @@ def test_mcp_tool_simulation(mock_model): """Test MCP tool simulation.""" # Register and create simulator with mock model schema = {"type": "object", "properties": {"param": {"type": "string"}}} + @ToolSimulator.mcp_tool("test_mcp", schema=schema) def test_mcp(**params): """Test MCP tool that should be simulated.""" pass - + simulator = ToolSimulator(model=mock_model) - + # Mock the Agent constructor and its result to avoid real LLM calls mock_agent_instance = MagicMock() mock_result = MagicMock() mock_result.structured_output.model_dump.return_value = {"content": [{"type": "text", "text": "MCP response"}]} mock_agent_instance.return_value = mock_result - + with pytest.MonkeyPatch().context() as m: # Mock the Agent class constructor - from strands_evals.simulation.tool_simulator import Agent - m.setattr('strands_evals.simulation.tool_simulator.Agent', lambda **kwargs: mock_agent_instance) - + m.setattr("strands_evals.simulation.tool_simulator.Agent", lambda **kwargs: mock_agent_instance) + # Execute simulated MCP tool result = simulator.test_mcp(param="test_value") - + assert result == {"content": [{"type": "text", "text": "MCP response"}]} assert mock_agent_instance.called def test_api_tool_simulation(mock_model): """Test API tool simulation.""" + # Register and create simulator with mock model @ToolSimulator.api_tool("test_api", path="/test", method="POST") def test_api(**kwargs): """Test API tool that should be simulated.""" pass - + simulator = ToolSimulator(model=mock_model) - + # Mock the Agent constructor and its result to avoid real LLM calls mock_agent_instance = MagicMock() mock_result = MagicMock() mock_result.structured_output.model_dump.return_value = {"status": 200, "data": {"key": "value"}} mock_agent_instance.return_value = mock_result - + with pytest.MonkeyPatch().context() as m: # Mock the Agent class constructor - from strands_evals.simulation.tool_simulator import Agent - m.setattr('strands_evals.simulation.tool_simulator.Agent', lambda **kwargs: mock_agent_instance) - + m.setattr("strands_evals.simulation.tool_simulator.Agent", lambda **kwargs: mock_agent_instance) + # Execute simulated API tool result = simulator.test_api(key="value") - + assert result == {"status": 200, "data": {"key": "value"}} assert mock_agent_instance.called def test_list_tools(): """Test listing registered tools.""" + @ToolSimulator.function_tool("func1") def func1(): pass - @ToolSimulator.function_tool("func2") + @ToolSimulator.function_tool("func2") def func2(): pass simulator = ToolSimulator() tools = simulator.list_tools() - + assert set(tools) == {"func1", "func2"} @@ -216,110 +211,107 @@ def test_shared_state_registry(mock_model): """Test that function, MCP, and API tools can share the same state registry.""" shared_state_id = "shared_banking_state" initial_state = "Initial banking system state with account balances" - + # Register three different tools that share the same state @ToolSimulator.function_tool( - "check_balance", - initial_state_description=initial_state, - share_state_id=shared_state_id + "check_balance", initial_state_description=initial_state, share_state_id=shared_state_id ) def check_balance(account_id: str): """Check account balance.""" pass - + @ToolSimulator.mcp_tool( - "transfer_funds", + "transfer_funds", schema={"type": "object", "properties": {"from_account": {"type": "string"}, "to_account": {"type": "string"}}}, initial_state_description=initial_state, - share_state_id=shared_state_id + share_state_id=shared_state_id, ) def transfer_funds(**params): """Transfer funds between accounts.""" pass - + @ToolSimulator.api_tool( "get_transactions", - path="/transactions", + path="/transactions", method="GET", initial_state_description=initial_state, - share_state_id=shared_state_id + share_state_id=shared_state_id, ) def get_transactions(**kwargs): """Get transaction history.""" pass - + simulator = ToolSimulator(model=mock_model) - + # Mock the Agent constructor to avoid real LLM calls mock_agent_instances = [] expected_responses = [ {"balance": 1000, "currency": "USD"}, # Function response {"content": [{"type": "text", "text": "Transfer completed"}]}, # MCP response - {"status": 200, "data": {"transactions": []}} # API response + {"status": 200, "data": {"transactions": []}}, # API response ] - + def create_mock_agent(**kwargs): mock_agent = MagicMock() mock_result = MagicMock() - + if len(mock_agent_instances) < len(expected_responses): response = expected_responses[len(mock_agent_instances)] - if 'content' in response: + if "content" in response: # MCP response needs .model_dump() mock_result.structured_output.model_dump.return_value = response else: # Function and API responses - function uses dict directly, API uses .model_dump() - if 'balance' in response: + if "balance" in response: # Function response - use direct structured_output mock_result.structured_output = response else: # API response - use .model_dump() mock_result.structured_output.model_dump.return_value = response - + mock_agent.return_value = mock_result mock_agent_instances.append(mock_agent) return mock_agent - + with pytest.MonkeyPatch().context() as m: # Mock the Agent class constructor - from strands_evals.simulation.tool_simulator import Agent - m.setattr('strands_evals.simulation.tool_simulator.Agent', create_mock_agent) - + m.setattr("strands_evals.simulation.tool_simulator.Agent", create_mock_agent) + # Execute each tool in order balance_result = simulator.check_balance("12345") transfer_result = simulator.transfer_funds(from_account="12345", to_account="67890") transactions_result = simulator.get_transactions(account_id="12345") - + # Verify results assert balance_result == {"balance": 1000, "currency": "USD"} assert transfer_result == {"content": [{"type": "text", "text": "Transfer completed"}]} assert transactions_result == {"status": 200, "data": {"transactions": []}} - + # Verify all agents were called assert len(mock_agent_instances) == 3 for agent in mock_agent_instances: assert agent.called - + # Verify all tools accessed the same shared state shared_state = simulator._state_registry.get_state(shared_state_id) assert "initial_state" in shared_state assert shared_state["initial_state"] == initial_state assert "previous_calls" in shared_state assert len(shared_state["previous_calls"]) == 3 - + # Check that all three tool calls are recorded in the shared state tool_names = [call["tool_name"] for call in shared_state["previous_calls"]] assert "check_balance" in tool_names - assert "transfer_funds" in tool_names + assert "transfer_funds" in tool_names assert "get_transactions" in tool_names - + # Verify each tool type recorded its specific data correctly function_call = next(call for call in shared_state["previous_calls"] if call["tool_name"] == "check_balance") assert "parameters" in function_call - + mcp_call = next(call for call in shared_state["previous_calls"] if call["tool_name"] == "transfer_funds") assert "input_mcp_payload" in mcp_call - + api_call = next(call for call in shared_state["previous_calls"] if call["tool_name"] == "get_transactions") assert "path" in api_call assert "method" in api_call @@ -328,15 +320,15 @@ def create_mock_agent(**kwargs): def test_record_tool_call_function(): """Test recording function call in state registry using unified method.""" registry = StateRegistry() - + registry.record_tool_call( tool_name="test_tool", state_key="test_state", tool_type=ToolType.FUNCTION, response_data={"result": "success"}, - parameters={"param": "value"} + parameters={"param": "value"}, ) - + state = registry.get_state("test_state") assert "previous_calls" in state assert len(state["previous_calls"]) == 1 @@ -350,15 +342,15 @@ def test_record_tool_call_function(): def test_record_tool_call_mcp(): """Test recording MCP tool call in state registry using unified method.""" registry = StateRegistry() - + registry.record_tool_call( tool_name="mcp_tool", state_key="mcp_state", tool_type=ToolType.MCP, response_data={"content": [{"type": "text", "text": "result"}]}, - input_mcp_payload={"input": "data"} + input_mcp_payload={"input": "data"}, ) - + state = registry.get_state("mcp_state") assert "previous_calls" in state assert len(state["previous_calls"]) == 1 @@ -372,7 +364,7 @@ def test_record_tool_call_mcp(): def test_record_tool_call_api(): """Test recording API call in state registry using unified method.""" registry = StateRegistry() - + registry.record_tool_call( tool_name="api_tool", state_key="api_state", @@ -380,9 +372,9 @@ def test_record_tool_call_api(): response_data={"status": 200}, path="/test", method="POST", - input_data={"data": "test"} + input_data={"data": "test"}, ) - + state = registry.get_state("api_state") assert "previous_calls" in state assert len(state["previous_calls"]) == 1 @@ -398,116 +390,116 @@ def test_record_tool_call_api(): def test_tool_not_found_raises_error(): """Test that accessing non-existent tools raises ValueError.""" simulator = ToolSimulator() - + # Test that accessing a non-existent tool via _simulate_tool_call raises ValueError with pytest.raises(ValueError) as excinfo: simulator._simulate_tool_call( - tool_type=ToolType.FUNCTION, - state_key="test", - input_data={"tool_name": "nonexistent_tool"} + tool_type=ToolType.FUNCTION, state_key="test", input_data={"tool_name": "nonexistent_tool"} ) - + assert "not registered" in str(excinfo.value) def test_api_tool_missing_name_raises_error(): """Test that API tool simulation raises ValueError when tool_name is missing.""" simulator = ToolSimulator() - + with pytest.raises(ValueError) as excinfo: simulator._handle_api_tool( input_data={"tool_name": ""}, # Empty tool name - state_key="test" + state_key="test", ) - + assert "tool_name is required for API tool simulation" in str(excinfo.value) def test_mock_mode_missing_function_raises_error(): """Test that mock mode raises ValueError when mock_function is missing.""" + # Register a tool without mock_function but with mock mode @ToolSimulator.function_tool("test_mock_tool", mode="mock") def test_mock_tool(): pass - + simulator = ToolSimulator() registered_tool = ToolSimulator._registered_tools["test_mock_tool"] - + with pytest.raises(ValueError) as excinfo: simulator._handle_mock_mode( registered_tool=registered_tool, input_data={"tool_name": "test_mock_tool", "parameters": {}}, state_key="test", - tool_type=ToolType.FUNCTION + tool_type=ToolType.FUNCTION, ) - + assert "mock_function is required for tool simulator mock mode" in str(excinfo.value) def test_clear_registry(): """Test clearing tool registry.""" + @ToolSimulator.function_tool("test_function") def test_func(): pass - + assert len(ToolSimulator._registered_tools) == 1 - + ToolSimulator.clear_registry() - + assert len(ToolSimulator._registered_tools) == 0 assert ToolSimulator._state_registry is None def test_function_tool_decorator_stacking_with_strands_tool(): """Test function tool decorator stacking with Strands @tool decorator.""" + # Mock function that handles parameters with **kwargs def mock_function(**kwargs): input_value = kwargs.get("input_value", "") return {"result": f"processed {input_value}"} - + # Define tool with stacked decorators @tool - @ToolSimulator.function_tool("stacked_function_tool", mode="mock", - mock_function=mock_function) + @ToolSimulator.function_tool("stacked_function_tool", mode="mock", mock_function=mock_function) def stacked_function_tool(input_value: str) -> Dict[str, Any]: """Test function tool with stacked decorators. - + Args: input_value: Input parameter for processing """ pass - + # Create simulator simulator = ToolSimulator() - + # Test that the tool is callable and returns expected result result = simulator.stacked_function_tool(input_value="test_input") assert result == {"result": "processed test_input"} - + # Verify the tool is registered in ToolSimulator assert "stacked_function_tool" in ToolSimulator._registered_tools registered_tool = ToolSimulator._registered_tools["stacked_function_tool"] assert registered_tool.tool_type == ToolType.FUNCTION assert registered_tool.mode == "mock" assert registered_tool.mock_function == mock_function - + # Validate Strands tool creation assert stacked_function_tool.tool_spec is not None spec = stacked_function_tool.tool_spec - + # Check basic spec properties assert spec["name"] == "stacked_function_tool" assert spec["description"] == "Test function tool with stacked decorators." - + # Check input schema schema = spec["inputSchema"]["json"] assert schema["type"] == "object" assert set(schema["required"]) == {"input_value"} - + # Check parameter properties assert schema["properties"]["input_value"]["type"] == "string" assert schema["properties"]["input_value"]["description"] == "Input parameter for processing" - + # Make sure these are set properly assert stacked_function_tool.__wrapped__ is not None assert stacked_function_tool.__doc__ == stacked_function_tool._tool_func.__doc__ @@ -515,74 +507,63 @@ def stacked_function_tool(input_value: str) -> Dict[str, Any]: def test_mcp_tool_decorator_stacking_with_strands_tool(): """Test MCP tool decorator stacking with Strands @tool decorator.""" + # Mock function for MCP tool def mock_mcp_processor(param1, param2=42): - return { - "content": [ - {"type": "text", "text": f"MCP processed: {param1} with value {param2}"} - ], - "isError": False - } - + return {"content": [{"type": "text", "text": f"MCP processed: {param1} with value {param2}"}], "isError": False} + schema = { "type": "object", - "properties": { - "param1": {"type": "string"}, - "param2": {"type": "integer", "default": 42} - }, - "required": ["param1"] + "properties": {"param1": {"type": "string"}, "param2": {"type": "integer", "default": 42}}, + "required": ["param1"], } - + # Define tool with stacked decorators @tool - @ToolSimulator.mcp_tool("stacked_mcp_tool", schema=schema, mode="mock", - mock_function=mock_mcp_processor) + @ToolSimulator.mcp_tool("stacked_mcp_tool", schema=schema, mode="mock", mock_function=mock_mcp_processor) def stacked_mcp_tool(param1: str, param2: int = 42) -> Dict[str, Any]: """Test MCP tool with stacked decorators. - + Args: param1: First parameter for MCP processing param2: Second parameter with default value """ pass - + # Create simulator simulator = ToolSimulator() - + # Test that the tool is callable and returns expected result result = simulator.stacked_mcp_tool(param1="test", param2=100) - expected = { - "content": [{"type": "text", "text": "MCP processed: test with value 100"}], - "isError": False - } + expected = {"content": [{"type": "text", "text": "MCP processed: test with value 100"}], "isError": False} assert result == expected - + # Verify the tool is registered in ToolSimulator assert "stacked_mcp_tool" in ToolSimulator._registered_tools registered_tool = ToolSimulator._registered_tools["stacked_mcp_tool"] assert registered_tool.tool_type == ToolType.MCP assert registered_tool.mode == "mock" assert registered_tool.mock_function == mock_mcp_processor - + # Validate Strands tool creation assert stacked_mcp_tool.tool_spec is not None spec = stacked_mcp_tool.tool_spec - + # Check basic spec properties assert spec["name"] == "stacked_mcp_tool" assert spec["description"] == "Test MCP tool with stacked decorators." - + # Check input schema schema = spec["inputSchema"]["json"] assert schema["type"] == "object" assert set(schema["required"]) == {"param1"} - + # Check parameter properties assert schema["properties"]["param1"]["type"] == "string" assert schema["properties"]["param2"]["type"] == "integer" assert schema["properties"]["param1"]["description"] == "First parameter for MCP processing" assert schema["properties"]["param2"]["description"] == "Second parameter with default value" - + # Make sure these are set properly assert stacked_mcp_tool.__wrapped__ is not None assert stacked_mcp_tool.__doc__ == stacked_mcp_tool._tool_func.__doc__ @@ -593,32 +574,29 @@ def test_api_tool_decorator_stacking_with_strands_tool(): # Static response for API tool static_response = { "status": 200, - "data": { - "message": "API tool working", - "timestamp": "2024-01-01T12:00:00Z", - "endpoint": "/test/api" - } + "data": {"message": "API tool working", "timestamp": "2024-01-01T12:00:00Z", "endpoint": "/test/api"}, } - + # Define tool with stacked decorators @tool - @ToolSimulator.api_tool("stacked_api_tool", path="/test/api", method="GET", - mode="static", static_response=static_response) + @ToolSimulator.api_tool( + "stacked_api_tool", path="/test/api", method="GET", mode="static", static_response=static_response + ) def stacked_api_tool(query: str = "") -> Dict[str, Any]: """Test API tool with stacked decorators. - + Args: query: Query parameter for API call """ pass - + # Create simulator simulator = ToolSimulator() - + # Test that the tool is callable and returns expected result result = simulator.stacked_api_tool(query="test_query") assert result == static_response - + # Verify the tool is registered in ToolSimulator assert "stacked_api_tool" in ToolSimulator._registered_tools registered_tool = ToolSimulator._registered_tools["stacked_api_tool"] @@ -627,26 +605,26 @@ def stacked_api_tool(query: str = "") -> Dict[str, Any]: assert registered_tool.api_path == "/test/api" assert registered_tool.api_method == "GET" assert registered_tool.static_response == static_response - + # Validate Strands tool creation assert stacked_api_tool.tool_spec is not None spec = stacked_api_tool.tool_spec - + # Check basic spec properties assert spec["name"] == "stacked_api_tool" assert spec["description"] == "Test API tool with stacked decorators." - + # Check input schema schema = spec["inputSchema"]["json"] assert schema["type"] == "object" # query parameter is optional, so required list may be empty or missing required_fields = set(schema.get("required", [])) assert required_fields == set() - + # Check parameter properties assert schema["properties"]["query"]["type"] == "string" assert schema["properties"]["query"]["description"] == "Query parameter for API call" - + # Make sure these are set properly assert stacked_api_tool.__wrapped__ is not None assert stacked_api_tool.__doc__ == stacked_api_tool._tool_func.__doc__ From 9701bbee3f185f6b0b1ec627746522dbe926dff6 Mon Sep 17 00:00:00 2001 From: Darren Wang Date: Thu, 29 Jan 2026 18:49:29 +0000 Subject: [PATCH 08/15] fix tool simulator integration with strands tool decorator --- .../simulation/tool_simulator.py | 103 +++++++++++++++--- 1 file changed, 85 insertions(+), 18 deletions(-) diff --git a/src/strands_evals/simulation/tool_simulator.py b/src/strands_evals/simulation/tool_simulator.py index f1bde51..c469065 100644 --- a/src/strands_evals/simulation/tool_simulator.py +++ b/src/strands_evals/simulation/tool_simulator.py @@ -50,7 +50,6 @@ def initialize_state_via_description(self, initial_state_description: str, state self._states[state_key] = { "initial_state": initial_state_description, "previous_calls": [], - "user_context": {}, } else: warnings.warn( @@ -65,16 +64,13 @@ def get_state(self, state_key: str) -> Dict[str, Any]: state_key: Key for the state (tool_name or share_state_id). Returns: - State dictionary containing previous_calls and user_context. + State dictionary containing previous_calls. """ if state_key is None: raise ValueError("Value of state_key is required.") if state_key not in self._states: - self._states[state_key] = { - "previous_calls": [], - "user_context": {}, - } + self._states[state_key] = {"previous_calls": []} return dict(self._states[state_key]) @@ -163,6 +159,7 @@ class ToolSimulator: # Class-level registry for all registered tools _registered_tools: Dict[str, RegisteredTool] = {} _state_registry: Optional[StateRegistry] = None + _global_instance: Optional["ToolSimulator"] = None def __init__( self, @@ -171,6 +168,7 @@ def __init__( mcp_tool_prompt: Optional[str] = None, api_tool_prompt: Optional[str] = None, model: Model | str | None = None, + framework: str = "strands", ): """ Initialize a ToolSimulator instance. @@ -181,7 +179,10 @@ def __init__( mcp_tool_prompt: Optional custom prompt for MCP tool response generation api_tool_prompt: Optional custom prompt for API tool response generation model: Provider for running inference or a string representing the model-id for Bedrock to use + framework: Agent framework to use (default: "strands") """ + # Store framework selection + self.framework = framework # Store model configuration for creating internal agents self.model_id = model @@ -216,6 +217,10 @@ def __init__( # Initialize shared states from registered tools self._initialize_shared_states() + # Set as global instance if none exists + if ToolSimulator._global_instance is None: + ToolSimulator._global_instance = self + def _initialize_shared_states(self): """Initialize shared states from registered tools' initial descriptions.""" for tool_name, registered_tool in self._registered_tools.items(): @@ -252,8 +257,8 @@ def __getattr__(self, name: str) -> Any: raise AttributeError(f"Tool '{name}' not found in registered tools") - def _create_tool_wrapper(self, registered_tool: RegisteredTool) -> Callable: - """Create a wrapper function for direct tool access.""" + def _create_tool_wrapper(self, registered_tool: RegisteredTool): + """Create a framework-compatible tool wrapper.""" def wrapper(*args, **kwargs): # Determine state key @@ -296,21 +301,71 @@ def wrapper(*args, **kwargs): # Copy function metadata if registered_tool.function: wrapper.__name__ = registered_tool.function.__name__ - try: - wrapper.__signature__ = inspect.signature(registered_tool.function) # type: ignore - except (ValueError, TypeError): - pass wrapper.__doc__ = registered_tool.function.__doc__ else: wrapper.__name__ = registered_tool.name + wrapper.__doc__ = f"Simulated {registered_tool.name} tool" + + # Use framework-specific method to create the tool wrapper + if self.framework == "strands": + return self._create_strands_tool_wrapper(registered_tool, wrapper) + else: + raise ValueError(f"Framework '{self.framework}' is not supported. Only 'strands' is currently supported.") + + def _create_strands_tool_wrapper(self, registered_tool: RegisteredTool, wrapper: Callable): + """Create a Strands-specific DecoratedFunctionTool wrapper.""" + from strands.tools.decorator import DecoratedFunctionTool, FunctionToolMetadata - return wrapper + # Create tool spec based on function signature and docstring + tool_description = wrapper.__doc__ or f"Simulated {registered_tool.name} tool" + + # Build input schema from function signature + input_schema: Dict[str, Any] = {"type": "object", "properties": {}} + if registered_tool.function: + try: + sig = inspect.signature(registered_tool.function) + for param_name, param in sig.parameters.items(): + if param.annotation != inspect.Parameter.empty: + param_type = ( + str(param.annotation).replace("", "").replace("typing.", "") + ) + if "str" in param_type.lower(): + input_schema["properties"][param_name] = {"type": "string"} + elif "int" in param_type.lower(): + input_schema["properties"][param_name] = {"type": "integer"} + elif "float" in param_type.lower(): + input_schema["properties"][param_name] = {"type": "number"} + elif "bool" in param_type.lower(): + input_schema["properties"][param_name] = {"type": "boolean"} + else: + input_schema["properties"][param_name] = {"type": "object"} + else: + input_schema["properties"][param_name] = {"type": "string"} # default + except Exception: + pass # fallback to empty schema + + # Create Strands tool's FunctionToolMetadata object and DecoratedFunctionTool instance + metadata = FunctionToolMetadata(registered_tool.function or wrapper) + + # Extract tool_spec from metadata; override with our custom description if needed + extracted_tool_spec = metadata.extract_metadata() + if tool_description != extracted_tool_spec.get("description"): + extracted_tool_spec["description"] = tool_description + extracted_tool_spec["name"] = registered_tool.name + + decorated_tool = DecoratedFunctionTool( + tool_name=registered_tool.name, + tool_spec=extracted_tool_spec, + tool_func=wrapper, # Always use wrapper to ensure simulation logic is executed + metadata=metadata, + ) + + return decorated_tool def _simulate_tool_call(self, tool_type: ToolType, state_key: str, input_data: Dict[str, Any]) -> Any: """Simulate a tool invocation and return the response.""" tool_name = input_data.get("tool_name", "") registered_tool = self._registered_tools.get(tool_name) - if not registered_tool: raise ValueError(f"Tool '{tool_name}' not registered") @@ -365,13 +420,17 @@ def _handle_function_tool(self, input_data: Dict[str, Any], state_key: str) -> D result = agent(prompt, structured_output_model=None) # Parse JSON response for function tools since they vary based on function signature - if result.response and isinstance(result.response, str): + response_text = ( + getattr(result, "response", None) or str(result.content) if hasattr(result, "content") else str(result) + ) + + if response_text and isinstance(response_text, str): try: - response_data = json.loads(result.response) + response_data = json.loads(response_text) except json.JSONDecodeError: - response_data = {"result": result.response} + response_data = {"result": response_text} else: - response_data = {"result": str(result.response) if result.response else "No response"} + response_data = {"result": response_text or "No response"} # Record the call self._state_registry.record_tool_call( @@ -575,8 +634,16 @@ def clear_registry(cls): """Clear all registered tools. Useful for testing.""" cls._registered_tools.clear() cls._state_registry = None + cls._global_instance = None logger.info("Cleared tool registry") + @classmethod + def _get_instance(cls) -> "ToolSimulator": + """Get the global ToolSimulator instance.""" + if cls._global_instance is None: + cls._global_instance = cls() + return cls._global_instance + @classmethod def function_tool( cls, From e6702f598f667deb7cc769781de6c86cc988fc30 Mon Sep 17 00:00:00 2001 From: Darren Wang Date: Thu, 29 Jan 2026 21:53:39 +0000 Subject: [PATCH 09/15] update test --- .../simulation/test_tool_simulator.py | 194 ++++++------------ 1 file changed, 58 insertions(+), 136 deletions(-) diff --git a/tests/strands_evals/simulation/test_tool_simulator.py b/tests/strands_evals/simulation/test_tool_simulator.py index 1c79665..c9bb3fc 100644 --- a/tests/strands_evals/simulation/test_tool_simulator.py +++ b/tests/strands_evals/simulation/test_tool_simulator.py @@ -450,64 +450,39 @@ def test_func(): assert ToolSimulator._state_registry is None -def test_function_tool_decorator_stacking_with_strands_tool(): - """Test function tool decorator stacking with Strands @tool decorator.""" - - # Mock function that handles parameters with **kwargs - def mock_function(**kwargs): - input_value = kwargs.get("input_value", "") +def test_attaching_function_tool_simulator_to_strands_agent(): + """Test attaching function tool simulator to Strands agent.""" + + # Mock function that handles parameters + def mock_function(input_value): return {"result": f"processed {input_value}"} - # Define tool with stacked decorators - @tool - @ToolSimulator.function_tool("stacked_function_tool", mode="mock", mock_function=mock_function) - def stacked_function_tool(input_value: str) -> Dict[str, Any]: - """Test function tool with stacked decorators. + # Register a function tool simulator + @ToolSimulator.function_tool("test_function_tool", mode="mock", mock_function=mock_function) + def test_function_tool(input_value: str) -> Dict[str, Any]: + """Test function tool for agent attachment. Args: input_value: Input parameter for processing """ pass - # Create simulator + # Create simulator and get the tool simulator = ToolSimulator() - - # Test that the tool is callable and returns expected result - result = simulator.stacked_function_tool(input_value="test_input") - assert result == {"result": "processed test_input"} - - # Verify the tool is registered in ToolSimulator - assert "stacked_function_tool" in ToolSimulator._registered_tools - registered_tool = ToolSimulator._registered_tools["stacked_function_tool"] - assert registered_tool.tool_type == ToolType.FUNCTION - assert registered_tool.mode == "mock" - assert registered_tool.mock_function == mock_function - - # Validate Strands tool creation - assert stacked_function_tool.tool_spec is not None - spec = stacked_function_tool.tool_spec - - # Check basic spec properties - assert spec["name"] == "stacked_function_tool" - assert spec["description"] == "Test function tool with stacked decorators." - - # Check input schema - schema = spec["inputSchema"]["json"] - assert schema["type"] == "object" - assert set(schema["required"]) == {"input_value"} - - # Check parameter properties - assert schema["properties"]["input_value"]["type"] == "string" - assert schema["properties"]["input_value"]["description"] == "Input parameter for processing" - - # Make sure these are set properly - assert stacked_function_tool.__wrapped__ is not None - assert stacked_function_tool.__doc__ == stacked_function_tool._tool_func.__doc__ - - -def test_mcp_tool_decorator_stacking_with_strands_tool(): - """Test MCP tool decorator stacking with Strands @tool decorator.""" - + tool_wrapper = simulator.get_tool("test_function_tool") + + # Create a Strands Agent with the tool simulator + from strands import Agent + agent = Agent(tools=[tool_wrapper]) + + # Verify the agent has access to the tool + assert "test_function_tool" in agent.tool_names + assert hasattr(agent.tool, "test_function_tool") + + +def test_attaching_mcp_tool_simulator_to_strands_agent(): + """Test attaching MCP tool simulator to Strands agent.""" + # Mock function for MCP tool def mock_mcp_processor(param1, param2=42): return {"content": [{"type": "text", "text": f"MCP processed: {param1} with value {param2}"}], "isError": False} @@ -518,11 +493,10 @@ def mock_mcp_processor(param1, param2=42): "required": ["param1"], } - # Define tool with stacked decorators - @tool - @ToolSimulator.mcp_tool("stacked_mcp_tool", schema=schema, mode="mock", mock_function=mock_mcp_processor) - def stacked_mcp_tool(param1: str, param2: int = 42) -> Dict[str, Any]: - """Test MCP tool with stacked decorators. + # Register an MCP tool simulator + @ToolSimulator.mcp_tool("test_mcp_tool", schema=schema, mode="mock", mock_function=mock_mcp_processor) + def test_mcp_tool(param1: str, param2: int = 42) -> Dict[str, Any]: + """Test MCP tool for agent attachment. Args: param1: First parameter for MCP processing @@ -530,101 +504,49 @@ def stacked_mcp_tool(param1: str, param2: int = 42) -> Dict[str, Any]: """ pass - # Create simulator + # Create simulator and get the tool simulator = ToolSimulator() - - # Test that the tool is callable and returns expected result - result = simulator.stacked_mcp_tool(param1="test", param2=100) - expected = {"content": [{"type": "text", "text": "MCP processed: test with value 100"}], "isError": False} - assert result == expected - - # Verify the tool is registered in ToolSimulator - assert "stacked_mcp_tool" in ToolSimulator._registered_tools - registered_tool = ToolSimulator._registered_tools["stacked_mcp_tool"] - assert registered_tool.tool_type == ToolType.MCP - assert registered_tool.mode == "mock" - assert registered_tool.mock_function == mock_mcp_processor - - # Validate Strands tool creation - assert stacked_mcp_tool.tool_spec is not None - spec = stacked_mcp_tool.tool_spec - - # Check basic spec properties - assert spec["name"] == "stacked_mcp_tool" - assert spec["description"] == "Test MCP tool with stacked decorators." - - # Check input schema - schema = spec["inputSchema"]["json"] - assert schema["type"] == "object" - assert set(schema["required"]) == {"param1"} - - # Check parameter properties - assert schema["properties"]["param1"]["type"] == "string" - assert schema["properties"]["param2"]["type"] == "integer" - assert schema["properties"]["param1"]["description"] == "First parameter for MCP processing" - assert schema["properties"]["param2"]["description"] == "Second parameter with default value" - - # Make sure these are set properly - assert stacked_mcp_tool.__wrapped__ is not None - assert stacked_mcp_tool.__doc__ == stacked_mcp_tool._tool_func.__doc__ - - -def test_api_tool_decorator_stacking_with_strands_tool(): - """Test API tool decorator stacking with Strands @tool decorator.""" + tool_wrapper = simulator.get_tool("test_mcp_tool") + + # Create a Strands Agent with the tool simulator + from strands import Agent + agent = Agent(tools=[tool_wrapper]) + + # Verify the agent has access to the tool + assert "test_mcp_tool" in agent.tool_names + assert hasattr(agent.tool, "test_mcp_tool") + + +def test_attaching_api_tool_simulator_to_strands_agent(): + """Test attaching API tool simulator to Strands agent.""" + # Static response for API tool static_response = { "status": 200, "data": {"message": "API tool working", "timestamp": "2024-01-01T12:00:00Z", "endpoint": "/test/api"}, } - # Define tool with stacked decorators - @tool + # Register an API tool simulator @ToolSimulator.api_tool( - "stacked_api_tool", path="/test/api", method="GET", mode="static", static_response=static_response + "test_api_tool", path="/test/api", method="GET", mode="static", static_response=static_response ) - def stacked_api_tool(query: str = "") -> Dict[str, Any]: - """Test API tool with stacked decorators. + def test_api_tool(query: str = "") -> Dict[str, Any]: + """Test API tool for agent attachment. Args: query: Query parameter for API call """ pass - # Create simulator + # Create simulator and get the tool simulator = ToolSimulator() - - # Test that the tool is callable and returns expected result - result = simulator.stacked_api_tool(query="test_query") - assert result == static_response - - # Verify the tool is registered in ToolSimulator - assert "stacked_api_tool" in ToolSimulator._registered_tools - registered_tool = ToolSimulator._registered_tools["stacked_api_tool"] - assert registered_tool.tool_type == ToolType.API - assert registered_tool.mode == "static" - assert registered_tool.api_path == "/test/api" - assert registered_tool.api_method == "GET" - assert registered_tool.static_response == static_response - - # Validate Strands tool creation - assert stacked_api_tool.tool_spec is not None - spec = stacked_api_tool.tool_spec - - # Check basic spec properties - assert spec["name"] == "stacked_api_tool" - assert spec["description"] == "Test API tool with stacked decorators." - - # Check input schema - schema = spec["inputSchema"]["json"] - assert schema["type"] == "object" - # query parameter is optional, so required list may be empty or missing - required_fields = set(schema.get("required", [])) - assert required_fields == set() - - # Check parameter properties - assert schema["properties"]["query"]["type"] == "string" - assert schema["properties"]["query"]["description"] == "Query parameter for API call" - - # Make sure these are set properly - assert stacked_api_tool.__wrapped__ is not None - assert stacked_api_tool.__doc__ == stacked_api_tool._tool_func.__doc__ + tool_wrapper = simulator.get_tool("test_api_tool") + + # Create a Strands Agent with the tool simulator + from strands import Agent + agent = Agent(tools=[tool_wrapper]) + + # Verify the agent has access to the tool + assert "test_api_tool" in agent.tool_names + assert hasattr(agent.tool, "test_api_tool") + \ No newline at end of file From 367c637fe37bb4f00cff098dec15eb3bbf525a0b Mon Sep 17 00:00:00 2001 From: Darren Wang Date: Thu, 29 Jan 2026 23:52:32 +0000 Subject: [PATCH 10/15] fix test --- .../simulation/test_tool_simulator.py | 43 +++++++++++-------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/tests/strands_evals/simulation/test_tool_simulator.py b/tests/strands_evals/simulation/test_tool_simulator.py index c9bb3fc..b98f80a 100644 --- a/tests/strands_evals/simulation/test_tool_simulator.py +++ b/tests/strands_evals/simulation/test_tool_simulator.py @@ -4,7 +4,6 @@ from unittest.mock import MagicMock import pytest -from strands import tool from strands_evals.case import Case from strands_evals.simulation.tool_simulator import StateRegistry, ToolSimulator @@ -119,7 +118,10 @@ def test_func(message: str) -> dict: # Mock the Agent constructor and its result to avoid real LLM calls mock_agent_instance = MagicMock() mock_result = MagicMock() - mock_result.structured_output = {"result": "simulated response"} + # For function tools, set structured_output to None so it falls back to parsing response text + mock_result.structured_output = None + # Mock the response attribute to return parsable JSON string + mock_result.response = '{"result": "simulated response"}' mock_agent_instance.return_value = mock_result with pytest.MonkeyPatch().context() as m: @@ -260,14 +262,15 @@ def create_mock_agent(**kwargs): if "content" in response: # MCP response needs .model_dump() mock_result.structured_output.model_dump.return_value = response + elif "balance" in response: + # Function response - set structured_output to None and provide JSON string in response + mock_result.structured_output = None + import json + + mock_result.response = json.dumps(response) else: - # Function and API responses - function uses dict directly, API uses .model_dump() - if "balance" in response: - # Function response - use direct structured_output - mock_result.structured_output = response - else: - # API response - use .model_dump() - mock_result.structured_output.model_dump.return_value = response + # API response - use .model_dump() + mock_result.structured_output.model_dump.return_value = response mock_agent.return_value = mock_result mock_agent_instances.append(mock_agent) @@ -452,7 +455,7 @@ def test_func(): def test_attaching_function_tool_simulator_to_strands_agent(): """Test attaching function tool simulator to Strands agent.""" - + # Mock function that handles parameters def mock_function(input_value): return {"result": f"processed {input_value}"} @@ -470,11 +473,12 @@ def test_function_tool(input_value: str) -> Dict[str, Any]: # Create simulator and get the tool simulator = ToolSimulator() tool_wrapper = simulator.get_tool("test_function_tool") - + # Create a Strands Agent with the tool simulator from strands import Agent + agent = Agent(tools=[tool_wrapper]) - + # Verify the agent has access to the tool assert "test_function_tool" in agent.tool_names assert hasattr(agent.tool, "test_function_tool") @@ -482,7 +486,7 @@ def test_function_tool(input_value: str) -> Dict[str, Any]: def test_attaching_mcp_tool_simulator_to_strands_agent(): """Test attaching MCP tool simulator to Strands agent.""" - + # Mock function for MCP tool def mock_mcp_processor(param1, param2=42): return {"content": [{"type": "text", "text": f"MCP processed: {param1} with value {param2}"}], "isError": False} @@ -507,11 +511,12 @@ def test_mcp_tool(param1: str, param2: int = 42) -> Dict[str, Any]: # Create simulator and get the tool simulator = ToolSimulator() tool_wrapper = simulator.get_tool("test_mcp_tool") - + # Create a Strands Agent with the tool simulator from strands import Agent + agent = Agent(tools=[tool_wrapper]) - + # Verify the agent has access to the tool assert "test_mcp_tool" in agent.tool_names assert hasattr(agent.tool, "test_mcp_tool") @@ -519,7 +524,7 @@ def test_mcp_tool(param1: str, param2: int = 42) -> Dict[str, Any]: def test_attaching_api_tool_simulator_to_strands_agent(): """Test attaching API tool simulator to Strands agent.""" - + # Static response for API tool static_response = { "status": 200, @@ -541,12 +546,12 @@ def test_api_tool(query: str = "") -> Dict[str, Any]: # Create simulator and get the tool simulator = ToolSimulator() tool_wrapper = simulator.get_tool("test_api_tool") - + # Create a Strands Agent with the tool simulator from strands import Agent + agent = Agent(tools=[tool_wrapper]) - + # Verify the agent has access to the tool assert "test_api_tool" in agent.tool_names assert hasattr(agent.tool, "test_api_tool") - \ No newline at end of file From 087ce6e1595de2226530d97ffde6c02deed1e012 Mon Sep 17 00:00:00 2001 From: Darren Wang Date: Mon, 2 Feb 2026 19:05:35 +0000 Subject: [PATCH 11/15] utilize built-in collection types; improve readability --- .../simulation/tool_simulator.py | 78 ++++++++++++------- .../simulation/test_tool_simulator.py | 12 +-- 2 files changed, 55 insertions(+), 35 deletions(-) diff --git a/src/strands_evals/simulation/tool_simulator.py b/src/strands_evals/simulation/tool_simulator.py index c469065..4646a66 100644 --- a/src/strands_evals/simulation/tool_simulator.py +++ b/src/strands_evals/simulation/tool_simulator.py @@ -2,6 +2,7 @@ import json import logging import warnings +from collections import defaultdict, deque from datetime import datetime from typing import Any, Callable, Dict, List, Optional @@ -22,16 +23,28 @@ class StateRegistry: """ State registry for managing shared state between tool simulators. Organized by state_key to isolate state between different tools or shared state groups. + + The registry automatically maintains a bounded cache of tool calls per state key. + The maximum number of tool calls stored is configurable via max_tool_call_cache_size parameter. """ - def __init__(self): + def __init__(self, max_tool_call_cache_size: int = 20): """ Initialize state registry. Creates an empty state dictionary to track tool calls and responses - across different simulation sessions. + across different simulation sessions. Tool call cache is automatically + bounded to prevent excessive memory usage. + + Args: + max_tool_call_cache_size: Maximum number of tool calls to store per state key. + Older calls are automatically evicted when limit is exceeded. + Default is 20. """ - self._states: Dict[str, Dict[str, Any]] = {} + self._max_tool_call_cache_size = max_tool_call_cache_size + self._states: defaultdict[str, Dict[str, Any]] = defaultdict( + lambda: {"previous_calls": deque(maxlen=self._max_tool_call_cache_size)} + ) def initialize_state_via_description(self, initial_state_description: str, state_key: str) -> None: """ @@ -49,7 +62,7 @@ def initialize_state_via_description(self, initial_state_description: str, state if state_key not in self._states: self._states[state_key] = { "initial_state": initial_state_description, - "previous_calls": [], + "previous_calls": deque(maxlen=self._max_tool_call_cache_size), } else: warnings.warn( @@ -69,12 +82,13 @@ def get_state(self, state_key: str) -> Dict[str, Any]: if state_key is None: raise ValueError("Value of state_key is required.") - if state_key not in self._states: - self._states[state_key] = {"previous_calls": []} + # Access will create the default state automatically due to defaultdict + state = self._states[state_key] - return dict(self._states[state_key]) + # Convert deque to list for JSON serialization compatibility + return {key: list(value) if isinstance(value, deque) else value for key, value in state.items()} - def record_tool_call( + def cache_tool_call( self, tool_name: str, state_key: str, @@ -83,7 +97,7 @@ def record_tool_call( **call_data: Any, ) -> Dict[str, Any]: """ - Record a tool call in the tool's state history. + Cache a tool call in the tool's state key. Args: tool_name: Name of the tool being called. @@ -95,7 +109,8 @@ def record_tool_call( Returns: Updated state dictionary. """ - state = self.get_state(state_key) + # Access the actual state storage (not converted copy) + state = self._states[state_key] timestamp = datetime.now().strftime("%Y%m%d%H%M%S") # Build call record based on tool type @@ -120,16 +135,11 @@ def record_tool_call( } ) + # Append to deque with automatic FIFO eviction when cache is full state["previous_calls"].append(call_record) - # Keep history manageable - if len(state["previous_calls"]) > 20: - state["previous_calls"] = state["previous_calls"][-20:] - - # Update the stored state - self._states[state_key] = state - - return state + # Return converted state for external use + return self.get_state(state_key) def clear_state(self, state_key: str) -> None: """ @@ -150,6 +160,10 @@ class ToolSimulator: a registry of all registered tools. It can be configured to override tool behavior for simulation purposes, enabling controlled testing scenarios. + The simulator automatically maintains a bounded cache of tool calls for context. + The maximum number of tool calls stored per state key is configurable via + max_tool_call_cache_size parameter (default: 20). + Attributes: model: Provider for running inference or model identifier for Bedrock. _registered_tools: Class-level registry for all registered tools. @@ -169,17 +183,23 @@ def __init__( api_tool_prompt: Optional[str] = None, model: Model | str | None = None, framework: str = "strands", + max_tool_call_cache_size: int = 20, ): """ Initialize a ToolSimulator instance. Args: - state_registry: Registry for maintaining tool state + state_registry: Registry for maintaining tool state. If not provided, + a new StateRegistry will be created with max_tool_call_cache_size. function_tool_prompt: Optional custom prompt for function tool response generation mcp_tool_prompt: Optional custom prompt for MCP tool response generation api_tool_prompt: Optional custom prompt for API tool response generation model: Provider for running inference or a string representing the model-id for Bedrock to use framework: Agent framework to use (default: "strands") + max_tool_call_cache_size: Maximum number of tool calls to store per state key. + Only used when creating a new StateRegistry (ignored if state_registry + is provided). Older calls are automatically evicted when limit is exceeded. + Default is 20. """ # Store framework selection self.framework = framework @@ -212,7 +232,7 @@ def __init__( if state_registry: self._state_registry = state_registry elif self._state_registry is None: - self._state_registry = StateRegistry() + self._state_registry = StateRegistry(max_tool_call_cache_size=max_tool_call_cache_size) # Initialize shared states from registered tools self._initialize_shared_states() @@ -432,8 +452,8 @@ def _handle_function_tool(self, input_data: Dict[str, Any], state_key: str) -> D else: response_data = {"result": response_text or "No response"} - # Record the call - self._state_registry.record_tool_call( + # Cache the call + self._state_registry.cache_tool_call( tool_name, state_key, ToolType.FUNCTION, response_data, parameters=parameters ) @@ -482,8 +502,8 @@ def _handle_mcp_tool(self, input_data: Dict[str, Any], state_key: str) -> Dict[s "content": [{"type": "text", "text": "No structured output received"}], } - # Record the call - self._state_registry.record_tool_call( + # Cache the call + self._state_registry.cache_tool_call( tool_name, state_key, ToolType.MCP, response_data, input_mcp_payload=input_mcp_payload ) @@ -540,8 +560,8 @@ def _handle_api_tool(self, input_data: Dict[str, Any], state_key: str) -> Dict[s }, } - # Record the call - self._state_registry.record_tool_call( + # Cache the call + self._state_registry.cache_tool_call( tool_name, state_key, ToolType.API, @@ -602,17 +622,17 @@ def _handle_mock_mode( tool_name = registered_tool.name if tool_type == ToolType.FUNCTION: - self._state_registry.record_tool_call( + self._state_registry.cache_tool_call( tool_name, state_key, ToolType.FUNCTION, result, parameters=parameters ) elif tool_type == ToolType.MCP: - self._state_registry.record_tool_call( + self._state_registry.cache_tool_call( tool_name, state_key, ToolType.MCP, result, input_mcp_payload=input_mcp_payload ) elif tool_type == ToolType.API: path = input_data.get("path", "") method = input_data.get("method", "GET") - self._state_registry.record_tool_call( + self._state_registry.cache_tool_call( tool_name, state_key, ToolType.API, diff --git a/tests/strands_evals/simulation/test_tool_simulator.py b/tests/strands_evals/simulation/test_tool_simulator.py index b98f80a..cfa52b5 100644 --- a/tests/strands_evals/simulation/test_tool_simulator.py +++ b/tests/strands_evals/simulation/test_tool_simulator.py @@ -320,11 +320,11 @@ def create_mock_agent(**kwargs): assert "method" in api_call -def test_record_tool_call_function(): +def test_cache_tool_call_function(): """Test recording function call in state registry using unified method.""" registry = StateRegistry() - registry.record_tool_call( + registry.cache_tool_call( tool_name="test_tool", state_key="test_state", tool_type=ToolType.FUNCTION, @@ -342,11 +342,11 @@ def test_record_tool_call_function(): assert call["response"] == {"result": "success"} -def test_record_tool_call_mcp(): +def test_cache_tool_call_mcp(): """Test recording MCP tool call in state registry using unified method.""" registry = StateRegistry() - registry.record_tool_call( + registry.cache_tool_call( tool_name="mcp_tool", state_key="mcp_state", tool_type=ToolType.MCP, @@ -364,11 +364,11 @@ def test_record_tool_call_mcp(): assert call["response"] == {"content": [{"type": "text", "text": "result"}]} -def test_record_tool_call_api(): +def test_cache_tool_call_api(): """Test recording API call in state registry using unified method.""" registry = StateRegistry() - registry.record_tool_call( + registry.cache_tool_call( tool_name="api_tool", state_key="api_state", tool_type=ToolType.API, From c084aea342cf57fbff867de58d83abde8d64a8c0 Mon Sep 17 00:00:00 2001 From: Darren Wang Date: Tue, 3 Feb 2026 17:45:41 +0000 Subject: [PATCH 12/15] clean loading and init --- .../simulation/tool_simulator.py | 42 +++++++------------ .../simulation/test_tool_simulator.py | 2 +- 2 files changed, 15 insertions(+), 29 deletions(-) diff --git a/src/strands_evals/simulation/tool_simulator.py b/src/strands_evals/simulation/tool_simulator.py index 4646a66..8ee5514 100644 --- a/src/strands_evals/simulation/tool_simulator.py +++ b/src/strands_evals/simulation/tool_simulator.py @@ -16,6 +16,12 @@ ToolType, ) +from .prompt_templates.tool_response_generation import ( + API_TOOL_RESPONSE_GENERATION_PROMPT, + FUNCTION_TOOL_RESPONSE_GENERATION_PROMPT, + MCP_TOOL_RESPONSE_GENERATION_PROMPT, +) + logger = logging.getLogger(__name__) @@ -204,35 +210,15 @@ def __init__( # Store framework selection self.framework = framework # Store model configuration for creating internal agents - self.model_id = model + self.model = model # Set custom prompts or use defaults - if function_tool_prompt is None: - from .prompt_templates.tool_response_generation import FUNCTION_TOOL_RESPONSE_GENERATION_PROMPT - - self.function_tool_prompt = FUNCTION_TOOL_RESPONSE_GENERATION_PROMPT - else: - self.function_tool_prompt = function_tool_prompt - - if mcp_tool_prompt is None: - from .prompt_templates.tool_response_generation import MCP_TOOL_RESPONSE_GENERATION_PROMPT - - self.mcp_tool_prompt = MCP_TOOL_RESPONSE_GENERATION_PROMPT - else: - self.mcp_tool_prompt = mcp_tool_prompt - - if api_tool_prompt is None: - from .prompt_templates.tool_response_generation import API_TOOL_RESPONSE_GENERATION_PROMPT - - self.api_tool_prompt = API_TOOL_RESPONSE_GENERATION_PROMPT - else: - self.api_tool_prompt = api_tool_prompt + self.function_tool_prompt = function_tool_prompt or FUNCTION_TOOL_RESPONSE_GENERATION_PROMPT + self.mcp_tool_prompt = mcp_tool_prompt or MCP_TOOL_RESPONSE_GENERATION_PROMPT + self.api_tool_prompt = api_tool_prompt or API_TOOL_RESPONSE_GENERATION_PROMPT # Set up state registry - if state_registry: - self._state_registry = state_registry - elif self._state_registry is None: - self._state_registry = StateRegistry(max_tool_call_cache_size=max_tool_call_cache_size) + self._state_registry = state_registry or StateRegistry(max_tool_call_cache_size=max_tool_call_cache_size) # Initialize shared states from registered tools self._initialize_shared_states() @@ -434,7 +420,7 @@ def _handle_function_tool(self, input_data: Dict[str, Any], state_key: str) -> D agent = Agent( system_prompt=self.function_tool_prompt, tools=[], - model=self.model_id, + model=self.model, callback_handler=None, ) result = agent(prompt, structured_output_model=None) @@ -490,7 +476,7 @@ def _handle_mcp_tool(self, input_data: Dict[str, Any], state_key: str) -> Dict[s agent = Agent( system_prompt=self.mcp_tool_prompt, tools=[], - model=self.model_id, + model=self.model, callback_handler=None, ) result = agent(prompt, structured_output_model=MCPToolResponse) @@ -544,7 +530,7 @@ def _handle_api_tool(self, input_data: Dict[str, Any], state_key: str) -> Dict[s agent = Agent( system_prompt=self.api_tool_prompt, tools=[], - model=self.model_id, + model=self.model, callback_handler=None, ) result = agent(prompt, structured_output_model=APIToolResponse) diff --git a/tests/strands_evals/simulation/test_tool_simulator.py b/tests/strands_evals/simulation/test_tool_simulator.py index cfa52b5..74cf5cb 100644 --- a/tests/strands_evals/simulation/test_tool_simulator.py +++ b/tests/strands_evals/simulation/test_tool_simulator.py @@ -51,7 +51,7 @@ def test_tool_simulator_init(): ) assert simulator._state_registry is custom_registry - assert simulator.model_id is None # model_id is now used instead of system_prompt_template + assert simulator.model is None # model is now used for LLM inference assert simulator.function_tool_prompt is not None # Check that prompt templates are loaded assert simulator.mcp_tool_prompt is not None assert simulator.api_tool_prompt is not None From 33c1b4040c44b16f7499bd0ef818b7a9f0f165cf Mon Sep 17 00:00:00 2001 From: Darren Wang Date: Tue, 3 Feb 2026 23:13:20 +0000 Subject: [PATCH 13/15] use python3.10 types and match case, simplify response validation --- .../simulation/tool_simulator.py | 152 +++++++++--------- src/strands_evals/types/simulation/tool.py | 30 ++-- .../simulation/test_tool_simulator.py | 28 ++-- 3 files changed, 101 insertions(+), 109 deletions(-) diff --git a/src/strands_evals/simulation/tool_simulator.py b/src/strands_evals/simulation/tool_simulator.py index 8ee5514..d43b2d2 100644 --- a/src/strands_evals/simulation/tool_simulator.py +++ b/src/strands_evals/simulation/tool_simulator.py @@ -4,7 +4,7 @@ import warnings from collections import defaultdict, deque from datetime import datetime -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable from strands import Agent from strands.models.model import Model @@ -48,7 +48,7 @@ def __init__(self, max_tool_call_cache_size: int = 20): Default is 20. """ self._max_tool_call_cache_size = max_tool_call_cache_size - self._states: defaultdict[str, Dict[str, Any]] = defaultdict( + self._states: defaultdict[str, dict[str, Any]] = defaultdict( lambda: {"previous_calls": deque(maxlen=self._max_tool_call_cache_size)} ) @@ -75,7 +75,7 @@ def initialize_state_via_description(self, initial_state_description: str, state f"State with key '{state_key}' already initialized. Skipping re-initialization.", stacklevel=2 ) - def get_state(self, state_key: str) -> Dict[str, Any]: + def get_state(self, state_key: str) -> dict[str, Any]: """ Get state for a specific tool or shared state group. @@ -101,7 +101,7 @@ def cache_tool_call( tool_type: ToolType, response_data: Any, **call_data: Any, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Cache a tool call in the tool's state key. @@ -177,16 +177,16 @@ class ToolSimulator: """ # Class-level registry for all registered tools - _registered_tools: Dict[str, RegisteredTool] = {} - _state_registry: Optional[StateRegistry] = None - _global_instance: Optional["ToolSimulator"] = None + _registered_tools: dict[str, RegisteredTool] = {} + _state_registry: StateRegistry | None = None + _global_instance: "ToolSimulator" | None = None def __init__( self, - state_registry: Optional[StateRegistry] = None, - function_tool_prompt: Optional[str] = None, - mcp_tool_prompt: Optional[str] = None, - api_tool_prompt: Optional[str] = None, + state_registry: StateRegistry | None = None, + function_tool_prompt: str | None = None, + mcp_tool_prompt: str | None = None, + api_tool_prompt: str | None = None, model: Model | str | None = None, framework: str = "strands", max_tool_call_cache_size: int = 20, @@ -326,7 +326,7 @@ def _create_strands_tool_wrapper(self, registered_tool: RegisteredTool, wrapper: tool_description = wrapper.__doc__ or f"Simulated {registered_tool.name} tool" # Build input schema from function signature - input_schema: Dict[str, Any] = {"type": "object", "properties": {}} + input_schema: dict[str, Any] = {"type": "object", "properties": {}} if registered_tool.function: try: sig = inspect.signature(registered_tool.function) @@ -368,7 +368,7 @@ def _create_strands_tool_wrapper(self, registered_tool: RegisteredTool, wrapper: return decorated_tool - def _simulate_tool_call(self, tool_type: ToolType, state_key: str, input_data: Dict[str, Any]) -> Any: + def _simulate_tool_call(self, tool_type: ToolType, state_key: str, input_data: dict[str, Any]) -> Any: """Simulate a tool invocation and return the response.""" tool_name = input_data.get("tool_name", "") registered_tool = self._registered_tools.get(tool_name) @@ -376,24 +376,26 @@ def _simulate_tool_call(self, tool_type: ToolType, state_key: str, input_data: D raise ValueError(f"Tool '{tool_name}' not registered") # Handle different simulation modes - if registered_tool.mode == "static": - return self._handle_static_mode(registered_tool, tool_type) - elif registered_tool.mode == "mock": - return self._handle_mock_mode(registered_tool, input_data, state_key, tool_type) - elif registered_tool.mode == "dynamic": - # Route to appropriate handler based on tool type - if tool_type == ToolType.FUNCTION: - return self._handle_function_tool(input_data, state_key) - elif tool_type == ToolType.MCP: - return self._handle_mcp_tool(input_data, state_key) - elif tool_type == ToolType.API: - return self._handle_api_tool(input_data, state_key) - else: - raise ValueError(f"Tool type '{tool_type}' not supported") - else: - raise ValueError(f"Tool simulation mode '{registered_tool.mode}' not supported") - - def _handle_function_tool(self, input_data: Dict[str, Any], state_key: str) -> Dict[str, Any]: + match registered_tool.mode: + case "static": + return self._handle_static_mode(registered_tool, tool_type) + case "mock": + return self._handle_mock_mode(registered_tool, input_data, state_key, tool_type) + case "dynamic": + # Route to appropriate handler based on tool type + match tool_type: + case ToolType.FUNCTION: + return self._handle_function_tool(input_data, state_key) + case ToolType.MCP: + return self._handle_mcp_tool(input_data, state_key) + case ToolType.API: + return self._handle_api_tool(input_data, state_key) + case _: + raise ValueError(f"Tool type '{tool_type}' not supported") + case _: + raise ValueError(f"Tool simulation mode '{registered_tool.mode}' not supported") + + def _handle_function_tool(self, input_data: dict[str, Any], state_key: str) -> dict[str, Any]: """Handle function tool simulation.""" tool_name = input_data.get("tool_name", "") parameters = input_data.get("parameters", {}) @@ -411,9 +413,9 @@ def _handle_function_tool(self, input_data: Dict[str, Any], state_key: str) -> D prompt = self.function_tool_prompt.format( tool_name=tool_name, - parameters=json.dumps(parameters, indent=2) if parameters else "{}", + parameters=json.dumps(parameters, indent=2), initial_state_description=initial_state_description, - previous_responses=json.dumps(current_state, indent=2) or "{}", + previous_responses=json.dumps(current_state, indent=2), ) # Create agent and generate response with structured output @@ -426,17 +428,11 @@ def _handle_function_tool(self, input_data: Dict[str, Any], state_key: str) -> D result = agent(prompt, structured_output_model=None) # Parse JSON response for function tools since they vary based on function signature - response_text = ( - getattr(result, "response", None) or str(result.content) if hasattr(result, "content") else str(result) - ) - - if response_text and isinstance(response_text, str): - try: - response_data = json.loads(response_text) - except json.JSONDecodeError: - response_data = {"result": response_text} - else: - response_data = {"result": response_text or "No response"} + response_text = str(result) or "No response" + try: + response_data = json.loads(response_text) + except json.JSONDecodeError: + response_data = {"result": response_text} # Cache the call self._state_registry.cache_tool_call( @@ -449,7 +445,7 @@ def _handle_function_tool(self, input_data: Dict[str, Any], state_key: str) -> D logger.error(f"Error generating function response: {e}") raise RuntimeError(f"Error generating function response: {e}") from e - def _handle_mcp_tool(self, input_data: Dict[str, Any], state_key: str) -> Dict[str, Any]: + def _handle_mcp_tool(self, input_data: dict[str, Any], state_key: str) -> dict[str, Any]: """Handle MCP tool simulation.""" tool_name = input_data.get("tool_name", "") input_mcp_payload = input_data.get("input_mcp_payload", {}) @@ -467,9 +463,9 @@ def _handle_mcp_tool(self, input_data: Dict[str, Any], state_key: str) -> Dict[s prompt = self.mcp_tool_prompt.format( tool_name=tool_name, - mcp_payload=json.dumps(input_mcp_payload, indent=2) if input_mcp_payload else "{}", + mcp_payload=json.dumps(input_mcp_payload, indent=2), initial_state_description=initial_state_description, - previous_responses=json.dumps(current_state, indent=2) or "{}", + previous_responses=json.dumps(current_state, indent=2), ) # Create agent and generate response with structured output @@ -480,9 +476,11 @@ def _handle_mcp_tool(self, input_data: Dict[str, Any], state_key: str) -> Dict[s callback_handler=None, ) result = agent(prompt, structured_output_model=MCPToolResponse) - if result.structured_output: - response_data = result.structured_output.model_dump() - else: + + response_text = str(result) or "No response" + try: + response_data = json.loads(response_text) + except json.JSONDecodeError: response_data = { "isError": True, "content": [{"type": "text", "text": "No structured output received"}], @@ -499,7 +497,7 @@ def _handle_mcp_tool(self, input_data: Dict[str, Any], state_key: str) -> Dict[s logger.error(f"Error generating MCP response: {e}") return {"isError": True, "content": [{"type": "text", "text": f"Error generating response: {str(e)}"}]} - def _handle_api_tool(self, input_data: Dict[str, Any], state_key: str) -> Dict[str, Any]: + def _handle_api_tool(self, input_data: dict[str, Any], state_key: str) -> dict[str, Any]: """Handle API tool simulation.""" tool_name = input_data.get("tool_name", "") user_input_api_payload = input_data.get("user_input_api_payload", {}) @@ -523,7 +521,7 @@ def _handle_api_tool(self, input_data: Dict[str, Any], state_key: str) -> Dict[s method=method, api_payload=json.dumps(user_input_api_payload, indent=2) if user_input_api_payload else "{}", initial_state_description=initial_state_description, - previous_responses=json.dumps(current_state, indent=2) or "{}", + previous_responses=json.dumps(current_state, indent=2), ) # Create agent and generate response with structured output @@ -534,9 +532,11 @@ def _handle_api_tool(self, input_data: Dict[str, Any], state_key: str) -> Dict[s callback_handler=None, ) result = agent(prompt, structured_output_model=APIToolResponse) - if result.structured_output: - response_data = result.structured_output.model_dump() - else: + + response_text = str(result) or "No response" + try: + response_data = json.loads(response_text) + except json.JSONDecodeError: response_data = { "status": 500, "error": { @@ -562,7 +562,7 @@ def _handle_api_tool(self, input_data: Dict[str, Any], state_key: str) -> Dict[s except Exception as e: raise RuntimeError(f"Error generating simulated API response: {e}") from e - def _handle_static_mode(self, registered_tool: RegisteredTool, tool_type: ToolType) -> Dict[str, Any]: + def _handle_static_mode(self, registered_tool: RegisteredTool, tool_type: ToolType) -> dict[str, Any]: """Handle static mode simulation - returns predefined static response.""" if registered_tool.static_response is None: raise ValueError(f"Static response is required for tool '{registered_tool.name}' in static mode") @@ -570,8 +570,8 @@ def _handle_static_mode(self, registered_tool: RegisteredTool, tool_type: ToolTy return registered_tool.static_response def _handle_mock_mode( - self, registered_tool: RegisteredTool, input_data: Dict[str, Any], state_key: str, tool_type: ToolType - ) -> Dict[str, Any]: + self, registered_tool: RegisteredTool, input_data: dict[str, Any], state_key: str, tool_type: ToolType + ) -> dict[str, Any]: """Handle mock mode simulation - calls custom mock function.""" if registered_tool.mock_function is None: raise ValueError("mock_function is required for tool simulator mock mode") @@ -653,11 +653,11 @@ def _get_instance(cls) -> "ToolSimulator": @classmethod def function_tool( cls, - name: Optional[str] = None, - initial_state_description: Optional[str] = None, + name: str | None = None, + initial_state_description: str | None = None, mode: str = "dynamic", - static_response: Optional[Dict[str, Any]] = None, - mock_function: Optional[Callable] = None, + static_response: dict[str, Any] | None = None, + mock_function: Callable | None = None, **simulator_kwargs, ) -> Callable: """ @@ -704,12 +704,12 @@ def decorator(func: Callable) -> Callable: @classmethod def mcp_tool( cls, - name: Optional[str] = None, - schema: Optional[Dict[str, Any]] = None, - initial_state_description: Optional[str] = None, + name: str | None = None, + schema: dict[str, Any] | None = None, + initial_state_description: str | None = None, mode: str = "dynamic", - static_response: Optional[Dict[str, Any]] = None, - mock_function: Optional[Callable] = None, + static_response: dict[str, Any] | None = None, + mock_function: Callable | None = None, **simulator_kwargs, ) -> Callable: """ @@ -756,14 +756,14 @@ def decorator(func: Callable) -> Callable: @classmethod def api_tool( cls, - name: Optional[str] = None, - path: Optional[str] = None, - method: Optional[str] = None, - schema: Optional[Dict[str, Any]] = None, - initial_state_description: Optional[str] = None, + name: str | None = None, + path: str | None = None, + method: str | None = None, + schema: dict[str, Any] | None = None, + initial_state_description: str | None = None, mode: str = "dynamic", - static_response: Optional[Dict[str, Any]] = None, - mock_function: Optional[Callable] = None, + static_response: dict[str, Any] | None = None, + mock_function: Callable | None = None, **simulator_kwargs, ) -> Callable: """ @@ -812,7 +812,7 @@ def decorator(func: Callable) -> Callable: return decorator - def get_tool(self, tool_name: str) -> Optional[Callable]: + def get_tool(self, tool_name: str) -> Callable | None: """ Get a tool by name and create a simulation wrapper. @@ -828,7 +828,7 @@ def get_tool(self, tool_name: str) -> Optional[Callable]: return self._create_tool_wrapper(registered_tool) - def list_tools(self) -> List[str]: + def list_tools(self) -> list[str]: """ List all registered tool names. diff --git a/src/strands_evals/types/simulation/tool.py b/src/strands_evals/types/simulation/tool.py index d1e7449..7b719a2 100644 --- a/src/strands_evals/types/simulation/tool.py +++ b/src/strands_evals/types/simulation/tool.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable from pydantic import BaseModel, Field @@ -36,19 +36,19 @@ class RegisteredTool(BaseModel): name: str = Field(..., description="Name of the tool") tool_type: ToolType = Field(..., description="Type of the tool") - function: Optional[Callable] = Field(default=None, description="Function callable", exclude=True) - mcp_schema: Optional[Dict[str, Any]] = Field(default=None, description="MCP tool schema") - api_path: Optional[str] = Field(default=None, description="API endpoint path") - api_method: Optional[str] = Field(default=None, description="HTTP method") - initial_state_description: Optional[str] = Field( + function: Callable | None = Field(default=None, description="Function callable", exclude=True) + mcp_schema: dict[str, Any] | None = Field(default=None, description="MCP tool schema") + api_path: str | None = Field(default=None, description="API endpoint path") + api_method: str | None = Field(default=None, description="HTTP method") + initial_state_description: str | None = Field( default=None, description="Initial state description for the tool's context" ) - simulator_kwargs: Optional[Dict[str, Any]] = Field( + simulator_kwargs: dict[str, Any] | None = Field( default_factory=dict, description="Additional simulator configuration" ) mode: str = Field(default="dynamic", description="Simulation mode: dynamic, static, mock") - static_response: Optional[Dict[str, Any]] = Field(default=None, description="Static response for static mode") - mock_function: Optional[Callable] = Field(default=None, description="Mock function for mock mode", exclude=True) + static_response: dict[str, Any] | None = Field(default=None, description="Static response for static mode") + mock_function: Callable | None = Field(default=None, description="Mock function for mock mode", exclude=True) model_config = {"arbitrary_types_allowed": True} @@ -57,8 +57,8 @@ class MCPContentItem(BaseModel): """Individual content item in MCP response.""" type: str = Field(..., description="Type of content (text, resource, etc.)") - text: Optional[str] = Field(default=None, description="Text content") - resource: Optional[Dict[str, Any]] = Field(default=None, description="Resource information") + text: str | None = Field(default=None, description="Text content") + resource: dict[str, Any] | None = Field(default=None, description="Resource information") class MCPToolResponse(BaseModel): @@ -68,8 +68,8 @@ class MCPToolResponse(BaseModel): Follows the MCP response format with content array and optional error flag. """ - content: List[MCPContentItem] = Field(..., description="Array of content items") - isError: Optional[bool] = Field(default=False, description="Whether this response represents an error") + content: list[MCPContentItem] = Field(..., description="Array of content items") + isError: bool | None = Field(default=False, description="Whether this response represents an error") class APIErrorDetail(BaseModel): @@ -88,8 +88,8 @@ class APIToolResponse(BaseModel): """ status: int = Field(..., description="HTTP status code") - data: Optional[Any] = Field(default=None, description="Response data for successful requests") - error: Optional[APIErrorDetail] = Field(default=None, description="Error details for failed requests") + data: Any | None = Field(default=None, description="Response data for successful requests") + error: APIErrorDetail | None = Field(default=None, description="Error details for failed requests") # Allow additional fields for flexibility model_config = {"extra": "allow"} diff --git a/tests/strands_evals/simulation/test_tool_simulator.py b/tests/strands_evals/simulation/test_tool_simulator.py index 74cf5cb..6f0581c 100644 --- a/tests/strands_evals/simulation/test_tool_simulator.py +++ b/tests/strands_evals/simulation/test_tool_simulator.py @@ -118,10 +118,8 @@ def test_func(message: str) -> dict: # Mock the Agent constructor and its result to avoid real LLM calls mock_agent_instance = MagicMock() mock_result = MagicMock() - # For function tools, set structured_output to None so it falls back to parsing response text - mock_result.structured_output = None - # Mock the response attribute to return parsable JSON string - mock_result.response = '{"result": "simulated response"}' + # Mock __str__ method to return expected JSON string + mock_result.__str__ = MagicMock(return_value='{"result": "simulated response"}') mock_agent_instance.return_value = mock_result with pytest.MonkeyPatch().context() as m: @@ -150,7 +148,8 @@ def test_mcp(**params): # Mock the Agent constructor and its result to avoid real LLM calls mock_agent_instance = MagicMock() mock_result = MagicMock() - mock_result.structured_output.model_dump.return_value = {"content": [{"type": "text", "text": "MCP response"}]} + # Mock __str__ method to return expected JSON string + mock_result.__str__ = MagicMock(return_value='{"content": [{"type": "text", "text": "MCP response"}]}') mock_agent_instance.return_value = mock_result with pytest.MonkeyPatch().context() as m: @@ -178,7 +177,8 @@ def test_api(**kwargs): # Mock the Agent constructor and its result to avoid real LLM calls mock_agent_instance = MagicMock() mock_result = MagicMock() - mock_result.structured_output.model_dump.return_value = {"status": 200, "data": {"key": "value"}} + # Mock __str__ method to return expected JSON string + mock_result.__str__ = MagicMock(return_value='{"status": 200, "data": {"key": "value"}}') mock_agent_instance.return_value = mock_result with pytest.MonkeyPatch().context() as m: @@ -259,18 +259,10 @@ def create_mock_agent(**kwargs): if len(mock_agent_instances) < len(expected_responses): response = expected_responses[len(mock_agent_instances)] - if "content" in response: - # MCP response needs .model_dump() - mock_result.structured_output.model_dump.return_value = response - elif "balance" in response: - # Function response - set structured_output to None and provide JSON string in response - mock_result.structured_output = None - import json - - mock_result.response = json.dumps(response) - else: - # API response - use .model_dump() - mock_result.structured_output.model_dump.return_value = response + import json + + # Simplified approach: Mock __str__ method to return JSON string for all tool types + mock_result.__str__ = MagicMock(return_value=json.dumps(response)) mock_agent.return_value = mock_result mock_agent_instances.append(mock_agent) From f8c5cd358216c844b19e43c77b9abeb827f06c95 Mon Sep 17 00:00:00 2001 From: Darren Wang Date: Tue, 3 Feb 2026 23:53:51 +0000 Subject: [PATCH 14/15] fix forward referencing with Union --- src/strands_evals/simulation/tool_simulator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/strands_evals/simulation/tool_simulator.py b/src/strands_evals/simulation/tool_simulator.py index d43b2d2..f8ab958 100644 --- a/src/strands_evals/simulation/tool_simulator.py +++ b/src/strands_evals/simulation/tool_simulator.py @@ -4,7 +4,7 @@ import warnings from collections import defaultdict, deque from datetime import datetime -from typing import Any, Callable +from typing import Any, Callable, Union from strands import Agent from strands.models.model import Model @@ -179,7 +179,7 @@ class ToolSimulator: # Class-level registry for all registered tools _registered_tools: dict[str, RegisteredTool] = {} _state_registry: StateRegistry | None = None - _global_instance: "ToolSimulator" | None = None + _global_instance: Union["ToolSimulator", None] = None def __init__( self, From 1303d57c592bfbb84b19301624910041f579ba35 Mon Sep 17 00:00:00 2001 From: Darren Wang Date: Wed, 4 Feb 2026 23:04:21 +0000 Subject: [PATCH 15/15] simplify util function argument passing and validation --- .../simulation/tool_simulator.py | 307 +++++++----------- src/strands_evals/types/simulation/tool.py | 19 +- .../simulation/test_tool_simulator.py | 24 +- 3 files changed, 145 insertions(+), 205 deletions(-) diff --git a/src/strands_evals/simulation/tool_simulator.py b/src/strands_evals/simulation/tool_simulator.py index f8ab958..82208c0 100644 --- a/src/strands_evals/simulation/tool_simulator.py +++ b/src/strands_evals/simulation/tool_simulator.py @@ -4,15 +4,17 @@ import warnings from collections import defaultdict, deque from datetime import datetime -from typing import Any, Callable, Union +from typing import Any, Callable from strands import Agent +from strands.agent import AgentResult from strands.models.model import Model from strands_evals.types.simulation.tool import ( APIToolResponse, MCPToolResponse, RegisteredTool, + ToolSimulationMode, ToolType, ) @@ -178,8 +180,8 @@ class ToolSimulator: # Class-level registry for all registered tools _registered_tools: dict[str, RegisteredTool] = {} - _state_registry: StateRegistry | None = None - _global_instance: Union["ToolSimulator", None] = None + _state_registry: StateRegistry + _global_instance: "ToolSimulator | None" = None def __init__( self, @@ -302,7 +304,7 @@ def wrapper(*args, **kwargs): else: raise ValueError(f"Unsupported tool type: {registered_tool.tool_type}") - return self._simulate_tool_call(registered_tool.tool_type, state_key, input_data) + return self._call(registered_tool, input_data, state_key) # Copy function metadata if registered_tool.function: @@ -368,201 +370,123 @@ def _create_strands_tool_wrapper(self, registered_tool: RegisteredTool, wrapper: return decorated_tool - def _simulate_tool_call(self, tool_type: ToolType, state_key: str, input_data: dict[str, Any]) -> Any: - """Simulate a tool invocation and return the response.""" - tool_name = input_data.get("tool_name", "") - registered_tool = self._registered_tools.get(tool_name) - if not registered_tool: - raise ValueError(f"Tool '{tool_name}' not registered") + def _simulate_tool_call(self, prompt: str, structured_output_model=None) -> Any: + """Tool simulation agent creation and response generation.""" + agent = Agent( + system_prompt=prompt, + tools=[], + model=self.model, + callback_handler=None, + ) + return agent(prompt, structured_output_model=structured_output_model) + def _parse_simulated_response(self, result: AgentResult) -> dict[str, Any]: + """Parse tool simulation agent response, trying to extract JSON first, falling back to wrapping in result.""" + response_text = str(result) or "No response" + try: + response_data = json.loads(response_text) + except json.JSONDecodeError: + response_data = {"result": response_text} + return response_data + + def _call(self, registered_tool: RegisteredTool, input_data: dict[str, Any], state_key: str) -> Any: + """Simulate a tool invocation and return the response.""" # Handle different simulation modes match registered_tool.mode: - case "static": - return self._handle_static_mode(registered_tool, tool_type) - case "mock": - return self._handle_mock_mode(registered_tool, input_data, state_key, tool_type) - case "dynamic": + case ToolSimulationMode.STATIC: + return self._handle_static_mode(registered_tool) + case ToolSimulationMode.MOCK: + return self._handle_mock_mode(registered_tool, input_data, state_key) + case ToolSimulationMode.DYNAMIC: # Route to appropriate handler based on tool type - match tool_type: + match registered_tool.tool_type: case ToolType.FUNCTION: - return self._handle_function_tool(input_data, state_key) + return self._handle_function_tool(registered_tool.name, input_data, state_key) case ToolType.MCP: - return self._handle_mcp_tool(input_data, state_key) + return self._handle_mcp_tool(registered_tool.name, input_data, state_key) case ToolType.API: - return self._handle_api_tool(input_data, state_key) - case _: - raise ValueError(f"Tool type '{tool_type}' not supported") - case _: - raise ValueError(f"Tool simulation mode '{registered_tool.mode}' not supported") + return self._handle_api_tool(registered_tool.name, input_data, state_key) - def _handle_function_tool(self, input_data: dict[str, Any], state_key: str) -> dict[str, Any]: + def _handle_function_tool(self, tool_name: str, input_data: dict[str, Any], state_key: str) -> dict[str, Any]: """Handle function tool simulation.""" - tool_name = input_data.get("tool_name", "") parameters = input_data.get("parameters", {}) - if not tool_name: - raise ValueError("Tool name is required") - - if not self._state_registry: - raise RuntimeError("State registry is not initialized") + current_state = self._state_registry.get_state(state_key) - try: - # Get initial state description from state registry - current_state = self._state_registry.get_state(state_key) - initial_state_description = current_state.get("initial_state", "No initial state provided.") - - prompt = self.function_tool_prompt.format( - tool_name=tool_name, - parameters=json.dumps(parameters, indent=2), - initial_state_description=initial_state_description, - previous_responses=json.dumps(current_state, indent=2), - ) + prompt = self.function_tool_prompt.format( + tool_name=tool_name, + parameters=json.dumps(parameters, indent=2), + initial_state_description=current_state.get("initial_state", "No initial state provided."), + previous_responses=json.dumps(current_state.get("previous_calls", []), indent=2), + ) - # Create agent and generate response with structured output - agent = Agent( - system_prompt=self.function_tool_prompt, - tools=[], - model=self.model, - callback_handler=None, - ) - result = agent(prompt, structured_output_model=None) + result = self._simulate_tool_call(prompt, structured_output_model=None) - # Parse JSON response for function tools since they vary based on function signature - response_text = str(result) or "No response" - try: - response_data = json.loads(response_text) - except json.JSONDecodeError: - response_data = {"result": response_text} + response_data = self._parse_simulated_response(result) - # Cache the call - self._state_registry.cache_tool_call( - tool_name, state_key, ToolType.FUNCTION, response_data, parameters=parameters - ) - - return response_data + self._state_registry.cache_tool_call( + tool_name, state_key, ToolType.FUNCTION, response_data, parameters=parameters + ) - except Exception as e: - logger.error(f"Error generating function response: {e}") - raise RuntimeError(f"Error generating function response: {e}") from e + return response_data - def _handle_mcp_tool(self, input_data: dict[str, Any], state_key: str) -> dict[str, Any]: + def _handle_mcp_tool(self, tool_name: str, input_data: dict[str, Any], state_key: str) -> dict[str, Any]: """Handle MCP tool simulation.""" - tool_name = input_data.get("tool_name", "") input_mcp_payload = input_data.get("input_mcp_payload", {}) - if not tool_name: - return {"isError": True, "content": [{"type": "text", "text": "Tool name is required"}]} - - if not self._state_registry: - raise RuntimeError("State registry is not initialized") - - try: - # Get initial state description from state registry - current_state = self._state_registry.get_state(state_key) - initial_state_description = current_state.get("initial_state", "No initial state provided.") + current_state = self._state_registry.get_state(state_key) - prompt = self.mcp_tool_prompt.format( - tool_name=tool_name, - mcp_payload=json.dumps(input_mcp_payload, indent=2), - initial_state_description=initial_state_description, - previous_responses=json.dumps(current_state, indent=2), - ) + prompt = self.mcp_tool_prompt.format( + tool_name=tool_name, + mcp_payload=json.dumps(input_mcp_payload, indent=2), + initial_state_description=current_state.get("initial_state", "No initial state provided."), + previous_responses=json.dumps(current_state.get("previous_calls", []), indent=2), + ) - # Create agent and generate response with structured output - agent = Agent( - system_prompt=self.mcp_tool_prompt, - tools=[], - model=self.model, - callback_handler=None, - ) - result = agent(prompt, structured_output_model=MCPToolResponse) + result = self._simulate_tool_call(prompt, structured_output_model=MCPToolResponse) - response_text = str(result) or "No response" - try: - response_data = json.loads(response_text) - except json.JSONDecodeError: - response_data = { - "isError": True, - "content": [{"type": "text", "text": "No structured output received"}], - } + response_data = self._parse_simulated_response(result) - # Cache the call - self._state_registry.cache_tool_call( - tool_name, state_key, ToolType.MCP, response_data, input_mcp_payload=input_mcp_payload - ) - - return response_data + self._state_registry.cache_tool_call( + tool_name, state_key, ToolType.MCP, response_data, input_mcp_payload=input_mcp_payload + ) - except Exception as e: - logger.error(f"Error generating MCP response: {e}") - return {"isError": True, "content": [{"type": "text", "text": f"Error generating response: {str(e)}"}]} + return response_data - def _handle_api_tool(self, input_data: dict[str, Any], state_key: str) -> dict[str, Any]: + def _handle_api_tool(self, tool_name: str, input_data: dict[str, Any], state_key: str) -> dict[str, Any]: """Handle API tool simulation.""" - tool_name = input_data.get("tool_name", "") user_input_api_payload = input_data.get("user_input_api_payload", {}) path = input_data.get("path", "") - method = input_data.get("method", "GET") - - if not tool_name: - raise ValueError("tool_name is required for API tool simulation") + method = input_data.get("method", "GET").upper() # Normalize HTTP method to uppercase - if not self._state_registry: - raise RuntimeError("State registry is not initialized") - - try: - # Get initial state description from state registry - current_state = self._state_registry.get_state(state_key) - initial_state_description = current_state.get("initial_state", "No initial state provided.") - - prompt = self.api_tool_prompt.format( - tool_name=tool_name, - path=path, - method=method, - api_payload=json.dumps(user_input_api_payload, indent=2) if user_input_api_payload else "{}", - initial_state_description=initial_state_description, - previous_responses=json.dumps(current_state, indent=2), - ) + current_state = self._state_registry.get_state(state_key) - # Create agent and generate response with structured output - agent = Agent( - system_prompt=self.api_tool_prompt, - tools=[], - model=self.model, - callback_handler=None, - ) - result = agent(prompt, structured_output_model=APIToolResponse) + prompt = self.api_tool_prompt.format( + tool_name=tool_name, + path=path, + method=method, + api_payload=json.dumps(user_input_api_payload, indent=2) if user_input_api_payload else "{}", + initial_state_description=current_state.get("initial_state", "No initial state provided."), + previous_responses=json.dumps(current_state.get("previous_calls", []), indent=2), + ) - response_text = str(result) or "No response" - try: - response_data = json.loads(response_text) - except json.JSONDecodeError: - response_data = { - "status": 500, - "error": { - "type": "internal_error", - "title": "Internal Error", - "detail": "No structured output received", - }, - } + result = self._simulate_tool_call(prompt, structured_output_model=APIToolResponse) - # Cache the call - self._state_registry.cache_tool_call( - tool_name, - state_key, - ToolType.API, - response_data, - path=path, - method=method, - input_data=user_input_api_payload, - ) + response_data = self._parse_simulated_response(result) - return response_data + self._state_registry.cache_tool_call( + tool_name, + state_key, + ToolType.API, + response_data, + path=path, + method=method, + input_data=user_input_api_payload, + ) - except Exception as e: - raise RuntimeError(f"Error generating simulated API response: {e}") from e + return response_data - def _handle_static_mode(self, registered_tool: RegisteredTool, tool_type: ToolType) -> dict[str, Any]: + def _handle_static_mode(self, registered_tool: RegisteredTool) -> dict[str, Any]: """Handle static mode simulation - returns predefined static response.""" if registered_tool.static_response is None: raise ValueError(f"Static response is required for tool '{registered_tool.name}' in static mode") @@ -570,7 +494,7 @@ def _handle_static_mode(self, registered_tool: RegisteredTool, tool_type: ToolTy return registered_tool.static_response def _handle_mock_mode( - self, registered_tool: RegisteredTool, input_data: dict[str, Any], state_key: str, tool_type: ToolType + self, registered_tool: RegisteredTool, input_data: dict[str, Any], state_key: str ) -> dict[str, Any]: """Handle mock mode simulation - calls custom mock function.""" if registered_tool.mock_function is None: @@ -578,7 +502,7 @@ def _handle_mock_mode( try: # Extract parameters based on tool type - if tool_type == ToolType.FUNCTION: + if registered_tool.tool_type == ToolType.FUNCTION: parameters = input_data.get("parameters", {}) if isinstance(parameters, str): parameters = json.loads(parameters) @@ -591,31 +515,28 @@ def _handle_mock_mode( else: result = registered_tool.mock_function(**parameters) - elif tool_type == ToolType.MCP: + elif registered_tool.tool_type == ToolType.MCP: input_mcp_payload = input_data.get("input_mcp_payload", {}) result = registered_tool.mock_function(**input_mcp_payload) - elif tool_type == ToolType.API: + elif registered_tool.tool_type == ToolType.API: user_input_api_payload = input_data.get("user_input_api_payload", {}) result = registered_tool.mock_function(**user_input_api_payload) else: - raise ValueError(f"Unsupported tool type '{tool_type}' for mock mode") + raise ValueError(f"Unsupported tool type '{registered_tool.tool_type}' for mock mode") # Record the call in state registry - if not self._state_registry: - raise RuntimeError("State registry is not initialized") - tool_name = registered_tool.name - if tool_type == ToolType.FUNCTION: + if registered_tool.tool_type == ToolType.FUNCTION: self._state_registry.cache_tool_call( tool_name, state_key, ToolType.FUNCTION, result, parameters=parameters ) - elif tool_type == ToolType.MCP: + elif registered_tool.tool_type == ToolType.MCP: self._state_registry.cache_tool_call( tool_name, state_key, ToolType.MCP, result, input_mcp_payload=input_mcp_payload ) - elif tool_type == ToolType.API: + elif registered_tool.tool_type == ToolType.API: path = input_data.get("path", "") method = input_data.get("method", "GET") self._state_registry.cache_tool_call( @@ -632,7 +553,7 @@ def _handle_mock_mode( except Exception as e: raise RuntimeError( - f"Tool simulator mock mode error for {tool_type} tool {registered_tool.name}: {e}" + f"Tool simulator mock mode error for {registered_tool.tool_type} tool {registered_tool.name}: {e}" ) from e @classmethod @@ -655,7 +576,7 @@ def function_tool( cls, name: str | None = None, initial_state_description: str | None = None, - mode: str = "dynamic", + mode: ToolSimulationMode | str = ToolSimulationMode.DYNAMIC, static_response: dict[str, Any] | None = None, mock_function: Callable | None = None, **simulator_kwargs, @@ -666,7 +587,7 @@ def function_tool( Args: name: Optional name for the tool. If None, uses function.__name__ initial_state_description: Optional initial state description for the tool's context - mode: Simulation mode - "dynamic", "static", or "mock" + mode: Simulation mode - ToolSimulationMode enum or "dynamic", "static", "mock" string static_response: Static response dict for static mode mock_function: Custom callable for mock mode **simulator_kwargs: Additional simulator configuration @@ -679,6 +600,12 @@ def decorator(func: Callable) -> Callable: try: tool_name = name or func.__name__ + # Convert string mode to enum for backward compatibility + if isinstance(mode, str): + mode_enum = ToolSimulationMode(mode) + else: + mode_enum = mode + # Register tool registered_tool = RegisteredTool( name=tool_name, @@ -686,7 +613,7 @@ def decorator(func: Callable) -> Callable: function=func, initial_state_description=initial_state_description, simulator_kwargs=simulator_kwargs, - mode=mode, + mode=mode_enum, static_response=static_response, mock_function=mock_function, ) @@ -707,7 +634,7 @@ def mcp_tool( name: str | None = None, schema: dict[str, Any] | None = None, initial_state_description: str | None = None, - mode: str = "dynamic", + mode: ToolSimulationMode | str = ToolSimulationMode.DYNAMIC, static_response: dict[str, Any] | None = None, mock_function: Callable | None = None, **simulator_kwargs, @@ -719,7 +646,7 @@ def mcp_tool( name: Optional name for the tool. If None, uses function.__name__ schema: MCP tool schema dictionary initial_state_description: Optional initial state description for the tool's context - mode: Simulation mode - "dynamic", "static", or "mock" + mode: Simulation mode - ToolSimulationMode enum or "dynamic", "static", "mock" string static_response: Static response dict for static mode mock_function: Custom callable for mock mode **simulator_kwargs: Additional simulator configuration @@ -734,6 +661,12 @@ def decorator(func: Callable) -> Callable: if schema is None: raise ValueError(f"MCP schema is required for tool {tool_name}") + # Convert string mode to enum for backward compatibility + if isinstance(mode, str): + mode_enum = ToolSimulationMode(mode) + else: + mode_enum = mode + # Register tool registered_tool = RegisteredTool( name=tool_name, @@ -742,7 +675,7 @@ def decorator(func: Callable) -> Callable: mcp_schema=schema, initial_state_description=initial_state_description, simulator_kwargs=simulator_kwargs, - mode=mode, + mode=mode_enum, static_response=static_response, mock_function=mock_function, ) @@ -761,7 +694,7 @@ def api_tool( method: str | None = None, schema: dict[str, Any] | None = None, initial_state_description: str | None = None, - mode: str = "dynamic", + mode: ToolSimulationMode | str = ToolSimulationMode.DYNAMIC, static_response: dict[str, Any] | None = None, mock_function: Callable | None = None, **simulator_kwargs, @@ -775,7 +708,7 @@ def api_tool( method: HTTP method (GET, POST, etc.) schema: API tool schema dictionary initial_state_description: Optional initial state description for the tool's context - mode: Simulation mode - "dynamic", "static", or "mock" + mode: Simulation mode - ToolSimulationMode enum or "dynamic", "static", "mock" string static_response: Static response dict for static mode mock_function: Custom callable for mock mode **simulator_kwargs: Additional simulator configuration @@ -792,6 +725,12 @@ def decorator(func: Callable) -> Callable: if method is None: raise ValueError("HTTP method is required") + # Convert string mode to enum for backward compatibility + if isinstance(mode, str): + mode_enum = ToolSimulationMode(mode) + else: + mode_enum = mode + # Register tool registered_tool = RegisteredTool( name=tool_name, @@ -801,7 +740,7 @@ def decorator(func: Callable) -> Callable: api_method=method, initial_state_description=initial_state_description, simulator_kwargs=simulator_kwargs, - mode=mode, + mode=mode_enum, static_response=static_response, mock_function=mock_function, ) diff --git a/src/strands_evals/types/simulation/tool.py b/src/strands_evals/types/simulation/tool.py index 7b719a2..d454699 100644 --- a/src/strands_evals/types/simulation/tool.py +++ b/src/strands_evals/types/simulation/tool.py @@ -19,6 +19,21 @@ class ToolType(Enum): API = "api" +class ToolSimulationMode(Enum): + """ + Enumeration of supported simulation modes. + + Attributes: + DYNAMIC: Generate responses using LLM based on tool context and history. + STATIC: Return predefined static responses. + MOCK: Call custom mock functions for controlled behavior. + """ + + DYNAMIC = "dynamic" + STATIC = "static" + MOCK = "mock" + + class RegisteredTool(BaseModel): """ Represents a registered tool in the simulator. @@ -46,7 +61,9 @@ class RegisteredTool(BaseModel): simulator_kwargs: dict[str, Any] | None = Field( default_factory=dict, description="Additional simulator configuration" ) - mode: str = Field(default="dynamic", description="Simulation mode: dynamic, static, mock") + mode: ToolSimulationMode = Field( + default=ToolSimulationMode.DYNAMIC, description="Simulation mode: dynamic, static, mock" + ) static_response: dict[str, Any] | None = Field(default=None, description="Static response for static mode") mock_function: Callable | None = Field(default=None, description="Mock function for mock mode", exclude=True) diff --git a/tests/strands_evals/simulation/test_tool_simulator.py b/tests/strands_evals/simulation/test_tool_simulator.py index 6f0581c..907059d 100644 --- a/tests/strands_evals/simulation/test_tool_simulator.py +++ b/tests/strands_evals/simulation/test_tool_simulator.py @@ -386,26 +386,11 @@ def test_tool_not_found_raises_error(): """Test that accessing non-existent tools raises ValueError.""" simulator = ToolSimulator() - # Test that accessing a non-existent tool via _simulate_tool_call raises ValueError - with pytest.raises(ValueError) as excinfo: - simulator._simulate_tool_call( - tool_type=ToolType.FUNCTION, state_key="test", input_data={"tool_name": "nonexistent_tool"} - ) - - assert "not registered" in str(excinfo.value) - - -def test_api_tool_missing_name_raises_error(): - """Test that API tool simulation raises ValueError when tool_name is missing.""" - simulator = ToolSimulator() - - with pytest.raises(ValueError) as excinfo: - simulator._handle_api_tool( - input_data={"tool_name": ""}, # Empty tool name - state_key="test", - ) + # Test that accessing a non-existent tool via __getattr__ raises AttributeError + with pytest.raises(AttributeError) as excinfo: + _ = simulator.nonexistent_tool - assert "tool_name is required for API tool simulation" in str(excinfo.value) + assert "not found in registered tools" in str(excinfo.value) def test_mock_mode_missing_function_raises_error(): @@ -424,7 +409,6 @@ def test_mock_tool(): registered_tool=registered_tool, input_data={"tool_name": "test_mock_tool", "parameters": {}}, state_key="test", - tool_type=ToolType.FUNCTION, ) assert "mock_function is required for tool simulator mock mode" in str(excinfo.value)