Skip to content

Commit 2f60526

Browse files
committed
fix(observe): finalize abandoned generator wrappers
1 parent 6af775c commit 2f60526

File tree

2 files changed

+225
-48
lines changed

2 files changed

+225
-48
lines changed

langfuse/_client/observe.py

Lines changed: 96 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -563,10 +563,56 @@ def __init__(
563563
self.span = span
564564
self.capture_output = capture_output
565565
self.transform_fn = transform_fn
566+
self._span_ended = False
566567

567568
def __iter__(self) -> "_ContextPreservedSyncGeneratorWrapper":
568569
return self
569570

571+
def _finalize(self) -> None:
572+
if self._span_ended:
573+
return
574+
575+
if self.capture_output:
576+
output: Any = self.items
577+
578+
if self.transform_fn is not None:
579+
output = self.transform_fn(self.items)
580+
581+
elif all(isinstance(item, str) for item in self.items):
582+
output = "".join(self.items)
583+
584+
self.span.update(output=output)
585+
586+
self.span.end()
587+
self._span_ended = True
588+
589+
def _finalize_with_error(self, error: BaseException) -> None:
590+
if self._span_ended:
591+
return
592+
593+
self.span.update(
594+
level="ERROR", status_message=str(error) or type(error).__name__
595+
).end()
596+
self._span_ended = True
597+
598+
def close(self) -> None:
599+
if self._span_ended:
600+
return
601+
602+
try:
603+
self.generator.close()
604+
except (Exception, asyncio.CancelledError) as error:
605+
self._finalize_with_error(error)
606+
raise
607+
else:
608+
self._finalize()
609+
610+
def __del__(self) -> None:
611+
try:
612+
self.close()
613+
except BaseException:
614+
pass
615+
570616
def __next__(self) -> Any:
571617
try:
572618
# Run the generator's __next__ in the preserved context
@@ -577,27 +623,11 @@ def __next__(self) -> Any:
577623
return item
578624

579625
except StopIteration:
580-
# Handle output and span cleanup when generator is exhausted
581-
if self.capture_output:
582-
output: Any = self.items
583-
584-
if self.transform_fn is not None:
585-
output = self.transform_fn(self.items)
586-
587-
elif all(isinstance(item, str) for item in self.items):
588-
output = "".join(self.items)
589-
590-
self.span.update(output=output)
591-
592-
self.span.end()
593-
626+
self._finalize()
594627
raise # Re-raise StopIteration
595628

596629
except (Exception, asyncio.CancelledError) as e:
597-
self.span.update(
598-
level="ERROR", status_message=str(e) or type(e).__name__
599-
).end()
600-
630+
self._finalize_with_error(e)
601631
raise
602632

603633

@@ -628,10 +658,56 @@ def __init__(
628658
self.span = span
629659
self.capture_output = capture_output
630660
self.transform_fn = transform_fn
661+
self._span_ended = False
631662

632663
def __aiter__(self) -> "_ContextPreservedAsyncGeneratorWrapper":
633664
return self
634665

666+
def _finalize(self) -> None:
667+
if self._span_ended:
668+
return
669+
670+
if self.capture_output:
671+
output: Any = self.items
672+
673+
if self.transform_fn is not None:
674+
output = self.transform_fn(self.items)
675+
676+
elif all(isinstance(item, str) for item in self.items):
677+
output = "".join(self.items)
678+
679+
self.span.update(output=output)
680+
681+
self.span.end()
682+
self._span_ended = True
683+
684+
def _finalize_with_error(self, error: BaseException) -> None:
685+
if self._span_ended:
686+
return
687+
688+
self.span.update(
689+
level="ERROR", status_message=str(error) or type(error).__name__
690+
).end()
691+
self._span_ended = True
692+
693+
async def aclose(self) -> None:
694+
if self._span_ended:
695+
return
696+
697+
try:
698+
await self.generator.aclose()
699+
except (Exception, asyncio.CancelledError) as error:
700+
self._finalize_with_error(error)
701+
raise
702+
else:
703+
self._finalize()
704+
705+
async def close(self) -> None:
706+
await self.aclose()
707+
708+
def __del__(self) -> None:
709+
self._finalize()
710+
635711
async def __anext__(self) -> Any:
636712
try:
637713
# Run the generator's __anext__ in the preserved context
@@ -651,24 +727,8 @@ async def __anext__(self) -> Any:
651727
return item
652728

653729
except StopAsyncIteration:
654-
# Handle output and span cleanup when generator is exhausted
655-
if self.capture_output:
656-
output: Any = self.items
657-
658-
if self.transform_fn is not None:
659-
output = self.transform_fn(self.items)
660-
661-
elif all(isinstance(item, str) for item in self.items):
662-
output = "".join(self.items)
663-
664-
self.span.update(output=output)
665-
666-
self.span.end()
667-
730+
self._finalize()
668731
raise # Re-raise StopAsyncIteration
669732
except (Exception, asyncio.CancelledError) as e:
670-
self.span.update(
671-
level="ERROR", status_message=str(e) or type(e).__name__
672-
).end()
673-
733+
self._finalize_with_error(e)
674734
raise

tests/unit/test_observe.py

Lines changed: 129 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,47 @@
1+
import asyncio
2+
import contextvars
3+
import gc
14
import sys
5+
from typing import Any, AsyncGenerator, Generator, cast
26

37
import pytest
48

59
from langfuse import observe
610
from langfuse._client.attributes import LangfuseOtelSpanAttributes
11+
from langfuse._client.observe import (
12+
_ContextPreservedAsyncGeneratorWrapper,
13+
_ContextPreservedSyncGeneratorWrapper,
14+
)
715

816

9-
def _finished_spans_by_name(memory_exporter, name: str):
17+
class SpanRecorder:
18+
def __init__(self) -> None:
19+
self.ended = 0
20+
self.updates: list[dict[str, Any]] = []
21+
22+
def update(self, **kwargs: Any) -> "SpanRecorder":
23+
self.updates.append(kwargs)
24+
return self
25+
26+
def end(self) -> "SpanRecorder":
27+
self.ended += 1
28+
return self
29+
30+
31+
def _finished_spans_by_name(memory_exporter: Any, name: str) -> list[Any]:
1032
return [span for span in memory_exporter.get_finished_spans() if span.name == name]
1133

1234

1335
def test_sync_generator_preserves_context_without_output_capture(
14-
langfuse_memory_client, memory_exporter
15-
):
36+
langfuse_memory_client: Any, memory_exporter: Any
37+
) -> None:
1638
@observe(name="child_step")
1739
def child_step(index: int) -> str:
1840
return f"item_{index}"
1941

2042
@observe(name="root", capture_output=False)
21-
def root():
22-
def body():
43+
def root() -> Generator[str, None, None]:
44+
def body() -> Generator[str, None, None]:
2345
for index in range(2):
2446
yield child_step(index)
2547

@@ -30,7 +52,7 @@ def body():
3052
assert memory_exporter.get_finished_spans() == []
3153

3254
assert list(generator) == ["item_0", "item_1"]
33-
assert generator.items == []
55+
assert cast(Any, generator).items == []
3456

3557
langfuse_memory_client.flush()
3658

@@ -51,30 +73,30 @@ def body():
5173
@pytest.mark.asyncio
5274
@pytest.mark.skipif(sys.version_info < (3, 11), reason="requires python3.11 or higher")
5375
async def test_streaming_response_preserves_context_without_output_capture(
54-
langfuse_memory_client, memory_exporter
55-
):
76+
langfuse_memory_client: Any, memory_exporter: Any
77+
) -> None:
5678
class StreamingResponse:
57-
def __init__(self, body_iterator):
79+
def __init__(self, body_iterator: AsyncGenerator[str, None]) -> None:
5880
self.body_iterator = body_iterator
5981

6082
@observe(name="stream_step")
6183
async def stream_step(index: int) -> str:
6284
return f"chunk_{index}"
6385

64-
async def body():
86+
async def body() -> AsyncGenerator[str, None]:
6587
for index in range(2):
6688
yield await stream_step(index)
6789

6890
@observe(name="endpoint", capture_output=False)
69-
async def endpoint():
91+
async def endpoint() -> StreamingResponse:
7092
return StreamingResponse(body())
7193

7294
response = await endpoint()
7395

7496
assert memory_exporter.get_finished_spans() == []
7597

7698
assert [item async for item in response.body_iterator] == ["chunk_0", "chunk_1"]
77-
assert response.body_iterator.items == []
99+
assert cast(Any, response.body_iterator).items == []
78100

79101
langfuse_memory_client.flush()
80102

@@ -90,3 +112,98 @@ async def endpoint():
90112
step.context.trace_id == endpoint_span.context.trace_id for step in step_spans
91113
)
92114
assert LangfuseOtelSpanAttributes.OBSERVATION_OUTPUT not in endpoint_span.attributes
115+
116+
117+
def test_sync_generator_wrapper_close_ends_span_without_exhaustion() -> None:
118+
def generator() -> Generator[str, None, None]:
119+
yield "item_0"
120+
yield "item_1"
121+
122+
span = SpanRecorder()
123+
wrapper = _ContextPreservedSyncGeneratorWrapper(
124+
generator(),
125+
contextvars.copy_context(),
126+
cast(Any, span),
127+
False,
128+
None,
129+
)
130+
131+
assert next(wrapper) == "item_0"
132+
133+
wrapper.close()
134+
wrapper.close()
135+
136+
assert span.ended == 1
137+
assert span.updates == []
138+
139+
140+
def test_sync_generator_wrapper_del_ends_span_when_abandoned() -> None:
141+
def generator() -> Generator[str, None, None]:
142+
yield "item_0"
143+
yield "item_1"
144+
145+
span = SpanRecorder()
146+
wrapper = _ContextPreservedSyncGeneratorWrapper(
147+
generator(),
148+
contextvars.copy_context(),
149+
cast(Any, span),
150+
False,
151+
None,
152+
)
153+
154+
assert next(wrapper) == "item_0"
155+
156+
del wrapper
157+
gc.collect()
158+
159+
assert span.ended == 1
160+
assert span.updates == []
161+
162+
163+
@pytest.mark.asyncio
164+
async def test_async_generator_wrapper_aclose_ends_span_without_exhaustion() -> None:
165+
async def generator() -> AsyncGenerator[str, None]:
166+
yield "item_0"
167+
yield "item_1"
168+
169+
span = SpanRecorder()
170+
wrapper = _ContextPreservedAsyncGeneratorWrapper(
171+
generator(),
172+
contextvars.copy_context(),
173+
cast(Any, span),
174+
False,
175+
None,
176+
)
177+
178+
assert await wrapper.__anext__() == "item_0"
179+
180+
await wrapper.aclose()
181+
await wrapper.close()
182+
183+
assert span.ended == 1
184+
assert span.updates == []
185+
186+
187+
@pytest.mark.asyncio
188+
async def test_async_generator_wrapper_del_ends_span_when_abandoned() -> None:
189+
async def generator() -> AsyncGenerator[str, None]:
190+
yield "item_0"
191+
yield "item_1"
192+
193+
span = SpanRecorder()
194+
wrapper = _ContextPreservedAsyncGeneratorWrapper(
195+
generator(),
196+
contextvars.copy_context(),
197+
cast(Any, span),
198+
False,
199+
None,
200+
)
201+
202+
assert await wrapper.__anext__() == "item_0"
203+
204+
del wrapper
205+
gc.collect()
206+
await asyncio.sleep(0)
207+
208+
assert span.ended == 1
209+
assert span.updates == []

0 commit comments

Comments
 (0)