diff --git a/libs/langchain_v1/Makefile b/libs/langchain_v1/Makefile index 7df0cec386f29..032ffeeee64b2 100644 --- a/libs/langchain_v1/Makefile +++ b/libs/langchain_v1/Makefile @@ -28,7 +28,7 @@ coverage: $(TEST_FILE) test: - make start_services && LANGGRAPH_TEST_FAST=0 uv run --group test pytest -n auto --disable-socket --allow-unix-socket $(TEST_FILE) --cov-report term-missing:skip-covered; \ + make start_services && LANGGRAPH_TEST_FAST=0 uv run --no-sync --active --group test pytest -n auto --disable-socket --allow-unix-socket $(TEST_FILE) --cov-report term-missing:skip-covered; \ EXIT_CODE=$$?; \ make stop_services; \ exit $$EXIT_CODE diff --git a/libs/langchain_v1/langchain/agents/factory.py b/libs/langchain_v1/langchain/agents/factory.py index 2f9962759fc7a..e0d81cd294d4a 100644 --- a/libs/langchain_v1/langchain/agents/factory.py +++ b/libs/langchain_v1/langchain/agents/factory.py @@ -11,6 +11,7 @@ from langgraph._internal._runnable import RunnableCallable from langgraph.constants import END, START from langgraph.graph.state import StateGraph +from langgraph.prebuilt.tool_node import ToolCallWithContext, ToolNode from langgraph.runtime import Runtime # noqa: TC002 from langgraph.types import Command, Send from langgraph.typing import ContextT # noqa: TC002 @@ -37,7 +38,6 @@ ToolStrategy, ) from langchain.chat_models import init_chat_model -from langchain.tools.tool_node import ToolCallWithContext, _ToolNode if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Sequence @@ -48,7 +48,7 @@ from langgraph.store.base import BaseStore from langgraph.types import Checkpointer - from langchain.tools.tool_node import ToolCallRequest, ToolCallWrapper + from langchain.agents.middleware.types import ToolCallRequest, ToolCallWrapper STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes." @@ -675,7 +675,7 @@ def check_weather(location: str) -> str: awrap_tool_call_wrapper = _chain_async_tool_call_wrappers(async_wrappers) # Setup tools - tool_node: _ToolNode | None = None + tool_node: ToolNode | None = None # Extract built-in provider tools (dict format) and regular tools (BaseTool/callables) built_in_tools = [t for t in tools if isinstance(t, dict)] regular_tools = [t for t in tools if not isinstance(t, dict)] @@ -685,7 +685,7 @@ def check_weather(location: str) -> str: # Only create ToolNode if we have client-side tools tool_node = ( - _ToolNode( + ToolNode( tools=available_tools, wrap_tool_call=wrap_tool_call_wrapper, awrap_tool_call=awrap_tool_call_wrapper, @@ -1491,7 +1491,7 @@ def model_to_model( def _make_tools_to_model_edge( *, - tool_node: _ToolNode, + tool_node: ToolNode, model_destination: str, structured_output_tools: dict[str, OutputToolBinding], end_destination: str, diff --git a/libs/langchain_v1/langchain/agents/middleware/shell_tool.py b/libs/langchain_v1/langchain/agents/middleware/shell_tool.py index f91f760f2ef9f..563ef2a2c39f7 100644 --- a/libs/langchain_v1/langchain/agents/middleware/shell_tool.py +++ b/libs/langchain_v1/langchain/agents/middleware/shell_tool.py @@ -45,7 +45,7 @@ from langgraph.runtime import Runtime from langgraph.types import Command - from langchain.tools.tool_node import ToolCallRequest + from langchain.agents.middleware.types import ToolCallRequest LOGGER = logging.getLogger(__name__) _DONE_MARKER_PREFIX = "__LC_SHELL_DONE__" diff --git a/libs/langchain_v1/langchain/agents/middleware/tool_emulator.py b/libs/langchain_v1/langchain/agents/middleware/tool_emulator.py index 1b6ee202167e6..019045c577ac3 100644 --- a/libs/langchain_v1/langchain/agents/middleware/tool_emulator.py +++ b/libs/langchain_v1/langchain/agents/middleware/tool_emulator.py @@ -15,8 +15,8 @@ from langgraph.types import Command + from langchain.agents.middleware.types import ToolCallRequest from langchain.tools import BaseTool - from langchain.tools.tool_node import ToolCallRequest class LLMToolEmulator(AgentMiddleware): diff --git a/libs/langchain_v1/langchain/agents/middleware/tool_retry.py b/libs/langchain_v1/langchain/agents/middleware/tool_retry.py index 7fb319a46d0f6..361158b0c6187 100644 --- a/libs/langchain_v1/langchain/agents/middleware/tool_retry.py +++ b/libs/langchain_v1/langchain/agents/middleware/tool_retry.py @@ -16,8 +16,8 @@ from langgraph.types import Command + from langchain.agents.middleware.types import ToolCallRequest from langchain.tools import BaseTool - from langchain.tools.tool_node import ToolCallRequest class ToolRetryMiddleware(AgentMiddleware): diff --git a/libs/langchain_v1/langchain/agents/middleware/types.py b/libs/langchain_v1/langchain/agents/middleware/types.py index a100c65ce179f..c9bfa2ee74fc0 100644 --- a/libs/langchain_v1/langchain/agents/middleware/types.py +++ b/libs/langchain_v1/langchain/agents/middleware/types.py @@ -19,14 +19,13 @@ if TYPE_CHECKING: from collections.abc import Awaitable - from langchain.tools.tool_node import ToolCallRequest - # Needed as top level import for Pydantic schema generation on AgentState from typing import TypeAlias from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, ToolMessage # noqa: TC002 from langgraph.channels.ephemeral_value import EphemeralValue from langgraph.graph.message import add_messages +from langgraph.prebuilt.tool_node import ToolCallRequest, ToolCallWrapper from langgraph.types import Command # noqa: TC002 from langgraph.typing import ContextT from typing_extensions import NotRequired, Required, TypedDict, TypeVar, Unpack @@ -45,6 +44,8 @@ "ModelRequest", "ModelResponse", "OmitFromSchema", + "ToolCallRequest", + "ToolCallWrapper", "after_agent", "after_model", "before_agent", diff --git a/libs/langchain_v1/langchain/tools/tool_node.py b/libs/langchain_v1/langchain/tools/tool_node.py index 5a685f6fbb9cb..4474c8ba935fb 100644 --- a/libs/langchain_v1/langchain/tools/tool_node.py +++ b/libs/langchain_v1/langchain/tools/tool_node.py @@ -1,1786 +1,20 @@ -"""Tool execution node for LangGraph workflows. +"""Utils file included for backwards compat imports.""" -This module provides prebuilt functionality for executing tools in LangGraph. - -Tools are functions that models can call to interact with external systems, -APIs, databases, or perform computations. - -The module implements design patterns for: -- Parallel execution of multiple tool calls for efficiency -- Robust error handling with customizable error messages -- State injection for tools that need access to graph state -- Store injection for tools that need persistent storage -- Command-based state updates for advanced control flow - -Key Components: - `ToolNode`: Main class for executing tools in LangGraph workflows - `InjectedState`: Annotation for injecting graph state into tools - `InjectedStore`: Annotation for injecting persistent store into tools - `tools_condition`: Utility function for conditional routing based on tool calls - -Typical Usage: - ```python - from langchain_core.tools import tool - from langchain.tools import ToolNode - - - @tool - def my_tool(x: int) -> str: - return f"Result: {x}" - - - tool_node = ToolNode([my_tool]) - ``` -""" - -from __future__ import annotations - -import asyncio -import inspect -import json -from collections.abc import Awaitable, Callable -from copy import copy, deepcopy -from dataclasses import dataclass, replace -from types import UnionType -from typing import ( - TYPE_CHECKING, - Annotated, - Any, - Generic, - Literal, - TypedDict, - Union, - cast, - get_args, - get_origin, - get_type_hints, -) - -from langchain_core.messages import ( - AIMessage, - AnyMessage, - RemoveMessage, - ToolCall, - ToolMessage, - convert_to_messages, -) -from langchain_core.runnables.config import ( - RunnableConfig, - get_config_list, - get_executor_for_config, -) -from langchain_core.tools import BaseTool, InjectedToolArg -from langchain_core.tools import tool as create_tool -from langchain_core.tools.base import ( - TOOL_MESSAGE_BLOCK_TYPES, - ToolException, - _DirectlyInjectedToolArg, - get_all_basemodel_annotations, -) -from langgraph._internal._runnable import RunnableCallable -from langgraph.errors import GraphBubbleUp -from langgraph.graph.message import REMOVE_ALL_MESSAGES -from langgraph.store.base import BaseStore # noqa: TC002 -from langgraph.types import Command, Send, StreamWriter -from pydantic import BaseModel, ValidationError -from typing_extensions import TypeVar, Unpack - -if TYPE_CHECKING: - from collections.abc import Sequence - - from langgraph.runtime import Runtime - from pydantic_core import ErrorDetails - -# right now we use a dict as the default, can change this to AgentState, but depends -# on if this lives in LangChain or LangGraph... ideally would have some typed -# messages key -StateT = TypeVar("StateT", default=dict) -ContextT = TypeVar("ContextT", default=None) - -INVALID_TOOL_NAME_ERROR_TEMPLATE = ( - "Error: {requested_tool} is not a valid tool, try one of [{available_tools}]." +from langgraph.prebuilt import InjectedState, InjectedStore, ToolRuntime +from langgraph.prebuilt.tool_node import ( + ToolCallRequest, + ToolCallWithContext, + ToolCallWrapper, ) -TOOL_CALL_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes." -TOOL_EXECUTION_ERROR_TEMPLATE = ( - "Error executing tool '{tool_name}' with kwargs {tool_kwargs} with error:\n" - " {error}\n" - " Please fix the error and try again." +from langgraph.prebuilt.tool_node import ( + ToolNode as _ToolNode, # noqa: F401 ) -TOOL_INVOCATION_ERROR_TEMPLATE = ( - "Error invoking tool '{tool_name}' with kwargs {tool_kwargs} with error:\n" - " {error}\n" - " Please fix the error and try again." -) - - -class _ToolCallRequestOverrides(TypedDict, total=False): - """Possible overrides for ToolCallRequest.override() method.""" - - tool_call: ToolCall - - -@dataclass() -class ToolCallRequest: - """Tool execution request passed to tool call interceptors. - - Attributes: - tool_call: Tool call dict with name, args, and id from model output. - tool: BaseTool instance to be invoked, or None if tool is not - registered with the `ToolNode`. When tool is `None`, interceptors can - handle the request without validation. If the interceptor calls `execute()`, - validation will occur and raise an error for unregistered tools. - state: Agent state (`dict`, `list`, or `BaseModel`). - runtime: LangGraph runtime context (optional, `None` if outside graph). - """ - - tool_call: ToolCall - tool: BaseTool | None - state: Any - runtime: ToolRuntime - - def override(self, **overrides: Unpack[_ToolCallRequestOverrides]) -> ToolCallRequest: - """Replace the request with a new request with the given overrides. - - Returns a new `ToolCallRequest` instance with the specified attributes replaced. - This follows an immutable pattern, leaving the original request unchanged. - - Args: - **overrides: Keyword arguments for attributes to override. Supported keys: - - tool_call: Tool call dict with name, args, and id - - Returns: - New ToolCallRequest instance with specified overrides applied. - - Examples: - ```python - # Modify tool call arguments without mutating original - modified_call = {**request.tool_call, "args": {"value": 10}} - new_request = request.override(tool_call=modified_call) - - # Override multiple attributes - new_request = request.override(tool_call=modified_call, state=new_state) - ``` - """ - return replace(self, **overrides) - - -ToolCallWrapper = Callable[ - [ToolCallRequest, Callable[[ToolCallRequest], ToolMessage | Command]], - ToolMessage | Command, -] -"""Wrapper for tool call execution with multi-call support. - -Wrapper receives: - request: ToolCallRequest with tool_call, tool, state, and runtime. - execute: Callable to execute the tool (CAN BE CALLED MULTIPLE TIMES). - -Returns: - ToolMessage or Command (the final result). -The execute callable can be invoked multiple times for retry logic, -with potentially modified requests each time. Each call to execute -is independent and stateless. - -!!! note - When implementing middleware for `create_agent`, use - `AgentMiddleware.wrap_tool_call` which provides properly typed - state parameter for better type safety. - -Examples: - Passthrough (execute once): - - def handler(request, execute): - return execute(request) - - Modify request before execution: - - ```python - def handler(request, execute): - request.tool_call["args"]["value"] *= 2 - return execute(request) - ``` - - Retry on error (execute multiple times): - - ```python - def handler(request, execute): - for attempt in range(3): - try: - result = execute(request) - if is_valid(result): - return result - except Exception: - if attempt == 2: - raise - return result - ``` - - Conditional retry based on response: - - ```python - def handler(request, execute): - for attempt in range(3): - result = execute(request) - if isinstance(result, ToolMessage) and result.status != "error": - return result - if attempt < 2: - continue - return result - ``` - - Cache/short-circuit without calling execute: - - ```python - def handler(request, execute): - if cached := get_cache(request): - return ToolMessage(content=cached, tool_call_id=request.tool_call["id"]) - result = execute(request) - save_cache(request, result) - return result - ``` -""" - -AsyncToolCallWrapper = Callable[ - [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]], - Awaitable[ToolMessage | Command], +__all__ = [ + "InjectedState", + "InjectedStore", + "ToolCallRequest", + "ToolCallWithContext", + "ToolCallWrapper", + "ToolRuntime", ] -"""Async wrapper for tool call execution with multi-call support.""" - - -class ToolCallWithContext(TypedDict): - """ToolCall with additional context for graph state. - - This is an internal data structure meant to help the `ToolNode` accept - tool calls with additional context (e.g. state) when dispatched using the - Send API. - - The Send API is used in create_agent to distribute tool calls in parallel - and support human-in-the-loop workflows where graph execution may be paused - for an indefinite time. - """ - - tool_call: ToolCall - __type: Literal["tool_call_with_context"] - """Type to parameterize the payload. - - Using "__" as a prefix to be defensive against potential name collisions with - regular user state. - """ - state: Any - """The state is provided as additional context.""" - - -def msg_content_output(output: Any) -> str | list[dict]: - """Convert tool output to `ToolMessage` content format. - - Handles `str`, `list[dict]` (content blocks), and arbitrary objects by attempting - JSON serialization with fallback to str(). - - Args: - output: Tool execution output of any type. - - Returns: - String or list of content blocks suitable for `ToolMessage.content`. - """ - if isinstance(output, str) or ( - isinstance(output, list) - and all(isinstance(x, dict) and x.get("type") in TOOL_MESSAGE_BLOCK_TYPES for x in output) - ): - return output - # Technically a list of strings is also valid message content, but it's - # not currently well tested that all chat models support this. - # And for backwards compatibility we want to make sure we don't break - # any existing ToolNode usage. - try: - return json.dumps(output, ensure_ascii=False) - except Exception: # noqa: BLE001 - return str(output) - - -class ToolInvocationError(ToolException): - """An error occurred while invoking a tool due to invalid arguments. - - This exception is only raised when invoking a tool using the `ToolNode`! - """ - - def __init__( - self, - tool_name: str, - source: ValidationError, - tool_kwargs: dict[str, Any], - filtered_errors: list[ErrorDetails] | None = None, - ) -> None: - """Initialize the ToolInvocationError. - - Args: - tool_name: The name of the tool that failed. - source: The exception that occurred. - tool_kwargs: The keyword arguments that were passed to the tool. - filtered_errors: Optional list of filtered validation errors excluding - injected arguments. - """ - # Format error display based on filtered errors if provided - if filtered_errors is not None: - # Manually format the filtered errors without URLs or fancy formatting - error_str_parts = [] - for error in filtered_errors: - loc_str = ".".join(str(loc) for loc in error.get("loc", ())) - msg = error.get("msg", "Unknown error") - error_str_parts.append(f"{loc_str}: {msg}") - error_display_str = "\n".join(error_str_parts) - else: - error_display_str = str(source) - - self.message = TOOL_INVOCATION_ERROR_TEMPLATE.format( - tool_name=tool_name, tool_kwargs=tool_kwargs, error=error_display_str - ) - self.tool_name = tool_name - self.tool_kwargs = tool_kwargs - self.source = source - self.filtered_errors = filtered_errors - super().__init__(self.message) - - -def _default_handle_tool_errors(e: Exception) -> str: - """Default error handler for tool errors. - - If the tool is a tool invocation error, return its message. - Otherwise, raise the error. - """ - if isinstance(e, ToolInvocationError): - return e.message - raise e - - -def _handle_tool_error( - e: Exception, - *, - flag: bool | str | Callable[..., str] | type[Exception] | tuple[type[Exception], ...], -) -> str: - """Generate error message content based on exception handling configuration. - - This function centralizes error message generation logic, supporting different - error handling strategies configured via the `ToolNode`'s `handle_tool_errors` - parameter. - - Args: - e: The exception that occurred during tool execution. - flag: Configuration for how to handle the error. Can be: - - bool: If `True`, use default error template - - str: Use this string as the error message - - Callable: Call this function with the exception to get error message - - tuple: Not used in this context (handled by caller) - - Returns: - A string containing the error message to include in the `ToolMessage`. - - Raises: - ValueError: If flag is not one of the supported types. - - !!! note - The tuple case is handled by the caller through exception type checking, - not by this function directly. - """ - if isinstance(flag, (bool, tuple)) or (isinstance(flag, type) and issubclass(flag, Exception)): - content = TOOL_CALL_ERROR_TEMPLATE.format(error=repr(e)) - elif isinstance(flag, str): - content = flag - elif callable(flag): - content = flag(e) # type: ignore [assignment, call-arg] - else: - msg = ( - f"Got unexpected type of `handle_tool_error`. Expected bool, str " - f"or callable. Received: {flag}" - ) - raise ValueError(msg) - return content - - -def _infer_handled_types(handler: Callable[..., str]) -> tuple[type[Exception], ...]: - """Infer exception types handled by a custom error handler function. - - This function analyzes the type annotations of a custom error handler to determine - which exception types it's designed to handle. This enables type-safe error handling - where only specific exceptions are caught and processed by the handler. - - Args: - handler: A callable that takes an exception and returns an error message string. - The first parameter (after self/cls if present) should be type-annotated - with the exception type(s) to handle. - - Returns: - A tuple of exception types that the handler can process. Returns (Exception,) - if no specific type information is available for backward compatibility. - - Raises: - ValueError: If the handler's annotation contains non-Exception types or - if Union types contain non-Exception types. - - !!! note - This function supports both single exception types and Union types for - handlers that need to handle multiple exception types differently. - """ - sig = inspect.signature(handler) - params = list(sig.parameters.values()) - if params: - # If it's a method, the first argument is typically 'self' or 'cls' - if params[0].name in ["self", "cls"] and len(params) == 2: - first_param = params[1] - else: - first_param = params[0] - - type_hints = get_type_hints(handler) - if first_param.name in type_hints: - origin = get_origin(first_param.annotation) - if origin in [Union, UnionType]: - args = get_args(first_param.annotation) - if all(issubclass(arg, Exception) for arg in args): - return tuple(args) - msg = ( - "All types in the error handler error annotation must be " - "Exception types. For example, " - "`def custom_handler(e: Union[ValueError, TypeError])`. " - f"Got '{first_param.annotation}' instead." - ) - raise ValueError(msg) - - exception_type = type_hints[first_param.name] - if Exception in exception_type.__mro__: - return (exception_type,) - msg = ( - f"Arbitrary types are not supported in the error handler " - f"signature. Please annotate the error with either a " - f"specific Exception type or a union of Exception types. " - "For example, `def custom_handler(e: ValueError)` or " - "`def custom_handler(e: Union[ValueError, TypeError])`. " - f"Got '{exception_type}' instead." - ) - raise ValueError(msg) - - # If no type information is available, return (Exception,) - # for backwards compatibility. - return (Exception,) - - -def _filter_validation_errors( - validation_error: ValidationError, - tool_to_state_args: dict[str, str | None], - tool_to_store_arg: str | None, - tool_to_runtime_arg: str | None, -) -> list[ErrorDetails]: - """Filter validation errors to only include LLM-controlled arguments. - - When a tool invocation fails validation, only errors for arguments that the LLM - controls should be included in error messages. This ensures the LLM receives - focused, actionable feedback about parameters it can actually fix. System-injected - arguments (state, store, runtime) are filtered out since the LLM has no control - over them. - - This function also removes injected argument values from the `input` field in error - details, ensuring that only LLM-provided arguments appear in error messages. - - Args: - validation_error: The Pydantic ValidationError raised during tool invocation. - tool_to_state_args: Mapping of state argument names to state field names. - tool_to_store_arg: Name of the store argument, if any. - tool_to_runtime_arg: Name of the runtime argument, if any. - - Returns: - List of ErrorDetails containing only errors for LLM-controlled arguments, - with system-injected argument values removed from the input field. - """ - injected_args = set(tool_to_state_args.keys()) - if tool_to_store_arg: - injected_args.add(tool_to_store_arg) - if tool_to_runtime_arg: - injected_args.add(tool_to_runtime_arg) - - filtered_errors: list[ErrorDetails] = [] - for error in validation_error.errors(): - # Check if error location contains any injected argument - # error['loc'] is a tuple like ('field_name',) or ('field_name', 'nested_field') - if error["loc"] and error["loc"][0] not in injected_args: - # Create a copy of the error dict to avoid mutating the original - error_copy: dict[str, Any] = {**error} - - # Remove injected arguments from input_value if it's a dict - if isinstance(error_copy.get("input"), dict): - input_dict = error_copy["input"] - input_copy = {k: v for k, v in input_dict.items() if k not in injected_args} - error_copy["input"] = input_copy - - # Cast is safe because ErrorDetails is a TypedDict compatible with this structure - filtered_errors.append(error_copy) # type: ignore[arg-type] - - return filtered_errors - - -class _ToolNode(RunnableCallable): - """A node for executing tools in LangGraph workflows. - - Handles tool execution patterns including function calls, state injection, - persistent storage, and control flow. Manages parallel execution, - error handling. - - Input Formats: - 1. Graph state with `messages` key that has a list of messages: - - Common representation for agentic workflows - - Supports custom messages key via `messages_key` parameter - - 2. **Message List**: `[AIMessage(..., tool_calls=[...])]` - - List of messages with tool calls in the last AIMessage - - 3. **Direct Tool Calls**: `[{"name": "tool", "args": {...}, "id": "1", "type": "tool_call"}]` - - Bypasses message parsing for direct tool execution - - For programmatic tool invocation and testing - - Output Formats: - Output format depends on input type and tool behavior: - - **For Regular tools**: - - Dict input → `{"messages": [ToolMessage(...)]}` - - List input → `[ToolMessage(...)]` - - **For Command tools**: - - Returns `[Command(...)]` or mixed list with regular tool outputs - - Commands can update state, trigger navigation, or send messages - - Args: - tools: A sequence of tools that can be invoked by this node. Supports: - - **BaseTool instances**: Tools with schemas and metadata - - **Plain functions**: Automatically converted to tools with inferred schemas - name: The name identifier for this node in the graph. Used for debugging - and visualization. Defaults to "tools". - tags: Optional metadata tags to associate with the node for filtering - and organization. Defaults to `None`. - handle_tool_errors: Configuration for error handling during tool execution. - Supports multiple strategies: - - - **True**: Catch all errors and return a ToolMessage with the default - error template containing the exception details. - - **str**: Catch all errors and return a ToolMessage with this custom - error message string. - - **type[Exception]**: Only catch exceptions with the specified type and - return the default error message for it. - - **tuple[type[Exception], ...]**: Only catch exceptions with the specified - types and return default error messages for them. - - **Callable[..., str]**: Catch exceptions matching the callable's signature - and return the string result of calling it with the exception. - - **False**: Disable error handling entirely, allowing exceptions to - propagate. - - Defaults to a callable that: - - catches tool invocation errors (due to invalid arguments provided by the model) and returns a descriptive error message - - ignores tool execution errors (they will be re-raised) - - messages_key: The key in the state dictionary that contains the message list. - This same key will be used for the output `ToolMessage` objects. - Defaults to "messages". - Allows custom state schemas with different message field names. - - Examples: - Basic usage: - - ```python - from langchain.tools import ToolNode - from langchain_core.tools import tool - - @tool - def calculator(a: int, b: int) -> int: - \"\"\"Add two numbers.\"\"\" - return a + b - - tool_node = ToolNode([calculator]) - ``` - - State injection: - - ```python - from typing_extensions import Annotated - from langchain.tools import InjectedState - - @tool - def context_tool(query: str, state: Annotated[dict, InjectedState]) -> str: - \"\"\"Some tool that uses state.\"\"\" - return f"Query: {query}, Messages: {len(state['messages'])}" - - tool_node = ToolNode([context_tool]) - ``` - - Error handling: - - ```python - def handle_errors(e: ValueError) -> str: - return "Invalid input provided" - - - tool_node = ToolNode([my_tool], handle_tool_errors=handle_errors) - ``` - """ # noqa: E501 - - name: str = "tools" - - def __init__( - self, - tools: Sequence[BaseTool | Callable], - *, - name: str = "tools", - tags: list[str] | None = None, - handle_tool_errors: bool - | str - | Callable[..., str] - | type[Exception] - | tuple[type[Exception], ...] = _default_handle_tool_errors, - messages_key: str = "messages", - wrap_tool_call: ToolCallWrapper | None = None, - awrap_tool_call: AsyncToolCallWrapper | None = None, - ) -> None: - """Initialize `ToolNode` with tools and configuration. - - Args: - tools: Sequence of tools to make available for execution. - name: Node name for graph identification. - tags: Optional metadata tags. - handle_tool_errors: Error handling configuration. - messages_key: State key containing messages. - wrap_tool_call: Sync wrapper function to intercept tool execution. Receives - ToolCallRequest and execute callable, returns ToolMessage or Command. - Enables retries, caching, request modification, and control flow. - awrap_tool_call: Async wrapper function to intercept tool execution. - If not provided, falls back to wrap_tool_call for async execution. - """ - super().__init__(self._func, self._afunc, name=name, tags=tags, trace=False) - self._tools_by_name: dict[str, BaseTool] = {} - self._tool_to_state_args: dict[str, dict[str, str | None]] = {} - self._tool_to_store_arg: dict[str, str | None] = {} - self._tool_to_runtime_arg: dict[str, str | None] = {} - self._handle_tool_errors = handle_tool_errors - self._messages_key = messages_key - self._wrap_tool_call = wrap_tool_call - self._awrap_tool_call = awrap_tool_call - for tool in tools: - if not isinstance(tool, BaseTool): - tool_ = create_tool(cast("type[BaseTool]", tool)) - else: - tool_ = tool - self._tools_by_name[tool_.name] = tool_ - self._tool_to_state_args[tool_.name] = _get_state_args(tool_) - self._tool_to_store_arg[tool_.name] = _get_store_arg(tool_) - self._tool_to_runtime_arg[tool_.name] = _get_runtime_arg(tool_) - - @property - def tools_by_name(self) -> dict[str, BaseTool]: - """Mapping from tool name to BaseTool instance.""" - return self._tools_by_name - - def _func( - self, - input: list[AnyMessage] | dict[str, Any] | BaseModel, - config: RunnableConfig, - runtime: Runtime, - ) -> Any: - tool_calls, input_type = self._parse_input(input) - config_list = get_config_list(config, len(tool_calls)) - - # Construct ToolRuntime instances at the top level for each tool call - tool_runtimes = [] - for call, cfg in zip(tool_calls, config_list, strict=False): - state = self._extract_state(input) - tool_runtime = ToolRuntime( - state=state, - tool_call_id=call["id"], - config=cfg, - context=runtime.context, - store=runtime.store, - stream_writer=runtime.stream_writer, - ) - tool_runtimes.append(tool_runtime) - - # Pass original tool calls without injection - input_types = [input_type] * len(tool_calls) - with get_executor_for_config(config) as executor: - outputs = list(executor.map(self._run_one, tool_calls, input_types, tool_runtimes)) - - return self._combine_tool_outputs(outputs, input_type) - - async def _afunc( - self, - input: list[AnyMessage] | dict[str, Any] | BaseModel, - config: RunnableConfig, - runtime: Runtime, - ) -> Any: - tool_calls, input_type = self._parse_input(input) - config_list = get_config_list(config, len(tool_calls)) - - # Construct ToolRuntime instances at the top level for each tool call - tool_runtimes = [] - for call, cfg in zip(tool_calls, config_list, strict=False): - state = self._extract_state(input) - tool_runtime = ToolRuntime( - state=state, - tool_call_id=call["id"], - config=cfg, - context=runtime.context, - store=runtime.store, - stream_writer=runtime.stream_writer, - ) - tool_runtimes.append(tool_runtime) - - # Pass original tool calls without injection - coros = [] - for call, tool_runtime in zip(tool_calls, tool_runtimes, strict=False): - coros.append(self._arun_one(call, input_type, tool_runtime)) # type: ignore[arg-type] - outputs = await asyncio.gather(*coros) - - return self._combine_tool_outputs(outputs, input_type) - - def _combine_tool_outputs( - self, - outputs: list[ToolMessage | Command], - input_type: Literal["list", "dict", "tool_calls"], - ) -> list[Command | list[ToolMessage] | dict[str, list[ToolMessage]]]: - # preserve existing behavior for non-command tool outputs for backwards - # compatibility - if not any(isinstance(output, Command) for output in outputs): - # TypedDict, pydantic, dataclass, etc. should all be able to load from dict - return outputs if input_type == "list" else {self._messages_key: outputs} # type: ignore[return-value, return-value] - - # LangGraph will automatically handle list of Command and non-command node - # updates - combined_outputs: list[Command | list[ToolMessage] | dict[str, list[ToolMessage]]] = [] - - # combine all parent commands with goto into a single parent command - parent_command: Command | None = None - for output in outputs: - if isinstance(output, Command): - if ( - output.graph is Command.PARENT - and isinstance(output.goto, list) - and all(isinstance(send, Send) for send in output.goto) - ): - if parent_command: - parent_command = replace( - parent_command, - goto=cast("list[Send]", parent_command.goto) + output.goto, - ) - else: - parent_command = Command(graph=Command.PARENT, goto=output.goto) - else: - combined_outputs.append(output) - else: - combined_outputs.append( - [output] if input_type == "list" else {self._messages_key: [output]} - ) - - if parent_command: - combined_outputs.append(parent_command) - return combined_outputs - - def _execute_tool_sync( - self, - request: ToolCallRequest, - input_type: Literal["list", "dict", "tool_calls"], - config: RunnableConfig, - ) -> ToolMessage | Command: - """Execute tool call with configured error handling. - - Args: - request: Tool execution request. - input_type: Input format. - config: Runnable configuration. - - Returns: - ToolMessage or Command. - - Raises: - Exception: If tool fails and handle_tool_errors is False. - """ - call = request.tool_call - tool = request.tool - - # Validate tool exists when we actually need to execute it - if tool is None: - if invalid_tool_message := self._validate_tool_call(call): - return invalid_tool_message - # This should never happen if validation works correctly - msg = f"Tool {call['name']} is not registered with ToolNode" - raise TypeError(msg) - - # Inject state, store, and runtime right before invocation - injected_call = self._inject_tool_args(call, request.runtime) - call_args = {**injected_call, "type": "tool_call"} - - try: - try: - response = tool.invoke(call_args, config) - except ValidationError as exc: - # Filter out errors for injected arguments - filtered_errors = _filter_validation_errors( - exc, - self._tool_to_state_args.get(call["name"], {}), - self._tool_to_store_arg.get(call["name"]), - self._tool_to_runtime_arg.get(call["name"]), - ) - # Use original call["args"] without injected values for error reporting - raise ToolInvocationError(call["name"], exc, call["args"], filtered_errors) from exc - - # GraphInterrupt is a special exception that will always be raised. - # It can be triggered in the following scenarios, - # Where GraphInterrupt(GraphBubbleUp) is raised from an `interrupt` invocation - # most commonly: - # (1) a GraphInterrupt is raised inside a tool - # (2) a GraphInterrupt is raised inside a graph node for a graph called as a tool - # (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph - # called as a tool - # (2 and 3 can happen in a "supervisor w/ tools" multi-agent architecture) - except GraphBubbleUp: - raise - except Exception as e: - # Determine which exception types are handled - handled_types: tuple[type[Exception], ...] - if isinstance(self._handle_tool_errors, type) and issubclass( - self._handle_tool_errors, Exception - ): - handled_types = (self._handle_tool_errors,) - elif isinstance(self._handle_tool_errors, tuple): - handled_types = self._handle_tool_errors - elif callable(self._handle_tool_errors) and not isinstance( - self._handle_tool_errors, type - ): - handled_types = _infer_handled_types(self._handle_tool_errors) - else: - # default behavior is catching all exceptions - handled_types = (Exception,) - - # Check if this error should be handled - if not self._handle_tool_errors or not isinstance(e, handled_types): - raise - - # Error is handled - create error ToolMessage - content = _handle_tool_error(e, flag=self._handle_tool_errors) - return ToolMessage( - content=content, - name=call["name"], - tool_call_id=call["id"], - status="error", - ) - - # Process successful response - if isinstance(response, Command): - # Validate Command before returning to handler - return self._validate_tool_command(response, request.tool_call, input_type) - if isinstance(response, ToolMessage): - response.content = cast("str | list", msg_content_output(response.content)) - return response - - msg = f"Tool {call['name']} returned unexpected type: {type(response)}" - raise TypeError(msg) - - def _run_one( - self, - call: ToolCall, - input_type: Literal["list", "dict", "tool_calls"], - tool_runtime: ToolRuntime, - ) -> ToolMessage | Command: - """Execute single tool call with wrap_tool_call wrapper if configured. - - Args: - call: Tool call dict. - input_type: Input format. - tool_runtime: Tool runtime. - - Returns: - ToolMessage or Command. - """ - # Validation is deferred to _execute_tool_sync to allow interceptors - # to short-circuit requests for unregistered tools - tool = self.tools_by_name.get(call["name"]) - - # Create the tool request with state and runtime - tool_request = ToolCallRequest( - tool_call=call, - tool=tool, - state=tool_runtime.state, - runtime=tool_runtime, - ) - - config = tool_runtime.config - - if self._wrap_tool_call is None: - # No wrapper - execute directly - return self._execute_tool_sync(tool_request, input_type, config) - - # Define execute callable that can be called multiple times - def execute(req: ToolCallRequest) -> ToolMessage | Command: - """Execute tool with given request. Can be called multiple times.""" - return self._execute_tool_sync(req, input_type, config) - - # Call wrapper with request and execute callable - try: - return self._wrap_tool_call(tool_request, execute) - except Exception as e: - # Wrapper threw an exception - if not self._handle_tool_errors: - raise - # Convert to error message - content = _handle_tool_error(e, flag=self._handle_tool_errors) - return ToolMessage( - content=content, - name=tool_request.tool_call["name"], - tool_call_id=tool_request.tool_call["id"], - status="error", - ) - - async def _execute_tool_async( - self, - request: ToolCallRequest, - input_type: Literal["list", "dict", "tool_calls"], - config: RunnableConfig, - ) -> ToolMessage | Command: - """Execute tool call asynchronously with configured error handling. - - Args: - request: Tool execution request. - input_type: Input format. - config: Runnable configuration. - - Returns: - ToolMessage or Command. - - Raises: - Exception: If tool fails and handle_tool_errors is False. - """ - call = request.tool_call - tool = request.tool - - # Validate tool exists when we actually need to execute it - if tool is None: - if invalid_tool_message := self._validate_tool_call(call): - return invalid_tool_message - # This should never happen if validation works correctly - msg = f"Tool {call['name']} is not registered with ToolNode" - raise TypeError(msg) - - # Inject state, store, and runtime right before invocation - injected_call = self._inject_tool_args(call, request.runtime) - call_args = {**injected_call, "type": "tool_call"} - - try: - try: - response = await tool.ainvoke(call_args, config) - except ValidationError as exc: - # Filter out errors for injected arguments - filtered_errors = _filter_validation_errors( - exc, - self._tool_to_state_args.get(call["name"], {}), - self._tool_to_store_arg.get(call["name"]), - self._tool_to_runtime_arg.get(call["name"]), - ) - # Use original call["args"] without injected values for error reporting - raise ToolInvocationError(call["name"], exc, call["args"], filtered_errors) from exc - - # GraphInterrupt is a special exception that will always be raised. - # It can be triggered in the following scenarios, - # Where GraphInterrupt(GraphBubbleUp) is raised from an `interrupt` invocation - # most commonly: - # (1) a GraphInterrupt is raised inside a tool - # (2) a GraphInterrupt is raised inside a graph node for a graph called as a tool - # (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph - # called as a tool - # (2 and 3 can happen in a "supervisor w/ tools" multi-agent architecture) - except GraphBubbleUp: - raise - except Exception as e: - # Determine which exception types are handled - handled_types: tuple[type[Exception], ...] - if isinstance(self._handle_tool_errors, type) and issubclass( - self._handle_tool_errors, Exception - ): - handled_types = (self._handle_tool_errors,) - elif isinstance(self._handle_tool_errors, tuple): - handled_types = self._handle_tool_errors - elif callable(self._handle_tool_errors) and not isinstance( - self._handle_tool_errors, type - ): - handled_types = _infer_handled_types(self._handle_tool_errors) - else: - # default behavior is catching all exceptions - handled_types = (Exception,) - - # Check if this error should be handled - if not self._handle_tool_errors or not isinstance(e, handled_types): - raise - - # Error is handled - create error ToolMessage - content = _handle_tool_error(e, flag=self._handle_tool_errors) - return ToolMessage( - content=content, - name=call["name"], - tool_call_id=call["id"], - status="error", - ) - - # Process successful response - if isinstance(response, Command): - # Validate Command before returning to handler - return self._validate_tool_command(response, request.tool_call, input_type) - if isinstance(response, ToolMessage): - response.content = cast("str | list", msg_content_output(response.content)) - return response - - msg = f"Tool {call['name']} returned unexpected type: {type(response)}" - raise TypeError(msg) - - async def _arun_one( - self, - call: ToolCall, - input_type: Literal["list", "dict", "tool_calls"], - tool_runtime: ToolRuntime, - ) -> ToolMessage | Command: - """Execute single tool call asynchronously with awrap_tool_call wrapper if configured. - - Args: - call: Tool call dict. - input_type: Input format. - tool_runtime: Tool runtime. - - Returns: - ToolMessage or Command. - """ - # Validation is deferred to _execute_tool_async to allow interceptors - # to short-circuit requests for unregistered tools - tool = self.tools_by_name.get(call["name"]) - - # Create the tool request with state and runtime - tool_request = ToolCallRequest( - tool_call=call, - tool=tool, - state=tool_runtime.state, - runtime=tool_runtime, - ) - - config = tool_runtime.config - - if self._awrap_tool_call is None and self._wrap_tool_call is None: - # No wrapper - execute directly - return await self._execute_tool_async(tool_request, input_type, config) - - # Define async execute callable that can be called multiple times - async def execute(req: ToolCallRequest) -> ToolMessage | Command: - """Execute tool with given request. Can be called multiple times.""" - return await self._execute_tool_async(req, input_type, config) - - def _sync_execute(req: ToolCallRequest) -> ToolMessage | Command: - """Sync execute fallback for sync wrapper.""" - return self._execute_tool_sync(req, input_type, config) - - # Call wrapper with request and execute callable - try: - if self._awrap_tool_call is not None: - return await self._awrap_tool_call(tool_request, execute) - # None check was performed above already - self._wrap_tool_call = cast("ToolCallWrapper", self._wrap_tool_call) - return self._wrap_tool_call(tool_request, _sync_execute) - except Exception as e: - # Wrapper threw an exception - if not self._handle_tool_errors: - raise - # Convert to error message - content = _handle_tool_error(e, flag=self._handle_tool_errors) - return ToolMessage( - content=content, - name=tool_request.tool_call["name"], - tool_call_id=tool_request.tool_call["id"], - status="error", - ) - - def _parse_input( - self, - input: list[AnyMessage] | dict[str, Any] | BaseModel, - ) -> tuple[list[ToolCall], Literal["list", "dict", "tool_calls"]]: - input_type: Literal["list", "dict", "tool_calls"] - if isinstance(input, list): - if isinstance(input[-1], dict) and input[-1].get("type") == "tool_call": - input_type = "tool_calls" - tool_calls = cast("list[ToolCall]", input) - return tool_calls, input_type - input_type = "list" - messages = input - elif isinstance(input, dict) and input.get("__type") == "tool_call_with_context": - # Handle ToolCallWithContext from Send API - # mypy will not be able to type narrow correctly since the signature - # for input contains dict[str, Any]. We'd need to narrow dict[str, Any] - # before we can apply correct typing. - input_with_ctx = cast("ToolCallWithContext", input) - input_type = "tool_calls" - return [input_with_ctx["tool_call"]], input_type - elif isinstance(input, dict) and (messages := input.get(self._messages_key, [])): - input_type = "dict" - elif messages := getattr(input, self._messages_key, []): - # Assume dataclass-like state that can coerce from dict - input_type = "dict" - else: - msg = "No message found in input" - raise ValueError(msg) - - try: - latest_ai_message = next(m for m in reversed(messages) if isinstance(m, AIMessage)) - except StopIteration: - msg = "No AIMessage found in input" - raise ValueError(msg) - - tool_calls = list(latest_ai_message.tool_calls) - return tool_calls, input_type - - def _validate_tool_call(self, call: ToolCall) -> ToolMessage | None: - requested_tool = call["name"] - if requested_tool not in self.tools_by_name: - all_tool_names = list(self.tools_by_name.keys()) - content = INVALID_TOOL_NAME_ERROR_TEMPLATE.format( - requested_tool=requested_tool, - available_tools=", ".join(all_tool_names), - ) - return ToolMessage( - content, name=requested_tool, tool_call_id=call["id"], status="error" - ) - return None - - def _extract_state( - self, input: list[AnyMessage] | dict[str, Any] | BaseModel - ) -> list[AnyMessage] | dict[str, Any] | BaseModel: - """Extract state from input, handling ToolCallWithContext if present. - - Args: - input: The input which may be raw state or ToolCallWithContext. - - Returns: - The actual state to pass to wrap_tool_call wrappers. - """ - if isinstance(input, dict) and input.get("__type") == "tool_call_with_context": - return input["state"] - return input - - def _inject_state( - self, - tool_call: ToolCall, - state: list[AnyMessage] | dict[str, Any] | BaseModel, - ) -> ToolCall: - state_args = self._tool_to_state_args[tool_call["name"]] - - if state_args and isinstance(state, list): - required_fields = list(state_args.values()) - if ( - len(required_fields) == 1 and required_fields[0] == self._messages_key - ) or required_fields[0] is None: - state = {self._messages_key: state} - else: - err_msg = ( - f"Invalid input to ToolNode. Tool {tool_call['name']} requires " - f"graph state dict as input." - ) - if any(state_field for state_field in state_args.values()): - required_fields_str = ", ".join(f for f in required_fields if f) - err_msg += f" State should contain fields {required_fields_str}." - raise ValueError(err_msg) - - if isinstance(state, dict): - tool_state_args = { - tool_arg: state[state_field] if state_field else state - for tool_arg, state_field in state_args.items() - } - else: - tool_state_args = { - tool_arg: getattr(state, state_field) if state_field else state - for tool_arg, state_field in state_args.items() - } - - tool_call["args"] = { - **tool_call["args"], - **tool_state_args, - } - return tool_call - - def _inject_store(self, tool_call: ToolCall, store: BaseStore | None) -> ToolCall: - store_arg = self._tool_to_store_arg[tool_call["name"]] - if not store_arg: - return tool_call - - if store is None: - msg = ( - "Cannot inject store into tools with InjectedStore annotations - " - "please compile your graph with a store." - ) - raise ValueError(msg) - - tool_call["args"] = { - **tool_call["args"], - store_arg: store, - } - return tool_call - - def _inject_runtime(self, tool_call: ToolCall, tool_runtime: ToolRuntime) -> ToolCall: - """Inject ToolRuntime into tool call arguments. - - Args: - tool_call: The tool call to inject runtime into. - tool_runtime: The ToolRuntime instance to inject. - - Returns: - The tool call with runtime injected if needed. - """ - runtime_arg = self._tool_to_runtime_arg.get(tool_call["name"]) - if not runtime_arg: - return tool_call - - tool_call["args"] = { - **tool_call["args"], - runtime_arg: tool_runtime, - } - return tool_call - - def _inject_tool_args( - self, - tool_call: ToolCall, - tool_runtime: ToolRuntime, - ) -> ToolCall: - """Inject graph state, store, and runtime into tool call arguments. - - This is an internal method that enables tools to access graph context that - should not be controlled by the model. Tools can declare dependencies on graph - state, persistent storage, or runtime context using InjectedState, InjectedStore, - and ToolRuntime annotations. This method automatically identifies these - dependencies and injects the appropriate values. - - The injection process preserves the original tool call structure while adding - the necessary context arguments. This allows tools to be both model-callable - and context-aware without exposing internal state management to the model. - - Args: - tool_call: The tool call dictionary to augment with injected arguments. - Must contain 'name', 'args', 'id', and 'type' fields. - tool_runtime: The ToolRuntime instance containing all runtime context - (state, config, store, context, stream_writer) to inject into tools. - - Returns: - A new ToolCall dictionary with the same structure as the input but with - additional arguments injected based on the tool's annotation requirements. - - Raises: - ValueError: If a tool requires store injection but no store is provided, - or if state injection requirements cannot be satisfied. - - !!! note - This method is called automatically during tool execution. It should not - be called from outside the `ToolNode`. - """ - if tool_call["name"] not in self.tools_by_name: - return tool_call - - tool_call_copy: ToolCall = copy(tool_call) - tool_call_with_state = self._inject_state(tool_call_copy, tool_runtime.state) - tool_call_with_store = self._inject_store(tool_call_with_state, tool_runtime.store) - return self._inject_runtime(tool_call_with_store, tool_runtime) - - def _validate_tool_command( - self, - command: Command, - call: ToolCall, - input_type: Literal["list", "dict", "tool_calls"], - ) -> Command: - if isinstance(command.update, dict): - # input type is dict when ToolNode is invoked with a dict input - # (e.g. {"messages": [AIMessage(..., tool_calls=[...])]}) - if input_type not in ("dict", "tool_calls"): - msg = ( - "Tools can provide a dict in Command.update only when using dict " - f"with '{self._messages_key}' key as ToolNode input, " - f"got: {command.update} for tool '{call['name']}'" - ) - raise ValueError(msg) - - updated_command = deepcopy(command) - state_update = cast("dict[str, Any]", updated_command.update) or {} - messages_update = state_update.get(self._messages_key, []) - elif isinstance(command.update, list): - # Input type is list when ToolNode is invoked with a list input - # (e.g. [AIMessage(..., tool_calls=[...])]) - if input_type != "list": - msg = ( - "Tools can provide a list of messages in Command.update " - "only when using list of messages as ToolNode input, " - f"got: {command.update} for tool '{call['name']}'" - ) - raise ValueError(msg) - - updated_command = deepcopy(command) - messages_update = updated_command.update - else: - return command - - # convert to message objects if updates are in a dict format - messages_update = convert_to_messages(messages_update) - - # no validation needed if all messages are being removed - if messages_update == [RemoveMessage(id=REMOVE_ALL_MESSAGES)]: - return updated_command - - has_matching_tool_message = False - for message in messages_update: - if not isinstance(message, ToolMessage): - continue - - if message.tool_call_id == call["id"]: - message.name = call["name"] - has_matching_tool_message = True - - # validate that we always have a ToolMessage matching the tool call in - # Command.update if command is sent to the CURRENT graph - if updated_command.graph is None and not has_matching_tool_message: - example_update = ( - '`Command(update={"messages": ' - '[ToolMessage("Success", tool_call_id=tool_call_id), ...]}, ...)`' - if input_type == "dict" - else "`Command(update=" - '[ToolMessage("Success", tool_call_id=tool_call_id), ...], ...)`' - ) - msg = ( - "Expected to have a matching ToolMessage in Command.update " - f"for tool '{call['name']}', got: {messages_update}. " - "Every tool call (LLM requesting to call a tool) " - "in the message history MUST have a corresponding ToolMessage. " - f"You can fix it by modifying the tool to return {example_update}." - ) - raise ValueError(msg) - return updated_command - - -def tools_condition( - state: list[AnyMessage] | dict[str, Any] | BaseModel, - messages_key: str = "messages", -) -> Literal["tools", "__end__"]: - """Conditional routing function for tool-calling workflows. - - This utility function implements the standard conditional logic for ReAct-style - agents: if the last AI message contains tool calls, route to the tool execution - node; otherwise, end the workflow. This pattern is fundamental to most tool-calling - agent architectures. - - The function handles multiple state formats commonly used in LangGraph applications, - making it flexible for different graph designs while maintaining consistent behavior. - - Args: - state: The current graph state to examine for tool calls. Supported formats: - - Dictionary containing a messages key (for StateGraph) - - BaseModel instance with a messages attribute - messages_key: The key or attribute name containing the message list in the state. - This allows customization for graphs using different state schemas. - Defaults to "messages". - - Returns: - Either "tools" if tool calls are present in the last AI message, or "__end__" - to terminate the workflow. These are the standard routing destinations for - tool-calling conditional edges. - - Raises: - ValueError: If no messages can be found in the provided state format. - - Example: - Basic usage in a ReAct agent: - - ```python - from langgraph.graph import StateGraph - from langchain.tools import ToolNode - from langchain.tools.tool_node import tools_condition - from typing_extensions import TypedDict - - - class State(TypedDict): - messages: list - - - graph = StateGraph(State) - graph.add_node("llm", call_model) - graph.add_node("tools", ToolNode([my_tool])) - graph.add_conditional_edges( - "llm", - tools_condition, # Routes to "tools" or "__end__" - {"tools": "tools", "__end__": "__end__"}, - ) - ``` - - Custom messages key: - - ```python - def custom_condition(state): - return tools_condition(state, messages_key="chat_history") - ``` - - !!! note - This function is designed to work seamlessly with `ToolNode` and standard - LangGraph patterns. It expects the last message to be an `AIMessage` when - tool calls are present, which is the standard output format for tool-calling - language models. - """ - if isinstance(state, list): - ai_message = state[-1] - elif (isinstance(state, dict) and (messages := state.get(messages_key, []))) or ( - messages := getattr(state, messages_key, []) - ): - ai_message = messages[-1] - else: - msg = f"No messages found in input state to tool_edge: {state}" - raise ValueError(msg) - if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: - return "tools" - return "__end__" - - -@dataclass -class ToolRuntime(_DirectlyInjectedToolArg, Generic[ContextT, StateT]): - """Runtime context automatically injected into tools. - - When a tool function has a parameter named `tool_runtime` with type hint - `ToolRuntime`, the tool execution system will automatically inject an instance - containing: - - - `state`: The current graph state - - `tool_call_id`: The ID of the current tool call - - `config`: `RunnableConfig` for the current execution - - `context`: Runtime context (from langgraph `Runtime`) - - `store`: `BaseStore` instance for persistent storage (from langgraph `Runtime`) - - `stream_writer`: `StreamWriter` for streaming output (from langgraph `Runtime`) - - No `Annotated` wrapper is needed - just use `runtime: ToolRuntime` - as a parameter. - - Example: - ```python - from langchain_core.tools import tool - from langchain.tools import ToolRuntime - - @tool - def my_tool(x: int, runtime: ToolRuntime) -> str: - \"\"\"Tool that accesses runtime context.\"\"\" - # Access state - messages = tool_runtime.state["messages"] - - # Access tool_call_id - print(f"Tool call ID: {tool_runtime.tool_call_id}") - - # Access config - print(f"Run ID: {tool_runtime.config.get('run_id')}") - - # Access runtime context - user_id = tool_runtime.context.get("user_id") - - # Access store - tool_runtime.store.put(("metrics",), "count", 1) - - # Stream output - tool_runtime.stream_writer.write("Processing...") - - return f"Processed {x}" - ``` - - !!! note - This is a marker class used for type checking and detection. - The actual runtime object will be constructed during tool execution. - """ - - state: StateT - context: ContextT - config: RunnableConfig - stream_writer: StreamWriter - tool_call_id: str | None - store: BaseStore | None - - -class InjectedState(InjectedToolArg): - """Annotation for injecting graph state into tool arguments. - - This annotation enables tools to access graph state without exposing state - management details to the language model. Tools annotated with `InjectedState` - receive state data automatically during execution while remaining invisible - to the model's tool-calling interface. - - Args: - field: Optional key to extract from the state dictionary. If `None`, the entire - state is injected. If specified, only that field's value is injected. - This allows tools to request specific state components rather than - processing the full state structure. - - Example: - ```python - from typing import List - from typing_extensions import Annotated, TypedDict - - from langchain_core.messages import BaseMessage, AIMessage - from langchain.tools import InjectedState, ToolNode, tool - - - class AgentState(TypedDict): - messages: List[BaseMessage] - foo: str - - - @tool - def state_tool(x: int, state: Annotated[dict, InjectedState]) -> str: - '''Do something with state.''' - if len(state["messages"]) > 2: - return state["foo"] + str(x) - else: - return "not enough messages" - - - @tool - def foo_tool(x: int, foo: Annotated[str, InjectedState("foo")]) -> str: - '''Do something else with state.''' - return foo + str(x + 1) - - - node = ToolNode([state_tool, foo_tool]) - - tool_call1 = {"name": "state_tool", "args": {"x": 1}, "id": "1", "type": "tool_call"} - tool_call2 = {"name": "foo_tool", "args": {"x": 1}, "id": "2", "type": "tool_call"} - state = { - "messages": [AIMessage("", tool_calls=[tool_call1, tool_call2])], - "foo": "bar", - } - node.invoke(state) - ``` - - ```python - [ - ToolMessage(content="not enough messages", name="state_tool", tool_call_id="1"), - ToolMessage(content="bar2", name="foo_tool", tool_call_id="2"), - ] - ``` - - !!! note - - `InjectedState` arguments are automatically excluded from tool schemas - presented to language models - - `ToolNode` handles the injection process during execution - - Tools can mix regular arguments (controlled by the model) with injected - arguments (controlled by the system) - - State injection occurs after the model generates tool calls but before - tool execution - """ - - def __init__(self, field: str | None = None) -> None: - """Initialize the `InjectedState` annotation.""" - self.field = field - - -class InjectedStore(InjectedToolArg): - """Annotation for injecting persistent store into tool arguments. - - This annotation enables tools to access LangGraph's persistent storage system - without exposing storage details to the language model. Tools annotated with - InjectedStore receive the store instance automatically during execution while - remaining invisible to the model's tool-calling interface. - - The store provides persistent, cross-session data storage that tools can use - for maintaining context, user preferences, or any other data that needs to - persist beyond individual workflow executions. - - !!! warning - `InjectedStore` annotation requires `langchain-core >= 0.3.8` - - Example: - ```python - from typing_extensions import Annotated - from langgraph.store.memory import InMemoryStore - from langchain.tools import InjectedStore, ToolNode, tool - - @tool - def save_preference( - key: str, - value: str, - store: Annotated[Any, InjectedStore()] - ) -> str: - \"\"\"Save user preference to persistent storage.\"\"\" - store.put(("preferences",), key, value) - return f"Saved {key} = {value}" - - @tool - def get_preference( - key: str, - store: Annotated[Any, InjectedStore()] - ) -> str: - \"\"\"Retrieve user preference from persistent storage.\"\"\" - result = store.get(("preferences",), key) - return result.value if result else "Not found" - ``` - - Usage with `ToolNode` and graph compilation: - - ```python - from langgraph.graph import StateGraph - from langgraph.store.memory import InMemoryStore - - store = InMemoryStore() - tool_node = ToolNode([save_preference, get_preference]) - - graph = StateGraph(State) - graph.add_node("tools", tool_node) - compiled_graph = graph.compile(store=store) # Store is injected automatically - ``` - - Cross-session persistence: - - ```python - # First session - result1 = graph.invoke({"messages": [HumanMessage("Save my favorite color as blue")]}) - - # Later session - data persists - result2 = graph.invoke({"messages": [HumanMessage("What's my favorite color?")]}) - ``` - - !!! note - - `InjectedStore` arguments are automatically excluded from tool schemas - presented to language models - - The store instance is automatically injected by `ToolNode` during execution - - Tools can access namespaced storage using the store's get/put methods - - Store injection requires the graph to be compiled with a store instance - - Multiple tools can share the same store instance for data consistency - """ - - -def _is_injection( - type_arg: Any, - injection_type: type[InjectedState | InjectedStore | ToolRuntime], -) -> bool: - """Check if a type argument represents an injection annotation. - - This utility function determines whether a type annotation indicates that - an argument should be injected with state or store data. It handles both - direct annotations and nested annotations within Union or Annotated types. - - Args: - type_arg: The type argument to check for injection annotations. - injection_type: The injection type to look for (InjectedState or InjectedStore). - - Returns: - True if the type argument contains the specified injection annotation. - """ - if isinstance(type_arg, injection_type) or ( - isinstance(type_arg, type) and issubclass(type_arg, injection_type) - ): - return True - origin_ = get_origin(type_arg) - if origin_ is Union or origin_ is Annotated: - return any(_is_injection(ta, injection_type) for ta in get_args(type_arg)) - return False - - -def _get_state_args(tool: BaseTool) -> dict[str, str | None]: - """Extract state injection mappings from tool annotations. - - This function analyzes a tool's input schema to identify arguments that should - be injected with graph state. It processes InjectedState annotations to build - a mapping of tool argument names to state field names. - - Args: - tool: The tool to analyze for state injection requirements. - - Returns: - A dictionary mapping tool argument names to state field names. If a field - name is None, the entire state should be injected for that argument. - """ - full_schema = tool.get_input_schema() - tool_args_to_state_fields: dict = {} - - for name, type_ in get_all_basemodel_annotations(full_schema).items(): - injections = [ - type_arg for type_arg in get_args(type_) if _is_injection(type_arg, InjectedState) - ] - if len(injections) > 1: - msg = ( - "A tool argument should not be annotated with InjectedState more than " - f"once. Received arg {name} with annotations {injections}." - ) - raise ValueError(msg) - if len(injections) == 1: - injection = injections[0] - if isinstance(injection, InjectedState) and injection.field: - tool_args_to_state_fields[name] = injection.field - else: - tool_args_to_state_fields[name] = None - else: - pass - return tool_args_to_state_fields - - -def _get_store_arg(tool: BaseTool) -> str | None: - """Extract store injection argument from tool annotations. - - This function analyzes a tool's input schema to identify the argument that - should be injected with the graph store. Only one store argument is supported - per tool. - - Args: - tool: The tool to analyze for store injection requirements. - - Returns: - The name of the argument that should receive the store injection, or None - if no store injection is required. - - Raises: - ValueError: If a tool argument has multiple InjectedStore annotations. - """ - full_schema = tool.get_input_schema() - for name, type_ in get_all_basemodel_annotations(full_schema).items(): - injections = [ - type_arg for type_arg in get_args(type_) if _is_injection(type_arg, InjectedStore) - ] - if len(injections) > 1: - msg = ( - "A tool argument should not be annotated with InjectedStore more than " - f"once. Received arg {name} with annotations {injections}." - ) - raise ValueError(msg) - if len(injections) == 1: - return name - - return None - - -def _get_runtime_arg(tool: BaseTool) -> str | None: - """Extract runtime injection argument from tool annotations. - - This function analyzes a tool's input schema to identify the argument that - should be injected with the ToolRuntime instance. Only one runtime argument - is supported per tool. - - Args: - tool: The tool to analyze for runtime injection requirements. - - Returns: - The name of the argument that should receive the runtime injection, or None - if no runtime injection is required. - - Raises: - ValueError: If a tool argument has multiple ToolRuntime annotations. - """ - full_schema = tool.get_input_schema() - for name, type_ in get_all_basemodel_annotations(full_schema).items(): - # Check if the parameter name is "runtime" (regardless of type) - if name == "runtime": - return name - # Check if the type itself is ToolRuntime (direct usage) - if _is_injection(type_, ToolRuntime): - return name - # Check if ToolRuntime is in Annotated args - injections = [ - type_arg for type_arg in get_args(type_) if _is_injection(type_arg, ToolRuntime) - ] - if len(injections) > 1: - msg = ( - "A tool argument should not be annotated with ToolRuntime more than " - f"once. Received arg {name} with annotations {injections}." - ) - raise ValueError(msg) - if len(injections) == 1: - return name - - return None diff --git a/libs/langchain_v1/pyproject.toml b/libs/langchain_v1/pyproject.toml index f0bf36049e895..7c76e4e417634 100644 --- a/libs/langchain_v1/pyproject.toml +++ b/libs/langchain_v1/pyproject.toml @@ -8,7 +8,7 @@ license = { text = "MIT" } requires-python = ">=3.10.0,<4.0.0" dependencies = [ "langchain-core>=1.0.0,<2.0.0", - "langgraph>=1.0.0,<1.1.0", + "langgraph>=1.0.2,<1.1.0", "pydantic>=2.7.4,<3.0.0", ] diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_override_methods.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_override_methods.py index a73ae488fc48a..8005b40fdbd2e 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_override_methods.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_override_methods.py @@ -6,7 +6,7 @@ from langchain_core.tools import BaseTool from langchain.agents.middleware.types import ModelRequest -from langchain.tools.tool_node import ToolCallRequest +from langchain.agents.middleware.types import ToolCallRequest class TestModelRequestOverride: diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_wrap_tool_call_decorator.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_wrap_tool_call_decorator.py index b602c43208e12..ca92bc65b7837 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_wrap_tool_call_decorator.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_wrap_tool_call_decorator.py @@ -13,7 +13,7 @@ from langchain.agents.factory import create_agent from langchain.agents.middleware.types import wrap_tool_call -from langchain.tools.tool_node import ToolCallRequest +from langchain.agents.middleware.types import ToolCallRequest from tests.unit_tests.agents.test_middleware_agent import FakeToolCallingModel diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_create_agent_tool_validation.py b/libs/langchain_v1/tests/unit_tests/agents/test_create_agent_tool_validation.py new file mode 100644 index 0000000000000..537dd9e9544cd --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/test_create_agent_tool_validation.py @@ -0,0 +1,370 @@ +import sys +import pytest +from typing import Annotated + +from langchain.agents import AgentState, create_agent +from langchain.tools import InjectedState, tool as dec_tool +from .model import FakeToolCallingModel +from langchain_core.messages import HumanMessage +from langgraph.prebuilt import InjectedStore, ToolRuntime +from langgraph.store.base import BaseStore +from langgraph.store.memory import InMemoryStore + + +@pytest.mark.skipif( + sys.version_info >= (3, 14), reason="Pydantic model rebuild issue in Python 3.14" +) +def test_tool_invocation_error_excludes_injected_state() -> None: + """Test that tool invocation errors only include LLM-controllable arguments. + When a tool has InjectedState parameters and the LLM makes an incorrect + invocation (e.g., missing required arguments), the error message should only + contain the arguments from the tool call that the LLM controls. This ensures + the LLM receives relevant context to correct its mistakes, without being + distracted by system-injected parameters it has no control over. + This test uses create_agent to ensure the behavior works in a full agent context. + """ + + # Define a custom state schema with injected data + class TestState(AgentState): + secret_data: str # Example of state data not controlled by LLM + + @dec_tool + def tool_with_injected_state( + some_val: int, + state: Annotated[TestState, InjectedState], + ) -> str: + """Tool that uses injected state.""" + return f"some_val: {some_val}" + + # Create a fake model that makes an incorrect tool call (missing 'some_val') + # Then returns no tool calls on the second iteration to end the loop + model = FakeToolCallingModel( + tool_calls=[ + [ + { + "name": "tool_with_injected_state", + "args": {"wrong_arg": "value"}, # Missing required 'some_val' + "id": "call_1", + } + ], + [], # No tool calls on second iteration to end the loop + ] + ) + + # Create an agent with the tool and custom state schema + agent = create_agent( + model=model, + tools=[tool_with_injected_state], + state_schema=TestState, + ) + + # Invoke the agent with injected state data + result = agent.invoke( + { + "messages": [HumanMessage("Test message")], + "secret_data": "sensitive_secret_123", + } + ) + + # Find the tool error message + tool_messages = [m for m in result["messages"] if m.type == "tool"] + assert len(tool_messages) == 1 + tool_message = tool_messages[0] + assert tool_message.status == "error" + + # The error message should contain only the LLM-provided args (wrong_arg) + # and NOT the system-injected state (secret_data) + assert "{'wrong_arg': 'value'}" in tool_message.content + assert "secret_data" not in tool_message.content + assert "sensitive_secret_123" not in tool_message.content + + +@pytest.mark.skipif( + sys.version_info >= (3, 14), reason="Pydantic model rebuild issue in Python 3.14" +) +async def test_tool_invocation_error_excludes_injected_state_async() -> None: + """Test that async tool invocation errors only include LLM-controllable arguments. + This test verifies that the async execution path (_execute_tool_async and _arun_one) + properly filters validation errors to exclude system-injected arguments, ensuring + the LLM receives only relevant context for correction. + """ + + # Define a custom state schema + class TestState(AgentState): + internal_data: str + + @dec_tool + async def async_tool_with_injected_state( + query: str, + max_results: int, + state: Annotated[TestState, InjectedState], + ) -> str: + """Async tool that uses injected state.""" + return f"query: {query}, max_results: {max_results}" + + # Create a fake model that makes an incorrect tool call + # - query has wrong type (int instead of str) + # - max_results is missing + model = FakeToolCallingModel( + tool_calls=[ + [ + { + "name": "async_tool_with_injected_state", + "args": {"query": 999}, # Wrong type, missing max_results + "id": "call_async_1", + } + ], + [], # End the loop + ] + ) + + # Create an agent with the async tool + agent = create_agent( + model=model, + tools=[async_tool_with_injected_state], + state_schema=TestState, + ) + + # Invoke with state data + result = await agent.ainvoke( + { + "messages": [HumanMessage("Test async")], + "internal_data": "secret_internal_value_xyz", + } + ) + + # Find the tool error message + tool_messages = [m for m in result["messages"] if m.type == "tool"] + assert len(tool_messages) == 1 + tool_message = tool_messages[0] + assert tool_message.status == "error" + + # Verify error mentions LLM-controlled parameters only + content = tool_message.content + assert "query" in content.lower(), "Error should mention 'query' (LLM-controlled)" + assert "max_results" in content.lower(), "Error should mention 'max_results' (LLM-controlled)" + + # Verify system-injected state does not appear in the validation errors + # This keeps the error focused on what the LLM can actually fix + assert "internal_data" not in content, ( + "Error should NOT mention 'internal_data' (system-injected field)" + ) + assert "secret_internal_value" not in content, ( + "Error should NOT contain system-injected state values" + ) + + # Verify only LLM-controlled parameters are in the error list + # Should see "query" and "max_results" errors, but not "state" + lines = content.split("\n") + error_lines = [line.strip() for line in lines if line.strip()] + # Find lines that look like field names (single words at start of line) + field_errors = [ + line + for line in error_lines + if line + and not line.startswith("input") + and not line.startswith("field") + and not line.startswith("error") + and not line.startswith("please") + and len(line.split()) <= 2 + ] + # Verify system-injected 'state' is not in the field error list + assert not any("state" == field.lower() for field in field_errors), ( + "The field 'state' (system-injected) should not appear in validation errors" + ) + + +@pytest.mark.skipif( + sys.version_info >= (3, 14), reason="Pydantic model rebuild issue in Python 3.14" +) +async def test_create_agent_error_content_with_multiple_params() -> None: + """Test that error messages only include LLM-controlled parameter errors. + Uses create_agent to verify that when a tool with both LLM-controlled + and system-injected parameters receives invalid arguments, the error message: + 1. Contains details about LLM-controlled parameter errors (query, limit) + 2. Does NOT contain system-injected parameter names (state, store, runtime) + 3. Does NOT contain values from system-injected parameters + 4. Properly formats the validation errors for LLM correction + This ensures the LLM receives focused, actionable feedback. + """ + + class TestState(AgentState): + user_id: str + api_key: str + session_data: dict + + @dec_tool + def complex_tool( + query: str, + limit: int, + state: Annotated[TestState, InjectedState], + store: Annotated[BaseStore, InjectedStore()], + runtime: ToolRuntime, + ) -> str: + """A complex tool with multiple injected and non-injected parameters. + Args: + query: The search query string. + limit: Maximum number of results to return. + state: The graph state (injected). + store: The persistent store (injected). + runtime: The tool runtime context (injected). + """ + # Access injected params to verify they work in normal execution + user = state.get("user_id", "unknown") + return f"Results for '{query}' (limit={limit}, user={user})" + + # Create a model that makes an incorrect tool call with multiple errors: + # - query is wrong type (int instead of str) + # - limit is missing + # Then returns no tool calls to end the loop + model = FakeToolCallingModel( + tool_calls=[ + [ + { + "name": "complex_tool", + "args": { + "query": 12345, # Wrong type - should be str + # "limit" is missing - required field + }, + "id": "call_complex_1", + } + ], + [], # No tool calls on second iteration to end the loop + ] + ) + + # Create an agent with the complex tool and custom state + # Need to provide a store since the tool uses InjectedStore + agent = create_agent( + model=model, + tools=[complex_tool], + state_schema=TestState, + store=InMemoryStore(), + ) + + # Invoke with sensitive data in state + result = agent.invoke( + { + "messages": [HumanMessage("Search for something")], + "user_id": "user_12345", + "api_key": "sk-secret-key-abc123xyz", + "session_data": {"token": "secret_session_token"}, + } + ) + + # Find the tool error message + tool_messages = [m for m in result["messages"] if m.type == "tool"] + assert len(tool_messages) == 1 + tool_message = tool_messages[0] + assert tool_message.status == "error" + assert tool_message.tool_call_id == "call_complex_1" + + content = tool_message.content + + # Verify error mentions LLM-controlled parameter issues + assert "query" in content.lower(), "Error should mention 'query' (LLM-controlled)" + assert "limit" in content.lower(), "Error should mention 'limit' (LLM-controlled)" + + # Should indicate validation errors occurred + assert "validation error" in content.lower() or "error" in content.lower(), ( + "Error should indicate validation occurred" + ) + + # Verify NO system-injected parameter names appear in error + # These are not controlled by the LLM and should be excluded + assert "state" not in content.lower(), "Error should NOT mention 'state' (system-injected)" + assert "store" not in content.lower(), "Error should NOT mention 'store' (system-injected)" + assert "runtime" not in content.lower(), "Error should NOT mention 'runtime' (system-injected)" + + # Verify NO values from system-injected parameters appear in error + # The LLM doesn't control these, so they shouldn't distract from the actual issues + assert "user_12345" not in content, "Error should NOT contain user_id value (from state)" + assert "sk-secret-key" not in content, "Error should NOT contain api_key value (from state)" + assert "secret_session_token" not in content, ( + "Error should NOT contain session_data value (from state)" + ) + + # Verify the LLM's original tool call args are present + # The error should show what the LLM actually provided to help it correct the mistake + assert "12345" in content, "Error should show the invalid query value provided by LLM (12345)" + + # Check error is well-formatted + assert "complex_tool" in content, "Error should mention the tool name" + + +@pytest.mark.skipif( + sys.version_info >= (3, 14), reason="Pydantic model rebuild issue in Python 3.14" +) +async def test_create_agent_error_only_model_controllable_params() -> None: + """Test that errors only include LLM-controllable parameter issues. + Focused test ensuring that validation errors for LLM-controlled parameters + are clearly reported, while system-injected parameters remain completely + absent from error messages. This provides focused feedback to the LLM. + """ + + class StateWithSecrets(AgentState): + password: str # Example of data not controlled by LLM + + @dec_tool + def secure_tool( + username: str, + email: str, + state: Annotated[StateWithSecrets, InjectedState], + ) -> str: + """Tool that validates user credentials. + Args: + username: The username (3-20 chars). + email: The email address. + state: State with password (system-injected). + """ + return f"Validated {username} with email {email}" + + # LLM provides invalid username (too short) and invalid email + model = FakeToolCallingModel( + tool_calls=[ + [ + { + "name": "secure_tool", + "args": { + "username": "ab", # Too short (needs 3-20) + "email": "not-an-email", # Invalid format + }, + "id": "call_secure_1", + } + ], + [], + ] + ) + + agent = create_agent( + model=model, + tools=[secure_tool], + state_schema=StateWithSecrets, + ) + + result = agent.invoke( + { + "messages": [HumanMessage("Create account")], + "password": "super_secret_password_12345", + } + ) + + tool_messages = [m for m in result["messages"] if m.type == "tool"] + assert len(tool_messages) == 1 + content = tool_messages[0].content + + # The error should mention LLM-controlled parameters + # Note: Pydantic's default validation may or may not catch format issues, + # but the parameters themselves should be present in error messages + assert "username" in content.lower() or "email" in content.lower(), ( + "Error should mention at least one LLM-controlled parameter" + ) + + # Password is system-injected and should not appear + # The LLM doesn't control it, so it shouldn't distract from the actual errors + assert "password" not in content.lower(), ( + "Error should NOT mention 'password' (system-injected parameter)" + ) + assert "super_secret_password" not in content, ( + "Error should NOT contain password value (from system-injected state)" + ) diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_middleware_tools.py b/libs/langchain_v1/tests/unit_tests/agents/test_middleware_tools.py index 01edf609b0449..3bccc7ba34229 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_middleware_tools.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_middleware_tools.py @@ -6,7 +6,7 @@ from langchain.agents.factory import create_agent from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest -from langchain.tools.tool_node import _ToolNode +from langgraph.prebuilt.tool_node import ToolNode from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langchain_core.tools import tool from .model import FakeToolCallingModel @@ -326,9 +326,9 @@ def some_tool(input: str) -> str: """Some tool.""" return "result" - tool_node = _ToolNode([some_tool]) + tool_node = ToolNode([some_tool]) - with pytest.raises(TypeError, match="'_ToolNode' object is not iterable"): + with pytest.raises(TypeError, match="'ToolNode' object is not iterable"): create_agent( model=FakeToolCallingModel(), tools=tool_node, # type: ignore[arg-type] diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_on_tool_call_middleware.py b/libs/langchain_v1/tests/unit_tests/agents/test_on_tool_call_middleware.py index f9dbd4278caee..e936562ac0395 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_on_tool_call_middleware.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_on_tool_call_middleware.py @@ -10,7 +10,7 @@ from langchain.agents.factory import create_agent from langchain.agents.middleware.types import AgentMiddleware -from langchain.tools.tool_node import ToolCallRequest +from langchain.agents.middleware.types import ToolCallRequest from tests.unit_tests.agents.test_middleware_agent import FakeToolCallingModel diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_sync_async_tool_wrapper_composition.py b/libs/langchain_v1/tests/unit_tests/agents/test_sync_async_tool_wrapper_composition.py index 3d7b6e43f0e34..c565719d3f221 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_sync_async_tool_wrapper_composition.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_sync_async_tool_wrapper_composition.py @@ -13,7 +13,7 @@ from langchain.agents.factory import create_agent from langchain.agents.middleware.types import AgentMiddleware, wrap_tool_call -from langchain.tools.tool_node import ToolCallRequest +from langchain.agents.middleware.types import ToolCallRequest from tests.unit_tests.agents.test_middleware_agent import FakeToolCallingModel diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_tool_node.py b/libs/langchain_v1/tests/unit_tests/agents/test_tool_node.py deleted file mode 100644 index 20200d7cb1223..0000000000000 --- a/libs/langchain_v1/tests/unit_tests/agents/test_tool_node.py +++ /dev/null @@ -1,1716 +0,0 @@ -import contextlib -import dataclasses -import json -import sys -from functools import partial -from typing import ( - Annotated, - Any, - NoReturn, - TypeVar, -) -from unittest.mock import Mock -from langchain.agents import create_agent -from langchain.agents.middleware.types import AgentState - -import pytest -from langchain_core.messages import ( - AIMessage, - AnyMessage, - HumanMessage, - RemoveMessage, - ToolCall, - ToolMessage, -) -from langchain_core.runnables.config import RunnableConfig -from langchain_core.tools import BaseTool, ToolException -from langchain_core.tools import tool as dec_tool -from langgraph.config import get_stream_writer -from langgraph.errors import GraphBubbleUp, GraphInterrupt -from langgraph.graph import START, MessagesState, StateGraph -from langgraph.graph.message import REMOVE_ALL_MESSAGES, add_messages -from langgraph.store.base import BaseStore -from langgraph.store.memory import InMemoryStore -from langgraph.types import Command, Send -from pydantic import BaseModel -from pydantic.v1 import BaseModel as BaseModelV1 -from typing_extensions import TypedDict - -from langchain.tools import ( - InjectedState, - InjectedStore, -) -from langchain.tools.tool_node import _ToolNode -from langchain.tools.tool_node import TOOL_CALL_ERROR_TEMPLATE, ToolInvocationError, tools_condition - -from .messages import _AnyIdHumanMessage, _AnyIdToolMessage -from .model import FakeToolCallingModel - -pytestmark = pytest.mark.anyio - - -def _create_mock_runtime(store: BaseStore | None = None) -> Mock: - """Create a mock Runtime object for testing ToolNode outside of graph context. - - This helper is needed because ToolNode._func expects a Runtime parameter - which is injected by RunnableCallable from config["configurable"]["__pregel_runtime"]. - When testing ToolNode directly (outside a graph), we need to provide this manually. - """ - mock_runtime = Mock() - mock_runtime.store = store - mock_runtime.context = None - mock_runtime.stream_writer = lambda *args, **kwargs: None - return mock_runtime - - -def _create_config_with_runtime(store: BaseStore | None = None) -> RunnableConfig: - """Create a RunnableConfig with mock Runtime for testing ToolNode. - - Returns: - RunnableConfig with __pregel_runtime in configurable dict. - """ - return {"configurable": {"__pregel_runtime": _create_mock_runtime(store)}} - - -def tool1(some_val: int, some_other_val: str) -> str: - """Tool 1 docstring.""" - if some_val == 0: - msg = "Test error" - raise ValueError(msg) - return f"{some_val} - {some_other_val}" - - -async def tool2(some_val: int, some_other_val: str) -> str: - """Tool 2 docstring.""" - if some_val == 0: - msg = "Test error" - raise ToolException(msg) - return f"tool2: {some_val} - {some_other_val}" - - -async def tool3(some_val: int, some_other_val: str) -> str: - """Tool 3 docstring.""" - return [ - {"key_1": some_val, "key_2": "foo"}, - {"key_1": some_other_val, "key_2": "baz"}, - ] - - -async def tool4(some_val: int, some_other_val: str) -> str: - """Tool 4 docstring.""" - return [ - {"type": "image_url", "image_url": {"url": "abdc"}}, - ] - - -@dec_tool -def tool5(some_val: int) -> NoReturn: - """Tool 5 docstring.""" - msg = "Test error" - raise ToolException(msg) - - -tool5.handle_tool_error = "foo" - - -async def test_tool_node() -> None: - """Test tool node.""" - result = _ToolNode([tool1]).invoke( - { - "messages": [ - AIMessage( - "hi?", - tool_calls=[ - { - "name": "tool1", - "args": {"some_val": 1, "some_other_val": "foo"}, - "id": "some 0", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - tool_message: ToolMessage = result["messages"][-1] - assert tool_message.type == "tool" - assert tool_message.content == "1 - foo" - assert tool_message.tool_call_id == "some 0" - - result2 = await _ToolNode([tool2]).ainvoke( - { - "messages": [ - AIMessage( - "hi?", - tool_calls=[ - { - "name": "tool2", - "args": {"some_val": 2, "some_other_val": "bar"}, - "id": "some 1", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - tool_message: ToolMessage = result2["messages"][-1] - assert tool_message.type == "tool" - assert tool_message.content == "tool2: 2 - bar" - - # list of dicts tool content - result3 = await _ToolNode([tool3]).ainvoke( - { - "messages": [ - AIMessage( - "hi?", - tool_calls=[ - { - "name": "tool3", - "args": {"some_val": 2, "some_other_val": "bar"}, - "id": "some 2", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - tool_message: ToolMessage = result3["messages"][-1] - assert tool_message.type == "tool" - assert ( - tool_message.content == '[{"key_1": 2, "key_2": "foo"}, {"key_1": "bar", "key_2": "baz"}]' - ) - assert tool_message.tool_call_id == "some 2" - - # list of content blocks tool content - result4 = await _ToolNode([tool4]).ainvoke( - { - "messages": [ - AIMessage( - "hi?", - tool_calls=[ - { - "name": "tool4", - "args": {"some_val": 2, "some_other_val": "bar"}, - "id": "some 3", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - tool_message: ToolMessage = result4["messages"][-1] - assert tool_message.type == "tool" - assert tool_message.content == [{"type": "image_url", "image_url": {"url": "abdc"}}] - assert tool_message.tool_call_id == "some 3" - - -async def test_tool_node_tool_call_input() -> None: - # Single tool call - tool_call_1 = { - "name": "tool1", - "args": {"some_val": 1, "some_other_val": "foo"}, - "id": "some 0", - "type": "tool_call", - } - result = _ToolNode([tool1]).invoke([tool_call_1], config=_create_config_with_runtime()) - assert result["messages"] == [ - ToolMessage(content="1 - foo", tool_call_id="some 0", name="tool1"), - ] - - # Multiple tool calls - tool_call_2 = { - "name": "tool1", - "args": {"some_val": 2, "some_other_val": "bar"}, - "id": "some 1", - "type": "tool_call", - } - result = _ToolNode([tool1]).invoke( - [tool_call_1, tool_call_2], config=_create_config_with_runtime() - ) - assert result["messages"] == [ - ToolMessage(content="1 - foo", tool_call_id="some 0", name="tool1"), - ToolMessage(content="2 - bar", tool_call_id="some 1", name="tool1"), - ] - - # Test with unknown tool - tool_call_3 = tool_call_1.copy() - tool_call_3["name"] = "tool2" - result = _ToolNode([tool1]).invoke( - [tool_call_1, tool_call_3], config=_create_config_with_runtime() - ) - assert result["messages"] == [ - ToolMessage(content="1 - foo", tool_call_id="some 0", name="tool1"), - ToolMessage( - content="Error: tool2 is not a valid tool, try one of [tool1].", - name="tool2", - tool_call_id="some 0", - status="error", - ), - ] - - -def test_tool_node_error_handling_default_invocation() -> None: - tn = _ToolNode([tool1]) - result = tn.invoke( - { - "messages": [ - AIMessage( - "hi?", - tool_calls=[ - { - "name": "tool1", - "args": {"invalid": 0, "args": "foo"}, - "id": "some id", - }, - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - assert all(m.type == "tool" for m in result["messages"]) - assert all(m.status == "error" for m in result["messages"]) - assert ( - "Error invoking tool 'tool1' with kwargs {'invalid': 0, 'args': 'foo'} with error:\n" - in result["messages"][0].content - ) - - -def test_tool_node_error_handling_default_exception() -> None: - tn = _ToolNode([tool1]) - with pytest.raises(ValueError): - tn.invoke( - { - "messages": [ - AIMessage( - "hi?", - tool_calls=[ - { - "name": "tool1", - "args": {"some_val": 0, "some_other_val": "foo"}, - "id": "some id", - }, - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - -@pytest.mark.skipif( - sys.version_info >= (3, 14), reason="Pydantic model rebuild issue in Python 3.14" -) -def test_tool_invocation_error_excludes_injected_state() -> None: - """Test that tool invocation errors only include LLM-controllable arguments. - - When a tool has InjectedState parameters and the LLM makes an incorrect - invocation (e.g., missing required arguments), the error message should only - contain the arguments from the tool call that the LLM controls. This ensures - the LLM receives relevant context to correct its mistakes, without being - distracted by system-injected parameters it has no control over. - - This test uses create_agent to ensure the behavior works in a full agent context. - """ - - # Define a custom state schema with injected data - class TestState(AgentState): - secret_data: str # Example of state data not controlled by LLM - - @dec_tool - def tool_with_injected_state( - some_val: int, - state: Annotated[TestState, InjectedState], - ) -> str: - """Tool that uses injected state.""" - return f"some_val: {some_val}" - - # Create a fake model that makes an incorrect tool call (missing 'some_val') - # Then returns no tool calls on the second iteration to end the loop - model = FakeToolCallingModel( - tool_calls=[ - [ - { - "name": "tool_with_injected_state", - "args": {"wrong_arg": "value"}, # Missing required 'some_val' - "id": "call_1", - } - ], - [], # No tool calls on second iteration to end the loop - ] - ) - - # Create an agent with the tool and custom state schema - agent = create_agent( - model=model, - tools=[tool_with_injected_state], - state_schema=TestState, - ) - - # Invoke the agent with injected state data - result = agent.invoke( - { - "messages": [HumanMessage("Test message")], - "secret_data": "sensitive_secret_123", - } - ) - - # Find the tool error message - tool_messages = [m for m in result["messages"] if m.type == "tool"] - assert len(tool_messages) == 1 - tool_message = tool_messages[0] - assert tool_message.status == "error" - - # The error message should contain only the LLM-provided args (wrong_arg) - # and NOT the system-injected state (secret_data) - assert "{'wrong_arg': 'value'}" in tool_message.content - assert "secret_data" not in tool_message.content - assert "sensitive_secret_123" not in tool_message.content - - -@pytest.mark.skipif( - sys.version_info >= (3, 14), reason="Pydantic model rebuild issue in Python 3.14" -) -async def test_tool_invocation_error_excludes_injected_state_async() -> None: - """Test that async tool invocation errors only include LLM-controllable arguments. - - This test verifies that the async execution path (_execute_tool_async and _arun_one) - properly filters validation errors to exclude system-injected arguments, ensuring - the LLM receives only relevant context for correction. - """ - - # Define a custom state schema - class TestState(AgentState): - internal_data: str - - @dec_tool - async def async_tool_with_injected_state( - query: str, - max_results: int, - state: Annotated[TestState, InjectedState], - ) -> str: - """Async tool that uses injected state.""" - return f"query: {query}, max_results: {max_results}" - - # Create a fake model that makes an incorrect tool call - # - query has wrong type (int instead of str) - # - max_results is missing - model = FakeToolCallingModel( - tool_calls=[ - [ - { - "name": "async_tool_with_injected_state", - "args": {"query": 999}, # Wrong type, missing max_results - "id": "call_async_1", - } - ], - [], # End the loop - ] - ) - - # Create an agent with the async tool - agent = create_agent( - model=model, - tools=[async_tool_with_injected_state], - state_schema=TestState, - ) - - # Invoke with state data - result = await agent.ainvoke( - { - "messages": [HumanMessage("Test async")], - "internal_data": "secret_internal_value_xyz", - } - ) - - # Find the tool error message - tool_messages = [m for m in result["messages"] if m.type == "tool"] - assert len(tool_messages) == 1 - tool_message = tool_messages[0] - assert tool_message.status == "error" - - # Verify error mentions LLM-controlled parameters only - content = tool_message.content - assert "query" in content.lower(), "Error should mention 'query' (LLM-controlled)" - assert "max_results" in content.lower(), "Error should mention 'max_results' (LLM-controlled)" - - # Verify system-injected state does not appear in the validation errors - # This keeps the error focused on what the LLM can actually fix - assert "internal_data" not in content, ( - "Error should NOT mention 'internal_data' (system-injected field)" - ) - assert "secret_internal_value" not in content, ( - "Error should NOT contain system-injected state values" - ) - - # Verify only LLM-controlled parameters are in the error list - # Should see "query" and "max_results" errors, but not "state" - lines = content.split("\n") - error_lines = [line.strip() for line in lines if line.strip()] - # Find lines that look like field names (single words at start of line) - field_errors = [ - line - for line in error_lines - if line - and not line.startswith("input") - and not line.startswith("field") - and not line.startswith("error") - and not line.startswith("please") - and len(line.split()) <= 2 - ] - # Verify system-injected 'state' is not in the field error list - assert not any("state" == field.lower() for field in field_errors), ( - "The field 'state' (system-injected) should not appear in validation errors" - ) - - -async def test_tool_node_error_handling() -> None: - def handle_all(e: ValueError | ToolException | ToolInvocationError): - return TOOL_CALL_ERROR_TEMPLATE.format(error=repr(e)) - - # test catching all exceptions, via: - # - handle_tool_errors = True - # - passing a tuple of all exceptions - # - passing a callable with all exceptions in the signature - for handle_tool_errors in ( - True, - (ValueError, ToolException, ToolInvocationError), - handle_all, - ): - result_error = await _ToolNode( - [tool1, tool2, tool3], handle_tool_errors=handle_tool_errors - ).ainvoke( - { - "messages": [ - AIMessage( - "hi?", - tool_calls=[ - { - "name": "tool1", - "args": {"some_val": 0, "some_other_val": "foo"}, - "id": "some id", - }, - { - "name": "tool2", - "args": {"some_val": 0, "some_other_val": "bar"}, - "id": "some other id", - }, - { - "name": "tool3", - "args": {"some_val": 0}, - "id": "another id", - }, - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - assert all(m.type == "tool" for m in result_error["messages"]) - assert all(m.status == "error" for m in result_error["messages"]) - assert ( - result_error["messages"][0].content - == f"Error: {ValueError('Test error')!r}\n Please fix your mistakes." - ) - assert ( - result_error["messages"][1].content - == f"Error: {ToolException('Test error')!r}\n Please fix your mistakes." - ) - # Check that the validation error contains the field name - assert "some_other_val" in result_error["messages"][2].content - - assert result_error["messages"][0].tool_call_id == "some id" - assert result_error["messages"][1].tool_call_id == "some other id" - assert result_error["messages"][2].tool_call_id == "another id" - - -async def test_tool_node_error_handling_callable() -> None: - def handle_value_error(e: ValueError) -> str: - return "Value error" - - def handle_tool_exception(e: ToolException) -> str: - return "Tool exception" - - for handle_tool_errors in ("Value error", handle_value_error): - result_error = await _ToolNode([tool1], handle_tool_errors=handle_tool_errors).ainvoke( - { - "messages": [ - AIMessage( - "hi?", - tool_calls=[ - { - "name": "tool1", - "args": {"some_val": 0, "some_other_val": "foo"}, - "id": "some id", - }, - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - tool_message: ToolMessage = result_error["messages"][-1] - assert tool_message.type == "tool" - assert tool_message.status == "error" - assert tool_message.content == "Value error" - - # test raising for an unhandled exception, via: - # - passing a tuple of all exceptions - # - passing a callable with all exceptions in the signature - for handle_tool_errors in ((ValueError,), handle_value_error): - with pytest.raises(ToolException) as exc_info: - await _ToolNode([tool1, tool2], handle_tool_errors=handle_tool_errors).ainvoke( - { - "messages": [ - AIMessage( - "hi?", - tool_calls=[ - { - "name": "tool1", - "args": {"some_val": 0, "some_other_val": "foo"}, - "id": "some id", - }, - { - "name": "tool2", - "args": {"some_val": 0, "some_other_val": "bar"}, - "id": "some other id", - }, - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - assert str(exc_info.value) == "Test error" - - for handle_tool_errors in ((ToolException,), handle_tool_exception): - with pytest.raises(ValueError) as exc_info: - await _ToolNode([tool1, tool2], handle_tool_errors=handle_tool_errors).ainvoke( - { - "messages": [ - AIMessage( - "hi?", - tool_calls=[ - { - "name": "tool1", - "args": {"some_val": 0, "some_other_val": "foo"}, - "id": "some id", - }, - { - "name": "tool2", - "args": {"some_val": 0, "some_other_val": "bar"}, - "id": "some other id", - }, - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - assert str(exc_info.value) == "Test error" - - -async def test_tool_node_handle_tool_errors_false() -> None: - with pytest.raises(ValueError) as exc_info: - _ToolNode([tool1], handle_tool_errors=False).invoke( - { - "messages": [ - AIMessage( - "hi?", - tool_calls=[ - { - "name": "tool1", - "args": {"some_val": 0, "some_other_val": "foo"}, - "id": "some id", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - assert str(exc_info.value) == "Test error" - - with pytest.raises(ToolException): - await _ToolNode([tool2], handle_tool_errors=False).ainvoke( - { - "messages": [ - AIMessage( - "hi?", - tool_calls=[ - { - "name": "tool2", - "args": {"some_val": 0, "some_other_val": "bar"}, - "id": "some id", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - assert str(exc_info.value) == "Test error" - - # test validation errors get raised if handle_tool_errors is False - with pytest.raises(ToolInvocationError): - _ToolNode([tool1], handle_tool_errors=False).invoke( - { - "messages": [ - AIMessage( - "hi?", - tool_calls=[ - { - "name": "tool1", - "args": {"some_val": 0}, - "id": "some id", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - -def test_tool_node_individual_tool_error_handling() -> None: - # test error handling on individual tools (and that it overrides overall error handling!) - result_individual_tool_error_handler = _ToolNode([tool5], handle_tool_errors="bar").invoke( - { - "messages": [ - AIMessage( - "hi?", - tool_calls=[ - { - "name": "tool5", - "args": {"some_val": 0}, - "id": "some 0", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - tool_message: ToolMessage = result_individual_tool_error_handler["messages"][-1] - assert tool_message.type == "tool" - assert tool_message.status == "error" - assert tool_message.content == "foo" - assert tool_message.tool_call_id == "some 0" - - -def test_tool_node_incorrect_tool_name() -> None: - result_incorrect_name = _ToolNode([tool1, tool2]).invoke( - { - "messages": [ - AIMessage( - "hi?", - tool_calls=[ - { - "name": "tool3", - "args": {"some_val": 1, "some_other_val": "foo"}, - "id": "some 0", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - tool_message: ToolMessage = result_incorrect_name["messages"][-1] - assert tool_message.type == "tool" - assert tool_message.status == "error" - assert tool_message.content == "Error: tool3 is not a valid tool, try one of [tool1, tool2]." - assert tool_message.tool_call_id == "some 0" - - -def test_tool_node_node_interrupt() -> None: - def tool_interrupt(some_val: int) -> None: - """Tool docstring.""" - msg = "foo" - raise GraphBubbleUp(msg) - - def handle(e: GraphInterrupt) -> str: - return "handled" - - for handle_tool_errors in (True, (GraphBubbleUp,), "handled", handle, False): - node = _ToolNode([tool_interrupt], handle_tool_errors=handle_tool_errors) - with pytest.raises(GraphBubbleUp) as exc_info: - node.invoke( - { - "messages": [ - AIMessage( - "hi?", - tool_calls=[ - { - "name": "tool_interrupt", - "args": {"some_val": 0}, - "id": "some 0", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - assert exc_info.value == "foo" - - -@pytest.mark.parametrize("input_type", ["dict", "tool_calls"]) -async def test_tool_node_command(input_type: str) -> None: - from langchain_core.tools.base import InjectedToolCallId - - @dec_tool - def transfer_to_bob(tool_call_id: Annotated[str, InjectedToolCallId]): - """Transfer to Bob""" - return Command( - update={ - "messages": [ToolMessage(content="Transferred to Bob", tool_call_id=tool_call_id)] - }, - goto="bob", - graph=Command.PARENT, - ) - - @dec_tool - async def async_transfer_to_bob(tool_call_id: Annotated[str, InjectedToolCallId]): - """Transfer to Bob""" - return Command( - update={ - "messages": [ToolMessage(content="Transferred to Bob", tool_call_id=tool_call_id)] - }, - goto="bob", - graph=Command.PARENT, - ) - - class CustomToolSchema(BaseModel): - tool_call_id: Annotated[str, InjectedToolCallId] - - class MyCustomTool(BaseTool): - def _run(*args: Any, **kwargs: Any): - return Command( - update={ - "messages": [ - ToolMessage( - content="Transferred to Bob", - tool_call_id=kwargs["tool_call_id"], - ) - ] - }, - goto="bob", - graph=Command.PARENT, - ) - - async def _arun(*args: Any, **kwargs: Any): - return Command( - update={ - "messages": [ - ToolMessage( - content="Transferred to Bob", - tool_call_id=kwargs["tool_call_id"], - ) - ] - }, - goto="bob", - graph=Command.PARENT, - ) - - custom_tool = MyCustomTool( - name="custom_transfer_to_bob", - description="Transfer to bob", - args_schema=CustomToolSchema, - ) - async_custom_tool = MyCustomTool( - name="async_custom_transfer_to_bob", - description="Transfer to bob", - args_schema=CustomToolSchema, - ) - - # test mixing regular tools and tools returning commands - def add(a: int, b: int) -> int: - """Add two numbers""" - return a + b - - tool_calls = [ - {"args": {"a": 1, "b": 2}, "id": "1", "name": "add", "type": "tool_call"}, - {"args": {}, "id": "2", "name": "transfer_to_bob", "type": "tool_call"}, - ] - if input_type == "dict": - input_ = {"messages": [AIMessage("", tool_calls=tool_calls)]} - elif input_type == "tool_calls": - input_ = tool_calls - result = _ToolNode([add, transfer_to_bob]).invoke(input_, config=_create_config_with_runtime()) - - assert result == [ - { - "messages": [ - ToolMessage( - content="3", - tool_call_id="1", - name="add", - ) - ] - }, - Command( - update={ - "messages": [ - ToolMessage( - content="Transferred to Bob", - tool_call_id="2", - name="transfer_to_bob", - ) - ] - }, - goto="bob", - graph=Command.PARENT, - ), - ] - - # test tools returning commands - - # test sync tools - for tool in [transfer_to_bob, custom_tool]: - result = _ToolNode([tool]).invoke( - {"messages": [AIMessage("", tool_calls=[{"args": {}, "id": "1", "name": tool.name}])]}, - config=_create_config_with_runtime(), - ) - assert result == [ - Command( - update={ - "messages": [ - ToolMessage( - content="Transferred to Bob", - tool_call_id="1", - name=tool.name, - ) - ] - }, - goto="bob", - graph=Command.PARENT, - ) - ] - - # test async tools - for tool in [async_transfer_to_bob, async_custom_tool]: - result = await _ToolNode([tool]).ainvoke( - {"messages": [AIMessage("", tool_calls=[{"args": {}, "id": "1", "name": tool.name}])]}, - config=_create_config_with_runtime(), - ) - assert result == [ - Command( - update={ - "messages": [ - ToolMessage( - content="Transferred to Bob", - tool_call_id="1", - name=tool.name, - ) - ] - }, - goto="bob", - graph=Command.PARENT, - ) - ] - - # test multiple commands - result = _ToolNode([transfer_to_bob, custom_tool]).invoke( - { - "messages": [ - AIMessage( - "", - tool_calls=[ - {"args": {}, "id": "1", "name": "transfer_to_bob"}, - {"args": {}, "id": "2", "name": "custom_transfer_to_bob"}, - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - assert result == [ - Command( - update={ - "messages": [ - ToolMessage( - content="Transferred to Bob", - tool_call_id="1", - name="transfer_to_bob", - ) - ] - }, - goto="bob", - graph=Command.PARENT, - ), - Command( - update={ - "messages": [ - ToolMessage( - content="Transferred to Bob", - tool_call_id="2", - name="custom_transfer_to_bob", - ) - ] - }, - goto="bob", - graph=Command.PARENT, - ), - ] - - # test validation (mismatch between input type and command.update type) - with pytest.raises(ValueError): - - @dec_tool - def list_update_tool(tool_call_id: Annotated[str, InjectedToolCallId]): - """My tool""" - return Command(update=[ToolMessage(content="foo", tool_call_id=tool_call_id)]) - - _ToolNode([list_update_tool]).invoke( - { - "messages": [ - AIMessage( - "", - tool_calls=[{"args": {}, "id": "1", "name": "list_update_tool"}], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - # test validation (missing tool message in the update for current graph) - with pytest.raises(ValueError): - - @dec_tool - def no_update_tool(): - """My tool""" - return Command(update={"messages": []}) - - _ToolNode([no_update_tool]).invoke( - { - "messages": [ - AIMessage( - "", - tool_calls=[{"args": {}, "id": "1", "name": "no_update_tool"}], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - # test validation (tool message with a wrong tool call ID) - with pytest.raises(ValueError): - - @dec_tool - def mismatching_tool_call_id_tool(): - """My tool""" - return Command(update={"messages": [ToolMessage(content="foo", tool_call_id="2")]}) - - _ToolNode([mismatching_tool_call_id_tool]).invoke( - { - "messages": [ - AIMessage( - "", - tool_calls=[ - { - "args": {}, - "id": "1", - "name": "mismatching_tool_call_id_tool", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - # test validation (missing tool message in the update for parent graph is OK) - @dec_tool - def node_update_parent_tool(): - """No update""" - return Command(update={"messages": []}, graph=Command.PARENT) - - assert _ToolNode([node_update_parent_tool]).invoke( - { - "messages": [ - AIMessage( - "", - tool_calls=[{"args": {}, "id": "1", "name": "node_update_parent_tool"}], - ) - ] - }, - config=_create_config_with_runtime(), - ) == [Command(update={"messages": []}, graph=Command.PARENT)] - - -async def test_tool_node_command_list_input() -> None: - from langchain_core.tools.base import InjectedToolCallId - - @dec_tool - def transfer_to_bob(tool_call_id: Annotated[str, InjectedToolCallId]): - """Transfer to Bob""" - return Command( - update=[ToolMessage(content="Transferred to Bob", tool_call_id=tool_call_id)], - goto="bob", - graph=Command.PARENT, - ) - - @dec_tool - async def async_transfer_to_bob(tool_call_id: Annotated[str, InjectedToolCallId]): - """Transfer to Bob""" - return Command( - update=[ToolMessage(content="Transferred to Bob", tool_call_id=tool_call_id)], - goto="bob", - graph=Command.PARENT, - ) - - class CustomToolSchema(BaseModel): - tool_call_id: Annotated[str, InjectedToolCallId] - - class MyCustomTool(BaseTool): - def _run(*args: Any, **kwargs: Any): - return Command( - update=[ - ToolMessage( - content="Transferred to Bob", - tool_call_id=kwargs["tool_call_id"], - ) - ], - goto="bob", - graph=Command.PARENT, - ) - - async def _arun(*args: Any, **kwargs: Any): - return Command( - update=[ - ToolMessage( - content="Transferred to Bob", - tool_call_id=kwargs["tool_call_id"], - ) - ], - goto="bob", - graph=Command.PARENT, - ) - - custom_tool = MyCustomTool( - name="custom_transfer_to_bob", - description="Transfer to bob", - args_schema=CustomToolSchema, - ) - async_custom_tool = MyCustomTool( - name="async_custom_transfer_to_bob", - description="Transfer to bob", - args_schema=CustomToolSchema, - ) - - # test mixing regular tools and tools returning commands - def add(a: int, b: int) -> int: - """Add two numbers""" - return a + b - - result = _ToolNode([add, transfer_to_bob]).invoke( - [ - AIMessage( - "", - tool_calls=[ - {"args": {"a": 1, "b": 2}, "id": "1", "name": "add"}, - {"args": {}, "id": "2", "name": "transfer_to_bob"}, - ], - ) - ], - config=_create_config_with_runtime(), - ) - - assert result == [ - [ - ToolMessage( - content="3", - tool_call_id="1", - name="add", - ) - ], - Command( - update=[ - ToolMessage( - content="Transferred to Bob", - tool_call_id="2", - name="transfer_to_bob", - ) - ], - goto="bob", - graph=Command.PARENT, - ), - ] - - # test tools returning commands - - # test sync tools - for tool in [transfer_to_bob, custom_tool]: - result = _ToolNode([tool]).invoke( - [AIMessage("", tool_calls=[{"args": {}, "id": "1", "name": tool.name}])], - config=_create_config_with_runtime(), - ) - assert result == [ - Command( - update=[ - ToolMessage( - content="Transferred to Bob", - tool_call_id="1", - name=tool.name, - ) - ], - goto="bob", - graph=Command.PARENT, - ) - ] - - # test async tools - for tool in [async_transfer_to_bob, async_custom_tool]: - result = await _ToolNode([tool]).ainvoke( - [AIMessage("", tool_calls=[{"args": {}, "id": "1", "name": tool.name}])], - config=_create_config_with_runtime(), - ) - assert result == [ - Command( - update=[ - ToolMessage( - content="Transferred to Bob", - tool_call_id="1", - name=tool.name, - ) - ], - goto="bob", - graph=Command.PARENT, - ) - ] - - # test multiple commands - result = _ToolNode([transfer_to_bob, custom_tool]).invoke( - [ - AIMessage( - "", - tool_calls=[ - {"args": {}, "id": "1", "name": "transfer_to_bob"}, - {"args": {}, "id": "2", "name": "custom_transfer_to_bob"}, - ], - ) - ], - config=_create_config_with_runtime(), - ) - assert result == [ - Command( - update=[ - ToolMessage( - content="Transferred to Bob", - tool_call_id="1", - name="transfer_to_bob", - ) - ], - goto="bob", - graph=Command.PARENT, - ), - Command( - update=[ - ToolMessage( - content="Transferred to Bob", - tool_call_id="2", - name="custom_transfer_to_bob", - ) - ], - goto="bob", - graph=Command.PARENT, - ), - ] - - # test validation (mismatch between input type and command.update type) - with pytest.raises(ValueError): - - @dec_tool - def list_update_tool(tool_call_id: Annotated[str, InjectedToolCallId]): - """My tool""" - return Command( - update={"messages": [ToolMessage(content="foo", tool_call_id=tool_call_id)]} - ) - - _ToolNode([list_update_tool]).invoke( - [ - AIMessage( - "", - tool_calls=[{"args": {}, "id": "1", "name": "list_update_tool"}], - ) - ], - config=_create_config_with_runtime(), - ) - - # test validation (missing tool message in the update for current graph) - with pytest.raises(ValueError): - - @dec_tool - def no_update_tool(): - """My tool""" - return Command(update=[]) - - _ToolNode([no_update_tool]).invoke( - [ - AIMessage( - "", - tool_calls=[{"args": {}, "id": "1", "name": "no_update_tool"}], - ) - ], - config=_create_config_with_runtime(), - ) - - # test validation (tool message with a wrong tool call ID) - with pytest.raises(ValueError): - - @dec_tool - def mismatching_tool_call_id_tool(): - """My tool""" - return Command(update=[ToolMessage(content="foo", tool_call_id="2")]) - - _ToolNode([mismatching_tool_call_id_tool]).invoke( - [ - AIMessage( - "", - tool_calls=[{"args": {}, "id": "1", "name": "mismatching_tool_call_id_tool"}], - ) - ], - config=_create_config_with_runtime(), - ) - - # test validation (missing tool message in the update for parent graph is OK) - @dec_tool - def node_update_parent_tool(): - """No update""" - return Command(update=[], graph=Command.PARENT) - - assert _ToolNode([node_update_parent_tool]).invoke( - [ - AIMessage( - "", - tool_calls=[{"args": {}, "id": "1", "name": "node_update_parent_tool"}], - ) - ], - config=_create_config_with_runtime(), - ) == [Command(update=[], graph=Command.PARENT)] - - -def test_tool_node_parent_command_with_send() -> None: - from langchain_core.tools.base import InjectedToolCallId - - @dec_tool - def transfer_to_alice(tool_call_id: Annotated[str, InjectedToolCallId]): - """Transfer to Alice""" - return Command( - goto=[ - Send( - "alice", - { - "messages": [ - ToolMessage( - content="Transferred to Alice", - name="transfer_to_alice", - tool_call_id=tool_call_id, - ) - ] - }, - ) - ], - graph=Command.PARENT, - ) - - @dec_tool - def transfer_to_bob(tool_call_id: Annotated[str, InjectedToolCallId]): - """Transfer to Bob""" - return Command( - goto=[ - Send( - "bob", - { - "messages": [ - ToolMessage( - content="Transferred to Bob", - name="transfer_to_bob", - tool_call_id=tool_call_id, - ) - ] - }, - ) - ], - graph=Command.PARENT, - ) - - tool_calls = [ - {"args": {}, "id": "1", "name": "transfer_to_alice", "type": "tool_call"}, - {"args": {}, "id": "2", "name": "transfer_to_bob", "type": "tool_call"}, - ] - - result = _ToolNode([transfer_to_alice, transfer_to_bob]).invoke( - [AIMessage("", tool_calls=tool_calls)], - config=_create_config_with_runtime(), - ) - - assert result == [ - Command( - goto=[ - Send( - "alice", - { - "messages": [ - ToolMessage( - content="Transferred to Alice", - name="transfer_to_alice", - tool_call_id="1", - ) - ] - }, - ), - Send( - "bob", - { - "messages": [ - ToolMessage( - content="Transferred to Bob", - name="transfer_to_bob", - tool_call_id="2", - ) - ] - }, - ), - ], - graph=Command.PARENT, - ) - ] - - -async def test_tool_node_command_remove_all_messages() -> None: - from langchain_core.tools.base import InjectedToolCallId - - @dec_tool - def remove_all_messages_tool(tool_call_id: Annotated[str, InjectedToolCallId]): - """A tool that removes all messages.""" - return Command(update={"messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES)]}) - - tool_node = _ToolNode([remove_all_messages_tool]) - tool_call = { - "name": "remove_all_messages_tool", - "args": {}, - "id": "tool_call_123", - } - result = await tool_node.ainvoke( - {"messages": [AIMessage(content="", tool_calls=[tool_call])]}, - config=_create_config_with_runtime(), - ) - - assert isinstance(result, list) - assert len(result) == 1 - command = result[0] - assert isinstance(command, Command) - assert command.update == {"messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES)]} - - -class _InjectStateSchema(TypedDict): - messages: list - foo: str - - -class _InjectedStatePydanticV2Schema(BaseModel): - messages: list - foo: str - - -@dataclasses.dataclass -class _InjectedStateDataclassSchema: - messages: list - foo: str - - -_INJECTED_STATE_SCHEMAS = [ - _InjectStateSchema, - _InjectedStatePydanticV2Schema, - _InjectedStateDataclassSchema, -] - -if sys.version_info < (3, 14): - - class _InjectedStatePydanticSchema(BaseModelV1): - messages: list - foo: str - - _INJECTED_STATE_SCHEMAS.append(_InjectedStatePydanticSchema) - -T = TypeVar("T") - - -@pytest.mark.parametrize("schema_", _INJECTED_STATE_SCHEMAS) -def test_tool_node_inject_state(schema_: type[T]) -> None: - def tool1(some_val: int, state: Annotated[T, InjectedState]) -> str: - """Tool 1 docstring.""" - if isinstance(state, dict): - return state["foo"] - return state.foo - - def tool2(some_val: int, state: Annotated[T, InjectedState()]) -> str: - """Tool 2 docstring.""" - if isinstance(state, dict): - return state["foo"] - return state.foo - - def tool3( - some_val: int, - foo: Annotated[str, InjectedState("foo")], - msgs: Annotated[list[AnyMessage], InjectedState("messages")], - ) -> str: - """Tool 1 docstring.""" - return foo - - def tool4(some_val: int, msgs: Annotated[list[AnyMessage], InjectedState("messages")]) -> str: - """Tool 1 docstring.""" - return msgs[0].content - - node = _ToolNode([tool1, tool2, tool3, tool4], handle_tool_errors=True) - for tool_name in ("tool1", "tool2", "tool3"): - tool_call = { - "name": tool_name, - "args": {"some_val": 1}, - "id": "some 0", - "type": "tool_call", - } - msg = AIMessage("hi?", tool_calls=[tool_call]) - result = node.invoke( - schema_(messages=[msg], foo="bar"), config=_create_config_with_runtime() - ) - tool_message = result["messages"][-1] - assert tool_message.content == "bar", f"Failed for tool={tool_name}" - - if tool_name == "tool3": - failure_input = None - with contextlib.suppress(Exception): - failure_input = schema_(messages=[msg], notfoo="bar") - if failure_input is not None: - with pytest.raises(KeyError): - node.invoke(failure_input, config=_create_config_with_runtime()) - - with pytest.raises(ValueError): - node.invoke([msg], config=_create_config_with_runtime()) - else: - failure_input = None - try: - failure_input = schema_(messages=[msg], notfoo="bar") - except Exception: - # We'd get a validation error from pydantic state and wouldn't make it to the node - # anyway - pass - if failure_input is not None: - messages_ = node.invoke(failure_input, config=_create_config_with_runtime()) - tool_message = messages_["messages"][-1] - assert "KeyError" in tool_message.content - tool_message = node.invoke([msg], config=_create_config_with_runtime())[-1] - assert "KeyError" in tool_message.content - - tool_call = { - "name": "tool4", - "args": {"some_val": 1}, - "id": "some 0", - "type": "tool_call", - } - msg = AIMessage("hi?", tool_calls=[tool_call]) - result = node.invoke(schema_(messages=[msg], foo=""), config=_create_config_with_runtime()) - tool_message = result["messages"][-1] - assert tool_message.content == "hi?" - - result = node.invoke([msg], config=_create_config_with_runtime()) - tool_message = result[-1] - assert tool_message.content == "hi?" - - -def test_tool_node_inject_store() -> None: - store = InMemoryStore() - namespace = ("test",) - - def tool1(some_val: int, store: Annotated[BaseStore, InjectedStore()]) -> str: - """Tool 1 docstring.""" - store_val = store.get(namespace, "test_key").value["foo"] - return f"Some val: {some_val}, store val: {store_val}" - - def tool2(some_val: int, store: Annotated[BaseStore, InjectedStore()]) -> str: - """Tool 2 docstring.""" - store_val = store.get(namespace, "test_key").value["foo"] - return f"Some val: {some_val}, store val: {store_val}" - - def tool3( - some_val: int, - bar: Annotated[str, InjectedState("bar")], - store: Annotated[BaseStore, InjectedStore()], - ) -> str: - """Tool 3 docstring.""" - store_val = store.get(namespace, "test_key").value["foo"] - return f"Some val: {some_val}, store val: {store_val}, state val: {bar}" - - node = _ToolNode([tool1, tool2, tool3], handle_tool_errors=True) - store.put(namespace, "test_key", {"foo": "bar"}) - - class State(MessagesState): - bar: str - - builder = StateGraph(State) - builder.add_node("tools", node) - builder.add_edge(START, "tools") - graph = builder.compile(store=store) - - for tool_name in ("tool1", "tool2"): - tool_call = { - "name": tool_name, - "args": {"some_val": 1}, - "id": "some 0", - "type": "tool_call", - } - msg = AIMessage("hi?", tool_calls=[tool_call]) - node_result = node.invoke( - {"messages": [msg]}, config=_create_config_with_runtime(store=store) - ) - graph_result = graph.invoke({"messages": [msg]}) - for result in (node_result, graph_result): - result["messages"][-1] - tool_message = result["messages"][-1] - assert tool_message.content == "Some val: 1, store val: bar", ( - f"Failed for tool={tool_name}" - ) - - tool_call = { - "name": "tool3", - "args": {"some_val": 1}, - "id": "some 0", - "type": "tool_call", - } - msg = AIMessage("hi?", tool_calls=[tool_call]) - node_result = node.invoke( - {"messages": [msg], "bar": "baz"}, config=_create_config_with_runtime(store=store) - ) - graph_result = graph.invoke({"messages": [msg], "bar": "baz"}) - for result in (node_result, graph_result): - result["messages"][-1] - tool_message = result["messages"][-1] - assert tool_message.content == "Some val: 1, store val: bar, state val: baz", ( - f"Failed for tool={tool_name}" - ) - - # test injected store without passing store to compiled graph - failing_graph = builder.compile() - with pytest.raises(ValueError): - failing_graph.invoke({"messages": [msg], "bar": "baz"}) - - -def test_tool_node_ensure_utf8() -> None: - @dec_tool - def get_day_list(days: list[str]) -> list[str]: - """choose days""" - return days - - data = ["星期一", "水曜日", "목요일", "Friday"] - tools = [get_day_list] - tool_calls = [ToolCall(name=get_day_list.name, args={"days": data}, id="test_id")] - outputs: list[ToolMessage] = _ToolNode(tools).invoke( - [AIMessage(content="", tool_calls=tool_calls)], - config=_create_config_with_runtime(), - ) - assert outputs[0].content == json.dumps(data, ensure_ascii=False) - - -def test_tool_node_messages_key() -> None: - @dec_tool - def add(a: int, b: int) -> int: - """Adds a and b.""" - return a + b - - model = FakeToolCallingModel( - tool_calls=[[ToolCall(name=add.name, args={"a": 1, "b": 2}, id="test_id")]] - ) - - class State(TypedDict): - subgraph_messages: Annotated[list[AnyMessage], add_messages] - - def call_model(state: State) -> dict[str, Any]: - response = model.invoke(state["subgraph_messages"]) - model.tool_calls = [] - return {"subgraph_messages": response} - - builder = StateGraph(State) - builder.add_node("agent", call_model) - builder.add_node("tools", _ToolNode([add], messages_key="subgraph_messages")) - builder.add_conditional_edges( - "agent", partial(tools_condition, messages_key="subgraph_messages") - ) - builder.add_edge(START, "agent") - builder.add_edge("tools", "agent") - - graph = builder.compile() - result = graph.invoke({"subgraph_messages": [HumanMessage(content="hi")]}) - assert result["subgraph_messages"] == [ - _AnyIdHumanMessage(content="hi"), - AIMessage( - content="hi", - id="0", - tool_calls=[ToolCall(name=add.name, args={"a": 1, "b": 2}, id="test_id")], - ), - _AnyIdToolMessage(content="3", name=add.name, tool_call_id="test_id"), - AIMessage(content="hi-hi-3", id="1"), - ] - - -def test_tool_node_stream_writer() -> None: - @dec_tool - def streaming_tool(x: int) -> str: - """Do something with writer.""" - my_writer = get_stream_writer() - for value in ["foo", "bar", "baz"]: - my_writer({"custom_tool_value": value}) - - return x - - tool_node = _ToolNode([streaming_tool]) - graph = ( - StateGraph(MessagesState).add_node("tools", tool_node).add_edge(START, "tools").compile() - ) - - tool_call = { - "name": "streaming_tool", - "args": {"x": 1}, - "id": "1", - "type": "tool_call", - } - inputs = { - "messages": [AIMessage("", tool_calls=[tool_call])], - } - - assert list(graph.stream(inputs, stream_mode="custom")) == [ - {"custom_tool_value": "foo"}, - {"custom_tool_value": "bar"}, - {"custom_tool_value": "baz"}, - ] - assert list(graph.stream(inputs, stream_mode=["custom", "updates"])) == [ - ("custom", {"custom_tool_value": "foo"}), - ("custom", {"custom_tool_value": "bar"}), - ("custom", {"custom_tool_value": "baz"}), - ( - "updates", - { - "tools": { - "messages": [ - _AnyIdToolMessage( - content="1", - name="streaming_tool", - tool_call_id="1", - ), - ], - }, - }, - ), - ] diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_tool_node_interceptor_unregistered.py b/libs/langchain_v1/tests/unit_tests/agents/test_tool_node_interceptor_unregistered.py deleted file mode 100644 index bbf3060b3cb9b..0000000000000 --- a/libs/langchain_v1/tests/unit_tests/agents/test_tool_node_interceptor_unregistered.py +++ /dev/null @@ -1,571 +0,0 @@ -"""Test tool node interceptor handling of unregistered tools.""" - -from collections.abc import Awaitable, Callable -from unittest.mock import Mock - -import pytest -from langchain_core.messages import AIMessage, ToolMessage -from langchain_core.runnables.config import RunnableConfig -from langchain_core.tools import tool as dec_tool -from langgraph.store.base import BaseStore -from langgraph.types import Command - -from langchain.tools.tool_node import ToolCallRequest, _ToolNode - -pytestmark = pytest.mark.anyio - - -def _create_mock_runtime(store: BaseStore | None = None) -> Mock: - """Create a mock Runtime object for testing ToolNode outside of graph context. - - This helper is needed because ToolNode._func expects a Runtime parameter - which is injected by RunnableCallable from config["configurable"]["__pregel_runtime"]. - When testing ToolNode directly (outside a graph), we need to provide this manually. - """ - mock_runtime = Mock() - mock_runtime.store = store - mock_runtime.context = None - mock_runtime.stream_writer = lambda *args, **kwargs: None - return mock_runtime - - -def _create_config_with_runtime(store: BaseStore | None = None) -> RunnableConfig: - """Create a RunnableConfig with mock Runtime for testing ToolNode. - - Returns: - RunnableConfig with __pregel_runtime in configurable dict. - """ - return {"configurable": {"__pregel_runtime": _create_mock_runtime(store)}} - - -@dec_tool -def registered_tool(x: int) -> str: - """A registered tool.""" - return f"Result: {x}" - - -def test_interceptor_can_handle_unregistered_tool_sync() -> None: - """Test that interceptor can handle requests for unregistered tools (sync).""" - - def interceptor( - request: ToolCallRequest, - execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Intercept and handle unregistered tools.""" - if request.tool_call["name"] == "unregistered_tool": - # Short-circuit without calling execute for unregistered tool - return ToolMessage( - content="Handled by interceptor", - tool_call_id=request.tool_call["id"], - name="unregistered_tool", - ) - # Pass through for registered tools - return execute(request) - - node = _ToolNode([registered_tool], wrap_tool_call=interceptor) - - # Test registered tool works normally - result = node.invoke( - [ - AIMessage( - "", - tool_calls=[ - { - "name": "registered_tool", - "args": {"x": 42}, - "id": "1", - "type": "tool_call", - } - ], - ) - ], - config=_create_config_with_runtime(), - ) - assert result[0].content == "Result: 42" - assert result[0].tool_call_id == "1" - - # Test unregistered tool is intercepted and handled - result = node.invoke( - [ - AIMessage( - "", - tool_calls=[ - { - "name": "unregistered_tool", - "args": {"x": 99}, - "id": "2", - "type": "tool_call", - } - ], - ) - ], - config=_create_config_with_runtime(), - ) - assert result[0].content == "Handled by interceptor" - assert result[0].tool_call_id == "2" - assert result[0].name == "unregistered_tool" - - -async def test_interceptor_can_handle_unregistered_tool_async() -> None: - """Test that interceptor can handle requests for unregistered tools (async).""" - - async def async_interceptor( - request: ToolCallRequest, - execute: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]], - ) -> ToolMessage | Command: - """Intercept and handle unregistered tools.""" - if request.tool_call["name"] == "unregistered_tool": - # Short-circuit without calling execute for unregistered tool - return ToolMessage( - content="Handled by async interceptor", - tool_call_id=request.tool_call["id"], - name="unregistered_tool", - ) - # Pass through for registered tools - return await execute(request) - - node = _ToolNode([registered_tool], awrap_tool_call=async_interceptor) - - # Test registered tool works normally - result = await node.ainvoke( - [ - AIMessage( - "", - tool_calls=[ - { - "name": "registered_tool", - "args": {"x": 42}, - "id": "1", - "type": "tool_call", - } - ], - ) - ], - config=_create_config_with_runtime(), - ) - assert result[0].content == "Result: 42" - assert result[0].tool_call_id == "1" - - # Test unregistered tool is intercepted and handled - result = await node.ainvoke( - [ - AIMessage( - "", - tool_calls=[ - { - "name": "unregistered_tool", - "args": {"x": 99}, - "id": "2", - "type": "tool_call", - } - ], - ) - ], - config=_create_config_with_runtime(), - ) - assert result[0].content == "Handled by async interceptor" - assert result[0].tool_call_id == "2" - assert result[0].name == "unregistered_tool" - - -def test_unregistered_tool_error_when_interceptor_calls_execute() -> None: - """Test that unregistered tools error if interceptor tries to execute them.""" - - def bad_interceptor( - request: ToolCallRequest, - execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Interceptor that tries to execute unregistered tool.""" - # This should fail validation when execute is called - return execute(request) - - node = _ToolNode([registered_tool], wrap_tool_call=bad_interceptor) - - # Registered tool should still work - result = node.invoke( - [ - AIMessage( - "", - tool_calls=[ - { - "name": "registered_tool", - "args": {"x": 42}, - "id": "1", - "type": "tool_call", - } - ], - ) - ], - config=_create_config_with_runtime(), - ) - assert result[0].content == "Result: 42" - - # Unregistered tool should error when interceptor calls execute - result = node.invoke( - [ - AIMessage( - "", - tool_calls=[ - { - "name": "unregistered_tool", - "args": {"x": 99}, - "id": "2", - "type": "tool_call", - } - ], - ) - ], - config=_create_config_with_runtime(), - ) - # Should get validation error message - assert result[0].status == "error" - assert "is not a valid tool" in result[0].content - assert result[0].tool_call_id == "2" - - -def test_interceptor_handles_mix_of_registered_and_unregistered() -> None: - """Test interceptor handling mix of registered and unregistered tools.""" - - def selective_interceptor( - request: ToolCallRequest, - execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Handle unregistered tools, pass through registered ones.""" - if request.tool_call["name"] == "magic_tool": - return ToolMessage( - content=f"Magic result: {request.tool_call['args'].get('value', 0) * 2}", - tool_call_id=request.tool_call["id"], - name="magic_tool", - ) - return execute(request) - - node = _ToolNode([registered_tool], wrap_tool_call=selective_interceptor) - - # Test multiple tool calls - mix of registered and unregistered - result = node.invoke( - [ - AIMessage( - "", - tool_calls=[ - { - "name": "registered_tool", - "args": {"x": 10}, - "id": "1", - "type": "tool_call", - }, - { - "name": "magic_tool", - "args": {"value": 5}, - "id": "2", - "type": "tool_call", - }, - { - "name": "registered_tool", - "args": {"x": 20}, - "id": "3", - "type": "tool_call", - }, - ], - ) - ], - config=_create_config_with_runtime(), - ) - - # All tools should execute successfully - assert len(result) == 3 - assert result[0].content == "Result: 10" - assert result[0].tool_call_id == "1" - assert result[1].content == "Magic result: 10" - assert result[1].tool_call_id == "2" - assert result[2].content == "Result: 20" - assert result[2].tool_call_id == "3" - - -def test_interceptor_command_for_unregistered_tool() -> None: - """Test interceptor returning Command for unregistered tool.""" - - def command_interceptor( - request: ToolCallRequest, - execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Return Command for unregistered tools.""" - if request.tool_call["name"] == "routing_tool": - return Command( - update=[ - ToolMessage( - content="Routing to special handler", - tool_call_id=request.tool_call["id"], - name="routing_tool", - ) - ], - goto="special_node", - ) - return execute(request) - - node = _ToolNode([registered_tool], wrap_tool_call=command_interceptor) - - result = node.invoke( - [ - AIMessage( - "", - tool_calls=[ - { - "name": "routing_tool", - "args": {}, - "id": "1", - "type": "tool_call", - } - ], - ) - ], - config=_create_config_with_runtime(), - ) - - # Should get Command back - assert len(result) == 1 - assert isinstance(result[0], Command) - assert result[0].goto == "special_node" - assert result[0].update is not None - assert len(result[0].update) == 1 - assert result[0].update[0].content == "Routing to special handler" - - -def test_interceptor_exception_with_unregistered_tool() -> None: - """Test that interceptor exceptions are caught by error handling.""" - - def failing_interceptor( - request: ToolCallRequest, - execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Interceptor that throws exception for unregistered tools.""" - if request.tool_call["name"] == "bad_tool": - msg = "Interceptor failed" - raise ValueError(msg) - return execute(request) - - node = _ToolNode([registered_tool], wrap_tool_call=failing_interceptor, handle_tool_errors=True) - - # Interceptor exception should be caught and converted to error message - result = node.invoke( - [ - AIMessage( - "", - tool_calls=[ - { - "name": "bad_tool", - "args": {}, - "id": "1", - "type": "tool_call", - } - ], - ) - ], - config=_create_config_with_runtime(), - ) - - assert len(result) == 1 - assert result[0].status == "error" - assert "Interceptor failed" in result[0].content - assert result[0].tool_call_id == "1" - - # Test that exception is raised when handle_tool_errors is False - node_no_handling = _ToolNode( - [registered_tool], wrap_tool_call=failing_interceptor, handle_tool_errors=False - ) - - with pytest.raises(ValueError, match="Interceptor failed"): - node_no_handling.invoke( - [ - AIMessage( - "", - tool_calls=[ - { - "name": "bad_tool", - "args": {}, - "id": "2", - "type": "tool_call", - } - ], - ) - ], - config=_create_config_with_runtime(), - ) - - -async def test_async_interceptor_exception_with_unregistered_tool() -> None: - """Test that async interceptor exceptions are caught by error handling.""" - - async def failing_async_interceptor( - request: ToolCallRequest, - execute: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]], - ) -> ToolMessage | Command: - """Async interceptor that throws exception for unregistered tools.""" - if request.tool_call["name"] == "bad_async_tool": - msg = "Async interceptor failed" - raise RuntimeError(msg) - return await execute(request) - - node = _ToolNode( - [registered_tool], awrap_tool_call=failing_async_interceptor, handle_tool_errors=True - ) - - # Interceptor exception should be caught and converted to error message - result = await node.ainvoke( - [ - AIMessage( - "", - tool_calls=[ - { - "name": "bad_async_tool", - "args": {}, - "id": "1", - "type": "tool_call", - } - ], - ) - ], - config=_create_config_with_runtime(), - ) - - assert len(result) == 1 - assert result[0].status == "error" - assert "Async interceptor failed" in result[0].content - assert result[0].tool_call_id == "1" - - # Test that exception is raised when handle_tool_errors is False - node_no_handling = _ToolNode( - [registered_tool], awrap_tool_call=failing_async_interceptor, handle_tool_errors=False - ) - - with pytest.raises(RuntimeError, match="Async interceptor failed"): - await node_no_handling.ainvoke( - [ - AIMessage( - "", - tool_calls=[ - { - "name": "bad_async_tool", - "args": {}, - "id": "2", - "type": "tool_call", - } - ], - ) - ], - config=_create_config_with_runtime(), - ) - - -def test_interceptor_with_dict_input_format() -> None: - """Test that interceptor works with dict input format.""" - - def interceptor( - request: ToolCallRequest, - execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Intercept unregistered tools with dict input.""" - if request.tool_call["name"] == "dict_tool": - return ToolMessage( - content="Handled dict input", - tool_call_id=request.tool_call["id"], - name="dict_tool", - ) - return execute(request) - - node = _ToolNode([registered_tool], wrap_tool_call=interceptor) - - # Test with dict input format - result = node.invoke( - { - "messages": [ - AIMessage( - "", - tool_calls=[ - { - "name": "dict_tool", - "args": {"value": 5}, - "id": "1", - "type": "tool_call", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - # Should return dict format output - assert isinstance(result, dict) - assert "messages" in result - assert len(result["messages"]) == 1 - assert result["messages"][0].content == "Handled dict input" - assert result["messages"][0].tool_call_id == "1" - - -def test_interceptor_verifies_tool_is_none_for_unregistered() -> None: - """Test that request.tool is None for unregistered tools.""" - - captured_requests: list[ToolCallRequest] = [] - - def capturing_interceptor( - request: ToolCallRequest, - execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Capture request to verify tool field.""" - captured_requests.append(request) - if request.tool is None: - # Tool is unregistered - return ToolMessage( - content=f"Unregistered: {request.tool_call['name']}", - tool_call_id=request.tool_call["id"], - name=request.tool_call["name"], - ) - # Tool is registered - return execute(request) - - node = _ToolNode([registered_tool], wrap_tool_call=capturing_interceptor) - - # Test unregistered tool - node.invoke( - [ - AIMessage( - "", - tool_calls=[ - { - "name": "unknown_tool", - "args": {}, - "id": "1", - "type": "tool_call", - } - ], - ) - ], - config=_create_config_with_runtime(), - ) - - assert len(captured_requests) == 1 - assert captured_requests[0].tool is None - assert captured_requests[0].tool_call["name"] == "unknown_tool" - - # Clear and test registered tool - captured_requests.clear() - node.invoke( - [ - AIMessage( - "", - tool_calls=[ - { - "name": "registered_tool", - "args": {"x": 10}, - "id": "2", - "type": "tool_call", - } - ], - ) - ], - config=_create_config_with_runtime(), - ) - - assert len(captured_requests) == 1 - assert captured_requests[0].tool is not None - assert captured_requests[0].tool.name == "registered_tool" diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_tool_node_validation_error_filtering.py b/libs/langchain_v1/tests/unit_tests/agents/test_tool_node_validation_error_filtering.py deleted file mode 100644 index 42e22aaf679a5..0000000000000 --- a/libs/langchain_v1/tests/unit_tests/agents/test_tool_node_validation_error_filtering.py +++ /dev/null @@ -1,678 +0,0 @@ -"""Unit tests for ValidationError filtering in ToolNode. - -This module tests that validation errors are filtered to only include arguments -that the LLM controls. Injected arguments (InjectedState, InjectedStore, -ToolRuntime) are automatically provided by the system and should not appear in -validation error messages. This ensures the LLM receives focused, actionable -feedback about the parameters it can actually control, improving error correction -and reducing confusion from irrelevant system implementation details. -""" - -import sys -from typing import Annotated -from unittest.mock import Mock - -import pytest -from langchain_core.messages import AIMessage, HumanMessage -from langchain_core.runnables.config import RunnableConfig -from langchain_core.tools import tool as dec_tool -from langgraph.store.base import BaseStore -from langgraph.store.memory import InMemoryStore -from pydantic import BaseModel -from typing_extensions import TypedDict - -from langchain.agents import create_agent -from langchain.agents.middleware.types import AgentState -from langchain.tools import InjectedState, InjectedStore -from langchain.tools.tool_node import ToolInvocationError, ToolRuntime, _ToolNode - -from .model import FakeToolCallingModel - -pytestmark = pytest.mark.anyio - - -def _create_mock_runtime(store: BaseStore | None = None) -> Mock: - """Create a mock Runtime object for testing ToolNode outside of graph context.""" - mock_runtime = Mock() - mock_runtime.store = store - mock_runtime.context = None - mock_runtime.stream_writer = lambda *args, **kwargs: None - return mock_runtime - - -def _create_config_with_runtime(store: BaseStore | None = None) -> RunnableConfig: - """Create a RunnableConfig with mock Runtime for testing ToolNode.""" - return {"configurable": {"__pregel_runtime": _create_mock_runtime(store)}} - - -async def test_filter_injected_state_validation_errors() -> None: - """Test that validation errors for InjectedState arguments are filtered out. - - InjectedState parameters are not controlled by the LLM, so any validation - errors related to them should not appear in error messages. This ensures - the LLM receives only actionable feedback about its own tool call arguments. - """ - - @dec_tool - def my_tool( - value: int, - state: Annotated[dict, InjectedState], - ) -> str: - """Tool that uses injected state. - - Args: - value: An integer value. - state: The graph state (injected). - """ - return f"value={value}, messages={len(state.get('messages', []))}" - - tool_node = _ToolNode([my_tool]) - - # Call with invalid 'value' argument (should be int, not str) - result = await tool_node.ainvoke( - { - "messages": [ - AIMessage( - "hi?", - tool_calls=[ - { - "name": "my_tool", - "args": {"value": "not_an_int"}, # Invalid type - "id": "call_1", - "type": "tool_call", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - # Should get a ToolMessage with error - assert len(result["messages"]) == 1 - tool_message = result["messages"][0] - assert tool_message.status == "error" - assert tool_message.tool_call_id == "call_1" - - # Error should mention 'value' but NOT 'state' (which is injected) - assert "value" in tool_message.content - assert "state" not in tool_message.content.lower() - - -async def test_filter_injected_store_validation_errors() -> None: - """Test that validation errors for InjectedStore arguments are filtered out. - - InjectedStore parameters are not controlled by the LLM, so any validation - errors related to them should not appear in error messages. This keeps - error feedback focused on LLM-controllable parameters. - """ - - @dec_tool - def my_tool( - key: str, - store: Annotated[BaseStore, InjectedStore()], - ) -> str: - """Tool that uses injected store. - - Args: - key: A key to look up. - store: The persistent store (injected). - """ - return f"key={key}" - - tool_node = _ToolNode([my_tool]) - - # Call with invalid 'key' argument (missing required argument) - result = await tool_node.ainvoke( - { - "messages": [ - AIMessage( - "hi?", - tool_calls=[ - { - "name": "my_tool", - "args": {}, # Missing 'key' - "id": "call_1", - "type": "tool_call", - } - ], - ) - ] - }, - config=_create_config_with_runtime(store=InMemoryStore()), - ) - - # Should get a ToolMessage with error - assert len(result["messages"]) == 1 - tool_message = result["messages"][0] - assert tool_message.status == "error" - - # Error should mention 'key' is required - assert "key" in tool_message.content.lower() - # The error should be about 'key' field specifically (not about store field) - # Note: 'store' might appear in input_value representation, but the validation - # error itself should only be for 'key' - assert ( - "field required" in tool_message.content.lower() - or "missing" in tool_message.content.lower() - ) - - -async def test_filter_tool_runtime_validation_errors() -> None: - """Test that validation errors for ToolRuntime arguments are filtered out. - - ToolRuntime parameters are not controlled by the LLM, so any validation - errors related to them should not appear in error messages. This ensures - the LLM only sees errors for parameters it can fix. - """ - - @dec_tool - def my_tool( - query: str, - runtime: ToolRuntime, - ) -> str: - """Tool that uses ToolRuntime. - - Args: - query: A query string. - runtime: The tool runtime context (injected). - """ - return f"query={query}" - - tool_node = _ToolNode([my_tool]) - - # Call with invalid 'query' argument (wrong type) - result = await tool_node.ainvoke( - { - "messages": [ - AIMessage( - "hi?", - tool_calls=[ - { - "name": "my_tool", - "args": {"query": 123}, # Should be str, not int - "id": "call_1", - "type": "tool_call", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - # Should get a ToolMessage with error - assert len(result["messages"]) == 1 - tool_message = result["messages"][0] - assert tool_message.status == "error" - - # Error should mention 'query' but NOT 'runtime' (which is injected) - assert "query" in tool_message.content.lower() - assert "runtime" not in tool_message.content.lower() - - -async def test_filter_multiple_injected_args() -> None: - """Test filtering when a tool has multiple injected arguments. - - When a tool uses multiple injected parameters (state, store, runtime), none of - them should appear in validation error messages since they're all system-provided - and not controlled by the LLM. Only LLM-controllable parameter errors should appear. - """ - - @dec_tool - def my_tool( - value: int, - state: Annotated[dict, InjectedState], - store: Annotated[BaseStore, InjectedStore()], - runtime: ToolRuntime, - ) -> str: - """Tool with multiple injected arguments. - - Args: - value: An integer value. - state: The graph state (injected). - store: The persistent store (injected). - runtime: The tool runtime context (injected). - """ - return f"value={value}" - - tool_node = _ToolNode([my_tool]) - - # Call with invalid 'value' - injected args should be filtered from error - result = await tool_node.ainvoke( - { - "messages": [ - AIMessage( - "hi?", - tool_calls=[ - { - "name": "my_tool", - "args": {"value": "not_an_int"}, - "id": "call_1", - "type": "tool_call", - } - ], - ) - ] - }, - config=_create_config_with_runtime(store=InMemoryStore()), - ) - - tool_message = result["messages"][0] - assert tool_message.status == "error" - - # Only 'value' error should be reported - assert "value" in tool_message.content - # None of the injected args should appear in error - assert "state" not in tool_message.content.lower() - assert "store" not in tool_message.content.lower() - assert "runtime" not in tool_message.content.lower() - - -async def test_no_filtering_when_all_errors_are_model_args() -> None: - """Test that validation errors for LLM-controlled arguments are preserved. - - When validation fails for arguments the LLM controls, those errors should - be fully reported to help the LLM correct its tool calls. This ensures - the LLM receives complete feedback about all issues it can fix. - """ - - @dec_tool - def my_tool( - value1: int, - value2: str, - state: Annotated[dict, InjectedState], - ) -> str: - """Tool with both regular and injected arguments. - - Args: - value1: First value. - value2: Second value. - state: The graph state (injected). - """ - return f"value1={value1}, value2={value2}" - - tool_node = _ToolNode([my_tool]) - - # Call with invalid arguments for BOTH non-injected parameters - result = await tool_node.ainvoke( - { - "messages": [ - AIMessage( - "hi?", - tool_calls=[ - { - "name": "my_tool", - "args": { - "value1": "not_an_int", # Invalid - "value2": 456, # Invalid (should be str) - }, - "id": "call_1", - "type": "tool_call", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - tool_message = result["messages"][0] - assert tool_message.status == "error" - - # Both errors should be present - assert "value1" in tool_message.content - assert "value2" in tool_message.content - # Injected state should not appear - assert "state" not in tool_message.content.lower() - - -async def test_validation_error_with_no_injected_args() -> None: - """Test that tools without injected arguments show all validation errors. - - For tools that only have LLM-controlled parameters, all validation errors - should be reported since everything is under the LLM's control and can be - corrected by the LLM in subsequent tool calls. - """ - - @dec_tool - def my_tool(value1: int, value2: str) -> str: - """Regular tool without injected arguments. - - Args: - value1: First value. - value2: Second value. - """ - return f"{value1} {value2}" - - tool_node = _ToolNode([my_tool]) - - result = await tool_node.ainvoke( - { - "messages": [ - AIMessage( - "hi?", - tool_calls=[ - { - "name": "my_tool", - "args": {"value1": "invalid", "value2": 123}, - "id": "call_1", - "type": "tool_call", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - tool_message = result["messages"][0] - assert tool_message.status == "error" - - # Both errors should be present since there are no injected args to filter - assert "value1" in tool_message.content - assert "value2" in tool_message.content - - -async def test_tool_invocation_error_without_handle_errors() -> None: - """Test that ToolInvocationError contains only LLM-controlled parameter errors. - - When handle_tool_errors is False, the raised ToolInvocationError should still - filter out system-injected arguments from the error details, ensuring that - error messages focus on what the LLM can control. - """ - - @dec_tool - def my_tool( - value: int, - state: Annotated[dict, InjectedState], - ) -> str: - """Tool with injected state. - - Args: - value: An integer value. - state: The graph state (injected). - """ - return f"value={value}" - - tool_node = _ToolNode([my_tool], handle_tool_errors=False) - - # Should raise ToolInvocationError with filtered errors - with pytest.raises(ToolInvocationError) as exc_info: - await tool_node.ainvoke( - { - "messages": [ - AIMessage( - "hi?", - tool_calls=[ - { - "name": "my_tool", - "args": {"value": "not_an_int"}, - "id": "call_1", - "type": "tool_call", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - error = exc_info.value - assert error.tool_name == "my_tool" - assert error.filtered_errors is not None - assert len(error.filtered_errors) > 0 - - # Filtered errors should only contain 'value' error, not 'state' - error_locs = [err["loc"] for err in error.filtered_errors] - assert any("value" in str(loc) for loc in error_locs) - assert not any("state" in str(loc) for loc in error_locs) - - -async def test_sync_tool_validation_error_filtering() -> None: - """Test that error filtering works for sync tools. - - Error filtering should work identically for both sync and async tool execution, - excluding injected arguments from validation error messages. - """ - - @dec_tool - def my_tool( - value: int, - state: Annotated[dict, InjectedState], - ) -> str: - """Sync tool with injected state. - - Args: - value: An integer value. - state: The graph state (injected). - """ - return f"value={value}" - - tool_node = _ToolNode([my_tool]) - - # Test sync invocation - result = tool_node.invoke( - { - "messages": [ - AIMessage( - "hi?", - tool_calls=[ - { - "name": "my_tool", - "args": {"value": "not_an_int"}, - "id": "call_1", - "type": "tool_call", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - tool_message = result["messages"][0] - assert tool_message.status == "error" - assert "value" in tool_message.content - assert "state" not in tool_message.content.lower() - - -@pytest.mark.skipif( - sys.version_info >= (3, 14), reason="Pydantic model rebuild issue in Python 3.14" -) -async def test_create_agent_error_content_with_multiple_params() -> None: - """Test that error messages only include LLM-controlled parameter errors. - - Uses create_agent to verify that when a tool with both LLM-controlled - and system-injected parameters receives invalid arguments, the error message: - 1. Contains details about LLM-controlled parameter errors (query, limit) - 2. Does NOT contain system-injected parameter names (state, store, runtime) - 3. Does NOT contain values from system-injected parameters - 4. Properly formats the validation errors for LLM correction - - This ensures the LLM receives focused, actionable feedback. - """ - - class TestState(AgentState): - user_id: str - api_key: str - session_data: dict - - @dec_tool - def complex_tool( - query: str, - limit: int, - state: Annotated[TestState, InjectedState], - store: Annotated[BaseStore, InjectedStore()], - runtime: ToolRuntime, - ) -> str: - """A complex tool with multiple injected and non-injected parameters. - - Args: - query: The search query string. - limit: Maximum number of results to return. - state: The graph state (injected). - store: The persistent store (injected). - runtime: The tool runtime context (injected). - """ - # Access injected params to verify they work in normal execution - user = state.get("user_id", "unknown") - return f"Results for '{query}' (limit={limit}, user={user})" - - # Create a model that makes an incorrect tool call with multiple errors: - # - query is wrong type (int instead of str) - # - limit is missing - # Then returns no tool calls to end the loop - model = FakeToolCallingModel( - tool_calls=[ - [ - { - "name": "complex_tool", - "args": { - "query": 12345, # Wrong type - should be str - # "limit" is missing - required field - }, - "id": "call_complex_1", - } - ], - [], # No tool calls on second iteration to end the loop - ] - ) - - # Create an agent with the complex tool and custom state - # Need to provide a store since the tool uses InjectedStore - agent = create_agent( - model=model, - tools=[complex_tool], - state_schema=TestState, - store=InMemoryStore(), - ) - - # Invoke with sensitive data in state - result = agent.invoke( - { - "messages": [HumanMessage("Search for something")], - "user_id": "user_12345", - "api_key": "sk-secret-key-abc123xyz", - "session_data": {"token": "secret_session_token"}, - } - ) - - # Find the tool error message - tool_messages = [m for m in result["messages"] if m.type == "tool"] - assert len(tool_messages) == 1 - tool_message = tool_messages[0] - assert tool_message.status == "error" - assert tool_message.tool_call_id == "call_complex_1" - - content = tool_message.content - - # Verify error mentions LLM-controlled parameter issues - assert "query" in content.lower(), "Error should mention 'query' (LLM-controlled)" - assert "limit" in content.lower(), "Error should mention 'limit' (LLM-controlled)" - - # Should indicate validation errors occurred - assert "validation error" in content.lower() or "error" in content.lower(), ( - "Error should indicate validation occurred" - ) - - # Verify NO system-injected parameter names appear in error - # These are not controlled by the LLM and should be excluded - assert "state" not in content.lower(), "Error should NOT mention 'state' (system-injected)" - assert "store" not in content.lower(), "Error should NOT mention 'store' (system-injected)" - assert "runtime" not in content.lower(), "Error should NOT mention 'runtime' (system-injected)" - - # Verify NO values from system-injected parameters appear in error - # The LLM doesn't control these, so they shouldn't distract from the actual issues - assert "user_12345" not in content, "Error should NOT contain user_id value (from state)" - assert "sk-secret-key" not in content, "Error should NOT contain api_key value (from state)" - assert "secret_session_token" not in content, ( - "Error should NOT contain session_data value (from state)" - ) - - # Verify the LLM's original tool call args are present - # The error should show what the LLM actually provided to help it correct the mistake - assert "12345" in content, "Error should show the invalid query value provided by LLM (12345)" - - # Check error is well-formatted - assert "complex_tool" in content, "Error should mention the tool name" - - -@pytest.mark.skipif( - sys.version_info >= (3, 14), reason="Pydantic model rebuild issue in Python 3.14" -) -async def test_create_agent_error_only_model_controllable_params() -> None: - """Test that errors only include LLM-controllable parameter issues. - - Focused test ensuring that validation errors for LLM-controlled parameters - are clearly reported, while system-injected parameters remain completely - absent from error messages. This provides focused feedback to the LLM. - """ - - class StateWithSecrets(AgentState): - password: str # Example of data not controlled by LLM - - @dec_tool - def secure_tool( - username: str, - email: str, - state: Annotated[StateWithSecrets, InjectedState], - ) -> str: - """Tool that validates user credentials. - - Args: - username: The username (3-20 chars). - email: The email address. - state: State with password (system-injected). - """ - return f"Validated {username} with email {email}" - - # LLM provides invalid username (too short) and invalid email - model = FakeToolCallingModel( - tool_calls=[ - [ - { - "name": "secure_tool", - "args": { - "username": "ab", # Too short (needs 3-20) - "email": "not-an-email", # Invalid format - }, - "id": "call_secure_1", - } - ], - [], - ] - ) - - agent = create_agent( - model=model, - tools=[secure_tool], - state_schema=StateWithSecrets, - ) - - result = agent.invoke( - { - "messages": [HumanMessage("Create account")], - "password": "super_secret_password_12345", - } - ) - - tool_messages = [m for m in result["messages"] if m.type == "tool"] - assert len(tool_messages) == 1 - content = tool_messages[0].content - - # The error should mention LLM-controlled parameters - # Note: Pydantic's default validation may or may not catch format issues, - # but the parameters themselves should be present in error messages - assert "username" in content.lower() or "email" in content.lower(), ( - "Error should mention at least one LLM-controlled parameter" - ) - - # Password is system-injected and should not appear - # The LLM doesn't control it, so it shouldn't distract from the actual errors - assert "password" not in content.lower(), ( - "Error should NOT mention 'password' (system-injected parameter)" - ) - assert "super_secret_password" not in content, ( - "Error should NOT contain password value (from system-injected state)" - ) diff --git a/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py b/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py deleted file mode 100644 index fca75964f08bf..0000000000000 --- a/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py +++ /dev/null @@ -1,1287 +0,0 @@ -"""Unit tests for tool call interceptor in ToolNode.""" - -from collections.abc import Callable -from unittest.mock import Mock - -import pytest -from langchain_core.messages import AIMessage, ToolCall, ToolMessage -from langchain_core.runnables import RunnableConfig -from langchain_core.tools import tool -from langgraph.store.base import BaseStore -from langgraph.types import Command - -from langchain.tools.tool_node import ( - ToolCallRequest, - _ToolNode, -) - -pytestmark = pytest.mark.anyio - - -def _create_mock_runtime(store: BaseStore | None = None) -> Mock: - mock_runtime = Mock() - mock_runtime.store = store - mock_runtime.context = None - mock_runtime.stream_writer = lambda _: None - return mock_runtime - - -def _create_config_with_runtime(store: BaseStore | None = None) -> RunnableConfig: - return {"configurable": {"__pregel_runtime": _create_mock_runtime(store)}} - - -@tool -def add(a: int, b: int) -> int: - """Add two numbers.""" - return a + b - - -@tool -def failing_tool(a: int) -> int: - """A tool that always fails.""" - msg = f"This tool always fails (input: {a})" - raise ValueError(msg) - - -@tool -def command_tool(goto: str) -> Command: - """A tool that returns a Command.""" - return Command(goto=goto) - - -def test_passthrough_handler() -> None: - """Test a simple passthrough handler that doesn't modify anything.""" - - def passthrough_handler( - request: ToolCallRequest, - execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Simple passthrough handler.""" - return execute(request) - - tool_node = _ToolNode([add], wrap_tool_call=passthrough_handler) - - result = tool_node.invoke( - { - "messages": [ - AIMessage( - "adding", - tool_calls=[ - { - "name": "add", - "args": {"a": 1, "b": 2}, - "id": "call_1", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - tool_message = result["messages"][-1] - assert isinstance(tool_message, ToolMessage) - assert tool_message.content == "3" - assert tool_message.tool_call_id == "call_1" - assert tool_message.status != "error" - - -async def test_passthrough_handler_async() -> None: - """Test passthrough handler with async tool.""" - - def passthrough_handler( - request: ToolCallRequest, - execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Simple passthrough handler.""" - return execute(request) - - tool_node = _ToolNode([add], wrap_tool_call=passthrough_handler) - - result = await tool_node.ainvoke( - { - "messages": [ - AIMessage( - "adding", - tool_calls=[ - { - "name": "add", - "args": {"a": 2, "b": 3}, - "id": "call_2", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - tool_message = result["messages"][-1] - assert isinstance(tool_message, ToolMessage) - assert tool_message.content == "5" - assert tool_message.tool_call_id == "call_2" - - -def test_modify_arguments() -> None: - """Test handler that modifies tool arguments before execution.""" - - def modify_args_handler( - request: ToolCallRequest, - execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Handler that doubles the input arguments.""" - # Modify the arguments - request.tool_call["args"]["a"] *= 2 - request.tool_call["args"]["b"] *= 2 - - return execute(request) - - tool_node = _ToolNode([add], wrap_tool_call=modify_args_handler) - - result = tool_node.invoke( - { - "messages": [ - AIMessage( - "adding", - tool_calls=[ - { - "name": "add", - "args": {"a": 1, "b": 2}, - "id": "call_3", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - tool_message = result["messages"][-1] - assert isinstance(tool_message, ToolMessage) - # Original args were (1, 2), doubled to (2, 4), so result is 6 - assert tool_message.content == "6" - - -def test_handler_validation_no_return() -> None: - """Test that handler must return a result.""" - - def handler_with_explicit_none( - request: ToolCallRequest, - execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Handler that executes and returns result.""" - return execute(request) - - tool_node = _ToolNode([add], wrap_tool_call=handler_with_explicit_none) - - result = tool_node.invoke( - { - "messages": [ - AIMessage( - "adding", - tool_calls=[ - { - "name": "add", - "args": {"a": 1, "b": 2}, - "id": "call_6", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - assert isinstance(result, dict) - messages = result["messages"] - assert len(messages) == 1 - assert isinstance(messages[0], ToolMessage) - assert messages[0].content == "3" - - -def test_handler_validation_no_yield() -> None: - """Test that handler that doesn't call execute returns None (bad behavior).""" - - def bad_handler( - _request: ToolCallRequest, - _execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Handler that doesn't call execute - will cause type error.""" - # Don't call execute, just return None (invalid) - return None # type: ignore[return-value] - - tool_node = _ToolNode([add], wrap_tool_call=bad_handler) - - # This will return None wrapped in messages - result = tool_node.invoke( - { - "messages": [ - AIMessage( - "adding", - tool_calls=[ - { - "name": "add", - "args": {"a": 1, "b": 2}, - "id": "call_7", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - # Result contains None in messages (bad handler behavior) - assert isinstance(result, dict) - assert result["messages"][0] is None - - -def test_handler_with_handle_tool_errors_true() -> None: - """Test that handle_tool_errors=True works with on_tool_call handler.""" - - def passthrough_handler( - request: ToolCallRequest, - execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Simple passthrough handler.""" - message = execute(request) - # When handle_tool_errors=True, errors should be converted to error messages - assert isinstance(message, ToolMessage) - assert message.status == "error" - return message - - tool_node = _ToolNode( - [failing_tool], wrap_tool_call=passthrough_handler, handle_tool_errors=True - ) - - result = tool_node.invoke( - { - "messages": [ - AIMessage( - "failing", - tool_calls=[ - { - "name": "failing_tool", - "args": {"a": 1}, - "id": "call_9", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - tool_message = result["messages"][-1] - assert isinstance(tool_message, ToolMessage) - assert tool_message.status == "error" - - -def test_multiple_tool_calls_with_handler() -> None: - """Test handler with multiple tool calls in one message.""" - call_count = 0 - - def counting_handler( - request: ToolCallRequest, - execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Handler that counts calls.""" - nonlocal call_count - call_count += 1 - return execute(request) - - tool_node = _ToolNode([add], wrap_tool_call=counting_handler) - - result = tool_node.invoke( - { - "messages": [ - AIMessage( - "adding multiple", - tool_calls=[ - { - "name": "add", - "args": {"a": 1, "b": 2}, - "id": "call_10", - }, - { - "name": "add", - "args": {"a": 3, "b": 4}, - "id": "call_11", - }, - { - "name": "add", - "args": {"a": 5, "b": 6}, - "id": "call_12", - }, - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - # Handler should be called once for each tool call - assert call_count == 3 - - # Verify all results - messages = result["messages"] - assert len(messages) == 3 - assert all(isinstance(m, ToolMessage) for m in messages) - assert messages[0].content == "3" - assert messages[1].content == "7" - assert messages[2].content == "11" - - -def test_tool_call_request_dataclass() -> None: - """Test ToolCallRequest dataclass.""" - tool_call: ToolCall = {"name": "add", "args": {"a": 1, "b": 2}, "id": "call_1"} - state: dict = {"messages": []} - runtime = None - - request = ToolCallRequest(tool_call=tool_call, tool=add, state=state, runtime=runtime) # type: ignore[arg-type] - - assert request.tool_call == tool_call - assert request.tool == add - assert request.state == state - assert request.runtime is None - assert request.tool_call["name"] == "add" - - -async def test_handler_with_async_execution() -> None: - """Test handler works correctly with async tool execution.""" - - @tool - def async_add(a: int, b: int) -> int: - """Async add two numbers.""" - return a + b - - def modifying_handler( - request: ToolCallRequest, - execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Handler that modifies arguments.""" - # Add 10 to both arguments - request.tool_call["args"]["a"] += 10 - request.tool_call["args"]["b"] += 10 - return execute(request) - - tool_node = _ToolNode([async_add], wrap_tool_call=modifying_handler) - - result = await tool_node.ainvoke( - { - "messages": [ - AIMessage( - "adding", - tool_calls=[ - { - "name": "async_add", - "args": {"a": 1, "b": 2}, - "id": "call_13", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - tool_message = result["messages"][-1] - assert isinstance(tool_message, ToolMessage) - # Original: 1 + 2 = 3, with modifications: 11 + 12 = 23 - assert tool_message.content == "23" - - -def test_short_circuit_with_tool_message() -> None: - """Test handler that returns ToolMessage to short-circuit tool execution.""" - - def short_circuit_handler( - request: ToolCallRequest, - _execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Handler that returns cached result without executing tool.""" - # Return a ToolMessage directly instead of calling execute - return ToolMessage( - content="cached_result", - tool_call_id=request.tool_call["id"], - name=request.tool_call["name"], - ) - - tool_node = _ToolNode([add], wrap_tool_call=short_circuit_handler) - - result = tool_node.invoke( - { - "messages": [ - AIMessage( - "adding", - tool_calls=[ - { - "name": "add", - "args": {"a": 1, "b": 2}, - "id": "call_16", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - tool_message = result["messages"][-1] - assert isinstance(tool_message, ToolMessage) - assert tool_message.content == "cached_result" - assert tool_message.tool_call_id == "call_16" - assert tool_message.name == "add" - - -async def test_short_circuit_with_tool_message_async() -> None: - """Test async handler that returns ToolMessage to short-circuit tool execution.""" - - def short_circuit_handler( - request: ToolCallRequest, - _execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Handler that returns cached result without executing tool.""" - return ToolMessage( - content="async_cached_result", - tool_call_id=request.tool_call["id"], - name=request.tool_call["name"], - ) - - tool_node = _ToolNode([add], wrap_tool_call=short_circuit_handler) - - result = await tool_node.ainvoke( - { - "messages": [ - AIMessage( - "adding", - tool_calls=[ - { - "name": "add", - "args": {"a": 2, "b": 3}, - "id": "call_17", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - tool_message = result["messages"][-1] - assert isinstance(tool_message, ToolMessage) - assert tool_message.content == "async_cached_result" - assert tool_message.tool_call_id == "call_17" - - -def test_conditional_short_circuit() -> None: - """Test handler that conditionally short-circuits based on request.""" - call_count = {"count": 0} - - def conditional_handler( - request: ToolCallRequest, - execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Handler that caches even numbers, executes odd.""" - call_count["count"] += 1 - a = request.tool_call["args"]["a"] - - if a % 2 == 0: - # Even: use cached result - return ToolMessage( - content=f"cached_{a}", - tool_call_id=request.tool_call["id"], - name=request.tool_call["name"], - ) - # Odd: execute normally - return execute(request) - - tool_node = _ToolNode([add], wrap_tool_call=conditional_handler) - - # Test with even number (should be cached) - result1 = tool_node.invoke( - { - "messages": [ - AIMessage( - "adding", - tool_calls=[ - { - "name": "add", - "args": {"a": 2, "b": 3}, - "id": "call_18", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - tool_message1 = result1["messages"][-1] - assert tool_message1.content == "cached_2" - - # Test with odd number (should execute) - result2 = tool_node.invoke( - { - "messages": [ - AIMessage( - "adding", - tool_calls=[ - { - "name": "add", - "args": {"a": 3, "b": 4}, - "id": "call_19", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - tool_message2 = result2["messages"][-1] - assert tool_message2.content == "7" # Actual execution: 3 + 4 - - -def test_direct_return_tool_message() -> None: - """Test handler that returns ToolMessage directly without calling execute.""" - - def direct_return_handler( - request: ToolCallRequest, - _execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Handler that returns ToolMessage directly.""" - # Return ToolMessage directly instead of calling execute - return ToolMessage( - content="direct_return", - tool_call_id=request.tool_call["id"], - name=request.tool_call["name"], - ) - - tool_node = _ToolNode([add], wrap_tool_call=direct_return_handler) - - result = tool_node.invoke( - { - "messages": [ - AIMessage( - "adding", - tool_calls=[ - { - "name": "add", - "args": {"a": 1, "b": 2}, - "id": "call_21", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - tool_message = result["messages"][-1] - assert isinstance(tool_message, ToolMessage) - assert tool_message.content == "direct_return" - assert tool_message.tool_call_id == "call_21" - assert tool_message.name == "add" - - -async def test_direct_return_tool_message_async() -> None: - """Test async handler that returns ToolMessage directly without calling execute.""" - - def direct_return_handler( - request: ToolCallRequest, - _execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Handler that returns ToolMessage directly.""" - return ToolMessage( - content="async_direct_return", - tool_call_id=request.tool_call["id"], - name=request.tool_call["name"], - ) - - tool_node = _ToolNode([add], wrap_tool_call=direct_return_handler) - - result = await tool_node.ainvoke( - { - "messages": [ - AIMessage( - "adding", - tool_calls=[ - { - "name": "add", - "args": {"a": 2, "b": 3}, - "id": "call_22", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - tool_message = result["messages"][-1] - assert isinstance(tool_message, ToolMessage) - assert tool_message.content == "async_direct_return" - assert tool_message.tool_call_id == "call_22" - - -def test_conditional_direct_return() -> None: - """Test handler that conditionally returns ToolMessage directly or executes tool.""" - - def conditional_handler( - request: ToolCallRequest, - execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Handler that returns cached or executes based on condition.""" - a = request.tool_call["args"]["a"] - - if a == 0: - # Return ToolMessage directly for zero - return ToolMessage( - content="zero_cached", - tool_call_id=request.tool_call["id"], - name=request.tool_call["name"], - ) - # Execute tool normally - return execute(request) - - tool_node = _ToolNode([add], wrap_tool_call=conditional_handler) - - # Test with zero (should return directly) - result1 = tool_node.invoke( - { - "messages": [ - AIMessage( - "adding", - tool_calls=[ - { - "name": "add", - "args": {"a": 0, "b": 5}, - "id": "call_23", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - tool_message1 = result1["messages"][-1] - assert tool_message1.content == "zero_cached" - - # Test with non-zero (should execute) - result2 = tool_node.invoke( - { - "messages": [ - AIMessage( - "adding", - tool_calls=[ - { - "name": "add", - "args": {"a": 3, "b": 4}, - "id": "call_24", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - tool_message2 = result2["messages"][-1] - assert tool_message2.content == "7" # Actual execution: 3 + 4 - - -def test_handler_can_throw_exception() -> None: - """Test that a handler can throw an exception to signal error.""" - - def throwing_handler( - request: ToolCallRequest, - execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Handler that throws an exception after receiving response.""" - response = execute(request) - # Check response and throw if invalid - if isinstance(response, ToolMessage): - msg = "Handler rejected the response" - raise TypeError(msg) - return response - - tool_node = _ToolNode([add], wrap_tool_call=throwing_handler, handle_tool_errors=True) - - result = tool_node.invoke( - { - "messages": [ - AIMessage( - "adding", - tool_calls=[ - { - "name": "add", - "args": {"a": 1, "b": 2}, - "id": "call_exc_1", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - # Should get error message due to handle_tool_errors=True - messages = result["messages"] - assert len(messages) == 1 - assert isinstance(messages[0], ToolMessage) - assert messages[0].status == "error" - assert "Handler rejected the response" in messages[0].content - - -def test_handler_throw_without_handle_errors() -> None: - """Test that exception propagates when handle_tool_errors=False.""" - - def throwing_handler( - request: ToolCallRequest, - execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Handler that throws an exception.""" - execute(request) - msg = "Handler error" - raise ValueError(msg) - - tool_node = _ToolNode([add], wrap_tool_call=throwing_handler, handle_tool_errors=False) - - with pytest.raises(ValueError, match="Handler error"): - tool_node.invoke( - { - "messages": [ - AIMessage( - "adding", - tool_calls=[ - { - "name": "add", - "args": {"a": 1, "b": 2}, - "id": "call_exc_2", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - -def test_retry_middleware_with_exception() -> None: - """Test retry middleware pattern that can call execute multiple times.""" - attempt_count = {"count": 0} - - def retry_handler( - request: ToolCallRequest, - execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Handler that can retry by calling execute multiple times.""" - max_retries = 3 - - for _attempt in range(max_retries): - attempt_count["count"] += 1 - response = execute(request) - - # Simulate checking for retriable errors - # In real use case, would check response.status or content - if isinstance(response, ToolMessage): - # For this test, just succeed immediately - return response - - # If we exhausted retries, return last response - return response - - tool_node = _ToolNode([add], wrap_tool_call=retry_handler) - - result = tool_node.invoke( - { - "messages": [ - AIMessage( - "adding", - tool_calls=[ - { - "name": "add", - "args": {"a": 1, "b": 2}, - "id": "call_exc_3", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - # Should succeed after 1 attempt - assert attempt_count["count"] == 1 - messages = result["messages"] - assert len(messages) == 1 - assert isinstance(messages[0], ToolMessage) - assert messages[0].content == "3" - - -async def test_async_handler_can_throw_exception() -> None: - """Test that async execution also supports exception throwing.""" - - def throwing_handler( - _request: ToolCallRequest, - _execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Handler that throws an exception before calling execute.""" - # Throw exception before executing (to avoid async/await complications) - msg = "Async handler rejected the request" - raise ValueError(msg) - - tool_node = _ToolNode([add], wrap_tool_call=throwing_handler, handle_tool_errors=True) - - result = await tool_node.ainvoke( - { - "messages": [ - AIMessage( - "adding", - tool_calls=[ - { - "name": "add", - "args": {"a": 1, "b": 2}, - "id": "call_exc_4", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - # Should get error message due to handle_tool_errors=True - messages = result["messages"] - assert len(messages) == 1 - assert isinstance(messages[0], ToolMessage) - assert messages[0].status == "error" - assert "Async handler rejected the request" in messages[0].content - - -def test_handler_cannot_yield_multiple_tool_messages() -> None: - """Test that handler can only return once (not applicable to handler pattern).""" - # With handler pattern, you can only return once by definition - # This test is no longer relevant - handlers naturally return once - # Keep test for compatibility but with simple passthrough - - def single_return_handler( - request: ToolCallRequest, - execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Handler that returns once (as all handlers do).""" - return execute(request) - - tool_node = _ToolNode([add], wrap_tool_call=single_return_handler) - - result = tool_node.invoke( - { - "messages": [ - AIMessage( - "adding", - tool_calls=[ - { - "name": "add", - "args": {"a": 1, "b": 2}, - "id": "call_multi_1", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - # Should succeed - handlers can only return once - assert isinstance(result, dict) - assert len(result["messages"]) == 1 - - -def test_handler_cannot_yield_request_after_tool_message() -> None: - """Test that handler pattern doesn't allow multiple returns (not applicable).""" - # With handler pattern, you can only return once - # This test is no longer relevant - - def single_return_handler( - request: ToolCallRequest, - _execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Handler that returns cached result.""" - # Return cached result (short-circuit) - return ToolMessage("cached", tool_call_id=request.tool_call["id"], name="add") - - tool_node = _ToolNode([add], wrap_tool_call=single_return_handler) - - result = tool_node.invoke( - { - "messages": [ - AIMessage( - "adding", - tool_calls=[ - { - "name": "add", - "args": {"a": 1, "b": 2}, - "id": "call_confused_1", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - # Should succeed with cached result - assert isinstance(result, dict) - assert result["messages"][0].content == "cached" - - -def test_handler_can_short_circuit_with_command() -> None: - """Test that handler can short-circuit by returning Command.""" - - def command_handler( - _request: ToolCallRequest, - _execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Handler that short-circuits with Command.""" - # Short-circuit with Command instead of executing tool - return Command(goto="end") - - tool_node = _ToolNode([add], wrap_tool_call=command_handler) - - result = tool_node.invoke( - { - "messages": [ - AIMessage( - "adding", - tool_calls=[ - { - "name": "add", - "args": {"a": 1, "b": 2}, - "id": "call_cmd_1", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - # Should get Command in result list - assert isinstance(result, list) - assert len(result) == 1 - assert isinstance(result[0], Command) - assert result[0].goto == "end" - - -def test_handler_cannot_yield_multiple_commands() -> None: - """Test that handler can only return once (not applicable to handler pattern).""" - # With handler pattern, you can only return once - # This test is no longer relevant - - def single_command_handler( - _request: ToolCallRequest, - _execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Handler that returns Command once.""" - return Command(goto="step1") - - tool_node = _ToolNode([add], wrap_tool_call=single_command_handler) - - result = tool_node.invoke( - { - "messages": [ - AIMessage( - "adding", - tool_calls=[ - { - "name": "add", - "args": {"a": 1, "b": 2}, - "id": "call_multicmd_1", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - # Should succeed - handlers naturally return once - assert isinstance(result, list) - assert len(result) == 1 - assert isinstance(result[0], Command) - assert result[0].goto == "step1" - - -def test_handler_cannot_yield_request_after_command() -> None: - """Test that handler can only return once (not applicable to handler pattern).""" - # With handler pattern, you can only return once - # This test is no longer relevant - - def command_handler( - _request: ToolCallRequest, - _execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Handler that returns Command.""" - return Command(goto="somewhere") - - tool_node = _ToolNode([add], wrap_tool_call=command_handler) - - result = tool_node.invoke( - { - "messages": [ - AIMessage( - "adding", - tool_calls=[ - { - "name": "add", - "args": {"a": 1, "b": 2}, - "id": "call_cmdreq_1", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - # Should succeed with Command - assert isinstance(result, list) - assert len(result) == 1 - assert isinstance(result[0], Command) - assert result[0].goto == "somewhere" - - -def test_tool_returning_command_sent_to_handler() -> None: - """Test that when tool returns Command, it's sent to handler.""" - received_commands = [] - - def command_inspector_handler( - request: ToolCallRequest, - execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Handler that inspects Command returned by tool.""" - result = execute(request) - # Should receive Command from tool - if isinstance(result, Command): - received_commands.append(result) - return result - - tool_node = _ToolNode([command_tool], wrap_tool_call=command_inspector_handler) - - result = tool_node.invoke( - { - "messages": [ - AIMessage( - "navigating", - tool_calls=[ - { - "name": "command_tool", - "args": {"goto": "next_step"}, - "id": "call_cmdtool_1", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - # Handler should have received the Command - assert len(received_commands) == 1 - assert received_commands[0].goto == "next_step" - - # Final result should be the Command in result list - assert isinstance(result, list) - assert len(result) == 1 - assert isinstance(result[0], Command) - assert result[0].goto == "next_step" - - -def test_handler_can_modify_command_from_tool() -> None: - """Test that handler can inspect and modify Command from tool.""" - - def command_modifier_handler( - request: ToolCallRequest, - execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Handler that modifies Command returned by tool.""" - result = execute(request) - # Modify the Command - if isinstance(result, Command): - return Command(goto=f"modified_{result.goto}") - return result - - tool_node = _ToolNode([command_tool], wrap_tool_call=command_modifier_handler) - - result = tool_node.invoke( - { - "messages": [ - AIMessage( - "navigating", - tool_calls=[ - { - "name": "command_tool", - "args": {"goto": "original"}, - "id": "call_cmdmod_1", - } - ], - ) - ] - }, - config=_create_config_with_runtime(), - ) - - # Final result should be the modified Command in result list - assert isinstance(result, list) - assert len(result) == 1 - assert isinstance(result[0], Command) - assert result[0].goto == "modified_original" - - -def test_state_extraction_with_dict_input() -> None: - """Test that state is correctly passed when input is a dict.""" - state_seen = [] - - def state_inspector_handler( - request: ToolCallRequest, - execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Handler that records the state it receives.""" - state_seen.append(request.state) - return execute(request) - - tool_node = _ToolNode([add], wrap_tool_call=state_inspector_handler) - - input_state = { - "messages": [ - AIMessage( - "test", - tool_calls=[{"name": "add", "args": {"a": 1, "b": 2}, "id": "call_1"}], - ) - ], - "other_field": "value", - } - - tool_node.invoke(input_state, config=_create_config_with_runtime()) - - # State should be the dict we passed in - assert len(state_seen) == 1 - assert state_seen[0] == input_state - assert isinstance(state_seen[0], dict) - assert "messages" in state_seen[0] - assert "other_field" in state_seen[0] - assert "__type" not in state_seen[0] - - -def test_state_extraction_with_list_input() -> None: - """Test that state is correctly passed when input is a list.""" - state_seen = [] - - def state_inspector_handler( - request: ToolCallRequest, - execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Handler that records the state it receives.""" - state_seen.append(request.state) - return execute(request) - - tool_node = _ToolNode([add], wrap_tool_call=state_inspector_handler) - - input_state = [ - AIMessage( - "test", - tool_calls=[{"name": "add", "args": {"a": 1, "b": 2}, "id": "call_1"}], - ) - ] - - tool_node.invoke(input_state, config=_create_config_with_runtime()) - - # State should be the list we passed in - assert len(state_seen) == 1 - assert state_seen[0] == input_state - assert isinstance(state_seen[0], list) - - -def test_state_extraction_with_tool_call_with_context() -> None: - """Test that state is correctly extracted from ToolCallWithContext. - - This tests the scenario where ToolNode is invoked via the Send API in - create_agent, which wraps the tool call with additional context including - the graph state. - """ - state_seen = [] - - def state_inspector_handler( - request: ToolCallRequest, - execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Handler that records the state it receives.""" - state_seen.append(request.state) - return execute(request) - - tool_node = _ToolNode([add], wrap_tool_call=state_inspector_handler) - - # Simulate ToolCallWithContext as used by create_agent with Send API - actual_state = { - "messages": [AIMessage("test")], - "thread_model_call_count": 1, - "run_model_call_count": 1, - "custom_field": "custom_value", - } - - tool_call_with_context = { - "__type": "tool_call_with_context", - "tool_call": {"name": "add", "args": {"a": 1, "b": 2}, "id": "call_1", "type": "tool_call"}, - "state": actual_state, - } - - tool_node.invoke(tool_call_with_context, config=_create_config_with_runtime()) - - # State should be the extracted state from ToolCallWithContext, not the wrapper - assert len(state_seen) == 1 - assert state_seen[0] == actual_state - assert isinstance(state_seen[0], dict) - assert "messages" in state_seen[0] - assert "thread_model_call_count" in state_seen[0] - assert "custom_field" in state_seen[0] - # Most importantly, __type should NOT be in the extracted state - assert "__type" not in state_seen[0] - # And tool_call should not be in the state - assert "tool_call" not in state_seen[0] - - -async def test_state_extraction_with_tool_call_with_context_async() -> None: - """Test that state is correctly extracted from ToolCallWithContext in async mode.""" - state_seen = [] - - def state_inspector_handler( - request: ToolCallRequest, - execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Handler that records the state it receives.""" - state_seen.append(request.state) - return execute(request) - - tool_node = _ToolNode([add], wrap_tool_call=state_inspector_handler) - - # Simulate ToolCallWithContext as used by create_agent with Send API - actual_state = { - "messages": [AIMessage("test")], - "thread_model_call_count": 1, - "run_model_call_count": 1, - } - - tool_call_with_context = { - "__type": "tool_call_with_context", - "tool_call": {"name": "add", "args": {"a": 1, "b": 2}, "id": "call_1", "type": "tool_call"}, - "state": actual_state, - } - - await tool_node.ainvoke(tool_call_with_context, config=_create_config_with_runtime()) - - # State should be the extracted state from ToolCallWithContext - assert len(state_seen) == 1 - assert state_seen[0] == actual_state - assert "__type" not in state_seen[0] - assert "tool_call" not in state_seen[0] diff --git a/libs/langchain_v1/uv.lock b/libs/langchain_v1/uv.lock index b9bc5cb1741aa..91d5435922348 100644 --- a/libs/langchain_v1/uv.lock +++ b/libs/langchain_v1/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.10.0, <4.0.0" resolution-markers = [ "python_full_version >= '3.13' and platform_python_implementation == 'PyPy'", @@ -1653,7 +1653,7 @@ requires-dist = [ { name = "langchain-perplexity", marker = "extra == 'perplexity'" }, { name = "langchain-together", marker = "extra == 'together'" }, { name = "langchain-xai", marker = "extra == 'xai'" }, - { name = "langgraph", specifier = ">=1.0.0,<1.1.0" }, + { name = "langgraph", specifier = ">=1.0.2,<1.1.0" }, { name = "pydantic", specifier = ">=2.7.4,<3.0.0" }, ] provides-extras = ["community", "anthropic", "openai", "google-vertexai", "google-genai", "fireworks", "ollama", "together", "mistralai", "huggingface", "groq", "aws", "deepseek", "xai", "perplexity"] @@ -2123,7 +2123,7 @@ wheels = [ [[package]] name = "langgraph" -version = "1.0.0" +version = "1.0.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "langchain-core" }, @@ -2133,35 +2133,35 @@ dependencies = [ { name = "pydantic" }, { name = "xxhash" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/57/f7/7ae10f1832ab1a6a402f451e54d6dab277e28e7d4e4204e070c7897ca71c/langgraph-1.0.0.tar.gz", hash = "sha256:5f83ed0e9bbcc37635bc49cbc9b3d9306605fa07504f955b7a871ed715f9964c", size = 472835, upload-time = "2025-10-17T20:23:38.263Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0e/25/18e6e056ee1a8af64fcab441b4a3f2e158399935b08f148c7718fc42ecdb/langgraph-1.0.2.tar.gz", hash = "sha256:dae1af08d6025cb1fcaed68f502c01af7d634d9044787c853a46c791cfc52f67", size = 482660, upload-time = "2025-10-29T18:38:28.374Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/07/42/6f6d0fe4eb661b06da8e6c59e58044e9e4221fdbffdcacae864557de961e/langgraph-1.0.0-py3-none-any.whl", hash = "sha256:4d478781832a1bc67e06c3eb571412ec47d7c57a5467d1f3775adf0e9dd4042c", size = 155416, upload-time = "2025-10-17T20:23:36.978Z" }, + { url = "https://files.pythonhosted.org/packages/d7/b1/9f4912e13d4ed691f2685c8a4b764b5a9237a30cca0c5782bc213d9f0a9a/langgraph-1.0.2-py3-none-any.whl", hash = "sha256:b3d56b8c01de857b5fb1da107e8eab6e30512a377685eeedb4f76456724c9729", size = 156751, upload-time = "2025-10-29T18:38:26.577Z" }, ] [[package]] name = "langgraph-checkpoint" -version = "2.1.1" +version = "3.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "langchain-core" }, { name = "ormsgpack" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/73/3e/d00eb2b56c3846a0cabd2e5aa71c17a95f882d4f799a6ffe96a19b55eba9/langgraph_checkpoint-2.1.1.tar.gz", hash = "sha256:72038c0f9e22260cb9bff1f3ebe5eb06d940b7ee5c1e4765019269d4f21cf92d", size = 136256, upload-time = "2025-07-17T13:07:52.411Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b7/cb/2a6dad2f0a14317580cc122e2a60e7f0ecabb50aaa6dc5b7a6a2c94cead7/langgraph_checkpoint-3.0.0.tar.gz", hash = "sha256:f738695ad938878d8f4775d907d9629e9fcd345b1950196effb08f088c52369e", size = 132132, upload-time = "2025-10-20T18:35:49.132Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/4c/dd/64686797b0927fb18b290044be12ae9d4df01670dce6bb2498d5ab65cb24/langgraph_checkpoint-2.1.1-py3-none-any.whl", hash = "sha256:5a779134fd28134a9a83d078be4450bbf0e0c79fdf5e992549658899e6fc5ea7", size = 43925, upload-time = "2025-07-17T13:07:51.023Z" }, + { url = "https://files.pythonhosted.org/packages/85/2a/2efe0b5a72c41e3a936c81c5f5d8693987a1b260287ff1bbebaae1b7b888/langgraph_checkpoint-3.0.0-py3-none-any.whl", hash = "sha256:560beb83e629784ab689212a3d60834fb3196b4bbe1d6ac18e5cad5d85d46010", size = 46060, upload-time = "2025-10-20T18:35:48.255Z" }, ] [[package]] name = "langgraph-prebuilt" -version = "1.0.0" +version = "1.0.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "langchain-core" }, { name = "langgraph-checkpoint" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/02/2d/934b1129e217216a0dfaf0f7df0a10cedf2dfafe6cc8e1ee238cafaaa4a7/langgraph_prebuilt-1.0.0.tar.gz", hash = "sha256:eb75dad9aca0137451ca0395aa8541a665b3f60979480b0431d626fd195dcda2", size = 119927, upload-time = "2025-10-17T20:15:21.429Z" } +sdist = { url = "https://files.pythonhosted.org/packages/33/2f/b940590436e07b3450fe6d791aad5e581363ad536c4f1771e3ba46530268/langgraph_prebuilt-1.0.2.tar.gz", hash = "sha256:9896dbabf04f086eb59df4294f54ab5bdb21cd78e27e0a10e695dffd1cc6097d", size = 142075, upload-time = "2025-10-29T18:29:00.401Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/33/2e/ffa698eedc4c355168a9207ee598b2cc74ede92ce2b55c3469ea06978b6e/langgraph_prebuilt-1.0.0-py3-none-any.whl", hash = "sha256:ceaae4c5cee8c1f9b6468f76c114cafebb748aed0c93483b7c450e5a89de9c61", size = 28455, upload-time = "2025-10-17T20:15:20.043Z" }, + { url = "https://files.pythonhosted.org/packages/27/2f/9a7d00d4afa036e65294059c7c912002fb72ba5dbbd5c2a871ca06360278/langgraph_prebuilt-1.0.2-py3-none-any.whl", hash = "sha256:d9499f7c449fb637ee7b87e3f6a3b74095f4202053c74d33894bd839ea4c57c7", size = 34286, upload-time = "2025-10-29T18:28:59.26Z" }, ] [[package]] diff --git a/libs/partners/anthropic/langchain_anthropic/middleware/anthropic_tools.py b/libs/partners/anthropic/langchain_anthropic/middleware/anthropic_tools.py index 8d2ff27520c97..b72dd1a62c854 100644 --- a/libs/partners/anthropic/langchain_anthropic/middleware/anthropic_tools.py +++ b/libs/partners/anthropic/langchain_anthropic/middleware/anthropic_tools.py @@ -25,7 +25,7 @@ if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Sequence - from langchain.tools.tool_node import ToolCallRequest + from langchain.agents.middleware.types import ToolCallRequest # Tool type constants TEXT_EDITOR_TOOL_TYPE = "text_editor_20250728" diff --git a/libs/partners/anthropic/langchain_anthropic/middleware/bash.py b/libs/partners/anthropic/langchain_anthropic/middleware/bash.py index 191a36234923a..61184b1a037a4 100644 --- a/libs/partners/anthropic/langchain_anthropic/middleware/bash.py +++ b/libs/partners/anthropic/langchain_anthropic/middleware/bash.py @@ -6,8 +6,11 @@ from typing import Any, Literal from langchain.agents.middleware.shell_tool import ShellToolMiddleware -from langchain.agents.middleware.types import ModelRequest, ModelResponse -from langchain.tools.tool_node import ToolCallRequest +from langchain.agents.middleware.types import ( + ModelRequest, + ModelResponse, + ToolCallRequest, +) from langchain_core.messages import ToolMessage from langgraph.types import Command diff --git a/libs/partners/anthropic/tests/unit_tests/middleware/test_bash.py b/libs/partners/anthropic/tests/unit_tests/middleware/test_bash.py index 9a1c04cbad5d0..ca47a241dfafb 100644 --- a/libs/partners/anthropic/tests/unit_tests/middleware/test_bash.py +++ b/libs/partners/anthropic/tests/unit_tests/middleware/test_bash.py @@ -9,7 +9,7 @@ "anthropic", reason="Anthropic SDK is required for Claude middleware tests" ) -from langchain.tools.tool_node import ToolCallRequest +from langchain.agents.middleware.types import ToolCallRequest from langchain_core.messages import ToolMessage from langchain_anthropic.middleware.bash import ClaudeBashToolMiddleware diff --git a/libs/partners/anthropic/uv.lock b/libs/partners/anthropic/uv.lock index 25e50c6823710..22499e5b6562e 100644 --- a/libs/partners/anthropic/uv.lock +++ b/libs/partners/anthropic/uv.lock @@ -274,7 +274,7 @@ name = "exceptiongroup" version = "1.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" } wheels = [ @@ -477,7 +477,7 @@ wheels = [ [[package]] name = "langchain" -version = "1.0.0rc2" +version = "1.0.2" source = { editable = "../../langchain_v1" } dependencies = [ { name = "langchain-core" }, @@ -496,16 +496,17 @@ requires-dist = [ { name = "langchain-google-genai", marker = "extra == 'google-genai'" }, { name = "langchain-google-vertexai", marker = "extra == 'google-vertexai'" }, { name = "langchain-groq", marker = "extra == 'groq'" }, + { name = "langchain-huggingface", marker = "extra == 'huggingface'" }, { name = "langchain-mistralai", marker = "extra == 'mistralai'" }, { name = "langchain-ollama", marker = "extra == 'ollama'" }, { name = "langchain-openai", marker = "extra == 'openai'", editable = "../openai" }, { name = "langchain-perplexity", marker = "extra == 'perplexity'" }, { name = "langchain-together", marker = "extra == 'together'" }, { name = "langchain-xai", marker = "extra == 'xai'" }, - { name = "langgraph", specifier = ">=1.0.0a4,<2.0.0" }, + { name = "langgraph", specifier = ">=1.0.2,<1.1.0" }, { name = "pydantic", specifier = ">=2.7.4,<3.0.0" }, ] -provides-extras = ["community", "anthropic", "openai", "google-vertexai", "google-genai", "fireworks", "ollama", "together", "mistralai", "groq", "aws", "deepseek", "xai", "perplexity"] +provides-extras = ["community", "anthropic", "openai", "google-vertexai", "google-genai", "fireworks", "ollama", "together", "mistralai", "huggingface", "groq", "aws", "deepseek", "xai", "perplexity"] [package.metadata.requires-dev] lint = [{ name = "ruff", specifier = ">=0.12.2,<0.13.0" }] @@ -621,7 +622,7 @@ typing = [ [[package]] name = "langchain-core" -version = "1.0.0" +version = "1.0.1" source = { editable = "../../core" } dependencies = [ { name = "jsonpatch" }, @@ -724,7 +725,7 @@ typing = [ [[package]] name = "langgraph" -version = "1.0.0rc1" +version = "1.0.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "langchain-core" }, @@ -734,35 +735,35 @@ dependencies = [ { name = "pydantic" }, { name = "xxhash" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/06/6b/27863bc2197fe2cae5ae3241a79e1868b98f8dfa03991eea4f607dba177a/langgraph-1.0.0rc1.tar.gz", hash = "sha256:0acc0eddbed6b353334a93de6943bb49820054cf14e1ca7dab0a91ac7add1ce2", size = 466052, upload-time = "2025-10-17T00:56:12.222Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0e/25/18e6e056ee1a8af64fcab441b4a3f2e158399935b08f148c7718fc42ecdb/langgraph-1.0.2.tar.gz", hash = "sha256:dae1af08d6025cb1fcaed68f502c01af7d634d9044787c853a46c791cfc52f67", size = 482660, upload-time = "2025-10-29T18:38:28.374Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ee/5b/7d4aaf30f8c08d14ec16a49e89530c05ab9ccd7c00a54da6bd8adabefd26/langgraph-1.0.0rc1-py3-none-any.whl", hash = "sha256:9d84da21ae8bcc5b05dfa2e63396eb642d39c54d670406b7319810ede0c5ab26", size = 155229, upload-time = "2025-10-17T00:56:10.956Z" }, + { url = "https://files.pythonhosted.org/packages/d7/b1/9f4912e13d4ed691f2685c8a4b764b5a9237a30cca0c5782bc213d9f0a9a/langgraph-1.0.2-py3-none-any.whl", hash = "sha256:b3d56b8c01de857b5fb1da107e8eab6e30512a377685eeedb4f76456724c9729", size = 156751, upload-time = "2025-10-29T18:38:26.577Z" }, ] [[package]] name = "langgraph-checkpoint" -version = "2.1.2" +version = "3.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "langchain-core" }, { name = "ormsgpack" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/29/83/6404f6ed23a91d7bc63d7df902d144548434237d017820ceaa8d014035f2/langgraph_checkpoint-2.1.2.tar.gz", hash = "sha256:112e9d067a6eff8937caf198421b1ffba8d9207193f14ac6f89930c1260c06f9", size = 142420, upload-time = "2025-10-07T17:45:17.129Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b7/cb/2a6dad2f0a14317580cc122e2a60e7f0ecabb50aaa6dc5b7a6a2c94cead7/langgraph_checkpoint-3.0.0.tar.gz", hash = "sha256:f738695ad938878d8f4775d907d9629e9fcd345b1950196effb08f088c52369e", size = 132132, upload-time = "2025-10-20T18:35:49.132Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c4/f2/06bf5addf8ee664291e1b9ffa1f28fc9d97e59806dc7de5aea9844cbf335/langgraph_checkpoint-2.1.2-py3-none-any.whl", hash = "sha256:911ebffb069fd01775d4b5184c04aaafc2962fcdf50cf49d524cd4367c4d0c60", size = 45763, upload-time = "2025-10-07T17:45:16.19Z" }, + { url = "https://files.pythonhosted.org/packages/85/2a/2efe0b5a72c41e3a936c81c5f5d8693987a1b260287ff1bbebaae1b7b888/langgraph_checkpoint-3.0.0-py3-none-any.whl", hash = "sha256:560beb83e629784ab689212a3d60834fb3196b4bbe1d6ac18e5cad5d85d46010", size = 46060, upload-time = "2025-10-20T18:35:48.255Z" }, ] [[package]] name = "langgraph-prebuilt" -version = "0.7.0rc1" +version = "1.0.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "langchain-core" }, { name = "langgraph-checkpoint" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1d/d3/6474eecd1cc95cead25fe0f1717f0b76e1f6edbc8631fc773f6bf7ac03ad/langgraph_prebuilt-0.7.0rc1.tar.gz", hash = "sha256:23f2c1c0a3f0c643a45f90d99ac951a5e1d1be3e711ae10d91b03e34c05b306b", size = 114860, upload-time = "2025-10-17T00:51:56.719Z" } +sdist = { url = "https://files.pythonhosted.org/packages/33/2f/b940590436e07b3450fe6d791aad5e581363ad536c4f1771e3ba46530268/langgraph_prebuilt-1.0.2.tar.gz", hash = "sha256:9896dbabf04f086eb59df4294f54ab5bdb21cd78e27e0a10e695dffd1cc6097d", size = 142075, upload-time = "2025-10-29T18:29:00.401Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e1/70/6a46ebd63304028617b1de1377a159dedd95f0ce2e68a6d8b50c3378448c/langgraph_prebuilt-0.7.0rc1-py3-none-any.whl", hash = "sha256:7a2032683f1cab23d19f11f89805bc84b668502c491b6f041afb6932c8870e70", size = 28387, upload-time = "2025-10-17T00:51:55.75Z" }, + { url = "https://files.pythonhosted.org/packages/27/2f/9a7d00d4afa036e65294059c7c912002fb72ba5dbbd5c2a871ca06360278/langgraph_prebuilt-1.0.2-py3-none-any.whl", hash = "sha256:d9499f7c449fb637ee7b87e3f6a3b74095f4202053c74d33894bd839ea4c57c7", size = 34286, upload-time = "2025-10-29T18:28:59.26Z" }, ] [[package]]