diff --git a/libs/langchain_v1/langchain/agents/factory.py b/libs/langchain_v1/langchain/agents/factory.py index 2f9962759fc7a..59ac049b13fe9 100644 --- a/libs/langchain_v1/langchain/agents/factory.py +++ b/libs/langchain_v1/langchain/agents/factory.py @@ -33,6 +33,7 @@ ProviderStrategy, ProviderStrategyBinding, ResponseFormat, + StructuredOutputError, StructuredOutputValidationError, ToolStrategy, ) @@ -797,8 +798,16 @@ def _handle_model_output( provider_strategy_binding = ProviderStrategyBinding.from_schema_spec( effective_response_format.schema_spec ) - structured_response = provider_strategy_binding.parse(output) - return {"messages": [output], "structured_response": structured_response} + try: + structured_response = provider_strategy_binding.parse(output) + except Exception as exc: # noqa: BLE001 + schema_name = getattr( + effective_response_format.schema_spec.schema, "__name__", "response_format" + ) + validation_error = StructuredOutputValidationError(schema_name, exc, output) + raise validation_error + else: + return {"messages": [output], "structured_response": structured_response} return {"messages": [output]} # Handle structured output with tool strategy @@ -812,11 +821,11 @@ def _handle_model_output( ] if structured_tool_calls: - exception: Exception | None = None + exception: StructuredOutputError | None = None if len(structured_tool_calls) > 1: # Handle multiple structured outputs error tool_names = [tc["name"] for tc in structured_tool_calls] - exception = MultipleStructuredOutputsError(tool_names) + exception = MultipleStructuredOutputsError(tool_names, output) should_retry, error_message = _handle_structured_output_error( exception, effective_response_format ) @@ -858,7 +867,7 @@ def _handle_model_output( "structured_response": structured_response, } except Exception as exc: # noqa: BLE001 - exception = StructuredOutputValidationError(tool_call["name"], exc) + exception = StructuredOutputValidationError(tool_call["name"], exc, output) should_retry, error_message = _handle_structured_output_error( exception, effective_response_format ) diff --git a/libs/langchain_v1/langchain/agents/structured_output.py b/libs/langchain_v1/langchain/agents/structured_output.py index cd6a2fd9aed31..750386758077f 100644 --- a/libs/langchain_v1/langchain/agents/structured_output.py +++ b/libs/langchain_v1/langchain/agents/structured_output.py @@ -34,17 +34,21 @@ class StructuredOutputError(Exception): """Base class for structured output errors.""" + ai_message: AIMessage + class MultipleStructuredOutputsError(StructuredOutputError): """Raised when model returns multiple structured output tool calls when only one is expected.""" - def __init__(self, tool_names: list[str]) -> None: + def __init__(self, tool_names: list[str], ai_message: AIMessage) -> None: """Initialize `MultipleStructuredOutputsError`. Args: tool_names: The names of the tools called for structured output. + ai_message: The AI message that contained the invalid multiple tool calls. """ self.tool_names = tool_names + self.ai_message = ai_message super().__init__( "Model incorrectly returned multiple structured responses " @@ -55,15 +59,17 @@ def __init__(self, tool_names: list[str]) -> None: class StructuredOutputValidationError(StructuredOutputError): """Raised when structured output tool call arguments fail to parse according to the schema.""" - def __init__(self, tool_name: str, source: Exception) -> None: + def __init__(self, tool_name: str, source: Exception, ai_message: AIMessage) -> None: """Initialize `StructuredOutputValidationError`. Args: tool_name: The name of the tool that failed. source: The exception that occurred. + ai_message: The AI message that contained the invalid structured output. """ self.tool_name = tool_name self.source = source + self.ai_message = ai_message super().__init__(f"Failed to parse structured output for tool '{tool_name}': {source}.") diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_structured_output_retry.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_structured_output_retry.py new file mode 100644 index 0000000000000..a04f670ad4a06 --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_structured_output_retry.py @@ -0,0 +1,369 @@ +"""Tests for StructuredOutputRetryMiddleware functionality.""" + +from collections.abc import Callable + +import pytest +from langchain_core.messages import HumanMessage +from langchain_core.tools import tool +from langgraph.checkpoint.memory import InMemorySaver +from pydantic import BaseModel + +from langchain.agents import create_agent +from langchain.agents.middleware.types import ( + AgentMiddleware, + ModelRequest, + ModelResponse, +) +from langchain.agents.structured_output import StructuredOutputError, ToolStrategy +from tests.unit_tests.agents.model import FakeToolCallingModel + + +class StructuredOutputRetryMiddleware(AgentMiddleware): + """Retries model calls when structured output parsing fails.""" + + def __init__(self, max_retries: int) -> None: + """Initialize the structured output retry middleware. + + Args: + max_retries: Maximum number of retry attempts. + """ + self.max_retries = max_retries + + def wrap_model_call( + self, request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse] + ) -> ModelResponse: + """Intercept and control model execution via handler callback. + + Args: + request: The model request containing messages and configuration. + handler: The function to call the model. + + Returns: + The model response. + + Raises: + StructuredOutputError: If max retries exceeded without success. + """ + for attempt in range(self.max_retries + 1): + try: + return handler(request) + except StructuredOutputError as exc: + if attempt == self.max_retries: + raise + + # Include both the AI message and error in a single human message + # to maintain valid chat history alternation + ai_content = exc.ai_message.content + error_message = ( + f"Your previous response was:\n{ai_content}\n\n" + f"Error: {exc}. Please try again with a valid response." + ) + request.messages.append(HumanMessage(content=error_message)) + + # This should never be reached, but satisfies type checker + return handler(request) + + +class WeatherReport(BaseModel): + """Weather report schema for testing.""" + + temperature: float + conditions: str + + +@tool +def get_weather(city: str) -> str: + """Get the weather for a given city. + + Args: + city: The city to get weather for. + + Returns: + Weather information for the city. + """ + return f"The weather in {city} is sunny and 72 degrees." + + +def test_structured_output_retry_first_attempt_invalid() -> None: + """Test structured output retry when first two attempts have invalid output.""" + # First two attempts have invalid tool arguments, third attempt succeeds + # The model will call the WeatherReport structured output tool + tool_calls = [ + # First attempt - invalid: wrong type for temperature + [ + { + "name": "WeatherReport", + "id": "1", + "args": {"temperature": "not-a-float", "conditions": "sunny"}, + } + ], + # Second attempt - invalid: missing required field + [{"name": "WeatherReport", "id": "2", "args": {"temperature": 72.5}}], + # Third attempt - valid + [ + { + "name": "WeatherReport", + "id": "3", + "args": {"temperature": 72.5, "conditions": "sunny"}, + } + ], + ] + + model = FakeToolCallingModel(tool_calls=tool_calls) + retry_middleware = StructuredOutputRetryMiddleware(max_retries=2) + + agent = create_agent( + model=model, + tools=[get_weather], + middleware=[retry_middleware], + response_format=ToolStrategy(schema=WeatherReport, handle_errors=False), + checkpointer=InMemorySaver(), + ) + + result = agent.invoke( + {"messages": [HumanMessage("What's the weather in Tokyo?")]}, + {"configurable": {"thread_id": "test"}}, + ) + + # Verify we got a structured response + assert "structured_response" in result + structured = result["structured_response"] + assert isinstance(structured, WeatherReport) + assert structured.temperature == 72.5 + assert structured.conditions == "sunny" + + # Verify the model was called 3 times (initial + 2 retries) + assert model.index == 3 + + +def test_structured_output_retry_exceeds_max_retries() -> None: + """Test structured output retry raises error when max retries exceeded.""" + # All three attempts return invalid arguments + tool_calls = [ + [ + { + "name": "WeatherReport", + "id": "1", + "args": {"temperature": "invalid", "conditions": "sunny"}, + } + ], + [ + { + "name": "WeatherReport", + "id": "2", + "args": {"temperature": "also-invalid", "conditions": "cloudy"}, + } + ], + [ + { + "name": "WeatherReport", + "id": "3", + "args": {"temperature": "still-invalid", "conditions": "rainy"}, + } + ], + ] + + model = FakeToolCallingModel(tool_calls=tool_calls) + retry_middleware = StructuredOutputRetryMiddleware(max_retries=2) + + agent = create_agent( + model=model, + tools=[get_weather], + middleware=[retry_middleware], + response_format=ToolStrategy(schema=WeatherReport, handle_errors=False), + # No checkpointer - we expect this to fail + ) + + # Should raise StructuredOutputError after exhausting retries + with pytest.raises(StructuredOutputError): + agent.invoke( + {"messages": [HumanMessage("What's the weather in Tokyo?")]}, + ) + + # Verify the model was called 3 times (initial + 2 retries) + assert model.index == 3 + + +def test_structured_output_retry_succeeds_first_attempt() -> None: + """Test structured output retry when first attempt succeeds (no retry needed).""" + # First attempt returns valid structured output + tool_calls = [ + [ + { + "name": "WeatherReport", + "id": "1", + "args": {"temperature": 68.0, "conditions": "cloudy"}, + } + ], + ] + + model = FakeToolCallingModel(tool_calls=tool_calls) + retry_middleware = StructuredOutputRetryMiddleware(max_retries=2) + + agent = create_agent( + model=model, + tools=[get_weather], + middleware=[retry_middleware], + response_format=ToolStrategy(schema=WeatherReport, handle_errors=False), + checkpointer=InMemorySaver(), + ) + + result = agent.invoke( + {"messages": [HumanMessage("What's the weather in Paris?")]}, + {"configurable": {"thread_id": "test"}}, + ) + + # Verify we got a structured response + assert "structured_response" in result + structured = result["structured_response"] + assert isinstance(structured, WeatherReport) + assert structured.temperature == 68.0 + assert structured.conditions == "cloudy" + + # Verify the model was called only once + assert model.index == 1 + + +def test_structured_output_retry_validation_error() -> None: + """Test structured output retry with schema validation errors.""" + # First attempt has wrong type, second has missing field, third succeeds + tool_calls = [ + [ + { + "name": "WeatherReport", + "id": "1", + "args": {"temperature": "seventy-two", "conditions": "sunny"}, + } + ], + [{"name": "WeatherReport", "id": "2", "args": {"temperature": 72.5}}], + [ + { + "name": "WeatherReport", + "id": "3", + "args": {"temperature": 72.5, "conditions": "partly cloudy"}, + } + ], + ] + + model = FakeToolCallingModel(tool_calls=tool_calls) + retry_middleware = StructuredOutputRetryMiddleware(max_retries=2) + + agent = create_agent( + model=model, + tools=[get_weather], + middleware=[retry_middleware], + response_format=ToolStrategy(schema=WeatherReport, handle_errors=False), + checkpointer=InMemorySaver(), + ) + + result = agent.invoke( + {"messages": [HumanMessage("What's the weather in London?")]}, + {"configurable": {"thread_id": "test"}}, + ) + + # Verify we got a structured response + assert "structured_response" in result + structured = result["structured_response"] + assert isinstance(structured, WeatherReport) + assert structured.temperature == 72.5 + assert structured.conditions == "partly cloudy" + + # Verify the model was called 3 times + assert model.index == 3 + + +def test_structured_output_retry_zero_retries() -> None: + """Test structured output retry with max_retries=0 (no retries allowed).""" + # First attempt returns invalid arguments + tool_calls = [ + [ + { + "name": "WeatherReport", + "id": "1", + "args": {"temperature": "invalid", "conditions": "sunny"}, + } + ], + [ + { + "name": "WeatherReport", + "id": "2", + "args": {"temperature": 72.5, "conditions": "sunny"}, + } + ], # Would succeed if retried + ] + + model = FakeToolCallingModel(tool_calls=tool_calls) + retry_middleware = StructuredOutputRetryMiddleware(max_retries=0) + + agent = create_agent( + model=model, + tools=[get_weather], + middleware=[retry_middleware], + response_format=ToolStrategy(schema=WeatherReport, handle_errors=False), + checkpointer=InMemorySaver(), + ) + + # Should fail immediately without retrying + with pytest.raises(StructuredOutputError): + agent.invoke( + {"messages": [HumanMessage("What's the weather in Berlin?")]}, + {"configurable": {"thread_id": "test"}}, + ) + + # Verify the model was called only once (no retries) + assert model.index == 1 + + +def test_structured_output_retry_preserves_messages() -> None: + """Test structured output retry preserves error feedback in messages.""" + # First attempt invalid, second succeeds + tool_calls = [ + [ + { + "name": "WeatherReport", + "id": "1", + "args": {"temperature": "invalid", "conditions": "rainy"}, + } + ], + [ + { + "name": "WeatherReport", + "id": "2", + "args": {"temperature": 75.0, "conditions": "rainy"}, + } + ], + ] + + model = FakeToolCallingModel(tool_calls=tool_calls) + retry_middleware = StructuredOutputRetryMiddleware(max_retries=1) + + agent = create_agent( + model=model, + tools=[get_weather], + middleware=[retry_middleware], + response_format=ToolStrategy(schema=WeatherReport, handle_errors=False), + checkpointer=InMemorySaver(), + ) + + result = agent.invoke( + {"messages": [HumanMessage("What's the weather in Seattle?")]}, + {"configurable": {"thread_id": "test"}}, + ) + + # Verify structured response is correct + assert "structured_response" in result + structured = result["structured_response"] + assert structured.temperature == 75.0 + assert structured.conditions == "rainy" + + # Verify messages include the retry feedback + messages = result["messages"] + human_messages = [m for m in messages if isinstance(m, HumanMessage)] + + # Should have at least 2 human messages: initial + retry feedback + assert len(human_messages) >= 2 + + # The retry feedback message should contain error information + retry_message = human_messages[-1] + assert "Error:" in retry_message.content + assert "Please try again" in retry_message.content diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_response_format.py b/libs/langchain_v1/tests/unit_tests/agents/test_response_format.py index a7963ced16f57..7df5c23463b36 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_response_format.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_response_format.py @@ -610,6 +610,35 @@ def test_retry_with_custom_string_message(self) -> None: ) assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC + def test_validation_error_with_invalid_response(self) -> None: + """Test that StructuredOutputValidationError is raised when tool strategy receives invalid response.""" + tool_calls = [ + [ + { + "name": "WeatherBaseModel", + "id": "1", + "args": {"invalid_field": "wrong_data", "another_bad_field": 123}, + }, + ], + ] + + model = FakeToolCallingModel(tool_calls=tool_calls) + + agent = create_agent( + model, + [], + response_format=ToolStrategy( + WeatherBaseModel, + handle_errors=False, # Disable retry to ensure error is raised + ), + ) + + with pytest.raises( + StructuredOutputValidationError, + match=".*WeatherBaseModel.*", + ): + agent.invoke({"messages": [HumanMessage("What's the weather?")]}) + class TestResponseFormatAsProviderStrategy: def test_pydantic_model(self) -> None: @@ -630,6 +659,28 @@ def test_pydantic_model(self) -> None: assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC assert len(response["messages"]) == 4 + def test_validation_error_with_invalid_response(self) -> None: + """Test that StructuredOutputValidationError is raised when provider strategy receives invalid response.""" + tool_calls = [ + [{"args": {}, "id": "1", "name": "get_weather"}], + ] + + # But we're using WeatherBaseModel which has different field requirements + model = FakeToolCallingModel[dict]( + tool_calls=tool_calls, + structured_response={"invalid": "data"}, # Wrong structure + ) + + agent = create_agent( + model, [get_weather], response_format=ProviderStrategy(WeatherBaseModel) + ) + + with pytest.raises( + StructuredOutputValidationError, + match=".*WeatherBaseModel.*", + ): + agent.invoke({"messages": [HumanMessage("What's the weather?")]}) + def test_dataclass(self) -> None: """Test response_format as ProviderStrategy with dataclass.""" tool_calls = [ diff --git a/libs/langchain_v1/uv.lock b/libs/langchain_v1/uv.lock index 41d6066b33e10..714b6ec057f6c 100644 --- a/libs/langchain_v1/uv.lock +++ b/libs/langchain_v1/uv.lock @@ -1743,7 +1743,7 @@ wheels = [ [[package]] name = "langchain-core" -version = "1.0.0" +version = "1.0.1" source = { editable = "../core" } dependencies = [ { name = "jsonpatch" },