diff --git a/langfuse/_client/observe.py b/langfuse/_client/observe.py index c648a0a62..3ada60bc9 100644 --- a/langfuse/_client/observe.py +++ b/langfuse/_client/observe.py @@ -290,42 +290,15 @@ async def async_wrapper(*args: Tuple[Any], **kwargs: Dict[str, Any]) -> Any: try: result = await func(*args, **kwargs) - - if capture_output is True: - if inspect.isgenerator(result): - is_return_type_generator = True - - return self._wrap_sync_generator_result( - langfuse_span_or_generation, - result, - transform_to_string, - ) - - if inspect.isasyncgen(result): - is_return_type_generator = True - - return self._wrap_async_generator_result( - langfuse_span_or_generation, - result, - transform_to_string, - ) - - # handle starlette.StreamingResponse - if type(result).__name__ == "StreamingResponse" and hasattr( - result, "body_iterator" - ): - is_return_type_generator = True - - result.body_iterator = ( - self._wrap_async_generator_result( - langfuse_span_or_generation, - result.body_iterator, - transform_to_string, - ) - ) - - langfuse_span_or_generation.update(output=result) - + ( + is_return_type_generator, + result, + ) = self._handle_observe_result( + langfuse_span_or_generation, + result, + capture_output=capture_output, + transform_to_string=transform_to_string, + ) return result except (Exception, asyncio.CancelledError) as e: langfuse_span_or_generation.update( @@ -408,42 +381,15 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any: try: result = func(*args, **kwargs) - - if capture_output is True: - if inspect.isgenerator(result): - is_return_type_generator = True - - return self._wrap_sync_generator_result( - langfuse_span_or_generation, - result, - transform_to_string, - ) - - if inspect.isasyncgen(result): - is_return_type_generator = True - - return self._wrap_async_generator_result( - langfuse_span_or_generation, - result, - transform_to_string, - ) - - # handle starlette.StreamingResponse - if type(result).__name__ == "StreamingResponse" and hasattr( - result, "body_iterator" - ): - is_return_type_generator = True - - result.body_iterator = ( - self._wrap_async_generator_result( - langfuse_span_or_generation, - result.body_iterator, - transform_to_string, - ) - ) - - langfuse_span_or_generation.update(output=result) - + ( + is_return_type_generator, + result, + ) = self._handle_observe_result( + langfuse_span_or_generation, + result, + capture_output=capture_output, + transform_to_string=transform_to_string, + ) return result except (Exception, asyncio.CancelledError) as e: langfuse_span_or_generation.update( @@ -493,6 +439,7 @@ def _wrap_sync_generator_result( LangfuseGuardrail, ], generator: Generator, + capture_output: bool, transform_to_string: Optional[Callable[[Iterable], str]] = None, ) -> Any: preserved_context = contextvars.copy_context() @@ -501,6 +448,7 @@ def _wrap_sync_generator_result( generator, preserved_context, langfuse_span_or_generation, + capture_output, transform_to_string, ) @@ -518,6 +466,7 @@ def _wrap_async_generator_result( LangfuseGuardrail, ], generator: AsyncGenerator, + capture_output: bool, transform_to_string: Optional[Callable[[Iterable], str]] = None, ) -> Any: preserved_context = contextvars.copy_context() @@ -526,9 +475,61 @@ def _wrap_async_generator_result( generator, preserved_context, langfuse_span_or_generation, + capture_output, transform_to_string, ) + def _handle_observe_result( + self, + langfuse_span_or_generation: Union[ + LangfuseSpan, + LangfuseGeneration, + LangfuseAgent, + LangfuseTool, + LangfuseChain, + LangfuseRetriever, + LangfuseEvaluator, + LangfuseEmbedding, + LangfuseGuardrail, + ], + result: Any, + *, + capture_output: bool, + transform_to_string: Optional[Callable[[Iterable], str]] = None, + ) -> Tuple[bool, Any]: + if inspect.isgenerator(result): + return True, self._wrap_sync_generator_result( + langfuse_span_or_generation, + result, + capture_output, + transform_to_string, + ) + + if inspect.isasyncgen(result): + return True, self._wrap_async_generator_result( + langfuse_span_or_generation, + result, + capture_output, + transform_to_string, + ) + + # handle starlette.StreamingResponse + if type(result).__name__ == "StreamingResponse" and hasattr( + result, "body_iterator" + ): + result.body_iterator = self._wrap_async_generator_result( + langfuse_span_or_generation, + result.body_iterator, + capture_output, + transform_to_string, + ) + return True, result + + if capture_output is True: + langfuse_span_or_generation.update(output=result) + + return False, result + _decorator = LangfuseDecorator() @@ -553,12 +554,14 @@ def __init__( LangfuseEmbedding, LangfuseGuardrail, ], + capture_output: bool, transform_fn: Optional[Callable[[Iterable], str]], ) -> None: self.generator = generator self.context = context self.items: List[Any] = [] self.span = span + self.capture_output = capture_output self.transform_fn = transform_fn def __iter__(self) -> "_ContextPreservedSyncGeneratorWrapper": @@ -568,21 +571,25 @@ def __next__(self) -> Any: try: # Run the generator's __next__ in the preserved context item = self.context.run(next, self.generator) - self.items.append(item) + if self.capture_output: + self.items.append(item) return item except StopIteration: # Handle output and span cleanup when generator is exhausted - output: Any = self.items + if self.capture_output: + output: Any = self.items + + if self.transform_fn is not None: + output = self.transform_fn(self.items) - if self.transform_fn is not None: - output = self.transform_fn(self.items) + elif all(isinstance(item, str) for item in self.items): + output = "".join(self.items) - elif all(isinstance(item, str) for item in self.items): - output = "".join(self.items) + self.span.update(output=output) - self.span.update(output=output).end() + self.span.end() raise # Re-raise StopIteration @@ -612,12 +619,14 @@ def __init__( LangfuseEmbedding, LangfuseGuardrail, ], + capture_output: bool, transform_fn: Optional[Callable[[Iterable], str]], ) -> None: self.generator = generator self.context = context self.items: List[Any] = [] self.span = span + self.capture_output = capture_output self.transform_fn = transform_fn def __aiter__(self) -> "_ContextPreservedAsyncGeneratorWrapper": @@ -636,21 +645,25 @@ async def __anext__(self) -> Any: # Python < 3.10 fallback - context parameter not supported item = await self.generator.__anext__() - self.items.append(item) + if self.capture_output: + self.items.append(item) return item except StopAsyncIteration: # Handle output and span cleanup when generator is exhausted - output: Any = self.items + if self.capture_output: + output: Any = self.items + + if self.transform_fn is not None: + output = self.transform_fn(self.items) - if self.transform_fn is not None: - output = self.transform_fn(self.items) + elif all(isinstance(item, str) for item in self.items): + output = "".join(self.items) - elif all(isinstance(item, str) for item in self.items): - output = "".join(self.items) + self.span.update(output=output) - self.span.update(output=output).end() + self.span.end() raise # Re-raise StopAsyncIteration except (Exception, asyncio.CancelledError) as e: diff --git a/tests/unit/test_observe.py b/tests/unit/test_observe.py new file mode 100644 index 000000000..94a2cbb83 --- /dev/null +++ b/tests/unit/test_observe.py @@ -0,0 +1,92 @@ +import sys + +import pytest + +from langfuse import observe +from langfuse._client.attributes import LangfuseOtelSpanAttributes + + +def _finished_spans_by_name(memory_exporter, name: str): + return [span for span in memory_exporter.get_finished_spans() if span.name == name] + + +def test_sync_generator_preserves_context_without_output_capture( + langfuse_memory_client, memory_exporter +): + @observe(name="child_step") + def child_step(index: int) -> str: + return f"item_{index}" + + @observe(name="root", capture_output=False) + def root(): + def body(): + for index in range(2): + yield child_step(index) + + return body() + + generator = root() + + assert memory_exporter.get_finished_spans() == [] + + assert list(generator) == ["item_0", "item_1"] + assert generator.items == [] + + langfuse_memory_client.flush() + + root_span = _finished_spans_by_name(memory_exporter, "root")[0] + child_spans = _finished_spans_by_name(memory_exporter, "child_step") + + assert len(child_spans) == 2 + assert all(child.parent is not None for child in child_spans) + assert all( + child.parent.span_id == root_span.context.span_id for child in child_spans + ) + assert all( + child.context.trace_id == root_span.context.trace_id for child in child_spans + ) + assert LangfuseOtelSpanAttributes.OBSERVATION_OUTPUT not in root_span.attributes + + +@pytest.mark.asyncio +@pytest.mark.skipif(sys.version_info < (3, 11), reason="requires python3.11 or higher") +async def test_streaming_response_preserves_context_without_output_capture( + langfuse_memory_client, memory_exporter +): + class StreamingResponse: + def __init__(self, body_iterator): + self.body_iterator = body_iterator + + @observe(name="stream_step") + async def stream_step(index: int) -> str: + return f"chunk_{index}" + + async def body(): + for index in range(2): + yield await stream_step(index) + + @observe(name="endpoint", capture_output=False) + async def endpoint(): + return StreamingResponse(body()) + + response = await endpoint() + + assert memory_exporter.get_finished_spans() == [] + + assert [item async for item in response.body_iterator] == ["chunk_0", "chunk_1"] + assert response.body_iterator.items == [] + + langfuse_memory_client.flush() + + endpoint_span = _finished_spans_by_name(memory_exporter, "endpoint")[0] + step_spans = _finished_spans_by_name(memory_exporter, "stream_step") + + assert len(step_spans) == 2 + assert all(step.parent is not None for step in step_spans) + assert all( + step.parent.span_id == endpoint_span.context.span_id for step in step_spans + ) + assert all( + step.context.trace_id == endpoint_span.context.trace_id for step in step_spans + ) + assert LangfuseOtelSpanAttributes.OBSERVATION_OUTPUT not in endpoint_span.attributes