Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 166 additions & 4 deletions burr/core/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,132 @@ def _run_single_step_action(
return out


def _instrumented_sync_generator(
generator: Generator,
lifecycle_adapters: LifecycleAdapterSet,
stream_initialize_time,
action: str,
sequence_id: int,
app_id: str,
partition_key: Optional[str],
) -> Generator:
"""Wraps a synchronous generator to instrument stream generation with lifecycle hooks.

This function wraps a generator and fires pre_stream_generate and post_stream_generate
hooks around each __next__() call. This brackets the actual generation time for each
item, excluding consumer processing time. The hooks receive metadata including the
item index, action name, sequence_id, app_id, and partition_key.

Exceptions raised during generation are propagated after firing the post_stream_generate
hook with the exception. StopIteration is handled gracefully to signal generator completion.

Args:
generator: The synchronous generator to wrap and instrument.
lifecycle_adapters: Set of lifecycle adapters to call hooks on.
stream_initialize_time: Timestamp when the stream was initialized.
action: Name of the action generating the stream.
sequence_id: Sequence identifier for this execution.
app_id: Application identifier.
partition_key: Optional partition key for distributed execution.

Yields:
Items from the wrapped generator, one at a time.
"""
gen_iter = iter(generator)
count = 0
while True:
hook_kwargs = dict(
item_index=count,
stream_initialize_time=stream_initialize_time,
action=action,
sequence_id=sequence_id,
app_id=app_id,
partition_key=partition_key,
)
lifecycle_adapters.call_all_lifecycle_hooks_sync("pre_stream_generate", **hook_kwargs)
try:
item = next(gen_iter)
except StopIteration:
lifecycle_adapters.call_all_lifecycle_hooks_sync(
"post_stream_generate", item=None, exception=None, **hook_kwargs
)
return
except Exception as e:
lifecycle_adapters.call_all_lifecycle_hooks_sync(
"post_stream_generate", item=None, exception=e, **hook_kwargs
)
raise
lifecycle_adapters.call_all_lifecycle_hooks_sync(
"post_stream_generate", item=item, exception=None, **hook_kwargs
)
yield item
count += 1


async def _instrumented_async_generator(
generator: AsyncGenerator,
lifecycle_adapters: LifecycleAdapterSet,
stream_initialize_time,
action: str,
sequence_id: int,
app_id: str,
partition_key: Optional[str],
) -> AsyncGenerator:
"""Wraps an asynchronous generator to instrument stream generation with lifecycle hooks.

This function wraps an async generator and fires pre_stream_generate and post_stream_generate
hooks around each __anext__() call. This brackets the actual generation time for each
item, excluding consumer processing time. The hooks receive metadata including the
item index, action name, sequence_id, app_id, and partition_key.

Exceptions raised during generation are propagated after firing the post_stream_generate
hook with the exception. StopAsyncIteration is handled gracefully to signal generator completion.

Args:
generator: The asynchronous generator to wrap and instrument.
lifecycle_adapters: Set of lifecycle adapters to call hooks on.
stream_initialize_time: Timestamp when the stream was initialized.
action: Name of the action generating the stream.
sequence_id: Sequence identifier for this execution.
app_id: Application identifier.
partition_key: Optional partition key for distributed execution.

Yields:
Items from the wrapped generator, one at a time.
"""
aiter = generator.__aiter__()
count = 0
while True:
hook_kwargs = dict(
item_index=count,
stream_initialize_time=stream_initialize_time,
action=action,
sequence_id=sequence_id,
app_id=app_id,
partition_key=partition_key,
)
await lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async(
"pre_stream_generate", **hook_kwargs
)
try:
item = await aiter.__anext__()
except StopAsyncIteration:
await lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async(
"post_stream_generate", item=None, exception=None, **hook_kwargs
)
return
except Exception as e:
await lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async(
"post_stream_generate", item=None, exception=e, **hook_kwargs
)
raise
await lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async(
"post_stream_generate", item=item, exception=None, **hook_kwargs
)
yield item
count += 1


def _run_single_step_streaming_action(
action: SingleStepStreamingAction,
state: State,
Expand All @@ -334,7 +460,16 @@ def _run_single_step_streaming_action(
action.validate_inputs(inputs)
stream_initialize_time = system.now()
first_stream_start_time = None
generator = action.stream_run_and_update(state, **inputs)
raw_generator = action.stream_run_and_update(state, **inputs)
generator = _instrumented_sync_generator(
raw_generator,
lifecycle_adapters,
stream_initialize_time=stream_initialize_time,
action=action.name,
sequence_id=sequence_id,
app_id=app_id,
partition_key=partition_key,
)
result = None
state_update = None
count = 0
Expand Down Expand Up @@ -387,7 +522,16 @@ async def _arun_single_step_streaming_action(
action.validate_inputs(inputs)
stream_initialize_time = system.now()
first_stream_start_time = None
generator = action.stream_run_and_update(state, **inputs)
raw_generator = action.stream_run_and_update(state, **inputs)
generator = _instrumented_async_generator(
raw_generator,
lifecycle_adapters,
stream_initialize_time=stream_initialize_time,
action=action.name,
sequence_id=sequence_id,
app_id=app_id,
partition_key=partition_key,
)
result = None
state_update = None
count = 0
Expand Down Expand Up @@ -446,7 +590,16 @@ def _run_multi_step_streaming_action(
"""
action.validate_inputs(inputs)
stream_initialize_time = system.now()
generator = action.stream_run(state, **inputs)
raw_generator = action.stream_run(state, **inputs)
generator = _instrumented_sync_generator(
raw_generator,
lifecycle_adapters,
stream_initialize_time=stream_initialize_time,
action=action.name,
sequence_id=sequence_id,
app_id=app_id,
partition_key=partition_key,
)
result = None
first_stream_start_time = None
count = 0
Expand Down Expand Up @@ -490,7 +643,16 @@ async def _arun_multi_step_streaming_action(
"""Runs a multi-step streaming action in async. See the synchronous version for more details."""
action.validate_inputs(inputs)
stream_initialize_time = system.now()
generator = action.stream_run(state, **inputs)
raw_generator = action.stream_run(state, **inputs)
generator = _instrumented_async_generator(
raw_generator,
lifecycle_adapters,
stream_initialize_time=stream_initialize_time,
action=action.name,
sequence_id=sequence_id,
app_id=app_id,
partition_key=partition_key,
)
result = None
first_stream_start_time = None
count = 0
Expand Down
Loading
Loading