1+ import asyncio
2+ import contextvars
3+ import gc
14import sys
5+ from typing import Any , AsyncGenerator , Generator , cast
26
37import pytest
48
59from langfuse import observe
610from langfuse ._client .attributes import LangfuseOtelSpanAttributes
11+ from langfuse ._client .observe import (
12+ _ContextPreservedAsyncGeneratorWrapper ,
13+ _ContextPreservedSyncGeneratorWrapper ,
14+ )
715
816
9- def _finished_spans_by_name (memory_exporter , name : str ):
17+ class SpanRecorder :
18+ def __init__ (self ) -> None :
19+ self .ended = 0
20+ self .updates : list [dict [str , Any ]] = []
21+
22+ def update (self , ** kwargs : Any ) -> "SpanRecorder" :
23+ self .updates .append (kwargs )
24+ return self
25+
26+ def end (self ) -> "SpanRecorder" :
27+ self .ended += 1
28+ return self
29+
30+
31+ def _finished_spans_by_name (memory_exporter : Any , name : str ) -> list [Any ]:
1032 return [span for span in memory_exporter .get_finished_spans () if span .name == name ]
1133
1234
1335def test_sync_generator_preserves_context_without_output_capture (
14- langfuse_memory_client , memory_exporter
15- ):
36+ langfuse_memory_client : Any , memory_exporter : Any
37+ ) -> None :
1638 @observe (name = "child_step" )
1739 def child_step (index : int ) -> str :
1840 return f"item_{ index } "
1941
2042 @observe (name = "root" , capture_output = False )
21- def root ():
22- def body ():
43+ def root () -> Generator [ str , None , None ] :
44+ def body () -> Generator [ str , None , None ] :
2345 for index in range (2 ):
2446 yield child_step (index )
2547
@@ -30,7 +52,7 @@ def body():
3052 assert memory_exporter .get_finished_spans () == []
3153
3254 assert list (generator ) == ["item_0" , "item_1" ]
33- assert generator .items == []
55+ assert cast ( Any , generator ) .items == []
3456
3557 langfuse_memory_client .flush ()
3658
@@ -51,30 +73,30 @@ def body():
5173@pytest .mark .asyncio
5274@pytest .mark .skipif (sys .version_info < (3 , 11 ), reason = "requires python3.11 or higher" )
5375async def test_streaming_response_preserves_context_without_output_capture (
54- langfuse_memory_client , memory_exporter
55- ):
76+ langfuse_memory_client : Any , memory_exporter : Any
77+ ) -> None :
5678 class StreamingResponse :
57- def __init__ (self , body_iterator ) :
79+ def __init__ (self , body_iterator : AsyncGenerator [ str , None ]) -> None :
5880 self .body_iterator = body_iterator
5981
6082 @observe (name = "stream_step" )
6183 async def stream_step (index : int ) -> str :
6284 return f"chunk_{ index } "
6385
64- async def body ():
86+ async def body () -> AsyncGenerator [ str , None ] :
6587 for index in range (2 ):
6688 yield await stream_step (index )
6789
6890 @observe (name = "endpoint" , capture_output = False )
69- async def endpoint ():
91+ async def endpoint () -> StreamingResponse :
7092 return StreamingResponse (body ())
7193
7294 response = await endpoint ()
7395
7496 assert memory_exporter .get_finished_spans () == []
7597
7698 assert [item async for item in response .body_iterator ] == ["chunk_0" , "chunk_1" ]
77- assert response .body_iterator .items == []
99+ assert cast ( Any , response .body_iterator ) .items == []
78100
79101 langfuse_memory_client .flush ()
80102
@@ -90,3 +112,98 @@ async def endpoint():
90112 step .context .trace_id == endpoint_span .context .trace_id for step in step_spans
91113 )
92114 assert LangfuseOtelSpanAttributes .OBSERVATION_OUTPUT not in endpoint_span .attributes
115+
116+
117+ def test_sync_generator_wrapper_close_ends_span_without_exhaustion () -> None :
118+ def generator () -> Generator [str , None , None ]:
119+ yield "item_0"
120+ yield "item_1"
121+
122+ span = SpanRecorder ()
123+ wrapper = _ContextPreservedSyncGeneratorWrapper (
124+ generator (),
125+ contextvars .copy_context (),
126+ cast (Any , span ),
127+ False ,
128+ None ,
129+ )
130+
131+ assert next (wrapper ) == "item_0"
132+
133+ wrapper .close ()
134+ wrapper .close ()
135+
136+ assert span .ended == 1
137+ assert span .updates == []
138+
139+
140+ def test_sync_generator_wrapper_del_ends_span_when_abandoned () -> None :
141+ def generator () -> Generator [str , None , None ]:
142+ yield "item_0"
143+ yield "item_1"
144+
145+ span = SpanRecorder ()
146+ wrapper = _ContextPreservedSyncGeneratorWrapper (
147+ generator (),
148+ contextvars .copy_context (),
149+ cast (Any , span ),
150+ False ,
151+ None ,
152+ )
153+
154+ assert next (wrapper ) == "item_0"
155+
156+ del wrapper
157+ gc .collect ()
158+
159+ assert span .ended == 1
160+ assert span .updates == []
161+
162+
163+ @pytest .mark .asyncio
164+ async def test_async_generator_wrapper_aclose_ends_span_without_exhaustion () -> None :
165+ async def generator () -> AsyncGenerator [str , None ]:
166+ yield "item_0"
167+ yield "item_1"
168+
169+ span = SpanRecorder ()
170+ wrapper = _ContextPreservedAsyncGeneratorWrapper (
171+ generator (),
172+ contextvars .copy_context (),
173+ cast (Any , span ),
174+ False ,
175+ None ,
176+ )
177+
178+ assert await wrapper .__anext__ () == "item_0"
179+
180+ await wrapper .aclose ()
181+ await wrapper .close ()
182+
183+ assert span .ended == 1
184+ assert span .updates == []
185+
186+
187+ @pytest .mark .asyncio
188+ async def test_async_generator_wrapper_del_ends_span_when_abandoned () -> None :
189+ async def generator () -> AsyncGenerator [str , None ]:
190+ yield "item_0"
191+ yield "item_1"
192+
193+ span = SpanRecorder ()
194+ wrapper = _ContextPreservedAsyncGeneratorWrapper (
195+ generator (),
196+ contextvars .copy_context (),
197+ cast (Any , span ),
198+ False ,
199+ None ,
200+ )
201+
202+ assert await wrapper .__anext__ () == "item_0"
203+
204+ del wrapper
205+ gc .collect ()
206+ await asyncio .sleep (0 )
207+
208+ assert span .ended == 1
209+ assert span .updates == []
0 commit comments