diff --git a/langfuse/_client/observe.py b/langfuse/_client/observe.py index c648a0a62..250a15a56 100644 --- a/langfuse/_client/observe.py +++ b/langfuse/_client/observe.py @@ -560,6 +560,26 @@ def __init__( self.items: List[Any] = [] self.span = span self.transform_fn = transform_fn + self._ended = False + + def _end_span( + self, *, level: Optional[str] = None, status_message: Optional[str] = None + ) -> None: + if self._ended: + return + self._ended = True + + output: Any = self.items + + if self.transform_fn is not None: + output = self.transform_fn(self.items) + elif all(isinstance(item, str) for item in self.items): + output = "".join(self.items) + + if level is not None: + self.span.update(output=output, level=level, status_message=status_message).end() + else: + self.span.update(output=output).end() def __iter__(self) -> "_ContextPreservedSyncGeneratorWrapper": return self @@ -573,26 +593,29 @@ def __next__(self) -> Any: return item except StopIteration: - # Handle output and span cleanup when generator is exhausted - output: Any = self.items - - if self.transform_fn is not None: - output = self.transform_fn(self.items) - - elif all(isinstance(item, str) for item in self.items): - output = "".join(self.items) - - self.span.update(output=output).end() + self._end_span() raise # Re-raise StopIteration except (Exception, asyncio.CancelledError) as e: - self.span.update( + self._end_span( level="ERROR", status_message=str(e) or type(e).__name__ - ).end() + ) raise + def close(self) -> None: + tokens = [] + try: + if self.context: + for var, value in self.context.items(): + tokens.append((var, var.set(value))) + self._end_span() + self.generator.close() + finally: + for var, token in tokens: + var.reset(token) + class _ContextPreservedAsyncGeneratorWrapper: """Async generator wrapper that ensures each iteration runs in preserved context.""" @@ -619,6 +642,26 @@ def __init__( self.items: List[Any] = [] self.span = span self.transform_fn = transform_fn + self._ended = False + + def _end_span( + self, *, level: Optional[str] = None, status_message: Optional[str] = None + ) -> None: + if self._ended: + return + self._ended = True + + output: Any = self.items + + if self.transform_fn is not None: + output = self.transform_fn(self.items) + elif all(isinstance(item, str) for item in self.items): + output = "".join(self.items) + + if level is not None: + self.span.update(output=output, level=level, status_message=status_message).end() + else: + self.span.update(output=output).end() def __aiter__(self) -> "_ContextPreservedAsyncGeneratorWrapper": return self @@ -641,21 +684,24 @@ async def __anext__(self) -> Any: return item except StopAsyncIteration: - # Handle output and span cleanup when generator is exhausted - output: Any = self.items - - if self.transform_fn is not None: - output = self.transform_fn(self.items) - - elif all(isinstance(item, str) for item in self.items): - output = "".join(self.items) - - self.span.update(output=output).end() + self._end_span() raise # Re-raise StopAsyncIteration except (Exception, asyncio.CancelledError) as e: - self.span.update( + self._end_span( level="ERROR", status_message=str(e) or type(e).__name__ - ).end() + ) raise + + async def aclose(self) -> None: + tokens = [] + try: + if self.context: + for var, value in self.context.items(): + tokens.append((var, var.set(value))) + self._end_span() + await self.generator.aclose() + finally: + for var, token in tokens: + var.reset(token)