Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion py/packages/genkit/src/genkit/_core/_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,19 @@ def send_chunk(c: ChunkT) -> None:
channel.set_close_future(asyncio.create_task(resp))

result_future: asyncio.Future[OutputT] = asyncio.Future()
channel.closed.add_done_callback(lambda _: result_future.set_result(channel.closed.result().response))

def _propagate_closed_to_result(_: asyncio.Future[ActionResponse[OutputT]]) -> None:
if result_future.done():
return
closed = channel.closed
if closed.cancelled():
result_future.cancel()
elif (exc := closed.exception()) is not None:
result_future.set_exception(exc)
else:
result_future.set_result(closed.result().response)

channel.closed.add_done_callback(_propagate_closed_to_result)

return StreamResponse(stream=channel, response=result_future)

Expand Down
46 changes: 46 additions & 0 deletions py/packages/genkit/tests/genkit/core/action_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

"""Tests for the action module."""

import asyncio
import json
from typing import cast
from unittest.mock import patch

import pytest

Expand Down Expand Up @@ -177,6 +179,50 @@ async def foo(
assert chunks == ['1', '2']


@pytest.mark.asyncio
async def test_stream_cancellation_does_not_crash_callback() -> None:
"""ACC-562: When stream task is cancelled, the channel.closed callback must not crash.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix comment format


Previously the callback called channel.closed.result() which raises CancelledError
when the future was cancelled. The fix propagates cancellation to result_future
instead of crashing.
"""
block = asyncio.Event()

async def blocking_stream(
_input: str,
ctx: ActionRunContext,
) -> int:
ctx.send_chunk('started')
await block.wait() # Never completes unless we set it
return 42

action = Action(name='blocking', kind=ActionKind.CUSTOM, fn=blocking_stream)

captured_task: asyncio.Task | None = None
original_create_task = asyncio.create_task

def capturing_create_task(coro: object) -> asyncio.Task:
nonlocal captured_task
task = original_create_task(coro)
captured_task = task
return task

with patch('genkit._core._action.asyncio.create_task', side_effect=capturing_create_task):
result = action.stream('x')

# Consume the first chunk so the action progresses to block.wait()
async for _ in result.stream:
break

assert captured_task is not None
captured_task.cancel()

# Awaiting response should raise CancelledError, not crash in callback
with pytest.raises(asyncio.CancelledError):
await asyncio.wait_for(result.response, timeout=1.0)


def test_parse_plugin_name_from_action_name() -> None:
"""Parse plugin name from the action name."""
assert parse_plugin_name_from_action_name('foo') is None
Expand Down
Loading