Skip to content

Commit bdf0d17

Browse files
committed
fix(openai): finalize stream exits
1 parent 95c933d commit bdf0d17

File tree

2 files changed

+159
-5
lines changed

2 files changed

+159
-5
lines changed

langfuse/openai.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from collections import defaultdict
2222
from dataclasses import dataclass
2323
from datetime import datetime
24-
from inspect import isclass
24+
from inspect import isawaitable, isclass
2525
from typing import Any, Optional, cast
2626

2727
from openai._types import NotGiven
@@ -841,6 +841,7 @@ def _install_openai_stream_iteration_hooks() -> None:
841841

842842
if not _openai_stream_iter_hook_installed:
843843
original_iter = openai.Stream.__iter__
844+
original_aiter = openai.AsyncStream.__aiter__
844845

845846
def traced_iter(self: Any) -> Any:
846847
try:
@@ -850,7 +851,17 @@ def traced_iter(self: Any) -> Any:
850851
if finalize_once is not None:
851852
finalize_once()
852853

854+
async def traced_aiter(self: Any) -> Any:
855+
try:
856+
async for item in original_aiter(self):
857+
yield item
858+
finally:
859+
finalize_once = getattr(self, "_langfuse_finalize_once", None)
860+
if finalize_once is not None:
861+
await finalize_once()
862+
853863
setattr(openai.Stream, "__iter__", traced_iter)
864+
setattr(openai.AsyncStream, "__aiter__", traced_aiter)
854865
_openai_stream_iter_hook_installed = True
855866

856867

@@ -973,6 +984,8 @@ async def finalize_once() -> None:
973984
completion_start_time=completion_start_time,
974985
)
975986

987+
response._langfuse_finalize_once = finalize_once # type: ignore[attr-defined]
988+
976989
async def traced_iterator() -> Any:
977990
nonlocal completion_start_time
978991
try:
@@ -1228,7 +1241,16 @@ def __enter__(self) -> Any:
12281241
return self.__iter__()
12291242

12301243
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
1231-
pass
1244+
self.close()
1245+
1246+
def close(self) -> None:
1247+
close = getattr(self.response, "close", None)
1248+
1249+
try:
1250+
if callable(close):
1251+
close()
1252+
finally:
1253+
self._finalize()
12321254

12331255
def _finalize(self) -> None:
12341256
if self._is_finalized:
@@ -1290,7 +1312,7 @@ async def __aenter__(self) -> Any:
12901312
return self.__aiter__()
12911313

12921314
async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
1293-
pass
1315+
await self.aclose()
12941316

12951317
async def _finalize(self) -> None:
12961318
if self._is_finalized:
@@ -1309,11 +1331,37 @@ async def close(self) -> None:
13091331
13101332
Automatically called if the response body is read to completion.
13111333
"""
1312-
await self.response.close()
1334+
close = getattr(self.response, "close", None)
1335+
aclose = getattr(self.response, "aclose", None)
1336+
1337+
try:
1338+
if callable(close):
1339+
result = close()
1340+
if isawaitable(result):
1341+
await result
1342+
elif callable(aclose):
1343+
result = aclose()
1344+
if isawaitable(result):
1345+
await result
1346+
finally:
1347+
await self._finalize()
13131348

13141349
async def aclose(self) -> None:
13151350
"""Close the response and release the connection.
13161351
13171352
Automatically called if the response body is read to completion.
13181353
"""
1319-
await self.response.aclose()
1354+
aclose = getattr(self.response, "aclose", None)
1355+
close = getattr(self.response, "close", None)
1356+
1357+
try:
1358+
if callable(aclose):
1359+
result = aclose()
1360+
if isawaitable(result):
1361+
await result
1362+
elif callable(close):
1363+
result = close()
1364+
if isawaitable(result):
1365+
await result
1366+
finally:
1367+
await self._finalize()

tests/unit/test_openai.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from types import SimpleNamespace
23
from unittest.mock import patch
34

@@ -51,6 +52,18 @@ def end(self) -> None:
5152
self.end_calls += 1
5253

5354

55+
class DummyFallbackAsyncResponse:
56+
def __init__(self) -> None:
57+
self.close_calls = 0
58+
self.aclose_calls = 0
59+
60+
async def close(self) -> None:
61+
self.close_calls += 1
62+
63+
async def aclose(self) -> None:
64+
self.aclose_calls += 1
65+
66+
5467
def _make_chat_stream_chunks():
5568
usage = SimpleNamespace(prompt_tokens=3, completion_tokens=1, total_tokens=4)
5669

@@ -469,6 +482,42 @@ async def test_openai_async_stream_supports_anext(
469482
}
470483

471484

485+
@pytest.mark.asyncio
486+
async def test_openai_async_stream_break_still_finalizes_generation(
487+
langfuse_memory_client, get_span
488+
):
489+
openai_client = lf_openai.AsyncOpenAI(api_key="test")
490+
raw_stream = DummyOpenAIAsyncStream(
491+
_make_chat_stream_chunks(), DummyAsyncResponse()
492+
)
493+
494+
with patch.object(openai_client.chat.completions, "_post", return_value=raw_stream):
495+
stream = await openai_client.chat.completions.create(
496+
name="unit-openai-native-async-stream-break",
497+
model="gpt-4o-mini",
498+
messages=[{"role": "user", "content": "1 + 1 = ?"}],
499+
temperature=0,
500+
stream=True,
501+
)
502+
503+
async for chunk in stream:
504+
assert chunk.choices[0].delta.content == "2"
505+
break
506+
507+
# Async generator finalizers are scheduled across event-loop turns.
508+
for _ in range(5):
509+
await asyncio.sleep(0)
510+
511+
langfuse_memory_client.flush()
512+
span = get_span("unit-openai-native-async-stream-break")
513+
514+
assert span.attributes[LangfuseOtelSpanAttributes.OBSERVATION_OUTPUT] == "2"
515+
assert (
516+
span.attributes[LangfuseOtelSpanAttributes.OBSERVATION_COMPLETION_START_TIME]
517+
is not None
518+
)
519+
520+
472521
def test_fallback_sync_stream_finalizes_once():
473522
resource = SimpleNamespace(object="Completions", type="chat")
474523
generation = DummyGeneration()
@@ -490,6 +539,24 @@ def fallback_stream():
490539
assert generation.end_calls == 1
491540

492541

542+
def test_fallback_sync_stream_exit_finalizes_once():
543+
resource = SimpleNamespace(object="Completions", type="chat")
544+
generation = DummyGeneration()
545+
546+
def fallback_stream():
547+
yield _make_single_chunk_stream()
548+
549+
wrapper = lf_openai_module.LangfuseResponseGeneratorSync(
550+
resource=resource,
551+
response=fallback_stream(),
552+
generation=generation,
553+
)
554+
555+
wrapper.__exit__(None, None, None)
556+
557+
assert generation.end_calls == 1
558+
559+
493560
@pytest.mark.asyncio
494561
async def test_fallback_async_stream_finalizes_once():
495562
resource = SimpleNamespace(object="Completions", type="chat")
@@ -513,6 +580,45 @@ async def fallback_stream():
513580
assert generation.end_calls == 1
514581

515582

583+
@pytest.mark.asyncio
584+
async def test_fallback_async_stream_close_and_exit_finalize_once():
585+
resource = SimpleNamespace(object="Completions", type="chat")
586+
generation = DummyGeneration()
587+
response = DummyFallbackAsyncResponse()
588+
589+
wrapper = lf_openai_module.LangfuseResponseGeneratorAsync(
590+
resource=resource,
591+
response=response,
592+
generation=generation,
593+
)
594+
595+
await wrapper.close()
596+
await wrapper.__aexit__(None, None, None)
597+
598+
assert generation.end_calls == 1
599+
assert response.close_calls == 1
600+
assert response.aclose_calls == 1
601+
602+
603+
@pytest.mark.asyncio
604+
async def test_fallback_async_stream_aclose_finalizes_once():
605+
resource = SimpleNamespace(object="Completions", type="chat")
606+
generation = DummyGeneration()
607+
608+
async def fallback_stream():
609+
yield _make_single_chunk_stream()
610+
611+
wrapper = lf_openai_module.LangfuseResponseGeneratorAsync(
612+
resource=resource,
613+
response=fallback_stream(),
614+
generation=generation,
615+
)
616+
617+
await wrapper.aclose()
618+
619+
assert generation.end_calls == 1
620+
621+
516622
def test_embedding_exports_dimensions_and_count(
517623
langfuse_memory_client, get_span, json_attr
518624
):

0 commit comments

Comments
 (0)