Skip to content

Commit 3399a69

Browse files
authored
fix(openai): structured output parsing with openai >= 1.50 (#957)
1 parent 42d3b6e commit 3399a69

2 files changed

Lines changed: 81 additions & 13 deletions

File tree

langfuse/openai.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919

2020
import copy
2121
import logging
22+
from inspect import isclass
2223
import types
2324

2425
from collections import defaultdict
26+
from dataclasses import dataclass
2527
from typing import List, Optional
2628

2729
import openai.resources
@@ -34,6 +36,7 @@
3436
from langfuse.decorators import langfuse_context
3537
from langfuse.utils import _get_timestamp
3638
from langfuse.utils.langfuse_singleton import LangfuseSingleton
39+
from pydantic import BaseModel
3740

3841
try:
3942
import openai
@@ -53,19 +56,14 @@
5356
log = logging.getLogger("langfuse")
5457

5558

59+
@dataclass
5660
class OpenAiDefinition:
5761
module: str
5862
object: str
5963
method: str
6064
type: str
6165
sync: bool
62-
63-
def __init__(self, module: str, object: str, method: str, type: str, sync: bool):
64-
self.module = module
65-
self.object = object
66-
self.method = method
67-
self.type = type
68-
self.sync = sync
66+
min_version: Optional[str] = None
6967

7068

7169
OPENAI_METHODS_V0 = [
@@ -115,6 +113,22 @@ def __init__(self, module: str, object: str, method: str, type: str, sync: bool)
115113
type="completion",
116114
sync=False,
117115
),
116+
OpenAiDefinition(
117+
module="openai.resources.beta.chat.completions",
118+
object="Completions",
119+
method="parse",
120+
type="chat",
121+
sync=True,
122+
min_version="1.50.0",
123+
),
124+
OpenAiDefinition(
125+
module="openai.resources.beta.chat.completions",
126+
object="AsyncCompletions",
127+
method="parse",
128+
type="chat",
129+
sync=False,
130+
min_version="1.50.0",
131+
),
118132
]
119133

120134

@@ -136,7 +150,13 @@ def __init__(
136150
self.args["metadata"] = (
137151
metadata
138152
if "response_format" not in kwargs
139-
else {**(metadata or {}), "response_format": kwargs["response_format"]}
153+
else {
154+
**(metadata or {}),
155+
"response_format": kwargs["response_format"].model_json_schema()
156+
if isclass(kwargs["response_format"])
157+
and issubclass(kwargs["response_format"], BaseModel)
158+
else kwargs["response_format"],
159+
}
140160
)
141161
self.args["trace_id"] = trace_id
142162
self.args["session_id"] = session_id
@@ -651,6 +671,11 @@ def register_tracing(self):
651671
resources = OPENAI_METHODS_V1 if _is_openai_v1() else OPENAI_METHODS_V0
652672

653673
for resource in resources:
674+
if resource.min_version is not None and Version(
675+
openai.__version__
676+
) < Version(resource.min_version):
677+
continue
678+
654679
wrap_function_wrapper(
655680
resource.module,
656681
f"{resource.object}.{resource.method}",

tests/test_openai.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1482,26 +1482,69 @@ def test_structured_output_response_format_kwarg():
14821482

14831483
def test_structured_output_beta_completions_parse():
14841484
from typing import List
1485+
from packaging.version import Version
14851486

14861487
class CalendarEvent(BaseModel):
14871488
name: str
14881489
date: str
14891490
participants: List[str]
14901491

1491-
openai.beta.chat.completions.parse(
1492-
model="gpt-4o-2024-08-06",
1493-
messages=[
1492+
generation_name = create_uuid()
1493+
api = get_api()
1494+
1495+
params = {
1496+
"model": "gpt-4o-2024-08-06",
1497+
"messages": [
14941498
{"role": "system", "content": "Extract the event information."},
14951499
{
14961500
"role": "user",
14971501
"content": "Alice and Bob are going to a science fair on Friday.",
14981502
},
14991503
],
1500-
response_format=CalendarEvent,
1501-
)
1504+
"response_format": CalendarEvent,
1505+
"name": generation_name,
1506+
}
1507+
1508+
# The beta API is only wrapped for this version range. prior to that, implicitly another wrapped method was called
1509+
if Version(openai.__version__) < Version("1.50.0"):
1510+
params.pop("name")
1511+
1512+
openai.beta.chat.completions.parse(**params)
15021513

15031514
openai.flush_langfuse()
15041515

1516+
if Version(openai.__version__) >= Version("1.50.0"):
1517+
# Check the trace and observation properties
1518+
generation = api.observations.get_many(name=generation_name, type="GENERATION")
1519+
1520+
assert len(generation.data) == 1
1521+
assert generation.data[0].name == generation_name
1522+
assert generation.data[0].type == "GENERATION"
1523+
assert generation.data[0].model == "gpt-4o-2024-08-06"
1524+
assert generation.data[0].start_time is not None
1525+
assert generation.data[0].end_time is not None
1526+
assert generation.data[0].start_time < generation.data[0].end_time
1527+
1528+
# Check input and output
1529+
assert len(generation.data[0].input) == 2
1530+
assert generation.data[0].input[0]["role"] == "system"
1531+
assert generation.data[0].input[1]["role"] == "user"
1532+
assert isinstance(generation.data[0].output, dict)
1533+
assert "name" in generation.data[0].output["content"]
1534+
assert "date" in generation.data[0].output["content"]
1535+
assert "participants" in generation.data[0].output["content"]
1536+
1537+
# Check usage
1538+
assert generation.data[0].usage.input is not None
1539+
assert generation.data[0].usage.output is not None
1540+
assert generation.data[0].usage.total is not None
1541+
1542+
# Check trace
1543+
trace = api.trace.get(generation.data[0].trace_id)
1544+
1545+
assert trace.input is not None
1546+
assert trace.output is not None
1547+
15051548

15061549
@pytest.mark.asyncio
15071550
async def test_close_async_stream():

0 commit comments

Comments
 (0)