Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions sdks/python/src/agent_control/control_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,12 +828,64 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
except (AttributeError, TypeError):
pass

@functools.wraps(func)
async def async_gen_wrapper(*args: Any, **kwargs: Any) -> Any:
agent = _get_current_agent()
if agent is None:
logger.warning(
"No agent initialized. Call agent_control.init() first. "
"Running without protection."
)
async for chunk in func(*args, **kwargs):
yield chunk
return

controls = _get_server_controls()

existing_trace_id = get_current_trace_id()
if existing_trace_id:
trace_id = existing_trace_id
span_id = _generate_span_id()
else:
trace_id, span_id = get_trace_and_span_ids()

ctx = ControlContext(
agent_name=agent.agent_name,
server_url=_get_server_url(),
func=func,
args=args,
kwargs=kwargs,
trace_id=trace_id,
span_id=span_id,
start_time=time.perf_counter(),
step_name=step_name,
)
ctx.log_start()

try:
# PRE-EXECUTION: Check controls with check_stage="pre"
await _run_control_check(ctx, "pre", ctx.pre_payload(), controls)

# Yield chunks while accumulating full output for post-check
accumulated: list[str] = []
async for chunk in func(*args, **kwargs):
accumulated.append(str(chunk))
yield chunk

# POST-EXECUTION: Check controls on full accumulated output
full_output = "".join(accumulated)
await _run_control_check(ctx, "post", ctx.post_payload(full_output), controls)
finally:
ctx.log_end()

@functools.wraps(func)
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
return asyncio.run(
_execute_with_control(func, args, kwargs, is_async=False, step_name=step_name)
)

if inspect.isasyncgenfunction(func):
return async_gen_wrapper # type: ignore
if inspect.iscoroutinefunction(func):
return async_wrapper # type: ignore
return sync_wrapper # type: ignore
Expand Down
136 changes: 136 additions & 0 deletions sdks/python/tests/test_control_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,3 +886,139 @@ async def test_func():
assert mock_logger.error.call_args[0][0] == "%s-execution control check failed: %s"
assert mock_logger.error.call_args[0][1] == "Post"
assert str(mock_logger.error.call_args[0][2]) == "Post-execution error"


# =============================================================================
# ASYNC GENERATOR (STREAMING) TESTS
# =============================================================================

class TestAsyncGeneratorControl:
"""Tests for @control decorator on async generator (streaming) functions."""

@pytest.mark.asyncio
async def test_preserves_async_gen_type(self, mock_agent, mock_safe_response):
"""Test that decorated async generator is still recognized as async gen."""
import inspect

with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \
patch("agent_control.control_decorators._evaluate", return_value=mock_safe_response):

@control()
async def stream(message: str):
yield "Hello "
yield "world"

assert inspect.isasyncgenfunction(stream)

@pytest.mark.asyncio
async def test_yields_all_chunks(self, mock_agent, mock_safe_response):
"""Test that all chunks are yielded through when safe."""
with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \
patch("agent_control.control_decorators._evaluate", return_value=mock_safe_response):

@control()
async def stream(message: str):
yield "chunk1"
yield "chunk2"
yield "chunk3"

chunks = [chunk async for chunk in stream("test")]
assert chunks == ["chunk1", "chunk2", "chunk3"]

@pytest.mark.asyncio
async def test_pre_check_blocks_before_first_yield(self, mock_agent, mock_unsafe_response):
"""Test that pre-check deny prevents any chunks from being yielded."""
with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \
patch("agent_control.control_decorators._evaluate", return_value=mock_unsafe_response):

executed = False

@control()
async def stream(message: str):
nonlocal executed
executed = True
yield "should not appear"

with pytest.raises(ControlViolationError):
async for _ in stream("toxic input"):
pass

assert not executed

@pytest.mark.asyncio
async def test_post_check_deny_raises_after_stream(self, mock_agent, mock_safe_response, mock_unsafe_response):
"""Test that post-check deny raises after all chunks have been yielded."""
call_count = [0]

def mock_evaluate_side_effect(*args, **kwargs):
call_count[0] += 1
if call_count[0] == 1:
return mock_safe_response # pre-check passes
return mock_unsafe_response # post-check fails

with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \
patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate_side_effect):

@control()
async def stream(message: str):
yield "chunk1"
yield "chunk2"

collected = []
with pytest.raises(ControlViolationError):
async for chunk in stream("test"):
collected.append(chunk)

# Chunks were yielded before the post-check raised
assert collected == ["chunk1", "chunk2"]

@pytest.mark.asyncio
async def test_no_agent_streams_without_protection(self):
"""Test that async gen passes through if no agent initialized."""
with patch("agent_control.control_decorators._get_current_agent", return_value=None):

@control()
async def stream(message: str):
yield "a"
yield "b"

chunks = [chunk async for chunk in stream("hello")]
assert chunks == ["a", "b"]

@pytest.mark.asyncio
async def test_empty_stream(self, mock_agent, mock_safe_response):
"""Test that empty async generator works correctly."""
with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \
patch("agent_control.control_decorators._evaluate", return_value=mock_safe_response):

@control()
async def stream(message: str):
return
yield # noqa: unreachable - makes this an async generator

chunks = [chunk async for chunk in stream("test")]
assert chunks == []

@pytest.mark.asyncio
async def test_steer_on_post_check(self, mock_agent, mock_safe_response, mock_steer_response):
"""Test that steer action raises ControlSteerError after stream."""
call_count = [0]

def mock_evaluate_side_effect(*args, **kwargs):
call_count[0] += 1
if call_count[0] == 1:
return mock_safe_response
return mock_steer_response

with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \
patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate_side_effect):

@control()
async def stream(message: str):
yield "response"

with pytest.raises(ControlSteerError) as exc_info:
async for _ in stream("offensive"):
pass

assert exc_info.value.control_name == "steer-control"
Loading