Skip to content

Commit 9ee80d3

Browse files
committed
fix(observe): preserve context when closing wrappers
1 parent c18d0b4 commit 9ee80d3

File tree

2 files changed

+71
-2
lines changed

2 files changed

+71
-2
lines changed

langfuse/_client/observe.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ def close(self) -> None:
600600
return
601601

602602
try:
603-
self.generator.close()
603+
self.context.run(self.generator.close)
604604
except (Exception, asyncio.CancelledError) as error:
605605
self._finalize_with_error(error)
606606
raise
@@ -695,7 +695,13 @@ async def aclose(self) -> None:
695695
return
696696

697697
try:
698-
await self.generator.aclose()
698+
try:
699+
await asyncio.create_task(
700+
self.generator.aclose(),
701+
context=self.context,
702+
) # type: ignore
703+
except TypeError:
704+
await self.generator.aclose()
699705
except (Exception, asyncio.CancelledError) as error:
700706
self._finalize_with_error(error)
701707
raise

tests/unit/test_observe.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,37 @@ def generator() -> Generator[str, None, None]:
137137
assert span.updates == []
138138

139139

140+
def test_sync_generator_wrapper_close_preserves_context() -> None:
141+
marker = contextvars.ContextVar("marker", default="ambient")
142+
seen: list[str] = []
143+
144+
def generator() -> Generator[str, None, None]:
145+
try:
146+
yield "item_0"
147+
yield "item_1"
148+
finally:
149+
seen.append(marker.get())
150+
151+
span = SpanRecorder()
152+
context = contextvars.copy_context()
153+
context.run(marker.set, "preserved")
154+
wrapper = _ContextPreservedSyncGeneratorWrapper(
155+
generator(),
156+
context,
157+
cast(Any, span),
158+
False,
159+
None,
160+
)
161+
162+
assert next(wrapper) == "item_0"
163+
marker.set("ambient-now")
164+
165+
wrapper.close()
166+
167+
assert seen == ["preserved"]
168+
assert span.ended == 1
169+
170+
140171
def test_sync_generator_wrapper_del_ends_span_when_abandoned() -> None:
141172
def generator() -> Generator[str, None, None]:
142173
yield "item_0"
@@ -184,6 +215,38 @@ async def generator() -> AsyncGenerator[str, None]:
184215
assert span.updates == []
185216

186217

218+
@pytest.mark.asyncio
219+
async def test_async_generator_wrapper_aclose_preserves_context() -> None:
220+
marker = contextvars.ContextVar("marker", default="ambient")
221+
seen: list[str] = []
222+
223+
async def generator() -> AsyncGenerator[str, None]:
224+
try:
225+
yield "item_0"
226+
yield "item_1"
227+
finally:
228+
seen.append(marker.get())
229+
230+
span = SpanRecorder()
231+
context = contextvars.copy_context()
232+
context.run(marker.set, "preserved")
233+
wrapper = _ContextPreservedAsyncGeneratorWrapper(
234+
generator(),
235+
context,
236+
cast(Any, span),
237+
False,
238+
None,
239+
)
240+
241+
assert await wrapper.__anext__() == "item_0"
242+
marker.set("ambient-now")
243+
244+
await wrapper.aclose()
245+
246+
assert seen == ["preserved"]
247+
assert span.ended == 1
248+
249+
187250
@pytest.mark.asyncio
188251
async def test_async_generator_wrapper_del_ends_span_when_abandoned() -> None:
189252
async def generator() -> AsyncGenerator[str, None]:

0 commit comments

Comments
 (0)