diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e1c80960..02c45e7c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -62,7 +62,7 @@ jobs: - name: Upload coverage to Codecov uses: codecov/codecov-action@v4 with: - files: coverage-engine.xml,coverage-server.xml,coverage-sdk.xml + files: coverage-models.xml,coverage-engine.xml,coverage-server.xml,coverage-sdk.xml fail_ci_if_error: false token: ${{ secrets.CODECOV_TOKEN }} diff --git a/Makefile b/Makefile index 40f3ec5f..b858a3ed 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: help sync openapi-spec openapi-spec-check test test-extras test-all test-models test-sdk lint lint-fix typecheck check build build-models build-server build-sdk publish publish-models publish-server publish-sdk hooks-install hooks-uninstall prepush evaluators-test evaluators-lint evaluators-lint-fix evaluators-typecheck evaluators-build galileo-test galileo-lint galileo-lint-fix galileo-typecheck galileo-build sdk-ts-generate sdk-ts-overlay-test sdk-ts-name-check sdk-ts-generate-check sdk-ts-build sdk-ts-test sdk-ts-lint sdk-ts-typecheck sdk-ts-release-check sdk-ts-publish-dry-run sdk-ts-publish +.PHONY: help sync openapi-spec openapi-spec-check test test-extras test-all models-test test-models test-sdk lint lint-fix typecheck check build build-models build-server build-sdk publish publish-models publish-server publish-sdk hooks-install hooks-uninstall prepush evaluators-test evaluators-lint evaluators-lint-fix evaluators-typecheck evaluators-build galileo-test galileo-lint galileo-lint-fix galileo-typecheck galileo-build sdk-ts-generate sdk-ts-overlay-test sdk-ts-name-check sdk-ts-generate-check sdk-ts-build sdk-ts-test sdk-ts-lint sdk-ts-typecheck sdk-ts-release-check sdk-ts-publish-dry-run sdk-ts-publish # Workspace package names PACK_MODELS := agent-control-models @@ -31,7 +31,8 @@ help: @echo " make openapi-spec-check - verify OpenAPI generation succeeds" @echo "" @echo "Test:" - @echo " make test - run tests for core packages (server, engine, sdk, evaluators)" + @echo " make test - run tests for core packages (models, server, engine, sdk, evaluators)" + @echo " make models-test - run shared model tests with coverage" @echo " make test-extras - run tests for contrib evaluators (galileo, etc.)" @echo " make test-all - run all tests (core + extras)" @echo " make sdk-ts-test - run TypeScript SDK tests" @@ -81,7 +82,12 @@ openapi-spec-check: openapi-spec # Test # --------------------------- -test: server-test engine-test sdk-test evaluators-test +test: models-test server-test engine-test sdk-test evaluators-test + +models-test: + cd $(MODELS_DIR) && uv run pytest --cov=src --cov-report=xml:../coverage-models.xml -q + +test-models: models-test # Run tests for contrib evaluators (not included in default test target) test-extras: galileo-test diff --git a/README.md b/README.md index dfd528fe..a57c2263 100644 --- a/README.md +++ b/README.md @@ -235,10 +235,12 @@ async def setup(): "enabled": True, "execution": "server", "scope": {"stages": ["post"]}, - "selector": {"path": "output"}, - "evaluator": { - "name": "regex", - "config": {"pattern": r"\b\d{3}-\d{2}-\d{4}\b"}, + "condition": { + "selector": {"path": "output"}, + "evaluator": { + "name": "regex", + "config": {"pattern": r"\b\d{3}-\d{2}-\d{4}\b"}, + }, }, "action": {"decision": "deny"}, }, @@ -257,6 +259,8 @@ async def setup(): asyncio.run(setup()) ``` +Controls now store leaf `selector` and `evaluator` definitions under `condition`, which also enables composite `and`, `or`, and `not` trees. + Now, test your agent: ```bash diff --git a/engine/src/agent_control_engine/core.py b/engine/src/agent_control_engine/core.py index c109cc80..b89da3cb 100644 --- a/engine/src/agent_control_engine/core.py +++ b/engine/src/agent_control_engine/core.py @@ -14,6 +14,7 @@ import re2 from agent_control_evaluators import get_evaluator_instance from agent_control_models import ( + ConditionNode, ControlDefinition, ControlMatch, EvaluationRequest, @@ -54,11 +55,18 @@ class _EvalTask: """Internal container for evaluation task context.""" item: ControlWithIdentity - data: Any task: asyncio.Task[None] | None = None result: EvaluatorResult | None = None +@dataclass +class _ConditionEvaluation: + """Internal result for recursive condition evaluation.""" + + result: EvaluatorResult + trace: dict[str, Any] + + class ControlEngine: """Executes controls against requests with parallel evaluation. @@ -79,6 +87,276 @@ def __init__( self.controls = controls self.context = context + @staticmethod + def _truncated_message(message: str | None) -> str | None: + """Truncate long evaluator messages in condition traces.""" + if not message: + return None + if len(message) <= 200: + return message + return f"{message[:197]}..." + + @staticmethod + def _format_exception(error: BaseException) -> str: + """Format exceptions consistently for result.error fields.""" + return f"{type(error).__name__}: {error}" + + @staticmethod + def _build_error_result( + error: str, + *, + message_prefix: str = "Evaluation failed", + ) -> EvaluatorResult: + """Create a failed evaluator result from an internal error string.""" + return EvaluatorResult( + matched=False, + confidence=0.0, + message=f"{message_prefix}: {error}", + error=error, + ) + + def _skipped_trace(self, node: ConditionNode, reason: str) -> dict[str, Any]: + """Build an unevaluated trace subtree for short-circuited branches.""" + trace: dict[str, Any] = { + "type": node.kind(), + "evaluated": False, + "matched": None, + "short_circuit_reason": reason, + } + if node.is_leaf(): + leaf_parts = node.leaf_parts() + if leaf_parts is None: + raise ValueError("Leaf condition must contain selector and evaluator") + selector, evaluator = leaf_parts + trace["selector_path"] = selector.path + trace["evaluator_name"] = evaluator.name + trace["confidence"] = None + trace["error"] = None + return trace + + trace["children"] = [ + self._skipped_trace(child, reason) for child in node.children_in_order() + ] + return trace + + async def _evaluate_leaf( + self, + item: ControlWithIdentity, + node: ConditionNode, + request: EvaluationRequest, + semaphore: asyncio.Semaphore, + ) -> _ConditionEvaluation: + """Evaluate a leaf selector/evaluator pair. + + The shared semaphore limits concurrent leaf evaluator executions across + the entire engine run. Composite conditions evaluate serially, so a + single control only holds one semaphore slot at a time, but multi-leaf + controls may acquire and release that shared slot more than once while + traversing their tree. + """ + leaf_parts = node.leaf_parts() + if leaf_parts is None: + raise ValueError("Leaf condition must contain selector and evaluator") + selector, evaluator_spec = leaf_parts + + selector_path = selector.path or "*" + data = select_data(request.step, selector_path) + + try: + async with semaphore: + evaluator = get_evaluator_instance(evaluator_spec) + timeout = evaluator.get_timeout_seconds() + if timeout <= 0: + timeout = DEFAULT_EVALUATOR_TIMEOUT + + result = await asyncio.wait_for( + evaluator.evaluate(data), + timeout=timeout, + ) + except TimeoutError: + error_msg = f"TimeoutError: Evaluator exceeded {timeout}s timeout" + logger.warning( + "Evaluator timeout for control '%s' (evaluator: %s): %s", + item.name, + evaluator_spec.name, + error_msg, + exc_info=True, + ) + result = self._build_error_result(error_msg) + except Exception as e: + error_msg = self._format_exception(e) + logger.error( + "Evaluator error for control '%s' (evaluator: %s): %s", + item.name, + evaluator_spec.name, + error_msg, + exc_info=True, + ) + result = self._build_error_result(error_msg) + + trace = { + "type": "leaf", + "evaluated": True, + "matched": result.matched, + "selector_path": selector_path, + "evaluator_name": evaluator_spec.name, + "confidence": result.confidence, + "error": result.error, + "message": self._truncated_message(result.message), + } + metadata = dict(result.metadata or {}) + metadata["condition_trace"] = trace + return _ConditionEvaluation( + result=result.model_copy(update={"metadata": metadata}), + trace=trace, + ) + + def _build_composite_result( + self, + *, + matched: bool, + confidence: float, + trace: dict[str, Any], + error: str | None = None, + ) -> EvaluatorResult: + """Create a composite evaluator result with a condition trace.""" + if error is not None: + return EvaluatorResult( + matched=False, + confidence=0.0, + message=f"Condition evaluation failed: {error}", + metadata={"condition_trace": trace}, + error=error, + ) + + message = "Condition tree matched" if matched else "Condition tree did not match" + return EvaluatorResult( + matched=matched, + confidence=confidence, + message=message, + metadata={"condition_trace": trace}, + ) + + async def _evaluate_condition( + self, + item: ControlWithIdentity, + node: ConditionNode, + request: EvaluationRequest, + semaphore: asyncio.Semaphore, + ) -> _ConditionEvaluation: + """Evaluate a recursive condition tree.""" + if node.is_leaf(): + return await self._evaluate_leaf(item, node, request, semaphore) + + kind = node.kind() + children = node.children_in_order() + child_evaluations: list[_ConditionEvaluation] = [] + + if kind == "not": + child_eval = await self._evaluate_condition(item, children[0], request, semaphore) + trace = { + "type": "not", + "evaluated": True, + "matched": None if child_eval.result.error else (not child_eval.result.matched), + "children": [child_eval.trace], + } + if child_eval.result.error: + return _ConditionEvaluation( + result=self._build_composite_result( + matched=False, + confidence=0.0, + trace=trace, + error=child_eval.result.error, + ), + trace=trace, + ) + + result = self._build_composite_result( + matched=not child_eval.result.matched, + confidence=child_eval.result.confidence, + trace=trace, + ) + return _ConditionEvaluation(result=result, trace=trace) + + for index, child in enumerate(children): + child_eval = await self._evaluate_condition(item, child, request, semaphore) + child_evaluations.append(child_eval) + + if child_eval.result.error: + remaining = children[index + 1 :] + trace = { + "type": kind, + "evaluated": True, + "matched": False, + "children": [ + evaluation.trace for evaluation in child_evaluations + ] + + [self._skipped_trace(rest, "error") for rest in remaining], + "short_circuit_reason": "error", + } + return _ConditionEvaluation( + result=self._build_composite_result( + matched=False, + confidence=0.0, + trace=trace, + error=child_eval.result.error, + ), + trace=trace, + ) + + should_short_circuit = ( + kind == "and" and not child_eval.result.matched + ) or (kind == "or" and child_eval.result.matched) + if should_short_circuit: + remaining = children[index + 1 :] + matched = child_eval.result.matched if kind == "or" else False + trace = { + "type": kind, + "evaluated": True, + "matched": matched, + "children": [ + evaluation.trace for evaluation in child_evaluations + ] + + [ + self._skipped_trace( + rest, + "or_matched" if kind == "or" else "and_failed", + ) + for rest in remaining + ], + "short_circuit_reason": ( + "or_matched" if kind == "or" else "and_failed" + ), + } + confidence = min( + evaluation.result.confidence for evaluation in child_evaluations + ) + result = self._build_composite_result( + matched=matched, + confidence=confidence, + trace=trace, + ) + return _ConditionEvaluation(result=result, trace=trace) + + confidence = min(evaluation.result.confidence for evaluation in child_evaluations) + matched = all( + evaluation.result.matched for evaluation in child_evaluations + ) if kind == "and" else any( + evaluation.result.matched for evaluation in child_evaluations + ) + trace = { + "type": kind, + "evaluated": True, + "matched": matched, + "children": [evaluation.trace for evaluation in child_evaluations], + } + result = self._build_composite_result( + matched=matched, + confidence=confidence, + trace=trace, + ) + return _ConditionEvaluation(result=result, trace=trace) + def get_applicable_controls( self, request: EvaluationRequest, @@ -169,73 +447,45 @@ async def process(self, request: EvaluationRequest) -> EvaluationResponse: ) # Prepare evaluation tasks - eval_tasks: list[_EvalTask] = [] - for item in applicable: - control_def = item.control - sel_path = control_def.selector.path or "*" - data = select_data(request.step, sel_path) - eval_tasks.append(_EvalTask(item=item, data=data)) + eval_tasks: list[_EvalTask] = [_EvalTask(item=item) for item in applicable] # Run evaluations in parallel with cancel-on-deny matches: list[ControlMatch] = [] is_safe = True deny_found = asyncio.Event() + # The concurrency cap applies to visited leaf evaluator executions, not + # whole top-level controls. Composite trees are still walked serially. semaphore = asyncio.Semaphore(MAX_CONCURRENT_EVALUATIONS) async def evaluate_control(eval_task: _EvalTask) -> None: """Evaluate a single control, respecting cancellation and timeout.""" - async with semaphore: - try: - evaluator = get_evaluator_instance(eval_task.item.control.evaluator) - # Use evaluator's timeout or fall back to default - timeout = evaluator.get_timeout_seconds() - if timeout <= 0: - timeout = DEFAULT_EVALUATOR_TIMEOUT - - eval_task.result = await asyncio.wait_for( - evaluator.evaluate(eval_task.data), - timeout=timeout, - ) + try: + evaluation = await self._evaluate_condition( + eval_task.item, + eval_task.item.control.condition, + request, + semaphore, + ) + eval_task.result = evaluation.result - # Signal if this is a deny match - only deny should trigger cancellation - # to preserve deny-first semantics - if ( - eval_task.result.matched - and eval_task.item.control.action.decision == "deny" - ): - deny_found.set() - except asyncio.CancelledError: - # Task was cancelled due to another deny - that's OK - raise - except TimeoutError: - # Evaluator timed out - error_msg = f"TimeoutError: Evaluator exceeded {timeout}s timeout" - logger.warning( - f"Evaluator timeout for control '{eval_task.item.name}' " - f"(evaluator: {eval_task.item.control.evaluator.name}): {error_msg}", - exc_info=True, - ) - eval_task.result = EvaluatorResult( - matched=False, - confidence=0.0, - message=f"Evaluation failed: {error_msg}", - error=error_msg, - ) - except Exception as e: - # Evaluation error - fail open but mark as error - # The error field signals to callers that this was not a real evaluation - error_msg = f"{type(e).__name__}: {e}" - logger.error( - f"Evaluator error for control '{eval_task.item.name}' " - f"(evaluator: {eval_task.item.control.evaluator.name}): {error_msg}", - exc_info=True, - ) - eval_task.result = EvaluatorResult( - matched=False, - confidence=0.0, - message=f"Evaluation failed: {error_msg}", - error=error_msg, - ) + if ( + eval_task.result.matched + and eval_task.item.control.action.decision == "deny" + ): + deny_found.set() + except asyncio.CancelledError: + raise + except Exception as error: + error_msg = self._format_exception(error) + logger.exception( + "Unexpected condition evaluation error for control '%s': %s", + eval_task.item.name, + error_msg, + ) + eval_task.result = self._build_error_result( + error_msg, + message_prefix="Condition evaluation failed", + ) # Create and start all tasks for eval_task in eval_tasks: diff --git a/engine/tests/test_core.py b/engine/tests/test_core.py index d5e7418f..186ad243 100644 --- a/engine/tests/test_core.py +++ b/engine/tests/test_core.py @@ -17,13 +17,11 @@ from agent_control_models import ( ControlAction, ControlDefinition, - ControlScope, - ControlSelector, EvaluationRequest, EvaluatorResult, EvaluatorSpec, - Step, SteeringContext, + Step, ) from pydantic import BaseModel @@ -210,11 +208,13 @@ def make_control( enabled=True, execution=execution, scope=scope, - selector=selector or {"path": "*"}, - evaluator=EvaluatorSpec( - name=evaluator, - config={"value": config_value}, - ), + condition={ + "selector": selector or {"path": "*"}, + "evaluator": EvaluatorSpec( + name=evaluator, + config={"value": config_value}, + ), + }, action=( ControlAction(decision=action, steering_context=steering_context) if steering_context @@ -632,6 +632,41 @@ async def test_error_with_log_action_fails_open(self): # Error should be captured assert result.errors is not None + @pytest.mark.asyncio + async def test_unexpected_condition_error_is_captured_as_control_error( + self, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Unexpected traversal failures should surface as normal control errors.""" + # Given: a deny control whose condition traversal raises unexpectedly + controls = [ + make_control(1, "unexpected_error", "test-allow", action="deny", config_value="ok"), + ] + engine = ControlEngine(controls) + + async def raise_unexpected(*_args: object, **_kwargs: object) -> object: + raise RuntimeError("unexpected traversal bug") + + monkeypatch.setattr(engine, "_evaluate_condition", raise_unexpected) + + # When: processing the request + request = EvaluationRequest( + agent_name="00000000-0000-0000-0000-000000000001", + step=Step(type="llm", name="test-step", input="test", output=None), + stage="pre", + ) + result = await engine.process(request) + + # Then: the control is reported as an error instead of being silently dropped + assert result.is_safe is False + assert result.confidence == 0.0 + assert result.matches is None + assert result.errors is not None + assert len(result.errors) == 1 + assert result.errors[0].control_name == "unexpected_error" + assert result.errors[0].result.error is not None + assert "unexpected traversal bug" in result.errors[0].result.error + @pytest.mark.asyncio async def test_missing_evaluator_error_sets_error_field(self): """Test that missing evaluator error sets error field in result. @@ -1063,8 +1098,10 @@ def test_invalid_step_name_regex_rejected(self): enabled=True, execution="server", scope={"step_types": ["tool"], "stages": ["pre"], "step_name_regex": "("}, - selector={"path": "input"}, - evaluator=EvaluatorSpec(name="test-allow", config={"value": "x"}), + condition={ + "selector": {"path": "input"}, + "evaluator": EvaluatorSpec(name="test-allow", config={"value": "x"}), + }, action={"decision": "log"}, ) @@ -1100,11 +1137,13 @@ async def test_evaluator_timeout_is_enforced(self): enabled=True, execution="server", scope={"step_types": ["llm"], "stages": ["pre"]}, - selector={"path": "input"}, - evaluator=EvaluatorSpec( - name="test-timeout", - config={"value": "t1", "timeout_ms": 100}, - ), + condition={ + "selector": {"path": "input"}, + "evaluator": EvaluatorSpec( + name="test-timeout", + config={"value": "t1", "timeout_ms": 100}, + ), + }, action={"decision": "deny"}, ), ) @@ -1158,11 +1197,13 @@ async def test_timeout_does_not_affect_fast_evaluators(self): enabled=True, execution="server", scope={"step_types": ["llm"], "stages": ["pre"]}, - selector={"path": "input"}, - evaluator=EvaluatorSpec( - name="test-timeout", - config={"value": "slow", "timeout_ms": 100}, - ), + condition={ + "selector": {"path": "input"}, + "evaluator": EvaluatorSpec( + name="test-timeout", + config={"value": "slow", "timeout_ms": 100}, + ), + }, action={"decision": "log"}, # Log, not deny - so fails open ), ), @@ -1266,6 +1307,317 @@ async def evaluate(self, data: Any) -> EvaluatorResult: assert _max_concurrent <= 2, f"Expected max 2 concurrent, got {_max_concurrent}" +# ============================================================================= +# Test: Recursive Condition Trees +# ============================================================================= + + +class TestConditionTrees: + """Tests for recursive condition evaluation and trace metadata.""" + + @pytest.fixture(autouse=True) + def register_error_evaluator(self): + """Register ErrorEvaluator for these tests.""" + try: + register_evaluator(ErrorEvaluator) + except ValueError: + pass + + @pytest.mark.asyncio + async def test_or_short_circuit_records_skipped_trace(self): + """A matching OR child should short-circuit later children and mark them skipped.""" + # Given: an OR tree whose first child matches + controls = [ + MockControlWithIdentity( + id=1, + name="or_short_circuit", + control=ControlDefinition( + description="Short-circuit OR", + enabled=True, + execution="server", + scope={"step_types": ["llm"], "stages": ["pre"]}, + condition={ + "or": [ + { + "selector": {"path": "input"}, + "evaluator": {"name": "test-deny", "config": {"value": "match"}}, + }, + { + "selector": {"path": "input"}, + "evaluator": {"name": "test-slow", "config": {"value": "skip"}}, + }, + ] + }, + action={"decision": "log"}, + ), + ) + ] + engine = ControlEngine(controls) + + # When: processing the request + request = EvaluationRequest( + agent_name="00000000-0000-0000-0000-000000000001", + step=Step(type="llm", name="test-step", input="test", output=None), + stage="pre", + ) + result = await engine.process(request) + + # Then: later children are short-circuited and marked skipped in the trace + assert result.matches is not None + assert len(result.matches) == 1 + assert "slow:skip:start" not in _execution_log + + trace = result.matches[0].result.metadata["condition_trace"] + assert trace["type"] == "or" + assert trace["matched"] is True + assert trace["short_circuit_reason"] == "or_matched" + assert trace["children"][0]["evaluated"] is True + assert trace["children"][0]["matched"] is True + assert trace["children"][1]["evaluated"] is False + assert trace["children"][1]["matched"] is None + assert trace["children"][1]["short_circuit_reason"] == "or_matched" + + @pytest.mark.asyncio + async def test_and_condition_all_children_match_records_full_trace(self): + """A fully-evaluated AND tree should record every child and produce a match.""" + # Given: an AND tree where every leaf evaluator matches + controls = [ + MockControlWithIdentity( + id=1, + name="and_all_match", + control=ControlDefinition( + description="All AND children match", + enabled=True, + execution="server", + scope={"step_types": ["llm"], "stages": ["pre"]}, + condition={ + "and": [ + { + "selector": {"path": "input"}, + "evaluator": {"name": "test-deny", "config": {"value": "first"}}, + }, + { + "selector": {"path": "input"}, + "evaluator": {"name": "test-deny", "config": {"value": "second"}}, + }, + ] + }, + action={"decision": "log"}, + ), + ) + ] + engine = ControlEngine(controls) + + # When: processing the request + request = EvaluationRequest( + agent_name="00000000-0000-0000-0000-000000000001", + step=Step(type="llm", name="test-step", input="test", output=None), + stage="pre", + ) + result = await engine.process(request) + + # Then: the control matches and every child appears as evaluated in the trace + assert result.matches is not None + assert len(result.matches) == 1 + trace = result.matches[0].result.metadata["condition_trace"] + assert trace["type"] == "and" + assert trace["matched"] is True + assert "short_circuit_reason" not in trace + assert len(trace["children"]) == 2 + assert all(child["evaluated"] is True for child in trace["children"]) + assert "deny:first:end" in _execution_log + assert "deny:second:end" in _execution_log + + @pytest.mark.asyncio + async def test_not_condition_inverts_child_result(self): + """NOT should invert the child match result while preserving trace structure.""" + # Given: a NOT tree whose child returns non-match + controls = [ + MockControlWithIdentity( + id=1, + name="not_condition", + control=ControlDefinition( + description="Invert non-match", + enabled=True, + execution="server", + scope={"step_types": ["llm"], "stages": ["pre"]}, + condition={ + "not": { + "selector": {"path": "input"}, + "evaluator": {"name": "test-allow", "config": {"value": "child"}}, + } + }, + action={"decision": "log"}, + ), + ) + ] + engine = ControlEngine(controls) + + # When: processing the request + request = EvaluationRequest( + agent_name="00000000-0000-0000-0000-000000000001", + step=Step(type="llm", name="test-step", input="test", output=None), + stage="pre", + ) + result = await engine.process(request) + + # Then: the NOT node inverts the child result and preserves trace structure + assert result.matches is not None + trace = result.matches[0].result.metadata["condition_trace"] + assert trace["type"] == "not" + assert trace["matched"] is True + assert len(trace["children"]) == 1 + assert trace["children"][0]["type"] == "leaf" + assert trace["children"][0]["matched"] is False + + @pytest.mark.asyncio + async def test_not_condition_propagates_child_error_trace(self): + """NOT should surface child evaluator failures as composite errors.""" + # Given: a NOT tree whose child evaluator raises an error + controls = [ + MockControlWithIdentity( + id=1, + name="not_error", + control=ControlDefinition( + description="Invert errored child", + enabled=True, + execution="server", + scope={"step_types": ["llm"], "stages": ["pre"]}, + condition={ + "not": { + "selector": {"path": "input"}, + "evaluator": {"name": "test-error", "config": {"value": "boom"}}, + } + }, + action={"decision": "log"}, + ), + ) + ] + engine = ControlEngine(controls) + + # When: processing the request + request = EvaluationRequest( + agent_name="00000000-0000-0000-0000-000000000001", + step=Step(type="llm", name="test-step", input="test", output=None), + stage="pre", + ) + result = await engine.process(request) + + # Then: the composite returns an error result and preserves the child trace + assert result.errors is not None + assert len(result.errors) == 1 + trace = result.errors[0].result.metadata["condition_trace"] + assert trace["type"] == "not" + assert trace["matched"] is None + assert len(trace["children"]) == 1 + assert trace["children"][0]["evaluated"] is True + assert "Intentional error from boom" in trace["children"][0]["error"] + + @pytest.mark.asyncio + async def test_or_condition_all_children_non_match_records_full_trace(self): + """A fully-evaluated OR tree should record every child and produce a non-match.""" + # Given: an OR tree where every leaf evaluator returns non-match + controls = [ + MockControlWithIdentity( + id=1, + name="or_all_non_match", + control=ControlDefinition( + description="All OR children miss", + enabled=True, + execution="server", + scope={"step_types": ["llm"], "stages": ["pre"]}, + condition={ + "or": [ + { + "selector": {"path": "input"}, + "evaluator": {"name": "test-allow", "config": {"value": "first"}}, + }, + { + "selector": {"path": "input"}, + "evaluator": {"name": "test-allow", "config": {"value": "second"}}, + }, + ] + }, + action={"decision": "log"}, + ), + ) + ] + engine = ControlEngine(controls) + + # When: processing the request + request = EvaluationRequest( + agent_name="00000000-0000-0000-0000-000000000001", + step=Step(type="llm", name="test-step", input="test", output=None), + stage="pre", + ) + result = await engine.process(request) + + # Then: the control is recorded as a non-match and every child is evaluated + assert result.non_matches is not None + assert len(result.non_matches) == 1 + trace = result.non_matches[0].result.metadata["condition_trace"] + assert trace["type"] == "or" + assert trace["matched"] is False + assert "short_circuit_reason" not in trace + assert len(trace["children"]) == 2 + assert all(child["evaluated"] is True for child in trace["children"]) + assert "allow:first:end" in _execution_log + assert "allow:second:end" in _execution_log + + @pytest.mark.asyncio + async def test_and_error_records_skipped_children_in_trace(self): + """Errors in composite conditions should preserve trace context for skipped branches.""" + # Given: an AND tree whose first child evaluator errors + controls = [ + MockControlWithIdentity( + id=1, + name="and_error", + control=ControlDefinition( + description="Error in AND", + enabled=True, + execution="server", + scope={"step_types": ["llm"], "stages": ["pre"]}, + condition={ + "and": [ + { + "selector": {"path": "input"}, + "evaluator": {"name": "test-error", "config": {"value": "boom"}}, + }, + { + "selector": {"path": "input"}, + "evaluator": {"name": "test-slow", "config": {"value": "skip"}}, + }, + ] + }, + action={"decision": "log"}, + ), + ) + ] + engine = ControlEngine(controls) + + # When: processing the request + request = EvaluationRequest( + agent_name="00000000-0000-0000-0000-000000000001", + step=Step(type="llm", name="test-step", input="test", output=None), + stage="pre", + ) + result = await engine.process(request) + + # Then: the trace preserves the erroring child and marks remaining children skipped + assert result.errors is not None + assert len(result.errors) == 1 + assert "slow:skip:start" not in _execution_log + + trace = result.errors[0].result.metadata["condition_trace"] + assert trace["type"] == "and" + assert trace["matched"] is False + assert trace["short_circuit_reason"] == "error" + assert trace["children"][0]["evaluated"] is True + assert "Intentional error from boom" in trace["children"][0]["error"] + assert trace["children"][1]["evaluated"] is False + assert trace["children"][1]["short_circuit_reason"] == "error" + + # ============================================================================= # Test: Context Filtering (local vs server) # ============================================================================= @@ -1304,11 +1656,13 @@ def make_control_with_execution( enabled=True, execution=execution, scope=scope, - selector={"path": path}, - evaluator=EvaluatorSpec( - name=evaluator, - config={"value": config_value}, - ), + condition={ + "selector": {"path": path}, + "evaluator": EvaluatorSpec( + name=evaluator, + config={"value": config_value}, + ), + }, action={"decision": action}, ), ) @@ -1802,10 +2156,12 @@ class MockControl: "step_types": ["llm"], "step_name_regex": "[invalid(regex", # Invalid regex pattern }, - "selector": {"path": "input"}, - "evaluator": { - "name": "test-allow", - "config": {"value": "test"}, + "condition": { + "selector": {"path": "input"}, + "evaluator": { + "name": "test-allow", + "config": {"value": "test"}, + }, }, "action": {"decision": "deny"}, } diff --git a/evaluators/contrib/cisco/README.md b/evaluators/contrib/cisco/README.md index 6461d5b9..6e5867bf 100644 --- a/evaluators/contrib/cisco/README.md +++ b/evaluators/contrib/cisco/README.md @@ -79,15 +79,17 @@ Example using `messages_strategy: "history"` (for inputs that already have a `me "enabled": true, "execution": "server", "scope": { "step_types": ["llm"], "stages": ["pre", "post"] }, - "selector": { "path": "input" }, - "evaluator": { - "name": "cisco.ai_defense", - "config": { - "api_key_env": "AI_DEFENSE_API_KEY", - "region": "us", - "timeout_ms": 15000, - "on_error": "allow", - "messages_strategy": "history" + "condition": { + "selector": { "path": "input" }, + "evaluator": { + "name": "cisco.ai_defense", + "config": { + "api_key_env": "AI_DEFENSE_API_KEY", + "region": "us", + "timeout_ms": 15000, + "on_error": "allow", + "messages_strategy": "history" + } } }, "action": { "decision": "deny" }, @@ -101,16 +103,18 @@ Example using `messages_strategy: "history"` (for inputs that already have a `me "enabled": true, "execution": "server", "scope": { "step_types": ["llm"], "stages": ["pre", "post"] }, - "selector": { "path": "input" }, - "evaluator": { - "name": "cisco.ai_defense", - "config": { - "api_key_env": "AI_DEFENSE_API_KEY", - "region": "us", - "timeout_ms": 15000, - "on_error": "allow", - "messages_strategy": "single", - "payload_field": "input" + "condition": { + "selector": { "path": "input" }, + "evaluator": { + "name": "cisco.ai_defense", + "config": { + "api_key_env": "AI_DEFENSE_API_KEY", + "region": "us", + "timeout_ms": 15000, + "on_error": "allow", + "messages_strategy": "single", + "payload_field": "input" + } } }, "action": { "decision": "deny" }, diff --git a/evaluators/contrib/cisco/pyproject.toml b/evaluators/contrib/cisco/pyproject.toml index e53b7d89..a9fc1091 100644 --- a/evaluators/contrib/cisco/pyproject.toml +++ b/evaluators/contrib/cisco/pyproject.toml @@ -41,4 +41,3 @@ select = ["E", "F", "I"] [tool.uv.sources] agent-control-evaluators = { path = "../../builtin", editable = true } agent-control-models = { path = "../../../models", editable = true } - diff --git a/examples/agent_control_demo/setup_controls.py b/examples/agent_control_demo/setup_controls.py index 77b56e72..09811fcb 100644 --- a/examples/agent_control_demo/setup_controls.py +++ b/examples/agent_control_demo/setup_controls.py @@ -119,13 +119,15 @@ async def create_regex_control(client: AgentControlClient) -> int: "enabled": True, "execution": "server", "scope": {"step_types": ["llm"], "stages": ["post"]}, # Check AFTER - "selector": {"path": "output"}, - "evaluator": { - "name": "regex", - "config": { - "pattern": r"\b\d{3}-\d{2}-\d{4}\b", # SSN pattern - "flags": [] - } + "condition": { + "selector": {"path": "output"}, + "evaluator": { + "name": "regex", + "config": { + "pattern": r"\b\d{3}-\d{2}-\d{4}\b", # SSN pattern + "flags": [] + } + }, }, "action": {"decision": "deny"}, "tags": ["pii", "ssn", "output-filter"] @@ -133,7 +135,7 @@ async def create_regex_control(client: AgentControlClient) -> int: print(f"Creating control: block-ssn-output") print(f" Type: Regex") - print(f" Pattern: {control_definition['evaluator']['config']['pattern']}") + print(f" Pattern: {control_definition['condition']['evaluator']['config']['pattern']}") print(f" Stages: {', '.join(control_definition['scope']['stages'])}") print(f" Action: {control_definition['action']['decision']}") @@ -151,16 +153,18 @@ async def create_list_control(client: AgentControlClient) -> int: "enabled": True, "execution": "server", "scope": {"step_types": ["llm"], "stages": ["pre"]}, # Check BEFORE - "selector": {"path": "input"}, - "evaluator": { - "name": "list", - "config": { - "values": ["DROP", "DELETE", "TRUNCATE", "ALTER", "GRANT"], - "logic": "any", # Block if ANY keyword is found - "match_on": "match", - "match_mode": "contains", # Substring/keyword matching - "case_sensitive": False - } + "condition": { + "selector": {"path": "input"}, + "evaluator": { + "name": "list", + "config": { + "values": ["DROP", "DELETE", "TRUNCATE", "ALTER", "GRANT"], + "logic": "any", # Block if ANY keyword is found + "match_on": "match", + "match_mode": "contains", # Substring/keyword matching + "case_sensitive": False + } + }, }, "action": {"decision": "deny"}, "tags": ["sql-injection", "input-filter", "security"] @@ -168,8 +172,8 @@ async def create_list_control(client: AgentControlClient) -> int: print(f"Creating control: block-dangerous-sql") print(f" Type: List") - print(f" Values: {control_definition['evaluator']['config']['values']}") - print(f" Logic: {control_definition['evaluator']['config']['logic']}") + print(f" Values: {control_definition['condition']['evaluator']['config']['values']}") + print(f" Logic: {control_definition['condition']['evaluator']['config']['logic']}") print(f" Stages: {', '.join(control_definition['scope']['stages'])}") print(f" Action: {control_definition['action']['decision']}") @@ -223,7 +227,12 @@ async def list_agent_controls(client: AgentControlClient, agent_name: str) -> li print(f" ID: {ctrl.get('id')}") ctrl_def = ctrl.get("control", {}) print(f" Enabled: {ctrl_def.get('enabled', True)}") - print(f" Type: {ctrl_def.get('evaluator', {}).get('type', 'unknown')}") + evaluator_name = ( + ctrl_def.get("condition", {}) + .get("evaluator", {}) + .get("name", "unknown") + ) + print(f" Evaluator: {evaluator_name}") scope = ctrl_def.get("scope", {}) or {} stages = scope.get("stages", []) stage_label = ", ".join(stages) if stages else "unknown" @@ -250,19 +259,21 @@ async def update_control(client: AgentControlClient, control_id: int) -> None: "enabled": True, "execution": "server", "scope": {"step_types": ["llm"], "stages": ["pre"]}, - "selector": {"path": "input"}, - "evaluator": { - "name": "list", - "config": { - "values": [ - "DROP", "DELETE", "TRUNCATE", "ALTER", "GRANT", - "REVOKE", "EXECUTE", "SHUTDOWN", "BACKUP" # More keywords! - ], - "logic": "any", - "match_on": "match", - "match_mode": "contains", # Substring/keyword matching - "case_sensitive": False - } + "condition": { + "selector": {"path": "input"}, + "evaluator": { + "name": "list", + "config": { + "values": [ + "DROP", "DELETE", "TRUNCATE", "ALTER", "GRANT", + "REVOKE", "EXECUTE", "SHUTDOWN", "BACKUP" # More keywords! + ], + "logic": "any", + "match_on": "match", + "match_mode": "contains", # Substring/keyword matching + "case_sensitive": False + } + }, }, "action": {"decision": "deny"}, "tags": ["sql-injection", "input-filter", "security", "updated"] @@ -298,8 +309,10 @@ async def get_control_data(client: AgentControlClient, control_id: int) -> dict: print(f"✓ Retrieved control {control_id}:") print(f" Description: {data.get('description', 'N/A')}") - print(f" Evaluator Type: {data.get('evaluator', {}).get('type', 'N/A')}") - print(f" Values: {data.get('evaluator', {}).get('config', {}).get('values', [])}") + condition = data.get("condition", {}) + evaluator = condition.get("evaluator", {}) + print(f" Evaluator: {evaluator.get('name', 'N/A')}") + print(f" Values: {evaluator.get('config', {}).get('values', [])}") print(f" Tags: {data.get('tags', [])}") return data diff --git a/examples/agent_control_demo/update_controls.py b/examples/agent_control_demo/update_controls.py index 57d16606..5afcd63b 100644 --- a/examples/agent_control_demo/update_controls.py +++ b/examples/agent_control_demo/update_controls.py @@ -58,13 +58,15 @@ async def allow_ssn(client: AgentControlClient, control_id: int) -> None: "enabled": False, # Disabled = SSNs allowed "execution": "server", "scope": {"step_types": ["llm"], "stages": ["post"]}, - "selector": {"path": "output"}, - "evaluator": { - "name": "regex", - "config": { - "pattern": r"\b\d{3}-\d{2}-\d{4}\b", - "flags": [] - } + "condition": { + "selector": {"path": "output"}, + "evaluator": { + "name": "regex", + "config": { + "pattern": r"\b\d{3}-\d{2}-\d{4}\b", + "flags": [] + } + }, }, "action": {"decision": "deny"}, "tags": ["pii", "ssn", "output-filter", "disabled"] @@ -98,13 +100,15 @@ async def block_ssn(client: AgentControlClient, control_id: int) -> None: "enabled": True, # Enabled = SSNs blocked "execution": "server", "scope": {"step_types": ["llm"], "stages": ["post"]}, - "selector": {"path": "output"}, - "evaluator": { - "name": "regex", - "config": { - "pattern": r"\b\d{3}-\d{2}-\d{4}\b", - "flags": [] - } + "condition": { + "selector": {"path": "output"}, + "evaluator": { + "name": "regex", + "config": { + "pattern": r"\b\d{3}-\d{2}-\d{4}\b", + "flags": [] + } + }, }, "action": {"decision": "deny"}, "tags": ["pii", "ssn", "output-filter"] diff --git a/examples/cisco_ai_defense/setup_ai_defense_controls.py b/examples/cisco_ai_defense/setup_ai_defense_controls.py index 1097e62c..c9cb6f4e 100644 --- a/examples/cisco_ai_defense/setup_ai_defense_controls.py +++ b/examples/cisco_ai_defense/setup_ai_defense_controls.py @@ -104,8 +104,13 @@ async def main() -> int: "enabled": True, "execution": "server", "scope": {"step_types": ["llm"], "stages": ["pre"]}, - "selector": {"path": "input"}, - "evaluator": {"name": EVALUATOR_NAME, "config": {**base_config, "payload_field": "input"}}, + "condition": { + "selector": {"path": "input"}, + "evaluator": { + "name": EVALUATOR_NAME, + "config": {**base_config, "payload_field": "input"}, + }, + }, "action": {"decision": "deny"}, "tags": ["ai_defense", "security", "safety", "privacy"], } @@ -115,8 +120,13 @@ async def main() -> int: "enabled": True, "execution": "server", "scope": {"step_types": ["llm"], "stages": ["post"]}, - "selector": {"path": "output"}, - "evaluator": {"name": EVALUATOR_NAME, "config": {**base_config, "payload_field": "output"}}, + "condition": { + "selector": {"path": "output"}, + "evaluator": { + "name": EVALUATOR_NAME, + "config": {**base_config, "payload_field": "output"}, + }, + }, "action": {"decision": "deny"}, "tags": ["ai_defense", "security", "safety", "privacy"], } diff --git a/examples/crewai/setup_content_controls.py b/examples/crewai/setup_content_controls.py index 763f2f18..197ff8e5 100644 --- a/examples/crewai/setup_content_controls.py +++ b/examples/crewai/setup_content_controls.py @@ -44,15 +44,17 @@ async def setup_content_controls(): "step_names": ["handle_ticket"], "stages": ["pre"] # Check input before processing }, - "selector": { - "path": "input.ticket" - }, - "evaluator": { - "name": "regex", - "config": { - # Block requests for other users' data, admin access, passwords - "pattern": r"(?i)(show\s+me|what\s+is|give\s+me|tell\s+me).*(other\s+user|another\s+user|user\s+\w+|admin|password|credential|account\s+\d+|all\s+orders|other\s+customer)" - } + "condition": { + "selector": { + "path": "input.ticket" + }, + "evaluator": { + "name": "regex", + "config": { + # Block requests for other users' data, admin access, passwords + "pattern": r"(?i)(show\s+me|what\s+is|give\s+me|tell\s+me).*(other\s+user|another\s+user|user\s+\w+|admin|password|credential|account\s+\d+|all\s+orders|other\s+customer)" + } + }, }, "action": {"decision": "deny"} } @@ -90,15 +92,17 @@ async def setup_content_controls(): "step_names": ["handle_ticket"], "stages": ["post"] # Check output after generation }, - "selector": { - "path": "output" - }, - "evaluator": { - "name": "regex", - "config": { - # Block SSN, credit cards, emails, phone numbers - "pattern": r"(?:\b\d{3}-\d{2}-\d{4}\b|\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b|[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}|\b\d{3}[-.]?\d{3}[-.]?\d{4}\b)" - } + "condition": { + "selector": { + "path": "output" + }, + "evaluator": { + "name": "regex", + "config": { + # Block SSN, credit cards, emails, phone numbers + "pattern": r"(?:\b\d{3}-\d{2}-\d{4}\b|\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b|[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}|\b\d{3}[-.]?\d{3}[-.]?\d{4}\b)" + } + }, }, "action": {"decision": "deny"} } @@ -136,15 +140,17 @@ async def setup_content_controls(): "step_names": ["validate_final_output"], "stages": ["post"] # Check output after validation function }, - "selector": { - "path": "output" - }, - "evaluator": { - "name": "regex", - "config": { - # Block SSN, credit cards, emails, phone numbers - "pattern": r"(?:\b\d{3}-\d{2}-\d{4}\b|\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b|[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}|\b\d{3}[-.]?\d{3}[-.]?\d{4}\b)" - } + "condition": { + "selector": { + "path": "output" + }, + "evaluator": { + "name": "regex", + "config": { + # Block SSN, credit cards, emails, phone numbers + "pattern": r"(?:\b\d{3}-\d{2}-\d{4}\b|\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b|[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}|\b\d{3}[-.]?\d{3}[-.]?\d{4}\b)" + } + }, }, "action": {"decision": "deny"} } diff --git a/examples/customer_support_agent/run_demo.py b/examples/customer_support_agent/run_demo.py index a0d7ca41..a352afba 100644 --- a/examples/customer_support_agent/run_demo.py +++ b/examples/customer_support_agent/run_demo.py @@ -14,7 +14,7 @@ /test-pii Test PII detection (if control configured) /test-injection Test prompt injection detection (if control configured) /test-multispan Test multi-span traces (2-3 spans per request) - /test-tool-controls Test tool-specific controls (tool_names, tool_name_regex) + /test-tool-controls Test tool-specific controls (step_names, step_name_regex) /lookup Look up a customer (e.g., /lookup C001) /search Search knowledge base (e.g., /search refund) /ticket [priority] Create a test ticket (e.g., /ticket high) @@ -141,7 +141,7 @@ def print_header(): print() print("Features demonstrated:") print(" • Multi-span traces (2-3 spans per request)") - print(" • ControlSelector options: path, tool_names, tool_name_regex") + print(" • Control targeting: selector paths plus scope.step_names / scope.step_name_regex") print(" • Various control types: LLM calls and tool calls") print() print("Commands:") @@ -166,7 +166,7 @@ def print_help(): print(" /test-pii Test PII detection control") print(" /test-injection Test prompt injection control") print(" /test-multispan Test multi-span traces (2-3 spans per request)") - print(" /test-tool-controls Test tool-specific controls (tool_names, tool_name_regex)") + print(" /test-tool-controls Test tool-specific controls (step_names, step_name_regex)") print() print("Tools (single span each):") print(" /lookup Look up customer (e.g., /lookup C001)") @@ -441,21 +441,21 @@ async def run_multispan_tests(agent: CustomerSupportAgent): async def run_tool_control_tests(agent: CustomerSupportAgent): - """Run tests to verify tool-specific controls (tool_names, tool_name_regex).""" + """Run tests to verify tool-specific controls (step_names, step_name_regex).""" print() print("-" * 50) print("Running Tool-Specific Control Tests") print("-" * 50) print() - print("These tests exercise controls using different ControlSelector options:") - print(" • tool_names: exact tool name match (e.g., 'lookup_customer')") - print(" • tool_name_regex: pattern match (e.g., 'search|lookup')") - print(" • path: argument paths (e.g., 'arguments.priority')") + print("These tests exercise controls using different scope/selector combinations:") + print(" • scope.step_names: exact tool name match (e.g., 'lookup_customer')") + print(" • scope.step_name_regex: pattern match (e.g., 'search|lookup')") + print(" • selector paths: input fields such as 'input.priority'") print() - # Test 1: SQL injection in customer lookup (tool_names: lookup_customer) + # Test 1: SQL injection in customer lookup (scope.step_names: lookup_customer) print("Test 1: SQL injection attempt in customer lookup") - print(" Control: block-sql-injection-customer-lookup (tool_names: exact match)") + print(" Control: block-sql-injection-customer-lookup (scope.step_names: exact match)") print(" Query: SELECT * FROM users --") response = await agent.lookup("SELECT * FROM users --") print(f" Result: {response}") @@ -468,9 +468,9 @@ async def run_tool_control_tests(agent: CustomerSupportAgent): print(f" Result: {response}") print() - # Test 3: Profanity in search (tool_name_regex: search|lookup) + # Test 3: Profanity in search (scope.step_name_regex: search|lookup) print("Test 3: Inappropriate content in search") - print(" Control: block-profanity-in-search (tool_name_regex: pattern match)") + print(" Control: block-profanity-in-search (scope.step_name_regex: pattern match)") print(" Query: badword") response = await agent.search("badword") print(f" Result: {response}") @@ -485,7 +485,7 @@ async def run_tool_control_tests(agent: CustomerSupportAgent): # Test 5: High priority ticket (warn-high-priority-ticket) print("Test 5: High priority ticket creation") - print(" Control: warn-high-priority-ticket (tool_name_regex + path: arguments.priority)") + print(" Control: warn-high-priority-ticket (scope.step_name_regex + selector path: input.priority)") print(" Priority: critical") response = await agent.create_support_ticket( subject="Urgent issue", @@ -508,7 +508,7 @@ async def run_tool_control_tests(agent: CustomerSupportAgent): # Test 7: PII in ticket description print("Test 7: Email in ticket description") - print(" Control: block-pii-in-ticket-description (tool_names + path: arguments.description)") + print(" Control: block-pii-in-ticket-description (scope.step_names + selector path: input.description)") response = await agent.create_support_ticket( subject="Contact request", description="Please email me at secret@company.com", diff --git a/examples/customer_support_agent/setup_demo_controls.py b/examples/customer_support_agent/setup_demo_controls.py index b61dc604..7e8c0538 100644 --- a/examples/customer_support_agent/setup_demo_controls.py +++ b/examples/customer_support_agent/setup_demo_controls.py @@ -29,7 +29,7 @@ ] # Demo controls to create -# Demonstrates various ControlSelector options: path, tool_names, tool_name_regex +# Demonstrates selector paths plus scope.step_names / scope.step_name_regex targeting. DEMO_CONTROLS = [ # ========================================================================== # LLM CALL CONTROLS (using 'path' selector) @@ -42,15 +42,14 @@ "enabled": True, "execution": "server", "scope": {"step_types": ["llm"], "stages": ["post"]}, - "selector": {"path": "output"}, - "evaluator": { - "name": "regex", - "config": {"pattern": r"\d{3}-\d{2}-\d{4}"}, - }, - "action": { - "decision": "deny", - "message": "Response contains SSN pattern - blocked for PII protection", + "condition": { + "selector": {"path": "output"}, + "evaluator": { + "name": "regex", + "config": {"pattern": r"\d{3}-\d{2}-\d{4}"}, + }, }, + "action": {"decision": "deny"}, }, }, { @@ -61,17 +60,16 @@ "enabled": True, "execution": "server", "scope": {"step_types": ["llm"], "stages": ["pre"]}, - "selector": {"path": "input"}, - "evaluator": { - "name": "regex", - "config": { - "pattern": r"(?i)(ignore.{0,20}(previous|prior|above).{0,20}instructions|you are now|system:|forget everything|disregard)" + "condition": { + "selector": {"path": "input"}, + "evaluator": { + "name": "regex", + "config": { + "pattern": r"(?i)(ignore.{0,20}(previous|prior|above).{0,20}instructions|you are now|system:|forget everything|disregard)" + }, }, }, - "action": { - "decision": "deny", - "message": "Potential prompt injection detected - request blocked", - }, + "action": {"decision": "deny"}, }, }, { @@ -82,93 +80,97 @@ "enabled": True, "execution": "server", "scope": {"step_types": ["llm"], "stages": ["pre"]}, - "selector": {"path": "input"}, - "evaluator": { - "name": "regex", - "config": {"pattern": r"\b\d{4}[-\s]?\d{4}[-\s]?\d{4}[-\s]?\d{4}\b"}, - }, - "action": { - "decision": "deny", - "message": "Credit card number detected - please don't share payment info", + "condition": { + "selector": {"path": "input"}, + "evaluator": { + "name": "regex", + "config": {"pattern": r"\b\d{4}[-\s]?\d{4}[-\s]?\d{4}[-\s]?\d{4}\b"}, + }, }, + "action": {"decision": "deny"}, }, }, # ========================================================================== - # TOOL CALL CONTROLS - using 'tool_names' (exact match) + # TOOL CALL CONTROLS - using scope.step_names (exact match) # ========================================================================== { "name": "block-sql-injection-customer-lookup", - "description": "Blocks SQL injection in customer lookup (tool_names: exact match)", + "description": "Blocks SQL injection in customer lookup (scope.step_names: exact match)", "definition": { - "description": "Blocks SQL injection in customer lookup (tool_names: exact match)", + "description": "Blocks SQL injection in customer lookup (scope.step_names: exact match)", "enabled": True, "execution": "server", - "scope": {"step_types": ["tool"], "stages": ["pre"]}, - "selector": { - "path": "input.query", - "tool_names": ["lookup_customer"], # Only applies to this exact tool + "scope": { + "step_types": ["tool"], + "step_names": ["lookup_customer"], + "stages": ["pre"], }, - "evaluator": { - "name": "regex", - "config": { - "pattern": r"(?i)(select|insert|update|delete|drop|union|--|;)" + "condition": { + "selector": { + "path": "input.query", + }, + "evaluator": { + "name": "regex", + "config": { + "pattern": r"(?i)(select|insert|update|delete|drop|union|--|;)" + }, }, }, - "action": { - "decision": "deny", - "message": "Potential SQL injection in customer lookup query", - }, + "action": {"decision": "deny"}, }, }, { "name": "log-ticket-creation", - "description": "Logs all ticket creation for audit (tool_names: exact match)", + "description": "Logs all ticket creation for audit (scope.step_names: exact match)", "definition": { - "description": "Logs all ticket creation for audit (tool_names: exact match)", + "description": "Logs all ticket creation for audit (scope.step_names: exact match)", "enabled": True, "execution": "server", - "scope": {"step_types": ["tool"], "stages": ["pre"]}, - "selector": { - "path": "*", # Log entire payload - "tool_names": ["create_ticket"], - }, - "evaluator": { - "name": "regex", - "config": {"pattern": r".*"}, # Always matches + "scope": { + "step_types": ["tool"], + "step_names": ["create_ticket"], + "stages": ["pre"], }, - "action": { - "decision": "log", - "message": "Ticket creation logged for audit", + "condition": { + "selector": { + "path": "*", # Log entire payload + }, + "evaluator": { + "name": "regex", + "config": {"pattern": r".*"}, # Always matches + }, }, + "action": {"decision": "log"}, }, }, # ========================================================================== - # TOOL CALL CONTROLS - using 'tool_name_regex' (pattern match) + # TOOL CALL CONTROLS - using scope.step_name_regex (pattern match) # ========================================================================== { "name": "block-profanity-in-search", - "description": "Blocks profanity in any search/lookup tool (tool_name_regex: pattern)", + "description": "Blocks profanity in any search/lookup tool (scope.step_name_regex: pattern)", "definition": { - "description": "Blocks profanity in any search/lookup tool (tool_name_regex: pattern)", + "description": "Blocks profanity in any search/lookup tool (scope.step_name_regex: pattern)", "enabled": True, "execution": "server", - "scope": {"step_types": ["tool"], "stages": ["pre"]}, - "selector": { - "path": "input.query", - # Applies to any tool containing 'search' or 'lookup' - "tool_name_regex": r"(search|lookup)", + "scope": { + "step_types": ["tool"], + "step_name_regex": r"(search|lookup)", + "stages": ["pre"], }, - "evaluator": { - "name": "regex", - "config": { - # Simple profanity pattern for demo - "pattern": r"(?i)\b(badword|offensive|inappropriate)\b" + "condition": { + "selector": { + "path": "input.query", + }, + "evaluator": { + "name": "regex", + "config": { + # Simple profanity pattern for demo + "pattern": r"(?i)\b(badword|offensive|inappropriate)\b" + }, }, }, - "action": { - "decision": "deny", - "message": "Inappropriate content detected in search query", - }, + "action": {"decision": "deny"}, }, }, { @@ -178,25 +180,27 @@ "description": "Warns on high priority tickets (path: input.priority)", "enabled": True, "execution": "server", - "scope": {"step_types": ["tool"], "stages": ["pre"]}, - "selector": { - "path": "input.priority", - "tool_name_regex": r".*ticket.*", # Any tool with 'ticket' in name + "scope": { + "step_types": ["tool"], + "step_name_regex": r".*ticket.*", + "stages": ["pre"], }, - "evaluator": { - "name": "list", - "config": { - "values": ["high", "critical", "urgent"], - "logic": "any", - "match_on": "match", - "match_mode": "exact", - "case_sensitive": False, + "condition": { + "selector": { + "path": "input.priority", + }, + "evaluator": { + "name": "list", + "config": { + "values": ["high", "critical", "urgent"], + "logic": "any", + "match_on": "match", + "match_mode": "exact", + "case_sensitive": False, + }, }, }, - "action": { - "decision": "warn", - "message": "High priority ticket requires manager approval", - }, + "action": {"decision": "warn"}, }, }, # ========================================================================== @@ -209,22 +213,24 @@ "description": "Blocks PII in ticket descriptions (path: input.description)", "enabled": True, "execution": "server", - "scope": {"step_types": ["tool"], "stages": ["pre"]}, - "selector": { - "path": "input.description", - "tool_names": ["create_ticket"], + "scope": { + "step_types": ["tool"], + "step_names": ["create_ticket"], + "stages": ["pre"], }, - "evaluator": { - "name": "regex", - "config": { - # Email pattern - "pattern": r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b" + "condition": { + "selector": { + "path": "input.description", + }, + "evaluator": { + "name": "regex", + "config": { + # Email pattern + "pattern": r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b" + }, }, }, - "action": { - "decision": "warn", - "message": "Email address detected in ticket - consider removing PII", - }, + "action": {"decision": "warn"}, }, }, ] diff --git a/examples/deepeval/setup_controls.py b/examples/deepeval/setup_controls.py index 9d25b7fb..86caad58 100755 --- a/examples/deepeval/setup_controls.py +++ b/examples/deepeval/setup_controls.py @@ -56,27 +56,26 @@ "enabled": True, "execution": "server", "scope": {"step_types": ["llm"], "stages": ["post"]}, - "selector": {}, - "evaluator": { - "name": "deepeval-geval", - "config": { - "name": "Coherence", - "criteria": ( - "Evaluate whether the response is coherent, logically consistent, " - "and well-structured. Check for contradictions and flow of ideas. " - "The response should make logical sense and not contain contradictory statements." - ), - "evaluation_params": ["input", "actual_output"], - "threshold": 0.6, - "model": "gpt-4o", - "strict_mode": False, - "verbose_mode": False, + "condition": { + "selector": {"path": "*"}, + "evaluator": { + "name": "deepeval-geval", + "config": { + "name": "Coherence", + "criteria": ( + "Evaluate whether the response is coherent, logically consistent, " + "and well-structured. Check for contradictions and flow of ideas. " + "The response should make logical sense and not contain contradictory statements." + ), + "evaluation_params": ["input", "actual_output"], + "threshold": 0.6, + "model": "gpt-4o", + "strict_mode": False, + "verbose_mode": False, + }, }, }, - "action": { - "decision": "deny", - "message": "Response failed coherence check - please reformulate", - }, + "action": {"decision": "deny"}, }, }, { @@ -87,26 +86,25 @@ "enabled": True, "execution": "server", "scope": {"step_types": ["llm"], "stages": ["post"]}, - "selector": {}, - "evaluator": { - "name": "deepeval-geval", - "config": { - "name": "Relevance", - "criteria": ( - "Determine whether the actual output is relevant and directly addresses " - "the input query. Check if it stays on topic and provides useful information " - "that answers the question asked." - ), - "evaluation_params": ["input", "actual_output"], - "threshold": 0.5, - "model": "gpt-4o", - "strict_mode": False, + "condition": { + "selector": {"path": "*"}, + "evaluator": { + "name": "deepeval-geval", + "config": { + "name": "Relevance", + "criteria": ( + "Determine whether the actual output is relevant and directly addresses " + "the input query. Check if it stays on topic and provides useful information " + "that answers the question asked." + ), + "evaluation_params": ["input", "actual_output"], + "threshold": 0.5, + "model": "gpt-4o", + "strict_mode": False, + }, }, }, - "action": { - "decision": "deny", - "message": "Response is not relevant to the question - please provide a relevant answer", - }, + "action": {"decision": "deny"}, }, }, { @@ -117,26 +115,25 @@ "enabled": False, # Disabled by default - enable when you have expected outputs "execution": "server", "scope": {"step_types": ["llm"], "stages": ["post"]}, - "selector": {"path": "*"}, - "evaluator": { - "name": "deepeval-geval", - "config": { - "name": "Correctness", - "evaluation_steps": [ - "Check whether facts in actual output contradict expected output", - "Heavily penalize omission of critical details", - "Minor wording differences are acceptable", - "Focus on factual accuracy, not style", - ], - "evaluation_params": ["actual_output", "expected_output"], - "threshold": 0.8, - "model": "gpt-4o", + "condition": { + "selector": {"path": "*"}, + "evaluator": { + "name": "deepeval-geval", + "config": { + "name": "Correctness", + "evaluation_steps": [ + "Check whether facts in actual output contradict expected output", + "Heavily penalize omission of critical details", + "Minor wording differences are acceptable", + "Focus on factual accuracy, not style", + ], + "evaluation_params": ["actual_output", "expected_output"], + "threshold": 0.8, + "model": "gpt-4o", + }, }, }, - "action": { - "decision": "warn", - "message": "Response may contain factual errors - review carefully", - }, + "action": {"decision": "warn"}, }, }, ] diff --git a/examples/google_adk_callbacks/setup_controls.py b/examples/google_adk_callbacks/setup_controls.py index d5a94b8a..0aea1754 100644 --- a/examples/google_adk_callbacks/setup_controls.py +++ b/examples/google_adk_callbacks/setup_controls.py @@ -21,20 +21,19 @@ "enabled": True, "execution": "server", "scope": {"step_types": ["llm"], "stages": ["pre"]}, - "selector": {"path": "input"}, - "evaluator": { - "name": "regex", - "config": { - "pattern": ( - r"(?i)(ignore.{0,20}(previous|prior|above).{0,20}instructions" - r"|system:|forget everything|reveal secrets)" - ) + "condition": { + "selector": {"path": "input"}, + "evaluator": { + "name": "regex", + "config": { + "pattern": ( + r"(?i)(ignore.{0,20}(previous|prior|above).{0,20}instructions" + r"|system:|forget everything|reveal secrets)" + ) + }, }, }, - "action": { - "decision": "deny", - "message": "Prompt injection attempt detected.", - }, + "action": {"decision": "deny"}, }, ), ( @@ -48,21 +47,20 @@ "step_names": ["get_current_time", "get_weather"], "stages": ["pre"], }, - "selector": {"path": "input.city"}, - "evaluator": { - "name": "list", - "config": { - "values": ["Pyongyang", "Tehran", "Damascus"], - "logic": "any", - "match_on": "match", - "match_mode": "exact", - "case_sensitive": False, + "condition": { + "selector": {"path": "input.city"}, + "evaluator": { + "name": "list", + "config": { + "values": ["Pyongyang", "Tehran", "Damascus"], + "logic": "any", + "match_on": "match", + "match_mode": "exact", + "case_sensitive": False, + }, }, }, - "action": { - "decision": "deny", - "message": "That city is blocked by policy.", - }, + "action": {"decision": "deny"}, }, ), ( @@ -76,17 +74,16 @@ "step_names": ["get_current_time", "get_weather"], "stages": ["post"], }, - "selector": {"path": "output.note"}, - "evaluator": { - "name": "regex", - "config": { - "pattern": r"support@internal\.example|123-45-6789", + "condition": { + "selector": {"path": "output.note"}, + "evaluator": { + "name": "regex", + "config": { + "pattern": r"support@internal\.example|123-45-6789", + }, }, }, - "action": { - "decision": "deny", - "message": "Tool output exposed internal contact data.", - }, + "action": {"decision": "deny"}, }, ), ] diff --git a/examples/google_adk_decorator/setup_controls.py b/examples/google_adk_decorator/setup_controls.py index 772e9c9b..f3c68e66 100644 --- a/examples/google_adk_decorator/setup_controls.py +++ b/examples/google_adk_decorator/setup_controls.py @@ -29,21 +29,20 @@ def _control_specs(execution: str) -> list[tuple[str, dict[str, Any]]]: "step_names": ["get_current_time", "get_weather"], "stages": ["pre"], }, - "selector": {"path": "input.city"}, - "evaluator": { - "name": "list", - "config": { - "values": ["Pyongyang", "Tehran", "Damascus"], - "logic": "any", - "match_on": "match", - "match_mode": "exact", - "case_sensitive": False, + "condition": { + "selector": {"path": "input.city"}, + "evaluator": { + "name": "list", + "config": { + "values": ["Pyongyang", "Tehran", "Damascus"], + "logic": "any", + "match_on": "match", + "match_mode": "exact", + "case_sensitive": False, + }, }, }, - "action": { - "decision": "deny", - "message": "That city is blocked by policy.", - }, + "action": {"decision": "deny"}, }, ), ( @@ -57,17 +56,16 @@ def _control_specs(execution: str) -> list[tuple[str, dict[str, Any]]]: "step_names": ["get_current_time", "get_weather"], "stages": ["post"], }, - "selector": {"path": "output.note"}, - "evaluator": { - "name": "regex", - "config": { - "pattern": r"support@internal\.example|123-45-6789", + "condition": { + "selector": {"path": "output.note"}, + "evaluator": { + "name": "regex", + "config": { + "pattern": r"support@internal\.example|123-45-6789", + }, }, }, - "action": { - "decision": "deny", - "message": "Tool output exposed internal contact data.", - }, + "action": {"decision": "deny"}, }, ), ] diff --git a/examples/google_adk_plugin/setup_controls.py b/examples/google_adk_plugin/setup_controls.py index d75b1d32..3bfeb483 100644 --- a/examples/google_adk_plugin/setup_controls.py +++ b/examples/google_adk_plugin/setup_controls.py @@ -29,20 +29,19 @@ def _control_specs(execution: str) -> list[tuple[str, dict[str, Any]]]: "enabled": True, "execution": execution, "scope": {"step_types": ["llm"], "stages": ["pre"]}, - "selector": {"path": "input"}, - "evaluator": { - "name": "regex", - "config": { - "pattern": ( - r"(?i)(ignore.{0,20}(previous|prior|above).{0,20}instructions" - r"|system:|forget everything|reveal secrets)" - ) + "condition": { + "selector": {"path": "input"}, + "evaluator": { + "name": "regex", + "config": { + "pattern": ( + r"(?i)(ignore.{0,20}(previous|prior|above).{0,20}instructions" + r"|system:|forget everything|reveal secrets)" + ) + }, }, }, - "action": { - "decision": "deny", - "message": "Prompt injection attempt detected.", - }, + "action": {"decision": "deny"}, }, ), ( @@ -56,21 +55,20 @@ def _control_specs(execution: str) -> list[tuple[str, dict[str, Any]]]: "step_names": TOOL_STEP_NAMES, "stages": ["pre"], }, - "selector": {"path": "input.city"}, - "evaluator": { - "name": "list", - "config": { - "values": ["Pyongyang", "Tehran", "Damascus"], - "logic": "any", - "match_on": "match", - "match_mode": "exact", - "case_sensitive": False, + "condition": { + "selector": {"path": "input.city"}, + "evaluator": { + "name": "list", + "config": { + "values": ["Pyongyang", "Tehran", "Damascus"], + "logic": "any", + "match_on": "match", + "match_mode": "exact", + "case_sensitive": False, + }, }, }, - "action": { - "decision": "deny", - "message": "That city is blocked by policy.", - }, + "action": {"decision": "deny"}, }, ), ( @@ -84,17 +82,16 @@ def _control_specs(execution: str) -> list[tuple[str, dict[str, Any]]]: "step_names": TOOL_STEP_NAMES, "stages": ["post"], }, - "selector": {"path": "output.note"}, - "evaluator": { - "name": "regex", - "config": { - "pattern": r"support@internal\.example|123-45-6789", + "condition": { + "selector": {"path": "output.note"}, + "evaluator": { + "name": "regex", + "config": { + "pattern": r"support@internal\.example|123-45-6789", + }, }, }, - "action": { - "decision": "deny", - "message": "Tool output exposed internal contact data.", - }, + "action": {"decision": "deny"}, }, ), ] diff --git a/examples/langchain/setup_sql_controls.py b/examples/langchain/setup_sql_controls.py index e989318f..c4e4d428 100644 --- a/examples/langchain/setup_sql_controls.py +++ b/examples/langchain/setup_sql_controls.py @@ -74,17 +74,19 @@ async def setup_sql_controls(): "step_names": ["sql_db_query"], "stages": ["pre"] }, - "selector": { - "path": "input.query" - }, - "evaluator": { - "name": "sql", - "config": { - "blocked_operations": ["DROP", "DELETE", "TRUNCATE", "ALTER", "GRANT"], - "allow_multi_statements": False, - "require_limit": True, - "max_limit": 100 - } + "condition": { + "selector": { + "path": "input.query" + }, + "evaluator": { + "name": "sql", + "config": { + "blocked_operations": ["DROP", "DELETE", "TRUNCATE", "ALTER", "GRANT"], + "allow_multi_statements": False, + "require_limit": True, + "max_limit": 100 + } + }, }, "action": {"decision": "deny"} } diff --git a/examples/steer_action_demo/setup_controls.py b/examples/steer_action_demo/setup_controls.py index f6d8a0a2..af13c80f 100644 --- a/examples/steer_action_demo/setup_controls.py +++ b/examples/steer_action_demo/setup_controls.py @@ -49,17 +49,19 @@ async def setup_banking_controls(): "step_names": ["process_wire_transfer"], "stages": ["pre"] }, - "selector": { - "path": "input.destination_country" - }, - "evaluator": { - "name": "list", - "config": { - "values": ["north korea", "iran", "syria", "cuba", "crimea"], - "logic": "any", - "match_mode": "contains", - "case_sensitive": False - } + "condition": { + "selector": { + "path": "input.destination_country" + }, + "evaluator": { + "name": "list", + "config": { + "values": ["north korea", "iran", "syria", "cuba", "crimea"], + "logic": "any", + "match_mode": "contains", + "case_sensitive": False + } + }, }, "action": { "decision": "deny" @@ -76,19 +78,21 @@ async def setup_banking_controls(): "step_names": ["process_wire_transfer"], "stages": ["pre"] }, - "selector": { - "path": "input" - }, - "evaluator": { - "name": "json", - "config": { - "field_constraints": { - "fraud_score": { - "type": "number", - "max": 0.8 + "condition": { + "selector": { + "path": "input" + }, + "evaluator": { + "name": "json", + "config": { + "field_constraints": { + "fraud_score": { + "type": "number", + "max": 0.8 + } } } - } + }, }, "action": { "decision": "deny" @@ -109,16 +113,20 @@ async def setup_banking_controls(): "step_names": ["process_wire_transfer"], "stages": ["pre"] }, - "selector": { - "path": "input.recipient" - }, - "evaluator": { - "name": "list", - "config": { - "values": ["John Smith", "Acme Corp", "Global Suppliers Inc"], - "match_type": "not_in", - "case_sensitive": False - } + "condition": { + "selector": { + "path": "input.recipient" + }, + "evaluator": { + "name": "list", + "config": { + "values": ["John Smith", "Acme Corp", "Global Suppliers Inc"], + "logic": "any", + "match_on": "no_match", + "match_mode": "exact", + "case_sensitive": False + } + }, }, "action": { "decision": "warn" @@ -139,20 +147,22 @@ async def setup_banking_controls(): "step_names": ["process_wire_transfer"], "stages": ["pre"] }, - "selector": { - "path": "input" - }, - "evaluator": { - "name": "json", - "config": { - "json_schema": { - "type": "object", - "oneOf": [ - {"properties": {"amount": {"type": "number", "exclusiveMaximum": 10000}}}, - {"properties": {"amount": {"type": "number", "minimum": 10000}, "verified_2fa": {"const": True}}} - ] + "condition": { + "selector": { + "path": "input" + }, + "evaluator": { + "name": "json", + "config": { + "json_schema": { + "type": "object", + "oneOf": [ + {"properties": {"amount": {"type": "number", "exclusiveMaximum": 10000}}}, + {"properties": {"amount": {"type": "number", "minimum": 10000}, "verified_2fa": {"const": True}}} + ] + } } - } + }, }, "action": { "decision": "steer", @@ -172,20 +182,22 @@ async def setup_banking_controls(): "step_names": ["process_wire_transfer"], "stages": ["pre"] }, - "selector": { - "path": "input" - }, - "evaluator": { - "name": "json", - "config": { - "json_schema": { - "type": "object", - "oneOf": [ - {"properties": {"amount": {"type": "number", "exclusiveMaximum": 10000}}}, - {"properties": {"amount": {"type": "number", "minimum": 10000}, "manager_approved": {"const": True}}} - ] + "condition": { + "selector": { + "path": "input" + }, + "evaluator": { + "name": "json", + "config": { + "json_schema": { + "type": "object", + "oneOf": [ + {"properties": {"amount": {"type": "number", "exclusiveMaximum": 10000}}}, + {"properties": {"amount": {"type": "number", "minimum": 10000}, "manager_approved": {"const": True}}} + ] + } } - } + }, }, "action": { "decision": "steer", diff --git a/examples/strands_agents/interactive_demo/interactive_support_demo.py b/examples/strands_agents/interactive_demo/interactive_support_demo.py index ef7e3024..c36183d8 100644 --- a/examples/strands_agents/interactive_demo/interactive_support_demo.py +++ b/examples/strands_agents/interactive_demo/interactive_support_demo.py @@ -493,8 +493,9 @@ def render_sidebar(): # Display key information st.markdown(f"**ID:** `{control_id}`") - if "evaluator" in data: - evaluator = data["evaluator"] + condition = data.get("condition", {}) + if "evaluator" in condition: + evaluator = condition["evaluator"] st.markdown(f"**Evaluator:** `{evaluator.get('name', 'N/A')}`") if "config" in evaluator: diff --git a/examples/strands_agents/interactive_demo/setup_interactive_controls.py b/examples/strands_agents/interactive_demo/setup_interactive_controls.py index 18822ef4..cf1517aa 100644 --- a/examples/strands_agents/interactive_demo/setup_interactive_controls.py +++ b/examples/strands_agents/interactive_demo/setup_interactive_controls.py @@ -47,14 +47,16 @@ "step_names": ["check_before_invocation", "check_before_model"], "stages": ["pre"] }, - "selector": {"path": "input"}, - "evaluator": { - "name": "regex", - "config": { - "pattern": r"\b\d{3}-\d{2}-\d{4}\b|\b\d{9}\b|\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b|\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b" - } + "condition": { + "selector": {"path": "input"}, + "evaluator": { + "name": "regex", + "config": { + "pattern": r"\b\d{3}-\d{2}-\d{4}\b|\b\d{9}\b|\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b|\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b" + } + }, }, - "action": {"decision": "deny", "message": "PII detected in user input"}, + "action": {"decision": "deny"}, "tags": ["pii", "security"] } }, @@ -69,14 +71,16 @@ "step_names": ["check_before_invocation", "check_before_model"], "stages": ["pre"] }, - "selector": {"path": "input"}, - "evaluator": { - "name": "regex", - "config": { - "pattern": r"(\bDROP\s+TABLE\b|\bDROP\s+DATABASE\b|--;)" - } + "condition": { + "selector": {"path": "input"}, + "evaluator": { + "name": "regex", + "config": { + "pattern": r"(\bDROP\s+TABLE\b|\bDROP\s+DATABASE\b|--;)" + } + }, }, - "action": {"decision": "deny", "message": "Potentially malicious SQL patterns detected"}, + "action": {"decision": "deny"}, "tags": ["security"] } }, @@ -93,14 +97,16 @@ "step_names": ["check_after_model"], "stages": ["post"] }, - "selector": {"path": "output"}, - "evaluator": { - "name": "regex", - "config": { - "pattern": r"\b\d{3}-\d{2}-\d{4}\b|\b\d{9}\b|\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b|\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b" - } + "condition": { + "selector": {"path": "output"}, + "evaluator": { + "name": "regex", + "config": { + "pattern": r"\b\d{3}-\d{2}-\d{4}\b|\b\d{9}\b|\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b|\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b" + } + }, }, - "action": {"decision": "deny", "message": "PII detected in agent response"}, + "action": {"decision": "deny"}, "tags": ["pii", "security"] } }, @@ -117,14 +123,16 @@ "step_types": ["tool"], # Applies to ALL tools "stages": ["pre"] }, - "selector": {"path": "input"}, # Check entire tool input - "evaluator": { - "name": "regex", - "config": { - "pattern": r"\b\d{3}-\d{2}-\d{4}\b|\b\d{9}\b|\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b" - } + "condition": { + "selector": {"path": "input"}, # Check entire tool input + "evaluator": { + "name": "regex", + "config": { + "pattern": r"\b\d{3}-\d{2}-\d{4}\b|\b\d{9}\b|\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b" + } + }, }, - "action": {"decision": "deny", "message": "PII not allowed in tool inputs"}, + "action": {"decision": "deny"}, "tags": ["pii", "security", "tools"] } } diff --git a/examples/strands_agents/steering_demo/setup_email_controls.py b/examples/strands_agents/steering_demo/setup_email_controls.py index e86cef8a..5debaa17 100644 --- a/examples/strands_agents/steering_demo/setup_email_controls.py +++ b/examples/strands_agents/steering_demo/setup_email_controls.py @@ -43,16 +43,17 @@ "step_types": ["llm"], "stages": ["post"] }, - "selector": {"path": "output"}, - "evaluator": { - "name": "regex", - "config": { - "pattern": r"(\b\d{9,12}\b)|(\d{3}[-\s]?\d{2}[-\s]?\d{4})|(\$[\d,]+\d{3,})" - } + "condition": { + "selector": {"path": "output"}, + "evaluator": { + "name": "regex", + "config": { + "pattern": r"(\b\d{9,12}\b)|(\d{3}[-\s]?\d{2}[-\s]?\d{4})|(\$[\d,]+\d{3,})" + } + }, }, "action": { "decision": "steer", - "message": "PII detected in draft - redact before sending", "steering_context": { "message": """⚠️ FINANCIAL PII DETECTED - Apply redactions before sending email: @@ -97,18 +98,17 @@ "stages": ["pre"], "step_names": ["send_monthly_account_summary"] }, - "selector": {"path": "input.summary_text"}, - "evaluator": { - "name": "regex", - "config": { - # Match patterns like: api_key, password, secret, token - "pattern": r"(api[_-]?key|password|secret|token|credential)[\s:=]+['\"]?[\w\-]{8,}" - } - }, - "action": { - "decision": "deny", - "message": "Credentials detected in email - BLOCKED for security" + "condition": { + "selector": {"path": "input.summary_text"}, + "evaluator": { + "name": "regex", + "config": { + # Match patterns like: api_key, password, secret, token + "pattern": r"(api[_-]?key|password|secret|token|credential)[\s:=]+['\"]?[\w\-]{8,}" + } + }, }, + "action": {"decision": "deny"}, "tags": ["credentials", "secrets", "critical", "deny"] } }, @@ -126,18 +126,17 @@ "stages": ["pre"], "step_names": ["send_monthly_account_summary"] }, - "selector": {"path": "input.summary_text"}, - "evaluator": { - "name": "regex", - "config": { - # Match database names, server paths - "pattern": r"(database|db_|server|localhost|127\.0\.0\.1|/var/|/etc/|C:\\\\)" - } - }, - "action": { - "decision": "deny", - "message": "Internal system info detected in email - BLOCKED for security" + "condition": { + "selector": {"path": "input.summary_text"}, + "evaluator": { + "name": "regex", + "config": { + # Match database names, server paths + "pattern": r"(database|db_|server|localhost|127\.0\.0\.1|/var/|/etc/|C:\\\\)" + } + }, }, + "action": {"decision": "deny"}, "tags": ["internal-info", "security", "deny"] } }, diff --git a/models/pyproject.toml b/models/pyproject.toml index 9d448330..7d77632a 100644 --- a/models/pyproject.toml +++ b/models/pyproject.toml @@ -16,6 +16,7 @@ license = {text = "Apache-2.0"} [dependency-groups] dev = [ "pytest>=8.0.0", + "pytest-cov>=4.0.0", "ruff>=0.1.0", "mypy>=1.8.0", ] diff --git a/models/src/agent_control_models/__init__.py b/models/src/agent_control_models/__init__.py index 77ee5615..09cd3ffb 100644 --- a/models/src/agent_control_models/__init__.py +++ b/models/src/agent_control_models/__init__.py @@ -18,6 +18,7 @@ StepSchema, ) from .controls import ( + ConditionNode, ControlAction, ControlDefinition, ControlMatch, @@ -103,6 +104,7 @@ "EvaluationResult", # Controls "ControlDefinition", + "ConditionNode", "ControlAction", "ControlMatch", "ControlScope", diff --git a/models/src/agent_control_models/controls.py b/models/src/agent_control_models/controls.py index 62b3ad2e..872f8df7 100644 --- a/models/src/agent_control_models/controls.py +++ b/models/src/agent_control_models/controls.py @@ -1,10 +1,13 @@ """Control definition models for agent protection.""" +from __future__ import annotations + +from collections.abc import Iterator from typing import Any, Literal, Self from uuid import uuid4 import re2 -from pydantic import Field, ValidationInfo, field_validator, model_validator +from pydantic import ConfigDict, Field, ValidationInfo, field_validator, model_validator from .base import BaseModel @@ -208,6 +211,9 @@ def validate_evaluator_config(self) -> Self: return self +type ConditionLeafParts = tuple[ControlSelector, EvaluatorSpec] + + class SteeringContext(BaseModel): """Steering context for steer actions. @@ -260,6 +266,169 @@ class ControlAction(BaseModel): ) +MAX_CONDITION_DEPTH = 6 + + +class ConditionNode(BaseModel): + """Recursive boolean condition tree for control evaluation.""" + + selector: ControlSelector | None = Field( + default=None, + description="Leaf selector. Must be provided together with evaluator.", + ) + evaluator: EvaluatorSpec | None = Field( + default=None, + description="Leaf evaluator. Must be provided together with selector.", + ) + and_: list[ConditionNode] | None = Field( + default=None, + alias="and", + serialization_alias="and", + description="Logical AND over child conditions.", + ) + or_: list[ConditionNode] | None = Field( + default=None, + alias="or", + serialization_alias="or", + description="Logical OR over child conditions.", + ) + not_: ConditionNode | None = Field( + default=None, + alias="not", + serialization_alias="not", + description="Logical NOT over a single child condition.", + ) + + model_config = ConfigDict( + populate_by_name=True, + use_enum_values=True, + validate_assignment=True, + extra="ignore", + serialize_by_alias=True, + ) + + @model_validator(mode="after") + def validate_shape(self) -> Self: + """Ensure each node is exactly one of leaf/and/or/not.""" + has_selector = self.selector is not None + has_evaluator = self.evaluator is not None + has_leaf = has_selector and has_evaluator + if has_selector != has_evaluator: + raise ValueError("Leaf condition requires both selector and evaluator") + + populated = sum( + 1 + for present in ( + has_leaf, + self.and_ is not None, + self.or_ is not None, + self.not_ is not None, + ) + if present + ) + if populated != 1: + raise ValueError("Condition node must contain exactly one of leaf, and, or, not") + + if self.and_ is not None and len(self.and_) == 0: + raise ValueError("'and' must contain at least one child condition") + if self.or_ is not None and len(self.or_) == 0: + raise ValueError("'or' must contain at least one child condition") + + return self + + def kind(self) -> Literal["leaf", "and", "or", "not"]: + """Return the logical node type.""" + if self.is_leaf(): + return "leaf" + if self.and_ is not None: + return "and" + if self.or_ is not None: + return "or" + return "not" + + def is_leaf(self) -> bool: + """Return True when this node is a leaf selector/evaluator pair.""" + return self.selector is not None and self.evaluator is not None + + def children_in_order(self) -> list[ConditionNode]: + """Return child conditions in evaluation order.""" + if self.and_ is not None: + return self.and_ + if self.or_ is not None: + return self.or_ + if self.not_ is not None: + return [self.not_] + return [] + + def iter_leaves(self) -> Iterator[ConditionNode]: + """Yield leaf nodes in left-to-right traversal order.""" + if self.is_leaf(): + yield self + return + + for child in self.children_in_order(): + yield from child.iter_leaves() + + def iter_leaf_parts(self) -> Iterator[ConditionLeafParts]: + """Yield leaf selector/evaluator pairs in left-to-right traversal order.""" + leaf_parts = self.leaf_parts() + if leaf_parts is not None: + yield leaf_parts + return + + for child in self.children_in_order(): + yield from child.iter_leaf_parts() + + def max_depth(self) -> int: + """Return the maximum nesting depth of this condition tree.""" + children = self.children_in_order() + if not children: + return 1 + return 1 + max(child.max_depth() for child in children) + + def leaf_parts(self) -> ConditionLeafParts | None: + """Return the selector/evaluator pair for leaf nodes.""" + if not self.is_leaf(): + return None + selector = self.selector + evaluator = self.evaluator + if selector is None or evaluator is None: + return None + return selector, evaluator + + model_config["json_schema_extra"] = { + "examples": [ + { + "selector": {"path": "output"}, + "evaluator": {"name": "regex", "config": {"pattern": r"\d{3}-\d{2}-\d{4}"}}, + }, + { + "and": [ + { + "selector": {"path": "context.risk_level"}, + "evaluator": { + "name": "list", + "config": {"values": ["high", "critical"]}, + }, + }, + { + "not": { + "selector": {"path": "context.user_role"}, + "evaluator": { + "name": "list", + "config": {"values": ["admin", "security"]}, + }, + } + }, + ] + }, + ] + } + + +ConditionNode.model_rebuild() + + class ControlDefinition(BaseModel): """A control definition to evaluate agent interactions. @@ -280,10 +449,13 @@ class ControlDefinition(BaseModel): ) # What to check - selector: ControlSelector = Field(..., description="What data to select from the payload") - - # How to check (unified evaluator-based system) - evaluator: EvaluatorSpec = Field(..., description="How to evaluate the selected data") + condition: ConditionNode = Field( + ..., + description=( + "Recursive boolean condition tree. Leaf nodes contain selector + evaluator; " + "composite nodes contain and/or/not." + ), + ) # What to do action: ControlAction = Field(..., description="What action to take when control matches") @@ -291,6 +463,74 @@ class ControlDefinition(BaseModel): # Metadata tags: list[str] = Field(default_factory=list, description="Tags for categorization") + @classmethod + def canonicalize_payload(cls, data: Any) -> Any: + """Rewrite legacy selector/evaluator payloads into canonical condition shape.""" + if not isinstance(data, dict): + return data + + has_condition = "condition" in data + has_selector = "selector" in data + has_evaluator = "evaluator" in data + + if has_condition and (has_selector or has_evaluator): + raise ValueError( + "Control definition mixes canonical condition fields " + "with legacy selector/evaluator fields." + ) + if has_selector != has_evaluator: + raise ValueError( + "Legacy control definition must include both selector and evaluator." + ) + if not has_condition and has_selector: + canonical = dict(data) + selector = canonical.pop("selector") + evaluator = canonical.pop("evaluator") + canonical["condition"] = { + "selector": selector, + "evaluator": evaluator, + } + return canonical + return data + + @model_validator(mode="before") + @classmethod + def canonicalize_legacy_condition_shape(cls, data: Any) -> Any: + """Accept legacy flat leaf payloads during condition-tree rollout.""" + return cls.canonicalize_payload(data) + + @model_validator(mode="after") + def validate_condition_constraints(self) -> Self: + """Validate cross-field control constraints.""" + if self.condition.max_depth() > MAX_CONDITION_DEPTH: + raise ValueError( + f"Condition nesting depth exceeds maximum of {MAX_CONDITION_DEPTH}" + ) + + if ( + self.action.decision == "steer" + and not self.condition.is_leaf() + and self.action.steering_context is None + ): + raise ValueError( + "Composite steer controls require action.steering_context" + ) + return self + + def iter_condition_leaves(self) -> Iterator[ConditionNode]: + """Yield leaf conditions in evaluation order.""" + yield from self.condition.iter_leaves() + + def iter_condition_leaf_parts(self) -> Iterator[ConditionLeafParts]: + """Yield leaf selector/evaluator pairs in evaluation order.""" + yield from self.condition.iter_leaf_parts() + + def primary_leaf(self) -> ConditionNode | None: + """Return the single leaf node when the whole condition is just one leaf.""" + if self.condition.is_leaf(): + return self.condition + return None + model_config = { "json_schema_extra": { "examples": [ @@ -299,11 +539,13 @@ class ControlDefinition(BaseModel): "enabled": True, "execution": "server", "scope": {"step_types": ["llm"], "stages": ["post"]}, - "selector": {"path": "output"}, - "evaluator": { - "name": "regex", - "config": { - "pattern": r"\b\d{3}-\d{2}-\d{4}\b", + "condition": { + "selector": {"path": "output"}, + "evaluator": { + "name": "regex", + "config": { + "pattern": r"\b\d{3}-\d{2}-\d{4}\b", + }, }, }, "action": { diff --git a/models/tests/test_controls.py b/models/tests/test_controls.py new file mode 100644 index 00000000..f69bccde --- /dev/null +++ b/models/tests/test_controls.py @@ -0,0 +1,230 @@ +"""Direct tests for recursive condition-tree models.""" + +from __future__ import annotations + +import pytest +from agent_control_models import ControlDefinition +from pydantic import ValidationError + + +def _leaf( + path: str, + evaluator_name: str = "regex", + config: dict[str, object] | None = None, +) -> dict[str, object]: + return { + "selector": {"path": path}, + "evaluator": { + "name": evaluator_name, + "config": config or {"pattern": "ok"}, + }, + } + + +def test_condition_leaf_requires_selector_and_evaluator() -> None: + # Given: a leaf condition with only a selector + with pytest.raises( + ValidationError, + match="Leaf condition requires both selector and evaluator", + ): + # When: validating the control definition + ControlDefinition.model_validate( + { + "execution": "server", + "scope": {"step_types": ["llm"], "stages": ["pre"]}, + "condition": {"selector": {"path": "input"}}, + "action": {"decision": "deny"}, + } + ) + # Then: validation rejects the incomplete leaf shape + + +def test_condition_node_requires_exactly_one_shape() -> None: + # Given: a condition node that mixes leaf and composite fields + with pytest.raises( + ValidationError, + match="Condition node must contain exactly one of leaf, and, or, not", + ): + # When: validating the control definition + ControlDefinition.model_validate( + { + "execution": "server", + "scope": {"step_types": ["llm"], "stages": ["pre"]}, + "condition": { + "selector": {"path": "input"}, + "evaluator": {"name": "regex", "config": {"pattern": "ok"}}, + "and": [_leaf("output")], + }, + "action": {"decision": "deny"}, + } + ) + # Then: validation rejects the ambiguous node shape + + +def test_legacy_leaf_payload_is_canonicalized() -> None: + # Given: a legacy flat selector/evaluator payload + legacy_payload = { + "execution": "server", + "scope": {"step_types": ["llm"], "stages": ["pre"]}, + "selector": {"path": "input"}, + "evaluator": {"name": "regex", "config": {"pattern": "ok"}}, + "action": {"decision": "deny"}, + } + + # When: validating the legacy payload + control = ControlDefinition.model_validate(legacy_payload) + + # Then: the model dumps back out in canonical condition form + dumped = control.model_dump(mode="json", exclude_none=True) + assert "selector" not in dumped + assert "evaluator" not in dumped + assert dumped["condition"]["selector"]["path"] == "input" + assert dumped["condition"]["evaluator"]["name"] == "regex" + + +def test_mixed_legacy_and_condition_fields_are_rejected() -> None: + # Given: a payload that mixes canonical condition and legacy flat fields + payload = { + "execution": "server", + "scope": {"step_types": ["llm"], "stages": ["pre"]}, + "condition": _leaf("input"), + "selector": {"path": "output"}, + "evaluator": {"name": "regex", "config": {"pattern": "ok"}}, + "action": {"decision": "deny"}, + } + + with pytest.raises( + ValidationError, + match="Control definition mixes canonical condition fields " + "with legacy selector/evaluator fields", + ): + # When: validating the mixed payload + ControlDefinition.model_validate(payload) + # Then: validation rejects the mixed shape + + +def test_condition_and_requires_at_least_one_child() -> None: + # Given: an empty AND condition + with pytest.raises( + ValidationError, + match="'and' must contain at least one child condition", + ): + # When: validating the control definition + ControlDefinition.model_validate( + { + "execution": "server", + "scope": {"step_types": ["llm"], "stages": ["pre"]}, + "condition": {"and": []}, + "action": {"decision": "deny"}, + } + ) + # Then: validation rejects the empty composite + + +def test_condition_iter_leaves_preserves_left_to_right_order() -> None: + # Given: a nested condition tree with leaves in several branches + control = ControlDefinition.model_validate( + { + "execution": "server", + "scope": {"step_types": ["llm"], "stages": ["pre"]}, + "condition": { + "and": [ + _leaf("input.user"), + { + "not": _leaf( + "input.role", + evaluator_name="list", + config={"values": ["admin"]}, + ) + }, + { + "or": [ + _leaf("output.first"), + _leaf("output.second"), + ] + }, + ] + }, + "action": {"decision": "deny"}, + } + ) + + # When: iterating leaves and computing derived helpers + paths = [ + leaf.leaf_parts()[0].path + for leaf in control.iter_condition_leaves() + if leaf.leaf_parts() is not None + ] + + # Then: leaves are visited in evaluation order and tree helpers stay accurate + assert paths == ["input.user", "input.role", "output.first", "output.second"] + assert control.condition.max_depth() == 3 + assert control.primary_leaf() is None + + +def test_condition_depth_limit_is_enforced() -> None: + # Given: a condition tree nested deeper than the allowed maximum + too_deep = _leaf("input") + for _ in range(6): + too_deep = {"not": too_deep} + + with pytest.raises( + ValidationError, + match="Condition nesting depth exceeds maximum of 6", + ): + # When: validating the deep condition tree + ControlDefinition.model_validate( + { + "execution": "server", + "scope": {"step_types": ["llm"], "stages": ["pre"]}, + "condition": too_deep, + "action": {"decision": "deny"}, + } + ) + # Then: validation rejects the over-nested tree + + +def test_composite_steer_requires_steering_context() -> None: + # Given: a composite steer control without steering context + with pytest.raises( + ValidationError, + match="Composite steer controls require action.steering_context", + ): + # When: validating the control definition + ControlDefinition.model_validate( + { + "execution": "server", + "scope": {"step_types": ["llm"], "stages": ["pre"]}, + "condition": { + "or": [ + _leaf("input"), + _leaf("output"), + ] + }, + "action": {"decision": "steer"}, + } + ) + # Then: validation rejects the steer action without guidance + + +def test_single_leaf_control_returns_primary_leaf() -> None: + # Given: a control whose entire condition is a single leaf + control = ControlDefinition.model_validate( + { + "execution": "server", + "scope": {"step_types": ["llm"], "stages": ["pre"]}, + "condition": _leaf("input.value"), + "action": {"decision": "deny"}, + } + ) + + # When: asking for the primary leaf + primary_leaf = control.primary_leaf() + + # Then: the original selector/evaluator pair is returned intact + assert primary_leaf is not None + leaf_parts = primary_leaf.leaf_parts() + assert leaf_parts is not None + selector, evaluator = leaf_parts + assert selector.path == "input.value" + assert evaluator.name == "regex" diff --git a/pyproject.toml b/pyproject.toml index 3ee8b7d9..49552809 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,7 @@ ignore_missing_imports = true version_toml = [ "pyproject.toml:project.version", "models/pyproject.toml:project.version", + "engine/pyproject.toml:project.version", "sdks/python/pyproject.toml:project.version", "server/pyproject.toml:project.version", "evaluators/builtin/pyproject.toml:project.version", diff --git a/sdks/python/Makefile b/sdks/python/Makefile index 991e1f93..279c71ea 100644 --- a/sdks/python/Makefile +++ b/sdks/python/Makefile @@ -1,6 +1,9 @@ .PHONY: help test lint lint-fix typecheck build publish TEST_DB ?= agent_control_test +TEST_SERVER_PORT ?= 18000 +TEST_SERVER_HOST ?= 127.0.0.1 +TEST_SERVER_URL := http://$(TEST_SERVER_HOST):$(TEST_SERVER_PORT) help: @echo "Agent Control SDK - Makefile commands" @@ -22,16 +25,17 @@ test: DB_DATABASE=$(TEST_DB) uv run --package agent-control-server python scripts/reset_test_db.py DB_DATABASE=$(TEST_DB) $(MAKE) -C ../../ server-alembic-upgrade @# Start server in background and save PID - @DB_DATABASE=$(TEST_DB) uv run --package agent-control-server uvicorn agent_control_server.main:app --port 8000 --host 0.0.0.0 > server.log 2>&1 & echo $$! > server.pid + @DB_DATABASE=$(TEST_DB) uv run --package agent-control-server uvicorn agent_control_server.main:app --port $(TEST_SERVER_PORT) --host $(TEST_SERVER_HOST) > server.log 2>&1 & echo $$! > server.pid @echo "Waiting for server..." - @bash -c 'for i in {1..30}; do if curl -s http://localhost:8000/health >/dev/null; then echo "Server up!"; exit 0; fi; sleep 1; done; echo "Server failed"; cat server.log; exit 1' + @bash -c 'for i in {1..30}; do if curl -s $(TEST_SERVER_URL)/health >/dev/null; then echo "Server up!"; exit 0; fi; sleep 1; done; echo "Server failed"; cat server.log; exit 1' @# Run tests, capture exit code, and ensure cleanup @set -e; \ - DB_DATABASE=$(TEST_DB) uv run pytest --cov=src --cov-report=xml:../../coverage-sdk.xml -q; \ - TEST_EXIT_CODE=$$?; \ - echo "Stopping server..."; \ - if [ -f server.pid ]; then kill `cat server.pid` && rm server.pid; fi; \ - exit $$TEST_EXIT_CODE + cleanup() { \ + echo "Stopping server..."; \ + if [ -f server.pid ]; then kill `cat server.pid` 2>/dev/null || true; rm -f server.pid; fi; \ + }; \ + trap cleanup EXIT; \ + DB_DATABASE=$(TEST_DB) AGENT_CONTROL_TEST_URL=$(TEST_SERVER_URL) AGENT_CONTROL_URL=$(TEST_SERVER_URL) uv run pytest --cov=src --cov-report=xml:../../coverage-sdk.xml -q lint: uv run ruff check --config ../../pyproject.toml src/ diff --git a/sdks/python/src/agent_control/__init__.py b/sdks/python/src/agent_control/__init__.py index dd1a38be..b3e41d90 100644 --- a/sdks/python/src/agent_control/__init__.py +++ b/sdks/python/src/agent_control/__init__.py @@ -947,7 +947,7 @@ async def create_control( Args: name: Unique name for the control - data: Optional control definition with selector, evaluator, action, etc. + data: Optional control definition with a condition tree, action, scope, etc. server_url: Optional server URL (defaults to AGENT_CONTROL_URL env var) api_key: Optional API key for authentication (defaults to AGENT_CONTROL_API_KEY env var) @@ -972,10 +972,12 @@ async def main(): data={ "execution": "server", "scope": {"step_types": ["llm"], "stages": ["post"]}, - "selector": {"path": "output"}, - "evaluator": { - "name": "regex", - "config": {"pattern": r"\\d{3}-\\d{2}-\\d{4}"} + "condition": { + "selector": {"path": "output"}, + "evaluator": { + "name": "regex", + "config": {"pattern": r"\\d{3}-\\d{2}-\\d{4}"} + } }, "action": {"decision": "deny"} } diff --git a/sdks/python/src/agent_control/client.py b/sdks/python/src/agent_control/client.py index 208d1a50..41ce0425 100644 --- a/sdks/python/src/agent_control/client.py +++ b/sdks/python/src/agent_control/client.py @@ -1,10 +1,15 @@ """Base HTTP client for Agent Control server communication.""" +import logging import os from types import TracebackType import httpx +from . import __version__ as sdk_version + +_logger = logging.getLogger(__name__) + class AgentControlClient: """ @@ -33,10 +38,11 @@ class AgentControlClient: # Environment variable name for API key API_KEY_ENV_VAR = "AGENT_CONTROL_API_KEY" + BASE_URL_ENV_VAR = "AGENT_CONTROL_URL" def __init__( self, - base_url: str = "http://localhost:8000", + base_url: str | None = None, timeout: float = 30.0, api_key: str | None = None, ): @@ -44,15 +50,20 @@ def __init__( Initialize the client. Args: - base_url: Base URL of the Agent Control server + base_url: Base URL of the Agent Control server. If not provided, + AGENT_CONTROL_URL is used, falling back to http://localhost:8000. timeout: Request timeout in seconds api_key: API key for authentication. If not provided, will attempt to read from AGENT_CONTROL_API_KEY environment variable. """ - self.base_url = base_url.rstrip("/") + resolved_base_url = base_url or os.environ.get( + self.BASE_URL_ENV_VAR, "http://localhost:8000" + ) + self.base_url = resolved_base_url.rstrip("/") self.timeout = timeout self._api_key = api_key or os.environ.get(self.API_KEY_ENV_VAR) self._client: httpx.AsyncClient | None = None + self._server_version_warning_emitted = False @property def api_key(self) -> str | None: @@ -61,17 +72,43 @@ def api_key(self) -> str | None: def _get_headers(self) -> dict[str, str]: """Build request headers including authentication.""" - headers: dict[str, str] = {} + headers: dict[str, str] = { + "X-Agent-Control-SDK": "python", + "X-Agent-Control-SDK-Version": sdk_version, + } if self._api_key: headers["X-API-Key"] = self._api_key return headers + async def _check_server_version(self, response: httpx.Response) -> None: + """Warn once when the server major version differs from the SDK major.""" + if self._server_version_warning_emitted: + return + + server_version = response.headers.get("X-Agent-Control-Server-Version") + if not server_version: + return + + sdk_major = sdk_version.split(".", 1)[0] + server_major = server_version.split(".", 1)[0] + if sdk_major == server_major: + return + + _logger.warning( + "Agent Control SDK major version %s is talking to server major version %s. " + "Upgrade the SDK and server together to avoid control-schema mismatches.", + sdk_version, + server_version, + ) + self._server_version_warning_emitted = True + async def __aenter__(self) -> "AgentControlClient": """Async context manager entry.""" self._client = httpx.AsyncClient( base_url=self.base_url, timeout=self.timeout, headers=self._get_headers(), + event_hooks={"response": [self._check_server_version]}, ) return self @@ -108,4 +145,3 @@ def http_client(self) -> httpx.AsyncClient: if self._client is None: raise RuntimeError("Client not initialized. Use 'async with' context manager.") return self._client - diff --git a/sdks/python/src/agent_control/control_decorators.py b/sdks/python/src/agent_control/control_decorators.py index 569aecaa..1da113b1 100644 --- a/sdks/python/src/agent_control/control_decorators.py +++ b/sdks/python/src/agent_control/control_decorators.py @@ -5,7 +5,7 @@ Controls can be associated via policies and direct agent-control links. Architecture: - SERVER defines: Policies -> Controls (stage, selector, evaluator, action) + SERVER defines: Policies -> Controls (stage, condition tree, action) SDK decorator: just marks WHERE controls are evaluated Usage: @@ -744,7 +744,7 @@ def control(policy: str | None = None, step_name: str | None = None) -> Callable """ Decorator to apply server-defined controls at this code location. - Controls (stage, selector, evaluator, action) are defined on the SERVER. + Controls (stage, condition tree, action) are defined on the SERVER. This decorator marks WHERE to evaluate controls for the current agent. Args: diff --git a/sdks/python/src/agent_control/controls.py b/sdks/python/src/agent_control/controls.py index 2bc53c73..1305fca4 100644 --- a/sdks/python/src/agent_control/controls.py +++ b/sdks/python/src/agent_control/controls.py @@ -97,7 +97,7 @@ async def get_control( Dictionary containing: - id: Control ID - name: Control name - - data: Control definition (selector, evaluator, action) or None if not configured + - data: Control definition (condition, action, scope, etc.) or None if not configured Raises: httpx.HTTPError: If request fails @@ -132,7 +132,7 @@ async def create_control( Args: client: AgentControlClient instance name: Unique name for the control - data: Optional control definition (selector, evaluator, action, etc.) + data: Optional control definition (condition tree, action, scope, etc.) Returns: Dictionary containing: @@ -158,10 +158,12 @@ async def create_control( data={ "execution": "server", "scope": {"step_types": ["llm"], "stages": ["post"]}, - "selector": {"path": "output"}, - "evaluator": { - "name": "regex", - "config": {"pattern": r"\\d{3}-\\d{2}-\\d{4}"} + "condition": { + "selector": {"path": "output"}, + "evaluator": { + "name": "regex", + "config": {"pattern": r"\\d{3}-\\d{2}-\\d{4}"} + } }, "action": {"decision": "deny"} } @@ -195,7 +197,7 @@ async def set_control_data( """ Set the configuration data for a control. - This defines what the control actually does (selector, evaluator, action). + This defines what the control actually does (condition tree, action, scope). Args: client: AgentControlClient instance diff --git a/sdks/python/src/agent_control/evaluation.py b/sdks/python/src/agent_control/evaluation.py index e30bb3e2..878de63b 100644 --- a/sdks/python/src/agent_control/evaluation.py +++ b/sdks/python/src/agent_control/evaluation.py @@ -30,6 +30,19 @@ _FALLBACK_SPAN_ID = "0" * 16 _trace_warning_logged = False + +def _primary_leaf_details( + control_def: ControlDefinition, +) -> tuple[str | None, str | None]: + """Return selector/evaluator identifiers for single-leaf controls only.""" + primary_leaf = control_def.primary_leaf() + primary_parts = primary_leaf.leaf_parts() if primary_leaf else None + if primary_parts is None: + return None, None + selector, evaluator = primary_parts + return selector.path, evaluator.name + + def _map_applies_to(step_type: str) -> Literal["llm_call", "tool_call"]: return "tool_call" if step_type == "tool" else "llm_call" @@ -77,6 +90,9 @@ def _emit_matches(matches: list[ControlMatch] | None, matched: bool) -> None: return for match in matches: ctrl = control_lookup.get(match.control_id) + selector_path, evaluator_name = ( + _primary_leaf_details(ctrl.control) if ctrl else (None, None) + ) add_event( ControlExecutionEvent( control_execution_id=match.control_execution_id, @@ -91,8 +107,8 @@ def _emit_matches(matches: list[ControlMatch] | None, matched: bool) -> None: matched=matched, confidence=match.result.confidence, timestamp=now, - evaluator_name=ctrl.control.evaluator.name if ctrl else None, - selector_path=ctrl.control.selector.path if ctrl else None, + evaluator_name=evaluator_name, + selector_path=selector_path, error_message=match.result.error if not matched else None, metadata=match.result.metadata or {}, ) @@ -213,6 +229,7 @@ async def check_evaluation_with_local( local_controls: list[_ControlAdapter] = [] parse_errors: list[ControlMatch] = [] has_server_controls = False + available_evaluators = list_evaluators() for control in controls: control_data = control.get("control", {}) @@ -225,20 +242,21 @@ async def check_evaluation_with_local( try: control_def = ControlDefinition.model_validate(control_data) - evaluator_name = control_def.evaluator.name - - if ":" in evaluator_name: - raise RuntimeError( - f"Control '{control['name']}' is marked execution='sdk' but uses " - f"agent-scoped evaluator '{evaluator_name}' which is server-only. " - "Set execution='server' or use a built-in evaluator." - ) - if evaluator_name not in list_evaluators(): - raise RuntimeError( - f"Control '{control['name']}' is marked execution='sdk' but evaluator " - f"'{evaluator_name}' is not available in the SDK. " - "Install the evaluator or set execution='server'." - ) + for _, evaluator_spec in control_def.iter_condition_leaf_parts(): + evaluator_name = evaluator_spec.name + + if ":" in evaluator_name: + raise RuntimeError( + f"Control '{control['name']}' is marked execution='sdk' but uses " + f"agent-scoped evaluator '{evaluator_name}' which is server-only. " + "Set execution='server' or use a built-in evaluator." + ) + if evaluator_name not in available_evaluators: + raise RuntimeError( + f"Control '{control['name']}' is marked execution='sdk' but evaluator " + f"'{evaluator_name}' is not available in the SDK. " + "Install the evaluator or set execution='server'." + ) local_controls.append( _ControlAdapter( diff --git a/sdks/python/tests/test_client.py b/sdks/python/tests/test_client.py new file mode 100644 index 00000000..aff6796e --- /dev/null +++ b/sdks/python/tests/test_client.py @@ -0,0 +1,99 @@ +"""Unit tests for AgentControlClient configuration and version warnings.""" + +from __future__ import annotations + +from unittest.mock import patch + +import httpx +import pytest + +from agent_control.client import AgentControlClient, sdk_version + + +def test_client_uses_agent_control_url_env_var( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Given: AGENT_CONTROL_URL is set in the environment + monkeypatch.setenv("AGENT_CONTROL_URL", "http://example.test:9000/") + + # When: constructing a client without an explicit base URL + client = AgentControlClient() + + # Then: the client uses the environment-provided server URL + assert client.base_url == "http://example.test:9000" + + +def test_explicit_base_url_overrides_env_var( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Given: AGENT_CONTROL_URL is set but an explicit base URL is also provided + monkeypatch.setenv("AGENT_CONTROL_URL", "http://env.test:9000") + + # When: constructing the client with an explicit base URL + client = AgentControlClient(base_url="http://explicit.test:8000/") + + # Then: the explicit base URL wins + assert client.base_url == "http://explicit.test:8000" + + +def test_get_headers_include_sdk_metadata_and_api_key() -> None: + # Given: a client configured with an API key + client = AgentControlClient(api_key="test-key") + + # When: building request headers + headers = client._get_headers() + + # Then: SDK metadata and authentication headers are included + assert headers["X-Agent-Control-SDK"] == "python" + assert headers["X-Agent-Control-SDK-Version"] == sdk_version + assert headers["X-API-Key"] == "test-key" + + +@pytest.mark.asyncio +async def test_check_server_version_warns_once_on_major_mismatch() -> None: + # Given: a server response with a mismatched major version header + client = AgentControlClient() + response = httpx.Response( + 200, + headers={"X-Agent-Control-Server-Version": "999.1.0"}, + ) + + # When: version checking runs twice for the same mismatch + with patch("agent_control.client._logger.warning") as mock_warning: + await client._check_server_version(response) + await client._check_server_version(response) + + # Then: the warning is emitted only once + mock_warning.assert_called_once() + + +@pytest.mark.asyncio +async def test_check_server_version_does_not_warn_on_matching_major() -> None: + # Given: a server response whose major version matches the SDK major version + client = AgentControlClient() + matching_major = sdk_version.split(".", 1)[0] + response = httpx.Response( + 200, + headers={"X-Agent-Control-Server-Version": f"{matching_major}.99.0"}, + ) + + # When: version checking runs + with patch("agent_control.client._logger.warning") as mock_warning: + await client._check_server_version(response) + + # Then: no warning is emitted + mock_warning.assert_not_called() + + +@pytest.mark.asyncio +async def test_check_server_version_ignores_missing_header() -> None: + # Given: a response without the server version header + client = AgentControlClient() + response = httpx.Response(200) + + # When: version checking runs + with patch("agent_control.client._logger.warning") as mock_warning: + await client._check_server_version(response) + + # Then: no warning is emitted + mock_warning.assert_not_called() diff --git a/sdks/python/tests/test_integration_agents.py b/sdks/python/tests/test_integration_agents.py index 66a22c43..04ee5f86 100644 --- a/sdks/python/tests/test_integration_agents.py +++ b/sdks/python/tests/test_integration_agents.py @@ -214,10 +214,12 @@ async def test_convenience_agent_association_functions( "enabled": True, "execution": "server", "scope": {"step_types": ["tool"], "stages": ["pre"]}, - "selector": {"path": "input"}, - "evaluator": { - "name": "regex", - "config": {"pattern": ".*"}, + "condition": { + "selector": {"path": "input"}, + "evaluator": { + "name": "regex", + "config": {"pattern": ".*"}, + }, }, "action": {"decision": "allow"}, "tags": ["test"], diff --git a/sdks/python/tests/test_integration_health.py b/sdks/python/tests/test_integration_health.py index 948dbfcc..8c8503c0 100644 --- a/sdks/python/tests/test_integration_health.py +++ b/sdks/python/tests/test_integration_health.py @@ -33,7 +33,9 @@ async def test_health_check_workflow( @pytest.mark.asyncio -async def test_client_context_manager() -> None: +async def test_client_context_manager( + monkeypatch: pytest.MonkeyPatch, server_url: str +) -> None: """ Test client context manager behavior. @@ -41,6 +43,10 @@ async def test_client_context_manager() -> None: - Client can be created and closed properly - Context manager handles cleanup """ + # Given: the SDK client is configured from the test server URL environment variable + monkeypatch.setenv("AGENT_CONTROL_URL", server_url) + + # When: using AgentControlClient as an async context manager async with agent_control.AgentControlClient() as client: # Verify client is initialized assert client._client is not None @@ -49,6 +55,7 @@ async def test_client_context_manager() -> None: health = await client.health_check() assert health is not None + # Then: the client works inside the context manager and exits cleanly # After context, client should be closed # (we can't easily verify this without accessing internals) print("✓ Client context manager works correctly") @@ -69,4 +76,3 @@ async def test_invalid_server_url() -> None: await client.health_check() print("✓ Invalid server URL correctly raises error") - diff --git a/sdks/python/tests/test_local_evaluation.py b/sdks/python/tests/test_local_evaluation.py index 94661ab0..bd2eeb3f 100644 --- a/sdks/python/tests/test_local_evaluation.py +++ b/sdks/python/tests/test_local_evaluation.py @@ -12,11 +12,9 @@ from unittest.mock import AsyncMock, MagicMock import pytest - from agent_control_models import ( ControlMatch, EvaluationResponse, - EvaluationResult, EvaluatorResult, Step, ) @@ -27,7 +25,6 @@ check_evaluation_with_local, ) - # ============================================================================= # Test Fixtures # ============================================================================= @@ -75,10 +72,12 @@ def make_control_dict( "enabled": True, "execution": execution, "scope": {"step_types": [step_type], "stages": [stage]}, - "selector": {"path": path}, - "evaluator": { - "name": evaluator, - "config": {"pattern": pattern}, + "condition": { + "selector": {"path": path}, + "evaluator": { + "name": evaluator, + "config": {"pattern": pattern}, + }, }, "action": {"decision": action}, }, @@ -254,6 +253,45 @@ async def test_server_only_controls_calls_server(self, agent_name, llm_payload): assert result.is_safe is True + @pytest.mark.asyncio + async def test_local_legacy_flat_control_executes_locally(self, agent_name, llm_payload): + """Legacy flat selector/evaluator controls should still execute in the SDK.""" + # Given: a local SDK control in the legacy flat selector/evaluator shape + controls = [ + { + "id": 1, + "name": "legacy_local_ctrl", + "control": { + "description": "Legacy local control", + "enabled": True, + "execution": "sdk", + "scope": {"step_types": ["llm"], "stages": ["pre"]}, + "selector": {"path": "input"}, + "evaluator": {"name": "regex", "config": {"pattern": "test"}}, + "action": {"decision": "deny"}, + }, + } + ] + client = MagicMock(spec=AgentControlClient) + client.http_client = AsyncMock() + client.http_client.post = AsyncMock() + + # When: running mixed local/server evaluation + result = await check_evaluation_with_local( + client=client, + agent_name=agent_name, + step=llm_payload, + stage="pre", + controls=controls, + ) + + # Then: the legacy control is canonicalized locally and no server call is needed + client.http_client.post.assert_not_called() + assert result.is_safe is False + assert result.matches is not None + assert len(result.matches) == 1 + assert result.matches[0].control_name == "legacy_local_ctrl" + @pytest.mark.asyncio async def test_local_deny_short_circuits(self, agent_name, llm_payload): """Local deny should return immediately without calling server.""" @@ -634,7 +672,8 @@ async def test_local_control_with_agent_scoped_evaluator_raises(self, agent_name async def test_server_control_with_missing_evaluator_allowed(self, agent_name, llm_payload): """Test that server control with unavailable evaluator is allowed (server handles it). - Given: A server control (execution="server") referencing an evaluator that doesn't exist locally + Given: A server control (execution="server") referencing an evaluator + that doesn't exist locally When: check_evaluation_with_local is called Then: No error, server is called to handle it """ @@ -755,8 +794,6 @@ async def test_local_evaluation_includes_steering_context(self, agent_name, llm_ Then: Response includes steering_context field in matches Coverage: Lines 275, 280, 298-301 in control_decorators.py """ - from agent_control_models.controls import SteeringContext - controls = [ { "id": 1, @@ -766,8 +803,10 @@ async def test_local_evaluation_includes_steering_context(self, agent_name, llm_ "enabled": True, "execution": "sdk", "scope": {"step_types": ["llm"], "stages": ["pre"]}, - "selector": {"path": "input"}, - "evaluator": {"name": "regex", "config": {"pattern": "test"}}, + "condition": { + "selector": {"path": "input"}, + "evaluator": {"name": "regex", "config": {"pattern": "test"}}, + }, "action": { "decision": "steer", "steering_context": { diff --git a/sdks/python/tests/test_observability_updates.py b/sdks/python/tests/test_observability_updates.py index a99d6b7a..12ebc1bd 100644 --- a/sdks/python/tests/test_observability_updates.py +++ b/sdks/python/tests/test_observability_updates.py @@ -1,13 +1,17 @@ """Tests for observability updates: event emission, non_matches propagation, applies_to mapping.""" from unittest.mock import AsyncMock, MagicMock, patch -from uuid import UUID import pytest +from agent_control_models import ControlDefinition from agent_control import evaluation -from agent_control.evaluation import _map_applies_to, _merge_results - +from agent_control.evaluation import ( + _ControlAdapter, + _emit_local_events, + _map_applies_to, + _merge_results, +) # ============================================================================= # _map_applies_to tests @@ -111,13 +115,12 @@ class TestEmitLocalEvents: def _make_control_adapter(self, id, name, evaluator_name="regex", selector_path="input"): """Create a _ControlAdapter for testing.""" - from agent_control.evaluation import _ControlAdapter - from agent_control_models import ControlDefinition - control_def = ControlDefinition( execution="sdk", - evaluator={"name": evaluator_name, "config": {"pattern": "test"}}, - selector={"path": selector_path}, + condition={ + "evaluator": {"name": evaluator_name, "config": {"pattern": "test"}}, + "selector": {"path": selector_path}, + }, action={"decision": "deny"}, ) return _ControlAdapter(id=id, name=name, control=control_def) @@ -247,6 +250,50 @@ def test_uses_fallback_ids_when_trace_context_missing(self): mock_logger.warning.assert_called_once() assert "fallback" in mock_logger.warning.call_args[0][0].lower() + def test_composite_control_omits_primary_leaf_metadata(self): + """Composite local controls should not emit misleading leaf metadata.""" + # Given: a composite local control and a non-match response for that control + ctrl = _ControlAdapter( + id=1, + name="composite-ctrl", + control=ControlDefinition( + execution="sdk", + condition={ + "and": [ + { + "selector": {"path": "input"}, + "evaluator": {"name": "regex", "config": {"pattern": "test"}}, + }, + { + "selector": {"path": "output"}, + "evaluator": {"name": "regex", "config": {"pattern": "done"}}, + }, + ] + }, + action={"decision": "allow"}, + ), + ) + non_match = self._make_match(1, "composite-ctrl", action="allow", matched=False) + response = self._make_response(non_matches=[non_match]) + request = self._make_request() + + # When: emitting local observability events + with patch("agent_control.evaluation.is_observability_enabled", return_value=True), \ + patch("agent_control.evaluation.add_event") as mock_add: + _emit_local_events( + response, + request, + [ctrl], + "trace123", + "span456", + "test-agent", + ) + event = mock_add.call_args_list[0][0][0] + + # Then: no single-leaf selector/evaluator metadata is attached + assert event.evaluator_name is None + assert event.selector_path is None + def test_fallback_warning_logged_only_once(self): """The missing-trace-context warning should fire only on the first call.""" import agent_control.evaluation as eval_mod @@ -307,8 +354,10 @@ async def test_emits_events_when_trace_context_provided(self): "id": 1, "name": "test-ctrl", "control": { - "evaluator": {"name": "regex", "config": {"pattern": "test"}}, - "selector": {"path": "input"}, + "condition": { + "evaluator": {"name": "regex", "config": {"pattern": "test"}}, + "selector": {"path": "input"}, + }, "action": {"decision": "allow"}, "execution": "sdk", }, @@ -359,8 +408,10 @@ async def test_emits_events_without_trace_context(self): "id": 1, "name": "test-ctrl", "control": { - "evaluator": {"name": "regex", "config": {"pattern": "test"}}, - "selector": {"path": "input"}, + "condition": { + "evaluator": {"name": "regex", "config": {"pattern": "test"}}, + "selector": {"path": "input"}, + }, "action": {"decision": "allow"}, "execution": "sdk", }, @@ -396,8 +447,10 @@ async def test_forwards_trace_headers_to_server(self): "id": 1, "name": "server-ctrl", "control": { - "evaluator": {"name": "regex", "config": {"pattern": "test"}}, - "selector": {"path": "input"}, + "condition": { + "evaluator": {"name": "regex", "config": {"pattern": "test"}}, + "selector": {"path": "input"}, + }, "action": {"decision": "deny"}, "execution": "server", }, @@ -447,7 +500,7 @@ class TestControlDecoratorsNonMatches: @pytest.mark.asyncio async def test_non_matches_populated_in_stats(self): """non_matches should be properly converted to dicts for stats tracking.""" - from agent_control.control_decorators import ControlContext, _log_control_evaluations + from agent_control.control_decorators import ControlContext # Simulate a result dict with non_matches result = { diff --git a/sdks/typescript/src/generated/lib/config.ts b/sdks/typescript/src/generated/lib/config.ts index aacf1282..3e841117 100644 --- a/sdks/typescript/src/generated/lib/config.ts +++ b/sdks/typescript/src/generated/lib/config.ts @@ -57,8 +57,8 @@ export function serverURLFromOptions(options: SDKOptions): URL | null { export const SDK_METADATA = { language: "typescript", - openapiDocVersion: "0.1.0", + openapiDocVersion: "6.7.2", sdkVersion: "0.1.0", genVersion: "2.827.0", - userAgent: "speakeasy-sdk/typescript 0.1.0 2.827.0 0.1.0 agent-control", + userAgent: "speakeasy-sdk/typescript 0.1.0 2.827.0 6.7.2 agent-control", } as const; diff --git a/sdks/typescript/src/generated/models/condition-node-input.ts b/sdks/typescript/src/generated/models/condition-node-input.ts new file mode 100644 index 00000000..a156596b --- /dev/null +++ b/sdks/typescript/src/generated/models/condition-node-input.ts @@ -0,0 +1,74 @@ +/* + * Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT. + */ + +import * as z from "zod/v4-mini"; +import { + ControlSelector, + ControlSelector$Outbound, + ControlSelector$outboundSchema, +} from "./control-selector.js"; +import { + EvaluatorSpec, + EvaluatorSpec$Outbound, + EvaluatorSpec$outboundSchema, +} from "./evaluator-spec.js"; + +/** + * Recursive boolean condition tree for control evaluation. + */ +export type ConditionNodeInput = { + /** + * Logical AND over child conditions. + */ + and?: Array | null | undefined; + /** + * Leaf evaluator. Must be provided together with selector. + */ + evaluator?: EvaluatorSpec | null | undefined; + /** + * Logical NOT over a single child condition. + */ + not?: ConditionNodeInput | null | undefined; + /** + * Logical OR over child conditions. + */ + or?: Array | null | undefined; + /** + * Leaf selector. Must be provided together with evaluator. + */ + selector?: ControlSelector | null | undefined; +}; + +/** @internal */ +export type ConditionNodeInput$Outbound = { + and?: Array | null | undefined; + evaluator?: EvaluatorSpec$Outbound | null | undefined; + not?: ConditionNodeInput$Outbound | null | undefined; + or?: Array | null | undefined; + selector?: ControlSelector$Outbound | null | undefined; +}; + +/** @internal */ +export const ConditionNodeInput$outboundSchema: z.ZodMiniType< + ConditionNodeInput$Outbound, + ConditionNodeInput +> = z.object({ + and: z.optional( + z.nullable(z.array(z.lazy(() => ConditionNodeInput$outboundSchema))), + ), + evaluator: z.optional(z.nullable(EvaluatorSpec$outboundSchema)), + not: z.optional(z.nullable(z.lazy(() => ConditionNodeInput$outboundSchema))), + or: z.optional( + z.nullable(z.array(z.lazy(() => ConditionNodeInput$outboundSchema))), + ), + selector: z.optional(z.nullable(ControlSelector$outboundSchema)), +}); + +export function conditionNodeInputToJSON( + conditionNodeInput: ConditionNodeInput, +): string { + return JSON.stringify( + ConditionNodeInput$outboundSchema.parse(conditionNodeInput), + ); +} diff --git a/sdks/typescript/src/generated/models/condition-node-output.ts b/sdks/typescript/src/generated/models/condition-node-output.ts new file mode 100644 index 00000000..d0320504 --- /dev/null +++ b/sdks/typescript/src/generated/models/condition-node-output.ts @@ -0,0 +1,68 @@ +/* + * Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT. + */ + +import * as z from "zod/v4-mini"; +import { safeParse } from "../lib/schemas.js"; +import { Result as SafeParseResult } from "../types/fp.js"; +import { + ControlSelector, + ControlSelector$inboundSchema, +} from "./control-selector.js"; +import { SDKValidationError } from "./errors/sdk-validation-error.js"; +import { + EvaluatorSpec, + EvaluatorSpec$inboundSchema, +} from "./evaluator-spec.js"; + +/** + * Recursive boolean condition tree for control evaluation. + */ +export type ConditionNodeOutput = { + /** + * Logical AND over child conditions. + */ + and?: Array | null | undefined; + /** + * Leaf evaluator. Must be provided together with selector. + */ + evaluator?: EvaluatorSpec | null | undefined; + /** + * Logical NOT over a single child condition. + */ + not?: ConditionNodeOutput | null | undefined; + /** + * Logical OR over child conditions. + */ + or?: Array | null | undefined; + /** + * Leaf selector. Must be provided together with evaluator. + */ + selector?: ControlSelector | null | undefined; +}; + +/** @internal */ +export const ConditionNodeOutput$inboundSchema: z.ZodMiniType< + ConditionNodeOutput, + unknown +> = z.object({ + and: z.optional( + z.nullable(z.array(z.lazy(() => ConditionNodeOutput$inboundSchema))), + ), + evaluator: z.optional(z.nullable(EvaluatorSpec$inboundSchema)), + not: z.optional(z.nullable(z.lazy(() => ConditionNodeOutput$inboundSchema))), + or: z.optional( + z.nullable(z.array(z.lazy(() => ConditionNodeOutput$inboundSchema))), + ), + selector: z.optional(z.nullable(ControlSelector$inboundSchema)), +}); + +export function conditionNodeOutputFromJSON( + jsonString: string, +): SafeParseResult { + return safeParse( + jsonString, + (x) => ConditionNodeOutput$inboundSchema.parse(JSON.parse(x)), + `Failed to parse 'ConditionNodeOutput' from JSON`, + ); +} diff --git a/sdks/typescript/src/generated/models/control-definition-input.ts b/sdks/typescript/src/generated/models/control-definition-input.ts index a5c5c352..f385b0d5 100644 --- a/sdks/typescript/src/generated/models/control-definition-input.ts +++ b/sdks/typescript/src/generated/models/control-definition-input.ts @@ -4,6 +4,11 @@ import * as z from "zod/v4-mini"; import { ClosedEnum } from "../types/enums.js"; +import { + ConditionNodeInput, + ConditionNodeInput$Outbound, + ConditionNodeInput$outboundSchema, +} from "./condition-node-input.js"; import { ControlAction, ControlAction$Outbound, @@ -14,16 +19,6 @@ import { ControlScope$Outbound, ControlScope$outboundSchema, } from "./control-scope.js"; -import { - ControlSelector, - ControlSelector$Outbound, - ControlSelector$outboundSchema, -} from "./control-selector.js"; -import { - EvaluatorSpec, - EvaluatorSpec$Outbound, - EvaluatorSpec$outboundSchema, -} from "./evaluator-spec.js"; /** * Where this control executes @@ -52,6 +47,10 @@ export type ControlDefinitionInput = { * What to do when control matches. */ action: ControlAction; + /** + * Recursive boolean condition tree for control evaluation. + */ + condition: ConditionNodeInput; /** * Detailed description of the control */ @@ -60,17 +59,6 @@ export type ControlDefinitionInput = { * Whether this control is active */ enabled?: boolean | undefined; - /** - * Evaluator specification. See GET /evaluators for available evaluators and schemas. - * - * @remarks - * - * Evaluator reference formats: - * - Built-in: "regex", "list", "json", "sql" - * - External: "galileo.luna2" (requires agent-control-evaluators[galileo]) - * - Agent-scoped: "my-agent:my-evaluator" (validated in endpoint, not here) - */ - evaluator: EvaluatorSpec; /** * Where this control executes */ @@ -79,15 +67,6 @@ export type ControlDefinitionInput = { * Defines when a control applies to a Step. */ scope?: ControlScope | undefined; - /** - * Selects data from a Step payload. - * - * @remarks - * - * - path: which slice of the Step to feed into the evaluator. Optional, defaults to "*" - * meaning the entire Step object. - */ - selector: ControlSelector; /** * Tags for categorization */ @@ -102,12 +81,11 @@ export const ControlDefinitionInputExecution$outboundSchema: z.ZodMiniEnum< /** @internal */ export type ControlDefinitionInput$Outbound = { action: ControlAction$Outbound; + condition: ConditionNodeInput$Outbound; description?: string | null | undefined; enabled: boolean; - evaluator: EvaluatorSpec$Outbound; execution: string; scope?: ControlScope$Outbound | undefined; - selector: ControlSelector$Outbound; tags?: Array | undefined; }; @@ -117,12 +95,11 @@ export const ControlDefinitionInput$outboundSchema: z.ZodMiniType< ControlDefinitionInput > = z.object({ action: ControlAction$outboundSchema, + condition: ConditionNodeInput$outboundSchema, description: z.optional(z.nullable(z.string())), enabled: z._default(z.boolean(), true), - evaluator: EvaluatorSpec$outboundSchema, execution: ControlDefinitionInputExecution$outboundSchema, scope: z.optional(ControlScope$outboundSchema), - selector: ControlSelector$outboundSchema, tags: z.optional(z.array(z.string())), }); diff --git a/sdks/typescript/src/generated/models/control-definition-output.ts b/sdks/typescript/src/generated/models/control-definition-output.ts index ce4e2714..8199d5f3 100644 --- a/sdks/typescript/src/generated/models/control-definition-output.ts +++ b/sdks/typescript/src/generated/models/control-definition-output.ts @@ -8,20 +8,16 @@ import * as openEnums from "../types/enums.js"; import { OpenEnum } from "../types/enums.js"; import { Result as SafeParseResult } from "../types/fp.js"; import * as types from "../types/primitives.js"; +import { + ConditionNodeOutput, + ConditionNodeOutput$inboundSchema, +} from "./condition-node-output.js"; import { ControlAction, ControlAction$inboundSchema, } from "./control-action.js"; import { ControlScope, ControlScope$inboundSchema } from "./control-scope.js"; -import { - ControlSelector, - ControlSelector$inboundSchema, -} from "./control-selector.js"; import { SDKValidationError } from "./errors/sdk-validation-error.js"; -import { - EvaluatorSpec, - EvaluatorSpec$inboundSchema, -} from "./evaluator-spec.js"; /** * Where this control executes @@ -48,6 +44,10 @@ export type ControlDefinitionOutput = { * What to do when control matches. */ action: ControlAction; + /** + * Recursive boolean condition tree for control evaluation. + */ + condition: ConditionNodeOutput; /** * Detailed description of the control */ @@ -56,17 +56,6 @@ export type ControlDefinitionOutput = { * Whether this control is active */ enabled: boolean; - /** - * Evaluator specification. See GET /evaluators for available evaluators and schemas. - * - * @remarks - * - * Evaluator reference formats: - * - Built-in: "regex", "list", "json", "sql" - * - External: "galileo.luna2" (requires agent-control-evaluators[galileo]) - * - Agent-scoped: "my-agent:my-evaluator" (validated in endpoint, not here) - */ - evaluator: EvaluatorSpec; /** * Where this control executes */ @@ -75,15 +64,6 @@ export type ControlDefinitionOutput = { * Defines when a control applies to a Step. */ scope?: ControlScope | undefined; - /** - * Selects data from a Step payload. - * - * @remarks - * - * - path: which slice of the Step to feed into the evaluator. Optional, defaults to "*" - * meaning the entire Step object. - */ - selector: ControlSelector; /** * Tags for categorization */ @@ -100,12 +80,11 @@ export const ControlDefinitionOutput$inboundSchema: z.ZodMiniType< unknown > = z.object({ action: ControlAction$inboundSchema, + condition: ConditionNodeOutput$inboundSchema, description: z.optional(z.nullable(types.string())), enabled: z._default(types.boolean(), true), - evaluator: EvaluatorSpec$inboundSchema, execution: Execution$inboundSchema, scope: types.optional(ControlScope$inboundSchema), - selector: ControlSelector$inboundSchema, tags: types.optional(z.array(types.string())), }); diff --git a/sdks/typescript/src/generated/models/index.ts b/sdks/typescript/src/generated/models/index.ts index 53465775..aba5b19f 100644 --- a/sdks/typescript/src/generated/models/index.ts +++ b/sdks/typescript/src/generated/models/index.ts @@ -10,6 +10,8 @@ export * from "./assoc-response.js"; export * from "./auth-mode.js"; export * from "./batch-events-request.js"; export * from "./batch-events-response.js"; +export * from "./condition-node-input.js"; +export * from "./condition-node-output.js"; export * from "./config-response.js"; export * from "./conflict-mode.js"; export * from "./control-action.js"; diff --git a/server/Makefile b/server/Makefile index 3aae2898..5f8bd011 100644 --- a/server/Makefile +++ b/server/Makefile @@ -9,7 +9,9 @@ SHOW ?= head STAMP ?= head TEST_DB ?= agent_control_test -.PHONY: help run start-dependencies test migrate alembic-migrate alembic-revision alembic-upgrade alembic-downgrade alembic-current alembic-history alembic-heads alembic-show alembic-stamp +.PHONY: help run start-dependencies test migrate migrate-control-conditions alembic-migrate alembic-revision alembic-upgrade alembic-downgrade alembic-current alembic-history alembic-heads alembic-show alembic-stamp + +MIGRATE_ARGS ?= --dry-run help: @echo "Available targets:" @@ -17,6 +19,7 @@ help: @echo " start-dependencies - docker compose up -d (start local dependencies)" @echo " test - run server tests (uses DB_DATABASE=$(TEST_DB))" @echo " migrate - run database migrations (alembic upgrade head)" + @echo " migrate-control-conditions - rewrite stored controls to condition trees (default dry-run; set MIGRATE_ARGS=--apply to commit)" @echo " alembic-migrate MSG='message' - autogenerate alembic revision" @echo " alembic-upgrade UP=head - upgrade to revision" @echo " alembic-downgrade DOWN=-1 - downgrade to revision" @@ -30,6 +33,9 @@ migrate: @echo "Running database migrations..." $(ALEMBIC) upgrade head +migrate-control-conditions: + uv run --package agent-control-server agent-control-migrate-controls $(MIGRATE_ARGS) + alembic-migrate: $(ALEMBIC) revision --autogenerate -m "$(MSG)" @@ -66,4 +72,4 @@ test: DB_DATABASE=$(TEST_DB) uv run --package agent-control-server pytest --cov=src --cov-report=xml:../coverage-server.xml -q run: start-dependencies migrate - uv run --package agent-control-server uvicorn agent_control_server.main:app --reload + uv run -m agent_control_server.main diff --git a/server/pyproject.toml b/server/pyproject.toml index 40bbe243..b6699da4 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -48,6 +48,7 @@ dev = [ [project.scripts] agent-control-server = "agent_control_server.main:run" +agent-control-migrate-controls = "agent_control_server.scripts.migrate_control_conditions:main" diff --git a/server/src/agent_control_server/endpoints/agents.py b/server/src/agent_control_server/endpoints/agents.py index 23d3a547..8f7f8048 100644 --- a/server/src/agent_control_server/endpoints/agents.py +++ b/server/src/agent_control_server/endpoints/agents.py @@ -3,6 +3,7 @@ from agent_control_engine import list_evaluators from agent_control_models.agent import Agent as APIAgent from agent_control_models.agent import StepSchema +from agent_control_models.controls import ControlDefinition from agent_control_models.errors import ErrorCode, ValidationErrorItem from agent_control_models.server import ( AgentControlsResponse, @@ -109,43 +110,48 @@ def _validate_controls_for_agent(agent: Agent, controls: list[Control]) -> list[ if not control.data: continue - evaluator_cfg = control.data.get("evaluator", {}) - evaluator_name = evaluator_cfg.get("name", "") - if not evaluator_name: + try: + control_definition = ControlDefinition.model_validate(control.data) + except ValidationError: + errors.append(f"Control '{control.name}' has corrupted data") continue - parsed = parse_evaluator_ref_full(evaluator_name) - if parsed.type != "agent": - continue # Built-in/external evaluator, already validated at control creation - - # Agent-scoped evaluator - check if target matches this agent - if parsed.namespace != agent.name: - errors.append( - f"Control '{control.name}' references evaluator '{evaluator_name}' " - f"which belongs to agent '{parsed.namespace}', not '{agent.name}'" - ) - continue + for _, evaluator_cfg in control_definition.iter_condition_leaf_parts(): + evaluator_name = evaluator_cfg.name + parsed = parse_evaluator_ref_full(evaluator_name) + if parsed.type != "agent": + continue # Built-in/external evaluator, already validated at control creation - # Check if evaluator exists on this agent - if parsed.local_name not in agent_evaluators: - errors.append( - f"Control '{control.name}' references evaluator '{parsed.local_name}' " - f"which is not registered with agent '{agent.name}'. " - f"Register it via initAgent or use a different evaluator." - ) - continue + # Agent-scoped evaluator - check if target matches this agent + if parsed.namespace != agent.name: + errors.append( + f"Control '{control.name}' references evaluator '{evaluator_name}' " + f"which belongs to agent '{parsed.namespace}', not '{agent.name}'" + ) + continue - # Validate config against schema - registered_ev = agent_evaluators[parsed.local_name] - config = evaluator_cfg.get("config", {}) - if registered_ev.config_schema: - try: - validate_config_against_schema(config, registered_ev.config_schema) - except JSONSchemaValidationError as e: + # Check if evaluator exists on this agent + if parsed.local_name not in agent_evaluators: errors.append( - f"Control '{control.name}' invalid config for " - f"'{parsed.local_name}': {e.message}" + f"Control '{control.name}' references evaluator '{parsed.local_name}' " + f"which is not registered with agent '{agent.name}'. " + f"Register it via initAgent or use a different evaluator." ) + continue + + # Validate config against schema + registered_ev = agent_evaluators[parsed.local_name] + if registered_ev.config_schema: + try: + validate_config_against_schema( + evaluator_cfg.config, + registered_ev.config_schema, + ) + except JSONSchemaValidationError as e: + errors.append( + f"Control '{control.name}' invalid config for " + f"'{parsed.local_name}': {e.message}" + ) return errors @@ -200,15 +206,18 @@ async def _build_overwrite_evaluator_removals( references_by_evaluator: dict[str, list[tuple[int, str]]] = {} for control in controls: - evaluator_ref = control.control.evaluator.name - parsed = parse_evaluator_ref_full(evaluator_ref) - if parsed.type != "agent": - continue - if parsed.namespace != agent.name: - continue - if parsed.local_name not in removed_evaluators: - continue - references_by_evaluator.setdefault(parsed.local_name, []).append((control.id, control.name)) + for _, evaluator_spec in control.control.iter_condition_leaf_parts(): + evaluator_ref = evaluator_spec.name + parsed = parse_evaluator_ref_full(evaluator_ref) + if parsed.type != "agent": + continue + if parsed.namespace != agent.name: + continue + if parsed.local_name not in removed_evaluators: + continue + references_by_evaluator.setdefault(parsed.local_name, []).append( + (control.id, control.name) + ) removals: list[InitAgentEvaluatorRemoval] = [] for evaluator_name in sorted(removed_evaluators): @@ -1616,11 +1625,11 @@ async def patch_agent( referencing_controls: list[tuple[str, str]] = [] # (control_name, evaluator) for ctrl in controls: - evaluator_ref = ctrl.control.evaluator.name - if ":" in evaluator_ref: + for _, evaluator_spec in ctrl.control.iter_condition_leaf_parts(): + evaluator_ref = evaluator_spec.name + if ":" not in evaluator_ref: + continue ref_agent, ref_eval = evaluator_ref.split(":", 1) - # Check if this control references an evaluator we're removing - # AND it's scoped to this agent (by name match) if ref_agent == agent.name and ref_eval in remove_evaluator_set: referencing_controls.append((ctrl.name, ref_eval)) diff --git a/server/src/agent_control_server/endpoints/controls.py b/server/src/agent_control_server/endpoints/controls.py index 7957c377..8329a04f 100644 --- a/server/src/agent_control_server/endpoints/controls.py +++ b/server/src/agent_control_server/endpoints/controls.py @@ -1,5 +1,7 @@ +from collections.abc import Iterator + from agent_control_engine import list_evaluators -from agent_control_models import ControlDefinition +from agent_control_models import ConditionNode, ControlDefinition from agent_control_models.errors import ErrorCode, ValidationErrorItem from agent_control_models.server import ( AgentRef, @@ -35,11 +37,13 @@ ) from ..logging_utils import get_logger from ..models import Agent, AgentData, Control, agent_controls, agent_policies, policy_controls +from ..services.control_definitions import parse_control_definition_or_api_error from ..services.evaluator_utils import ( parse_evaluator_ref_full, validate_config_against_schema, ) from ..services.query_utils import escape_like_pattern +from ..services.validation_paths import format_field_path # Pagination constants _DEFAULT_PAGINATION_LIMIT = 20 @@ -53,148 +57,185 @@ _logger = get_logger(__name__) +def _iter_condition_leaves( + node: ConditionNode, + *, + path: str = "data.condition", +) -> Iterator[tuple[str, ConditionNode]]: + """Yield each leaf condition with its dot/bracket field path.""" + if node.is_leaf(): + yield path, node + return + + if node.and_ is not None: + for index, child in enumerate(node.and_): + yield from _iter_condition_leaves(child, path=f"{path}.and[{index}]") + return + + if node.or_ is not None: + for index, child in enumerate(node.or_): + yield from _iter_condition_leaves(child, path=f"{path}.or[{index}]") + return + + if node.not_ is not None: + yield from _iter_condition_leaves(node.not_, path=f"{path}.not") + + async def _validate_control_definition( control_def: ControlDefinition, db: AsyncSession ) -> None: """Validate evaluator config for a control definition.""" - evaluator_ref = control_def.evaluator.name - parsed = parse_evaluator_ref_full(evaluator_ref) + available_evaluators = list_evaluators() + agent_data_by_name: dict[str, AgentData] = {} + for field_prefix, leaf in _iter_condition_leaves(control_def.condition): + leaf_parts = leaf.leaf_parts() + if leaf_parts is None: + continue + _, evaluator_spec = leaf_parts + + evaluator_ref = evaluator_spec.name + parsed = parse_evaluator_ref_full(evaluator_ref) + + if parsed.type == "agent": + agent_data = agent_data_by_name.get(parsed.namespace) + if agent_data is None: + agent_result = await db.execute( + select(Agent).where(Agent.name == parsed.namespace) + ) + agent = agent_result.scalars().first() + if agent is None: + raise NotFoundError( + error_code=ErrorCode.AGENT_NOT_FOUND, + detail=f"Agent '{parsed.namespace}' not found", + resource="Agent", + resource_id=parsed.namespace, + hint=( + "Ensure the agent exists before creating controls " + "that reference its evaluators." + ), + ) - if parsed.type == "agent": - # Agent-scoped evaluator: validate against agent's registered schema - agent_result = await db.execute( - select(Agent).where(Agent.name == parsed.namespace) - ) - agent = agent_result.scalars().first() - if agent is None: - raise NotFoundError( - error_code=ErrorCode.AGENT_NOT_FOUND, - detail=f"Agent '{parsed.namespace}' not found", - resource="Agent", - resource_id=parsed.namespace, - hint=( - "Ensure the agent exists before creating controls " - "that reference its evaluators." - ), + try: + agent_data = AgentData.model_validate(agent.data) + except ValidationError as e: + raise APIValidationError( + error_code=ErrorCode.CORRUPTED_DATA, + detail=f"Agent '{parsed.namespace}' has invalid data", + resource="Agent", + errors=[ + ValidationErrorItem( + resource="Agent", + field=format_field_path(err.get("loc", ())), + code=err.get("type", "validation_error"), + message=err.get("msg", "Validation failed"), + ) + for err in e.errors() + ], + ) from e + agent_data_by_name[parsed.namespace] = agent_data + + evaluator = next( + (e for e in (agent_data.evaluators or []) if e.name == parsed.local_name), + None, ) + if evaluator is None: + available = [e.name for e in (agent_data.evaluators or [])] + raise APIValidationError( + error_code=ErrorCode.EVALUATOR_NOT_FOUND, + detail=( + f"Evaluator '{parsed.local_name}' is not registered " + f"with agent '{parsed.namespace}'" + ), + resource="Evaluator", + hint=( + f"Register it via initAgent first. " + f"Available evaluators: {available or 'none'}." + ), + errors=[ + ValidationErrorItem( + resource="Control", + field=f"{field_prefix}.evaluator.name", + code="evaluator_not_found", + message=( + f"Evaluator '{parsed.local_name}' not found " + f"on agent '{parsed.namespace}'" + ), + value=evaluator_ref, + ) + ], + ) + + if evaluator.config_schema: + try: + validate_config_against_schema( + evaluator_spec.config, + evaluator.config_schema, + ) + except JSONSchemaValidationError: + raise APIValidationError( + error_code=ErrorCode.INVALID_CONFIG, + detail=f"Config validation failed for evaluator '{evaluator_ref}'", + resource="Control", + hint=( + "Check the evaluator's config schema for required fields and types." + ), + errors=[ + ValidationErrorItem( + resource="Control", + field=f"{field_prefix}.evaluator.config", + code="schema_validation_error", + message=_SCHEMA_VALIDATION_FAILED_MESSAGE, + ) + ], + ) + continue + + evaluator_cls = available_evaluators.get(parsed.name) + if evaluator_cls is None: + continue try: - agent_data = AgentData.model_validate(agent.data) + evaluator_cls.config_model(**evaluator_spec.config) except ValidationError as e: raise APIValidationError( - error_code=ErrorCode.CORRUPTED_DATA, - detail=f"Agent '{parsed.namespace}' has invalid data", - resource="Agent", + error_code=ErrorCode.INVALID_CONFIG, + detail=f"Config validation failed for evaluator '{parsed.name}'", + resource="Control", + hint="Check the evaluator's config schema for required fields and types.", errors=[ ValidationErrorItem( - resource="Agent", - field=".".join(str(loc) for loc in err.get("loc", [])), + resource="Control", + field=( + f"{field_prefix}.evaluator.config." + f"{format_field_path(err.get('loc', ())) or ''}" + ).rstrip("."), code=err.get("type", "validation_error"), message=err.get("msg", "Validation failed"), ) for err in e.errors() ], ) - - evaluator = next( - (e for e in (agent_data.evaluators or []) if e.name == parsed.local_name), - None, - ) - if evaluator is None: - available = [e.name for e in (agent_data.evaluators or [])] + except TypeError: + _logger.warning( + "Config validation raised TypeError for evaluator '%s'", + parsed.name, + exc_info=True, + ) raise APIValidationError( - error_code=ErrorCode.EVALUATOR_NOT_FOUND, - detail=( - f"Evaluator '{parsed.local_name}' is not registered " - f"with agent '{parsed.namespace}'" - ), - resource="Evaluator", - hint=( - f"Register it via initAgent first. " - f"Available evaluators: {available or 'none'}." - ), + error_code=ErrorCode.INVALID_CONFIG, + detail=f"Invalid config parameters for evaluator '{parsed.name}'", + resource="Control", + hint="Check the evaluator's config schema for valid parameter names.", errors=[ ValidationErrorItem( resource="Control", - field="data.evaluator.name", - code="evaluator_not_found", - message=( - f"Evaluator '{parsed.local_name}' not found " - f"on agent '{parsed.namespace}'" - ), - value=evaluator_ref, + field=f"{field_prefix}.evaluator.config", + code="invalid_parameters", + message=_INVALID_PARAMETERS_MESSAGE, ) ], ) - # Validate config against evaluator's schema - if evaluator.config_schema: - try: - validate_config_against_schema( - control_def.evaluator.config, evaluator.config_schema - ) - except JSONSchemaValidationError: - raise APIValidationError( - error_code=ErrorCode.INVALID_CONFIG, - detail=f"Config validation failed for evaluator '{evaluator_ref}'", - resource="Control", - hint="Check the evaluator's config schema for required fields and types.", - errors=[ - ValidationErrorItem( - resource="Control", - field="data.evaluator.config", - code="schema_validation_error", - message=_SCHEMA_VALIDATION_FAILED_MESSAGE, - ) - ], - ) - else: - # Built-in or external evaluator: validate if registered - evaluator_cls = list_evaluators().get(parsed.name) - if evaluator_cls is not None: - try: - evaluator_cls.config_model(**control_def.evaluator.config) - except ValidationError as e: - raise APIValidationError( - error_code=ErrorCode.INVALID_CONFIG, - detail=f"Config validation failed for evaluator '{parsed.name}'", - resource="Control", - hint="Check the evaluator's config schema for required fields and types.", - errors=[ - ValidationErrorItem( - resource="Control", - field=( - "data.evaluator.config." - f"{'.'.join(str(loc) for loc in err.get('loc', []))}" - ), - code=err.get("type", "validation_error"), - message=err.get("msg", "Validation failed"), - ) - for err in e.errors() - ], - ) - except TypeError: - _logger.warning( - "Config validation raised TypeError for evaluator '%s'", - parsed.name, - exc_info=True, - ) - raise APIValidationError( - error_code=ErrorCode.INVALID_CONFIG, - detail=f"Invalid config parameters for evaluator '{parsed.name}'", - resource="Control", - hint="Check the evaluator's config schema for valid parameter names.", - errors=[ - ValidationErrorItem( - resource="Control", - field="data.evaluator.config", - code="invalid_parameters", - message=_INVALID_PARAMETERS_MESSAGE, - ) - ], - ) - # If evaluator not found, allow it - might be a server-side registered evaluator - @router.put( "", @@ -353,24 +394,12 @@ async def get_control_data( resource_id=str(control_id), hint="Verify the control ID is correct and the control has been created.", ) - try: - control_def = ControlDefinition.model_validate(control.data) - except ValidationError as e: - raise APIValidationError( - error_code=ErrorCode.CORRUPTED_DATA, - detail=f"Control '{control.name}' has invalid data", - resource="Control", - hint="Update the control data using PUT /{control_id}/data.", - errors=[ - ValidationErrorItem( - resource="Control", - field=".".join(str(loc) for loc in err.get("loc", [])), - code=err.get("type", "validation_error"), - message=err.get("msg", "Validation failed"), - ) - for err in e.errors() - ], - ) + control_def = parse_control_definition_or_api_error( + control.data, + detail=f"Control '{control.name}' has invalid data", + hint="Update the control data using PUT /{control_id}/data.", + field_prefix=None, + ) return GetControlDataResponse(data=control_def) @@ -418,17 +447,12 @@ async def set_control_data( # Validate evaluator config using shared logic await _validate_control_definition(request.data, db) - data_json = request.data.model_dump(mode="json", exclude_none=True, exclude_unset=True) - # Pydantic's exclude_none doesn't propagate into nested model dicts after - # serialization, so we re-dump the selector separately to strip null keys. - try: - selector_json = request.data.selector.model_dump(exclude_none=True, exclude_unset=True) # type: ignore[attr-defined] - selector_json = {k: v for k, v in selector_json.items() if v is not None} - if selector_json: - data_json["selector"] = selector_json - except AttributeError: - # Selector doesn't support model_dump, use original serialization - pass + data_json = request.data.model_dump( + mode="json", + by_alias=True, + exclude_none=True, + exclude_unset=True, + ) # Ensure scope does not store null/None for step_names or other optional fields, # so round-trip (save then load) preserves step selection in the UI. if "scope" in data_json and isinstance(data_json["scope"], dict): diff --git a/server/src/agent_control_server/endpoints/evaluation.py b/server/src/agent_control_server/endpoints/evaluation.py index 34c28c5d..99e96f02 100644 --- a/server/src/agent_control_server/endpoints/evaluation.py +++ b/server/src/agent_control_server/endpoints/evaluation.py @@ -58,6 +58,30 @@ def _sanitize_evaluator_error(error_message: str) -> str: return SAFE_EVALUATOR_ERROR +def _sanitize_condition_trace(trace: object) -> object: + """Recursively redact internal evaluator errors from condition traces.""" + if isinstance(trace, list): + return [_sanitize_condition_trace(item) for item in trace] + + if not isinstance(trace, dict): + return trace + + sanitized = { + key: _sanitize_condition_trace(value) + for key, value in trace.items() + } + + raw_error = sanitized.get("error") + if isinstance(raw_error, str) and raw_error: + safe_error = _sanitize_evaluator_error(raw_error) + sanitized["error"] = safe_error + raw_message = sanitized.get("message") + if raw_message is None or isinstance(raw_message, str): + sanitized["message"] = safe_error + + return sanitized + + def _sanitize_control_match(match: ControlMatch) -> ControlMatch: """Redact internal evaluator error strings from a control match.""" if match.result.error is None: @@ -65,10 +89,15 @@ def _sanitize_control_match(match: ControlMatch) -> ControlMatch: safe_error = _sanitize_evaluator_error(match.result.error) safe_message = safe_error + metadata = dict(match.result.metadata or {}) + condition_trace = metadata.get("condition_trace") + if condition_trace is not None: + metadata["condition_trace"] = _sanitize_condition_trace(condition_trace) sanitized_result = match.result.model_copy( update={ "error": safe_error, "message": safe_message, + "metadata": metadata or None, } ) return match.model_copy(update={"result": sanitized_result}) @@ -239,6 +268,8 @@ async def _emit_observability_events( if response.matches: for match in response.matches: ctrl = control_lookup.get(match.control_id) + primary_leaf = ctrl.control.primary_leaf() if ctrl else None + primary_parts = primary_leaf.leaf_parts() if primary_leaf else None events.append( ControlExecutionEvent( control_execution_id=match.control_execution_id, @@ -253,8 +284,8 @@ async def _emit_observability_events( matched=True, confidence=match.result.confidence, timestamp=now, - evaluator_name=ctrl.control.evaluator.name if ctrl else None, - selector_path=ctrl.control.selector.path if ctrl else None, + evaluator_name=primary_parts[1].name if primary_parts else None, + selector_path=primary_parts[0].path if primary_parts else None, error_message=match.result.error, metadata=match.result.metadata or {}, ) @@ -264,6 +295,8 @@ async def _emit_observability_events( if response.errors: for error in response.errors: ctrl = control_lookup.get(error.control_id) + primary_leaf = ctrl.control.primary_leaf() if ctrl else None + primary_parts = primary_leaf.leaf_parts() if primary_leaf else None events.append( ControlExecutionEvent( control_execution_id=error.control_execution_id, @@ -278,8 +311,8 @@ async def _emit_observability_events( matched=False, confidence=error.result.confidence, timestamp=now, - evaluator_name=ctrl.control.evaluator.name if ctrl else None, - selector_path=ctrl.control.selector.path if ctrl else None, + evaluator_name=primary_parts[1].name if primary_parts else None, + selector_path=primary_parts[0].path if primary_parts else None, error_message=error.result.error, metadata=error.result.metadata or {}, ) @@ -289,6 +322,8 @@ async def _emit_observability_events( if response.non_matches: for non_match in response.non_matches: ctrl = control_lookup.get(non_match.control_id) + primary_leaf = ctrl.control.primary_leaf() if ctrl else None + primary_parts = primary_leaf.leaf_parts() if primary_leaf else None events.append( ControlExecutionEvent( control_execution_id=non_match.control_execution_id, @@ -303,8 +338,8 @@ async def _emit_observability_events( matched=False, confidence=non_match.result.confidence, timestamp=now, - evaluator_name=ctrl.control.evaluator.name if ctrl else None, - selector_path=ctrl.control.selector.path if ctrl else None, + evaluator_name=primary_parts[1].name if primary_parts else None, + selector_path=primary_parts[0].path if primary_parts else None, error_message=None, metadata=non_match.result.metadata or {}, ) diff --git a/server/src/agent_control_server/errors.py b/server/src/agent_control_server/errors.py index 94af5890..1066a7cb 100644 --- a/server/src/agent_control_server/errors.py +++ b/server/src/agent_control_server/errors.py @@ -47,6 +47,8 @@ from fastapi import HTTPException, Request from fastapi.responses import JSONResponse +from .services.validation_paths import format_field_path + _logger = logging.getLogger(__name__) _MAX_PUBLIC_TEXT_LENGTH = 500 @@ -575,8 +577,8 @@ async def validation_exception_handler( # Build field path from location loc = error.get("loc", ()) # Skip 'body' prefix in location - field_parts = [str(p) for p in loc if p != "body"] - field = ".".join(field_parts) if field_parts else None + field_parts = [p for p in loc if p != "body"] + field = format_field_path(field_parts) # Determine resource from first path component resource = "Request" @@ -589,7 +591,7 @@ async def validation_exception_handler( "data": "Control", "policy": "Policy", } - first_part = field_parts[0].lower() + first_part = str(field_parts[0]).lower() resource = prefix_map.get(first_part, resource) errors.append( diff --git a/server/src/agent_control_server/main.py b/server/src/agent_control_server/main.py index 27509762..8f00adc9 100644 --- a/server/src/agent_control_server/main.py +++ b/server/src/agent_control_server/main.py @@ -14,6 +14,7 @@ from fastapi.openapi.utils import get_openapi from starlette_exporter import PrometheusMiddleware, handle_metrics +from . import __version__ as server_version from .auth import require_api_key from .config import observability_settings, settings from .db import AsyncSessionLocal @@ -141,7 +142,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: 4. Assign the policy to your agent 5. Query agent's active controls with `/api/v1/agents/{agent_name}/controls` """, - version="0.1.0", + version=server_version, lifespan=lifespan, ) @@ -161,6 +162,14 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: configure_logging(level=log_level) +@app.middleware("http") +async def attach_version_header(request, call_next): # type: ignore[no-untyped-def] + """Attach server version metadata to every response.""" + response = await call_next(request) + response.headers["X-Agent-Control-Server-Version"] = server_version + return response + + # ============================================================================= # Exception Handlers (RFC 7807 / Kubernetes / GitHub-style) # ============================================================================= @@ -264,7 +273,7 @@ async def health_check() -> HealthResponse: Returns: HealthResponse with status and version """ - return HealthResponse(status="healthy", version="0.1.0") + return HealthResponse(status="healthy", version=server_version) configure_ui_routes(app) diff --git a/server/src/agent_control_server/scripts/__init__.py b/server/src/agent_control_server/scripts/__init__.py new file mode 100644 index 00000000..e8e00625 --- /dev/null +++ b/server/src/agent_control_server/scripts/__init__.py @@ -0,0 +1 @@ +"""Operational scripts for the Agent Control server package.""" diff --git a/server/src/agent_control_server/scripts/migrate_control_conditions.py b/server/src/agent_control_server/scripts/migrate_control_conditions.py new file mode 100644 index 00000000..a25fbae4 --- /dev/null +++ b/server/src/agent_control_server/scripts/migrate_control_conditions.py @@ -0,0 +1,114 @@ +"""Rewrite stored control payloads into canonical condition-tree form.""" + +from __future__ import annotations + +import argparse + +from sqlalchemy import create_engine, select +from sqlalchemy.orm import Session + +from agent_control_server.config import db_config +from agent_control_server.models import Control +from agent_control_server.services.control_migration import ( + ControlMigrationResult, + migrate_control_payload, +) + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Migrate stored controls from legacy selector/evaluator fields " + "to canonical condition trees." + ), + ) + mode_group = parser.add_mutually_exclusive_group() + mode_group.add_argument( + "--dry-run", + action="store_true", + help="Analyze stored controls without writing changes (default).", + ) + mode_group.add_argument( + "--apply", + action="store_true", + help="Apply the migration after a clean analysis run.", + ) + return parser.parse_args() + + +def _print_summary( + *, + total: int, + unchanged: int, + migrated: list[tuple[Control, ControlMigrationResult]], + invalid: list[tuple[Control, ControlMigrationResult]], + apply: bool, +) -> None: + mode = "apply" if apply else "dry-run" + print(f"Control condition migration summary ({mode})") + print(f"Total controls: {total}") + print(f"Already canonical: {unchanged}") + print(f"Ready to migrate: {len(migrated)}") + print(f"Invalid/corrupted: {len(invalid)}") + + if invalid: + print("") + print("Invalid controls:") + for control, result in invalid: + reason = result.reason or "Unknown validation error." + print(f"- id={control.id} name={control.name}: {reason}") + + +def main() -> int: + args = _parse_args() + apply = bool(args.apply) + + engine = create_engine(db_config.get_url(), future=True) + + try: + with Session(engine) as session: + controls = list(session.execute(select(Control).order_by(Control.id)).scalars().all()) + migrated: list[tuple[Control, ControlMigrationResult]] = [] + invalid: list[tuple[Control, ControlMigrationResult]] = [] + unchanged = 0 + + for control in controls: + result = migrate_control_payload(control.data) + if result.status == "unchanged": + unchanged += 1 + elif result.status == "migrated": + migrated.append((control, result)) + else: + invalid.append((control, result)) + + _print_summary( + total=len(controls), + unchanged=unchanged, + migrated=migrated, + invalid=invalid, + apply=apply, + ) + + if invalid: + if apply: + print("") + print("Aborting apply because invalid controls must be fixed first.") + return 1 + + if not apply: + return 0 + + for control, result in migrated: + assert result.payload is not None + control.data = result.payload + + session.commit() + print("") + print(f"Applied migration to {len(migrated)} controls.") + return 0 + finally: + engine.dispose() + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/server/src/agent_control_server/services/control_definitions.py b/server/src/agent_control_server/services/control_definitions.py new file mode 100644 index 00000000..c73c8287 --- /dev/null +++ b/server/src/agent_control_server/services/control_definitions.py @@ -0,0 +1,64 @@ +"""Helpers for parsing stored control definitions consistently.""" + +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from typing import Any, cast + +from agent_control_models import ControlDefinition +from agent_control_models.errors import ErrorCode, ValidationErrorItem +from pydantic import ValidationError + +from ..errors import APIValidationError +from .validation_paths import format_field_path + + +def build_control_validation_errors( + validation_error: ValidationError, + *, + field_prefix: str | None = "data", +) -> list[ValidationErrorItem]: + """Convert ControlDefinition validation errors into API error items.""" + items: list[ValidationErrorItem] = [] + for err in validation_error.errors(): + loc = cast(Sequence[str | int], err.get("loc", ())) + field_suffix = format_field_path(loc) + if field_prefix is None: + field = field_suffix + elif field_suffix is None: + field = field_prefix + else: + field = f"{field_prefix}.{field_suffix}" + + items.append( + ValidationErrorItem( + resource="Control", + field=field, + code=err.get("type", "validation_error"), + message=err.get("msg", "Validation failed"), + ) + ) + return items + + +def parse_control_definition_or_api_error( + data: Any, + *, + detail: str, + hint: str, + resource_id: str | None = None, + context: Mapping[str, Any] | None = None, + field_prefix: str | None = "data", +) -> ControlDefinition: + """Parse stored control data or raise a structured CORRUPTED_DATA error.""" + try: + return ControlDefinition.model_validate(data, context=dict(context) if context else None) + except ValidationError as exc: + raise APIValidationError( + error_code=ErrorCode.CORRUPTED_DATA, + detail=detail, + resource="Control", + resource_id=resource_id, + hint=hint, + errors=build_control_validation_errors(exc, field_prefix=field_prefix), + ) from exc diff --git a/server/src/agent_control_server/services/control_migration.py b/server/src/agent_control_server/services/control_migration.py new file mode 100644 index 00000000..f6b26da3 --- /dev/null +++ b/server/src/agent_control_server/services/control_migration.py @@ -0,0 +1,89 @@ +"""Helpers for migrating stored controls to condition trees.""" + +from __future__ import annotations + +from copy import deepcopy +from dataclasses import dataclass +from typing import Any, Literal + +from agent_control_models import ControlDefinition +from pydantic import ValidationError + +type MigrationStatus = Literal["unchanged", "migrated", "invalid"] + + +@dataclass(frozen=True) +class ControlMigrationResult: + """Outcome of migrating a single stored control payload.""" + + status: MigrationStatus + payload: dict[str, Any] | None = None + reason: str | None = None + + +def _validation_message(error: ValidationError) -> str: + first_error = error.errors()[0] + location = ".".join(str(part) for part in first_error.get("loc", ())) + message = first_error.get("msg", "Validation failed.") + if location: + return f"{location}: {message}" + return message + + +def migrate_control_payload(data: object) -> ControlMigrationResult: + """Migrate a stored control payload to canonical condition-tree shape.""" + if not isinstance(data, dict): + return ControlMigrationResult( + status="invalid", + reason="Stored control data must be a JSON object.", + ) + + has_condition = "condition" in data + has_selector = "selector" in data + has_evaluator = "evaluator" in data + + if has_condition and (has_selector or has_evaluator): + return ControlMigrationResult( + status="invalid", + reason=( + "Stored control data mixes canonical condition fields " + "with legacy selector/evaluator fields." + ), + ) + + candidate = deepcopy(data) + status: MigrationStatus = "unchanged" + + if not has_condition: + if not has_selector: + return ControlMigrationResult( + status="invalid", + reason="Stored control data is missing the condition definition.", + ) + status = "migrated" + + try: + candidate = ControlDefinition.canonicalize_payload(candidate) + except ValueError as error: + return ControlMigrationResult( + status="invalid", + reason=str(error), + ) + + try: + validated = ControlDefinition.model_validate(candidate) + except ValidationError as error: + return ControlMigrationResult( + status="invalid", + reason=_validation_message(error), + ) + + return ControlMigrationResult( + status=status, + payload=validated.model_dump( + mode="json", + by_alias=True, + exclude_none=True, + exclude_unset=True, + ), + ) diff --git a/server/src/agent_control_server/services/controls.py b/server/src/agent_control_server/services/controls.py index ea334006..5ab97237 100644 --- a/server/src/agent_control_server/services/controls.py +++ b/server/src/agent_control_server/services/controls.py @@ -1,19 +1,13 @@ from __future__ import annotations -import logging from collections.abc import Sequence -from agent_control_models import ControlDefinition -from agent_control_models.errors import ErrorCode, ValidationErrorItem from agent_control_models.policy import Control as APIControl -from pydantic import ValidationError from sqlalchemy import select, union from sqlalchemy.ext.asyncio import AsyncSession -from ..errors import APIValidationError from ..models import Control, agent_controls, agent_policies, policy_controls - -_logger = logging.getLogger(__name__) +from .control_definitions import parse_control_definition_or_api_error async def list_controls_for_policy(policy_id: int, db: AsyncSession) -> list[Control]: @@ -67,37 +61,18 @@ async def list_controls_for_agent( # Map DB Control to API Control, raising on invalid definitions api_controls: list[APIControl] = [] for c in db_controls: - try: - context = ( - {"allow_invalid_step_name_regex": True} - if allow_invalid_step_name_regex - else None - ) - control_def = ControlDefinition.model_validate(c.data, context=context) - api_controls.append(APIControl(id=c.id, name=c.name, control=control_def)) - except ValidationError as e: - error_items = [] - for err in e.errors(): - loc: Sequence[str | int] = err.get("loc", []) - field_suffix = ".".join(str(part) for part in loc) if loc else "" - error_items.append( - ValidationErrorItem( - resource="Control", - field=f"data.{field_suffix}" if field_suffix else "data", - code=err.get("type", "validation_error"), - message=err.get("msg", "Validation failed"), - ) - ) - - raise APIValidationError( - error_code=ErrorCode.CORRUPTED_DATA, - detail=f"Control '{c.name}' has corrupted data", - resource="Control", - resource_id=str(c.id), - hint=( - "Update the control data using " - f"PUT /api/v1/controls/{c.id}/data." - ), - errors=error_items, - ) from e + context = ( + {"allow_invalid_step_name_regex": True} + if allow_invalid_step_name_regex + else None + ) + control_def = parse_control_definition_or_api_error( + c.data, + detail=f"Control '{c.name}' has corrupted data", + resource_id=str(c.id), + hint=f"Update the control data using PUT /api/v1/controls/{c.id}/data.", + context=context, + field_prefix="data", + ) + api_controls.append(APIControl(id=c.id, name=c.name, control=control_def)) return api_controls diff --git a/server/src/agent_control_server/services/validation_paths.py b/server/src/agent_control_server/services/validation_paths.py new file mode 100644 index 00000000..9df1052e --- /dev/null +++ b/server/src/agent_control_server/services/validation_paths.py @@ -0,0 +1,18 @@ +"""Helpers for formatting nested validation field paths.""" + +from collections.abc import Sequence + + +def format_field_path(parts: Sequence[str | int]) -> str | None: + """Format nested field parts using dot/bracket notation.""" + field = "" + for part in parts: + if isinstance(part, int): + field += f"[{part}]" + continue + + if field: + field += "." + field += part + + return field or None diff --git a/server/tests/test_agents_additional.py b/server/tests/test_agents_additional.py index 8c8085f8..8c69f79f 100644 --- a/server/tests/test_agents_additional.py +++ b/server/tests/test_agents_additional.py @@ -7,7 +7,7 @@ from fastapi.testclient import TestClient from sqlalchemy import text -from .utils import VALID_CONTROL_PAYLOAD +from .utils import VALID_CONTROL_PAYLOAD, canonicalize_control_payload from .conftest import engine @@ -39,7 +39,10 @@ def _create_control_with_data(client: TestClient, data: dict) -> int: resp = client.put("/api/v1/controls", json={"name": f"control-{uuid.uuid4()}"}) assert resp.status_code == 200 control_id = resp.json()["control_id"] - set_resp = client.put(f"/api/v1/controls/{control_id}/data", json={"data": data}) + set_resp = client.put( + f"/api/v1/controls/{control_id}/data", + json={"data": canonicalize_control_payload(data)}, + ) assert set_resp.status_code == 200, set_resp.text return control_id @@ -220,7 +223,7 @@ def test_patch_agent_remove_evaluator_in_use_conflict(client: TestClient) -> Non agent_name, agent_name = _init_agent(client, evaluators=evaluators) control_payload = deepcopy(VALID_CONTROL_PAYLOAD) - control_payload["evaluator"] = { + control_payload["condition"]["evaluator"] = { "name": f"{agent_name}:custom", "config": {"pattern": "x"}, } @@ -255,7 +258,7 @@ def test_set_agent_policy_incompatible_controls(client: TestClient) -> None: agent_a_id, agent_a_name = _init_agent(client, evaluators=evaluators) control_payload = deepcopy(VALID_CONTROL_PAYLOAD) - control_payload["evaluator"] = { + control_payload["condition"]["evaluator"] = { "name": f"{agent_a_name}:custom", "config": {}, } @@ -347,7 +350,7 @@ def test_list_agent_controls_corrupted_control_data_returns_422( # Given: an agent with a policy that includes a control agent_name, _ = _init_agent(client) control_payload = deepcopy(VALID_CONTROL_PAYLOAD) - control_payload["evaluator"] = {"name": "regex", "config": {"pattern": "x"}} + control_payload["condition"]["evaluator"] = {"name": "regex", "config": {"pattern": "x"}} control_id = _create_control_with_data(client, control_payload) policy_id = _create_policy(client) assoc = client.post(f"/api/v1/policies/{policy_id}/controls/{control_id}") @@ -442,12 +445,15 @@ def test_set_agent_policy_rejects_missing_agent_evaluator(client: TestClient) -> assert assoc.status_code == 200 with engine.begin() as conn: + corrupted_payload = deepcopy(VALID_CONTROL_PAYLOAD) + corrupted_payload["condition"]["evaluator"] = { + "name": f"{agent_name}:missing", + "config": {}, + } conn.execute( text("UPDATE controls SET data = CAST(:data AS JSONB) WHERE id = :id"), { - "data": json.dumps( - {"evaluator": {"name": f"{agent_name}:missing", "config": {}}} - ), + "data": json.dumps(corrupted_payload), "id": control_id, }, ) @@ -484,12 +490,15 @@ def test_set_agent_policy_rejects_invalid_agent_evaluator_config(client: TestCli assert assoc.status_code == 200 with engine.begin() as conn: + corrupted_payload = deepcopy(VALID_CONTROL_PAYLOAD) + corrupted_payload["condition"]["evaluator"] = { + "name": f"{agent_name}:custom", + "config": {}, + } conn.execute( text("UPDATE controls SET data = CAST(:data AS JSONB) WHERE id = :id"), { - "data": json.dumps( - {"evaluator": {"name": f"{agent_name}:custom", "config": {}}} - ), + "data": json.dumps(corrupted_payload), "id": control_id, }, ) @@ -666,8 +675,8 @@ def test_set_agent_policy_skips_controls_without_data(client: TestClient) -> Non assert resp.json()["success"] is True -def test_set_agent_policy_skips_controls_without_evaluator_name(client: TestClient) -> None: - # Given: an agent and a policy with a control missing evaluator name +def test_set_agent_policy_rejects_controls_without_evaluator_name(client: TestClient) -> None: + # Given: an agent and a policy with a stored control whose leaf is missing evaluator name agent_name, _ = _init_agent(client) policy_id = _create_policy(client) control_id = _create_control_with_data(client, VALID_CONTROL_PAYLOAD) @@ -675,17 +684,21 @@ def test_set_agent_policy_skips_controls_without_evaluator_name(client: TestClie assert assoc.status_code == 200 with engine.begin() as conn: + corrupted_payload = deepcopy(VALID_CONTROL_PAYLOAD) + corrupted_payload["condition"]["evaluator"] = {"config": {}} conn.execute( text("UPDATE controls SET data = CAST(:data AS JSONB) WHERE id = :id"), - {"data": json.dumps({"evaluator": {}}), "id": control_id}, + {"data": json.dumps(corrupted_payload), "id": control_id}, ) # When: assigning the policy to the agent resp = client.post(f"/api/v1/agents/{agent_name}/policy/{policy_id}") - # Then: assignment succeeds because evaluator name is missing - assert resp.status_code == 200 - assert resp.json()["success"] is True + # Then: assignment is rejected because the stored control data is corrupted + assert resp.status_code == 400 + body = resp.json() + assert body["error_code"] == "POLICY_CONTROL_INCOMPATIBLE" + assert any("corrupted data" in err.get("message", "").lower() for err in body.get("errors", [])) def test_list_agents_includes_active_controls_count(client: TestClient) -> None: diff --git a/server/tests/test_auth.py b/server/tests/test_auth.py index 57c2bba1..d31522a8 100644 --- a/server/tests/test_auth.py +++ b/server/tests/test_auth.py @@ -5,11 +5,9 @@ import pytest from fastapi.testclient import TestClient +from agent_control_server import __version__ as server_version from agent_control_server.config import auth_settings -from .conftest import TEST_ADMIN_API_KEY, TEST_API_KEY - - class TestHealthEndpoint: """Health endpoint should always be accessible without authentication.""" @@ -22,6 +20,7 @@ def test_health_without_auth(self, unauthenticated_client: TestClient) -> None: assert response.status_code == 200 data = response.json() assert data["status"] == "healthy" + assert response.headers["X-Agent-Control-Server-Version"] == server_version def test_health_with_auth(self, client: TestClient) -> None: """Given valid API key, when requesting health, then returns 200.""" @@ -45,6 +44,7 @@ def test_missing_api_key_returns_401(self, unauthenticated_client: TestClient) - # Then: assert response.status_code == 401 assert "Missing credentials" in response.json()["detail"] + assert response.headers["X-Agent-Control-Server-Version"] == server_version def test_invalid_api_key_returns_401(self, app: object) -> None: """Given invalid API key, when requesting protected endpoint, then returns 401.""" @@ -145,10 +145,12 @@ def test_evaluators_accessible_when_disabled( "enabled": True, "execution": "server", "scope": {"step_types": ["llm"], "stages": ["pre"]}, - "selector": {"path": "input"}, - "evaluator": { - "name": "regex", - "config": {"pattern": "test", "flags": []}, + "condition": { + "selector": {"path": "input"}, + "evaluator": { + "name": "regex", + "config": {"pattern": "test", "flags": []}, + }, }, "action": {"decision": "deny"}, "tags": ["test"], diff --git a/server/tests/test_control_compatibility.py b/server/tests/test_control_compatibility.py new file mode 100644 index 00000000..6485f1c3 --- /dev/null +++ b/server/tests/test_control_compatibility.py @@ -0,0 +1,177 @@ +"""Compatibility coverage for legacy flat control payloads.""" + +from __future__ import annotations + +import json +import uuid +from copy import deepcopy + +from fastapi.testclient import TestClient +from sqlalchemy import text + +from .conftest import engine +from .utils import VALID_CONTROL_PAYLOAD + + +def _init_agent(client: TestClient, *, agent_name: str | None = None) -> str: + name = (agent_name or f"agent-{uuid.uuid4().hex[:12]}").lower() + if len(name) < 10: + name = f"{name}-agent".replace("--", "-") + resp = client.post( + "/api/v1/agents/initAgent", + json={ + "agent": { + "agent_name": name, + "agent_description": "desc", + "agent_version": "1.0", + }, + "steps": [], + "evaluators": [], + }, + ) + assert resp.status_code == 200 + return name + + +def _create_policy(client: TestClient) -> int: + resp = client.put("/api/v1/policies", json={"name": f"policy-{uuid.uuid4()}"}) + assert resp.status_code == 200 + return resp.json()["policy_id"] + + +def _legacy_control_payload() -> dict[str, object]: + payload = deepcopy(VALID_CONTROL_PAYLOAD) + payload["selector"] = payload["condition"]["selector"] + payload["evaluator"] = payload["condition"]["evaluator"] + payload.pop("condition") + return payload + + +def test_set_agent_policy_accepts_legacy_stored_control_payload(client: TestClient) -> None: + # Given: an assigned policy whose stored control row has been reverted to the legacy flat shape + agent_name = _init_agent(client) + policy_id = _create_policy(client) + + control_resp = client.put("/api/v1/controls", json={"name": f"control-{uuid.uuid4()}"}) + assert control_resp.status_code == 200 + control_id = control_resp.json()["control_id"] + + set_resp = client.put( + f"/api/v1/controls/{control_id}/data", + json={"data": VALID_CONTROL_PAYLOAD}, + ) + assert set_resp.status_code == 200 + + assoc = client.post(f"/api/v1/policies/{policy_id}/controls/{control_id}") + assert assoc.status_code == 200 + + with engine.begin() as conn: + conn.execute( + text("UPDATE controls SET data = CAST(:data AS JSONB) WHERE id = :id"), + {"data": json.dumps(_legacy_control_payload()), "id": control_id}, + ) + + # When: assigning the policy to the agent + resp = client.post(f"/api/v1/agents/{agent_name}/policy/{policy_id}") + + # Then: assignment succeeds because the legacy payload is canonicalized on read + assert resp.status_code == 200 + assert resp.json()["success"] is True + + +def test_get_control_data_returns_canonical_shape_for_legacy_stored_payload( + client: TestClient, +) -> None: + # Given: a control whose stored row has been reverted to the legacy flat shape + control_resp = client.put("/api/v1/controls", json={"name": f"control-{uuid.uuid4()}"}) + assert control_resp.status_code == 200 + control_id = control_resp.json()["control_id"] + + with engine.begin() as conn: + conn.execute( + text("UPDATE controls SET data = CAST(:data AS JSONB) WHERE id = :id"), + {"data": json.dumps(_legacy_control_payload()), "id": control_id}, + ) + + # When: fetching control data through the typed API endpoint + resp = client.get(f"/api/v1/controls/{control_id}/data") + + # Then: the response is accepted and serialized back in canonical condition form + assert resp.status_code == 200 + data = resp.json()["data"] + assert "selector" not in data + assert "evaluator" not in data + assert data["condition"]["selector"]["path"] == "input" + assert data["condition"]["evaluator"]["name"] == "regex" + + +def test_list_agent_controls_returns_canonical_shape_for_legacy_stored_payload( + client: TestClient, +) -> None: + # Given: an agent assigned a policy whose control row is stored in legacy flat shape + agent_name = _init_agent(client) + policy_id = _create_policy(client) + + control_resp = client.put("/api/v1/controls", json={"name": f"control-{uuid.uuid4()}"}) + assert control_resp.status_code == 200 + control_id = control_resp.json()["control_id"] + + set_resp = client.put( + f"/api/v1/controls/{control_id}/data", + json={"data": VALID_CONTROL_PAYLOAD}, + ) + assert set_resp.status_code == 200 + + assoc = client.post(f"/api/v1/policies/{policy_id}/controls/{control_id}") + assert assoc.status_code == 200 + assign = client.post(f"/api/v1/agents/{agent_name}/policy/{policy_id}") + assert assign.status_code == 200 + + with engine.begin() as conn: + conn.execute( + text("UPDATE controls SET data = CAST(:data AS JSONB) WHERE id = :id"), + {"data": json.dumps(_legacy_control_payload()), "id": control_id}, + ) + + # When: listing active controls for the agent + resp = client.get(f"/api/v1/agents/{agent_name}/controls") + + # Then: the control is returned and serialized in canonical condition form + assert resp.status_code == 200 + controls = resp.json()["controls"] + assert len(controls) == 1 + control = controls[0]["control"] + assert "selector" not in control + assert "evaluator" not in control + assert control["condition"]["selector"]["path"] == "input" + assert control["condition"]["evaluator"]["name"] == "regex" + + +def test_get_control_data_rejects_partial_legacy_stored_payload( + client: TestClient, +) -> None: + # Given: a stored control row with only one half of the legacy flat shape + control_resp = client.put("/api/v1/controls", json={"name": f"control-{uuid.uuid4()}"}) + assert control_resp.status_code == 200 + control_id = control_resp.json()["control_id"] + + invalid_payload = _legacy_control_payload() + invalid_payload.pop("evaluator") + with engine.begin() as conn: + conn.execute( + text("UPDATE controls SET data = CAST(:data AS JSONB) WHERE id = :id"), + {"data": json.dumps(invalid_payload), "id": control_id}, + ) + + # When: fetching control data through the typed API endpoint + resp = client.get(f"/api/v1/controls/{control_id}/data") + + # Then: the API reports structured corrupted-data validation instead of silently accepting it + assert resp.status_code == 422 + body = resp.json() + assert body["error_code"] == "CORRUPTED_DATA" + assert any( + "Legacy control definition must include both selector and evaluator." + in error.get("message", "") + for error in body.get("errors", []) + ) diff --git a/server/tests/test_control_migration.py b/server/tests/test_control_migration.py new file mode 100644 index 00000000..d7c97923 --- /dev/null +++ b/server/tests/test_control_migration.py @@ -0,0 +1,250 @@ +"""Tests for stored control condition migration.""" + +from __future__ import annotations + +from argparse import Namespace +from copy import deepcopy +from dataclasses import dataclass +from typing import Any + +from agent_control_server.scripts import migrate_control_conditions +from agent_control_server.services.control_migration import migrate_control_payload + +from .utils import VALID_CONTROL_PAYLOAD + + +def test_migrate_control_payload_rewrites_legacy_leaf() -> None: + # Given: a stored control payload in the legacy flat shape + legacy_payload = deepcopy(VALID_CONTROL_PAYLOAD) + legacy_payload["selector"] = legacy_payload["condition"]["selector"] + legacy_payload["evaluator"] = legacy_payload["condition"]["evaluator"] + legacy_payload.pop("condition") + + # When: migrating the stored payload + result = migrate_control_payload(legacy_payload) + + # Then: the payload is rewritten into canonical condition form + assert result.status == "migrated" + assert result.payload is not None + assert "selector" not in result.payload + assert "evaluator" not in result.payload + assert result.payload["condition"]["selector"]["path"] == "input" + + +def test_migrate_control_payload_leaves_canonical_rows_unchanged() -> None: + # Given: a stored payload that is already canonical + # When: migrating the stored payload + result = migrate_control_payload(deepcopy(VALID_CONTROL_PAYLOAD)) + + # Then: no rewrite is needed and the payload is preserved + assert result.status == "unchanged" + assert result.payload == VALID_CONTROL_PAYLOAD + + +def test_migrate_control_payload_rejects_mixed_rows() -> None: + # Given: a stored payload that mixes canonical and legacy fields + mixed_payload = deepcopy(VALID_CONTROL_PAYLOAD) + mixed_payload["selector"] = {"path": "input"} + + # When: migrating the stored payload + result = migrate_control_payload(mixed_payload) + + # Then: migration rejects the ambiguous row as invalid + assert result.status == "invalid" + assert result.reason is not None + assert "mixes canonical condition fields" in result.reason + + +def test_migrate_control_payload_rejects_partial_legacy_rows() -> None: + # Given: a legacy payload that is missing one of selector/evaluator + partial_payload = deepcopy(VALID_CONTROL_PAYLOAD) + partial_payload.pop("condition") + partial_payload["selector"] = {"path": "input"} + + # When: migrating the stored payload + result = migrate_control_payload(partial_payload) + + # Then: migration rejects the incomplete legacy row + assert result.status == "invalid" + assert result.reason == "Legacy control definition must include both selector and evaluator." + + +def test_migrate_control_payload_rejects_non_object_rows() -> None: + # Given: stored control data that is not a JSON object + # When: migrating the stored payload + result = migrate_control_payload(["not", "an", "object"]) + + # Then: migration reports the row as invalid + assert result.status == "invalid" + assert result.reason == "Stored control data must be a JSON object." + + +@dataclass +class _FakeControl: + id: int + name: str + data: dict[str, Any] + + +class _FakeResult: + def __init__(self, controls: list[_FakeControl]) -> None: + self._controls = controls + + def scalars(self) -> _FakeResult: + return self + + def all(self) -> list[_FakeControl]: + return self._controls + + +class _FakeSession: + def __init__(self, controls: list[_FakeControl]) -> None: + self._controls = controls + self.committed = False + + def __enter__(self) -> _FakeSession: + return self + + def __exit__(self, exc_type: object, exc: object, tb: object) -> bool: + return False + + def execute(self, _statement: object) -> _FakeResult: + return _FakeResult(self._controls) + + def commit(self) -> None: + self.committed = True + + +class _FakeEngine: + def __init__(self) -> None: + self.disposed = False + + def dispose(self) -> None: + self.disposed = True + + +def _make_legacy_payload() -> dict[str, Any]: + payload = deepcopy(VALID_CONTROL_PAYLOAD) + payload["selector"] = payload["condition"]["selector"] + payload["evaluator"] = payload["condition"]["evaluator"] + payload.pop("condition") + return payload + + +def _run_migration_script( + monkeypatch: Any, + *, + controls: list[_FakeControl], + apply: bool, +) -> tuple[int, _FakeSession, _FakeEngine]: + fake_session = _FakeSession(controls) + fake_engine = _FakeEngine() + + # Given: fake engine/session dependencies and parsed CLI args + monkeypatch.setattr( + migrate_control_conditions, + "_parse_args", + lambda: Namespace(apply=apply, dry_run=not apply), + ) + monkeypatch.setattr( + migrate_control_conditions, + "create_engine", + lambda *_args, **_kwargs: fake_engine, + ) + monkeypatch.setattr( + migrate_control_conditions, + "Session", + lambda _engine: fake_session, + ) + + # When: running the migration script entrypoint + exit_code = migrate_control_conditions.main() + + # Then: the caller can assert on exit code and fake side effects + return exit_code, fake_session, fake_engine + + +def test_migration_script_dry_run_reports_summary_without_writing( + monkeypatch: Any, + capsys: Any, +) -> None: + # Given: one canonical row and one legacy row ready to migrate + controls = [ + _FakeControl(id=1, name="canonical", data=deepcopy(VALID_CONTROL_PAYLOAD)), + _FakeControl(id=2, name="legacy", data=_make_legacy_payload()), + ] + + # When: running the script in dry-run mode + exit_code, fake_session, fake_engine = _run_migration_script( + monkeypatch, + controls=controls, + apply=False, + ) + output = capsys.readouterr().out + + # Then: the summary is correct and no commit occurs + assert exit_code == 0 + assert "Already canonical: 1" in output + assert "Ready to migrate: 1" in output + assert "Invalid/corrupted: 0" in output + assert fake_session.committed is False + assert fake_engine.disposed is True + assert "condition" not in controls[1].data + + +def test_migration_script_apply_rewrites_legacy_rows_and_commits( + monkeypatch: Any, + capsys: Any, +) -> None: + # Given: one canonical row and one legacy row ready to migrate + controls = [ + _FakeControl(id=1, name="canonical", data=deepcopy(VALID_CONTROL_PAYLOAD)), + _FakeControl(id=2, name="legacy", data=_make_legacy_payload()), + ] + + # When: running the script in apply mode + exit_code, fake_session, fake_engine = _run_migration_script( + monkeypatch, + controls=controls, + apply=True, + ) + output = capsys.readouterr().out + + # Then: only the legacy row is rewritten and the session commits + assert exit_code == 0 + assert "Applied migration to 1 controls." in output + assert fake_session.committed is True + assert fake_engine.disposed is True + assert controls[0].data == VALID_CONTROL_PAYLOAD + assert "condition" in controls[1].data + assert "selector" not in controls[1].data + assert "evaluator" not in controls[1].data + + +def test_migration_script_apply_aborts_when_invalid_rows_exist( + monkeypatch: Any, + capsys: Any, +) -> None: + # Given: a legacy row plus an invalid partial-legacy row + invalid_partial = _make_legacy_payload() + invalid_partial.pop("evaluator") + controls = [ + _FakeControl(id=1, name="legacy", data=_make_legacy_payload()), + _FakeControl(id=2, name="invalid", data=invalid_partial), + ] + + # When: running the script in apply mode + exit_code, fake_session, fake_engine = _run_migration_script( + monkeypatch, + controls=controls, + apply=True, + ) + output = capsys.readouterr().out + + # Then: apply aborts before commit and leaves rows untouched + assert exit_code == 1 + assert "Invalid/corrupted: 1" in output + assert "Aborting apply because invalid controls must be fixed first." in output + assert fake_session.committed is False + assert fake_engine.disposed is True + assert "condition" not in controls[0].data diff --git a/server/tests/test_controls.py b/server/tests/test_controls.py index c901b884..51520e7a 100644 --- a/server/tests/test_controls.py +++ b/server/tests/test_controls.py @@ -1,5 +1,6 @@ -from typing import Any import uuid +from copy import deepcopy +from typing import Any from fastapi.testclient import TestClient @@ -38,10 +39,12 @@ def test_get_control_data_initially_unconfigured(client: TestClient) -> None: "enabled": True, "execution": "server", "scope": {"step_types": ["llm"], "stages": ["pre"]}, - "selector": {"path": "input"}, - "evaluator": { - "name": "regex", - "config": {"pattern": "test", "flags": []} + "condition": { + "selector": {"path": "input"}, + "evaluator": { + "name": "regex", + "config": {"pattern": "test", "flags": []} + }, }, "action": {"decision": "deny"}, "tags": ["test"] @@ -67,9 +70,31 @@ def test_set_control_data_replaces_existing(client: TestClient) -> None: assert data["enabled"] == payload["enabled"] assert data["execution"] == payload["execution"] assert data["scope"] == payload["scope"] - assert data["evaluator"] == payload["evaluator"] + assert data["condition"]["evaluator"] == payload["condition"]["evaluator"] assert data["action"] == payload["action"] - assert data["selector"]["path"] == payload["selector"]["path"] + assert data["condition"]["selector"]["path"] == payload["condition"]["selector"]["path"] + + +def test_set_control_data_accepts_legacy_leaf_payload(client: TestClient) -> None: + # Given: a legacy flat selector/evaluator payload + control_id = create_control(client) + payload = deepcopy(VALID_CONTROL_DATA) + payload["selector"] = payload["condition"]["selector"] + payload["evaluator"] = payload["condition"]["evaluator"] + payload.pop("condition") + + # When: saving and reading back the control data + resp_put = client.put(f"/api/v1/controls/{control_id}/data", json={"data": payload}) + + # Then: the stored response is canonicalized into condition form + assert resp_put.status_code == 200, resp_put.text + resp_get = client.get(f"/api/v1/controls/{control_id}/data") + assert resp_get.status_code == 200 + data = resp_get.json()["data"] + assert "selector" not in data + assert "evaluator" not in data + assert data["condition"]["selector"]["path"] == "input" + assert data["condition"]["evaluator"]["name"] == "regex" def test_set_control_data_with_empty_dict_fails(client: TestClient) -> None: @@ -84,11 +109,11 @@ def test_set_control_data_with_empty_dict_fails(client: TestClient) -> None: def test_set_control_data_validates_nested_schema(client: TestClient) -> None: # Given: a control control_id = create_control(client) - + # When: setting invalid data (missing required fields) - invalid_data = {"conditions": "test"} + invalid_data = {"conditions": "test"} r = client.put(f"/api/v1/controls/{control_id}/data", json={"data": invalid_data}) - + # Then: 422 Validation Error assert r.status_code == 422 diff --git a/server/tests/test_controls_additional.py b/server/tests/test_controls_additional.py index f68c7b72..a5be9777 100644 --- a/server/tests/test_controls_additional.py +++ b/server/tests/test_controls_additional.py @@ -14,6 +14,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session +from agent_control_models import ConditionNode from agent_control_server.db import get_async_db from agent_control_server.models import Control @@ -657,7 +658,7 @@ def test_set_control_data_agent_scoped_agent_not_found(client: TestClient) -> No # When: setting data with a missing agent in evaluator ref payload = deepcopy(VALID_CONTROL_PAYLOAD) - payload["evaluator"] = {"name": "missing-agent:custom", "config": {"pattern": "x"}} + payload["condition"]["evaluator"] = {"name": "missing-agent:custom", "config": {"pattern": "x"}} resp = client.put(f"/api/v1/controls/{control_id}/data", json={"data": payload}) # Then: not found @@ -681,7 +682,7 @@ def test_set_control_data_agent_scoped_evaluator_missing(client: TestClient) -> control_id, _ = _create_control(client) payload = deepcopy(VALID_CONTROL_PAYLOAD) - payload["evaluator"] = {"name": f"{agent_name}:missing", "config": {"pattern": "x"}} + payload["condition"]["evaluator"] = {"name": f"{agent_name}:missing", "config": {"pattern": "x"}} # When: setting data with evaluator not registered on agent resp = client.put(f"/api/v1/controls/{control_id}/data", json={"data": payload}) @@ -690,7 +691,7 @@ def test_set_control_data_agent_scoped_evaluator_missing(client: TestClient) -> assert resp.status_code == 422 body = resp.json() assert body["error_code"] == "EVALUATOR_NOT_FOUND" - assert any(err.get("field") == "data.evaluator.name" for err in body.get("errors", [])) + assert any(err.get("field") == "data.condition.evaluator.name" for err in body.get("errors", [])) def test_set_control_data_agent_scoped_invalid_schema(client: TestClient) -> None: @@ -719,7 +720,7 @@ def test_set_control_data_agent_scoped_invalid_schema(client: TestClient) -> Non control_id, _ = _create_control(client) payload = deepcopy(VALID_CONTROL_PAYLOAD) - payload["evaluator"] = {"name": f"{agent_name}:custom", "config": {}} + payload["condition"]["evaluator"] = {"name": f"{agent_name}:custom", "config": {}} # When: setting data with config missing required fields resp = client.put(f"/api/v1/controls/{control_id}/data", json={"data": payload}) @@ -728,7 +729,7 @@ def test_set_control_data_agent_scoped_invalid_schema(client: TestClient) -> Non assert resp.status_code == 422 body = resp.json() assert body["error_code"] == "INVALID_CONFIG" - assert any(err.get("field") == "data.evaluator.config" for err in body.get("errors", [])) + assert any(err.get("field") == "data.condition.evaluator.config" for err in body.get("errors", [])) def test_patch_control_updates_name_and_enabled(client: TestClient) -> None: @@ -805,7 +806,7 @@ def test_set_control_data_agent_scoped_corrupted_agent_data_returns_422( control_id, _ = _create_control(client) payload = deepcopy(VALID_CONTROL_PAYLOAD) - payload["evaluator"] = {"name": f"{agent_name}:custom", "config": {}} + payload["condition"]["evaluator"] = {"name": f"{agent_name}:custom", "config": {}} # When: setting control data referencing the corrupted agent's evaluator resp = client.put(f"/api/v1/controls/{control_id}/data", json={"data": payload}) @@ -819,7 +820,7 @@ def test_set_control_data_unknown_evaluator_allowed(client: TestClient) -> None: # Given: a control with a non-registered evaluator name control_id, _ = _create_control(client) payload = deepcopy(VALID_CONTROL_PAYLOAD) - payload["evaluator"] = {"name": "unknown-eval", "config": {}} + payload["condition"]["evaluator"] = {"name": "unknown-eval", "config": {}} # When: setting the control data resp = client.put(f"/api/v1/controls/{control_id}/data", json={"data": payload}) @@ -845,7 +846,7 @@ class DummyEvaluator: ) payload = deepcopy(VALID_CONTROL_PAYLOAD) - payload["evaluator"] = {"name": "dummy", "config": {}} + payload["condition"]["evaluator"] = {"name": "dummy", "config": {}} # When: setting control data with invalid config resp = client.put(f"/api/v1/controls/{control_id}/data", json={"data": payload}) @@ -854,7 +855,10 @@ class DummyEvaluator: assert resp.status_code == 422 body = resp.json() assert body["error_code"] == "INVALID_CONFIG" - assert any("data.evaluator.config" in err.get("field", "") for err in body.get("errors", [])) + assert any( + "data.condition.evaluator.config" in err.get("field", "") + for err in body.get("errors", []) + ) def test_set_control_data_builtin_evaluator_invalid_parameters( @@ -875,7 +879,7 @@ def config_model(**_kwargs): # type: ignore[no-untyped-def] ) payload = deepcopy(VALID_CONTROL_PAYLOAD) - payload["evaluator"] = {"name": "dummy", "config": {"unexpected": "value"}} + payload["condition"]["evaluator"] = {"name": "dummy", "config": {"unexpected": "value"}} # When: setting control data with invalid parameters resp = client.put(f"/api/v1/controls/{control_id}/data", json={"data": payload}) @@ -906,11 +910,7 @@ async def test_set_control_data_selector_without_model_dump_uses_original_serial class DummyData: def __init__(self, data: dict[str, object]) -> None: self._data = data - self.selector = data["selector"] - self.evaluator = SimpleNamespace( - name=data["evaluator"]["name"], - config=data["evaluator"]["config"], - ) + self.condition = ConditionNode.model_validate(data["condition"]) def model_dump(self, *args: object, **kwargs: object) -> dict[str, object]: return self._data @@ -923,7 +923,7 @@ def model_dump(self, *args: object, **kwargs: object) -> dict[str, object]: # Then: the update succeeds and uses the original selector serialization assert response.success is True await async_db.refresh(control) - assert control.data["selector"] == payload["selector"] + assert control.data["condition"] == payload["condition"] def test_patch_control_rename_preserves_enabled(client: TestClient) -> None: diff --git a/server/tests/test_controls_validation.py b/server/tests/test_controls_validation.py index cd5084ae..6cb1848f 100644 --- a/server/tests/test_controls_validation.py +++ b/server/tests/test_controls_validation.py @@ -1,8 +1,13 @@ """Tests for control validation and schema enforcement.""" + import uuid +from copy import deepcopy + from fastapi.testclient import TestClient + from .utils import VALID_CONTROL_PAYLOAD + def create_control(client: TestClient) -> int: name = f"control-{uuid.uuid4()}" resp = client.put("/api/v1/controls", json={"name": name}) @@ -13,8 +18,8 @@ def test_validation_invalid_logic_enum(client: TestClient): """Test that invalid enum values in config are rejected.""" # Given: a control and a payload with invalid 'logic' value control_id = create_control(client) - payload = VALID_CONTROL_PAYLOAD.copy() - payload["evaluator"] = { + payload = deepcopy(VALID_CONTROL_PAYLOAD) + payload["condition"]["evaluator"] = { "name": "list", "config": { "values": ["a", "b"], @@ -22,13 +27,13 @@ def test_validation_invalid_logic_enum(client: TestClient): "match_on": "match" } } - + # When: setting control data resp = client.put(f"/api/v1/controls/{control_id}/data", json={"data": payload}) # Then: 422 Unprocessable Entity assert resp.status_code == 422 - + # Then: error message mentions the field (RFC 7807 format) response_data = resp.json() errors = response_data.get("errors", []) @@ -40,8 +45,8 @@ def test_validation_discriminator_mismatch(client: TestClient): """Test that config must match the evaluator type.""" # Given: a control and type='list' but config has 'pattern' (RegexEvaluatorConfig) control_id = create_control(client) - payload = VALID_CONTROL_PAYLOAD.copy() - payload["evaluator"] = { + payload = deepcopy(VALID_CONTROL_PAYLOAD) + payload["condition"]["evaluator"] = { "name": "list", "config": { "pattern": "some_regex", # Invalid for ListEvaluatorConfig @@ -67,15 +72,15 @@ def test_validation_regex_flags_list(client: TestClient): """Test validation of regex flags list.""" # Given: a control and regex config with invalid flags type (string instead of list) control_id = create_control(client) - payload = VALID_CONTROL_PAYLOAD.copy() - payload["evaluator"] = { + payload = deepcopy(VALID_CONTROL_PAYLOAD) + payload["condition"]["evaluator"] = { "name": "regex", "config": { "pattern": "abc", "flags": "IGNORECASE" # Should be ["IGNORECASE"] } } - + # When: setting control data resp = client.put(f"/api/v1/controls/{control_id}/data", json={"data": payload}) @@ -90,21 +95,21 @@ def test_validation_invalid_regex_pattern(client: TestClient): """Test validation of regex pattern syntax.""" # Given: a control and regex config with invalid pattern (unclosed bracket) control_id = create_control(client) - payload = VALID_CONTROL_PAYLOAD.copy() - payload["evaluator"] = { + payload = deepcopy(VALID_CONTROL_PAYLOAD) + payload["condition"]["evaluator"] = { "name": "regex", "config": { "pattern": "[", # Invalid regex "flags": [] } } - + # When: setting control data resp = client.put(f"/api/v1/controls/{control_id}/data", json={"data": payload}) # Then: 422 Unprocessable Entity (RFC 7807 format) assert resp.status_code == 422 - + response_data = resp.json() errors = response_data.get("errors", []) # Then: error message mentions regex compilation failure @@ -116,8 +121,8 @@ def test_validation_empty_string_path_rejected(client: TestClient): """Test that empty string path is rejected.""" # Given: a control and payload with empty string path control_id = create_control(client) - payload = VALID_CONTROL_PAYLOAD.copy() - payload["selector"] = {"path": ""} + payload = deepcopy(VALID_CONTROL_PAYLOAD) + payload["condition"]["selector"] = {"path": ""} # When: setting control data resp = client.put(f"/api/v1/controls/{control_id}/data", json={"data": payload}) @@ -136,8 +141,8 @@ def test_validation_none_path_defaults_to_star(client: TestClient): """Test that None/missing path defaults to '*'.""" # Given: a control and payload without path in selector (None) control_id = create_control(client) - payload = VALID_CONTROL_PAYLOAD.copy() - payload["selector"] = {} # No path specified + payload = deepcopy(VALID_CONTROL_PAYLOAD) + payload["condition"]["selector"] = {} # No path specified # When: setting control data resp = client.put(f"/api/v1/controls/{control_id}/data", json={"data": payload}) @@ -151,7 +156,7 @@ def test_validation_none_path_defaults_to_star(client: TestClient): # Then: path should default to '*' data = get_resp.json()["data"] - assert data["selector"]["path"] == "*" + assert data["condition"]["selector"]["path"] == "*" def test_get_control_data_returns_typed_response(client: TestClient): @@ -171,9 +176,8 @@ def test_get_control_data_returns_typed_response(client: TestClient): data = resp_get.json()["data"] # Should have required ControlDefinition fields - assert "evaluator" in data + assert "condition" in data assert "action" in data - assert "selector" in data assert "execution" in data assert "scope" in data @@ -182,7 +186,7 @@ def test_validation_empty_step_names_rejected(client: TestClient): """Test that empty step_names list is rejected.""" # Given: a control and payload with empty step_names list control_id = create_control(client) - payload = VALID_CONTROL_PAYLOAD.copy() + payload = deepcopy(VALID_CONTROL_PAYLOAD) payload["scope"] = {"step_names": []} # When: setting control data @@ -196,3 +200,89 @@ def test_validation_empty_step_names_rejected(client: TestClient): errors = response_data.get("errors", []) assert any("step_names" in str(e.get("field", "")) for e in errors) assert any("empty list" in e.get("message", "") for e in errors) + + +def test_validation_nested_condition_error_uses_bracketed_field_path( + client: TestClient, +): + """Nested condition leaf errors should report full dot/bracket paths.""" + # Given: a nested condition whose first leaf has invalid evaluator config + control_id = create_control(client) + payload = deepcopy(VALID_CONTROL_PAYLOAD) + payload["condition"] = { + "and": [ + { + "selector": {"path": "input"}, + "evaluator": { + "name": "list", + "config": { + "values": ["a", "b"], + "logic": "invalid_logic", + "match_on": "match", + }, + }, + }, + { + "selector": {"path": "output"}, + "evaluator": { + "name": "regex", + "config": {"pattern": "ok"}, + }, + }, + ] + } + + # When: validating the nested control definition through the API + resp = client.put(f"/api/v1/controls/{control_id}/data", json={"data": payload}) + + # Then: the error points at the exact nested leaf path + assert resp.status_code == 422 + errors = resp.json().get("errors", []) + assert any( + err.get("field") == "data.condition.and[0].evaluator.logic" + for err in errors + ) + + +def test_validation_nested_agent_scoped_evaluator_error_uses_bracketed_field_path( + client: TestClient, +): + """Nested agent-scoped evaluator failures should identify the exact leaf path.""" + # Given: an agent and a nested condition that references a missing agent evaluator + agent_name = f"agent-{uuid.uuid4().hex[:12]}" + init_resp = client.post( + "/api/v1/agents/initAgent", + json={ + "agent": {"agent_name": agent_name}, + "steps": [], + "evaluators": [], + }, + ) + assert init_resp.status_code == 200 + + control_id = create_control(client) + payload = deepcopy(VALID_CONTROL_PAYLOAD) + payload["condition"] = { + "or": [ + { + "selector": {"path": "input"}, + "evaluator": { + "name": f"{agent_name}:missing-evaluator", + "config": {}, + }, + } + ] + } + + # When: validating the nested control definition through the API + resp = client.put(f"/api/v1/controls/{control_id}/data", json={"data": payload}) + + # Then: the error points at the exact nested evaluator name field + assert resp.status_code == 422 + body = resp.json() + assert body["error_code"] == "EVALUATOR_NOT_FOUND" + assert any( + err.get("field") == "data.condition.or[0].evaluator.name" + and err.get("code") == "evaluator_not_found" + for err in body.get("errors", []) + ) diff --git a/server/tests/test_error_handling.py b/server/tests/test_error_handling.py index 5f139431..f9689851 100644 --- a/server/tests/test_error_handling.py +++ b/server/tests/test_error_handling.py @@ -474,8 +474,10 @@ async def mock_db_returns_control() -> AsyncGenerator[AsyncSession, None]: "enabled": True, "execution": "server", "scope": {"step_types": ["llm"], "stages": ["pre"]}, - "selector": {"path": "input"}, - "evaluator": {"name": "regex", "config": {"pattern": "x"}}, + "condition": { + "selector": {"path": "input"}, + "evaluator": {"name": "regex", "config": {"pattern": "x"}}, + }, "action": {"decision": "deny"} } resp = client.put( diff --git a/server/tests/test_evaluation_e2e.py b/server/tests/test_evaluation_e2e.py index 8d3f6ad1..7ebb03b9 100644 --- a/server/tests/test_evaluation_e2e.py +++ b/server/tests/test_evaluation_e2e.py @@ -1,8 +1,10 @@ """End-to-end tests for evaluation flow.""" import uuid + from fastapi.testclient import TestClient from agent_control_models import EvaluationRequest, Step -from .utils import create_and_assign_policy + +from .utils import canonicalize_control_payload, create_and_assign_policy def test_evaluation_flow_deny(client: TestClient): @@ -236,7 +238,10 @@ def test_evaluation_deny_precedence(client: TestClient): } resp = client.put("/api/v1/controls", json={"name": f"deny-control-{uuid.uuid4()}"}) deny_control_id = resp.json()["control_id"] - client.put(f"/api/v1/controls/{deny_control_id}/data", json={"data": control_deny}) + client.put( + f"/api/v1/controls/{deny_control_id}/data", + json={"data": canonicalize_control_payload(control_deny)}, + ) # Add Control to Agent's Policy client.post(f"/api/v1/policies/{policy_id}/controls/{deny_control_id}") diff --git a/server/tests/test_evaluation_error_handling.py b/server/tests/test_evaluation_error_handling.py index c984dfdc..d48d8b59 100644 --- a/server/tests/test_evaluation_error_handling.py +++ b/server/tests/test_evaluation_error_handling.py @@ -1,11 +1,18 @@ """End-to-end tests for evaluator error handling.""" import logging import uuid +from unittest.mock import AsyncMock, MagicMock +from agent_control_models import ControlMatch, EvaluationRequest, EvaluatorResult, Step from fastapi.testclient import TestClient -from agent_control_models import EvaluationRequest, Step +from agent_control_server.endpoints.evaluation import ( + SAFE_EVALUATOR_ERROR, + SAFE_EVALUATOR_TIMEOUT_ERROR, + _sanitize_control_match, +) from agent_control_server.observability.ingest.base import IngestResult + from .utils import create_and_assign_policy @@ -95,8 +102,6 @@ def test_evaluation_errors_field_populated_on_evaluator_failure( When: Evaluation is requested Then: Response has errors field populated and is_safe=False (for deny) """ - from unittest.mock import MagicMock, AsyncMock - # Given: an agent with a working control control_data = { "description": "Test control", @@ -154,11 +159,93 @@ def mock_get_evaluator_instance(config): ) assert "RuntimeError" not in data["errors"][0]["result"]["error"] assert "Simulated evaluator crash" not in data["errors"][0]["result"]["error"] + condition_trace = data["errors"][0]["result"]["metadata"]["condition_trace"] + assert condition_trace["error"] == SAFE_EVALUATOR_ERROR + assert condition_trace["message"] == SAFE_EVALUATOR_ERROR + assert "RuntimeError" not in condition_trace["error"] + assert "Simulated evaluator crash" not in condition_trace["message"] # And: no matches are returned because evaluation failed assert data["matches"] is None or len(data["matches"]) == 0 +def test_sanitize_control_match_redacts_nested_condition_trace_errors() -> None: + # Given: a control match whose nested condition trace contains raw evaluator errors + match = ControlMatch( + control_id=1, + control_name="nested-trace", + action="deny", + result=EvaluatorResult( + matched=False, + confidence=0.0, + error="RuntimeError: nested boom", + message="Condition evaluation failed: RuntimeError: nested boom", + metadata={ + "condition_trace": { + "type": "and", + "children": [ + { + "type": "leaf", + "error": "RuntimeError: nested boom", + "message": "Evaluation failed: RuntimeError: nested boom", + } + ], + } + }, + ), + ) + + # When: sanitizing the control match for API output + sanitized = _sanitize_control_match(match) + child_trace = sanitized.result.metadata["condition_trace"]["children"][0] + + # Then: both the top-level result and nested trace are redacted + assert sanitized.result.error == SAFE_EVALUATOR_ERROR + assert sanitized.result.message == SAFE_EVALUATOR_ERROR + assert child_trace["error"] == SAFE_EVALUATOR_ERROR + assert child_trace["message"] == SAFE_EVALUATOR_ERROR + + +def test_sanitize_control_match_redacts_nested_condition_trace_timeouts() -> None: + # Given: a control match whose nested condition trace contains timeout errors + match = ControlMatch( + control_id=1, + control_name="nested-timeout", + action="deny", + result=EvaluatorResult( + matched=False, + confidence=0.0, + error="TimeoutError: Evaluator exceeded 30s timeout", + message="Condition evaluation failed: TimeoutError: Evaluator exceeded 30s timeout", + metadata={ + "condition_trace": { + "type": "or", + "children": [ + { + "type": "leaf", + "error": "TimeoutError: Evaluator exceeded 30s timeout", + "message": ( + "Evaluation failed: TimeoutError: " + "Evaluator exceeded 30s timeout" + ), + } + ], + } + }, + ), + ) + + # When: sanitizing the control match for API output + sanitized = _sanitize_control_match(match) + child_trace = sanitized.result.metadata["condition_trace"]["children"][0] + + # Then: both the top-level result and nested trace use the safe timeout text + assert sanitized.result.error == SAFE_EVALUATOR_TIMEOUT_ERROR + assert sanitized.result.message == SAFE_EVALUATOR_TIMEOUT_ERROR + assert child_trace["error"] == SAFE_EVALUATOR_TIMEOUT_ERROR + assert child_trace["message"] == SAFE_EVALUATOR_TIMEOUT_ERROR + + def test_evaluation_engine_value_error_returns_422(client: TestClient, monkeypatch) -> None: """Test that evaluation returns 422 when the engine raises a ValueError.""" # Given: a valid agent with a control assigned diff --git a/server/tests/test_init_agent_conflict_mode.py b/server/tests/test_init_agent_conflict_mode.py index e39f1c02..98af20e0 100644 --- a/server/tests/test_init_agent_conflict_mode.py +++ b/server/tests/test_init_agent_conflict_mode.py @@ -47,7 +47,7 @@ def _create_policy_with_agent_evaluator_control( control_id = create_control_resp.json()["control_id"] control_data = deepcopy(VALID_CONTROL_PAYLOAD) - control_data["evaluator"] = { + control_data["condition"]["evaluator"] = { "name": f"{agent_name}:{evaluator_name}", "config": {}, } diff --git a/server/tests/test_new_features.py b/server/tests/test_new_features.py index ad5a91c8..f79682a2 100644 --- a/server/tests/test_new_features.py +++ b/server/tests/test_new_features.py @@ -4,6 +4,8 @@ from fastapi.testclient import TestClient +from .utils import canonicalize_control_payload + def make_agent_payload( agent_name: str | None = None, @@ -269,7 +271,7 @@ def _create_policy_with_control( # Set control data data_resp = client.put( f"/api/v1/controls/{control_id}/data", - json={"data": control_data}, + json={"data": canonicalize_control_payload(control_data)}, ) assert data_resp.status_code == 200 @@ -357,8 +359,10 @@ def test_control_creation_with_unregistered_evaluator_fails(client: TestClient) "data": { "execution": "server", "scope": {"step_types": ["llm"], "stages": ["pre"]}, - "selector": {"path": "input"}, - "evaluator": {"name": f"{agent_name}:nonexistent-eval", "config": {}}, + "condition": { + "selector": {"path": "input"}, + "evaluator": {"name": f"{agent_name}:nonexistent-eval", "config": {}}, + }, "action": {"decision": "deny"}, } }, diff --git a/server/tests/test_validation_paths.py b/server/tests/test_validation_paths.py new file mode 100644 index 00000000..a1f5c736 --- /dev/null +++ b/server/tests/test_validation_paths.py @@ -0,0 +1,22 @@ +"""Tests for nested validation field path formatting.""" + +from agent_control_server.services.validation_paths import format_field_path + + +def test_format_field_path_renders_dot_and_bracket_notation() -> None: + # Given: nested string and integer path parts + # When: formatting the field path + assert ( + format_field_path( + ("data", "condition", "and", 0, "evaluator", "config", "logic") + ) + == "data.condition.and[0].evaluator.config.logic" + ) + # Then: indices use brackets and object keys use dots + + +def test_format_field_path_empty_sequence_returns_none() -> None: + # Given: an empty sequence of field parts + # When: formatting the field path + assert format_field_path(()) is None + # Then: no field path is returned diff --git a/server/tests/utils.py b/server/tests/utils.py index a2a32098..64c455c5 100644 --- a/server/tests/utils.py +++ b/server/tests/utils.py @@ -1,20 +1,32 @@ """Test utilities for server tests.""" import uuid +from copy import deepcopy from typing import Any -from fastapi.testclient import TestClient +from agent_control_models import ControlDefinition +from fastapi.testclient import TestClient VALID_CONTROL_PAYLOAD = { "description": "Valid Control", "enabled": True, "execution": "server", "scope": {"step_types": ["llm"], "stages": ["pre"]}, - "selector": {"path": "input"}, - "evaluator": {"name": "regex", "config": {"pattern": "x"}}, + "condition": { + "selector": {"path": "input"}, + "evaluator": {"name": "regex", "config": {"pattern": "x"}}, + }, "action": {"decision": "deny"} } +def canonicalize_control_payload(payload: dict[str, Any]) -> dict[str, Any]: + """Convert legacy flat test payloads into canonical condition trees.""" + canonical = ControlDefinition.canonicalize_payload(deepcopy(payload)) + if not isinstance(canonical, dict): + raise TypeError("Control payload canonicalization must return a dict.") + return canonical + + def create_and_assign_policy( client: TestClient, control_config: dict[str, Any] | None = None, @@ -31,7 +43,9 @@ def create_and_assign_policy( tuple: (agent_name, control_name) """ if control_config is None: - control_config = VALID_CONTROL_PAYLOAD.copy() + control_config = deepcopy(VALID_CONTROL_PAYLOAD) + else: + control_config = canonicalize_control_payload(control_config) # 1. Create Control control_name = f"control-{uuid.uuid4()}" diff --git a/ui/src/core/api/generated/api-types.ts b/ui/src/core/api/generated/api-types.ts index d54bbae1..55b9538c 100644 --- a/ui/src/core/api/generated/api-types.ts +++ b/ui/src/core/api/generated/api-types.ts @@ -4,6 +4,73 @@ */ export interface paths { + '/api/config': { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + /** + * UI configuration + * @description Return configuration flags that drive UI behavior. + * + * If authentication is enabled, this also reports whether the current + * request has an active session (via header or cookie), allowing the UI + * to skip the login prompt on refresh when a valid cookie is present. + */ + get: operations['get_config_api_config_get']; + put?: never; + post?: never; + delete?: never; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; + '/api/login': { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + get?: never; + put?: never; + /** + * Login with API key + * @description Validate an API key and issue a signed JWT session cookie. + * + * The raw API key is transmitted only in this single request and is never + * stored in the cookie. Subsequent requests authenticate via the JWT. + */ + post: operations['login_api_login_post']; + delete?: never; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; + '/api/logout': { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + get?: never; + put?: never; + /** + * Logout (clear session cookie) + * @description Clear the session cookie. + */ + post: operations['logout_api_logout_post']; + delete?: never; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; '/api/v1/agents': { parameters: { query?: never; @@ -62,7 +129,7 @@ export interface paths { * db: Database session (injected) * * Returns: - * InitAgentResponse with created flag and active controls (if policy assigned) + * InitAgentResponse with created flag and active controls (policy-derived + direct) */ post: operations['init_agent_api_v1_agents_initAgent_post']; delete?: never; @@ -1066,6 +1133,12 @@ export interface components { */ success: boolean; }; + /** + * AuthMode + * @description Authentication mode advertised to the UI. + * @enum {string} + */ + AuthMode: 'none' | 'api-key'; /** * BatchEventsRequest * @description Request model for batch event ingestion. @@ -1132,6 +1205,154 @@ export interface components { */ status: 'queued' | 'partial' | 'failed'; }; + /** + * ConditionNode + * @description Recursive boolean condition tree for control evaluation. + * @example { + * "evaluator": { + * "config": { + * "pattern": "\\d{3}-\\d{2}-\\d{4}" + * }, + * "name": "regex" + * }, + * "selector": { + * "path": "output" + * } + * } + * @example { + * "and": [ + * { + * "evaluator": { + * "config": { + * "values": [ + * "high", + * "critical" + * ] + * }, + * "name": "list" + * }, + * "selector": { + * "path": "context.risk_level" + * } + * }, + * { + * "not": { + * "evaluator": { + * "config": { + * "values": [ + * "admin", + * "security" + * ] + * }, + * "name": "list" + * }, + * "selector": { + * "path": "context.user_role" + * } + * } + * } + * ] + * } + */ + 'ConditionNode-Input': { + /** + * And + * @description Logical AND over child conditions. + */ + and?: components['schemas']['ConditionNode-Input'][] | null; + /** @description Leaf evaluator. Must be provided together with selector. */ + evaluator?: components['schemas']['EvaluatorSpec'] | null; + /** @description Logical NOT over a single child condition. */ + not?: components['schemas']['ConditionNode-Input'] | null; + /** + * Or + * @description Logical OR over child conditions. + */ + or?: components['schemas']['ConditionNode-Input'][] | null; + /** @description Leaf selector. Must be provided together with evaluator. */ + selector?: components['schemas']['ControlSelector'] | null; + }; + /** + * ConditionNode + * @description Recursive boolean condition tree for control evaluation. + * @example { + * "evaluator": { + * "config": { + * "pattern": "\\d{3}-\\d{2}-\\d{4}" + * }, + * "name": "regex" + * }, + * "selector": { + * "path": "output" + * } + * } + * @example { + * "and": [ + * { + * "evaluator": { + * "config": { + * "values": [ + * "high", + * "critical" + * ] + * }, + * "name": "list" + * }, + * "selector": { + * "path": "context.risk_level" + * } + * }, + * { + * "not": { + * "evaluator": { + * "config": { + * "values": [ + * "admin", + * "security" + * ] + * }, + * "name": "list" + * }, + * "selector": { + * "path": "context.user_role" + * } + * } + * } + * ] + * } + */ + 'ConditionNode-Output': { + /** + * And + * @description Logical AND over child conditions. + */ + and?: components['schemas']['ConditionNode-Output'][] | null; + /** @description Leaf evaluator. Must be provided together with selector. */ + evaluator?: components['schemas']['EvaluatorSpec'] | null; + /** @description Logical NOT over a single child condition. */ + not?: components['schemas']['ConditionNode-Output'] | null; + /** + * Or + * @description Logical OR over child conditions. + */ + or?: components['schemas']['ConditionNode-Output'][] | null; + /** @description Leaf selector. Must be provided together with evaluator. */ + selector?: components['schemas']['ControlSelector'] | null; + }; + /** + * ConfigResponse + * @description Configuration surface exposed to the UI. + */ + ConfigResponse: { + auth_mode: components['schemas']['AuthMode']; + /** + * Has Active Session + * @default false + */ + has_active_session: boolean; + /** Requires Api Key */ + requires_api_key: boolean; + }; /** * ConflictMode * @description Conflict handling mode for initAgent registration updates. @@ -1179,14 +1400,19 @@ export interface components { * "action": { * "decision": "deny" * }, - * "description": "Block outputs containing US Social Security Numbers", - * "enabled": true, - * "evaluator": { - * "config": { - * "pattern": "\\b\\d{3}-\\d{2}-\\d{4}\\b" + * "condition": { + * "evaluator": { + * "config": { + * "pattern": "\\b\\d{3}-\\d{2}-\\d{4}\\b" + * }, + * "name": "regex" * }, - * "name": "regex" + * "selector": { + * "path": "output" + * } * }, + * "description": "Block outputs containing US Social Security Numbers", + * "enabled": true, * "execution": "server", * "scope": { * "stages": [ @@ -1196,9 +1422,6 @@ export interface components { * "llm" * ] * }, - * "selector": { - * "path": "output" - * }, * "tags": [ * "pii", * "compliance" @@ -1208,6 +1431,8 @@ export interface components { 'ControlDefinition-Input': { /** @description What action to take when control matches */ action: components['schemas']['ControlAction']; + /** @description Recursive boolean condition tree. Leaf nodes contain selector + evaluator; composite nodes contain and/or/not. */ + condition: components['schemas']['ConditionNode-Input']; /** * Description * @description Detailed description of the control @@ -1219,8 +1444,6 @@ export interface components { * @default true */ enabled: boolean; - /** @description How to evaluate the selected data */ - evaluator: components['schemas']['EvaluatorSpec']; /** * Execution * @description Where this control executes @@ -1229,8 +1452,6 @@ export interface components { execution: 'server' | 'sdk'; /** @description Which steps and stages this control applies to */ scope?: components['schemas']['ControlScope']; - /** @description What data to select from the payload */ - selector: components['schemas']['ControlSelector']; /** * Tags * @description Tags for categorization @@ -1247,14 +1468,19 @@ export interface components { * "action": { * "decision": "deny" * }, - * "description": "Block outputs containing US Social Security Numbers", - * "enabled": true, - * "evaluator": { - * "config": { - * "pattern": "\\b\\d{3}-\\d{2}-\\d{4}\\b" + * "condition": { + * "evaluator": { + * "config": { + * "pattern": "\\b\\d{3}-\\d{2}-\\d{4}\\b" + * }, + * "name": "regex" * }, - * "name": "regex" + * "selector": { + * "path": "output" + * } * }, + * "description": "Block outputs containing US Social Security Numbers", + * "enabled": true, * "execution": "server", * "scope": { * "stages": [ @@ -1264,9 +1490,6 @@ export interface components { * "llm" * ] * }, - * "selector": { - * "path": "output" - * }, * "tags": [ * "pii", * "compliance" @@ -1276,6 +1499,8 @@ export interface components { 'ControlDefinition-Output': { /** @description What action to take when control matches */ action: components['schemas']['ControlAction']; + /** @description Recursive boolean condition tree. Leaf nodes contain selector + evaluator; composite nodes contain and/or/not. */ + condition: components['schemas']['ConditionNode-Output']; /** * Description * @description Detailed description of the control @@ -1287,8 +1512,6 @@ export interface components { * @default true */ enabled: boolean; - /** @description How to evaluate the selected data */ - evaluator: components['schemas']['EvaluatorSpec']; /** * Execution * @description Where this control executes @@ -1297,8 +1520,6 @@ export interface components { execution: 'server' | 'sdk'; /** @description Which steps and stages this control applies to */ scope?: components['schemas']['ControlScope']; - /** @description What data to select from the payload */ - selector: components['schemas']['ControlSelector']; /** * Tags * @description Tags for categorization @@ -2619,6 +2840,24 @@ export interface components { evaluators: components['schemas']['EvaluatorSchemaItem'][]; pagination: components['schemas']['PaginationInfo']; }; + /** + * LoginRequest + * @description Request body for the /login endpoint. + */ + LoginRequest: { + /** Api Key */ + api_key: string; + }; + /** + * LoginResponse + * @description Response body for the /login endpoint. + */ + LoginResponse: { + /** Authenticated */ + authenticated: boolean; + /** Is Admin */ + is_admin: boolean; + }; /** * PaginationInfo * @description Pagination metadata for cursor-based pagination. @@ -3116,6 +3355,77 @@ export interface components { } export type $defs = Record; export interface operations { + get_config_api_config_get: { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + requestBody?: never; + responses: { + /** @description Configuration flags for UI behavior */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + 'application/json': components['schemas']['ConfigResponse']; + }; + }; + }; + }; + login_api_login_post: { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + requestBody: { + content: { + 'application/json': components['schemas']['LoginRequest']; + }; + }; + responses: { + /** @description Authentication result; sets HttpOnly session cookie on success */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + 'application/json': components['schemas']['LoginResponse']; + }; + }; + /** @description Validation Error */ + 422: { + headers: { + [name: string]: unknown; + }; + content: { + 'application/json': components['schemas']['HTTPValidationError']; + }; + }; + }; + }; + logout_api_logout_post: { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + requestBody?: never; + responses: { + /** @description Successful Response */ + 204: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + }; + }; list_agents_api_v1_agents_get: { parameters: { query?: { diff --git a/ui/src/core/api/types.ts b/ui/src/core/api/types.ts index abb335a5..a4720f5f 100644 --- a/ui/src/core/api/types.ts +++ b/ui/src/core/api/types.ts @@ -76,6 +76,9 @@ export type ControlStage = NonNullable< export type ControlScope = components['schemas']['ControlScope']; export type ControlSelector = components['schemas']['ControlSelector']; export type ControlAction = components['schemas']['ControlAction']; +export type ConditionNodeInput = components['schemas']['ConditionNode-Input']; +export type ConditionNodeOutput = components['schemas']['ConditionNode-Output']; +export type ConditionNode = ConditionNodeInput | ConditionNodeOutput; export type ControlDefinitionInput = components['schemas']['ControlDefinition-Input']; export type ControlDefinitionOutput = diff --git a/ui/src/core/evaluators/list/index.ts b/ui/src/core/evaluators/list/index.ts index f9deeac5..197c221f 100644 --- a/ui/src/core/evaluators/list/index.ts +++ b/ui/src/core/evaluators/list/index.ts @@ -33,7 +33,8 @@ export const listEvaluator: EvaluatorDefinition = { toConfig: (values) => { // Convert newline-separated string to array, pass rest through directly - const valuesList = values.values.split('\n').filter((v) => v.trim() !== ''); + const rawValues = typeof values.values === 'string' ? values.values : ''; + const valuesList = rawValues.split('\n').filter((v) => v.trim() !== ''); return { ...values, values: valuesList, diff --git a/ui/src/core/page-components/agent-detail/modals/add-new-control/index.tsx b/ui/src/core/page-components/agent-detail/modals/add-new-control/index.tsx index 9fbf0949..5fba8fd6 100644 --- a/ui/src/core/page-components/agent-detail/modals/add-new-control/index.tsx +++ b/ui/src/core/page-components/agent-detail/modals/add-new-control/index.tsx @@ -123,12 +123,14 @@ export function AddNewControlModal({ step_types: ['llm'], stages: ['post'] as ('post' | 'pre')[], }, - selector: { - path: '*', - }, - evaluator: { - name: selectedEvaluator.id, - config: getDefaultConfigForEvaluator(selectedEvaluator.id), + condition: { + selector: { + path: '*', + }, + evaluator: { + name: selectedEvaluator.id, + config: getDefaultConfigForEvaluator(selectedEvaluator.id), + }, }, action: { decision: 'deny' as const }, }, diff --git a/ui/src/core/page-components/agent-detail/modals/edit-control/control-condition.ts b/ui/src/core/page-components/agent-detail/modals/edit-control/control-condition.ts new file mode 100644 index 00000000..3218f2ee --- /dev/null +++ b/ui/src/core/page-components/agent-detail/modals/edit-control/control-condition.ts @@ -0,0 +1,76 @@ +import type { ControlDefinition } from '@/core/api/types'; +import type { AnyEvaluatorDefinition } from '@/core/evaluators'; +import { getEvaluator } from '@/core/evaluators'; + +const COMPOSITE_CONDITION_EDITING_MESSAGE = + 'This control uses a composite condition tree. This PR keeps the old single-condition UI, so saving will preserve the existing tree without editing it.'; + +export type LeafConditionDetails = { + selectorPath: string; + evaluatorName: string; + evaluatorConfig: Record; +}; + +export type ControlConditionState = { + leafCondition: LeafConditionDetails | null; + evaluatorId: string; + evaluator: AnyEvaluatorDefinition | undefined; + canEditLeafCondition: boolean; + conditionEditingMessage: string | null; +}; + +function getLeafConditionDetails( + definition: ControlDefinition +): LeafConditionDetails | null { + const condition = definition.condition; + if (!condition.selector || !condition.evaluator) { + return null; + } + + return { + selectorPath: condition.selector.path ?? '*', + evaluatorName: condition.evaluator.name, + evaluatorConfig: condition.evaluator.config, + }; +} + +export function getControlConditionState( + definition: ControlDefinition +): ControlConditionState { + const leafCondition = getLeafConditionDetails(definition); + const evaluatorId = leafCondition?.evaluatorName ?? ''; + const evaluator = getEvaluator(evaluatorId); + + return { + leafCondition, + evaluatorId, + evaluator, + canEditLeafCondition: Boolean(leafCondition && evaluator), + conditionEditingMessage: leafCondition + ? evaluator + ? null + : `This control uses the "${leafCondition.evaluatorName}" evaluator, which does not have a UI editor here yet. Saving will preserve its current condition.` + : COMPOSITE_CONDITION_EDITING_MESSAGE, + }; +} + +export function buildEditableCondition( + definition: ControlDefinition, + leafCondition: LeafConditionDetails | null, + selectorPath: string, + finalConfig: Record +): ControlDefinition['condition'] { + if (!leafCondition) { + return definition.condition; + } + + return { + selector: { + path: selectorPath, + }, + evaluator: { + name: leafCondition.evaluatorName, + config: finalConfig, + }, + }; +} diff --git a/ui/src/core/page-components/agent-detail/modals/edit-control/control-definition-form.tsx b/ui/src/core/page-components/agent-detail/modals/edit-control/control-definition-form.tsx index afa18567..61e266d0 100644 --- a/ui/src/core/page-components/agent-detail/modals/edit-control/control-definition-form.tsx +++ b/ui/src/core/page-components/agent-detail/modals/edit-control/control-definition-form.tsx @@ -30,6 +30,7 @@ export type ControlDefinitionFormWithStepsProps = ControlDefinitionFormProps & { export const ControlDefinitionForm = ({ form, steps, + disableSelectorPath = false, }: ControlDefinitionFormWithStepsProps) => { return ( @@ -86,6 +87,7 @@ export const ControlDefinitionForm = ({ )} size="sm" placeholder="e.g., input or input.args.command" + disabled={disableSelectorPath} {...form.getInputProps('selector_path')} /> diff --git a/ui/src/core/page-components/agent-detail/modals/edit-control/edit-control-content.tsx b/ui/src/core/page-components/agent-detail/modals/edit-control/edit-control-content.tsx index da1e226e..9d3e0674 100644 --- a/ui/src/core/page-components/agent-detail/modals/edit-control/edit-control-content.tsx +++ b/ui/src/core/page-components/agent-detail/modals/edit-control/edit-control-content.tsx @@ -1,4 +1,12 @@ -import { Box, Divider, Grid, Group, Text, TextInput } from '@mantine/core'; +import { + Alert, + Box, + Divider, + Grid, + Group, + Text, + TextInput, +} from '@mantine/core'; import { useForm } from '@mantine/form'; import { modals } from '@mantine/modals'; import { notifications } from '@mantine/notifications'; @@ -6,8 +14,11 @@ import { Button } from '@rungalileo/jupiter-ds'; import { useCallback, useEffect, useMemo, useRef, useState } from 'react'; import { isApiError } from '@/core/api/errors'; -import type { Control, ProblemDetail } from '@/core/api/types'; -import { getEvaluator } from '@/core/evaluators'; +import type { + Control, + ControlDefinition, + ProblemDetail, +} from '@/core/api/types'; import { useAddControlToAgent } from '@/core/hooks/query-hooks/use-add-control-to-agent'; import { useAgent } from '@/core/hooks/query-hooks/use-agent'; import { useUpdateControl } from '@/core/hooks/query-hooks/use-update-control'; @@ -15,6 +26,10 @@ import { useUpdateControlMetadata } from '@/core/hooks/query-hooks/use-update-co import { useValidateControlData } from '@/core/hooks/query-hooks/use-validate-control-data'; import { ApiErrorAlert } from './api-error-alert'; +import { + buildEditableCondition, + getControlConditionState, +} from './control-condition'; import { ControlDefinitionForm } from './control-definition-form'; import { EvaluatorConfigSection } from './evaluator-config-section'; import type { ControlDefinitionFormValues, EditControlMode } from './types'; @@ -43,17 +58,14 @@ export const EditControlContent = ({ onClose, onSuccess, }: EditControlContentProps) => { - // Fetch agent data to get steps - React Query will dedupe requests const { data: agentResponse } = useAgent(agentId); const steps = agentResponse?.steps ?? []; - // API error state + const [apiError, setApiError] = useState(null); - // Errors that couldn't be mapped to form fields (shown in Alert) const [unmappedErrors, setUnmappedErrors] = useState< Array<{ field: string | null; message: string }> >([]); - // Mutation hooks const updateControl = useUpdateControl(); const updateControlMetadata = useUpdateControlMetadata(); const addControlToAgent = useAddControlToAgent(); @@ -63,14 +75,18 @@ export const EditControlContent = ({ ? addControlToAgent.isPending : updateControl.isPending || updateControlMetadata.isPending; - // Track which evaluator the evaluator form has been initialized for const formInitializedForEvaluator = useRef(''); + const { + leafCondition, + evaluatorId, + evaluator, + canEditLeafCondition, + conditionEditingMessage, + } = useMemo( + () => getControlConditionState(control.control), + [control.control] + ); - // Get evaluator for this control - const evaluatorId = control.control.evaluator.name || ''; - const evaluator = useMemo(() => getEvaluator(evaluatorId), [evaluatorId]); - - // Control definition form const definitionForm = useForm({ initialValues: { name: '', @@ -89,10 +105,12 @@ export const EditControlContent = ({ validate: { name: (value) => (!value?.trim() ? 'Control name is required' : null), selector_path: (value) => { + if (!canEditLeafCondition) { + return null; + } if (!value?.trim()) { return 'Selector path is required'; } - // Validate root field matches backend validation const validRoots = ['input', 'output', 'name', 'type', 'context', '*']; const root = value.split('.')[0]; if (!validRoots.includes(root)) { @@ -103,34 +121,53 @@ export const EditControlContent = ({ }, }); - // Evaluator config form - dynamically configured based on evaluator const evaluatorForm = useForm({ initialValues: evaluator?.initialValues ?? {}, validate: evaluator?.validate, }); - // Get config from evaluator form - // If form hasn't been initialized for current evaluator yet, use initial values to avoid crashes - const getEvaluatorConfig = () => { - if (!evaluator) return {}; + const getEvaluatorConfig = useCallback(() => { + if (!leafCondition) { + return {}; + } + if (!evaluator) { + return leafCondition.evaluatorConfig; + } if (formInitializedForEvaluator.current !== evaluatorId) { return evaluator.toConfig(evaluator.initialValues); } return evaluator.toConfig(evaluatorForm.values); - }; + }, [evaluator, evaluatorForm.values, evaluatorId, leafCondition]); - // Sync JSON to form - const syncJsonToForm = (config: Record) => { - if (evaluator) { - evaluatorForm.setValues(evaluator.fromConfig(config)); - } - }; + const syncJsonToForm = useCallback( + (config: Record) => { + if (evaluator) { + evaluatorForm.setValues(evaluator.fromConfig(config)); + } + }, + [evaluator, evaluatorForm] + ); + + const buildCondition = useCallback( + ( + values: ControlDefinitionFormValues, + finalConfig: Record + ): ControlDefinition['condition'] => { + return buildEditableCondition( + control.control, + leafCondition, + values.selector_path.trim(), + finalConfig + ); + }, + [control.control, leafCondition] + ); const buildControlDefinition = useCallback( ( values: ControlDefinitionFormValues, finalConfig: Record - ) => { + ): ControlDefinition => { const stepTypes = values.step_types .map((value) => value.trim()) .filter(Boolean); @@ -148,12 +185,11 @@ export const EditControlContent = ({ if (values.stages.length > 0) scope.stages = values.stages; return { - ...control.control, description: values.description?.trim() || undefined, enabled: values.enabled, execution: values.execution, scope: Object.keys(scope).length > 0 ? scope : undefined, - selector: { ...control.control.selector, path: values.selector_path }, + condition: buildCondition(values, finalConfig), action: { decision: values.action_decision, ...(values.action_decision === 'steer' && @@ -165,18 +201,18 @@ export const EditControlContent = ({ } : {}), }, - evaluator: { ...control.control.evaluator, config: finalConfig }, + tags: control.control.tags, }; }, - [control.control] + [buildCondition, control.control.tags] ); const buildDefinitionForValidation = useCallback( - (finalConfig: Record) => ({ + (finalConfig: Record): ControlDefinition => ({ ...control.control, - evaluator: { ...control.control.evaluator, config: finalConfig }, + condition: buildCondition(definitionForm.values, finalConfig), }), - [control.control] + [buildCondition, control.control, definitionForm.values] ); const validateEvaluatorConfig = useCallback( @@ -198,9 +234,8 @@ export const EditControlContent = ({ onValidateConfig: validateEvaluatorConfig, }); - const { isJsonInvalid, reset } = evaluatorConfig; + const { reset } = evaluatorConfig; - // Clear steering_context when switching away from steer action useEffect(() => { if (definitionForm.values.action_decision !== 'steer') { definitionForm.setFieldValue('action_steering_context', ''); @@ -208,75 +243,85 @@ export const EditControlContent = ({ // eslint-disable-next-line react-hooks/exhaustive-deps }, [definitionForm.values.action_decision]); - // Reset view mode and errors when evaluator changes useEffect(() => { reset(); setApiError(null); setUnmappedErrors([]); - }, [reset, evaluatorId]); + formInitializedForEvaluator.current = ''; + }, [reset, evaluatorId, control.id]); - // Load control data into forms useEffect(() => { - if (control && evaluator) { - const scope = control.control.scope ?? {}; - const stepNamesValue = (scope.step_names ?? []).join(', '); - const stepRegexValue = scope.step_name_regex ?? ''; - const stepNameMode = - stepRegexValue && !stepNamesValue ? 'regex' : 'names'; - definitionForm.setValues({ - name: control.name, - description: control.control.description ?? '', - enabled: control.control.enabled, - step_types: scope.step_types ?? [], - stages: scope.stages ?? [], - step_names: stepNamesValue, - step_name_regex: stepRegexValue, - step_name_mode: stepNameMode, - selector_path: control.control.selector.path ?? '*', - action_decision: control.control.action.decision, - action_steering_context: - control.control.action.decision === 'steer' - ? (control.control.action.steering_context?.message ?? '') - : '', - execution: control.control.execution ?? 'server', - }); + const scope = control.control.scope ?? {}; + const stepNamesValue = (scope.step_names ?? []).join(', '); + const stepRegexValue = scope.step_name_regex ?? ''; + const stepNameMode = stepRegexValue && !stepNamesValue ? 'regex' : 'names'; + + definitionForm.setValues({ + name: control.name, + description: control.control.description ?? '', + enabled: control.control.enabled, + step_types: scope.step_types ?? [], + stages: scope.stages ?? [], + step_names: stepNamesValue, + step_name_regex: stepRegexValue, + step_name_mode: stepNameMode, + selector_path: leafCondition?.selectorPath ?? '*', + action_decision: control.control.action.decision, + action_steering_context: + control.control.action.decision === 'steer' + ? (control.control.action.steering_context?.message ?? '') + : '', + execution: control.control.execution ?? 'server', + }); + + if (leafCondition && evaluator) { evaluatorForm.setValues( - evaluator.fromConfig(control.control.evaluator.config) + evaluator.fromConfig(leafCondition.evaluatorConfig) ); - // Mark form as initialized for this evaluator formInitializedForEvaluator.current = evaluatorId; } // eslint-disable-next-line react-hooks/exhaustive-deps - }, [control, evaluator, evaluatorId]); + }, [control, evaluator, evaluatorId, leafCondition]); - // Handle form submission const handleSubmit = async (values: ControlDefinitionFormValues) => { - // Clear previous errors setApiError(null); setUnmappedErrors([]); definitionForm.clearErrors(); evaluatorForm.clearErrors(); - let finalConfig: Record; - - if (evaluatorConfig.configViewMode === 'json') { - const jsonConfig = evaluatorConfig.getJsonConfig(); - if (!jsonConfig) return; - finalConfig = jsonConfig; - } else { - // Validate evaluator form - const validation = evaluatorForm.validate(); - if (validation.hasErrors) return; - finalConfig = getEvaluatorConfig(); + if ( + values.action_decision === 'steer' && + !leafCondition && + !values.action_steering_context?.trim() + ) { + definitionForm.setFieldError( + 'action_steering_context', + 'Composite steer controls require steering context' + ); + return; + } + + let finalConfig: Record = + leafCondition?.evaluatorConfig ?? {}; + + if (canEditLeafCondition) { + if (evaluatorConfig.configViewMode === 'json') { + const jsonConfig = evaluatorConfig.getJsonConfig(); + if (!jsonConfig) return; + finalConfig = jsonConfig; + } else { + const validation = evaluatorForm.validate(); + if (validation.hasErrors) return; + finalConfig = getEvaluatorConfig(); + } } - const definition = buildControlDefinition( - { ...values, name: values.name }, - finalConfig - ); + const definition = buildControlDefinition(values, finalConfig); const runSave = async () => { try { + await validateControlDataAsync({ definition }); + if (isCreating) { await addControlToAgent.mutateAsync({ agentId, @@ -311,14 +356,12 @@ export const EditControlContent = ({ problemDetail.detail || 'Control name already exists' ); } else if (problemDetail.status === 422) { - // Mirror the main error-handling behavior so validation errors - // render inline (and in the alert when unmapped). setApiError(problemDetail); if (problemDetail.errors) { const unmapped = applyApiErrorsToForms( problemDetail.errors, definitionForm, - evaluatorForm + canEditLeafCondition ? evaluatorForm : null ); setUnmappedErrors( unmapped.map((e) => ({ @@ -361,8 +404,7 @@ export const EditControlContent = ({ color: 'green', }); } - // Call onSuccess first (which should close all modals) - // Only call onClose if onSuccess is not provided (for backward compatibility) + if (onSuccess) { onSuccess(); } else { @@ -371,44 +413,43 @@ export const EditControlContent = ({ } catch (error) { if (isApiError(error)) { const problemDetail = error.problemDetail; - - // Check if this is a "name already exists" error (409 Conflict) - // and map it to the name field if it's not already in the errors array const isNameExistsError = (problemDetail.status === 409 || problemDetail.error_code === 'CONTROL_NAME_CONFLICT') && - !problemDetail.errors?.some((e) => e.field === 'name'); + !problemDetail.errors?.some((item) => item.field === 'name'); if (isNameExistsError) { - // Set error directly on the name field definitionForm.setFieldError( 'name', problemDetail.detail || 'Control name already exists' ); - // Don't show it in the alert since it's now on the field setApiError(null); setUnmappedErrors([]); - } else { - setApiError(problemDetail); - - if (problemDetail.errors) { - if (evaluatorConfig.configViewMode === 'form') { - const unmapped = applyApiErrorsToForms( - problemDetail.errors, - definitionForm, - evaluatorForm - ); - setUnmappedErrors( - unmapped.map((e) => ({ field: e.field, message: e.message })) - ); - } else { - setUnmappedErrors( - problemDetail.errors.map((e) => ({ - field: e.field, - message: e.message, - })) - ); - } + return; + } + + setApiError(problemDetail); + + if (problemDetail.errors) { + if (evaluatorConfig.configViewMode === 'form') { + const unmapped = applyApiErrorsToForms( + problemDetail.errors, + definitionForm, + canEditLeafCondition ? evaluatorForm : null + ); + setUnmappedErrors( + unmapped.map((e) => ({ + field: e.field, + message: e.message, + })) + ); + } else { + setUnmappedErrors( + problemDetail.errors.map((e) => ({ + field: e.field, + message: e.message, + })) + ); } } } else { @@ -448,8 +489,7 @@ export const EditControlContent = ({ }); }; - // Render the evaluator's form component - const FormComponent = evaluator?.FormComponent; + const formComponent = evaluator?.FormComponent; return ( @@ -476,22 +516,35 @@ export const EditControlContent = ({ - + - + {canEditLeafCondition ? ( + + ) : ( + + {conditionEditingMessage} + + )} - {/* API Error Alert */} {apiError ? ( <> @@ -503,7 +556,6 @@ export const EditControlContent = ({ ) : null} - {/* Buttons */} diff --git a/ui/src/core/page-components/agent-detail/modals/edit-control/types.ts b/ui/src/core/page-components/agent-detail/modals/edit-control/types.ts index 9426f85a..98caf9d5 100644 --- a/ui/src/core/page-components/agent-detail/modals/edit-control/types.ts +++ b/ui/src/core/page-components/agent-detail/modals/edit-control/types.ts @@ -80,4 +80,5 @@ export type EvaluatorJsonViewProps = { export type ControlDefinitionFormProps = { form: UseFormReturnType; + disableSelectorPath?: boolean; }; diff --git a/ui/src/core/page-components/agent-detail/modals/edit-control/utils.ts b/ui/src/core/page-components/agent-detail/modals/edit-control/utils.ts index 1cccc2fc..b962de8a 100644 --- a/ui/src/core/page-components/agent-detail/modals/edit-control/utils.ts +++ b/ui/src/core/page-components/agent-detail/modals/edit-control/utils.ts @@ -23,9 +23,8 @@ type FieldMapping = { * API field paths look like: * - "name" (control name) * - "data.scope.step_types" (definition field) - * - "data.selector.path" → selector_path (definition field) - * - "data.evaluator.config.pattern" (evaluator config field) - * - "data.evaluator.field_types" (evaluator config field without config prefix) + * - "data.condition.selector.path" → selector_path (definition field) + * - "data.condition.evaluator.config.pattern" (evaluator config field) * * Since forms use snake_case, we can directly use the API field names. * @@ -50,32 +49,32 @@ export function mapApiFieldToFormField( const fieldPath = apiField.slice(dataPrefix.length); - // Handle evaluator fields - API may return either: - // - "evaluator.{field}" (e.g., "evaluator.field_types") - // - "evaluator.config.{field}" (e.g., "evaluator.config.pattern") - const evalPrefix = 'evaluator.'; - if (fieldPath.startsWith(evalPrefix)) { - let configField = fieldPath.slice(evalPrefix.length); - - // Strip "config." prefix if present - const configPrefix = 'config.'; - if (configField.startsWith(configPrefix)) { - configField = configField.slice(configPrefix.length); + const leafConditionPrefix = 'condition.'; + if (fieldPath.startsWith(leafConditionPrefix)) { + const conditionField = fieldPath.slice(leafConditionPrefix.length); + + if (conditionField === 'selector.path') { + return { form: 'definition', field: 'selector_path' }; } - // For nested paths like "field_types.name", use the first segment - const firstDotIndex = configField.indexOf('.'); - const field = - firstDotIndex > 0 ? configField.slice(0, firstDotIndex) : configField; + const evalPrefix = 'evaluator.'; + if (conditionField.startsWith(evalPrefix)) { + let configField = conditionField.slice(evalPrefix.length); + if (configField.startsWith('config.')) { + configField = configField.slice('config.'.length); + } + + const firstDotIndex = configField.indexOf('.'); + const field = + firstDotIndex > 0 ? configField.slice(0, firstDotIndex) : configField; + + return { form: 'evaluator', field }; + } - return { form: 'evaluator', field }; + return null; } // Handle definition fields - // Map nested paths like "selector.path" to "selector_path" - if (fieldPath === 'selector.path') { - return { form: 'definition', field: 'selector_path' }; - } if (fieldPath === 'action.decision') { return { form: 'definition', field: 'action_decision' }; } @@ -117,7 +116,7 @@ export function mapApiFieldToFormField( export function applyApiErrorsToForms( errors: ValidationErrorItem[] | undefined, definitionForm: UseFormReturnType, - evaluatorForm: UseFormReturnType + evaluatorForm?: UseFormReturnType | null ): ValidationErrorItem[] { if (!errors || errors.length === 0) { return []; @@ -131,8 +130,10 @@ export function applyApiErrorsToForms( if (mapping) { if (mapping.form === 'definition') { definitionForm.setFieldError(mapping.field, error.message); - } else if (mapping.form === 'evaluator') { + } else if (mapping.form === 'evaluator' && evaluatorForm) { evaluatorForm.setFieldError(mapping.field, error.message); + } else if (mapping.form === 'evaluator') { + unmappedErrors.push(error); } } else { unmappedErrors.push(error); diff --git a/ui/tests/control-store.spec.ts b/ui/tests/control-store.spec.ts index 66d384e3..6cc259b9 100644 --- a/ui/tests/control-store.spec.ts +++ b/ui/tests/control-store.spec.ts @@ -786,22 +786,25 @@ test.describe('Modal Routing', () => { } ); - await mockedPage.route('**/api/v1/controls/*', async (route, request) => { - if (request.method() === 'DELETE') { - cleanupDeleteCalls += 1; - await route.fulfill({ - status: 200, - contentType: 'application/json', - body: JSON.stringify({ - success: true, - dissociated_from_policies: [], - dissociated_from_agents: [], - }), - }); - } else { - await route.continue(); + await mockedPage.route( + /\/api\/v1\/controls\/\d+\?force=true$/, + async (route, request) => { + if (request.method() === 'DELETE') { + cleanupDeleteCalls += 1; + await route.fulfill({ + status: 200, + contentType: 'application/json', + body: JSON.stringify({ + success: true, + dissociated_from_policies: [], + dissociated_from_agents: [], + }), + }); + } else { + await route.continue(); + } } - }); + ); const controlNameInput = createModal.getByPlaceholder('Enter control name'); await controlNameInput.fill('cleanup-test-control'); diff --git a/ui/tests/fixtures.ts b/ui/tests/fixtures.ts index f2289004..9f566844 100644 --- a/ui/tests/fixtures.ts +++ b/ui/tests/fixtures.ts @@ -88,10 +88,12 @@ const controlsList: Control[] = [ enabled: true, execution: 'server', scope: { step_types: ['llm'], stages: ['post'] }, - selector: { path: 'output' }, - evaluator: { - name: 'regex', - config: { pattern: '\\b\\d{3}-\\d{2}-\\d{4}\\b' }, + condition: { + selector: { path: 'output' }, + evaluator: { + name: 'regex', + config: { pattern: '\\b\\d{3}-\\d{2}-\\d{4}\\b' }, + }, }, action: { decision: 'deny' }, tags: ['pii', 'compliance'], @@ -110,10 +112,12 @@ const controlsList: Control[] = [ step_name_regex: '^db_.*', stages: ['pre'], }, - selector: { path: 'input.query' }, - evaluator: { - name: 'sql', - config: { mode: 'safe' }, + condition: { + selector: { path: 'input.query' }, + evaluator: { + name: 'sql', + config: { mode: 'safe' }, + }, }, action: { decision: 'deny' }, tags: ['security'], @@ -127,10 +131,12 @@ const controlsList: Control[] = [ enabled: false, execution: 'server', scope: { step_types: ['llm'], stages: ['pre'] }, - selector: { path: '*' }, - evaluator: { - name: 'list', - config: { values: [], logic: 'any', match_on: 'match' }, + condition: { + selector: { path: '*' }, + evaluator: { + name: 'list', + config: { values: [], logic: 'any', match_on: 'match' }, + }, }, action: { decision: 'allow' }, tags: [],