From 564fcde84b0674141f650dc1b5e1b5c7e44403dc Mon Sep 17 00:00:00 2001 From: Stefan Krawczyk Date: Sat, 28 Feb 2026 15:24:35 -0800 Subject: [PATCH 1/3] feat: add configurable streaming telemetry to OpenTelemetryBridge Add pre_stream_generate/post_stream_generate lifecycle hooks that bracket each generator __next__/__anext__ call, enabling per-yield instrumentation for streaming actions. Also improves attributes added for streaming to distinguish internal computation from consumer time. OpenTelemetryBridge accepts a new streaming_telemetry parameter (StreamingTelemetryMode enum) controlling how streaming actions are instrumented: - SINGLE_SPAN (default): single action span, backwards compatible - EVENT: action span + stream_completed summary event with generation/consumer timing, iteration count, and TTFT - CHUNK_SPANS: per-yield child spans, no action span - BOTH: action span with summary event + per-yield child spans --- burr/core/application.py | 170 ++- burr/integrations/opentelemetry.py | 301 +++- burr/lifecycle/__init__.py | 8 + burr/lifecycle/base.py | 96 ++ .../streaming_telemetry_modes.py | 119 ++ tests/core/test_application.py | 365 +++++ tests/integrations/test_opentelemetry.py | 1234 +++++++++++++++++ 7 files changed, 2276 insertions(+), 17 deletions(-) create mode 100644 examples/opentelemetry/streaming_telemetry_modes.py diff --git a/burr/core/application.py b/burr/core/application.py index dc8067c4b..a92cdcb19 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -320,6 +320,132 @@ def _run_single_step_action( return out +def _instrumented_sync_generator( + generator: Generator, + lifecycle_adapters: LifecycleAdapterSet, + stream_initialize_time, + action: str, + sequence_id: int, + app_id: str, + partition_key: Optional[str], +) -> Generator: + """Wraps a synchronous generator to instrument stream generation with lifecycle hooks. + + This function wraps a generator and fires pre_stream_generate and post_stream_generate + hooks around each __next__() call. This brackets the actual generation time for each + item, excluding consumer processing time. The hooks receive metadata including the + item index, action name, sequence_id, app_id, and partition_key. + + Exceptions raised during generation are propagated after firing the post_stream_generate + hook with the exception. StopIteration is handled gracefully to signal generator completion. + + Args: + generator: The synchronous generator to wrap and instrument. + lifecycle_adapters: Set of lifecycle adapters to call hooks on. + stream_initialize_time: Timestamp when the stream was initialized. + action: Name of the action generating the stream. + sequence_id: Sequence identifier for this execution. + app_id: Application identifier. + partition_key: Optional partition key for distributed execution. + + Yields: + Items from the wrapped generator, one at a time. + """ + gen_iter = iter(generator) + count = 0 + while True: + hook_kwargs = dict( + item_index=count, + stream_initialize_time=stream_initialize_time, + action=action, + sequence_id=sequence_id, + app_id=app_id, + partition_key=partition_key, + ) + lifecycle_adapters.call_all_lifecycle_hooks_sync("pre_stream_generate", **hook_kwargs) + try: + item = next(gen_iter) + except StopIteration: + lifecycle_adapters.call_all_lifecycle_hooks_sync( + "post_stream_generate", item=None, exception=None, **hook_kwargs + ) + return + except Exception as e: + lifecycle_adapters.call_all_lifecycle_hooks_sync( + "post_stream_generate", item=None, exception=e, **hook_kwargs + ) + raise + lifecycle_adapters.call_all_lifecycle_hooks_sync( + "post_stream_generate", item=item, exception=None, **hook_kwargs + ) + yield item + count += 1 + + +async def _instrumented_async_generator( + generator: AsyncGenerator, + lifecycle_adapters: LifecycleAdapterSet, + stream_initialize_time, + action: str, + sequence_id: int, + app_id: str, + partition_key: Optional[str], +) -> AsyncGenerator: + """Wraps an asynchronous generator to instrument stream generation with lifecycle hooks. + + This function wraps an async generator and fires pre_stream_generate and post_stream_generate + hooks around each __anext__() call. This brackets the actual generation time for each + item, excluding consumer processing time. The hooks receive metadata including the + item index, action name, sequence_id, app_id, and partition_key. + + Exceptions raised during generation are propagated after firing the post_stream_generate + hook with the exception. StopAsyncIteration is handled gracefully to signal generator completion. + + Args: + generator: The asynchronous generator to wrap and instrument. + lifecycle_adapters: Set of lifecycle adapters to call hooks on. + stream_initialize_time: Timestamp when the stream was initialized. + action: Name of the action generating the stream. + sequence_id: Sequence identifier for this execution. + app_id: Application identifier. + partition_key: Optional partition key for distributed execution. + + Yields: + Items from the wrapped generator, one at a time. + """ + aiter = generator.__aiter__() + count = 0 + while True: + hook_kwargs = dict( + item_index=count, + stream_initialize_time=stream_initialize_time, + action=action, + sequence_id=sequence_id, + app_id=app_id, + partition_key=partition_key, + ) + await lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async( + "pre_stream_generate", **hook_kwargs + ) + try: + item = await aiter.__anext__() + except StopAsyncIteration: + await lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async( + "post_stream_generate", item=None, exception=None, **hook_kwargs + ) + return + except Exception as e: + await lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async( + "post_stream_generate", item=None, exception=e, **hook_kwargs + ) + raise + await lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async( + "post_stream_generate", item=item, exception=None, **hook_kwargs + ) + yield item + count += 1 + + def _run_single_step_streaming_action( action: SingleStepStreamingAction, state: State, @@ -334,7 +460,16 @@ def _run_single_step_streaming_action( action.validate_inputs(inputs) stream_initialize_time = system.now() first_stream_start_time = None - generator = action.stream_run_and_update(state, **inputs) + raw_generator = action.stream_run_and_update(state, **inputs) + generator = _instrumented_sync_generator( + raw_generator, + lifecycle_adapters, + stream_initialize_time=stream_initialize_time, + action=action.name, + sequence_id=sequence_id, + app_id=app_id, + partition_key=partition_key, + ) result = None state_update = None count = 0 @@ -387,7 +522,16 @@ async def _arun_single_step_streaming_action( action.validate_inputs(inputs) stream_initialize_time = system.now() first_stream_start_time = None - generator = action.stream_run_and_update(state, **inputs) + raw_generator = action.stream_run_and_update(state, **inputs) + generator = _instrumented_async_generator( + raw_generator, + lifecycle_adapters, + stream_initialize_time=stream_initialize_time, + action=action.name, + sequence_id=sequence_id, + app_id=app_id, + partition_key=partition_key, + ) result = None state_update = None count = 0 @@ -446,7 +590,16 @@ def _run_multi_step_streaming_action( """ action.validate_inputs(inputs) stream_initialize_time = system.now() - generator = action.stream_run(state, **inputs) + raw_generator = action.stream_run(state, **inputs) + generator = _instrumented_sync_generator( + raw_generator, + lifecycle_adapters, + stream_initialize_time=stream_initialize_time, + action=action.name, + sequence_id=sequence_id, + app_id=app_id, + partition_key=partition_key, + ) result = None first_stream_start_time = None count = 0 @@ -490,7 +643,16 @@ async def _arun_multi_step_streaming_action( """Runs a multi-step streaming action in async. See the synchronous version for more details.""" action.validate_inputs(inputs) stream_initialize_time = system.now() - generator = action.stream_run(state, **inputs) + raw_generator = action.stream_run(state, **inputs) + generator = _instrumented_async_generator( + raw_generator, + lifecycle_adapters, + stream_initialize_time=stream_initialize_time, + action=action.name, + sequence_id=sequence_id, + app_id=app_id, + partition_key=partition_key, + ) result = None first_stream_start_time = None count = 0 diff --git a/burr/integrations/opentelemetry.py b/burr/integrations/opentelemetry.py index 32dc4dd7f..7dc09c6d4 100644 --- a/burr/integrations/opentelemetry.py +++ b/burr/integrations/opentelemetry.py @@ -17,12 +17,14 @@ import dataclasses import datetime +import enum import importlib import importlib.metadata import json import logging import random import sys +import time from contextvars import ContextVar from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple @@ -47,8 +49,10 @@ from burr.lifecycle import ( PostApplicationExecuteCallHook, PostRunStepHook, + PostStreamGenerateHook, PreApplicationExecuteCallHook, PreRunStepHook, + PreStreamGenerateHook, ) from burr.lifecycle.base import DoLogAttributeHook, ExecuteMethod, PostEndSpanHook, PreStartSpanHook from burr.tracking import LocalTrackingClient @@ -87,6 +91,43 @@ def get_cached_span(span_id: int) -> Optional[FullSpanContext]: tracker_context = ContextVar[Optional[SyncTrackingClient]]("tracker_context", default=None) +# Tracks whether the action-level span was skipped for streaming actions +_skipped_action_span = ContextVar[bool]("_skipped_action_span", default=False) + + +# Valid streaming telemetry modes +class StreamingTelemetryMode(enum.Enum): + """Controls how streaming actions are instrumented by the OpenTelemetryBridge. + + - ``SINGLE_SPAN``: A single action span covers the full generator lifetime (default). + - ``EVENT``: A single action span plus a ``stream_completed`` summary span event. + - ``CHUNK_SPANS``: No action span. Per-yield child spans under the method span. + - ``SINGLE_AND_CHUNK_SPANS``: Action span with summary event plus per-yield child spans. + """ + + SINGLE_SPAN = "single_span" + EVENT = "event" + CHUNK_SPANS = "chunk_spans" + SINGLE_AND_CHUNK_SPANS = "single_and_chunk_spans" + + +@dataclasses.dataclass +class _StreamingAccumulator: + """Accumulates timing data across stream yields for the span event summary.""" + + generation_time_ns: int = 0 + consumer_time_ns: int = 0 + iteration_count: int = 0 + first_item_time_ns: Optional[int] = None + stream_start_ns: Optional[int] = None + last_post_generate_ns: Optional[int] = None + _pre_generate_ns: Optional[int] = None + + +_streaming_accumulator = ContextVar[Optional[_StreamingAccumulator]]( + "_streaming_accumulator", default=None +) + def _is_homogeneous_sequence(value: Sequence): if len(value) == 0: @@ -149,20 +190,34 @@ class OpenTelemetryBridge( PreStartSpanHook, PostEndSpanHook, DoLogAttributeHook, + PreStreamGenerateHook, + PostStreamGenerateHook, ): - """Adapter to log Burr events to OpenTelemetry. At a high level, this works as follows: + """Lifecycle adapter that maps Burr execution events to OpenTelemetry spans and events. + + **How it works** + + The bridge implements Burr lifecycle hooks to create a span hierarchy that mirrors the + execution structure: - 1. On any of the start/pre hooks (pre_run_execute_call, pre_run_step, pre_start_span), we start a new span - 2. On any of the post ones we exit the span, accounting for the error (setting it if needed) - 3. On do_log_attributes, we log the attributes to the current span -- these are serialized using the serde module + 1. ``pre_run_execute_call`` / ``post_run_execute_call`` — creates a top-level **method span** + for the application method being called (e.g. ``step``, ``astream_result``). + 2. ``pre_run_step`` / ``post_run_step`` — creates an **action span** as a child of the + method span. For streaming actions, behavior depends on the ``streaming_telemetry`` mode. + 3. ``pre_start_span`` / ``post_end_span`` — creates **sub-action spans** for user-defined + visibility spans (via ``TracerFactory`` / ``__tracer``). + 4. ``do_log_attributes`` — sets OTel attributes on the current span. + 5. ``pre_stream_generate`` / ``post_stream_generate`` — for streaming actions, optionally + creates per-yield **chunk spans** and/or accumulates timing data for a summary event. - This works by logging to OpenTelemetry, and setting the span processor to be the right one (that knows about the tracker). + All spans are managed via a ContextVar-based token stack (``token_stack``) to correctly + handle nesting across sync and async execution. - You can use this as follows: + **Usage** .. code-block:: python - # replace with instructions from your prefered vendor + # replace with instructions from your preferred vendor my_vendor_library_or_tracer_provider.init() app = ( @@ -174,15 +229,44 @@ class OpenTelemetryBridge( .build() ) - app.run() # will log to OpenTelemetry + app.run() # will log to OpenTelemetry + + **Streaming telemetry modes** + + The ``streaming_telemetry`` parameter controls how streaming actions are instrumented. + Non-streaming actions are unaffected — they always produce a single action span. + + - ``StreamingTelemetryMode.SINGLE_SPAN`` (default): A single action span covers the full + generator lifetime (including consumer wait time). Streaming **attributes** are set on + the span with the generation/consumer timing breakdown: + + - ``stream.generation_time_ms`` — time spent inside the generator producing items + - ``stream.consumer_time_ms`` — time the consumer spent processing yielded items + - ``stream.iteration_count`` — number of items yielded + - ``stream.first_item_time_ms`` — time to first item (TTFT) + + - ``StreamingTelemetryMode.EVENT``: No action span. A ``stream_completed`` (or + ``stream_error``) span event is added to the **method span** with the timing summary + (including ``stream.total_time_ms`` since there is no action span to carry duration). + - ``StreamingTelemetryMode.CHUNK_SPANS``: No action span. A child span + (``{action}::chunk_{N}``) is created for each generator yield under the method span. + Each chunk span measures only generation time (excludes consumer processing time). + - ``StreamingTelemetryMode.SINGLE_AND_CHUNK_SPANS``: Combines ``SINGLE_SPAN`` and ``CHUNK_SPANS`` — the + action span (with streaming attributes) plus per-yield chunk spans as its children. """ - def __init__(self, tracer_name: str = None, tracer: trace.Tracer = None): + def __init__( + self, + tracer_name: str = None, + tracer: trace.Tracer = None, + streaming_telemetry: StreamingTelemetryMode = StreamingTelemetryMode.SINGLE_SPAN, + ): """Initializes an OpenTel adapter. Passes in a tracer_name or a tracer object, should only pass one. :param tracer_name: Name of the tracer if you want it to initialize for you -- not including it will use a default :param tracer: Tracer object if you want to pass it in yourself + :param streaming_telemetry: How to instrument streaming actions. See :class:`StreamingTelemetryMode`. """ if tracer_name and tracer: raise ValueError( @@ -192,6 +276,54 @@ def __init__(self, tracer_name: str = None, tracer: trace.Tracer = None): self.tracer = tracer else: self.tracer = trace.get_tracer(__name__ if tracer_name is None else tracer_name) + self.streaming_telemetry = streaming_telemetry + + @property + def _emit_chunk_spans(self) -> bool: + """Whether to create per-yield chunk spans (CHUNK_SPANS or BOTH).""" + return self.streaming_telemetry in ( + StreamingTelemetryMode.CHUNK_SPANS, + StreamingTelemetryMode.SINGLE_AND_CHUNK_SPANS, + ) + + @property + def _emit_event(self) -> bool: + """Whether to emit a summary span event on the method span (EVENT only). + + EVENT mode skips the action span entirely and attaches a ``stream_completed`` + event to the method span instead. + """ + return self.streaming_telemetry == StreamingTelemetryMode.EVENT + + @property + def _emit_attributes(self) -> bool: + """Whether to set streaming attributes on the action span (SINGLE_SPAN or BOTH). + + These modes create an action span and set generation time, consumer time, + iteration count, and TTFT as span attributes. + """ + return self.streaming_telemetry in ( + StreamingTelemetryMode.SINGLE_SPAN, + StreamingTelemetryMode.SINGLE_AND_CHUNK_SPANS, + ) + + @property + def _use_accumulator(self) -> bool: + """Whether timing accumulation is needed (all modes except CHUNK_SPANS).""" + return self.streaming_telemetry != StreamingTelemetryMode.CHUNK_SPANS + + @property + def _skip_single_action_span_for_streaming(self) -> bool: + """Whether to skip the action-level span for streaming actions. + + True for EVENT and CHUNK_SPANS modes. EVENT attaches data to the method span + instead. CHUNK_SPANS replaces the action span with per-yield child spans. + In SINGLE_SPAN and BOTH modes, the action span is created normally. + """ + return self.streaming_telemetry in ( + StreamingTelemetryMode.EVENT, + StreamingTelemetryMode.CHUNK_SPANS, + ) def pre_run_execute_call( self, @@ -199,6 +331,11 @@ def pre_run_execute_call( method: ExecuteMethod, **future_kwargs: Any, ): + """Opens the top-level **method span** (e.g. ``step``, ``astream_result``). + + This is the outermost span in the Burr trace hierarchy. Action spans and chunk + spans are nested under it. + """ # TODO -- handle links -- we need to wire this through _enter_span(method.value, self.tracer) @@ -208,6 +345,11 @@ def do_log_attributes( attributes: Dict[str, Any], **future_kwargs: Any, ): + """Sets key-value attributes on the current OTel span. + + Values are serialized via :func:`convert_to_otel_attribute` to ensure they are + OTel-compatible types (str, bool, int, float, or homogeneous sequences thereof). + """ otel_span = get_current_span() if otel_span is None: logger.warning( @@ -224,7 +366,22 @@ def pre_run_step( action: "Action", **future_kwargs: Any, ): - _enter_span(action.name, self.tracer) + """Opens an **action span** for the step about to execute. + + For streaming actions in ``EVENT`` or ``CHUNK_SPANS`` mode, the action span is + skipped. In ``SINGLE_SPAN`` and ``SINGLE_AND_CHUNK_SPANS`` modes, the action span is created normally. + + For all modes except ``CHUNK_SPANS``, a :class:`_StreamingAccumulator` is initialized + to collect timing data across generator yields. + """ + if getattr(action, "streaming", False) and self._skip_single_action_span_for_streaming: + _skipped_action_span.set(True) + else: + _skipped_action_span.set(False) + _enter_span(action.name, self.tracer) + # Initialize accumulator for modes that need timing data + if getattr(action, "streaming", False) and self._use_accumulator: + _streaming_accumulator.set(_StreamingAccumulator()) def pre_start_span( self, @@ -232,6 +389,11 @@ def pre_start_span( span: "ActionSpan", **future_kwargs: Any, ): + """Opens a **sub-action span** for a user-defined visibility span. + + These are created by the ``TracerFactory`` (``__tracer``) context manager inside + actions, and are nested under the current action span. + """ _enter_span(span.name, self.tracer) def post_end_span( @@ -240,6 +402,7 @@ def post_end_span( span: "ActionSpan", **future_kwargs: Any, ): + """Closes a sub-action span opened by :meth:`pre_start_span`.""" # TODO -- wire through exceptions _exit_span() @@ -249,7 +412,120 @@ def post_run_step( exception: Exception, **future_kwargs: Any, ): - _exit_span(exception) + """Closes the action span and, for streaming actions, emits summary telemetry. + + Behavior depends on mode: + + - ``SINGLE_SPAN`` / ``SINGLE_AND_CHUNK_SPANS``: Sets streaming attributes on the action span, then + closes it. + - ``EVENT``: Emits a ``stream_completed`` (or ``stream_error``) span event on the + method span (the action span was skipped). Resets the skipped flag. + - ``CHUNK_SPANS``: The action span was skipped; just resets the flag. + """ + acc = _streaming_accumulator.get() + if acc is not None: + first_item_ms = 0.0 + if acc.first_item_time_ns is not None and acc.stream_start_ns is not None: + first_item_ms = (acc.first_item_time_ns - acc.stream_start_ns) / 1e6 + + if self._emit_attributes: + # SINGLE_SPAN / BOTH: set attributes on the action span + otel_span = get_current_span() + if otel_span is not None: + otel_span.set_attributes( + { + "stream.generation_time_ms": acc.generation_time_ns / 1e6, + "stream.consumer_time_ms": acc.consumer_time_ns / 1e6, + "stream.iteration_count": acc.iteration_count, + "stream.first_item_time_ms": first_item_ms, + } + ) + + elif self._emit_event: + # EVENT: emit span event on the method span (action span was skipped) + otel_span = get_current_span() + if otel_span is not None: + total_time_ns = 0 + if acc.stream_start_ns is not None and acc.last_post_generate_ns is not None: + total_time_ns = acc.last_post_generate_ns - acc.stream_start_ns + event_name = "stream_error" if exception else "stream_completed" + attrs: Dict[str, Any] = { + "stream.generation_time_ms": acc.generation_time_ns / 1e6, + "stream.consumer_time_ms": acc.consumer_time_ns / 1e6, + "stream.total_time_ms": total_time_ns / 1e6, + "stream.iteration_count": acc.iteration_count, + "stream.first_item_time_ms": first_item_ms, + } + if exception: + attrs["stream.error"] = str(exception) + otel_span.add_event(event_name, attributes=attrs) + + _streaming_accumulator.set(None) + + if _skipped_action_span.get(): + _skipped_action_span.set(False) + else: + _exit_span(exception) + + def pre_stream_generate( + self, + *, + action: str, + item_index: int, + **future_kwargs: Any, + ): + """Called just before each ``__next__()`` / ``__anext__()`` on the generator. + + For modes with accumulation (``SINGLE_SPAN``, ``EVENT``, ``SINGLE_AND_CHUNK_SPANS``), records the + start of generation time and accumulates consumer time (the gap between the previous + ``post_stream_generate`` and now). + + In ``CHUNK_SPANS`` or ``SINGLE_AND_CHUNK_SPANS`` mode, opens a child span named + ``{action}::chunk_{item_index}``. + """ + now_ns = time.time_ns() + acc = _streaming_accumulator.get() + if acc is not None: + if acc.stream_start_ns is None: + acc.stream_start_ns = now_ns + if acc.last_post_generate_ns is not None: + acc.consumer_time_ns += now_ns - acc.last_post_generate_ns + acc._pre_generate_ns = now_ns # stash for post + + if self._emit_chunk_spans: + _enter_span(f"{action}::chunk_{item_index}", self.tracer) + + def post_stream_generate( + self, + *, + item: Any, + item_index: int, + exception: Optional[Exception], + **future_kwargs: Any, + ): + """Called just after each ``__next__()`` / ``__anext__()`` returns (or raises). + + For modes with accumulation (``SINGLE_SPAN``, ``EVENT``, ``SINGLE_AND_CHUNK_SPANS``), accumulates + generation time and updates the iteration count. When ``item`` is not ``None``, + the item is counted; a ``None`` item signals generator exhaustion (``StopIteration``). + + In ``CHUNK_SPANS`` or ``SINGLE_AND_CHUNK_SPANS`` mode, closes the chunk span opened by + :meth:`pre_stream_generate`, setting an error status if ``exception`` is provided. + """ + now_ns = time.time_ns() + acc = _streaming_accumulator.get() + if acc is not None: + pre_ns = acc._pre_generate_ns + if pre_ns is not None: + acc.generation_time_ns += now_ns - pre_ns + if item is not None: + acc.iteration_count += 1 + if acc.first_item_time_ns is None: + acc.first_item_time_ns = now_ns + acc.last_post_generate_ns = now_ns + + if self._emit_chunk_spans: + _exit_span(exception) def post_run_execute_call( self, @@ -257,6 +533,7 @@ def post_run_execute_call( exception: Optional[Exception], **future_kwargs, ): + """Closes the top-level method span opened by :meth:`pre_run_execute_call`.""" _exit_span(exception) @@ -705,8 +982,6 @@ def init_instruments(*instruments: INSTRUMENTS, init_all: bool = False) -> None: tracker = LocalTrackingClient("otel_test") opentel_adapter = OpenTelemetryTracker(burr_tracker=tracker) - import time - from burr.core import ApplicationBuilder, Result, action, default, expr from burr.visibility import TracerFactory diff --git a/burr/lifecycle/__init__.py b/burr/lifecycle/__init__.py index 4ae24073a..9cf10eff5 100644 --- a/burr/lifecycle/__init__.py +++ b/burr/lifecycle/__init__.py @@ -23,11 +23,15 @@ PostEndSpanHook, PostRunStepHook, PostRunStepHookAsync, + PostStreamGenerateHook, + PostStreamGenerateHookAsync, PreApplicationExecuteCallHook, PreApplicationExecuteCallHookAsync, PreRunStepHook, PreRunStepHookAsync, PreStartSpanHook, + PreStreamGenerateHook, + PreStreamGenerateHookAsync, ) from burr.lifecycle.default import StateAndResultsFullLogger @@ -45,4 +49,8 @@ "PostApplicationCreateHook", "PostEndSpanHook", "PreStartSpanHook", + "PreStreamGenerateHook", + "PreStreamGenerateHookAsync", + "PostStreamGenerateHook", + "PostStreamGenerateHookAsync", ] diff --git a/burr/lifecycle/base.py b/burr/lifecycle/base.py index 66d8bd7e6..a2c3e3ff1 100644 --- a/burr/lifecycle/base.py +++ b/burr/lifecycle/base.py @@ -492,6 +492,98 @@ async def post_end_stream( pass +@lifecycle.base_hook("pre_stream_generate") +class PreStreamGenerateHook(abc.ABC): + """Hook that runs before the generator produces its next item. + Paired with PostStreamGenerateHook to bracket the actual generation time + for each stream item, excluding consumer processing time. + """ + + @abc.abstractmethod + def pre_stream_generate( + self, + *, + item_index: int, + stream_initialize_time: datetime.datetime, + action: str, + sequence_id: int, + app_id: str, + partition_key: Optional[str], + **future_kwargs: Any, + ): + pass + + +@lifecycle.base_hook("pre_stream_generate") +class PreStreamGenerateHookAsync(abc.ABC): + """Hook that runs before the generator produces its next item (async variant). + Paired with PostStreamGenerateHookAsync to bracket the actual generation time + for each stream item, excluding consumer processing time. + """ + + @abc.abstractmethod + async def pre_stream_generate( + self, + *, + item_index: int, + stream_initialize_time: datetime.datetime, + action: str, + sequence_id: int, + app_id: str, + partition_key: Optional[str], + **future_kwargs: Any, + ): + pass + + +@lifecycle.base_hook("post_stream_generate") +class PostStreamGenerateHook(abc.ABC): + """Hook that runs after the generator has produced an item (or exhausted/errored). + Paired with PreStreamGenerateHook to bracket the actual generation time + for each stream item, excluding consumer processing time. + """ + + @abc.abstractmethod + def post_stream_generate( + self, + *, + item: Any, + item_index: int, + stream_initialize_time: datetime.datetime, + action: str, + sequence_id: int, + app_id: str, + partition_key: Optional[str], + exception: Optional[Exception], + **future_kwargs: Any, + ): + pass + + +@lifecycle.base_hook("post_stream_generate") +class PostStreamGenerateHookAsync(abc.ABC): + """Hook that runs after the generator has produced an item (or exhausted/errored). + Async variant. Paired with PreStreamGenerateHookAsync to bracket the actual + generation time for each stream item, excluding consumer processing time. + """ + + @abc.abstractmethod + async def post_stream_generate( + self, + *, + item: Any, + item_index: int, + stream_initialize_time: datetime.datetime, + action: str, + sequence_id: int, + app_id: str, + partition_key: Optional[str], + exception: Optional[Exception], + **future_kwargs: Any, + ): + pass + + # strictly for typing -- this conflicts a bit with the lifecycle decorator above, but its fine for now # This makes IDE completion/type-hinting easier LifecycleAdapter = Union[ @@ -515,4 +607,8 @@ async def post_end_stream( PreStartStreamHookAsync, PostStreamItemHookAsync, PostEndStreamHookAsync, + PreStreamGenerateHook, + PreStreamGenerateHookAsync, + PostStreamGenerateHook, + PostStreamGenerateHookAsync, ] diff --git a/examples/opentelemetry/streaming_telemetry_modes.py b/examples/opentelemetry/streaming_telemetry_modes.py new file mode 100644 index 000000000..c50381f12 --- /dev/null +++ b/examples/opentelemetry/streaming_telemetry_modes.py @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Demonstrates the four StreamingTelemetryMode options for OpenTelemetryBridge. + +Runs a simple async streaming action under each mode with the OTel console +exporter so you can see the spans and events printed to stdout. + +Usage: + python examples/opentelemetry/streaming_telemetry_modes.py + +No external APIs are needed — the streaming action simulates an LLM by yielding +tokens with small delays. +""" + +import asyncio + +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import ConsoleSpanExporter, SimpleSpanProcessor + +from burr.core import ApplicationBuilder, State +from burr.core.action import streaming_action +from burr.core.graph import GraphBuilder +from burr.integrations.opentelemetry import OpenTelemetryBridge, StreamingTelemetryMode + +# --------------------------------------------------------------------------- +# A simple streaming action that simulates token-by-token LLM output +# --------------------------------------------------------------------------- + + +@streaming_action(reads=["prompt"], writes=["response"]) +async def generate_response(state: State) -> None: + """Simulates a streaming LLM response, yielding one token at a time.""" + tokens = state["prompt"].split() + buffer = [] + for token in tokens: + await asyncio.sleep(0.02) # simulate generation latency per token + buffer.append(token) + yield {"token": token}, None + + response = " ".join(buffer) + yield {"token": "", "response": response}, state.update(response=response) + + +# --------------------------------------------------------------------------- +# Build the graph (shared across all modes) +# --------------------------------------------------------------------------- + +graph = GraphBuilder().with_actions(generate=generate_response).with_transitions().build() + + +# --------------------------------------------------------------------------- +# Run one mode +# --------------------------------------------------------------------------- + + +async def run_with_mode(mode: StreamingTelemetryMode) -> None: + """Builds an app with the given streaming telemetry mode and runs it.""" + # Each run gets its own tracer provider so console output stays grouped + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(ConsoleSpanExporter())) + tracer = provider.get_tracer("streaming-telemetry-demo") + + bridge = OpenTelemetryBridge(tracer=tracer, streaming_telemetry=mode) + + app = ( + ApplicationBuilder() + .with_graph(graph) + .with_entrypoint("generate") + .with_state(State({"prompt": "hello world from burr streaming"})) + .with_hooks(bridge) + .with_identifiers(app_id=f"demo-{mode.value}") + .build() + ) + + action, container = await app.astream_result(halt_after=["generate"]) + async for item in container: + await asyncio.sleep(0.05) # simulate consumer processing time per token + await container.get() + + provider.shutdown() + + +# --------------------------------------------------------------------------- +# Main — run all four modes +# --------------------------------------------------------------------------- + + +async def main(): + modes = [ + StreamingTelemetryMode.SINGLE_SPAN, + StreamingTelemetryMode.EVENT, + StreamingTelemetryMode.CHUNK_SPANS, + StreamingTelemetryMode.SINGLE_AND_CHUNK_SPANS, + ] + for mode in modes: + print(f"\n{'=' * 70}") + print(f" StreamingTelemetryMode.{mode.name}") + print(f"{'=' * 70}\n") + await run_with_mode(mode) + print() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/core/test_application.py b/tests/core/test_application.py index c90c40676..c9facd9be 100644 --- a/tests/core/test_application.py +++ b/tests/core/test_application.py @@ -83,12 +83,16 @@ PostApplicationExecuteCallHook, PostApplicationExecuteCallHookAsync, PostEndStreamHook, + PostStreamGenerateHook, + PostStreamGenerateHookAsync, PostStreamItemHook, PostStreamItemHookAsync, PreApplicationExecuteCallHook, PreApplicationExecuteCallHookAsync, PreStartStreamHook, PreStartStreamHookAsync, + PreStreamGenerateHook, + PreStreamGenerateHookAsync, ) from burr.lifecycle.internal import LifecycleAdapterSet from burr.tracking.base import SyncTrackingClient @@ -3761,3 +3765,364 @@ def noop(state: State) -> State: app = builder.build() assert app.state["x"] == 100 + + +# ============================================================================ +# Tests for pre_stream_generate / post_stream_generate lifecycle hooks +# ============================================================================ + + +class GenerateEventCaptureTracker( + PreStartStreamHook, + PostEndStreamHook, +): + """Captures pre/post_stream_generate calls via the new hooks, plus + existing pre_start_stream/post_end_stream for ordering verification.""" + + def __init__(self): + self.calls: list[tuple[str, dict]] = [] + + def pre_start_stream( + self, *, action: str, sequence_id: int, app_id: str, partition_key, **future_kwargs + ): + self.calls.append(("pre_start_stream", {"action": action})) + + def post_end_stream( + self, *, action: str, sequence_id: int, app_id: str, partition_key, **future_kwargs + ): + self.calls.append(("post_end_stream", {"action": action})) + + +class StreamGenerateTracker(PreStreamGenerateHook, PostStreamGenerateHook): + """Sync tracker that captures pre/post_stream_generate hook calls.""" + + def __init__(self): + self.calls: list[tuple[str, int]] = [] # (hook_name, item_index) + + def pre_stream_generate(self, *, item_index: int, action: str, **future_kwargs): + self.calls.append(("pre_stream_generate", item_index)) + + def post_stream_generate( + self, *, item, item_index: int, action: str, exception, **future_kwargs + ): + self.calls.append(("post_stream_generate", item_index)) + + +class StreamGenerateTrackerAsync(PreStreamGenerateHookAsync, PostStreamGenerateHookAsync): + """Async tracker that captures pre/post_stream_generate hook calls.""" + + def __init__(self): + self.calls: list[tuple[str, int]] = [] + + async def pre_stream_generate(self, *, item_index: int, action: str, **future_kwargs): + self.calls.append(("pre_stream_generate", item_index)) + + async def post_stream_generate( + self, *, item, item_index: int, action: str, exception, **future_kwargs + ): + self.calls.append(("post_stream_generate", item_index)) + + +# --- Test #1: sync single-step calls pre/post_stream_generate --- + + +def test__run_single_step_streaming_action_calls_stream_generate_hooks(): + tracker = StreamGenerateTracker() + action = base_streaming_single_step_counter.with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _run_single_step_streaming_action( + action, + state, + inputs={}, + sequence_id=0, + partition_key="pk", + app_id="app", + lifecycle_adapters=LifecycleAdapterSet(tracker), + ) + collections.deque(generator, maxlen=0) # exhaust + # 10 intermediate yields + 1 final yield = 11 items from the action generator + pre_calls = [c for c in tracker.calls if c[0] == "pre_stream_generate"] + post_calls = [c for c in tracker.calls if c[0] == "post_stream_generate"] + # pre fires before each __next__ including the StopIteration attempt (11 items + 1 stop = 12) + assert len(pre_calls) == 12 + assert len(post_calls) == 12 # matched: 11 items + 1 StopIteration + + +# --- Test #2: sync multi-step calls pre/post_stream_generate --- + + +def test__run_multi_step_streaming_action_calls_stream_generate_hooks(): + tracker = StreamGenerateTracker() + action = base_streaming_counter.with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _run_multi_step_streaming_action( + action, + state, + inputs={}, + sequence_id=0, + partition_key="pk", + app_id="app", + lifecycle_adapters=LifecycleAdapterSet(tracker), + ) + collections.deque(generator, maxlen=0) + pre_calls = [c for c in tracker.calls if c[0] == "pre_stream_generate"] + post_calls = [c for c in tracker.calls if c[0] == "post_stream_generate"] + # Multi-step: 11 items from generator + 1 StopIteration attempt = 12 pre, 12 post + assert len(pre_calls) == 12 + assert len(post_calls) == 12 + + +# --- Test #3: async single-step calls pre/post_stream_generate --- + + +async def test__arun_single_step_streaming_action_calls_stream_generate_hooks(): + tracker = StreamGenerateTrackerAsync() + action = base_streaming_single_step_counter_async.with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _arun_single_step_streaming_action( + action=action, + state=state, + inputs={}, + sequence_id=0, + app_id="app", + partition_key="pk", + lifecycle_adapters=LifecycleAdapterSet(tracker), + ) + async for _ in generator: + pass + pre_calls = [c for c in tracker.calls if c[0] == "pre_stream_generate"] + post_calls = [c for c in tracker.calls if c[0] == "post_stream_generate"] + assert len(pre_calls) == 12 + assert len(post_calls) == 12 + + +# --- Test #4: async multi-step calls pre/post_stream_generate --- + + +async def test__arun_multi_step_streaming_action_calls_stream_generate_hooks(): + tracker = StreamGenerateTrackerAsync() + action = base_streaming_counter_async.with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _arun_multi_step_streaming_action( + action=action, + state=state, + inputs={}, + sequence_id=0, + app_id="app", + partition_key="pk", + lifecycle_adapters=LifecycleAdapterSet(tracker), + ) + async for _ in generator: + pass + pre_calls = [c for c in tracker.calls if c[0] == "pre_stream_generate"] + post_calls = [c for c in tracker.calls if c[0] == "post_stream_generate"] + assert len(pre_calls) == 12 + assert len(post_calls) == 12 + + +# --- Test #5: hook ordering (pre always before corresponding post) --- + + +def test__run_single_step_streaming_action_stream_generate_hook_ordering(): + tracker = StreamGenerateTracker() + action = base_streaming_single_step_counter.with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _run_single_step_streaming_action( + action, + state, + inputs={}, + sequence_id=0, + partition_key="pk", + app_id="app", + lifecycle_adapters=LifecycleAdapterSet(tracker), + ) + collections.deque(generator, maxlen=0) + # Verify strict interleaving: pre(0), post(0), pre(1), post(1), ... + for i in range(0, len(tracker.calls), 2): + assert tracker.calls[i] == ("pre_stream_generate", i // 2) + assert tracker.calls[i + 1] == ("post_stream_generate", i // 2) + + +async def test__arun_multi_step_streaming_action_stream_generate_hook_ordering(): + tracker = StreamGenerateTrackerAsync() + action = base_streaming_counter_async.with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _arun_multi_step_streaming_action( + action=action, + state=state, + inputs={}, + sequence_id=0, + app_id="app", + partition_key="pk", + lifecycle_adapters=LifecycleAdapterSet(tracker), + ) + async for _ in generator: + pass + for i in range(0, len(tracker.calls), 2): + assert tracker.calls[i] == ("pre_stream_generate", i // 2) + assert tracker.calls[i + 1] == ("post_stream_generate", i // 2) + + +# --- Test #7: error mid-stream --- + + +class ErrorAfterNSingleStep(SingleStepStreamingAction): + """Action that raises after n intermediate yields.""" + + def __init__(self, n: int): + super().__init__() + self.n = n + + def stream_run_and_update(self, state, **run_kwargs): + for i in range(self.n): + yield {"i": i}, None + raise RuntimeError("boom") + + @property + def reads(self): + return [] + + @property + def writes(self): + return [] + + +def test__run_single_step_streaming_action_stream_generate_on_error(): + tracker = StreamGenerateTracker() + action = ErrorAfterNSingleStep(3).with_name("errorer") + state = State({}) + generator = _run_single_step_streaming_action( + action, + state, + inputs={}, + sequence_id=0, + partition_key="pk", + app_id="app", + lifecycle_adapters=LifecycleAdapterSet(tracker), + ) + with pytest.raises(RuntimeError, match="boom"): + collections.deque(generator, maxlen=0) + pre_calls = [c for c in tracker.calls if c[0] == "pre_stream_generate"] + post_calls = [c for c in tracker.calls if c[0] == "post_stream_generate"] + # 3 successful yields + 1 that raises = 4 pre, 4 post (error post has exception) + assert len(pre_calls) == 4 + assert len(post_calls) == 4 + + +# --- Test #8: existing post_stream_item callbacks unchanged --- + + +def test__run_single_step_streaming_action_existing_callbacks_unchanged_with_generate_hooks(): + class TrackingCallback(PostStreamItemHook): + def __init__(self): + self.items = [] + + def post_stream_item(self, item, **future_kwargs): + self.items.append(item) + + hook = TrackingCallback() + gen_tracker = StreamGenerateTracker() + action = base_streaming_single_step_counter.with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _run_single_step_streaming_action( + action, + state, + inputs={}, + sequence_id=0, + partition_key="pk", + app_id="app", + lifecycle_adapters=LifecycleAdapterSet(hook, gen_tracker), + ) + collections.deque(generator, maxlen=0) + # post_stream_item still fires exactly 10 times (only for intermediate items) + assert len(hook.items) == 10 + # But pre/post_stream_generate fire for all items + StopIteration + pre_calls = [c for c in gen_tracker.calls if c[0] == "pre_stream_generate"] + assert len(pre_calls) == 12 + + +# --- Test #18: single yield --- + + +class SingleYieldAction(SingleStepStreamingAction): + def stream_run_and_update(self, state, **run_kwargs): + yield {"val": 1}, None + yield {"val": 2}, state + + @property + def reads(self): + return [] + + @property + def writes(self): + return [] + + +def test__stream_generate_hooks_single_intermediate_yield(): + tracker = StreamGenerateTracker() + action = SingleYieldAction().with_name("single") + state = State({}) + gen = _run_single_step_streaming_action( + action, + state, + inputs={}, + sequence_id=0, + partition_key="pk", + app_id="app", + lifecycle_adapters=LifecycleAdapterSet(tracker), + ) + collections.deque(gen, maxlen=0) + pre_calls = [c for c in tracker.calls if c[0] == "pre_stream_generate"] + post_calls = [c for c in tracker.calls if c[0] == "post_stream_generate"] + # 2 items + 1 StopIteration = 3 pre/post pairs + assert len(pre_calls) == 3 + assert len(post_calls) == 3 + + +# --- Test #19: zero intermediate yields --- + + +class NoIntermediateYieldAction(SingleStepStreamingAction): + def stream_run_and_update(self, state, **run_kwargs): + yield {"val": 1}, state + + @property + def reads(self): + return [] + + @property + def writes(self): + return [] + + +def test__stream_generate_hooks_zero_intermediate_yields(): + tracker = StreamGenerateTracker() + action = NoIntermediateYieldAction().with_name("noint") + state = State({}) + gen = _run_single_step_streaming_action( + action, + state, + inputs={}, + sequence_id=0, + partition_key="pk", + app_id="app", + lifecycle_adapters=LifecycleAdapterSet(tracker), + ) + collections.deque(gen, maxlen=0) + pre_calls = [c for c in tracker.calls if c[0] == "pre_stream_generate"] + post_calls = [c for c in tracker.calls if c[0] == "post_stream_generate"] + # 1 item (final) + 1 StopIteration = 2 pre/post pairs + assert len(pre_calls) == 2 + assert len(post_calls) == 2 + + +# --- Test #20: non-streaming action doesn't fire stream generate hooks --- + + +def test__non_streaming_action_does_not_fire_stream_generate_hooks(): + tracker = StreamGenerateTracker() + action = base_single_step_counter.with_name("counter") + state = State({"count": 0, "tracker": []}) + _run_single_step_action(action, state, inputs={}) + # No stream generate hooks should have been called + assert len(tracker.calls) == 0 diff --git a/tests/integrations/test_opentelemetry.py b/tests/integrations/test_opentelemetry.py index b86062764..f9644f476 100644 --- a/tests/integrations/test_opentelemetry.py +++ b/tests/integrations/test_opentelemetry.py @@ -15,10 +15,1244 @@ # specific language governing permissions and limitations # under the License. +import asyncio +import datetime +import threading +import time import typing +from typing import Sequence +from unittest.mock import MagicMock + +from opentelemetry import trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter, SpanExportResult import burr.integrations.opentelemetry as burr_otel +from burr.core.action import SingleStepAction, SingleStepStreamingAction +from burr.core.application import Application, _arun_single_step_streaming_action +from burr.core.graph import Graph +from burr.core.state import State +from burr.integrations.opentelemetry import OpenTelemetryBridge +from burr.integrations.opentelemetry import StreamingTelemetryMode as STM +from burr.integrations.opentelemetry import ( + _exit_span, + _skipped_action_span, + _streaming_accumulator, + token_stack, +) +from burr.lifecycle.internal import LifecycleAdapterSet + +# ============================================================================ +# Simple in-memory exporter (not available in all otel SDK versions) +# ============================================================================ + + +class _InMemorySpanExporter(SpanExporter): + """Collects finished spans in memory for test assertions.""" + + def __init__(self): + self._spans = [] + self._lock = threading.Lock() + + def export(self, spans: Sequence) -> SpanExportResult: + with self._lock: + self._spans.extend(spans) + return SpanExportResult.SUCCESS + + def shutdown(self): + pass + + def get_finished_spans(self): + with self._lock: + return list(self._spans) + + def clear(self): + with self._lock: + self._spans.clear() def test_instrument_specs_match_instruments_literal(): assert set(typing.get_args(burr_otel.INSTRUMENTS)) == set(burr_otel.INSTRUMENTS_SPECS.keys()) + + +# ============================================================================ +# Helpers +# ============================================================================ + + +def _make_bridge_and_exporter(streaming_telemetry: STM = STM.SINGLE_SPAN): + """Creates an OpenTelemetryBridge with an in-memory exporter for test assertions.""" + exporter = _InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = provider.get_tracer("test") + bridge = OpenTelemetryBridge(tracer=tracer, streaming_telemetry=streaming_telemetry) + return bridge, exporter + + +def _make_mock_action(name: str, streaming: bool = False): + """Creates a mock Action object with the given name and streaming flag.""" + action = MagicMock() + action.name = name + action.streaming = streaming + return action + + +def _reset_token_stack(): + """Reset the token_stack and streaming ContextVars to clean state.""" + token_stack.set(None) + _skipped_action_span.set(False) + _streaming_accumulator.set(None) + + +# ============================================================================ +# Test #9: pre_stream_generate enters a span +# ============================================================================ + + +def test_bridge_pre_stream_generate_enters_span(): + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.CHUNK_SPANS) + + bridge.pre_stream_generate( + action="my_action", + item_index=0, + stream_initialize_time=datetime.datetime.now(), + sequence_id=0, + app_id="app", + partition_key="pk", + ) + + stack = token_stack.get() + assert stack is not None + assert len(stack) == 1 + _, span = stack[0] + assert span.name == "my_action::chunk_0" + + # Clean up + _exit_span() + _reset_token_stack() + + +# ============================================================================ +# Test #10: post_stream_generate exits a span +# ============================================================================ + + +def test_bridge_post_stream_generate_exits_span(): + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.CHUNK_SPANS) + + bridge.pre_stream_generate( + action="my_action", + item_index=0, + stream_initialize_time=datetime.datetime.now(), + sequence_id=0, + app_id="app", + partition_key="pk", + ) + assert len(token_stack.get()) == 1 + + bridge.post_stream_generate( + item={"chunk": "data"}, + item_index=0, + stream_initialize_time=datetime.datetime.now(), + action="my_action", + sequence_id=0, + app_id="app", + partition_key="pk", + exception=None, + ) + + stack = token_stack.get() + assert len(stack) == 0 + + spans = exporter.get_finished_spans() + assert len(spans) == 1 + assert spans[0].name == "my_action::chunk_0" + assert spans[0].status.status_code == trace.StatusCode.OK + + _reset_token_stack() + + +# ============================================================================ +# Test #12: span naming for multiple chunks +# ============================================================================ + + +def test_bridge_stream_span_naming(): + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.CHUNK_SPANS) + + now = datetime.datetime.now() + for i in range(3): + bridge.pre_stream_generate( + action="my_action", + item_index=i, + stream_initialize_time=now, + sequence_id=0, + app_id="app", + partition_key="pk", + ) + bridge.post_stream_generate( + item={"i": i}, + item_index=i, + stream_initialize_time=now, + action="my_action", + sequence_id=0, + app_id="app", + partition_key="pk", + exception=None, + ) + + spans = exporter.get_finished_spans() + assert [s.name for s in spans] == [ + "my_action::chunk_0", + "my_action::chunk_1", + "my_action::chunk_2", + ] + _reset_token_stack() + + +# ============================================================================ +# Test #14: span closed on generator error +# ============================================================================ + + +def test_bridge_stream_span_closed_on_generator_error(): + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.CHUNK_SPANS) + + now = datetime.datetime.now() + exc = RuntimeError("generator failed") + + bridge.pre_stream_generate( + action="my_action", + item_index=0, + stream_initialize_time=now, + sequence_id=0, + app_id="app", + partition_key="pk", + ) + bridge.post_stream_generate( + item=None, + item_index=0, + stream_initialize_time=now, + action="my_action", + sequence_id=0, + app_id="app", + partition_key="pk", + exception=exc, + ) + + spans = exporter.get_finished_spans() + assert len(spans) == 1 + assert spans[0].status.status_code == trace.StatusCode.ERROR + assert "generator failed" in spans[0].status.description + + stack = token_stack.get() + assert len(stack) == 0 + _reset_token_stack() + + +# ============================================================================ +# Test #21: pre_run_step skips span for streaming action +# ============================================================================ + + +def test_bridge_pre_run_step_skips_span_for_streaming_action(): + _reset_token_stack() + bridge, _ = _make_bridge_and_exporter(streaming_telemetry=STM.CHUNK_SPANS) + action = _make_mock_action("stream_action", streaming=True) + + bridge.pre_run_step(action=action) + + assert _skipped_action_span.get() is True + stack = token_stack.get() + assert stack is None or len(stack) == 0 + + _reset_token_stack() + + +# ============================================================================ +# Test #22: pre_run_step creates span for non-streaming action +# ============================================================================ + + +def test_bridge_pre_run_step_creates_span_for_non_streaming_action(): + _reset_token_stack() + bridge, _ = _make_bridge_and_exporter() + action = _make_mock_action("normal_action", streaming=False) + + bridge.pre_run_step(action=action) + + assert _skipped_action_span.get() is False + stack = token_stack.get() + assert stack is not None + assert len(stack) == 1 + _, span = stack[0] + assert span.name == "normal_action" + + # Clean up + _exit_span() + _reset_token_stack() + + +# ============================================================================ +# Test #23: post_run_step skips exit when action span was skipped +# ============================================================================ + + +def test_bridge_post_run_step_skips_exit_when_action_span_was_skipped(): + _reset_token_stack() + bridge, _ = _make_bridge_and_exporter(streaming_telemetry=STM.CHUNK_SPANS) + + # Simulate streaming action: pre_run_step skipped the span + _skipped_action_span.set(True) + + # post_run_step should not pop anything (nothing was pushed) + bridge.post_run_step(exception=None) + + assert _skipped_action_span.get() is False + stack = token_stack.get() + assert stack is None or len(stack) == 0 + + _reset_token_stack() + + +# ============================================================================ +# Test #24: post_run_step exits span for non-streaming action +# ============================================================================ + + +def test_bridge_post_run_step_exits_span_for_non_streaming_action(): + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter() + action = _make_mock_action("normal_action", streaming=False) + + bridge.pre_run_step(action=action) + bridge.post_run_step(exception=None) + + stack = token_stack.get() + assert len(stack) == 0 + + spans = exporter.get_finished_spans() + assert len(spans) == 1 + assert spans[0].name == "normal_action" + assert spans[0].status.status_code == trace.StatusCode.OK + + _reset_token_stack() + + +# ============================================================================ +# Test #25: streaming hierarchy has no action span — chunks are children of method span +# ============================================================================ + + +def test_bridge_streaming_span_hierarchy_no_action_span(): + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.CHUNK_SPANS) + action = _make_mock_action("my_stream", streaming=True) + now = datetime.datetime.now() + + # Simulate full streaming hook sequence + bridge.pre_run_execute_call(method=burr_otel.ExecuteMethod.stream_result) + bridge.pre_run_step(action=action) + + for i in range(3): + bridge.pre_stream_generate( + action="my_stream", + item_index=i, + stream_initialize_time=now, + sequence_id=0, + app_id="app", + partition_key="pk", + ) + bridge.post_stream_generate( + item={"i": i}, + item_index=i, + stream_initialize_time=now, + action="my_stream", + sequence_id=0, + app_id="app", + partition_key="pk", + exception=None, + ) + + bridge.post_run_step(exception=None) + bridge.post_run_execute_call(exception=None) + + spans = exporter.get_finished_spans() + span_names = [s.name for s in spans] + + # Should have 4 spans: 3 chunks + 1 method. No "my_stream" action span. + assert "my_stream" not in span_names + assert "my_stream::chunk_0" in span_names + assert "my_stream::chunk_1" in span_names + assert "my_stream::chunk_2" in span_names + assert "stream_result" in span_names + assert len(spans) == 4 + + # Chunk spans should be children of the stream_result method span + method_span = next(s for s in spans if s.name == "stream_result") + for s in spans: + if s.name.startswith("my_stream::chunk_"): + assert s.parent is not None + assert s.parent.span_id == method_span.context.span_id + + _reset_token_stack() + + +# ============================================================================ +# Test #26: non-streaming then streaming — no state leak +# ============================================================================ + + +def test_bridge_non_streaming_then_streaming_no_state_leak(): + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.CHUNK_SPANS) + now = datetime.datetime.now() + + # First: non-streaming action + normal_action = _make_mock_action("normal", streaming=False) + bridge.pre_run_step(action=normal_action) + bridge.post_run_step(exception=None) + assert _skipped_action_span.get() is False + + # Second: streaming action + stream_action = _make_mock_action("streamer", streaming=True) + bridge.pre_run_step(action=stream_action) + assert _skipped_action_span.get() is True + + bridge.pre_stream_generate( + action="streamer", + item_index=0, + stream_initialize_time=now, + sequence_id=0, + app_id="app", + partition_key="pk", + ) + bridge.post_stream_generate( + item={"x": 1}, + item_index=0, + stream_initialize_time=now, + action="streamer", + sequence_id=0, + app_id="app", + partition_key="pk", + exception=None, + ) + bridge.post_run_step(exception=None) + + spans = exporter.get_finished_spans() + span_names = [s.name for s in spans] + # Should have: "normal" action span + "streamer::chunk_0" chunk span + assert "normal" in span_names + assert "streamer::chunk_0" in span_names + assert "streamer" not in span_names # no action-level span for streaming + + _reset_token_stack() + + +# ============================================================================ +# Test #27: streaming then non-streaming — no state leak +# ============================================================================ + + +def test_bridge_streaming_then_non_streaming_no_state_leak(): + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.CHUNK_SPANS) + now = datetime.datetime.now() + + # First: streaming action + stream_action = _make_mock_action("streamer", streaming=True) + bridge.pre_run_step(action=stream_action) + bridge.pre_stream_generate( + action="streamer", + item_index=0, + stream_initialize_time=now, + sequence_id=0, + app_id="app", + partition_key="pk", + ) + bridge.post_stream_generate( + item={"x": 1}, + item_index=0, + stream_initialize_time=now, + action="streamer", + sequence_id=0, + app_id="app", + partition_key="pk", + exception=None, + ) + bridge.post_run_step(exception=None) + + # Second: non-streaming action + normal_action = _make_mock_action("normal", streaming=False) + bridge.pre_run_step(action=normal_action) + assert _skipped_action_span.get() is False + stack = token_stack.get() + assert len(stack) == 1 + _, span = stack[0] + assert span.name == "normal" + + bridge.post_run_step(exception=None) + + spans = exporter.get_finished_spans() + span_names = [s.name for s in spans] + assert "streamer::chunk_0" in span_names + assert "normal" in span_names + assert "streamer" not in span_names + + _reset_token_stack() + + +# ============================================================================ +# Test #11 (updated): child spans under action span for non-streaming +# ============================================================================ + + +def test_bridge_non_streaming_creates_child_spans_under_action_span(): + """For non-streaming actions, pre/post_start_span creates children of the action span.""" + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter() + action = _make_mock_action("my_action", streaming=False) + + bridge.pre_run_step(action=action) + # Simulate a nested span (e.g., from TracerFactory) + mock_span = MagicMock() + mock_span.name = "inner_op" + bridge.pre_start_span(span=mock_span) + bridge.post_end_span(span=mock_span) + bridge.post_run_step(exception=None) + + spans = exporter.get_finished_spans() + assert len(spans) == 2 + action_span = next(s for s in spans if s.name == "my_action") + inner_span = next(s for s in spans if s.name == "inner_op") + assert inner_span.parent is not None + assert inner_span.parent.span_id == action_span.context.span_id + + _reset_token_stack() + + +# ============================================================================ +# Test #13: span timing excludes consumer time +# ============================================================================ + + +async def test_bridge_stream_span_timing_excludes_consumer_time(): + """Verify that chunk spans measure generation time, not consumer processing time.""" + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.CHUNK_SPANS) + + class SlowGeneratorAction(SingleStepStreamingAction): + """Each yield takes ~50ms of 'generation time'.""" + + async def stream_run_and_update(self, state, **run_kwargs): + for i in range(3): + await asyncio.sleep(0.05) # simulate generation time + yield {"i": i}, None + await asyncio.sleep(0.05) + yield {"i": 3}, state + + @property + def reads(self): + return [] + + @property + def writes(self): + return [] + + action = SlowGeneratorAction().with_name("slow_gen") + state = State({}) + + generator = _arun_single_step_streaming_action( + action=action, + state=state, + inputs={}, + sequence_id=0, + app_id="app", + partition_key="pk", + lifecycle_adapters=LifecycleAdapterSet(bridge), + ) + + # Consumer adds significant delay + async for item, state_update in generator: + await asyncio.sleep(0.3) # simulate slow consumer + + spans = exporter.get_finished_spans() + chunk_spans = [s for s in spans if "chunk_" in s.name] + assert len(chunk_spans) >= 3 # at least 3 intermediate + final + stop + + for span in chunk_spans: + duration_ns = span.end_time - span.start_time + duration_ms = duration_ns / 1e6 + # Each chunk should take roughly 50ms of generation time, + # NOT 350ms (50ms generation + 300ms consumer). + # Use generous tolerance to avoid flakiness. + assert duration_ms < 200, ( + f"Span {span.name} took {duration_ms:.0f}ms, expected <200ms. " + f"Consumer time is leaking into the span." + ) + + _reset_token_stack() + + +# ============================================================================ +# Test #15: full integration — astream_result produces per-yield spans +# ============================================================================ + + +async def test_astream_result_with_otel_bridge_produces_per_yield_spans(): + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.CHUNK_SPANS) + + class SimpleStreamer(SingleStepStreamingAction): + async def stream_run_and_update(self, state, **run_kwargs): + for i in range(5): + yield {"i": i}, None + yield {"i": 5}, state.update(done=True) + + @property + def reads(self): + return [] + + @property + def writes(self): + return ["done"] + + streamer = SimpleStreamer().with_name("streamer") + app = Application( + state=State({"done": False}), + entrypoint="streamer", + adapter_set=LifecycleAdapterSet(bridge), + partition_key="test", + uid="test-app", + graph=Graph( + actions=[streamer], + transitions=[], + ), + ) + + action, container = await app.astream_result(halt_after=["streamer"]) + _ = [item async for item in container] + result, state = await container.get() + + spans = exporter.get_finished_spans() + span_names = [s.name for s in spans] + + # Should have: stream_result method span + chunk spans (no action span) + assert "stream_result" in span_names + assert "streamer" not in span_names # no action-level span + chunk_names = [n for n in span_names if n.startswith("streamer::chunk_")] + # 5 intermediate + 1 final + 1 StopIteration = 7 chunk spans + assert len(chunk_names) >= 5 + + # All chunk spans are children of stream_result + method_span = next(s for s in spans if s.name == "stream_result") + for s in spans: + if s.name.startswith("streamer::chunk_"): + assert s.parent is not None + assert s.parent.span_id == method_span.context.span_id + + _reset_token_stack() + + +# ============================================================================ +# Test #17: non-streaming action via astream_result still gets action span +# ============================================================================ + + +async def test_astream_result_with_otel_bridge_non_streaming_action(): + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter() + + class SimpleAction(SingleStepAction): + def run_and_update(self, state, **run_kwargs): + return {"val": 1}, state.update(val=1) + + @property + def reads(self): + return [] + + @property + def writes(self): + return ["val"] + + action_obj = SimpleAction().with_name("simple") + app = Application( + state=State({"val": 0}), + entrypoint="simple", + adapter_set=LifecycleAdapterSet(bridge), + partition_key="test", + uid="test-app", + graph=Graph( + actions=[action_obj], + transitions=[], + ), + ) + + action, container = await app.astream_result(halt_after=["simple"]) + result, state = await container.get() + + spans = exporter.get_finished_spans() + span_names = [s.name for s in spans] + + # Non-streaming should have an action span + assert "simple" in span_names + # No chunk spans + chunk_names = [n for n in span_names if "chunk_" in n] + assert len(chunk_names) == 0 + + _reset_token_stack() + + +# ============================================================================ +# Mode: "single_span" — backwards compatible (action span, no chunks, no events) +# ============================================================================ + + +def test_bridge_single_span_mode_creates_action_span_for_streaming(): + """In 'single_span' mode, streaming actions get a normal action span, no chunk spans.""" + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.SINGLE_SPAN) + action = _make_mock_action("my_stream", streaming=True) + now = datetime.datetime.now() + + bridge.pre_run_execute_call(method=burr_otel.ExecuteMethod.stream_result) + bridge.pre_run_step(action=action) + + # The action span should have been created + assert _skipped_action_span.get() is False + stack = token_stack.get() + assert len(stack) == 2 # method span + action span + + for i in range(3): + bridge.pre_stream_generate( + action="my_stream", + item_index=i, + stream_initialize_time=now, + sequence_id=0, + app_id="app", + partition_key="pk", + ) + bridge.post_stream_generate( + item={"i": i}, + item_index=i, + stream_initialize_time=now, + action="my_stream", + sequence_id=0, + app_id="app", + partition_key="pk", + exception=None, + ) + + bridge.post_run_step(exception=None) + bridge.post_run_execute_call(exception=None) + + spans = exporter.get_finished_spans() + span_names = [s.name for s in spans] + + # Should have action span + method span, no chunk spans + assert "my_stream" in span_names + assert "stream_result" in span_names + chunk_names = [n for n in span_names if "chunk_" in n] + assert len(chunk_names) == 0 + + # No span events on the action span + action_span = next(s for s in spans if s.name == "my_stream") + assert len(action_span.events) == 0 + + _reset_token_stack() + + +def test_bridge_single_span_mode_sets_attributes_on_action_span(): + """In 'single_span' mode, streaming attributes are set on the action span.""" + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.SINGLE_SPAN) + action = _make_mock_action("my_stream", streaming=True) + now = datetime.datetime.now() + + bridge.pre_run_execute_call(method=burr_otel.ExecuteMethod.stream_result) + bridge.pre_run_step(action=action) + + # Accumulator should be initialized + assert _streaming_accumulator.get() is not None + + for i in range(3): + bridge.pre_stream_generate( + action="my_stream", + item_index=i, + stream_initialize_time=now, + sequence_id=0, + app_id="app", + partition_key="pk", + ) + bridge.post_stream_generate( + item={"i": i}, + item_index=i, + stream_initialize_time=now, + action="my_stream", + sequence_id=0, + app_id="app", + partition_key="pk", + exception=None, + ) + + bridge.post_run_step(exception=None) + bridge.post_run_execute_call(exception=None) + + spans = exporter.get_finished_spans() + action_span = next(s for s in spans if s.name == "my_stream") + + # Should have attributes, not events + assert len(action_span.events) == 0 + attrs = dict(action_span.attributes) + assert attrs["stream.iteration_count"] == 3 + assert "stream.generation_time_ms" in attrs + assert "stream.consumer_time_ms" in attrs + assert "stream.first_item_time_ms" in attrs + + # No chunk spans + chunk_names = [s.name for s in spans if "chunk_" in s.name] + assert len(chunk_names) == 0 + + _reset_token_stack() + + +# ============================================================================ +# Mode: "event" — action span + summary event, no chunk spans +# ============================================================================ + + +def test_bridge_event_mode_emits_stream_completed_event(): + """In 'event' mode, no action span. A stream_completed event is added to the method span.""" + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.EVENT) + action = _make_mock_action("my_stream", streaming=True) + now = datetime.datetime.now() + + bridge.pre_run_execute_call(method=burr_otel.ExecuteMethod.stream_result) + bridge.pre_run_step(action=action) + + # Action span should be skipped + assert _skipped_action_span.get() is True + # Accumulator should be initialized + assert _streaming_accumulator.get() is not None + + for i in range(3): + bridge.pre_stream_generate( + action="my_stream", + item_index=i, + stream_initialize_time=now, + sequence_id=0, + app_id="app", + partition_key="pk", + ) + bridge.post_stream_generate( + item={"i": i}, + item_index=i, + stream_initialize_time=now, + action="my_stream", + sequence_id=0, + app_id="app", + partition_key="pk", + exception=None, + ) + + # Signal end of stream (StopIteration case) + bridge.pre_stream_generate( + action="my_stream", + item_index=3, + stream_initialize_time=now, + sequence_id=0, + app_id="app", + partition_key="pk", + ) + bridge.post_stream_generate( + item=None, + item_index=3, + stream_initialize_time=now, + action="my_stream", + sequence_id=0, + app_id="app", + partition_key="pk", + exception=None, + ) + + bridge.post_run_step(exception=None) + bridge.post_run_execute_call(exception=None) + + spans = exporter.get_finished_spans() + span_names = [s.name for s in spans] + + # Should have only method span, no action span, no chunk spans + assert "stream_result" in span_names + assert "my_stream" not in span_names + chunk_names = [n for n in span_names if "chunk_" in n] + assert len(chunk_names) == 0 + + # Method span should have a stream_completed event + method_span = next(s for s in spans if s.name == "stream_result") + assert len(method_span.events) == 1 + event = method_span.events[0] + assert event.name == "stream_completed" + attrs = dict(event.attributes) + assert "stream.generation_time_ms" in attrs + assert "stream.consumer_time_ms" in attrs + assert "stream.total_time_ms" in attrs + assert attrs["stream.iteration_count"] == 3 + assert "stream.first_item_time_ms" in attrs + + # Accumulator should be cleaned up + assert _streaming_accumulator.get() is None + + _reset_token_stack() + + +def test_bridge_event_mode_emits_stream_error_event(): + """In 'event' mode with an exception, a stream_error event is emitted on method span.""" + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.EVENT) + action = _make_mock_action("my_stream", streaming=True) + now = datetime.datetime.now() + + bridge.pre_run_execute_call(method=burr_otel.ExecuteMethod.stream_result) + bridge.pre_run_step(action=action) + + # One successful yield + bridge.pre_stream_generate( + action="my_stream", + item_index=0, + stream_initialize_time=now, + sequence_id=0, + app_id="app", + partition_key="pk", + ) + bridge.post_stream_generate( + item={"i": 0}, + item_index=0, + stream_initialize_time=now, + action="my_stream", + sequence_id=0, + app_id="app", + partition_key="pk", + exception=None, + ) + + # Error on second yield + exc = RuntimeError("stream failed") + bridge.pre_stream_generate( + action="my_stream", + item_index=1, + stream_initialize_time=now, + sequence_id=0, + app_id="app", + partition_key="pk", + ) + bridge.post_stream_generate( + item=None, + item_index=1, + stream_initialize_time=now, + action="my_stream", + sequence_id=0, + app_id="app", + partition_key="pk", + exception=exc, + ) + + bridge.post_run_step(exception=exc) + bridge.post_run_execute_call(exception=exc) + + spans = exporter.get_finished_spans() + # No action span — event is on the method span + method_span = next(s for s in spans if s.name == "stream_result") + + assert len(method_span.events) == 1 + event = method_span.events[0] + assert event.name == "stream_error" + attrs = dict(event.attributes) + assert attrs["stream.iteration_count"] == 1 # only 1 successful yield + assert "stream.error" in attrs + assert "stream failed" in attrs["stream.error"] + + _reset_token_stack() + + +def test_bridge_event_mode_accumulator_timing_values(): + """Verify that the accumulator separates generation time from consumer time.""" + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.EVENT) + action = _make_mock_action("my_stream", streaming=True) + now = datetime.datetime.now() + + bridge.pre_run_execute_call(method=burr_otel.ExecuteMethod.stream_result) + bridge.pre_run_step(action=action) + + # Simulate 2 yields with measurable time gaps + bridge.pre_stream_generate( + action="my_stream", + item_index=0, + stream_initialize_time=now, + sequence_id=0, + app_id="app", + partition_key="pk", + ) + time.sleep(0.05) # ~50ms generation time + bridge.post_stream_generate( + item={"i": 0}, + item_index=0, + stream_initialize_time=now, + action="my_stream", + sequence_id=0, + app_id="app", + partition_key="pk", + exception=None, + ) + time.sleep(0.1) # ~100ms consumer time + + bridge.pre_stream_generate( + action="my_stream", + item_index=1, + stream_initialize_time=now, + sequence_id=0, + app_id="app", + partition_key="pk", + ) + time.sleep(0.05) # ~50ms generation time + bridge.post_stream_generate( + item={"i": 1}, + item_index=1, + stream_initialize_time=now, + action="my_stream", + sequence_id=0, + app_id="app", + partition_key="pk", + exception=None, + ) + + bridge.post_run_step(exception=None) + bridge.post_run_execute_call(exception=None) + + spans = exporter.get_finished_spans() + # Event is on the method span (no action span in EVENT mode) + method_span = next(s for s in spans if s.name == "stream_result") + event = method_span.events[0] + attrs = dict(event.attributes) + + gen_ms = attrs["stream.generation_time_ms"] + consumer_ms = attrs["stream.consumer_time_ms"] + total_ms = attrs["stream.total_time_ms"] + first_item_ms = attrs["stream.first_item_time_ms"] + + # Generation: ~100ms total (2 × 50ms) + assert gen_ms >= 50, f"generation_time_ms={gen_ms}, expected >= 50" + assert gen_ms < 300, f"generation_time_ms={gen_ms}, expected < 300" + + # Consumer: ~100ms (gap between first post and second pre) + assert consumer_ms >= 50, f"consumer_time_ms={consumer_ms}, expected >= 50" + assert consumer_ms < 300, f"consumer_time_ms={consumer_ms}, expected < 300" + + # Total should be >= generation + consumer + assert total_ms >= gen_ms, f"total_time_ms={total_ms} < generation_time_ms={gen_ms}" + + # First item time should be close to first generation time (~50ms) + assert first_item_ms >= 20, f"first_item_time_ms={first_item_ms}, expected >= 20" + assert first_item_ms < 200, f"first_item_time_ms={first_item_ms}, expected < 200" + + assert attrs["stream.iteration_count"] == 2 + + _reset_token_stack() + + +# ============================================================================ +# Mode: SINGLE_AND_CHUNK_SPANS — action span with attributes + per-yield child spans +# ============================================================================ + + +def test_bridge_single_and_chunk_spans_mode(): + """In SINGLE_AND_CHUNK_SPANS mode, action span has streaming attributes AND per-yield child spans exist.""" + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.SINGLE_AND_CHUNK_SPANS) + action = _make_mock_action("my_stream", streaming=True) + now = datetime.datetime.now() + + bridge.pre_run_execute_call(method=burr_otel.ExecuteMethod.stream_result) + bridge.pre_run_step(action=action) + + # Action span should be created (not skipped) + assert _skipped_action_span.get() is False + + for i in range(3): + bridge.pre_stream_generate( + action="my_stream", + item_index=i, + stream_initialize_time=now, + sequence_id=0, + app_id="app", + partition_key="pk", + ) + bridge.post_stream_generate( + item={"i": i}, + item_index=i, + stream_initialize_time=now, + action="my_stream", + sequence_id=0, + app_id="app", + partition_key="pk", + exception=None, + ) + + bridge.post_run_step(exception=None) + bridge.post_run_execute_call(exception=None) + + spans = exporter.get_finished_spans() + span_names = [s.name for s in spans] + + # Should have: method span + action span + 3 chunk spans + assert "stream_result" in span_names + assert "my_stream" in span_names + assert "my_stream::chunk_0" in span_names + assert "my_stream::chunk_1" in span_names + assert "my_stream::chunk_2" in span_names + assert len(spans) == 5 + + # Action span should have streaming attributes (not events) + action_span = next(s for s in spans if s.name == "my_stream") + assert len(action_span.events) == 0 + attrs = dict(action_span.attributes) + assert attrs["stream.iteration_count"] == 3 + assert "stream.generation_time_ms" in attrs + assert "stream.consumer_time_ms" in attrs + assert "stream.first_item_time_ms" in attrs + + # Chunk spans should be children of the action span + for s in spans: + if s.name.startswith("my_stream::chunk_"): + assert s.parent is not None + assert s.parent.span_id == action_span.context.span_id + + _reset_token_stack() + + +# ============================================================================ +# Mode: "chunk_spans" — per-yield spans, no action span (already covered above, +# this test verifies no event is emitted) +# ============================================================================ + + +def test_bridge_chunk_spans_mode_no_event_emitted(): + """In 'chunk_spans' mode, no span event is emitted (no accumulator).""" + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.CHUNK_SPANS) + action = _make_mock_action("my_stream", streaming=True) + now = datetime.datetime.now() + + bridge.pre_run_execute_call(method=burr_otel.ExecuteMethod.stream_result) + bridge.pre_run_step(action=action) + + # No accumulator in spans-only mode + assert _streaming_accumulator.get() is None + + bridge.pre_stream_generate( + action="my_stream", + item_index=0, + stream_initialize_time=now, + sequence_id=0, + app_id="app", + partition_key="pk", + ) + bridge.post_stream_generate( + item={"i": 0}, + item_index=0, + stream_initialize_time=now, + action="my_stream", + sequence_id=0, + app_id="app", + partition_key="pk", + exception=None, + ) + + bridge.post_run_step(exception=None) + bridge.post_run_execute_call(exception=None) + + spans = exporter.get_finished_spans() + + # Only chunk span + method span, no action span + span_names = [s.name for s in spans] + assert "my_stream::chunk_0" in span_names + assert "stream_result" in span_names + assert "my_stream" not in span_names + + # No events on any span + for s in spans: + assert len(s.events) == 0 + + _reset_token_stack() + + +# ============================================================================ +# Integration: "event" mode with astream_result +# ============================================================================ + + +async def test_astream_result_event_mode_produces_summary_event(): + """Full integration test: event mode with astream_result produces summary event.""" + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.EVENT) + + class SimpleStreamer(SingleStepStreamingAction): + async def stream_run_and_update(self, state, **run_kwargs): + for i in range(5): + yield {"i": i}, None + yield {"i": 5}, state.update(done=True) + + @property + def reads(self): + return [] + + @property + def writes(self): + return ["done"] + + streamer = SimpleStreamer().with_name("streamer") + app = Application( + state=State({"done": False}), + entrypoint="streamer", + adapter_set=LifecycleAdapterSet(bridge), + partition_key="test", + uid="test-app", + graph=Graph( + actions=[streamer], + transitions=[], + ), + ) + + action, container = await app.astream_result(halt_after=["streamer"]) + _ = [item async for item in container] + result, state = await container.get() + + spans = exporter.get_finished_spans() + span_names = [s.name for s in spans] + + # Should have only method span, no action span, no chunk spans + assert "stream_result" in span_names + assert "streamer" not in span_names + chunk_names = [n for n in span_names if "chunk_" in n] + assert len(chunk_names) == 0 + + # Method span should have stream_completed event + method_span = next(s for s in spans if s.name == "stream_result") + assert len(method_span.events) == 1 + event = method_span.events[0] + assert event.name == "stream_completed" + attrs = dict(event.attributes) + # 5 intermediate + 1 final = 6 yielded items + assert attrs["stream.iteration_count"] >= 5 + assert attrs["stream.generation_time_ms"] >= 0 + assert attrs["stream.consumer_time_ms"] >= 0 + assert attrs["stream.total_time_ms"] >= 0 + assert attrs["stream.first_item_time_ms"] >= 0 + + _reset_token_stack() From 7f34286be6cfb473a671d9f77c2420ba3764597d Mon Sep 17 00:00:00 2001 From: Stefan Krawczyk Date: Sat, 28 Feb 2026 22:35:23 -0800 Subject: [PATCH 2/3] Add streaming timing to Burr tracker, UI, and documentation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Surface generation-vs-consumer timing data in the Burr tracker and UI, independent of the OpenTelemetry streaming telemetry modes. Tracker changes: - Extend StreamState with timing fields (generation_time_ns, consumer_time_ns, first_item_time_ns, etc.) accumulated via PreStreamGenerateHook/PostStreamGenerateHook - Add generate hooks to SyncTrackingClient in both base.py and client.py with defensive getattr for subclass compatibility - Add optional timing fields to EndStreamModel (generation_time_ms, consumer_time_ms, first_item_time_ms) with None defaults for backwards compatibility - Update post_end_stream in both LocalTrackingClient and S3TrackingClient to convert and write accumulated timing UI changes: - Add timing fields to EndStreamModel.ts TypeScript type - Update StepList.tsx end_stream rendering to show "gen: Xms · consumer: Yms · N items · TTFT: Zms" when timing data is available, falling back to legacy throughput display Documentation: - Add "Streaming Telemetry Modes" section to additional-visibility.rst - Add StreamingTelemetryMode to opentelemetry.rst API reference - Add 4 new generate hook classes to lifecycle.rst - Add "Telemetry & Observability" section to streaming-actions.rst - Add "Streaming Timing" section to tracking.rst - Update monitoring.rst and examples/opentelemetry/README.md - Update design doc to reflect final implementation Example: - Add --tracker flag to streaming_telemetry_modes.py for validating timing data in the Burr UI Tests: - 6 new tests covering StreamState defaults, timing accumulation, defensive noop, EndStreamModel backwards compat, and sync/async end-to-end with LocalTrackingClient --- burr/tracking/base.py | 105 ++++++- burr/tracking/client.py | 136 +++++++++ burr/tracking/common/models.py | 17 +- burr/tracking/s3client.py | 14 + docs/concepts/additional-visibility.rst | 42 +++ docs/concepts/streaming-actions.rst | 18 ++ docs/examples/deployment/monitoring.rst | 5 +- docs/reference/integrations/opentelemetry.rst | 4 + docs/reference/lifecycle.rst | 12 + docs/reference/tracking.rst | 15 + examples/opentelemetry/README.md | 15 +- .../streaming_telemetry_modes.py | 38 ++- telemetry/ui/src/api/models/EndStreamModel.ts | 11 +- .../ui/src/components/routes/app/StepList.tsx | 24 +- tests/tracking/test_local_tracking_client.py | 272 +++++++++++++++++- 15 files changed, 710 insertions(+), 18 deletions(-) diff --git a/burr/tracking/base.py b/burr/tracking/base.py index d8b3f54f3..ab26784df 100644 --- a/burr/tracking/base.py +++ b/burr/tracking/base.py @@ -16,6 +16,9 @@ # under the License. import abc +import datetime +import time +from typing import Any, Optional from burr.lifecycle import ( PostApplicationCreateHook, @@ -27,8 +30,10 @@ from burr.lifecycle.base import ( DoLogAttributeHook, PostEndStreamHook, + PostStreamGenerateHook, PostStreamItemHook, PreStartStreamHook, + PreStreamGenerateHook, ) @@ -42,10 +47,106 @@ class SyncTrackingClient( PreStartStreamHook, PostStreamItemHook, PostEndStreamHook, + PreStreamGenerateHook, + PostStreamGenerateHook, abc.ABC, ): - """Base class for synchronous tracking clients. All tracking clients must implement from this - TODO -- create an async tracking client""" + """Base class for synchronous tracking clients. + + Inherits from PreStreamGenerateHook/PostStreamGenerateHook so that all + tracker implementations automatically accumulate generation-vs-consumer + timing for streaming actions. The accumulated data is written to the + EndStreamModel in post_end_stream. + + Subclasses do NOT need to override pre_stream_generate/post_stream_generate + unless they want custom behavior — the default implementations here handle + timing accumulation using the StreamState dataclass. + + TODO -- create an async tracking client + """ + + def pre_stream_generate( + self, + *, + item_index: int, + stream_initialize_time: datetime.datetime, + action: str, + sequence_id: int, + app_id: str, + partition_key: Optional[str], + **future_kwargs: Any, + ): + """Records the start of a single generator __next__() call. + + Uses defensive getattr to access stream_state so that custom subclasses + that don't call super().__init__() or don't have stream_state won't crash. + """ + stream_state = getattr(self, "stream_state", None) + if stream_state is None: + return + key = (app_id, action, partition_key) + state = stream_state.get(key) + if state is None: + return + + now_ns = time.monotonic_ns() + state._pre_generate_ns = now_ns + + # Record the stream start time on the first yield + if state.stream_start_ns is None: + state.stream_start_ns = now_ns + + # Consumer time = gap between previous post_stream_generate and this + # pre_stream_generate. On the first call there's no previous post, so + # consumer_time stays at 0. + if state.last_post_generate_ns is not None: + state.consumer_time_ns += now_ns - state.last_post_generate_ns + + def post_stream_generate( + self, + *, + item: Any, + item_index: int, + stream_initialize_time: datetime.datetime, + action: str, + sequence_id: int, + app_id: str, + partition_key: Optional[str], + exception: Optional[Exception] = None, + **future_kwargs: Any, + ): + """Records the end of a single generator __next__() call. + + Accumulates generation_time_ns from the paired pre_stream_generate call, + tracks iteration_count, and captures first_item_time_ns for TTFT. + + Uses defensive getattr to access stream_state so that custom subclasses + that don't call super().__init__() or don't have stream_state won't crash. + """ + stream_state = getattr(self, "stream_state", None) + if stream_state is None: + return + key = (app_id, action, partition_key) + state = stream_state.get(key) + if state is None: + return + + now_ns = time.monotonic_ns() + state.last_post_generate_ns = now_ns + + # Accumulate generation time (time spent inside the generator) + if state._pre_generate_ns is not None: + state.generation_time_ns += now_ns - state._pre_generate_ns + state._pre_generate_ns = None + + # Track iteration count (only for actual items, not StopIteration) + if item is not None: + state.iteration_count += 1 + + # Capture TTFT (time from stream start to first item) + if state.first_item_time_ns is None and item is not None: + if state.stream_start_ns is not None: + state.first_item_time_ns = now_ns - state.stream_start_ns @abc.abstractmethod def copy(self): diff --git a/burr/tracking/client.py b/burr/tracking/client.py index 44919aed5..1ba49af35 100644 --- a/burr/tracking/client.py +++ b/burr/tracking/client.py @@ -18,13 +18,16 @@ import abc import dataclasses import datetime +import time from burr.common.types import BaseCopyable from burr.lifecycle.base import ( DoLogAttributeHook, PostEndStreamHook, + PostStreamGenerateHook, PostStreamItemHook, PreStartStreamHook, + PreStreamGenerateHook, ) # this is a quick hack to get it to work on windows @@ -120,9 +123,38 @@ def _allowed_project_name(project_name: str, on_windows: bool) -> bool: @dataclasses.dataclass class StreamState: + """Tracks state for an in-progress stream. + + The timing fields (generation_time_ns, consumer_time_ns, etc.) are populated + by the PreStreamGenerateHook/PostStreamGenerateHook implementations on the + tracker. They accumulate generation vs. consumer timing across all yields, + enabling the tracker to write a timing summary when the stream ends. + + These fields default to 0/None so that existing code that only uses + stream_init_time/count continues to work unchanged. + """ + stream_init_time: datetime.datetime count: Optional[int] + # --- Streaming timing fields (populated by pre/post_stream_generate) --- + # Accumulated wall-clock nanoseconds the generator spent producing items. + generation_time_ns: int = 0 + # Accumulated wall-clock nanoseconds the consumer spent processing items. + consumer_time_ns: int = 0 + # Total number of items the generator has yielded so far. + iteration_count: int = 0 + # Nanosecond timestamp of the first item produced (for TTFT calculation). + first_item_time_ns: Optional[int] = None + # Nanosecond timestamp when the stream started (first pre_stream_generate). + stream_start_ns: Optional[int] = None + # Nanosecond timestamp of the most recent post_stream_generate call, + # used to compute consumer_time between yields. + last_post_generate_ns: Optional[int] = None + # Nanosecond timestamp captured at the start of the current generation + # (set in pre_stream_generate, consumed in post_stream_generate). + _pre_generate_ns: Optional[int] = None + StateKey = Tuple[str, str, Optional[str]] @@ -137,9 +169,98 @@ class SyncTrackingClient( PreStartStreamHook, PostStreamItemHook, PostEndStreamHook, + PreStreamGenerateHook, + PostStreamGenerateHook, BaseCopyable, ABC, ): + """Synchronous tracking client base class (client.py variant). + + Includes PreStreamGenerateHook/PostStreamGenerateHook so that all tracker + implementations automatically accumulate generation-vs-consumer timing for + streaming actions. The concrete implementations below populate the + StreamState timing fields; post_end_stream reads them to write the + EndStreamModel with timing data. + + Subclasses do NOT need to override pre/post_stream_generate unless they + want custom behavior. + """ + + def pre_stream_generate( + self, + *, + item_index: int, + stream_initialize_time: datetime.datetime, + action: str, + sequence_id: int, + app_id: str, + partition_key: Optional[str], + **future_kwargs: Any, + ): + """Records the start of a single generator __next__() call. + + Uses defensive getattr so custom subclasses without stream_state + won't crash. + """ + stream_state = getattr(self, "stream_state", None) + if stream_state is None: + return + key = (app_id, action, partition_key) + state = stream_state.get(key) + if state is None: + return + + now_ns = time.monotonic_ns() + state._pre_generate_ns = now_ns + + if state.stream_start_ns is None: + state.stream_start_ns = now_ns + + # Consumer time = gap between previous post_stream_generate and this call. + # On the first call there's no previous post, so consumer_time stays at 0. + if state.last_post_generate_ns is not None: + state.consumer_time_ns += now_ns - state.last_post_generate_ns + + def post_stream_generate( + self, + *, + item: Any, + item_index: int, + stream_initialize_time: datetime.datetime, + action: str, + sequence_id: int, + app_id: str, + partition_key: Optional[str], + exception: Optional[Exception] = None, + **future_kwargs: Any, + ): + """Records the end of a single generator __next__() call. + + Accumulates generation_time_ns, tracks iteration_count, and captures + first_item_time_ns for TTFT. Uses defensive getattr for compatibility. + """ + stream_state = getattr(self, "stream_state", None) + if stream_state is None: + return + key = (app_id, action, partition_key) + state = stream_state.get(key) + if state is None: + return + + now_ns = time.monotonic_ns() + state.last_post_generate_ns = now_ns + + if state._pre_generate_ns is not None: + state.generation_time_ns += now_ns - state._pre_generate_ns + state._pre_generate_ns = None + + if item is not None: + state.iteration_count += 1 + + if state.first_item_time_ns is None and item is not None: + if state.stream_start_ns is not None: + state.first_item_time_ns = now_ns - state.stream_start_ns + @abc.abstractmethod def copy(self) -> Self: """Clones the tracking client. This is useful for forking applications. @@ -591,12 +712,27 @@ def post_end_stream( **future_kwargs: Any, ): stream_state = self.stream_state[app_id, action, partition_key] + # Convert nanosecond timing accumulated by pre/post_stream_generate + # into millisecond floats for the EndStreamModel. If stream_start_ns + # is None, the generate hooks never fired (e.g. the action isn't using + # the instrumented generator), so we leave timing fields as None. + generation_time_ms = None + consumer_time_ms = None + first_item_time_ms = None + if stream_state.stream_start_ns is not None: + generation_time_ms = stream_state.generation_time_ns / 1_000_000 + consumer_time_ms = stream_state.consumer_time_ns / 1_000_000 + if stream_state.first_item_time_ns is not None: + first_item_time_ms = stream_state.first_item_time_ns / 1_000_000 self._append_write_line( EndStreamModel( action_sequence_id=sequence_id, span_id=None, end_time=system.now(), items_streamed=stream_state.count, + generation_time_ms=generation_time_ms, + consumer_time_ms=consumer_time_ms, + first_item_time_ms=first_item_time_ms, ) ) del self.stream_state[app_id, action, partition_key] diff --git a/burr/tracking/common/models.py b/burr/tracking/common/models.py index 5980bf9df..911b1e48e 100644 --- a/burr/tracking/common/models.py +++ b/burr/tracking/common/models.py @@ -261,7 +261,14 @@ def sequence_id(self) -> int: class EndStreamModel(IdentifyingModel): - """Pydantic model that represents an entry for the first item of a stream""" + """Pydantic model that represents the end of a stream. + + The optional timing fields (generation_time_ms, consumer_time_ms, etc.) are + populated when the tracker has PreStreamGenerateHook/PostStreamGenerateHook + support. They are Optional so that: + - Old log files (without timing) still parse with new server code. + - New log files don't crash old server code (Pydantic ignores extra keys). + """ action_sequence_id: int span_id: Optional[ @@ -271,6 +278,14 @@ class EndStreamModel(IdentifyingModel): items_streamed: int type: str = "end_stream" + # --- Streaming timing summary (Optional for backwards compatibility) --- + # Sum of time spent inside the generator producing items (excludes consumer wait). + generation_time_ms: Optional[float] = None + # Sum of time the consumer spent processing yielded items between yields. + consumer_time_ms: Optional[float] = None + # Time from stream start to first item produced (time to first token / TTFT). + first_item_time_ms: Optional[float] = None + @property def sequence_id(self) -> int: return self.action_sequence_id diff --git a/burr/tracking/s3client.py b/burr/tracking/s3client.py index 561d517af..e6857b896 100644 --- a/burr/tracking/s3client.py +++ b/burr/tracking/s3client.py @@ -500,12 +500,26 @@ def post_end_stream( **future_kwargs: Any, ): stream_state = self.stream_state[app_id, action, partition_key] + # Convert nanosecond timing accumulated by pre/post_stream_generate + # into millisecond floats for the EndStreamModel. If stream_start_ns + # is None, the generate hooks never fired, so we leave timing as None. + generation_time_ms = None + consumer_time_ms = None + first_item_time_ms = None + if stream_state.stream_start_ns is not None: + generation_time_ms = stream_state.generation_time_ns / 1_000_000 + consumer_time_ms = stream_state.consumer_time_ns / 1_000_000 + if stream_state.first_item_time_ns is not None: + first_item_time_ms = stream_state.first_item_time_ns / 1_000_000 self.submit_log_event( EndStreamModel( action_sequence_id=sequence_id, span_id=None, end_time=system.now(), items_streamed=stream_state.count, + generation_time_ms=generation_time_ms, + consumer_time_ms=consumer_time_ms, + first_item_time_ms=first_item_time_ms, ), app_id, partition_key, diff --git a/docs/concepts/additional-visibility.rst b/docs/concepts/additional-visibility.rst index a315e68da..656ac44c2 100644 --- a/docs/concepts/additional-visibility.rst +++ b/docs/concepts/additional-visibility.rst @@ -308,6 +308,48 @@ it as you see fit). With this you can log to any OpenTelemetry provider. +Streaming Telemetry Modes +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +When using streaming actions with the ``OpenTelemetryBridge``, you can control how +streaming telemetry is emitted via the ``streaming_telemetry`` parameter. This accepts +a :py:class:`StreamingTelemetryMode ` enum value: + +.. code-block:: python + + from burr.integrations.opentelemetry import OpenTelemetryBridge, StreamingTelemetryMode + + # Default: one action span with streaming attributes (generation time, consumer time, TTFT) + bridge = OpenTelemetryBridge(tracer=tracer, streaming_telemetry=StreamingTelemetryMode.SINGLE_SPAN) + + # Lightest-weight: no action span, single summary event on the method span + bridge = OpenTelemetryBridge(tracer=tracer, streaming_telemetry=StreamingTelemetryMode.EVENT) + + # Per-yield spans: one child span per generator yield, measuring generation time only + bridge = OpenTelemetryBridge(tracer=tracer, streaming_telemetry=StreamingTelemetryMode.CHUNK_SPANS) + + # Maximum visibility: action span with attributes + per-yield child spans + bridge = OpenTelemetryBridge(tracer=tracer, streaming_telemetry=StreamingTelemetryMode.SINGLE_AND_CHUNK_SPANS) + +The four modes are: + +- **SINGLE_SPAN** (default) — The action span covers the full generator lifetime. Span attributes + provide the generation-vs-consumer time breakdown (``stream.generation_time_ms``, + ``stream.consumer_time_ms``, ``stream.iteration_count``, ``stream.first_item_time_ms``). + This is backwards-compatible with the pre-existing behavior. +- **EVENT** — No action span is created. A single ``stream_completed`` event is added to the + method span with the same timing attributes. This is the lightest-weight option. +- **CHUNK_SPANS** — No action span. One child span per generator yield, each measuring only + the generation time for that item. +- **SINGLE_AND_CHUNK_SPANS** — Combines SINGLE_SPAN and CHUNK_SPANS: the action span with + streaming attributes, plus per-yield child spans nested under it. + +Non-streaming actions are unaffected by this setting — they always get an action span regardless +of the mode. + +See the `streaming_telemetry_modes.py example `_ +for a runnable demo of all four modes. + Instrumenting libraries ----------------------- diff --git a/docs/concepts/streaming-actions.rst b/docs/concepts/streaming-actions.rst index 95e3a06e0..edf642417 100644 --- a/docs/concepts/streaming-actions.rst +++ b/docs/concepts/streaming-actions.rst @@ -242,3 +242,21 @@ be consistent with the asynchronous method. If you're using the old version, the 1. The return type of the streaming action should be ``Generator[Tuple[dict, Optional[State], None, None]]`` instead of ``Generator[dict, None, Tuple[dict, State]]``. 2. All intermediate results should be yielded as ``yield {'response': delta}, None`` instead of ``yield {'response': delta}``. 3. The final result will be a ``yield`` instead of a ``return`` + +----------------------- +Telemetry & Observability +----------------------- + +Streaming actions are instrumented with lifecycle hooks that bracket each generator +yield. This enables two forms of observability: + +**OpenTelemetry** — The :py:class:`OpenTelemetryBridge ` +supports configurable streaming telemetry modes (``SINGLE_SPAN``, ``EVENT``, ``CHUNK_SPANS``, +``SINGLE_AND_CHUNK_SPANS``) that control how streaming spans and events are emitted. These +modes provide timing attributes such as generation time, consumer time, iteration count, and +time to first token (TTFT). See :ref:`Streaming Telemetry Modes ` for details. + +**Burr Tracker** — The :py:class:`LocalTrackingClient ` (and +``S3TrackingClient``) automatically accumulate generation-vs-consumer timing for streaming +actions and write it to the ``end_stream`` log entry. The Burr UI displays this data in the +step detail view, showing the generation/consumer time split and TTFT. diff --git a/docs/examples/deployment/monitoring.rst b/docs/examples/deployment/monitoring.rst index 44f104e24..0a0909e4d 100644 --- a/docs/examples/deployment/monitoring.rst +++ b/docs/examples/deployment/monitoring.rst @@ -22,7 +22,10 @@ Monitoring in Production ------------------------ Burr's telemetry UI is meant both for debugging and running in production. It can consume `OpenTelemetry traces `_, -and has a suite of useful capabilities for debugging Burr applications. +and has a suite of useful capabilities for debugging Burr applications. For streaming actions, the tracker +and UI surface generation-vs-consumer timing (including time to first token), and the +``OpenTelemetryBridge`` supports :ref:`configurable streaming telemetry modes ` for +controlling span and event granularity. It has two (current) implementations: diff --git a/docs/reference/integrations/opentelemetry.rst b/docs/reference/integrations/opentelemetry.rst index e68ffd7c4..ff8755e60 100644 --- a/docs/reference/integrations/opentelemetry.rst +++ b/docs/reference/integrations/opentelemetry.rst @@ -41,4 +41,8 @@ Reference for the various useful methods: .. autoclass:: burr.integrations.opentelemetry.OpenTelemetryBridge :members: +.. autoclass:: burr.integrations.opentelemetry.StreamingTelemetryMode + :members: + :undoc-members: + .. autofunction:: burr.integrations.opentelemetry.init_instruments diff --git a/docs/reference/lifecycle.rst b/docs/reference/lifecycle.rst index 84703e931..0175e76ba 100644 --- a/docs/reference/lifecycle.rst +++ b/docs/reference/lifecycle.rst @@ -68,6 +68,18 @@ and add instances to the application builder to customize your state machines's .. autoclass:: burr.lifecycle.base.PostApplicationExecuteCallHookAsync :members: +.. autoclass:: burr.lifecycle.base.PreStreamGenerateHook + :members: + +.. autoclass:: burr.lifecycle.base.PreStreamGenerateHookAsync + :members: + +.. autoclass:: burr.lifecycle.base.PostStreamGenerateHook + :members: + +.. autoclass:: burr.lifecycle.base.PostStreamGenerateHookAsync + :members: + These hooks are available for you to use: .. autoclass:: burr.lifecycle.default.StateAndResultsFullLogger diff --git a/docs/reference/tracking.rst b/docs/reference/tracking.rst index 4bd268be4..b48a36518 100644 --- a/docs/reference/tracking.rst +++ b/docs/reference/tracking.rst @@ -29,3 +29,18 @@ Rather, you should use this through/in conjunction with :py:meth:`burr.core.appl :members: .. automethod:: __init__ + +Streaming Timing +~~~~~~~~~~~~~~~~ + +For streaming actions, the tracker automatically accumulates timing data by implementing +``PreStreamGenerateHook`` and ``PostStreamGenerateHook``. When a streaming action completes, +the ``end_stream`` log entry includes the following optional timing fields: + +- ``generation_time_ms`` — Sum of time spent inside the generator producing items (excludes consumer wait time). +- ``consumer_time_ms`` — Sum of time the consumer spent processing yielded items between yields. +- ``first_item_time_ms`` — Time from stream start to first item produced (time to first token / TTFT). + +These fields are ``null`` when the streaming timing hooks have not fired (e.g. old log files or +non-instrumented generators). The Burr UI renders these fields in the step detail view when +available, falling back to the legacy throughput calculation otherwise. diff --git a/examples/opentelemetry/README.md b/examples/opentelemetry/README.md index d0c18c6ca..ff350c2d5 100644 --- a/examples/opentelemetry/README.md +++ b/examples/opentelemetry/README.md @@ -27,6 +27,19 @@ We have two modes: 2. Log Burr to OpenTelemetry See [notebook.ipynb](./notebook.ipynb) for a simple overview. -See [application.py](./application.py) for the full code +See [application.py](./application.py) for the full code. + +## Streaming Telemetry + +For streaming actions, the `OpenTelemetryBridge` supports four configurable +telemetry modes via `StreamingTelemetryMode`: + +- **SINGLE_SPAN** (default) — one action span with streaming attributes (generation time, consumer time, TTFT) +- **EVENT** — no action span, single summary event on the method span +- **CHUNK_SPANS** — per-yield child spans measuring generation time only +- **SINGLE_AND_CHUNK_SPANS** — action span with attributes + per-yield child spans + +See [streaming_telemetry_modes.py](./streaming_telemetry_modes.py) for a runnable +demo exercising all four modes with the console exporter. See the [documentation](https://burr.dagworks.io/concepts/additional-visibility/#open-telemetry) for more info diff --git a/examples/opentelemetry/streaming_telemetry_modes.py b/examples/opentelemetry/streaming_telemetry_modes.py index c50381f12..67a653bce 100644 --- a/examples/opentelemetry/streaming_telemetry_modes.py +++ b/examples/opentelemetry/streaming_telemetry_modes.py @@ -20,14 +20,23 @@ Runs a simple async streaming action under each mode with the OTel console exporter so you can see the spans and events printed to stdout. +When --tracker is passed, each mode also gets a LocalTrackingClient so the +results show up in the Burr UI (run ``burr`` to open it). + Usage: + # OTel console output only python examples/opentelemetry/streaming_telemetry_modes.py + # OTel console output + Burr tracker (viewable in the UI) + python examples/opentelemetry/streaming_telemetry_modes.py --tracker + No external APIs are needed — the streaming action simulates an LLM by yielding tokens with small delays. """ +import argparse import asyncio +import time from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import ConsoleSpanExporter, SimpleSpanProcessor @@ -68,7 +77,7 @@ async def generate_response(state: State) -> None: # --------------------------------------------------------------------------- -async def run_with_mode(mode: StreamingTelemetryMode) -> None: +async def run_with_mode(mode: StreamingTelemetryMode, use_tracker: bool = False) -> None: """Builds an app with the given streaming telemetry mode and runs it.""" # Each run gets its own tracer provider so console output stays grouped provider = TracerProvider() @@ -77,16 +86,20 @@ async def run_with_mode(mode: StreamingTelemetryMode) -> None: bridge = OpenTelemetryBridge(tracer=tracer, streaming_telemetry=mode) - app = ( + builder = ( ApplicationBuilder() .with_graph(graph) .with_entrypoint("generate") .with_state(State({"prompt": "hello world from burr streaming"})) .with_hooks(bridge) - .with_identifiers(app_id=f"demo-{mode.value}") - .build() + .with_identifiers(app_id=f"demo-{mode.value}-{time.time()}") ) + if use_tracker: + builder = builder.with_tracker(project="streaming-telemetry-modes", tracker="local") + + app = builder.build() + action, container = await app.astream_result(halt_after=["generate"]) async for item in container: await asyncio.sleep(0.05) # simulate consumer processing time per token @@ -100,7 +113,7 @@ async def run_with_mode(mode: StreamingTelemetryMode) -> None: # --------------------------------------------------------------------------- -async def main(): +async def main(use_tracker: bool = False): modes = [ StreamingTelemetryMode.SINGLE_SPAN, StreamingTelemetryMode.EVENT, @@ -111,9 +124,20 @@ async def main(): print(f"\n{'=' * 70}") print(f" StreamingTelemetryMode.{mode.name}") print(f"{'=' * 70}\n") - await run_with_mode(mode) + await run_with_mode(mode, use_tracker=use_tracker) print() + if use_tracker: + print("Tracker data written to ~/.burr/streaming-telemetry-modes/") + print("Run `burr` to open the UI and inspect the streaming timing data.") + if __name__ == "__main__": - asyncio.run(main()) + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--tracker", + action="store_true", + help="Enable the Burr LocalTrackingClient so results appear in the UI", + ) + args = parser.parse_args() + asyncio.run(main(use_tracker=args.tracker)) diff --git a/telemetry/ui/src/api/models/EndStreamModel.ts b/telemetry/ui/src/api/models/EndStreamModel.ts index 60dcd1d7e..061c95a2e 100644 --- a/telemetry/ui/src/api/models/EndStreamModel.ts +++ b/telemetry/ui/src/api/models/EndStreamModel.ts @@ -22,7 +22,10 @@ /* tslint:disable */ /* eslint-disable */ /** - * Pydantic model that represents an entry for the first item of a stream + * Pydantic model that represents the end of a stream. + * + * The optional timing fields are populated when the tracker has + * PreStreamGenerateHook/PostStreamGenerateHook support. */ export type EndStreamModel = { type?: string; @@ -30,4 +33,10 @@ export type EndStreamModel = { span_id: string | null; end_time: string; items_streamed: number; + /** Sum of time spent inside the generator producing items (ms). */ + generation_time_ms?: number | null; + /** Sum of time the consumer spent processing yielded items (ms). */ + consumer_time_ms?: number | null; + /** Time from stream start to first item produced / TTFT (ms). */ + first_item_time_ms?: number | null; }; diff --git a/telemetry/ui/src/components/routes/app/StepList.tsx b/telemetry/ui/src/components/routes/app/StepList.tsx index 5f0d197e8..f12ec130e 100644 --- a/telemetry/ui/src/components/routes/app/StepList.tsx +++ b/telemetry/ui/src/components/routes/app/StepList.tsx @@ -892,9 +892,27 @@ const StepSubTable = (props: { new Date(firstStream?.first_item_time).getTime() : undefined; const numStreamed = streamModel.items_streamed; - // const name = ellapsedStreamTime ? `last token (throughput=${ellapsedStreamTime/streamModel.items_streamed} ms/token)` : 'last token'; - // const name = `throughput: ${(ellapsedStreamTime || 0) / numStreamed} ms/token (${numStreamed} tokens/${ellapsedStreamTime}ms)`; - const name = `throughput: ${((ellapsedStreamTime || 0) / numStreamed).toFixed(1)} ms/token (${numStreamed}/${ellapsedStreamTime}ms)`; + // Build a descriptive name that includes generation/consumer timing + // when available (from the new PreStreamGenerateHook/PostStreamGenerateHook + // timing accumulation), falling back to the legacy throughput calculation. + const genTime = streamModel.generation_time_ms; + const consTime = streamModel.consumer_time_ms; + const ttftTime = streamModel.first_item_time_ms; + let name: string; + if ( + genTime !== null && + genTime !== undefined && + consTime !== null && + consTime !== undefined + ) { + const ttft = + ttftTime !== null && ttftTime !== undefined + ? ` · TTFT: ${ttftTime.toFixed(0)}ms` + : ''; + name = `gen: ${genTime.toFixed(0)}ms · consumer: ${consTime.toFixed(0)}ms · ${numStreamed} items${ttft}`; + } else { + name = `throughput: ${((ellapsedStreamTime || 0) / numStreamed).toFixed(1)} ms/token (${numStreamed}/${ellapsedStreamTime}ms)`; + } return ( (generation) -> post + tracker.pre_stream_generate(item_index=0, **common) + state = tracker.stream_state[key] + assert state.stream_start_ns is not None # set on first call + assert state._pre_generate_ns is not None + + tracker.post_stream_generate(item={"token": "hello"}, item_index=0, **common) + assert state.iteration_count == 1 + assert state.generation_time_ns > 0 + assert state.first_item_time_ns is not None # TTFT captured + first_gen_time = state.generation_time_ns + + # Yield 1: pre -> (generation) -> post + tracker.pre_stream_generate(item_index=1, **common) + # Consumer time should now be > 0 (gap between previous post and this pre) + assert state.consumer_time_ns > 0 + + tracker.post_stream_generate(item={"token": "world"}, item_index=1, **common) + assert state.iteration_count == 2 + assert state.generation_time_ns > first_gen_time + + # Final yield (item=None signals StopIteration) + tracker.pre_stream_generate(item_index=2, **common) + tracker.post_stream_generate(item=None, item_index=2, **common) + # item=None should NOT increment iteration_count + assert state.iteration_count == 2 + + +def test_pre_stream_generate_no_stream_state_is_noop(): + """pre/post_stream_generate should silently do nothing when there's no + matching stream_state entry (defensive getattr pattern).""" + import datetime + + tracker = LocalTrackingClient("test", "/tmp/unused") + now = datetime.datetime.now() + common = dict( + stream_initialize_time=now, + action="missing", + sequence_id=0, + app_id="missing", + partition_key=None, + ) + # Should not raise + tracker.pre_stream_generate(item_index=0, **common) + tracker.post_stream_generate(item={"x": 1}, item_index=0, **common) + + +# --------------------------------------------------------------------------- +# EndStreamModel backwards compatibility tests +# --------------------------------------------------------------------------- + + +def test_end_stream_model_without_timing_fields(): + """Old-style EndStreamModel JSON (no timing fields) should parse into the + new model with None timing values — backwards compatibility.""" + old_json = ( + '{"type":"end_stream","action_sequence_id":1,"span_id":null,' + '"end_time":"2024-01-01T00:00:00","items_streamed":10}' + ) + model = EndStreamModel.model_validate_json(old_json) + assert model.items_streamed == 10 + assert model.generation_time_ms is None + assert model.consumer_time_ms is None + assert model.first_item_time_ms is None + + +def test_end_stream_model_with_timing_fields(): + """EndStreamModel with timing fields should round-trip through JSON.""" + import datetime + + model = EndStreamModel( + action_sequence_id=1, + span_id=None, + end_time=datetime.datetime.now(), + items_streamed=47, + generation_time_ms=245.3, + consumer_time_ms=1830.1, + first_item_time_ms=52.0, + ) + dumped = model.model_dump_json() + restored = EndStreamModel.model_validate_json(dumped) + assert restored.generation_time_ms == 245.3 + assert restored.consumer_time_ms == 1830.1 + assert restored.first_item_time_ms == 52.0 + + +# --------------------------------------------------------------------------- +# End-to-end streaming test with LocalTrackingClient +# --------------------------------------------------------------------------- + + +class _SimpleStreamingAction(StreamingAction): + """A streaming action that yields a fixed number of items with a small + delay to produce measurable generation time.""" + + @property + def reads(self) -> list[str]: + return ["prompt"] + + @property + def writes(self) -> list[str]: + return ["response"] + + def stream_run(self, state: State, **run_kwargs) -> Generator[dict, None, None]: + tokens = state["prompt"].split() + for token in tokens: + time.sleep(0.01) # small delay so generation_time_ns > 0 + yield {"token": token} + + def update(self, result: dict, state: State) -> State: + return state.update(response=result.get("token", "")) + + +def test_streaming_action_end_to_end_writes_timing(tmpdir): + """Integration test: run a streaming action through ApplicationBuilder with + a LocalTrackingClient and verify that the end_stream log entry contains + non-null timing fields.""" + app_id = str(uuid.uuid4()) + log_dir = os.path.join(tmpdir, "tracking") + project_name = "test_streaming_timing" + + tracker = LocalTrackingClient(project=project_name, storage_dir=log_dir) + app = ( + ApplicationBuilder() + .with_state(prompt="hello world test", response="") + .with_actions(generate=_SimpleStreamingAction()) + .with_transitions() + .with_entrypoint("generate") + .with_tracker(tracker) + .with_identifiers(app_id=app_id) + .build() + ) + + action_, streaming_container = app.stream_result(halt_after=["generate"]) + for _ in streaming_container: + time.sleep(0.01) # simulate consumer processing + streaming_container.get() + + # Read the log file and find the end_stream entry + log_path = os.path.join(log_dir, project_name, app_id, LocalTrackingClient.LOG_FILENAME) + assert os.path.exists(log_path) + with open(log_path) as f: + log_lines = [json.loads(line) for line in f.readlines()] + + end_stream_entries = [ + EndStreamModel.model_validate(line) for line in log_lines if line["type"] == "end_stream" + ] + assert len(end_stream_entries) == 1 + end_stream = end_stream_entries[0] + + # Verify timing fields are populated (not None) + assert end_stream.generation_time_ms is not None + assert ( + end_stream.generation_time_ms > 0 + ), "generation_time_ms should be > 0 (we slept in stream_run)" + assert end_stream.consumer_time_ms is not None + assert ( + end_stream.consumer_time_ms > 0 + ), "consumer_time_ms should be > 0 (we slept between items)" + assert end_stream.first_item_time_ms is not None + assert end_stream.first_item_time_ms > 0, "first_item_time_ms (TTFT) should be > 0" + # items_streamed is tracked by the existing post_stream_item hook, which + # may not count all yields depending on the streaming container semantics. + assert end_stream.items_streamed >= 1 + + +async def test_async_streaming_action_end_to_end_writes_timing(tmpdir): + """Async variant: verify timing fields appear in end_stream log entry.""" + + @streaming_action(reads=["prompt"], writes=["response"]) + async def async_generate(state: State): + tokens = state["prompt"].split() + buffer = [] + for token in tokens: + await asyncio.sleep(0.01) + buffer.append(token) + yield {"token": token}, None + yield {"token": ""}, state.update(response=" ".join(buffer)) + + app_id = str(uuid.uuid4()) + log_dir = os.path.join(tmpdir, "tracking") + project_name = "test_async_streaming_timing" + + tracker = LocalTrackingClient(project=project_name, storage_dir=log_dir) + app = ( + ApplicationBuilder() + .with_state(prompt="async streaming test tokens", response="") + .with_actions(generate=async_generate) + .with_transitions() + .with_entrypoint("generate") + .with_tracker(tracker) + .with_identifiers(app_id=app_id) + .build() + ) + + action_, streaming_container = await app.astream_result(halt_after=["generate"]) + async for _ in streaming_container: + await asyncio.sleep(0.01) + await streaming_container.get() + + log_path = os.path.join(log_dir, project_name, app_id, LocalTrackingClient.LOG_FILENAME) + assert os.path.exists(log_path) + with open(log_path) as f: + log_lines = [json.loads(line) for line in f.readlines()] + + end_stream_entries = [ + EndStreamModel.model_validate(line) for line in log_lines if line["type"] == "end_stream" + ] + assert len(end_stream_entries) == 1 + end_stream = end_stream_entries[0] + + assert end_stream.generation_time_ms is not None + assert end_stream.generation_time_ms > 0 + assert end_stream.consumer_time_ms is not None + assert end_stream.consumer_time_ms > 0 + assert end_stream.first_item_time_ms is not None + assert end_stream.first_item_time_ms > 0 + assert end_stream.items_streamed >= 1 From 448b03ccdf28697864dd41897ea2d7016bfcdb65 Mon Sep 17 00:00:00 2001 From: Stefan Krawczyk Date: Mon, 2 Mar 2026 17:53:27 -0800 Subject: [PATCH 3/3] Fixes docs issue --- docs/concepts/streaming-actions.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/concepts/streaming-actions.rst b/docs/concepts/streaming-actions.rst index edf642417..26bc2db8c 100644 --- a/docs/concepts/streaming-actions.rst +++ b/docs/concepts/streaming-actions.rst @@ -243,9 +243,9 @@ be consistent with the asynchronous method. If you're using the old version, the 2. All intermediate results should be yielded as ``yield {'response': delta}, None`` instead of ``yield {'response': delta}``. 3. The final result will be a ``yield`` instead of a ``return`` ------------------------ +------------------------- Telemetry & Observability ------------------------ +------------------------- Streaming actions are instrumented with lifecycle hooks that bracket each generator yield. This enables two forms of observability: