Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 96 additions & 36 deletions langfuse/_client/observe.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,10 +563,56 @@ def __init__(
self.span = span
self.capture_output = capture_output
self.transform_fn = transform_fn
self._span_ended = False

def __iter__(self) -> "_ContextPreservedSyncGeneratorWrapper":
return self

def _finalize(self) -> None:
if self._span_ended:
return

if self.capture_output:
output: Any = self.items

if self.transform_fn is not None:
output = self.transform_fn(self.items)

elif all(isinstance(item, str) for item in self.items):
output = "".join(self.items)

self.span.update(output=output)

self.span.end()
self._span_ended = True

def _finalize_with_error(self, error: BaseException) -> None:
if self._span_ended:
return

self.span.update(
level="ERROR", status_message=str(error) or type(error).__name__
).end()
self._span_ended = True

def close(self) -> None:
if self._span_ended:
return

try:
self.generator.close()
Comment thread
hassiebp marked this conversation as resolved.
Outdated
except (Exception, asyncio.CancelledError) as error:
self._finalize_with_error(error)
raise
else:
self._finalize()

def __del__(self) -> None:
try:
self.close()
except BaseException:
pass

def __next__(self) -> Any:
try:
# Run the generator's __next__ in the preserved context
Expand All @@ -577,27 +623,11 @@ def __next__(self) -> Any:
return item

except StopIteration:
# Handle output and span cleanup when generator is exhausted
if self.capture_output:
output: Any = self.items

if self.transform_fn is not None:
output = self.transform_fn(self.items)

elif all(isinstance(item, str) for item in self.items):
output = "".join(self.items)

self.span.update(output=output)

self.span.end()

self._finalize()
raise # Re-raise StopIteration

except (Exception, asyncio.CancelledError) as e:
self.span.update(
level="ERROR", status_message=str(e) or type(e).__name__
).end()

self._finalize_with_error(e)
raise


Expand Down Expand Up @@ -628,10 +658,56 @@ def __init__(
self.span = span
self.capture_output = capture_output
self.transform_fn = transform_fn
self._span_ended = False

def __aiter__(self) -> "_ContextPreservedAsyncGeneratorWrapper":
return self

def _finalize(self) -> None:
if self._span_ended:
return

if self.capture_output:
output: Any = self.items

if self.transform_fn is not None:
output = self.transform_fn(self.items)

elif all(isinstance(item, str) for item in self.items):
output = "".join(self.items)

self.span.update(output=output)

self.span.end()
self._span_ended = True

def _finalize_with_error(self, error: BaseException) -> None:
if self._span_ended:
return

self.span.update(
level="ERROR", status_message=str(error) or type(error).__name__
).end()
self._span_ended = True

async def aclose(self) -> None:
if self._span_ended:
return

try:
await self.generator.aclose()
Comment thread
hassiebp marked this conversation as resolved.
Outdated
except (Exception, asyncio.CancelledError) as error:
self._finalize_with_error(error)
raise
else:
self._finalize()

async def close(self) -> None:
await self.aclose()

def __del__(self) -> None:
self._finalize()

async def __anext__(self) -> Any:
try:
# Run the generator's __anext__ in the preserved context
Expand All @@ -651,24 +727,8 @@ async def __anext__(self) -> Any:
return item

except StopAsyncIteration:
# Handle output and span cleanup when generator is exhausted
if self.capture_output:
output: Any = self.items

if self.transform_fn is not None:
output = self.transform_fn(self.items)

elif all(isinstance(item, str) for item in self.items):
output = "".join(self.items)

self.span.update(output=output)

self.span.end()

self._finalize()
raise # Re-raise StopAsyncIteration
except (Exception, asyncio.CancelledError) as e:
self.span.update(
level="ERROR", status_message=str(e) or type(e).__name__
).end()

self._finalize_with_error(e)
raise
141 changes: 129 additions & 12 deletions tests/unit/test_observe.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,47 @@
import asyncio
import contextvars
import gc
import sys
from typing import Any, AsyncGenerator, Generator, cast

import pytest

from langfuse import observe
from langfuse._client.attributes import LangfuseOtelSpanAttributes
from langfuse._client.observe import (
_ContextPreservedAsyncGeneratorWrapper,
_ContextPreservedSyncGeneratorWrapper,
)


def _finished_spans_by_name(memory_exporter, name: str):
class SpanRecorder:
def __init__(self) -> None:
self.ended = 0
self.updates: list[dict[str, Any]] = []

def update(self, **kwargs: Any) -> "SpanRecorder":
self.updates.append(kwargs)
return self

def end(self) -> "SpanRecorder":
self.ended += 1
return self


def _finished_spans_by_name(memory_exporter: Any, name: str) -> list[Any]:
return [span for span in memory_exporter.get_finished_spans() if span.name == name]


def test_sync_generator_preserves_context_without_output_capture(
langfuse_memory_client, memory_exporter
):
langfuse_memory_client: Any, memory_exporter: Any
) -> None:
@observe(name="child_step")
def child_step(index: int) -> str:
return f"item_{index}"

@observe(name="root", capture_output=False)
def root():
def body():
def root() -> Generator[str, None, None]:
def body() -> Generator[str, None, None]:
for index in range(2):
yield child_step(index)

Expand All @@ -30,7 +52,7 @@ def body():
assert memory_exporter.get_finished_spans() == []

assert list(generator) == ["item_0", "item_1"]
assert generator.items == []
assert cast(Any, generator).items == []

langfuse_memory_client.flush()

Expand All @@ -51,30 +73,30 @@ def body():
@pytest.mark.asyncio
@pytest.mark.skipif(sys.version_info < (3, 11), reason="requires python3.11 or higher")
async def test_streaming_response_preserves_context_without_output_capture(
langfuse_memory_client, memory_exporter
):
langfuse_memory_client: Any, memory_exporter: Any
) -> None:
class StreamingResponse:
def __init__(self, body_iterator):
def __init__(self, body_iterator: AsyncGenerator[str, None]) -> None:
self.body_iterator = body_iterator

@observe(name="stream_step")
async def stream_step(index: int) -> str:
return f"chunk_{index}"

async def body():
async def body() -> AsyncGenerator[str, None]:
for index in range(2):
yield await stream_step(index)

@observe(name="endpoint", capture_output=False)
async def endpoint():
async def endpoint() -> StreamingResponse:
return StreamingResponse(body())

response = await endpoint()

assert memory_exporter.get_finished_spans() == []

assert [item async for item in response.body_iterator] == ["chunk_0", "chunk_1"]
assert response.body_iterator.items == []
assert cast(Any, response.body_iterator).items == []

langfuse_memory_client.flush()

Expand All @@ -90,3 +112,98 @@ async def endpoint():
step.context.trace_id == endpoint_span.context.trace_id for step in step_spans
)
assert LangfuseOtelSpanAttributes.OBSERVATION_OUTPUT not in endpoint_span.attributes


def test_sync_generator_wrapper_close_ends_span_without_exhaustion() -> None:
def generator() -> Generator[str, None, None]:
yield "item_0"
yield "item_1"

span = SpanRecorder()
wrapper = _ContextPreservedSyncGeneratorWrapper(
generator(),
contextvars.copy_context(),
cast(Any, span),
False,
None,
)

assert next(wrapper) == "item_0"

wrapper.close()
wrapper.close()

assert span.ended == 1
assert span.updates == []


def test_sync_generator_wrapper_del_ends_span_when_abandoned() -> None:
def generator() -> Generator[str, None, None]:
yield "item_0"
yield "item_1"

span = SpanRecorder()
wrapper = _ContextPreservedSyncGeneratorWrapper(
generator(),
contextvars.copy_context(),
cast(Any, span),
False,
None,
)

assert next(wrapper) == "item_0"

del wrapper
gc.collect()

assert span.ended == 1
assert span.updates == []


@pytest.mark.asyncio
async def test_async_generator_wrapper_aclose_ends_span_without_exhaustion() -> None:
async def generator() -> AsyncGenerator[str, None]:
yield "item_0"
yield "item_1"

span = SpanRecorder()
wrapper = _ContextPreservedAsyncGeneratorWrapper(
generator(),
contextvars.copy_context(),
cast(Any, span),
False,
None,
)

assert await wrapper.__anext__() == "item_0"

await wrapper.aclose()
await wrapper.close()

assert span.ended == 1
assert span.updates == []


@pytest.mark.asyncio
async def test_async_generator_wrapper_del_ends_span_when_abandoned() -> None:
async def generator() -> AsyncGenerator[str, None]:
yield "item_0"
yield "item_1"

span = SpanRecorder()
wrapper = _ContextPreservedAsyncGeneratorWrapper(
generator(),
contextvars.copy_context(),
cast(Any, span),
False,
None,
)

assert await wrapper.__anext__() == "item_0"

del wrapper
gc.collect()
await asyncio.sleep(0)

assert span.ended == 1
assert span.updates == []
Loading