@@ -137,6 +137,37 @@ def generator() -> Generator[str, None, None]:
137137 assert span .updates == []
138138
139139
140+ def test_sync_generator_wrapper_close_preserves_context () -> None :
141+ marker = contextvars .ContextVar ("marker" , default = "ambient" )
142+ seen : list [str ] = []
143+
144+ def generator () -> Generator [str , None , None ]:
145+ try :
146+ yield "item_0"
147+ yield "item_1"
148+ finally :
149+ seen .append (marker .get ())
150+
151+ span = SpanRecorder ()
152+ context = contextvars .copy_context ()
153+ context .run (marker .set , "preserved" )
154+ wrapper = _ContextPreservedSyncGeneratorWrapper (
155+ generator (),
156+ context ,
157+ cast (Any , span ),
158+ False ,
159+ None ,
160+ )
161+
162+ assert next (wrapper ) == "item_0"
163+ marker .set ("ambient-now" )
164+
165+ wrapper .close ()
166+
167+ assert seen == ["preserved" ]
168+ assert span .ended == 1
169+
170+
140171def test_sync_generator_wrapper_del_ends_span_when_abandoned () -> None :
141172 def generator () -> Generator [str , None , None ]:
142173 yield "item_0"
@@ -184,6 +215,38 @@ async def generator() -> AsyncGenerator[str, None]:
184215 assert span .updates == []
185216
186217
218+ @pytest .mark .asyncio
219+ async def test_async_generator_wrapper_aclose_preserves_context () -> None :
220+ marker = contextvars .ContextVar ("marker" , default = "ambient" )
221+ seen : list [str ] = []
222+
223+ async def generator () -> AsyncGenerator [str , None ]:
224+ try :
225+ yield "item_0"
226+ yield "item_1"
227+ finally :
228+ seen .append (marker .get ())
229+
230+ span = SpanRecorder ()
231+ context = contextvars .copy_context ()
232+ context .run (marker .set , "preserved" )
233+ wrapper = _ContextPreservedAsyncGeneratorWrapper (
234+ generator (),
235+ context ,
236+ cast (Any , span ),
237+ False ,
238+ None ,
239+ )
240+
241+ assert await wrapper .__anext__ () == "item_0"
242+ marker .set ("ambient-now" )
243+
244+ await wrapper .aclose ()
245+
246+ assert seen == ["preserved" ]
247+ assert span .ended == 1
248+
249+
187250@pytest .mark .asyncio
188251async def test_async_generator_wrapper_del_ends_span_when_abandoned () -> None :
189252 async def generator () -> AsyncGenerator [str , None ]:
0 commit comments