Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 70 additions & 24 deletions langfuse/_client/observe.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,26 @@ def __init__(
self.items: List[Any] = []
self.span = span
self.transform_fn = transform_fn
self._ended = False

def _end_span(
self, *, level: Optional[str] = None, status_message: Optional[str] = None
) -> None:
if self._ended:
return
self._ended = True

output: Any = self.items

if self.transform_fn is not None:
output = self.transform_fn(self.items)
elif all(isinstance(item, str) for item in self.items):
output = "".join(self.items)

if level is not None:
self.span.update(output=output, level=level, status_message=status_message).end()
else:
self.span.update(output=output).end()

def __iter__(self) -> "_ContextPreservedSyncGeneratorWrapper":
return self
Expand All @@ -573,26 +593,29 @@ def __next__(self) -> Any:
return item

except StopIteration:
# Handle output and span cleanup when generator is exhausted
output: Any = self.items

if self.transform_fn is not None:
output = self.transform_fn(self.items)

elif all(isinstance(item, str) for item in self.items):
output = "".join(self.items)

self.span.update(output=output).end()
self._end_span()

raise # Re-raise StopIteration

except (Exception, asyncio.CancelledError) as e:
self.span.update(
self._end_span(
level="ERROR", status_message=str(e) or type(e).__name__
).end()
)

raise

def close(self) -> None:
tokens = []
try:
if self.context:
for var, value in self.context.items():
tokens.append((var, var.set(value)))
self._end_span()
self.generator.close()
finally:
for var, token in tokens:
var.reset(token)


class _ContextPreservedAsyncGeneratorWrapper:
"""Async generator wrapper that ensures each iteration runs in preserved context."""
Expand All @@ -619,6 +642,26 @@ def __init__(
self.items: List[Any] = []
self.span = span
self.transform_fn = transform_fn
self._ended = False

def _end_span(
self, *, level: Optional[str] = None, status_message: Optional[str] = None
) -> None:
if self._ended:
return
self._ended = True

output: Any = self.items

if self.transform_fn is not None:
output = self.transform_fn(self.items)
elif all(isinstance(item, str) for item in self.items):
output = "".join(self.items)

if level is not None:
self.span.update(output=output, level=level, status_message=status_message).end()
else:
self.span.update(output=output).end()

def __aiter__(self) -> "_ContextPreservedAsyncGeneratorWrapper":
return self
Expand All @@ -641,21 +684,24 @@ async def __anext__(self) -> Any:
return item

except StopAsyncIteration:
# Handle output and span cleanup when generator is exhausted
output: Any = self.items

if self.transform_fn is not None:
output = self.transform_fn(self.items)

elif all(isinstance(item, str) for item in self.items):
output = "".join(self.items)

self.span.update(output=output).end()
self._end_span()

raise # Re-raise StopAsyncIteration
except (Exception, asyncio.CancelledError) as e:
self.span.update(
self._end_span(
level="ERROR", status_message=str(e) or type(e).__name__
).end()
)

raise

async def aclose(self) -> None:
tokens = []
try:
if self.context:
for var, value in self.context.items():
tokens.append((var, var.set(value)))
self._end_span()
await self.generator.aclose()
finally:
for var, token in tokens:
var.reset(token)