From 5dbe0dcf900dbeab7dfe2c0eed748089a8f90fd1 Mon Sep 17 00:00:00 2001 From: kazche Date: Thu, 12 Mar 2026 00:32:37 -0700 Subject: [PATCH] =?UTF-8?q?fix(sdk):=20support=20async=20generator=20funct?= =?UTF-8?q?ions=20in=20control()=20decorator=20The=20control()=20decorator?= =?UTF-8?q?=20silently=20broke=20on=20streaming=20(async=20generator)=20fu?= =?UTF-8?q?nctions=20=E2=80=94=20the=20standard=20pattern=20for=20LLM=20re?= =?UTF-8?q?sponse=20streaming.This=20adds=20an=20async=5Fgen=5Fwrapper=20p?= =?UTF-8?q?ath=20that=20runs=20pre-check=20before=20the=20first=20chunk,?= =?UTF-8?q?=20yields=20chunks=20in=20real-time=20while=20accumulating=20ou?= =?UTF-8?q?tput,=20and=20runs=20post-check=20on=20the=20full=20accumulated?= =?UTF-8?q?=20output=20after=20the=20stream=20completes.=20fixes=20#113?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/agent_control/control_decorators.py | 52 +++++++ sdks/python/tests/test_control_decorators.py | 136 ++++++++++++++++++ 2 files changed, 188 insertions(+) diff --git a/sdks/python/src/agent_control/control_decorators.py b/sdks/python/src/agent_control/control_decorators.py index 569aecaa..c86b45a4 100644 --- a/sdks/python/src/agent_control/control_decorators.py +++ b/sdks/python/src/agent_control/control_decorators.py @@ -828,12 +828,64 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Any: except (AttributeError, TypeError): pass + @functools.wraps(func) + async def async_gen_wrapper(*args: Any, **kwargs: Any) -> Any: + agent = _get_current_agent() + if agent is None: + logger.warning( + "No agent initialized. Call agent_control.init() first. " + "Running without protection." + ) + async for chunk in func(*args, **kwargs): + yield chunk + return + + controls = _get_server_controls() + + existing_trace_id = get_current_trace_id() + if existing_trace_id: + trace_id = existing_trace_id + span_id = _generate_span_id() + else: + trace_id, span_id = get_trace_and_span_ids() + + ctx = ControlContext( + agent_name=agent.agent_name, + server_url=_get_server_url(), + func=func, + args=args, + kwargs=kwargs, + trace_id=trace_id, + span_id=span_id, + start_time=time.perf_counter(), + step_name=step_name, + ) + ctx.log_start() + + try: + # PRE-EXECUTION: Check controls with check_stage="pre" + await _run_control_check(ctx, "pre", ctx.pre_payload(), controls) + + # Yield chunks while accumulating full output for post-check + accumulated: list[str] = [] + async for chunk in func(*args, **kwargs): + accumulated.append(str(chunk)) + yield chunk + + # POST-EXECUTION: Check controls on full accumulated output + full_output = "".join(accumulated) + await _run_control_check(ctx, "post", ctx.post_payload(full_output), controls) + finally: + ctx.log_end() + @functools.wraps(func) def sync_wrapper(*args: Any, **kwargs: Any) -> Any: return asyncio.run( _execute_with_control(func, args, kwargs, is_async=False, step_name=step_name) ) + if inspect.isasyncgenfunction(func): + return async_gen_wrapper # type: ignore if inspect.iscoroutinefunction(func): return async_wrapper # type: ignore return sync_wrapper # type: ignore diff --git a/sdks/python/tests/test_control_decorators.py b/sdks/python/tests/test_control_decorators.py index d13ec46d..c8b770d9 100644 --- a/sdks/python/tests/test_control_decorators.py +++ b/sdks/python/tests/test_control_decorators.py @@ -886,3 +886,139 @@ async def test_func(): assert mock_logger.error.call_args[0][0] == "%s-execution control check failed: %s" assert mock_logger.error.call_args[0][1] == "Post" assert str(mock_logger.error.call_args[0][2]) == "Post-execution error" + + +# ============================================================================= +# ASYNC GENERATOR (STREAMING) TESTS +# ============================================================================= + +class TestAsyncGeneratorControl: + """Tests for @control decorator on async generator (streaming) functions.""" + + @pytest.mark.asyncio + async def test_preserves_async_gen_type(self, mock_agent, mock_safe_response): + """Test that decorated async generator is still recognized as async gen.""" + import inspect + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", return_value=mock_safe_response): + + @control() + async def stream(message: str): + yield "Hello " + yield "world" + + assert inspect.isasyncgenfunction(stream) + + @pytest.mark.asyncio + async def test_yields_all_chunks(self, mock_agent, mock_safe_response): + """Test that all chunks are yielded through when safe.""" + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", return_value=mock_safe_response): + + @control() + async def stream(message: str): + yield "chunk1" + yield "chunk2" + yield "chunk3" + + chunks = [chunk async for chunk in stream("test")] + assert chunks == ["chunk1", "chunk2", "chunk3"] + + @pytest.mark.asyncio + async def test_pre_check_blocks_before_first_yield(self, mock_agent, mock_unsafe_response): + """Test that pre-check deny prevents any chunks from being yielded.""" + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", return_value=mock_unsafe_response): + + executed = False + + @control() + async def stream(message: str): + nonlocal executed + executed = True + yield "should not appear" + + with pytest.raises(ControlViolationError): + async for _ in stream("toxic input"): + pass + + assert not executed + + @pytest.mark.asyncio + async def test_post_check_deny_raises_after_stream(self, mock_agent, mock_safe_response, mock_unsafe_response): + """Test that post-check deny raises after all chunks have been yielded.""" + call_count = [0] + + def mock_evaluate_side_effect(*args, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + return mock_safe_response # pre-check passes + return mock_unsafe_response # post-check fails + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate_side_effect): + + @control() + async def stream(message: str): + yield "chunk1" + yield "chunk2" + + collected = [] + with pytest.raises(ControlViolationError): + async for chunk in stream("test"): + collected.append(chunk) + + # Chunks were yielded before the post-check raised + assert collected == ["chunk1", "chunk2"] + + @pytest.mark.asyncio + async def test_no_agent_streams_without_protection(self): + """Test that async gen passes through if no agent initialized.""" + with patch("agent_control.control_decorators._get_current_agent", return_value=None): + + @control() + async def stream(message: str): + yield "a" + yield "b" + + chunks = [chunk async for chunk in stream("hello")] + assert chunks == ["a", "b"] + + @pytest.mark.asyncio + async def test_empty_stream(self, mock_agent, mock_safe_response): + """Test that empty async generator works correctly.""" + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", return_value=mock_safe_response): + + @control() + async def stream(message: str): + return + yield # noqa: unreachable - makes this an async generator + + chunks = [chunk async for chunk in stream("test")] + assert chunks == [] + + @pytest.mark.asyncio + async def test_steer_on_post_check(self, mock_agent, mock_safe_response, mock_steer_response): + """Test that steer action raises ControlSteerError after stream.""" + call_count = [0] + + def mock_evaluate_side_effect(*args, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + return mock_safe_response + return mock_steer_response + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate_side_effect): + + @control() + async def stream(message: str): + yield "response" + + with pytest.raises(ControlSteerError) as exc_info: + async for _ in stream("offensive"): + pass + + assert exc_info.value.control_name == "steer-control"