Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 112 additions & 2 deletions src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,6 +1072,7 @@ async def _start_streaming(
tool_use_tracker,
all_tools,
server_conversation_tracker,
session,
)
should_run_agent_start_hooks = False

Expand Down Expand Up @@ -1107,7 +1108,7 @@ async def _start_streaming(
AgentUpdatedStreamEvent(new_agent=current_agent)
)

# Check for soft cancel after handoff
# Check for soft cancel after handoff (before next turn)
if streamed_result._cancel_mode == "after_turn": # type: ignore[comparison-overlap]
streamed_result.is_complete = True
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
Expand Down Expand Up @@ -1158,7 +1159,7 @@ async def _start_streaming(
session, [], turn_result.new_step_items
)

# Check for soft cancel after turn completion
# Check for soft cancel after tool execution completes (before next turn)
if streamed_result._cancel_mode == "after_turn": # type: ignore[comparison-overlap]
streamed_result.is_complete = True
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
Expand Down Expand Up @@ -1217,6 +1218,7 @@ async def _run_single_turn_streamed(
tool_use_tracker: AgentToolUseTracker,
all_tools: list[Tool],
server_conversation_tracker: _ServerConversationTracker | None = None,
session: Session | None = None,
) -> SingleStepResult:
emitted_tool_call_ids: set[str] = set()
emitted_reasoning_item_ids: set[str] = set()
Expand Down Expand Up @@ -1369,6 +1371,114 @@ async def _run_single_turn_streamed(
if not final_response:
raise ModelBehaviorError("Model did not produce a final response!")

# Check for soft cancel after LLM response streaming completes (before tool execution)
# Only cancel here if there are no tools/handoffs to execute - otherwise let tools execute
# and the cancel will be honored after tool execution completes
if streamed_result._cancel_mode == "after_turn":
# Process the model response to check if there are tools/handoffs to execute
processed_response = RunImpl.process_model_response(
agent=agent,
all_tools=all_tools,
response=final_response,
output_schema=output_schema,
handoffs=handoffs,
)

# If there are tools, handoffs, or approvals to execute, let normal flow continue
# The cancel will be honored after tool execution completes (before next step)
if processed_response.has_tools_or_approvals_to_run() or processed_response.handoffs:
# Continue with normal flow - tools will execute,
# then cancel after execution completes
pass
else:
# No tools/handoffs to execute - safe to cancel here and skip tool execution
# Note: We intentionally skip execute_tools_and_side_effects() since there are
# no tools to execute. This allows faster cancellation when the LLM response
# contains no actions.
tool_use_tracker.add_tool_use(agent, processed_response.tools_used)

# Filter out items that have already been sent to avoid duplicates
items_to_save = list(processed_response.new_items)

if emitted_tool_call_ids:
# Filter out tool call items that were already emitted during streaming
items_to_save = [
item
for item in items_to_save
if not (
isinstance(item, ToolCallItem)
and (
call_id := getattr(
item.raw_item, "call_id", getattr(item.raw_item, "id", None)
)
)
and call_id in emitted_tool_call_ids
)
]

if emitted_reasoning_item_ids:
# Filter out reasoning items that were already emitted during streaming
items_to_save = [
item
for item in items_to_save
if not (
isinstance(item, ReasoningItem)
and (reasoning_id := getattr(item.raw_item, "id", None))
and reasoning_id in emitted_reasoning_item_ids
)
]

# Filter out HandoffCallItem to avoid duplicates (already sent earlier)
items_to_save = [
item for item in items_to_save if not isinstance(item, HandoffCallItem)
]

# Create SingleStepResult with NextStepRunAgain (we're stopping mid-turn)
single_step_result = SingleStepResult(
original_input=streamed_result.input,
model_response=final_response,
pre_step_items=streamed_result.new_items,
new_step_items=items_to_save,
next_step=NextStepRunAgain(),
tool_input_guardrail_results=[],
tool_output_guardrail_results=[],
)

# Save session with the model response items
# Exclude ToolCallItem objects to avoid saving incomplete tool calls without outputs
if session is not None:
should_skip_session_save = (
await AgentRunner._input_guardrail_tripwire_triggered_for_stream(
streamed_result
)
)
if should_skip_session_save is False:
# Filter out tool calls - they don't have outputs yet, so shouldn't be saved
# This prevents saving incomplete tool calls that violate API requirements
items_for_session = [
item for item in items_to_save if not isinstance(item, ToolCallItem)
]
# Type ignore: intentionally filtering out ToolCallItem to avoid saving
# incomplete tool calls without corresponding outputs
await AgentRunner._save_result_to_session(
session,
[],
items_for_session, # type: ignore[arg-type]
)

# Stream the items to the event queue
import dataclasses as _dc

RunImpl.stream_step_result_to_queue(
single_step_result, streamed_result._event_queue
)

# Mark as complete and signal completion
streamed_result.is_complete = True
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())

return single_step_result

# 3. Now, we can process the turn as we do in the non-streaming case
single_step_result = await cls._get_single_step_result_from_response(
agent=agent,
Expand Down
67 changes: 60 additions & 7 deletions tests/test_soft_cancel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import pytest

from agents import Agent, Runner, SQLiteSession
from agents import Agent, OutputGuardrail, Runner, SQLiteSession
from agents.guardrail import GuardrailFunctionOutput

from .fake_model import FakeModel
from .test_responses import get_function_tool, get_function_tool_call, get_text_message
Expand Down Expand Up @@ -87,13 +88,15 @@ async def test_soft_cancel_with_tool_calls():
if event.type == "run_item_stream_event":
if event.name == "tool_called":
tool_call_seen = True
# Cancel right after seeing tool call
# Cancel right after seeing tool call - tools will execute
# then cancel is honored after tool execution completes
result.cancel(mode="after_turn")
elif event.name == "tool_output":
tool_output_seen = True

assert tool_call_seen, "Tool call should be seen"
assert tool_output_seen, "Tool output should be seen (tool should execute before soft cancel)"
assert tool_output_seen, "Tool output SHOULD be seen (tools execute before cancel is honored)"
assert result.is_complete, "Result should be marked complete"


@pytest.mark.asyncio
Expand Down Expand Up @@ -293,18 +296,25 @@ async def test_soft_cancel_with_multiple_tool_calls():

result = Runner.run_streamed(agent, input="Execute tools")

tool_calls_seen = 0
tool_outputs_seen = 0
async for event in result.stream_events():
if event.type == "run_item_stream_event":
if event.name == "tool_called":
# Cancel after seeing first tool call
if tool_outputs_seen == 0:
tool_calls_seen += 1
# Cancel after seeing first tool call - tools will execute
# then cancel is honored after tool execution completes
if tool_calls_seen == 1:
result.cancel(mode="after_turn")
elif event.name == "tool_output":
tool_outputs_seen += 1

# Both tools should execute
assert tool_outputs_seen == 2, "Both tools should execute before soft cancel"
# Tool calls should be seen, and tools SHOULD execute before cancel is honored
assert tool_calls_seen >= 1, "Tool calls should be seen"
assert tool_outputs_seen > 0, (
"Tool outputs SHOULD be seen (tools execute before cancel is honored)"
)
assert result.is_complete, "Result should be marked complete"


@pytest.mark.asyncio
Expand Down Expand Up @@ -476,3 +486,46 @@ async def test_soft_cancel_with_session_and_multiple_turns():

# Cleanup
await session.clear_session()


@pytest.mark.asyncio
async def test_soft_cancel_runs_output_guardrails_before_canceling():
"""Verify output guardrails run even when cancellation happens after final output."""
model = FakeModel()

# Track if guardrail was called
guardrail_called = False

def output_guardrail_fn(context, agent, output):
nonlocal guardrail_called
guardrail_called = True
return GuardrailFunctionOutput(output_info=None, tripwire_triggered=False)

agent = Agent(
name="Assistant",
model=model,
output_guardrails=[OutputGuardrail(guardrail_function=output_guardrail_fn)],
)

# Setup: agent produces final output
model.add_multiple_turn_outputs([[get_text_message("Final answer")]])

result = Runner.run_streamed(agent, input="What is the answer?")

# Cancel after seeing the message output event (indicates turn completed)
# but before consuming all events
async for event in result.stream_events():
if event.type == "run_item_stream_event" and event.name == "message_output_created":
# Cancel after turn completes - guardrails should still run
result.cancel(mode="after_turn")
# Don't break - continue consuming to let guardrails complete

# Guardrail should have been called
assert guardrail_called, "Output guardrail should run even when canceling after final output"

# Final output should be set
assert result.final_output is not None, "final_output should be set even when canceling"
assert result.final_output == "Final answer"

# Output guardrail results should be recorded
assert len(result.output_guardrail_results) == 1, "Output guardrail results should be recorded"