diff --git a/langfuse/_client/observe.py b/langfuse/_client/observe.py index 3ada60bc9..64882a20f 100644 --- a/langfuse/_client/observe.py +++ b/langfuse/_client/observe.py @@ -563,10 +563,56 @@ def __init__( self.span = span self.capture_output = capture_output self.transform_fn = transform_fn + self._span_ended = False def __iter__(self) -> "_ContextPreservedSyncGeneratorWrapper": return self + def _finalize(self) -> None: + if self._span_ended: + return + + if self.capture_output: + output: Any = self.items + + if self.transform_fn is not None: + output = self.transform_fn(self.items) + + elif all(isinstance(item, str) for item in self.items): + output = "".join(self.items) + + self.span.update(output=output) + + self.span.end() + self._span_ended = True + + def _finalize_with_error(self, error: BaseException) -> None: + if self._span_ended: + return + + self.span.update( + level="ERROR", status_message=str(error) or type(error).__name__ + ).end() + self._span_ended = True + + def close(self) -> None: + if self._span_ended: + return + + try: + self.context.run(self.generator.close) + except (Exception, asyncio.CancelledError) as error: + self._finalize_with_error(error) + raise + else: + self._finalize() + + def __del__(self) -> None: + try: + self.close() + except BaseException: + pass + def __next__(self) -> Any: try: # Run the generator's __next__ in the preserved context @@ -577,27 +623,11 @@ def __next__(self) -> Any: return item except StopIteration: - # Handle output and span cleanup when generator is exhausted - if self.capture_output: - output: Any = self.items - - if self.transform_fn is not None: - output = self.transform_fn(self.items) - - elif all(isinstance(item, str) for item in self.items): - output = "".join(self.items) - - self.span.update(output=output) - - self.span.end() - + self._finalize() raise # Re-raise StopIteration except (Exception, asyncio.CancelledError) as e: - self.span.update( - level="ERROR", status_message=str(e) or type(e).__name__ - ).end() - + self._finalize_with_error(e) raise @@ -628,22 +658,77 @@ def __init__( self.span = span self.capture_output = capture_output self.transform_fn = transform_fn + self._span_ended = False def __aiter__(self) -> "_ContextPreservedAsyncGeneratorWrapper": return self + def _finalize(self) -> None: + if self._span_ended: + return + + if self.capture_output: + output: Any = self.items + + if self.transform_fn is not None: + output = self.transform_fn(self.items) + + elif all(isinstance(item, str) for item in self.items): + output = "".join(self.items) + + self.span.update(output=output) + + self.span.end() + self._span_ended = True + + def _finalize_with_error(self, error: BaseException) -> None: + if self._span_ended: + return + + self.span.update( + level="ERROR", status_message=str(error) or type(error).__name__ + ).end() + self._span_ended = True + + async def aclose(self) -> None: + if self._span_ended: + return + + try: + try: + await asyncio.create_task( + self.generator.aclose(), + context=self.context, + ) # type: ignore + except TypeError: + await self.context.run(asyncio.create_task, self.generator.aclose()) + except (Exception, asyncio.CancelledError) as error: + self._finalize_with_error(error) + raise + else: + self._finalize() + + async def close(self) -> None: + await self.aclose() + + def __del__(self) -> None: + self._finalize() + async def __anext__(self) -> Any: try: # Run the generator's __anext__ in the preserved context try: - # Python 3.10+ approach with context parameter + # Python 3.11+ approach with explicit task context item = await asyncio.create_task( self.generator.__anext__(), # type: ignore context=self.context, ) # type: ignore except TypeError: - # Python < 3.10 fallback - context parameter not supported - item = await self.generator.__anext__() + # Python 3.10 fallback - create the task inside the preserved context. + item = await self.context.run( + asyncio.create_task, + self.generator.__anext__(), # type: ignore + ) if self.capture_output: self.items.append(item) @@ -651,24 +736,8 @@ async def __anext__(self) -> Any: return item except StopAsyncIteration: - # Handle output and span cleanup when generator is exhausted - if self.capture_output: - output: Any = self.items - - if self.transform_fn is not None: - output = self.transform_fn(self.items) - - elif all(isinstance(item, str) for item in self.items): - output = "".join(self.items) - - self.span.update(output=output) - - self.span.end() - + self._finalize() raise # Re-raise StopAsyncIteration except (Exception, asyncio.CancelledError) as e: - self.span.update( - level="ERROR", status_message=str(e) or type(e).__name__ - ).end() - + self._finalize_with_error(e) raise diff --git a/tests/unit/test_observe.py b/tests/unit/test_observe.py index 94a2cbb83..24f79c3fc 100644 --- a/tests/unit/test_observe.py +++ b/tests/unit/test_observe.py @@ -1,25 +1,47 @@ +import asyncio +import contextvars +import gc import sys +from typing import Any, AsyncGenerator, Generator, cast import pytest from langfuse import observe from langfuse._client.attributes import LangfuseOtelSpanAttributes +from langfuse._client.observe import ( + _ContextPreservedAsyncGeneratorWrapper, + _ContextPreservedSyncGeneratorWrapper, +) -def _finished_spans_by_name(memory_exporter, name: str): +class SpanRecorder: + def __init__(self) -> None: + self.ended = 0 + self.updates: list[dict[str, Any]] = [] + + def update(self, **kwargs: Any) -> "SpanRecorder": + self.updates.append(kwargs) + return self + + def end(self) -> "SpanRecorder": + self.ended += 1 + return self + + +def _finished_spans_by_name(memory_exporter: Any, name: str) -> list[Any]: return [span for span in memory_exporter.get_finished_spans() if span.name == name] def test_sync_generator_preserves_context_without_output_capture( - langfuse_memory_client, memory_exporter -): + langfuse_memory_client: Any, memory_exporter: Any +) -> None: @observe(name="child_step") def child_step(index: int) -> str: return f"item_{index}" @observe(name="root", capture_output=False) - def root(): - def body(): + def root() -> Generator[str, None, None]: + def body() -> Generator[str, None, None]: for index in range(2): yield child_step(index) @@ -30,7 +52,7 @@ def body(): assert memory_exporter.get_finished_spans() == [] assert list(generator) == ["item_0", "item_1"] - assert generator.items == [] + assert cast(Any, generator).items == [] langfuse_memory_client.flush() @@ -51,22 +73,22 @@ def body(): @pytest.mark.asyncio @pytest.mark.skipif(sys.version_info < (3, 11), reason="requires python3.11 or higher") async def test_streaming_response_preserves_context_without_output_capture( - langfuse_memory_client, memory_exporter -): + langfuse_memory_client: Any, memory_exporter: Any +) -> None: class StreamingResponse: - def __init__(self, body_iterator): + def __init__(self, body_iterator: AsyncGenerator[str, None]) -> None: self.body_iterator = body_iterator @observe(name="stream_step") async def stream_step(index: int) -> str: return f"chunk_{index}" - async def body(): + async def body() -> AsyncGenerator[str, None]: for index in range(2): yield await stream_step(index) @observe(name="endpoint", capture_output=False) - async def endpoint(): + async def endpoint() -> StreamingResponse: return StreamingResponse(body()) response = await endpoint() @@ -74,7 +96,7 @@ async def endpoint(): assert memory_exporter.get_finished_spans() == [] assert [item async for item in response.body_iterator] == ["chunk_0", "chunk_1"] - assert response.body_iterator.items == [] + assert cast(Any, response.body_iterator).items == [] langfuse_memory_client.flush() @@ -90,3 +112,204 @@ async def endpoint(): step.context.trace_id == endpoint_span.context.trace_id for step in step_spans ) assert LangfuseOtelSpanAttributes.OBSERVATION_OUTPUT not in endpoint_span.attributes + + +def test_sync_generator_wrapper_close_ends_span_without_exhaustion() -> None: + def generator() -> Generator[str, None, None]: + yield "item_0" + yield "item_1" + + span = SpanRecorder() + wrapper = _ContextPreservedSyncGeneratorWrapper( + generator(), + contextvars.copy_context(), + cast(Any, span), + False, + None, + ) + + assert next(wrapper) == "item_0" + + wrapper.close() + wrapper.close() + + assert span.ended == 1 + assert span.updates == [] + + +def test_sync_generator_wrapper_close_preserves_context() -> None: + marker = contextvars.ContextVar("marker", default="ambient") + seen: list[str] = [] + + def generator() -> Generator[str, None, None]: + try: + yield "item_0" + yield "item_1" + finally: + seen.append(marker.get()) + + span = SpanRecorder() + context = contextvars.copy_context() + context.run(marker.set, "preserved") + wrapper = _ContextPreservedSyncGeneratorWrapper( + generator(), + context, + cast(Any, span), + False, + None, + ) + + assert next(wrapper) == "item_0" + marker.set("ambient-now") + + wrapper.close() + + assert seen == ["preserved"] + assert span.ended == 1 + + +def test_sync_generator_wrapper_del_ends_span_when_abandoned() -> None: + def generator() -> Generator[str, None, None]: + yield "item_0" + yield "item_1" + + span = SpanRecorder() + wrapper = _ContextPreservedSyncGeneratorWrapper( + generator(), + contextvars.copy_context(), + cast(Any, span), + False, + None, + ) + + assert next(wrapper) == "item_0" + + del wrapper + gc.collect() + + assert span.ended == 1 + assert span.updates == [] + + +@pytest.mark.asyncio +async def test_async_generator_wrapper_aclose_ends_span_without_exhaustion() -> None: + async def generator() -> AsyncGenerator[str, None]: + yield "item_0" + yield "item_1" + + span = SpanRecorder() + wrapper = _ContextPreservedAsyncGeneratorWrapper( + generator(), + contextvars.copy_context(), + cast(Any, span), + False, + None, + ) + + assert await wrapper.__anext__() == "item_0" + + await wrapper.aclose() + await wrapper.close() + + assert span.ended == 1 + assert span.updates == [] + + +@pytest.mark.asyncio +async def test_async_generator_wrapper_aclose_preserves_context() -> None: + marker = contextvars.ContextVar("marker", default="ambient") + seen: list[str] = [] + + async def generator() -> AsyncGenerator[str, None]: + try: + yield "item_0" + yield "item_1" + finally: + seen.append(marker.get()) + + span = SpanRecorder() + context = contextvars.copy_context() + context.run(marker.set, "preserved") + wrapper = _ContextPreservedAsyncGeneratorWrapper( + generator(), + context, + cast(Any, span), + False, + None, + ) + + assert await wrapper.__anext__() == "item_0" + marker.set("ambient-now") + + await wrapper.aclose() + + assert seen == ["preserved"] + assert span.ended == 1 + + +@pytest.mark.asyncio +async def test_async_generator_wrapper_fallback_preserves_context( + monkeypatch: pytest.MonkeyPatch, +) -> None: + marker = contextvars.ContextVar("marker", default="ambient") + seen: list[str] = [] + original_create_task = asyncio.create_task + + def create_task_with_type_error(*args: Any, **kwargs: Any) -> asyncio.Task[Any]: + if "context" in kwargs: + raise TypeError("context argument unsupported") + + return original_create_task(*args, **kwargs) + + monkeypatch.setattr(asyncio, "create_task", create_task_with_type_error) + + async def generator() -> AsyncGenerator[str, None]: + try: + yield marker.get() + yield "item_1" + finally: + seen.append(marker.get()) + + span = SpanRecorder() + context = contextvars.copy_context() + context.run(marker.set, "preserved") + wrapper = _ContextPreservedAsyncGeneratorWrapper( + generator(), + context, + cast(Any, span), + False, + None, + ) + + assert await wrapper.__anext__() == "preserved" + marker.set("ambient-now") + + await wrapper.aclose() + + assert seen == ["preserved"] + assert span.ended == 1 + + +@pytest.mark.asyncio +async def test_async_generator_wrapper_del_ends_span_when_abandoned() -> None: + async def generator() -> AsyncGenerator[str, None]: + yield "item_0" + yield "item_1" + + span = SpanRecorder() + wrapper = _ContextPreservedAsyncGeneratorWrapper( + generator(), + contextvars.copy_context(), + cast(Any, span), + False, + None, + ) + + assert await wrapper.__anext__() == "item_0" + + del wrapper + gc.collect() + await asyncio.sleep(0) + + assert span.ended == 1 + assert span.updates == []