From c7fef8324ae4707a7fae13f5ad48843ae29ea3cb Mon Sep 17 00:00:00 2001 From: Lev Neiman Date: Wed, 11 Mar 2026 10:42:58 -0700 Subject: [PATCH 1/5] feat: add langgraph tool node integration --- sdks/python/pyproject.toml | 2 + .../integrations/langgraph/__init__.py | 19 + .../integrations/langgraph/tool_node.py | 450 ++++++++++++++++++ sdks/python/tests/test_langgraph_init.py | 55 +++ sdks/python/tests/test_langgraph_tool_node.py | 320 +++++++++++++ 5 files changed, 846 insertions(+) create mode 100644 sdks/python/src/agent_control/integrations/langgraph/__init__.py create mode 100644 sdks/python/src/agent_control/integrations/langgraph/tool_node.py create mode 100644 sdks/python/tests/test_langgraph_init.py create mode 100644 sdks/python/tests/test_langgraph_tool_node.py diff --git a/sdks/python/pyproject.toml b/sdks/python/pyproject.toml index c5c1f203..587b1236 100644 --- a/sdks/python/pyproject.toml +++ b/sdks/python/pyproject.toml @@ -35,6 +35,7 @@ Documentation = "https://github.com/yourusername/agent-control#readme" Repository = "https://github.com/yourusername/agent-control" [project.optional-dependencies] +langgraph = ["langgraph>=0.2.0"] strands-agents = ["strands-agents>=1.26.0"] galileo = ["agent-control-evaluator-galileo>=3.0.0"] @@ -48,6 +49,7 @@ dev = [ "agent-control-models", "agent-control-engine", "agent-control-evaluators", + "langgraph>=0.2.0", # For langgraph integration tests "strands-agents>=1.26.0", # For strands integration tests ] diff --git a/sdks/python/src/agent_control/integrations/langgraph/__init__.py b/sdks/python/src/agent_control/integrations/langgraph/__init__.py new file mode 100644 index 00000000..84c317d8 --- /dev/null +++ b/sdks/python/src/agent_control/integrations/langgraph/__init__.py @@ -0,0 +1,19 @@ +"""LangGraph integration for Agent Control.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from .tool_node import create_controlled_tool_node + +__all__ = ["create_controlled_tool_node"] + + +def __getattr__(name: str) -> Any: + """Lazy import to avoid import errors when langgraph is not installed.""" + if name == "create_controlled_tool_node": + from .tool_node import create_controlled_tool_node + + return create_controlled_tool_node + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/sdks/python/src/agent_control/integrations/langgraph/tool_node.py b/sdks/python/src/agent_control/integrations/langgraph/tool_node.py new file mode 100644 index 00000000..efa2979e --- /dev/null +++ b/sdks/python/src/agent_control/integrations/langgraph/tool_node.py @@ -0,0 +1,450 @@ +"""Agent Control ToolNode integration for LangGraph.""" + +from __future__ import annotations + +import asyncio +import inspect +import logging +import threading +from collections.abc import Awaitable, Callable, Coroutine, Sequence +from dataclasses import dataclass, field +from typing import Any, Literal, TypeVar, cast + +from agent_control_models import Agent, EvaluationResult + +import agent_control +from agent_control import AgentControlClient, agents +from agent_control._state import state as sdk_state +from agent_control.validation import ensure_agent_name + +try: + from langchain_core.messages import ToolMessage + from langchain_core.tools import BaseTool + from langgraph.prebuilt import ToolNode + from langgraph.types import Command +except Exception as exc: # pragma: no cover - optional dependency + raise RuntimeError( + "LangGraph integration requires langgraph. " + "Install with: agent-control-sdk[langgraph]." + ) from exc + +logger = logging.getLogger(__name__) + +T = TypeVar("T") +ToolHandlerResult = ToolMessage | Command[Any] +ToolHandler = Callable[[Any], ToolHandlerResult | Awaitable[ToolHandlerResult]] +_TOOLNODE_DEFAULT = object() + + +def _run_coro_in_new_loop[T](coro: Coroutine[Any, Any, T]) -> T: + """Run a coroutine on a dedicated event loop in the current thread.""" + loop = asyncio.new_event_loop() + try: + asyncio.set_event_loop(loop) + return loop.run_until_complete(coro) + finally: + loop.close() + asyncio.set_event_loop(None) + + +def _run_coro_sync[T](coro: Coroutine[Any, Any, T]) -> T: + """Run an async coroutine from a synchronous context.""" + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(coro) + + result_container: list[T | None] = [None] + exception_container: list[Exception | None] = [None] + + def run_in_thread() -> None: + try: + result_container[0] = _run_coro_in_new_loop(coro) + except Exception as exc: # pragma: no cover - thread path is hard to force + exception_container[0] = exc + + thread = threading.Thread(target=run_in_thread, daemon=True) + thread.start() + thread.join(timeout=10) + + if exception_container[0] is not None: + raise exception_container[0] + if result_container[0] is None: + raise RuntimeError("Timed out while waiting for LangGraph integration coroutine.") + return result_container[0] + + +def _tool_block_prefix(tool_name: str) -> str: + return f"Tool '{tool_name}' was blocked due to security policy." + + +def _tool_eval_error_prefix(tool_name: str) -> str: + return f"Tool '{tool_name}' was blocked because policy evaluation failed." + + +def _build_error_tool_message( + *, + tool_name: str, + tool_call_id: str, + content: str, +) -> ToolMessage: + """Build an error ToolMessage preserving the original tool identifiers.""" + return ToolMessage( + content=content, + name=tool_name, + tool_call_id=tool_call_id, + status="error", + ) + + +def _append_detail(prefix: str, detail: str | None) -> str: + """Append optional detail to a message prefix.""" + if detail is None: + return prefix + normalized = detail.strip() + if not normalized: + return prefix + return f"{prefix} {normalized}" + + +def _safe_tool_schema(tool: BaseTool, getter_name: str) -> dict[str, Any] | None: + """Extract a JSON schema from a tool getter, logging and falling back to None.""" + getter = getattr(tool, getter_name, None) + if not callable(getter): + return None + + try: + schema = getter() + except Exception as exc: + logger.warning( + "Failed to extract %s for tool '%s': %s", + getter_name, + tool.name, + exc, + ) + return None + + if not isinstance(schema, dict): + logger.warning( + "Ignoring non-dict %s for tool '%s': %r", + getter_name, + tool.name, + schema, + ) + return None + + return schema + + +def _normalize_tool_output(output: ToolHandlerResult) -> Any: + """Normalize LangGraph tool output into evaluation-friendly payloads.""" + if isinstance(output, ToolMessage): + return output.content + if isinstance(output, Command): + command_payload = { + "graph": output.graph, + "update": output.update, + "resume": output.resume, + "goto": output.goto, + } + return {key: value for key, value in command_payload.items() if value is not None} + return str(output) + + +def _first_result_message(result: EvaluationResult) -> str | None: + """Return the first available control result message from matches.""" + for match in result.matches or []: + message = match.result.message + if message: + return message + return result.reason + + +def _result_to_tool_message( + *, + result: EvaluationResult, + tool_name: str, + tool_call_id: str, +) -> ToolMessage | None: + """Convert an evaluation result into a blocking ToolMessage when required.""" + if result.errors: + error_names = ", ".join(error.control_name for error in result.errors) + content = _append_detail(_tool_eval_error_prefix(tool_name), f"Errors: {error_names}") + return _build_error_tool_message( + tool_name=tool_name, + tool_call_id=tool_call_id, + content=content, + ) + + matches = result.matches or [] + deny_match = next((match for match in matches if match.action == "deny"), None) + if deny_match is not None: + content = _append_detail(_tool_block_prefix(tool_name), deny_match.result.message) + return _build_error_tool_message( + tool_name=tool_name, + tool_call_id=tool_call_id, + content=content, + ) + + steer_match = next((match for match in matches if match.action == "steer"), None) + if steer_match is not None: + guidance = None + if steer_match.steering_context is not None: + guidance = steer_match.steering_context.message + if not guidance: + guidance = steer_match.result.message or result.reason + content = _tool_block_prefix(tool_name) + if guidance: + content = f"{content} Guidance: {guidance}" + return _build_error_tool_message( + tool_name=tool_name, + tool_call_id=tool_call_id, + content=content, + ) + + if not result.is_safe: + content = _append_detail(_tool_block_prefix(tool_name), _first_result_message(result)) + return _build_error_tool_message( + tool_name=tool_name, + tool_call_id=tool_call_id, + content=content, + ) + + return None + + +@dataclass +class _LangGraphControlState: + """Holds tool registration and evaluation state for a controlled ToolNode.""" + + tool_node: ToolNode + agent: Agent + observed_tool_names: set[str] = field(default_factory=set) + lock: threading.Lock = field(default_factory=threading.Lock) + + def _current_tool_names(self) -> set[str]: + return set(self.tool_node.tools_by_name) + + def _normalized_steps(self) -> list[dict[str, Any]]: + steps: list[dict[str, Any]] = [] + for tool in self.tool_node.tools_by_name.values(): + description = tool.description.strip() if tool.description else None + step: dict[str, Any] = { + "type": "tool", + "name": tool.name, + } + if description: + step["description"] = description + input_schema = _safe_tool_schema(tool, "get_input_jsonschema") + if input_schema is not None: + step["input_schema"] = input_schema + output_schema = _safe_tool_schema(tool, "get_output_jsonschema") + if output_schema is not None: + step["output_schema"] = output_schema + steps.append(step) + return steps + + async def _register_async(self) -> None: + if sdk_state.server_url is None: + raise RuntimeError( + "Agent Control server URL is not configured. " + "Call agent_control.init()." + ) + + async with AgentControlClient( + base_url=sdk_state.server_url, + api_key=sdk_state.api_key, + ) as client: + await agents.register_agent( + client, + self.agent, + steps=self._normalized_steps(), + conflict_mode="overwrite", + ) + + async def _register_and_refresh_async(self) -> None: + await self._register_async() + await agent_control.refresh_controls_async() + + def _register_and_refresh_sync(self) -> None: + _run_coro_sync(self._register_async()) + agent_control.refresh_controls() + + async def ensure_registered_async(self) -> None: + current_tool_names = self._current_tool_names() + with self.lock: + should_register = current_tool_names != self.observed_tool_names + + if not should_register: + return + + await self._register_and_refresh_async() + with self.lock: + self.observed_tool_names = current_tool_names + + def ensure_registered_sync(self) -> None: + current_tool_names = self._current_tool_names() + with self.lock: + should_register = current_tool_names != self.observed_tool_names + + if not should_register: + return + + self._register_and_refresh_sync() + with self.lock: + self.observed_tool_names = current_tool_names + + async def evaluate_request( + self, + *, + tool_name: str, + tool_call_id: str, + tool_input: dict[str, Any], + stage: Literal["pre", "post"], + output: Any = None, + ) -> ToolMessage | None: + try: + result = await agent_control.evaluate_controls( + step_name=tool_name, + input=tool_input, + output=output, + step_type="tool", + stage=stage, + agent_name=self.agent.agent_name, + ) + except Exception as exc: + logger.error( + "Agent Control evaluation failed for tool '%s' (%s stage): %s", + tool_name, + stage, + exc, + exc_info=True, + ) + return _build_error_tool_message( + tool_name=tool_name, + tool_call_id=tool_call_id, + content=_append_detail(_tool_eval_error_prefix(tool_name), str(exc)), + ) + + return _result_to_tool_message( + result=result, + tool_name=tool_name, + tool_call_id=tool_call_id, + ) + + async def run_with_controls( + self, + request: Any, + handler: ToolHandler, + ) -> ToolHandlerResult: + tool_name = request.tool_call["name"] + tool_call_id = str(request.tool_call.get("id", tool_name)) + raw_args = request.tool_call.get("args", {}) + tool_input = raw_args if isinstance(raw_args, dict) else {"value": raw_args} + + pre_message = await self.evaluate_request( + tool_name=tool_name, + tool_call_id=tool_call_id, + tool_input=tool_input, + stage="pre", + ) + if pre_message is not None: + return pre_message + + response = handler(request) + if inspect.isawaitable(response): + response = await cast(Awaitable[ToolHandlerResult], response) + + post_message = await self.evaluate_request( + tool_name=tool_name, + tool_call_id=tool_call_id, + tool_input=tool_input, + stage="post", + output=_normalize_tool_output(cast(ToolHandlerResult, response)), + ) + if post_message is not None: + return post_message + + return cast(ToolHandlerResult, response) + + +def create_controlled_tool_node( + tools: Sequence[BaseTool | Callable[..., Any]], + *, + agent_name: str | None = None, + name: str = "tools", + tags: list[str] | None = None, + handle_tool_errors: Any = _TOOLNODE_DEFAULT, + messages_key: str = "messages", +) -> ToolNode: + """Create a ToolNode configured with Agent Control policy enforcement.""" + current_agent = agent_control.current_agent() + if current_agent is None: + raise RuntimeError("Agent Control is not initialized. Call agent_control.init() first.") + + resolved_agent = current_agent + if agent_name is not None: + resolved_agent = current_agent.model_copy( + update={"agent_name": ensure_agent_name(agent_name)} + ) + + state_ref: dict[str, _LangGraphControlState] = {} + + def wrap_tool_call( + request: Any, + handler: Callable[[Any], ToolHandlerResult], + ) -> ToolHandlerResult: + state = state_ref["state"] + try: + state.ensure_registered_sync() + except Exception as exc: + logger.error( + "LangGraph ToolNode re-registration failed; " + "continuing without blocking tool '%s': %s", + request.tool_call["name"], + exc, + exc_info=True, + ) + return _run_coro_sync(state.run_with_controls(request, handler)) + + async def awrap_tool_call( + request: Any, + handler: Callable[[Any], Awaitable[ToolHandlerResult]], + ) -> ToolHandlerResult: + state = state_ref["state"] + try: + await state.ensure_registered_async() + except Exception as exc: + logger.error( + "LangGraph ToolNode async re-registration failed; " + "continuing without blocking tool '%s': %s", + request.tool_call["name"], + exc, + exc_info=True, + ) + return await state.run_with_controls(request, handler) + + tool_node_kwargs: dict[str, Any] = { + "name": name, + "tags": tags, + "messages_key": messages_key, + "wrap_tool_call": wrap_tool_call, + "awrap_tool_call": awrap_tool_call, + } + if handle_tool_errors is not _TOOLNODE_DEFAULT: + tool_node_kwargs["handle_tool_errors"] = handle_tool_errors + + tool_node = ToolNode(tools, **tool_node_kwargs) + state = _LangGraphControlState(tool_node=tool_node, agent=resolved_agent) + state_ref["state"] = state + + try: + state.ensure_registered_sync() + except Exception as exc: + logger.error( + "Initial LangGraph ToolNode registration failed; continuing without blocking tools: %s", + exc, + exc_info=True, + ) + + return tool_node diff --git a/sdks/python/tests/test_langgraph_init.py b/sdks/python/tests/test_langgraph_init.py new file mode 100644 index 00000000..d08d5897 --- /dev/null +++ b/sdks/python/tests/test_langgraph_init.py @@ -0,0 +1,55 @@ +"""Unit tests for LangGraph integration __init__.py.""" + +from __future__ import annotations + +import sys +from unittest.mock import patch + +import pytest + + +def test_langgraph_init_exports(): + """Test that __init__.py exports the expected factory.""" + from agent_control.integrations.langgraph import create_controlled_tool_node + + assert callable(create_controlled_tool_node) + + +def test_langgraph_init_all(): + """Test that __all__ contains expected exports.""" + import agent_control.integrations.langgraph as langgraph_module + + assert hasattr(langgraph_module, "__all__") + assert langgraph_module.__all__ == ["create_controlled_tool_node"] + + +def test_lazy_import_create_controlled_tool_node(): + """Test lazy import of create_controlled_tool_node via __getattr__.""" + from agent_control.integrations.langgraph import create_controlled_tool_node + + assert create_controlled_tool_node.__name__ == "create_controlled_tool_node" + + +def test_missing_langgraph_dependency_raises_runtime_error(): + """Test that missing optional deps surface a helpful install message.""" + import agent_control.integrations.langgraph as langgraph_module + + sys.modules.pop("agent_control.integrations.langgraph.tool_node", None) + real_import = __import__ + + def fake_import(name: str, *args: object, **kwargs: object): + if name.startswith("langgraph") or name.startswith("langchain_core"): + raise ModuleNotFoundError(name) + return real_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=fake_import): + with pytest.raises(RuntimeError, match=r"agent-control-sdk\[langgraph\]"): + langgraph_module.__getattr__("create_controlled_tool_node") + + +def test_invalid_attribute_raises_error(): + """Test that accessing invalid attribute raises AttributeError.""" + import agent_control.integrations.langgraph as langgraph_module + + with pytest.raises(AttributeError, match="has no attribute 'InvalidFactory'"): + _ = langgraph_module.InvalidFactory # type: ignore[attr-defined] diff --git a/sdks/python/tests/test_langgraph_tool_node.py b/sdks/python/tests/test_langgraph_tool_node.py new file mode 100644 index 00000000..84ea7434 --- /dev/null +++ b/sdks/python/tests/test_langgraph_tool_node.py @@ -0,0 +1,320 @@ +"""Unit tests for LangGraph ToolNode integration.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest +from agent_control_models import Agent, ControlMatch, EvaluationResult, EvaluatorResult, SteeringContext +from langchain_core.messages import ToolMessage +from langchain_core.tools import tool +from langgraph.prebuilt.tool_node import ToolCallRequest, ToolRuntime + +from agent_control.integrations.langgraph.tool_node import create_controlled_tool_node + + +class _DummyClient: + """Minimal async client used to patch AgentControlClient in tests.""" + + def __init__(self, *, base_url: str, api_key: str | None) -> None: + self.base_url = base_url + self.api_key = api_key + + async def __aenter__(self) -> "_DummyClient": + return self + + async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None: + return None + + +@tool +def echo_tool(text: str) -> str: + """Echo text.""" + return text + + +@tool +def other_tool(text: str) -> str: + """Return a second tool result.""" + return f"other:{text}" + + +def _tool_request(tool_obj: Any, *, text: str = "hello", tool_call_id: str = "call-1") -> ToolCallRequest: + runtime = ToolRuntime( + state={}, + context=None, + config={}, + stream_writer=lambda *_args, **_kwargs: None, + tool_call_id=tool_call_id, + store=None, + ) + tool_call = { + "name": tool_obj.name, + "args": {"text": text}, + "id": tool_call_id, + "type": "tool_call", + } + return ToolCallRequest(tool_call=tool_call, tool=tool_obj, state={}, runtime=runtime) + + +def _safe_result() -> EvaluationResult: + return EvaluationResult(is_safe=True, confidence=1.0, reason=None, matches=None, errors=None) + + +def _deny_result(message: str = "Not allowed") -> EvaluationResult: + return EvaluationResult( + is_safe=False, + confidence=1.0, + reason=None, + matches=[ + ControlMatch( + control_id=1, + control_name="deny-control", + action="deny", + result=EvaluatorResult(matched=True, confidence=1.0, message=message), + ) + ], + errors=None, + ) + + +def _steer_result(message: str = "Adjust the tool call") -> EvaluationResult: + return EvaluationResult( + is_safe=False, + confidence=1.0, + reason=None, + matches=[ + ControlMatch( + control_id=2, + control_name="steer-control", + action="steer", + result=EvaluatorResult(matched=True, confidence=1.0, message="Steer required"), + steering_context=SteeringContext(message=message), + ) + ], + errors=None, + ) + + +def _error_result() -> EvaluationResult: + return EvaluationResult( + is_safe=True, + confidence=1.0, + reason=None, + matches=None, + errors=[ + ControlMatch( + control_id=3, + control_name="error-control", + action="log", + result=EvaluatorResult( + matched=False, + confidence=0.0, + message="broken", + error="broken", + ), + ) + ], + ) + + +@pytest.fixture +def langgraph_env(monkeypatch: pytest.MonkeyPatch) -> dict[str, Any]: + """Patch Agent Control environment and registration dependencies.""" + from agent_control.integrations.langgraph import tool_node as module + + current_agent = Agent( + agent_name="test-agent-123", + agent_description="Test agent", + agent_version="1.0.0", + ) + register_agent = AsyncMock(return_value={}) + refresh_controls = MagicMock(return_value=[]) + refresh_controls_async = AsyncMock(return_value=[]) + evaluate_controls = AsyncMock(return_value=_safe_result()) + + monkeypatch.setattr(module.agent_control, "current_agent", lambda: current_agent) + monkeypatch.setattr(module.sdk_state, "server_url", "http://example.test") + monkeypatch.setattr(module.sdk_state, "api_key", "secret") + monkeypatch.setattr(module, "AgentControlClient", _DummyClient) + monkeypatch.setattr(module.agents, "register_agent", register_agent) + monkeypatch.setattr(module.agent_control, "refresh_controls", refresh_controls) + monkeypatch.setattr(module.agent_control, "refresh_controls_async", refresh_controls_async) + monkeypatch.setattr(module.agent_control, "evaluate_controls", evaluate_controls) + + return { + "module": module, + "agent": current_agent, + "register_agent": register_agent, + "refresh_controls": refresh_controls, + "refresh_controls_async": refresh_controls_async, + "evaluate_controls": evaluate_controls, + } + + +def test_initial_registration_sends_full_tool_list(langgraph_env: dict[str, Any]) -> None: + node = create_controlled_tool_node([echo_tool, other_tool]) + + assert node.tools_by_name.keys() == {"echo_tool", "other_tool"} + register_agent = langgraph_env["register_agent"] + assert register_agent.await_count == 1 + + _, kwargs = register_agent.await_args + assert kwargs["conflict_mode"] == "overwrite" + assert {step["name"] for step in kwargs["steps"]} == {"echo_tool", "other_tool"} + assert all(step["type"] == "tool" for step in kwargs["steps"]) + + +def test_successful_registration_triggers_public_refresh(langgraph_env: dict[str, Any]) -> None: + create_controlled_tool_node([echo_tool]) + + langgraph_env["refresh_controls"].assert_called_once_with() + + +def test_sync_wrapper_allows_tool_output(langgraph_env: dict[str, Any]) -> None: + node = create_controlled_tool_node([echo_tool]) + request = _tool_request(node.tools_by_name["echo_tool"]) + handler = MagicMock( + return_value=ToolMessage( + content="allowed", + name="echo_tool", + tool_call_id="call-1", + ) + ) + + response = node._wrap_tool_call(request, handler) + + assert isinstance(response, ToolMessage) + assert response.content == "allowed" + assert langgraph_env["evaluate_controls"].await_count == 2 + + +@pytest.mark.asyncio +async def test_async_wrapper_allows_tool_output(langgraph_env: dict[str, Any]) -> None: + node = create_controlled_tool_node([echo_tool]) + request = _tool_request(node.tools_by_name["echo_tool"]) + + async def handler(_: Any) -> ToolMessage: + return ToolMessage(content="allowed", name="echo_tool", tool_call_id="call-1") + + response = await node._awrap_tool_call(request, handler) + + assert isinstance(response, ToolMessage) + assert response.content == "allowed" + assert langgraph_env["evaluate_controls"].await_count == 2 + + +def test_deny_returns_error_tool_message(langgraph_env: dict[str, Any]) -> None: + langgraph_env["evaluate_controls"].return_value = _deny_result("Query blocked") + node = create_controlled_tool_node([echo_tool]) + request = _tool_request(node.tools_by_name["echo_tool"]) + handler = MagicMock() + + response = node._wrap_tool_call(request, handler) + + assert isinstance(response, ToolMessage) + assert response.status == "error" + assert response.tool_call_id == "call-1" + assert "blocked due to security policy" in str(response.content) + assert "Query blocked" in str(response.content) + handler.assert_not_called() + + +@pytest.mark.asyncio +async def test_steer_returns_error_tool_message_with_guidance( + langgraph_env: dict[str, Any] +) -> None: + langgraph_env["evaluate_controls"].return_value = _steer_result("Try a safer query") + node = create_controlled_tool_node([echo_tool]) + request = _tool_request(node.tools_by_name["echo_tool"]) + + async def handler(_: Any) -> ToolMessage: + return ToolMessage(content="allowed", name="echo_tool", tool_call_id="call-1") + + response = await node._awrap_tool_call(request, handler) + + assert isinstance(response, ToolMessage) + assert response.status == "error" + assert "Guidance: Try a safer query" in str(response.content) + + +def test_evaluation_errors_fail_closed(langgraph_env: dict[str, Any]) -> None: + langgraph_env["evaluate_controls"].return_value = _error_result() + node = create_controlled_tool_node([echo_tool]) + request = _tool_request(node.tools_by_name["echo_tool"]) + handler = MagicMock() + + response = node._wrap_tool_call(request, handler) + + assert isinstance(response, ToolMessage) + assert response.status == "error" + assert "policy evaluation failed" in str(response.content) + handler.assert_not_called() + + +def test_evaluation_exception_fails_closed(langgraph_env: dict[str, Any]) -> None: + langgraph_env["evaluate_controls"].side_effect = RuntimeError("boom") + node = create_controlled_tool_node([echo_tool]) + request = _tool_request(node.tools_by_name["echo_tool"]) + handler = MagicMock() + + response = node._wrap_tool_call(request, handler) + + assert isinstance(response, ToolMessage) + assert response.status == "error" + assert "policy evaluation failed" in str(response.content) + handler.assert_not_called() + + +def test_resync_failure_is_fail_open_and_logged( + langgraph_env: dict[str, Any], + caplog: pytest.LogCaptureFixture, +) -> None: + register_agent = langgraph_env["register_agent"] + register_agent.side_effect = [{}, RuntimeError("resync failed")] + node = create_controlled_tool_node([echo_tool]) + node.tools_by_name["other_tool"] = other_tool + request = _tool_request(node.tools_by_name["echo_tool"]) + handler = MagicMock( + return_value=ToolMessage(content="allowed", name="echo_tool", tool_call_id="call-1") + ) + + response = node._wrap_tool_call(request, handler) + + assert isinstance(response, ToolMessage) + assert response.content == "allowed" + assert register_agent.await_count == 2 + assert "re-registration failed" in caplog.text + + +def test_tool_name_change_causes_full_reregistration(langgraph_env: dict[str, Any]) -> None: + node = create_controlled_tool_node([echo_tool]) + node.tools_by_name["other_tool"] = other_tool + request = _tool_request(node.tools_by_name["echo_tool"]) + handler = MagicMock( + return_value=ToolMessage(content="allowed", name="echo_tool", tool_call_id="call-1") + ) + + node._wrap_tool_call(request, handler) + + register_agent = langgraph_env["register_agent"] + assert register_agent.await_count == 2 + _, kwargs = register_agent.await_args + assert {step["name"] for step in kwargs["steps"]} == {"echo_tool", "other_tool"} + + +def test_omitted_agent_name_resolves_from_current_agent(langgraph_env: dict[str, Any]) -> None: + create_controlled_tool_node([echo_tool]) + + args, _ = langgraph_env["register_agent"].await_args + assert args[1].agent_name == langgraph_env["agent"].agent_name + + +def test_missing_initialized_agent_raises(monkeypatch: pytest.MonkeyPatch) -> None: + from agent_control.integrations.langgraph import tool_node as module + + monkeypatch.setattr(module.agent_control, "current_agent", lambda: None) + + with pytest.raises(RuntimeError, match="Call agent_control.init"): + create_controlled_tool_node([echo_tool]) From 682aaeccb4e1dbb0e8577f5d157efa6fda853b07 Mon Sep 17 00:00:00 2001 From: Lev Neiman Date: Wed, 11 Mar 2026 12:40:39 -0700 Subject: [PATCH 2/5] docs: add langgraph integration smoke example --- .../langgraph_toolnode_integration_smoke.py | 136 ++++++++++++++++++ .../setup_langgraph_toolnode_controls.py | 108 ++++++++++++++ 2 files changed, 244 insertions(+) create mode 100644 examples/langchain/langgraph_toolnode_integration_smoke.py create mode 100644 examples/langchain/setup_langgraph_toolnode_controls.py diff --git a/examples/langchain/langgraph_toolnode_integration_smoke.py b/examples/langchain/langgraph_toolnode_integration_smoke.py new file mode 100644 index 00000000..cfc26f6a --- /dev/null +++ b/examples/langchain/langgraph_toolnode_integration_smoke.py @@ -0,0 +1,136 @@ +"""Minimal LangGraph smoke test for the Agent Control ToolNode integration. + +This example proves the LangGraph ToolNode wrapper path works without using +``@control()`` on the underlying tool implementation. + +Run: + cd examples/langchain + uv run setup_langgraph_toolnode_controls.py + uv run langgraph_toolnode_integration_smoke.py + +Prerequisite: + Start the Agent Control server first (`cd server && make run`). +""" + +from __future__ import annotations + +import asyncio +import os +import re +from typing import Annotated, TypedDict + +import agent_control +from agent_control.integrations.langgraph import create_controlled_tool_node +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage +from langchain_core.tools import tool +from langgraph.graph import END, START, StateGraph +from langgraph.graph.message import add_messages + +AGENT_NAME = "langgraph-toolnode-smoke" +AGENT_DESCRIPTION = "Minimal LangGraph ToolNode smoke test using Agent Control integration" + + +class AgentState(TypedDict): + """LangGraph state object.""" + + messages: Annotated[list[BaseMessage], add_messages] + + +@tool("get_weather") +async def get_weather(city: str) -> dict[str, str]: + """Return a deterministic weather response for a city.""" + return { + "city": city, + "forecast": { + "seattle": "Rainy and 53F", + "tehran": "Sunny and 75F", + "tokyo": "Clear and 61F", + }.get(city.lower(), "Partly cloudy and 68F"), + } + + +def _extract_city(user_text: str) -> str: + """Extract a city token from a simple prompt.""" + match = re.search(r"(?:for|in)\s+([A-Za-z][A-Za-z\\s-]*)", user_text) + if match: + return match.group(1).strip().rstrip("?.!") + return user_text.strip().split()[-1].rstrip("?.!") + + +def _build_graph(): + """Build a deterministic graph that always routes to the weather tool.""" + tool_node = create_controlled_tool_node([get_weather]) + + def planner(state: AgentState) -> dict[str, list[AIMessage]]: + user_text = str(state["messages"][-1].content) + city = _extract_city(user_text) + tool_call = { + "name": "get_weather", + "args": {"city": city}, + "id": f"call-weather-{city.lower().replace(' ', '-')}", + "type": "tool_call", + } + return {"messages": [AIMessage(content="", tool_calls=[tool_call])]} # type: ignore[arg-type] + + def finalize(state: AgentState) -> dict[str, list[AIMessage]]: + tool_message = next( + message for message in reversed(state["messages"]) if isinstance(message, ToolMessage) + ) + status = getattr(tool_message, "status", None) or "success" + return { + "messages": [ + AIMessage( + content=( + f"Tool `{tool_message.name}` finished with status `{status}`: " + f"{tool_message.content}" + ) + ) + ] + } + + graph = StateGraph(AgentState) + graph.add_node("planner", planner) + graph.add_node("tools", tool_node) + graph.add_node("finalize", finalize) + graph.add_edge(START, "planner") + graph.add_edge("planner", "tools") + graph.add_edge("tools", "finalize") + graph.add_edge("finalize", END) + return graph.compile() + + +async def main() -> None: + """Run the example with one allowed and one blocked scenario.""" + agent_control.init( + agent_name=AGENT_NAME, + agent_description=AGENT_DESCRIPTION, + server_url=os.getenv("AGENT_CONTROL_URL"), + ) + + app = _build_graph() + scenarios = [ + "What is the weather in Seattle?", + "What is the weather in Tehran?", + ] + + print("Running LangGraph ToolNode integration smoke test...") + for prompt in scenarios: + print("=" * 80) + print(f"User: {prompt}") + result = await app.ainvoke({"messages": [HumanMessage(content=prompt)]}) + + final_message = result["messages"][-1] + print(final_message.content) + + tool_message = next( + message for message in reversed(result["messages"]) if isinstance(message, ToolMessage) + ) + print( + f"Raw tool message -> name={tool_message.name!r}, " + f"status={getattr(tool_message, 'status', None)!r}, " + f"content={tool_message.content!r}" + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/langchain/setup_langgraph_toolnode_controls.py b/examples/langchain/setup_langgraph_toolnode_controls.py new file mode 100644 index 00000000..5ed63f50 --- /dev/null +++ b/examples/langchain/setup_langgraph_toolnode_controls.py @@ -0,0 +1,108 @@ +"""Create controls for the LangGraph ToolNode integration smoke test. + +This script prepares a single direct agent control for the +``langgraph_toolnode_integration_smoke.py`` example. + +Run: + cd examples/langchain + uv run setup_langgraph_toolnode_controls.py +""" + +from __future__ import annotations + +import asyncio +import os +from typing import Any + +import httpx +from agent_control import Agent, AgentControlClient, agents, controls + +AGENT_NAME = "langgraph-toolnode-smoke" +SERVER_URL = os.getenv("AGENT_CONTROL_URL", "http://localhost:8000") + +CONTROL_SPECS: list[tuple[str, dict[str, Any]]] = [ + ( + "langgraph-toolnode-block-city", + { + "description": "Block restricted cities before the get_weather tool runs.", + "enabled": True, + "execution": "server", + "scope": { + "step_types": ["tool"], + "step_names": ["get_weather"], + "stages": ["pre"], + }, + "selector": {"path": "input.city"}, + "evaluator": { + "name": "list", + "config": { + "values": ["Tehran", "Pyongyang"], + "logic": "any", + "match_on": "match", + "match_mode": "exact", + "case_sensitive": False, + }, + }, + "action": { + "decision": "deny", + "message": "That city is blocked by policy.", + }, + }, + ), +] + + +async def _ensure_control( + client: AgentControlClient, + name: str, + data: dict[str, Any], +) -> int: + """Create a control or update the existing definition.""" + try: + result = await controls.create_control(client, name=name, data=data) + return int(result["control_id"]) + except httpx.HTTPStatusError as exc: + if exc.response.status_code != 409: + raise + + control_list = await controls.list_controls(client, name=name, limit=1) + existing = control_list.get("controls", []) + if not existing: + raise RuntimeError(f"Control '{name}' already exists but could not be listed.") + + control_id = int(existing[0]["id"]) + await controls.set_control_data(client, control_id, data) + return control_id + + +async def main() -> None: + """Register the example agent and ensure its controls exist.""" + async with AgentControlClient(base_url=SERVER_URL) as client: + await client.health_check() + + agent = Agent( + agent_name=AGENT_NAME, + agent_description="LangGraph ToolNode smoke test using Agent Control integration", + ) + await agents.register_agent(client, agent, steps=[]) + + control_ids: list[int] = [] + for control_name, control_data in CONTROL_SPECS: + control_id = await _ensure_control(client, control_name, control_data) + control_ids.append(control_id) + print(f"Prepared control: {control_name} ({control_id})") + + for control_id in control_ids: + try: + await agents.add_agent_control(client, AGENT_NAME, control_id) + except httpx.HTTPStatusError as exc: + if exc.response.status_code != 409: + raise + + print() + print("LangGraph ToolNode smoke test is ready.") + print("Run: uv run langgraph_toolnode_integration_smoke.py") + + +if __name__ == "__main__": + asyncio.run(main()) From 66e199358c0467842d1162149e0170cbcb1e5ed9 Mon Sep 17 00:00:00 2001 From: Lev Neiman Date: Wed, 11 Mar 2026 13:13:37 -0700 Subject: [PATCH 3/5] feat: add langchain middleware integration --- sdks/python/pyproject.toml | 2 + .../integrations/_tool_controls.py | 376 +++++++++++++++++ .../integrations/langchain/__init__.py | 19 + .../integrations/langchain/middleware.py | 120 ++++++ .../integrations/langgraph/tool_node.py | 387 +----------------- sdks/python/tests/test_langchain_init.py | 55 +++ .../python/tests/test_langchain_middleware.py | 297 ++++++++++++++ sdks/python/tests/test_langgraph_tool_node.py | 22 +- 8 files changed, 900 insertions(+), 378 deletions(-) create mode 100644 sdks/python/src/agent_control/integrations/_tool_controls.py create mode 100644 sdks/python/src/agent_control/integrations/langchain/__init__.py create mode 100644 sdks/python/src/agent_control/integrations/langchain/middleware.py create mode 100644 sdks/python/tests/test_langchain_init.py create mode 100644 sdks/python/tests/test_langchain_middleware.py diff --git a/sdks/python/pyproject.toml b/sdks/python/pyproject.toml index 587b1236..6b520bac 100644 --- a/sdks/python/pyproject.toml +++ b/sdks/python/pyproject.toml @@ -35,6 +35,7 @@ Documentation = "https://github.com/yourusername/agent-control#readme" Repository = "https://github.com/yourusername/agent-control" [project.optional-dependencies] +langchain = ["langchain>=1.0.0"] langgraph = ["langgraph>=0.2.0"] strands-agents = ["strands-agents>=1.26.0"] galileo = ["agent-control-evaluator-galileo>=3.0.0"] @@ -49,6 +50,7 @@ dev = [ "agent-control-models", "agent-control-engine", "agent-control-evaluators", + "langchain>=1.0.0", # For langchain middleware tests "langgraph>=0.2.0", # For langgraph integration tests "strands-agents>=1.26.0", # For strands integration tests ] diff --git a/sdks/python/src/agent_control/integrations/_tool_controls.py b/sdks/python/src/agent_control/integrations/_tool_controls.py new file mode 100644 index 00000000..256ed9fb --- /dev/null +++ b/sdks/python/src/agent_control/integrations/_tool_controls.py @@ -0,0 +1,376 @@ +"""Shared tool-control helpers for optional framework integrations.""" + +from __future__ import annotations + +import asyncio +import inspect +import logging +import threading +from collections.abc import Awaitable, Callable, Coroutine, Sequence +from dataclasses import dataclass, field +from typing import Any, Literal, TypeVar, cast + +from agent_control_models import Agent, EvaluationResult +from langchain_core.messages import ToolMessage +from langchain_core.tools import BaseTool + +import agent_control +from agent_control import AgentControlClient, agents +from agent_control._state import state as sdk_state +from agent_control.validation import ensure_agent_name + +try: # pragma: no cover - optional dependency + from langgraph.types import Command as _ImportedLangGraphCommand +except Exception: # pragma: no cover - optional dependency + _LangGraphCommand: type[Any] | None = None +else: # pragma: no cover - optional dependency + _LangGraphCommand = cast(type[Any], _ImportedLangGraphCommand) + +logger = logging.getLogger(__name__) + +T = TypeVar("T") +ToolHandlerResult = ToolMessage | Any +ToolHandler = Callable[[Any], ToolHandlerResult | Awaitable[ToolHandlerResult]] + + +def resolve_agent(agent_name: str | None) -> Agent: + """Resolve the initialized agent, optionally overriding the agent name.""" + current_agent = agent_control.current_agent() + if current_agent is None: + raise RuntimeError("Agent Control is not initialized. Call agent_control.init() first.") + + if agent_name is None: + return current_agent + + return current_agent.model_copy(update={"agent_name": ensure_agent_name(agent_name)}) + + +def run_coro_in_new_loop[T](coro: Coroutine[Any, Any, T]) -> T: + """Run a coroutine on a dedicated event loop in the current thread.""" + loop = asyncio.new_event_loop() + try: + asyncio.set_event_loop(loop) + return loop.run_until_complete(coro) + finally: + loop.close() + asyncio.set_event_loop(None) + + +def run_coro_sync[T](coro: Coroutine[Any, Any, T]) -> T: + """Run an async coroutine from a synchronous context.""" + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(coro) + + result_container: list[T | None] = [None] + exception_container: list[Exception | None] = [None] + + def run_in_thread() -> None: + try: + result_container[0] = run_coro_in_new_loop(coro) + except Exception as exc: # pragma: no cover - thread path is hard to force + exception_container[0] = exc + + thread = threading.Thread(target=run_in_thread, daemon=True) + thread.start() + thread.join(timeout=10) + + if exception_container[0] is not None: + raise exception_container[0] + if result_container[0] is None: + raise RuntimeError("Timed out while waiting for Agent Control integration coroutine.") + return result_container[0] + + +def _tool_block_prefix(tool_name: str) -> str: + return f"Tool '{tool_name}' was blocked due to security policy." + + +def _tool_eval_error_prefix(tool_name: str) -> str: + return f"Tool '{tool_name}' was blocked because policy evaluation failed." + + +def build_error_tool_message( + *, + tool_name: str, + tool_call_id: str, + content: str, +) -> ToolMessage: + """Build an error ToolMessage preserving the original tool identifiers.""" + return ToolMessage( + content=content, + name=tool_name, + tool_call_id=tool_call_id, + status="error", + ) + + +def _append_detail(prefix: str, detail: str | None) -> str: + if detail is None: + return prefix + normalized = detail.strip() + if not normalized: + return prefix + return f"{prefix} {normalized}" + + +def _safe_tool_schema(tool: BaseTool, getter_name: str) -> dict[str, Any] | None: + getter = getattr(tool, getter_name, None) + if not callable(getter): + return None + + try: + schema = getter() + except Exception as exc: + logger.warning( + "Failed to extract %s for tool '%s': %s", + getter_name, + tool.name, + exc, + ) + return None + + if not isinstance(schema, dict): + logger.warning( + "Ignoring non-dict %s for tool '%s': %r", + getter_name, + tool.name, + schema, + ) + return None + + return schema + + +def _normalize_tool_output(output: ToolHandlerResult) -> Any: + if isinstance(output, ToolMessage): + return output.content + if _LangGraphCommand is not None and isinstance(output, _LangGraphCommand): + command_payload = { + "graph": output.graph, + "update": output.update, + "resume": output.resume, + "goto": output.goto, + } + return {key: value for key, value in command_payload.items() if value is not None} + return str(output) + + +def _first_result_message(result: EvaluationResult) -> str | None: + for match in result.matches or []: + message = match.result.message + if message: + return message + return result.reason + + +def result_to_tool_message( + *, + result: EvaluationResult, + tool_name: str, + tool_call_id: str, +) -> ToolMessage | None: + """Convert an evaluation result into a blocking ToolMessage when required.""" + if result.errors: + error_names = ", ".join(error.control_name for error in result.errors) + content = _append_detail(_tool_eval_error_prefix(tool_name), f"Errors: {error_names}") + return build_error_tool_message( + tool_name=tool_name, + tool_call_id=tool_call_id, + content=content, + ) + + matches = result.matches or [] + deny_match = next((match for match in matches if match.action == "deny"), None) + if deny_match is not None: + content = _append_detail(_tool_block_prefix(tool_name), deny_match.result.message) + return build_error_tool_message( + tool_name=tool_name, + tool_call_id=tool_call_id, + content=content, + ) + + steer_match = next((match for match in matches if match.action == "steer"), None) + if steer_match is not None: + guidance = None + if steer_match.steering_context is not None: + guidance = steer_match.steering_context.message + if not guidance: + guidance = steer_match.result.message or result.reason + content = _tool_block_prefix(tool_name) + if guidance: + content = f"{content} Guidance: {guidance}" + return build_error_tool_message( + tool_name=tool_name, + tool_call_id=tool_call_id, + content=content, + ) + + if not result.is_safe: + content = _append_detail(_tool_block_prefix(tool_name), _first_result_message(result)) + return build_error_tool_message( + tool_name=tool_name, + tool_call_id=tool_call_id, + content=content, + ) + + return None + + +@dataclass +class ToolControlState: + """Track framework tool registration state and enforce tool evaluations.""" + + agent: Agent + get_tools: Callable[[], Sequence[BaseTool]] + observed_tool_names: set[str] = field(default_factory=set) + lock: threading.Lock = field(default_factory=threading.Lock) + + def _current_tools(self) -> list[BaseTool]: + return list(self.get_tools()) + + def _current_tool_names(self) -> set[str]: + return {tool.name for tool in self._current_tools()} + + def _normalized_steps(self) -> list[dict[str, Any]]: + steps: list[dict[str, Any]] = [] + for tool in self._current_tools(): + description = tool.description.strip() if tool.description else None + step: dict[str, Any] = { + "type": "tool", + "name": tool.name, + } + if description: + step["description"] = description + input_schema = _safe_tool_schema(tool, "get_input_jsonschema") + if input_schema is not None: + step["input_schema"] = input_schema + output_schema = _safe_tool_schema(tool, "get_output_jsonschema") + if output_schema is not None: + step["output_schema"] = output_schema + steps.append(step) + return steps + + async def _register_async(self) -> None: + if sdk_state.server_url is None: + raise RuntimeError( + "Agent Control server URL is not configured. " + "Call agent_control.init()." + ) + + async with AgentControlClient( + base_url=sdk_state.server_url, + api_key=sdk_state.api_key, + ) as client: + await agents.register_agent( + client, + self.agent, + steps=self._normalized_steps(), + conflict_mode="overwrite", + ) + + async def _register_and_refresh_async(self) -> None: + await self._register_async() + await agent_control.refresh_controls_async() + + def _register_and_refresh_sync(self) -> None: + run_coro_sync(self._register_async()) + agent_control.refresh_controls() + + async def ensure_registered_async(self) -> None: + current_tool_names = self._current_tool_names() + with self.lock: + should_register = current_tool_names != self.observed_tool_names + + if not should_register: + return + + await self._register_and_refresh_async() + with self.lock: + self.observed_tool_names = current_tool_names + + def ensure_registered_sync(self) -> None: + current_tool_names = self._current_tool_names() + with self.lock: + should_register = current_tool_names != self.observed_tool_names + + if not should_register: + return + + self._register_and_refresh_sync() + with self.lock: + self.observed_tool_names = current_tool_names + + async def evaluate_request( + self, + *, + tool_name: str, + tool_call_id: str, + tool_input: dict[str, Any], + stage: Literal["pre", "post"], + output: Any = None, + ) -> ToolMessage | None: + try: + result = await agent_control.evaluate_controls( + step_name=tool_name, + input=tool_input, + output=output, + step_type="tool", + stage=stage, + agent_name=self.agent.agent_name, + ) + except Exception as exc: + logger.error( + "Agent Control evaluation failed for tool '%s' (%s stage): %s", + tool_name, + stage, + exc, + exc_info=True, + ) + return build_error_tool_message( + tool_name=tool_name, + tool_call_id=tool_call_id, + content=_append_detail(_tool_eval_error_prefix(tool_name), str(exc)), + ) + + return result_to_tool_message( + result=result, + tool_name=tool_name, + tool_call_id=tool_call_id, + ) + + async def run_with_controls( + self, + request: Any, + handler: ToolHandler, + ) -> ToolHandlerResult: + tool_name = request.tool_call["name"] + tool_call_id = str(request.tool_call.get("id", tool_name)) + raw_args = request.tool_call.get("args", {}) + tool_input = raw_args if isinstance(raw_args, dict) else {"value": raw_args} + + pre_message = await self.evaluate_request( + tool_name=tool_name, + tool_call_id=tool_call_id, + tool_input=tool_input, + stage="pre", + ) + if pre_message is not None: + return pre_message + + response = handler(request) + if inspect.isawaitable(response): + response = await cast(Awaitable[ToolHandlerResult], response) + + post_message = await self.evaluate_request( + tool_name=tool_name, + tool_call_id=tool_call_id, + tool_input=tool_input, + stage="post", + output=_normalize_tool_output(cast(ToolHandlerResult, response)), + ) + if post_message is not None: + return post_message + + return cast(ToolHandlerResult, response) diff --git a/sdks/python/src/agent_control/integrations/langchain/__init__.py b/sdks/python/src/agent_control/integrations/langchain/__init__.py new file mode 100644 index 00000000..fd71dfa7 --- /dev/null +++ b/sdks/python/src/agent_control/integrations/langchain/__init__.py @@ -0,0 +1,19 @@ +"""LangChain integration for Agent Control.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from .middleware import AgentControlMiddleware + +__all__ = ["AgentControlMiddleware"] + + +def __getattr__(name: str) -> Any: + """Lazy import to avoid import errors when langchain is not installed.""" + if name == "AgentControlMiddleware": + from .middleware import AgentControlMiddleware + + return AgentControlMiddleware + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/sdks/python/src/agent_control/integrations/langchain/middleware.py b/sdks/python/src/agent_control/integrations/langchain/middleware.py new file mode 100644 index 00000000..feba46af --- /dev/null +++ b/sdks/python/src/agent_control/integrations/langchain/middleware.py @@ -0,0 +1,120 @@ +"""Agent Control middleware integration for LangChain agents.""" + +from __future__ import annotations + +import logging +from collections.abc import Awaitable, Callable, Sequence +from typing import Any + +from agent_control.integrations._tool_controls import ToolControlState, resolve_agent, run_coro_sync + +try: + from langchain.agents.middleware import ( + AgentMiddleware, + ExtendedModelResponse, + ModelRequest, + ModelResponse, + ToolCallRequest, + ) + from langchain_core.messages import AIMessage + from langchain_core.tools import BaseTool + from langgraph.types import Command +except Exception as exc: # pragma: no cover - optional dependency + raise RuntimeError( + "LangChain integration requires langchain. " + "Install with: agent-control-sdk[langchain]." + ) from exc + +logger = logging.getLogger(__name__) + + +class AgentControlMiddleware(AgentMiddleware[Any, Any, Any]): + """LangChain middleware that enforces Agent Control policies on tool calls.""" + + tools: Sequence[BaseTool] = () + + def __init__(self, *, agent_name: str | None = None) -> None: + self._tools_by_name: dict[str, BaseTool] = {} + self._state = ToolControlState( + agent=resolve_agent(agent_name), + get_tools=lambda: list(self._tools_by_name.values()), + ) + + def _capture_tools(self, tools: Sequence[BaseTool | dict[str, Any]]) -> None: + for tool in tools: + if isinstance(tool, BaseTool): + self._tools_by_name[tool.name] = tool + + def _capture_request_tool(self, request: ToolCallRequest) -> None: + if request.tool is not None: + self._tools_by_name[request.tool.name] = request.tool + + def wrap_model_call( + self, + request: ModelRequest[Any], + handler: Callable[[ModelRequest[Any]], ModelResponse[Any]], + ) -> ModelResponse[Any] | AIMessage | ExtendedModelResponse[Any]: + self._capture_tools(request.tools) + try: + self._state.ensure_registered_sync() + except Exception as exc: + logger.error( + "LangChain middleware re-registration failed; " + "continuing without blocking model execution: %s", + exc, + exc_info=True, + ) + return handler(request) + + async def awrap_model_call( + self, + request: ModelRequest[Any], + handler: Callable[[ModelRequest[Any]], Awaitable[ModelResponse[Any]]], + ) -> ModelResponse[Any] | AIMessage | ExtendedModelResponse[Any]: + self._capture_tools(request.tools) + try: + await self._state.ensure_registered_async() + except Exception as exc: + logger.error( + "LangChain middleware async re-registration failed; " + "continuing without blocking model execution: %s", + exc, + exc_info=True, + ) + return await handler(request) + + def wrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], Any], + ) -> Any: + self._capture_request_tool(request) + try: + self._state.ensure_registered_sync() + except Exception as exc: + logger.error( + "LangChain middleware re-registration failed; " + "continuing without blocking tool '%s': %s", + request.tool_call["name"], + exc, + exc_info=True, + ) + return run_coro_sync(self._state.run_with_controls(request, handler)) + + async def awrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], Awaitable[Command[Any] | Any]], + ) -> Any: + self._capture_request_tool(request) + try: + await self._state.ensure_registered_async() + except Exception as exc: + logger.error( + "LangChain middleware async re-registration failed; " + "continuing without blocking tool '%s': %s", + request.tool_call["name"], + exc, + exc_info=True, + ) + return await self._state.run_with_controls(request, handler) diff --git a/sdks/python/src/agent_control/integrations/langgraph/tool_node.py b/sdks/python/src/agent_control/integrations/langgraph/tool_node.py index efa2979e..036af8f6 100644 --- a/sdks/python/src/agent_control/integrations/langgraph/tool_node.py +++ b/sdks/python/src/agent_control/integrations/langgraph/tool_node.py @@ -2,26 +2,19 @@ from __future__ import annotations -import asyncio -import inspect import logging -import threading -from collections.abc import Awaitable, Callable, Coroutine, Sequence -from dataclasses import dataclass, field -from typing import Any, Literal, TypeVar, cast +from collections.abc import Awaitable, Callable, Sequence +from typing import Any -from agent_control_models import Agent, EvaluationResult - -import agent_control -from agent_control import AgentControlClient, agents -from agent_control._state import state as sdk_state -from agent_control.validation import ensure_agent_name +from agent_control.integrations._tool_controls import ( + ToolControlState, + resolve_agent, + run_coro_sync, +) try: - from langchain_core.messages import ToolMessage from langchain_core.tools import BaseTool from langgraph.prebuilt import ToolNode - from langgraph.types import Command except Exception as exc: # pragma: no cover - optional dependency raise RuntimeError( "LangGraph integration requires langgraph. " @@ -30,344 +23,9 @@ logger = logging.getLogger(__name__) -T = TypeVar("T") -ToolHandlerResult = ToolMessage | Command[Any] -ToolHandler = Callable[[Any], ToolHandlerResult | Awaitable[ToolHandlerResult]] _TOOLNODE_DEFAULT = object() -def _run_coro_in_new_loop[T](coro: Coroutine[Any, Any, T]) -> T: - """Run a coroutine on a dedicated event loop in the current thread.""" - loop = asyncio.new_event_loop() - try: - asyncio.set_event_loop(loop) - return loop.run_until_complete(coro) - finally: - loop.close() - asyncio.set_event_loop(None) - - -def _run_coro_sync[T](coro: Coroutine[Any, Any, T]) -> T: - """Run an async coroutine from a synchronous context.""" - try: - asyncio.get_running_loop() - except RuntimeError: - return asyncio.run(coro) - - result_container: list[T | None] = [None] - exception_container: list[Exception | None] = [None] - - def run_in_thread() -> None: - try: - result_container[0] = _run_coro_in_new_loop(coro) - except Exception as exc: # pragma: no cover - thread path is hard to force - exception_container[0] = exc - - thread = threading.Thread(target=run_in_thread, daemon=True) - thread.start() - thread.join(timeout=10) - - if exception_container[0] is not None: - raise exception_container[0] - if result_container[0] is None: - raise RuntimeError("Timed out while waiting for LangGraph integration coroutine.") - return result_container[0] - - -def _tool_block_prefix(tool_name: str) -> str: - return f"Tool '{tool_name}' was blocked due to security policy." - - -def _tool_eval_error_prefix(tool_name: str) -> str: - return f"Tool '{tool_name}' was blocked because policy evaluation failed." - - -def _build_error_tool_message( - *, - tool_name: str, - tool_call_id: str, - content: str, -) -> ToolMessage: - """Build an error ToolMessage preserving the original tool identifiers.""" - return ToolMessage( - content=content, - name=tool_name, - tool_call_id=tool_call_id, - status="error", - ) - - -def _append_detail(prefix: str, detail: str | None) -> str: - """Append optional detail to a message prefix.""" - if detail is None: - return prefix - normalized = detail.strip() - if not normalized: - return prefix - return f"{prefix} {normalized}" - - -def _safe_tool_schema(tool: BaseTool, getter_name: str) -> dict[str, Any] | None: - """Extract a JSON schema from a tool getter, logging and falling back to None.""" - getter = getattr(tool, getter_name, None) - if not callable(getter): - return None - - try: - schema = getter() - except Exception as exc: - logger.warning( - "Failed to extract %s for tool '%s': %s", - getter_name, - tool.name, - exc, - ) - return None - - if not isinstance(schema, dict): - logger.warning( - "Ignoring non-dict %s for tool '%s': %r", - getter_name, - tool.name, - schema, - ) - return None - - return schema - - -def _normalize_tool_output(output: ToolHandlerResult) -> Any: - """Normalize LangGraph tool output into evaluation-friendly payloads.""" - if isinstance(output, ToolMessage): - return output.content - if isinstance(output, Command): - command_payload = { - "graph": output.graph, - "update": output.update, - "resume": output.resume, - "goto": output.goto, - } - return {key: value for key, value in command_payload.items() if value is not None} - return str(output) - - -def _first_result_message(result: EvaluationResult) -> str | None: - """Return the first available control result message from matches.""" - for match in result.matches or []: - message = match.result.message - if message: - return message - return result.reason - - -def _result_to_tool_message( - *, - result: EvaluationResult, - tool_name: str, - tool_call_id: str, -) -> ToolMessage | None: - """Convert an evaluation result into a blocking ToolMessage when required.""" - if result.errors: - error_names = ", ".join(error.control_name for error in result.errors) - content = _append_detail(_tool_eval_error_prefix(tool_name), f"Errors: {error_names}") - return _build_error_tool_message( - tool_name=tool_name, - tool_call_id=tool_call_id, - content=content, - ) - - matches = result.matches or [] - deny_match = next((match for match in matches if match.action == "deny"), None) - if deny_match is not None: - content = _append_detail(_tool_block_prefix(tool_name), deny_match.result.message) - return _build_error_tool_message( - tool_name=tool_name, - tool_call_id=tool_call_id, - content=content, - ) - - steer_match = next((match for match in matches if match.action == "steer"), None) - if steer_match is not None: - guidance = None - if steer_match.steering_context is not None: - guidance = steer_match.steering_context.message - if not guidance: - guidance = steer_match.result.message or result.reason - content = _tool_block_prefix(tool_name) - if guidance: - content = f"{content} Guidance: {guidance}" - return _build_error_tool_message( - tool_name=tool_name, - tool_call_id=tool_call_id, - content=content, - ) - - if not result.is_safe: - content = _append_detail(_tool_block_prefix(tool_name), _first_result_message(result)) - return _build_error_tool_message( - tool_name=tool_name, - tool_call_id=tool_call_id, - content=content, - ) - - return None - - -@dataclass -class _LangGraphControlState: - """Holds tool registration and evaluation state for a controlled ToolNode.""" - - tool_node: ToolNode - agent: Agent - observed_tool_names: set[str] = field(default_factory=set) - lock: threading.Lock = field(default_factory=threading.Lock) - - def _current_tool_names(self) -> set[str]: - return set(self.tool_node.tools_by_name) - - def _normalized_steps(self) -> list[dict[str, Any]]: - steps: list[dict[str, Any]] = [] - for tool in self.tool_node.tools_by_name.values(): - description = tool.description.strip() if tool.description else None - step: dict[str, Any] = { - "type": "tool", - "name": tool.name, - } - if description: - step["description"] = description - input_schema = _safe_tool_schema(tool, "get_input_jsonschema") - if input_schema is not None: - step["input_schema"] = input_schema - output_schema = _safe_tool_schema(tool, "get_output_jsonschema") - if output_schema is not None: - step["output_schema"] = output_schema - steps.append(step) - return steps - - async def _register_async(self) -> None: - if sdk_state.server_url is None: - raise RuntimeError( - "Agent Control server URL is not configured. " - "Call agent_control.init()." - ) - - async with AgentControlClient( - base_url=sdk_state.server_url, - api_key=sdk_state.api_key, - ) as client: - await agents.register_agent( - client, - self.agent, - steps=self._normalized_steps(), - conflict_mode="overwrite", - ) - - async def _register_and_refresh_async(self) -> None: - await self._register_async() - await agent_control.refresh_controls_async() - - def _register_and_refresh_sync(self) -> None: - _run_coro_sync(self._register_async()) - agent_control.refresh_controls() - - async def ensure_registered_async(self) -> None: - current_tool_names = self._current_tool_names() - with self.lock: - should_register = current_tool_names != self.observed_tool_names - - if not should_register: - return - - await self._register_and_refresh_async() - with self.lock: - self.observed_tool_names = current_tool_names - - def ensure_registered_sync(self) -> None: - current_tool_names = self._current_tool_names() - with self.lock: - should_register = current_tool_names != self.observed_tool_names - - if not should_register: - return - - self._register_and_refresh_sync() - with self.lock: - self.observed_tool_names = current_tool_names - - async def evaluate_request( - self, - *, - tool_name: str, - tool_call_id: str, - tool_input: dict[str, Any], - stage: Literal["pre", "post"], - output: Any = None, - ) -> ToolMessage | None: - try: - result = await agent_control.evaluate_controls( - step_name=tool_name, - input=tool_input, - output=output, - step_type="tool", - stage=stage, - agent_name=self.agent.agent_name, - ) - except Exception as exc: - logger.error( - "Agent Control evaluation failed for tool '%s' (%s stage): %s", - tool_name, - stage, - exc, - exc_info=True, - ) - return _build_error_tool_message( - tool_name=tool_name, - tool_call_id=tool_call_id, - content=_append_detail(_tool_eval_error_prefix(tool_name), str(exc)), - ) - - return _result_to_tool_message( - result=result, - tool_name=tool_name, - tool_call_id=tool_call_id, - ) - - async def run_with_controls( - self, - request: Any, - handler: ToolHandler, - ) -> ToolHandlerResult: - tool_name = request.tool_call["name"] - tool_call_id = str(request.tool_call.get("id", tool_name)) - raw_args = request.tool_call.get("args", {}) - tool_input = raw_args if isinstance(raw_args, dict) else {"value": raw_args} - - pre_message = await self.evaluate_request( - tool_name=tool_name, - tool_call_id=tool_call_id, - tool_input=tool_input, - stage="pre", - ) - if pre_message is not None: - return pre_message - - response = handler(request) - if inspect.isawaitable(response): - response = await cast(Awaitable[ToolHandlerResult], response) - - post_message = await self.evaluate_request( - tool_name=tool_name, - tool_call_id=tool_call_id, - tool_input=tool_input, - stage="post", - output=_normalize_tool_output(cast(ToolHandlerResult, response)), - ) - if post_message is not None: - return post_message - - return cast(ToolHandlerResult, response) - - def create_controlled_tool_node( tools: Sequence[BaseTool | Callable[..., Any]], *, @@ -378,23 +36,18 @@ def create_controlled_tool_node( messages_key: str = "messages", ) -> ToolNode: """Create a ToolNode configured with Agent Control policy enforcement.""" - current_agent = agent_control.current_agent() - if current_agent is None: - raise RuntimeError("Agent Control is not initialized. Call agent_control.init() first.") + resolved_agent = resolve_agent(agent_name) + tool_node_ref: dict[str, ToolNode] = {} - resolved_agent = current_agent - if agent_name is not None: - resolved_agent = current_agent.model_copy( - update={"agent_name": ensure_agent_name(agent_name)} - ) - - state_ref: dict[str, _LangGraphControlState] = {} + state = ToolControlState( + agent=resolved_agent, + get_tools=lambda: list(tool_node_ref["tool_node"].tools_by_name.values()), + ) def wrap_tool_call( request: Any, - handler: Callable[[Any], ToolHandlerResult], - ) -> ToolHandlerResult: - state = state_ref["state"] + handler: Callable[[Any], Any], + ) -> Any: try: state.ensure_registered_sync() except Exception as exc: @@ -405,13 +58,12 @@ def wrap_tool_call( exc, exc_info=True, ) - return _run_coro_sync(state.run_with_controls(request, handler)) + return run_coro_sync(state.run_with_controls(request, handler)) async def awrap_tool_call( request: Any, - handler: Callable[[Any], Awaitable[ToolHandlerResult]], - ) -> ToolHandlerResult: - state = state_ref["state"] + handler: Callable[[Any], Awaitable[Any]], + ) -> Any: try: await state.ensure_registered_async() except Exception as exc: @@ -435,8 +87,7 @@ async def awrap_tool_call( tool_node_kwargs["handle_tool_errors"] = handle_tool_errors tool_node = ToolNode(tools, **tool_node_kwargs) - state = _LangGraphControlState(tool_node=tool_node, agent=resolved_agent) - state_ref["state"] = state + tool_node_ref["tool_node"] = tool_node try: state.ensure_registered_sync() diff --git a/sdks/python/tests/test_langchain_init.py b/sdks/python/tests/test_langchain_init.py new file mode 100644 index 00000000..1547de8b --- /dev/null +++ b/sdks/python/tests/test_langchain_init.py @@ -0,0 +1,55 @@ +"""Unit tests for LangChain integration __init__.py.""" + +from __future__ import annotations + +import sys +from unittest.mock import patch + +import pytest + + +def test_langchain_init_exports() -> None: + """Test that __init__.py exports the expected middleware.""" + from agent_control.integrations.langchain import AgentControlMiddleware + + assert isinstance(AgentControlMiddleware, type) + + +def test_langchain_init_all() -> None: + """Test that __all__ contains expected exports.""" + import agent_control.integrations.langchain as langchain_module + + assert hasattr(langchain_module, "__all__") + assert langchain_module.__all__ == ["AgentControlMiddleware"] + + +def test_lazy_import_agent_control_middleware() -> None: + """Test lazy import of AgentControlMiddleware via __getattr__.""" + from agent_control.integrations.langchain import AgentControlMiddleware + + assert AgentControlMiddleware.__name__ == "AgentControlMiddleware" + + +def test_missing_langchain_dependency_raises_runtime_error() -> None: + """Test that missing optional deps surface a helpful install message.""" + import agent_control.integrations.langchain as langchain_module + + sys.modules.pop("agent_control.integrations.langchain.middleware", None) + real_import = __import__ + + def fake_import(name: str, *args: object, **kwargs: object): + if name.startswith("langchain") or name.startswith("langchain_core"): + raise ModuleNotFoundError(name) + return real_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=fake_import): + with pytest.raises(RuntimeError, match=r"agent-control-sdk\[langchain\]"): + langchain_module.__getattr__("AgentControlMiddleware") + + +def test_invalid_attribute_raises_error() -> None: + """Test that accessing invalid attribute raises AttributeError.""" + import agent_control.integrations.langchain as langchain_module + + with pytest.raises(AttributeError, match="has no attribute 'InvalidMiddleware'"): + _ = langchain_module.InvalidMiddleware # type: ignore[attr-defined] diff --git a/sdks/python/tests/test_langchain_middleware.py b/sdks/python/tests/test_langchain_middleware.py new file mode 100644 index 00000000..9537e401 --- /dev/null +++ b/sdks/python/tests/test_langchain_middleware.py @@ -0,0 +1,297 @@ +"""Unit tests for LangChain AgentMiddleware integration.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest +from agent_control_models import Agent, ControlMatch, EvaluationResult, EvaluatorResult, SteeringContext +from langchain_core.messages import ToolMessage +from langchain_core.tools import tool +from langgraph.prebuilt.tool_node import ToolCallRequest, ToolRuntime + +from agent_control.integrations.langchain.middleware import AgentControlMiddleware + + +class _DummyClient: + """Minimal async client used to patch AgentControlClient in tests.""" + + def __init__(self, *, base_url: str, api_key: str | None) -> None: + self.base_url = base_url + self.api_key = api_key + + async def __aenter__(self) -> "_DummyClient": + return self + + async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None: + return None + + +@tool +def echo_tool(text: str) -> str: + """Echo text.""" + return text + + +@tool +def other_tool(text: str) -> str: + """Return a second tool result.""" + return f"other:{text}" + + +def _tool_request( + tool_obj: Any, + *, + text: str = "hello", + tool_call_id: str = "call-1", +) -> ToolCallRequest: + runtime = ToolRuntime( + state={}, + context=None, + config={}, + stream_writer=lambda *_args, **_kwargs: None, + tool_call_id=tool_call_id, + store=None, + ) + tool_call = { + "name": tool_obj.name, + "args": {"text": text}, + "id": tool_call_id, + "type": "tool_call", + } + return ToolCallRequest(tool_call=tool_call, tool=tool_obj, state={}, runtime=runtime) + + +def _model_request(tools: list[Any]) -> Any: + return SimpleNamespace(tools=tools) + + +def _safe_result() -> EvaluationResult: + return EvaluationResult(is_safe=True, confidence=1.0, reason=None, matches=None, errors=None) + + +def _deny_result(message: str = "Not allowed") -> EvaluationResult: + return EvaluationResult( + is_safe=False, + confidence=1.0, + reason=None, + matches=[ + ControlMatch( + control_id=1, + control_name="deny-control", + action="deny", + result=EvaluatorResult(matched=True, confidence=1.0, message=message), + ) + ], + errors=None, + ) + + +def _steer_result(message: str = "Adjust the tool call") -> EvaluationResult: + return EvaluationResult( + is_safe=False, + confidence=1.0, + reason=None, + matches=[ + ControlMatch( + control_id=2, + control_name="steer-control", + action="steer", + result=EvaluatorResult(matched=True, confidence=1.0, message="Steer required"), + steering_context=SteeringContext(message=message), + ) + ], + errors=None, + ) + + +def _error_result() -> EvaluationResult: + return EvaluationResult( + is_safe=True, + confidence=1.0, + reason=None, + matches=None, + errors=[ + ControlMatch( + control_id=3, + control_name="error-control", + action="log", + result=EvaluatorResult( + matched=False, + confidence=0.0, + message="broken", + error="broken", + ), + ) + ], + ) + + +@pytest.fixture +def langchain_env(monkeypatch: pytest.MonkeyPatch) -> dict[str, Any]: + """Patch Agent Control environment and registration dependencies.""" + from agent_control.integrations import _tool_controls as shared + + current_agent = Agent( + agent_name="test-agent-123", + agent_description="Test agent", + agent_version="1.0.0", + ) + register_agent = AsyncMock(return_value={}) + refresh_controls = MagicMock(return_value=[]) + refresh_controls_async = AsyncMock(return_value=[]) + evaluate_controls = AsyncMock(return_value=_safe_result()) + + monkeypatch.setattr(shared.agent_control, "current_agent", lambda: current_agent) + monkeypatch.setattr(shared.sdk_state, "server_url", "http://example.test") + monkeypatch.setattr(shared.sdk_state, "api_key", "secret") + monkeypatch.setattr(shared, "AgentControlClient", _DummyClient) + monkeypatch.setattr(shared.agents, "register_agent", register_agent) + monkeypatch.setattr(shared.agent_control, "refresh_controls", refresh_controls) + monkeypatch.setattr(shared.agent_control, "refresh_controls_async", refresh_controls_async) + monkeypatch.setattr(shared.agent_control, "evaluate_controls", evaluate_controls) + + return { + "agent": current_agent, + "register_agent": register_agent, + "refresh_controls": refresh_controls, + "refresh_controls_async": refresh_controls_async, + "evaluate_controls": evaluate_controls, + } + + +def test_model_wrapper_registers_full_tool_list(langchain_env: dict[str, Any]) -> None: + middleware = AgentControlMiddleware() + request = _model_request([echo_tool, other_tool]) + handler = MagicMock(return_value="model-response") + + response = middleware.wrap_model_call(request, handler) + + assert response == "model-response" + register_agent = langchain_env["register_agent"] + assert register_agent.await_count == 1 + _, kwargs = register_agent.await_args + assert kwargs["conflict_mode"] == "overwrite" + assert {step["name"] for step in kwargs["steps"]} == {"echo_tool", "other_tool"} + langchain_env["refresh_controls"].assert_called_once_with() + + +def test_sync_tool_wrapper_allows_tool_output(langchain_env: dict[str, Any]) -> None: + middleware = AgentControlMiddleware() + middleware.wrap_model_call(_model_request([echo_tool]), MagicMock(return_value="ok")) + request = _tool_request(echo_tool) + handler = MagicMock( + return_value=ToolMessage(content="allowed", name="echo_tool", tool_call_id="call-1") + ) + + response = middleware.wrap_tool_call(request, handler) + + assert isinstance(response, ToolMessage) + assert response.content == "allowed" + assert langchain_env["evaluate_controls"].await_count == 2 + + +@pytest.mark.asyncio +async def test_async_tool_wrapper_allows_tool_output(langchain_env: dict[str, Any]) -> None: + middleware = AgentControlMiddleware() + await middleware.awrap_model_call( + _model_request([echo_tool]), + AsyncMock(return_value="ok"), + ) + request = _tool_request(echo_tool) + + async def handler(_: Any) -> ToolMessage: + return ToolMessage(content="allowed", name="echo_tool", tool_call_id="call-1") + + response = await middleware.awrap_tool_call(request, handler) + + assert isinstance(response, ToolMessage) + assert response.content == "allowed" + assert langchain_env["evaluate_controls"].await_count == 2 + + +def test_deny_returns_error_tool_message(langchain_env: dict[str, Any]) -> None: + langchain_env["evaluate_controls"].return_value = _deny_result("Query blocked") + middleware = AgentControlMiddleware() + middleware.wrap_model_call(_model_request([echo_tool]), MagicMock(return_value="ok")) + request = _tool_request(echo_tool) + handler = MagicMock() + + response = middleware.wrap_tool_call(request, handler) + + assert isinstance(response, ToolMessage) + assert response.status == "error" + assert response.tool_call_id == "call-1" + assert "blocked due to security policy" in str(response.content) + assert "Query blocked" in str(response.content) + handler.assert_not_called() + + +@pytest.mark.asyncio +async def test_steer_returns_error_tool_message_with_guidance( + langchain_env: dict[str, Any] +) -> None: + langchain_env["evaluate_controls"].return_value = _steer_result("Try a safer query") + middleware = AgentControlMiddleware() + await middleware.awrap_model_call(_model_request([echo_tool]), AsyncMock(return_value="ok")) + request = _tool_request(echo_tool) + + async def handler(_: Any) -> ToolMessage: + return ToolMessage(content="allowed", name="echo_tool", tool_call_id="call-1") + + response = await middleware.awrap_tool_call(request, handler) + + assert isinstance(response, ToolMessage) + assert response.status == "error" + assert "Guidance: Try a safer query" in str(response.content) + + +def test_evaluation_errors_fail_closed(langchain_env: dict[str, Any]) -> None: + langchain_env["evaluate_controls"].return_value = _error_result() + middleware = AgentControlMiddleware() + middleware.wrap_model_call(_model_request([echo_tool]), MagicMock(return_value="ok")) + request = _tool_request(echo_tool) + handler = MagicMock() + + response = middleware.wrap_tool_call(request, handler) + + assert isinstance(response, ToolMessage) + assert response.status == "error" + assert "policy evaluation failed" in str(response.content) + handler.assert_not_called() + + +def test_model_resync_failure_is_fail_open_and_logged( + langchain_env: dict[str, Any], + caplog: pytest.LogCaptureFixture, +) -> None: + register_agent = langchain_env["register_agent"] + register_agent.side_effect = [{}, RuntimeError("resync failed")] + middleware = AgentControlMiddleware() + middleware.wrap_model_call(_model_request([echo_tool]), MagicMock(return_value="ok")) + handler = MagicMock(return_value="model-response") + + response = middleware.wrap_model_call(_model_request([echo_tool, other_tool]), handler) + + assert response == "model-response" + assert register_agent.await_count == 2 + assert "re-registration failed" in caplog.text + + +def test_omitted_agent_name_resolves_from_current_agent(langchain_env: dict[str, Any]) -> None: + middleware = AgentControlMiddleware() + middleware.wrap_model_call(_model_request([echo_tool]), MagicMock(return_value="ok")) + + args, _ = langchain_env["register_agent"].await_args + assert args[1].agent_name == langchain_env["agent"].agent_name + + +def test_missing_initialized_agent_raises(monkeypatch: pytest.MonkeyPatch) -> None: + from agent_control.integrations import _tool_controls as shared + + monkeypatch.setattr(shared.agent_control, "current_agent", lambda: None) + + with pytest.raises(RuntimeError, match="Call agent_control.init"): + AgentControlMiddleware() diff --git a/sdks/python/tests/test_langgraph_tool_node.py b/sdks/python/tests/test_langgraph_tool_node.py index 84ea7434..c11a0e64 100644 --- a/sdks/python/tests/test_langgraph_tool_node.py +++ b/sdks/python/tests/test_langgraph_tool_node.py @@ -122,6 +122,7 @@ def _error_result() -> EvaluationResult: @pytest.fixture def langgraph_env(monkeypatch: pytest.MonkeyPatch) -> dict[str, Any]: """Patch Agent Control environment and registration dependencies.""" + from agent_control.integrations import _tool_controls as shared from agent_control.integrations.langgraph import tool_node as module current_agent = Agent( @@ -134,17 +135,18 @@ def langgraph_env(monkeypatch: pytest.MonkeyPatch) -> dict[str, Any]: refresh_controls_async = AsyncMock(return_value=[]) evaluate_controls = AsyncMock(return_value=_safe_result()) - monkeypatch.setattr(module.agent_control, "current_agent", lambda: current_agent) - monkeypatch.setattr(module.sdk_state, "server_url", "http://example.test") - monkeypatch.setattr(module.sdk_state, "api_key", "secret") - monkeypatch.setattr(module, "AgentControlClient", _DummyClient) - monkeypatch.setattr(module.agents, "register_agent", register_agent) - monkeypatch.setattr(module.agent_control, "refresh_controls", refresh_controls) - monkeypatch.setattr(module.agent_control, "refresh_controls_async", refresh_controls_async) - monkeypatch.setattr(module.agent_control, "evaluate_controls", evaluate_controls) + monkeypatch.setattr(shared.agent_control, "current_agent", lambda: current_agent) + monkeypatch.setattr(shared.sdk_state, "server_url", "http://example.test") + monkeypatch.setattr(shared.sdk_state, "api_key", "secret") + monkeypatch.setattr(shared, "AgentControlClient", _DummyClient) + monkeypatch.setattr(shared.agents, "register_agent", register_agent) + monkeypatch.setattr(shared.agent_control, "refresh_controls", refresh_controls) + monkeypatch.setattr(shared.agent_control, "refresh_controls_async", refresh_controls_async) + monkeypatch.setattr(shared.agent_control, "evaluate_controls", evaluate_controls) return { "module": module, + "shared": shared, "agent": current_agent, "register_agent": register_agent, "refresh_controls": refresh_controls, @@ -312,9 +314,9 @@ def test_omitted_agent_name_resolves_from_current_agent(langgraph_env: dict[str, def test_missing_initialized_agent_raises(monkeypatch: pytest.MonkeyPatch) -> None: - from agent_control.integrations.langgraph import tool_node as module + from agent_control.integrations import _tool_controls as shared - monkeypatch.setattr(module.agent_control, "current_agent", lambda: None) + monkeypatch.setattr(shared.agent_control, "current_agent", lambda: None) with pytest.raises(RuntimeError, match="Call agent_control.init"): create_controlled_tool_node([echo_tool]) From 308f19e224863ea0f675579fc9eb8cec56eb9264 Mon Sep 17 00:00:00 2001 From: Lev Neiman Date: Wed, 11 Mar 2026 13:19:02 -0700 Subject: [PATCH 4/5] docs: add langchain middleware smoke example --- .../langchain/langchain_middleware_smoke.py | 102 +++++++++++++++++ .../setup_langchain_middleware_controls.py | 108 ++++++++++++++++++ 2 files changed, 210 insertions(+) create mode 100644 examples/langchain/langchain_middleware_smoke.py create mode 100644 examples/langchain/setup_langchain_middleware_controls.py diff --git a/examples/langchain/langchain_middleware_smoke.py b/examples/langchain/langchain_middleware_smoke.py new file mode 100644 index 00000000..71fb9a51 --- /dev/null +++ b/examples/langchain/langchain_middleware_smoke.py @@ -0,0 +1,102 @@ +"""Minimal LangChain middleware smoke test for Agent Control. + +This example proves the LangChain agent-factory middleware path works with +``AgentControlMiddleware`` and protects tool calls without using ``@control()``. + +Run: + cd examples/langchain + uv run setup_langchain_middleware_controls.py + uv run langchain_middleware_smoke.py + +Prerequisite: + Start the Agent Control server first (`cd server && make run`). +""" + +from __future__ import annotations + +import os + +import agent_control +from agent_control.integrations.langchain import AgentControlMiddleware +from langchain.agents import create_agent +from langchain_core.language_models import FakeMessagesListChatModel +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage +from langchain_core.tools import tool + +AGENT_NAME = "langchain-middleware-smoke" +AGENT_DESCRIPTION = "Minimal LangChain middleware smoke test using Agent Control" + + +@tool("get_weather") +def get_weather(city: str) -> dict[str, str]: + """Return a deterministic weather response for a city.""" + return { + "city": city, + "forecast": { + "seattle": "Rainy and 53F", + "tehran": "Sunny and 75F", + "tokyo": "Clear and 61F", + }.get(city.lower(), "Partly cloudy and 68F"), + } + + +def _build_agent_for_city(city: str): + tool_call_id = f"call-weather-{city.lower().replace(' ', '-')}" + model = FakeMessagesListChatModel( + responses=[ + AIMessage( + content="", + tool_calls=[ + { + "name": "get_weather", + "args": {"city": city}, + "id": tool_call_id, + "type": "tool_call", + } + ], + ), + AIMessage(content=f"Finished processing weather lookup for {city}."), + ] + ) + return create_agent( + model=model, + tools=[get_weather], + middleware=[AgentControlMiddleware()], + system_prompt="Always use the available weather tool before replying.", + name="agent-control-langchain-middleware-smoke", + ) + + +def _run_scenario(prompt: str, city: str) -> None: + app = _build_agent_for_city(city) + result = app.invoke({"messages": [HumanMessage(content=prompt)]}) + + print("=" * 80) + print(f"User: {prompt}") + print(f"Final response: {result['messages'][-1].content}") + + tool_message = next( + message for message in reversed(result["messages"]) if isinstance(message, ToolMessage) + ) + print( + f"Raw tool message -> name={tool_message.name!r}, " + f"status={getattr(tool_message, 'status', None)!r}, " + f"content={tool_message.content!r}" + ) + + +def main() -> None: + """Run the example with one allowed and one blocked scenario.""" + agent_control.init( + agent_name=AGENT_NAME, + agent_description=AGENT_DESCRIPTION, + server_url=os.getenv("AGENT_CONTROL_URL"), + ) + + print("Running LangChain middleware smoke test...") + _run_scenario("What is the weather in Seattle?", "Seattle") + _run_scenario("What is the weather in Tehran?", "Tehran") + + +if __name__ == "__main__": + main() diff --git a/examples/langchain/setup_langchain_middleware_controls.py b/examples/langchain/setup_langchain_middleware_controls.py new file mode 100644 index 00000000..292b0a95 --- /dev/null +++ b/examples/langchain/setup_langchain_middleware_controls.py @@ -0,0 +1,108 @@ +"""Create controls for the LangChain middleware smoke test. + +This script prepares a single direct agent control for the +``langchain_middleware_smoke.py`` example. + +Run: + cd examples/langchain + uv run setup_langchain_middleware_controls.py +""" + +from __future__ import annotations + +import asyncio +import os +from typing import Any + +import httpx +from agent_control import Agent, AgentControlClient, agents, controls + +AGENT_NAME = "langchain-middleware-smoke" +SERVER_URL = os.getenv("AGENT_CONTROL_URL", "http://localhost:8000") + +CONTROL_SPECS: list[tuple[str, dict[str, Any]]] = [ + ( + "langchain-middleware-block-city", + { + "description": "Block restricted cities before the get_weather tool runs.", + "enabled": True, + "execution": "server", + "scope": { + "step_types": ["tool"], + "step_names": ["get_weather"], + "stages": ["pre"], + }, + "selector": {"path": "input.city"}, + "evaluator": { + "name": "list", + "config": { + "values": ["Tehran", "Pyongyang"], + "logic": "any", + "match_on": "match", + "match_mode": "exact", + "case_sensitive": False, + }, + }, + "action": { + "decision": "deny", + "message": "That city is blocked by policy.", + }, + }, + ), +] + + +async def _ensure_control( + client: AgentControlClient, + name: str, + data: dict[str, Any], +) -> int: + """Create a control or update the existing definition.""" + try: + result = await controls.create_control(client, name=name, data=data) + return int(result["control_id"]) + except httpx.HTTPStatusError as exc: + if exc.response.status_code != 409: + raise + + control_list = await controls.list_controls(client, name=name, limit=1) + existing = control_list.get("controls", []) + if not existing: + raise RuntimeError(f"Control '{name}' already exists but could not be listed.") + + control_id = int(existing[0]["id"]) + await controls.set_control_data(client, control_id, data) + return control_id + + +async def main() -> None: + """Register the example agent and ensure its controls exist.""" + async with AgentControlClient(base_url=SERVER_URL) as client: + await client.health_check() + + agent = Agent( + agent_name=AGENT_NAME, + agent_description="LangChain middleware smoke test using Agent Control", + ) + await agents.register_agent(client, agent, steps=[]) + + control_ids: list[int] = [] + for control_name, control_data in CONTROL_SPECS: + control_id = await _ensure_control(client, control_name, control_data) + control_ids.append(control_id) + print(f"Prepared control: {control_name} ({control_id})") + + for control_id in control_ids: + try: + await agents.add_agent_control(client, AGENT_NAME, control_id) + except httpx.HTTPStatusError as exc: + if exc.response.status_code != 409: + raise + + print() + print("LangChain middleware smoke test is ready.") + print("Run: uv run langchain_middleware_smoke.py") + + +if __name__ == "__main__": + asyncio.run(main()) From 35b04ef28017b24aac7570653e42d882fb4810a0 Mon Sep 17 00:00:00 2001 From: Lev Neiman Date: Wed, 11 Mar 2026 18:33:53 -0700 Subject: [PATCH 5/5] docs: clarify example tool usage --- examples/langchain/langchain_middleware_smoke.py | 2 ++ examples/langchain/langgraph_toolnode_integration_smoke.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/langchain/langchain_middleware_smoke.py b/examples/langchain/langchain_middleware_smoke.py index 71fb9a51..dba76b12 100644 --- a/examples/langchain/langchain_middleware_smoke.py +++ b/examples/langchain/langchain_middleware_smoke.py @@ -2,6 +2,8 @@ This example proves the LangChain agent-factory middleware path works with ``AgentControlMiddleware`` and protects tool calls without using ``@control()``. +It passes a plain ``@tool``-decorated Python function directly to +``langchain.create_agent(...)``. Run: cd examples/langchain diff --git a/examples/langchain/langgraph_toolnode_integration_smoke.py b/examples/langchain/langgraph_toolnode_integration_smoke.py index cfc26f6a..4306b980 100644 --- a/examples/langchain/langgraph_toolnode_integration_smoke.py +++ b/examples/langchain/langgraph_toolnode_integration_smoke.py @@ -1,7 +1,9 @@ """Minimal LangGraph smoke test for the Agent Control ToolNode integration. This example proves the LangGraph ToolNode wrapper path works without using -``@control()`` on the underlying tool implementation. +``@control()`` on the underlying tool implementation. It passes a plain +``@tool``-decorated Python function directly to the Agent Control LangGraph +integration. Run: cd examples/langchain