Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 72 additions & 12 deletions sdks/python/src/agent_control/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,61 @@ class _ControlAdapter:
control: "ControlDefinition"


def _get_applicable_controls(
controls: list[_ControlAdapter],
request: EvaluationRequest,
*,
context: Literal["sdk", "server"],
) -> list[_ControlAdapter]:
"""Return parsed controls that apply to this request in the given context."""
applicable_controls = ControlEngine(
controls,
context=context,
).get_applicable_controls(request)
return cast(list[_ControlAdapter], applicable_controls)


def _has_applicable_prefiltered_server_controls(
server_control_payloads: list[dict[str, Any]],
request: EvaluationRequest,
) -> bool:
"""Return whether any partitioned server control applies to this request.

The caller is responsible for partitioning raw control payloads by
``execution`` before calling this helper. This function only inspects the
server-control subset and does not re-check ``execution`` itself.

If any server control payload cannot be parsed locally, this returns True so
the SDK still defers to the server for authoritative handling.
"""
parsed_server_controls: list[_ControlAdapter] = []

for control in server_control_payloads:
try:
control_def = ControlDefinition.model_validate(control["control"])
parsed_server_controls.append(
_ControlAdapter(
id=control["id"],
name=control["name"],
control=control_def,
)
)
except Exception:
# Preserve existing fail-open behavior for malformed server controls.
return True

if not parsed_server_controls:
return False

return bool(
_get_applicable_controls(
parsed_server_controls,
request,
context="server",
)
)


def _merge_results(
local_result: "EvaluationResponse",
server_result: "EvaluationResponse",
Expand Down Expand Up @@ -212,15 +267,15 @@ async def check_evaluation_with_local(
# Partition controls by local flag
local_controls: list[_ControlAdapter] = []
parse_errors: list[ControlMatch] = []
has_server_controls = False
server_control_payloads: list[dict[str, Any]] = []

for control in controls:
control_data = control.get("control", {})
execution = control_data.get("execution", "server")
is_local = execution == "sdk"

if not is_local:
has_server_controls = True
server_control_payloads.append(control)
continue

try:
Expand Down Expand Up @@ -272,6 +327,12 @@ async def check_evaluation_with_local(
)
)

request = EvaluationRequest(
agent_name=normalized_name,
step=step,
stage=stage,
)

def _with_parse_errors(result: EvaluationResult) -> EvaluationResult:
if not parse_errors:
return result
Expand All @@ -285,21 +346,20 @@ def _with_parse_errors(result: EvaluationResult) -> EvaluationResult:
non_matches=result.non_matches,
)

request = EvaluationRequest(
agent_name=normalized_name,
step=step,
stage=stage,
)

local_result: EvaluationResponse | None = None
if local_controls:
engine = ControlEngine(local_controls, context="sdk")
applicable_local_controls = _get_applicable_controls(
local_controls,
request,
context="sdk",
)
if applicable_local_controls:
engine = ControlEngine(applicable_local_controls, context="sdk")
local_result = await engine.process(request)

_emit_local_events(
local_result,
request,
local_controls,
applicable_local_controls,
trace_id,
span_id,
agent_name=event_agent_name,
Expand All @@ -317,7 +377,7 @@ def _with_parse_errors(result: EvaluationResult) -> EvaluationResult:
)
)

if has_server_controls:
if _has_applicable_prefiltered_server_controls(server_control_payloads, request):
request_payload = request.model_dump(mode="json", exclude_none=True)
headers: dict[str, str] = {}
if trace_id:
Expand Down
184 changes: 169 additions & 15 deletions sdks/python/tests/test_local_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,21 @@
"""

from typing import Any
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from agent_control.client import AgentControlClient
from agent_control.evaluation import (
_merge_results,
check_evaluation_with_local,
)
from agent_control_models import (
ControlMatch,
EvaluationResponse,
EvaluationResult,
EvaluatorResult,
Step,
)

from agent_control.client import AgentControlClient
from agent_control.evaluation import (
_merge_results,
check_evaluation_with_local,
)


# =============================================================================
# Test Fixtures
# =============================================================================
Expand Down Expand Up @@ -55,26 +51,36 @@ def make_control_dict(
control_id: int,
name: str,
*,
enabled: bool = True,
execution: str = "server",
evaluator: str = "regex",
pattern: str = r"test",
action: str = "deny",
step_type: str = "llm",
stage: str = "pre",
step_names: list[str] | None = None,
step_name_regex: str | None = None,
path: str | None = None,
) -> dict[str, Any]:
"""Create a control dict like what initAgent returns."""
# Default path based on payload type
if path is None:
path = "input"

scope: dict[str, Any] = {"step_types": [step_type], "stages": [stage]}
if step_names is not None:
scope["step_names"] = step_names
if step_name_regex is not None:
scope["step_name_regex"] = step_name_regex

return {
"id": control_id,
"name": name,
"control": {
"description": f"Test control {name}",
"enabled": True,
"enabled": enabled,
"execution": execution,
"scope": {"step_types": [step_type], "stages": [stage]},
"scope": scope,
"selector": {"path": path},
"evaluator": {
"name": evaluator,
Expand All @@ -85,6 +91,15 @@ def make_control_dict(
}


NON_APPLICABLE_CONTROL_CASES = [
pytest.param({"enabled": False}, id="disabled"),
pytest.param({"stage": "post"}, id="stage_mismatch"),
pytest.param({"step_type": "tool"}, id="step_type_mismatch"),
pytest.param({"step_names": ["send_email"]}, id="step_name_mismatch"),
pytest.param({"step_name_regex": r"^send_.*$"}, id="step_name_regex_mismatch"),
]


# =============================================================================
# Test: _merge_results
# =============================================================================
Expand Down Expand Up @@ -254,6 +269,146 @@ async def test_server_only_controls_calls_server(self, agent_name, llm_payload):

assert result.is_safe is True

@pytest.mark.asyncio
@pytest.mark.parametrize(
"control_kwargs",
NON_APPLICABLE_CONTROL_CASES,
)
async def test_non_applicable_server_controls_do_not_call_server(
self,
agent_name,
llm_payload,
control_kwargs,
):
"""Given: Only server controls that do not apply to this invocation.

When: check_evaluation_with_local is called.
Then: The SDK skips the server evaluation request entirely.
"""
controls = [
make_control_dict(
1,
"server_ctrl",
execution="server",
**control_kwargs,
),
]

client = MagicMock(spec=AgentControlClient)
mock_response = MagicMock()
mock_response.json.return_value = {"is_safe": True, "confidence": 1.0}
mock_response.raise_for_status = MagicMock()
client.http_client = AsyncMock()
client.http_client.post = AsyncMock(return_value=mock_response)

result = await check_evaluation_with_local(
client=client,
agent_name=agent_name,
step=llm_payload,
stage="pre",
controls=controls,
)

client.http_client.post.assert_not_called()
assert result.is_safe is True
assert result.confidence == 1.0
assert result.matches is None

@pytest.mark.asyncio
@pytest.mark.parametrize(
"control_kwargs",
NON_APPLICABLE_CONTROL_CASES,
)
async def test_non_applicable_local_controls_skip_local_evaluation(
self,
agent_name,
llm_payload,
control_kwargs,
):
"""Given: Only local controls that do not apply to this invocation.

When: check_evaluation_with_local is called.
Then: The SDK skips local evaluation and returns a no-op safe result.
"""
controls = [
make_control_dict(
1,
"local_ctrl",
execution="sdk",
**control_kwargs,
),
]

client = MagicMock(spec=AgentControlClient)
client.http_client = AsyncMock()
client.http_client.post = AsyncMock()

with patch(
"agent_control.evaluation.ControlEngine.process",
side_effect=AssertionError("local evaluation should not run"),
) as mock_process:
result = await check_evaluation_with_local(
client=client,
agent_name=agent_name,
step=llm_payload,
stage="pre",
controls=controls,
)

mock_process.assert_not_called()
client.http_client.post.assert_not_called()
assert result.is_safe is True
assert result.confidence == 1.0
assert result.matches is None

@pytest.mark.asyncio
async def test_non_applicable_local_controls_skip_local_but_still_call_server(
self,
agent_name,
llm_payload,
):
"""Given: A non-applicable local control and an applicable server control.

When: check_evaluation_with_local is called.
Then: Local evaluation is skipped, but the applicable server control still runs.
"""
controls = [
make_control_dict(
1,
"local_ctrl",
execution="sdk",
step_names=["send_email"],
),
make_control_dict(
2,
"server_ctrl",
execution="server",
),
]

client = MagicMock(spec=AgentControlClient)
mock_response = MagicMock()
mock_response.json.return_value = {"is_safe": True, "confidence": 1.0}
mock_response.raise_for_status = MagicMock()
client.http_client = AsyncMock()
client.http_client.post = AsyncMock(return_value=mock_response)

with patch(
"agent_control.evaluation.ControlEngine.process",
side_effect=AssertionError("local evaluation should not run"),
) as mock_process:
result = await check_evaluation_with_local(
client=client,
agent_name=agent_name,
step=llm_payload,
stage="pre",
controls=controls,
)

mock_process.assert_not_called()
client.http_client.post.assert_called_once()
assert result.is_safe is True

@pytest.mark.asyncio
async def test_local_deny_short_circuits(self, agent_name, llm_payload):
"""Local deny should return immediately without calling server."""
Expand Down Expand Up @@ -634,7 +789,8 @@ async def test_local_control_with_agent_scoped_evaluator_raises(self, agent_name
async def test_server_control_with_missing_evaluator_allowed(self, agent_name, llm_payload):
"""Test that server control with unavailable evaluator is allowed (server handles it).

Given: A server control (execution="server") referencing an evaluator that doesn't exist locally
Given: A server control (execution="server") referencing an evaluator that
doesn't exist locally
When: check_evaluation_with_local is called
Then: No error, server is called to handle it
"""
Expand Down Expand Up @@ -755,8 +911,6 @@ async def test_local_evaluation_includes_steering_context(self, agent_name, llm_
Then: Response includes steering_context field in matches
Coverage: Lines 275, 280, 298-301 in control_decorators.py
"""
from agent_control_models.controls import SteeringContext

controls = [
{
"id": 1,
Expand Down
Loading