Skip to content

Commit 75eec09

Browse files
committed
fix(observe): preserve streaming context without output capture
1 parent cfbe7a3 commit 75eec09

File tree

2 files changed

+185
-84
lines changed

2 files changed

+185
-84
lines changed

langfuse/_client/observe.py

Lines changed: 95 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -290,42 +290,15 @@ async def async_wrapper(*args: Tuple[Any], **kwargs: Dict[str, Any]) -> Any:
290290

291291
try:
292292
result = await func(*args, **kwargs)
293-
294-
if capture_output is True:
295-
if inspect.isgenerator(result):
296-
is_return_type_generator = True
297-
298-
return self._wrap_sync_generator_result(
299-
langfuse_span_or_generation,
300-
result,
301-
transform_to_string,
302-
)
303-
304-
if inspect.isasyncgen(result):
305-
is_return_type_generator = True
306-
307-
return self._wrap_async_generator_result(
308-
langfuse_span_or_generation,
309-
result,
310-
transform_to_string,
311-
)
312-
313-
# handle starlette.StreamingResponse
314-
if type(result).__name__ == "StreamingResponse" and hasattr(
315-
result, "body_iterator"
316-
):
317-
is_return_type_generator = True
318-
319-
result.body_iterator = (
320-
self._wrap_async_generator_result(
321-
langfuse_span_or_generation,
322-
result.body_iterator,
323-
transform_to_string,
324-
)
325-
)
326-
327-
langfuse_span_or_generation.update(output=result)
328-
293+
(
294+
is_return_type_generator,
295+
result,
296+
) = self._handle_observe_result(
297+
langfuse_span_or_generation,
298+
result,
299+
capture_output=capture_output,
300+
transform_to_string=transform_to_string,
301+
)
329302
return result
330303
except (Exception, asyncio.CancelledError) as e:
331304
langfuse_span_or_generation.update(
@@ -408,42 +381,15 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
408381

409382
try:
410383
result = func(*args, **kwargs)
411-
412-
if capture_output is True:
413-
if inspect.isgenerator(result):
414-
is_return_type_generator = True
415-
416-
return self._wrap_sync_generator_result(
417-
langfuse_span_or_generation,
418-
result,
419-
transform_to_string,
420-
)
421-
422-
if inspect.isasyncgen(result):
423-
is_return_type_generator = True
424-
425-
return self._wrap_async_generator_result(
426-
langfuse_span_or_generation,
427-
result,
428-
transform_to_string,
429-
)
430-
431-
# handle starlette.StreamingResponse
432-
if type(result).__name__ == "StreamingResponse" and hasattr(
433-
result, "body_iterator"
434-
):
435-
is_return_type_generator = True
436-
437-
result.body_iterator = (
438-
self._wrap_async_generator_result(
439-
langfuse_span_or_generation,
440-
result.body_iterator,
441-
transform_to_string,
442-
)
443-
)
444-
445-
langfuse_span_or_generation.update(output=result)
446-
384+
(
385+
is_return_type_generator,
386+
result,
387+
) = self._handle_observe_result(
388+
langfuse_span_or_generation,
389+
result,
390+
capture_output=capture_output,
391+
transform_to_string=transform_to_string,
392+
)
447393
return result
448394
except (Exception, asyncio.CancelledError) as e:
449395
langfuse_span_or_generation.update(
@@ -493,6 +439,7 @@ def _wrap_sync_generator_result(
493439
LangfuseGuardrail,
494440
],
495441
generator: Generator,
442+
capture_output: bool,
496443
transform_to_string: Optional[Callable[[Iterable], str]] = None,
497444
) -> Any:
498445
preserved_context = contextvars.copy_context()
@@ -501,6 +448,7 @@ def _wrap_sync_generator_result(
501448
generator,
502449
preserved_context,
503450
langfuse_span_or_generation,
451+
capture_output,
504452
transform_to_string,
505453
)
506454

@@ -518,6 +466,7 @@ def _wrap_async_generator_result(
518466
LangfuseGuardrail,
519467
],
520468
generator: AsyncGenerator,
469+
capture_output: bool,
521470
transform_to_string: Optional[Callable[[Iterable], str]] = None,
522471
) -> Any:
523472
preserved_context = contextvars.copy_context()
@@ -526,9 +475,61 @@ def _wrap_async_generator_result(
526475
generator,
527476
preserved_context,
528477
langfuse_span_or_generation,
478+
capture_output,
529479
transform_to_string,
530480
)
531481

482+
def _handle_observe_result(
483+
self,
484+
langfuse_span_or_generation: Union[
485+
LangfuseSpan,
486+
LangfuseGeneration,
487+
LangfuseAgent,
488+
LangfuseTool,
489+
LangfuseChain,
490+
LangfuseRetriever,
491+
LangfuseEvaluator,
492+
LangfuseEmbedding,
493+
LangfuseGuardrail,
494+
],
495+
result: Any,
496+
*,
497+
capture_output: bool,
498+
transform_to_string: Optional[Callable[[Iterable], str]] = None,
499+
) -> Tuple[bool, Any]:
500+
if inspect.isgenerator(result):
501+
return True, self._wrap_sync_generator_result(
502+
langfuse_span_or_generation,
503+
result,
504+
capture_output,
505+
transform_to_string,
506+
)
507+
508+
if inspect.isasyncgen(result):
509+
return True, self._wrap_async_generator_result(
510+
langfuse_span_or_generation,
511+
result,
512+
capture_output,
513+
transform_to_string,
514+
)
515+
516+
# handle starlette.StreamingResponse
517+
if type(result).__name__ == "StreamingResponse" and hasattr(
518+
result, "body_iterator"
519+
):
520+
result.body_iterator = self._wrap_async_generator_result(
521+
langfuse_span_or_generation,
522+
result.body_iterator,
523+
capture_output,
524+
transform_to_string,
525+
)
526+
return True, result
527+
528+
if capture_output is True:
529+
langfuse_span_or_generation.update(output=result)
530+
531+
return False, result
532+
532533

533534
_decorator = LangfuseDecorator()
534535

@@ -553,12 +554,14 @@ def __init__(
553554
LangfuseEmbedding,
554555
LangfuseGuardrail,
555556
],
557+
capture_output: bool,
556558
transform_fn: Optional[Callable[[Iterable], str]],
557559
) -> None:
558560
self.generator = generator
559561
self.context = context
560562
self.items: List[Any] = []
561563
self.span = span
564+
self.capture_output = capture_output
562565
self.transform_fn = transform_fn
563566

564567
def __iter__(self) -> "_ContextPreservedSyncGeneratorWrapper":
@@ -574,15 +577,18 @@ def __next__(self) -> Any:
574577

575578
except StopIteration:
576579
# Handle output and span cleanup when generator is exhausted
577-
output: Any = self.items
580+
if self.capture_output:
581+
output: Any = self.items
582+
583+
if self.transform_fn is not None:
584+
output = self.transform_fn(self.items)
578585

579-
if self.transform_fn is not None:
580-
output = self.transform_fn(self.items)
586+
elif all(isinstance(item, str) for item in self.items):
587+
output = "".join(self.items)
581588

582-
elif all(isinstance(item, str) for item in self.items):
583-
output = "".join(self.items)
589+
self.span.update(output=output)
584590

585-
self.span.update(output=output).end()
591+
self.span.end()
586592

587593
raise # Re-raise StopIteration
588594

@@ -612,12 +618,14 @@ def __init__(
612618
LangfuseEmbedding,
613619
LangfuseGuardrail,
614620
],
621+
capture_output: bool,
615622
transform_fn: Optional[Callable[[Iterable], str]],
616623
) -> None:
617624
self.generator = generator
618625
self.context = context
619626
self.items: List[Any] = []
620627
self.span = span
628+
self.capture_output = capture_output
621629
self.transform_fn = transform_fn
622630

623631
def __aiter__(self) -> "_ContextPreservedAsyncGeneratorWrapper":
@@ -642,15 +650,18 @@ async def __anext__(self) -> Any:
642650

643651
except StopAsyncIteration:
644652
# Handle output and span cleanup when generator is exhausted
645-
output: Any = self.items
653+
if self.capture_output:
654+
output: Any = self.items
655+
656+
if self.transform_fn is not None:
657+
output = self.transform_fn(self.items)
646658

647-
if self.transform_fn is not None:
648-
output = self.transform_fn(self.items)
659+
elif all(isinstance(item, str) for item in self.items):
660+
output = "".join(self.items)
649661

650-
elif all(isinstance(item, str) for item in self.items):
651-
output = "".join(self.items)
662+
self.span.update(output=output)
652663

653-
self.span.update(output=output).end()
664+
self.span.end()
654665

655666
raise # Re-raise StopAsyncIteration
656667
except (Exception, asyncio.CancelledError) as e:

tests/unit/test_observe.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import sys
2+
3+
import pytest
4+
5+
from langfuse import observe
6+
from langfuse._client.attributes import LangfuseOtelSpanAttributes
7+
8+
9+
def _finished_spans_by_name(memory_exporter, name: str):
10+
return [span for span in memory_exporter.get_finished_spans() if span.name == name]
11+
12+
13+
def test_sync_generator_preserves_context_without_output_capture(
14+
langfuse_memory_client, memory_exporter
15+
):
16+
@observe(name="child_step")
17+
def child_step(index: int) -> str:
18+
return f"item_{index}"
19+
20+
@observe(name="root", capture_output=False)
21+
def root():
22+
def body():
23+
for index in range(2):
24+
yield child_step(index)
25+
26+
return body()
27+
28+
generator = root()
29+
30+
assert memory_exporter.get_finished_spans() == []
31+
32+
assert list(generator) == ["item_0", "item_1"]
33+
34+
langfuse_memory_client.flush()
35+
36+
root_span = _finished_spans_by_name(memory_exporter, "root")[0]
37+
child_spans = _finished_spans_by_name(memory_exporter, "child_step")
38+
39+
assert len(child_spans) == 2
40+
assert all(child.parent is not None for child in child_spans)
41+
assert all(
42+
child.parent.span_id == root_span.context.span_id for child in child_spans
43+
)
44+
assert all(
45+
child.context.trace_id == root_span.context.trace_id for child in child_spans
46+
)
47+
assert LangfuseOtelSpanAttributes.OBSERVATION_OUTPUT not in root_span.attributes
48+
49+
50+
@pytest.mark.asyncio
51+
@pytest.mark.skipif(sys.version_info < (3, 11), reason="requires python3.11 or higher")
52+
async def test_streaming_response_preserves_context_without_output_capture(
53+
langfuse_memory_client, memory_exporter
54+
):
55+
class StreamingResponse:
56+
def __init__(self, body_iterator):
57+
self.body_iterator = body_iterator
58+
59+
@observe(name="stream_step")
60+
async def stream_step(index: int) -> str:
61+
return f"chunk_{index}"
62+
63+
async def body():
64+
for index in range(2):
65+
yield await stream_step(index)
66+
67+
@observe(name="endpoint", capture_output=False)
68+
async def endpoint():
69+
return StreamingResponse(body())
70+
71+
response = await endpoint()
72+
73+
assert memory_exporter.get_finished_spans() == []
74+
75+
assert [item async for item in response.body_iterator] == ["chunk_0", "chunk_1"]
76+
77+
langfuse_memory_client.flush()
78+
79+
endpoint_span = _finished_spans_by_name(memory_exporter, "endpoint")[0]
80+
step_spans = _finished_spans_by_name(memory_exporter, "stream_step")
81+
82+
assert len(step_spans) == 2
83+
assert all(step.parent is not None for step in step_spans)
84+
assert all(
85+
step.parent.span_id == endpoint_span.context.span_id for step in step_spans
86+
)
87+
assert all(
88+
step.context.trace_id == endpoint_span.context.trace_id for step in step_spans
89+
)
90+
assert LangfuseOtelSpanAttributes.OBSERVATION_OUTPUT not in endpoint_span.attributes

0 commit comments

Comments
 (0)