-
Notifications
You must be signed in to change notification settings - Fork 262
Expand file tree
/
Copy pathtest_langchain_callback_unit.py
More file actions
267 lines (221 loc) · 9.42 KB
/
test_langchain_callback_unit.py
File metadata and controls
267 lines (221 loc) · 9.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
"""Unit tests for CallbackHandler trace-name propagation.
These tests cover the fix for on_chain_start not passing trace_name to
propagate_attributes, which caused non-deterministic trace names on LangGraph
resume (e.g. after a human-in-the-loop interrupt).
No real API calls are made — propagate_attributes and get_client are mocked.
"""
import uuid
from contextlib import contextmanager
from typing import Any, Dict, Optional
from unittest.mock import MagicMock, call, patch
import pytest
from langfuse.langchain import CallbackHandler
from langfuse.langchain.CallbackHandler import (
LangchainCallbackHandler,
_strip_langfuse_keys_from_dict,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_handler() -> CallbackHandler:
"""Return a CallbackHandler with Langfuse SDK calls mocked out."""
with patch("langfuse.langchain.CallbackHandler.get_client") as mock_get_client:
mock_client = MagicMock()
mock_client.start_observation.return_value = MagicMock(trace_id="trace-123")
mock_get_client.return_value = mock_client
handler = CallbackHandler()
# Keep a reference so tests can inspect it
handler._langfuse_client = MagicMock()
handler._langfuse_client.start_observation.return_value = MagicMock(
trace_id="trace-123"
)
return handler
def _make_run_id() -> uuid.UUID:
return uuid.uuid4()
# ---------------------------------------------------------------------------
# Tests: _parse_langfuse_trace_attributes
# ---------------------------------------------------------------------------
class TestParseLangfuseTraceAttributes:
def _parse(self, handler, metadata=None, tags=None):
return handler._parse_langfuse_trace_attributes(
metadata=metadata, tags=tags
)
def test_extracts_trace_name_from_metadata(self):
handler = _make_handler()
result = self._parse(
handler, metadata={"langfuse_trace_name": "my-agent"}
)
assert result["trace_name"] == "my-agent"
def test_ignores_non_string_trace_name(self):
handler = _make_handler()
result = self._parse(handler, metadata={"langfuse_trace_name": 42})
assert "trace_name" not in result
def test_does_not_set_trace_name_when_absent(self):
handler = _make_handler()
result = self._parse(handler, metadata={"langfuse_session_id": "s1"})
assert "trace_name" not in result
def test_extracts_all_attributes_together(self):
handler = _make_handler()
result = self._parse(
handler,
metadata={
"langfuse_trace_name": "agent",
"langfuse_session_id": "sess-1",
"langfuse_user_id": "user-1",
},
)
assert result["trace_name"] == "agent"
assert result["session_id"] == "sess-1"
assert result["user_id"] == "user-1"
# ---------------------------------------------------------------------------
# Tests: on_chain_start passes trace_name to propagate_attributes
# ---------------------------------------------------------------------------
class TestOnChainStartTraceNamePropagation:
"""Verify that on_chain_start forwards trace_name to propagate_attributes."""
def _run_on_chain_start(
self,
handler: CallbackHandler,
serialized: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
parent_run_id: Optional[uuid.UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> uuid.UUID:
run_id = _make_run_id()
kwargs: Dict[str, Any] = {}
if name is not None:
kwargs["name"] = name
handler.on_chain_start(
serialized=serialized or {},
inputs={},
run_id=run_id,
parent_run_id=parent_run_id,
metadata=metadata,
**kwargs,
)
return run_id
def test_trace_name_passed_to_propagate_attributes(self):
"""span_name derived from serialized['name'] is forwarded as trace_name."""
handler = _make_handler()
@contextmanager
def _noop_ctx(*args, **kwargs):
yield
with patch(
"langfuse.langchain.CallbackHandler.propagate_attributes"
) as mock_pa:
mock_pa.return_value = MagicMock(
__enter__=MagicMock(return_value=None),
__exit__=MagicMock(return_value=False),
)
self._run_on_chain_start(
handler,
serialized={"name": "my-agent"},
parent_run_id=None,
)
mock_pa.assert_called_once()
_, kwargs = mock_pa.call_args
assert kwargs.get("trace_name") == "my-agent"
def test_trace_name_uses_kwargs_name_over_serialized(self):
"""The 'name' kwarg takes priority over serialized dict (LangChain convention)."""
handler = _make_handler()
with patch(
"langfuse.langchain.CallbackHandler.propagate_attributes"
) as mock_pa:
mock_pa.return_value = MagicMock(
__enter__=MagicMock(return_value=None),
__exit__=MagicMock(return_value=False),
)
self._run_on_chain_start(
handler,
serialized={"name": "fallback-name"},
name="explicit-name",
parent_run_id=None,
)
_, kwargs = mock_pa.call_args
assert kwargs.get("trace_name") == "explicit-name"
def test_metadata_langfuse_trace_name_overrides_span_name(self):
"""langfuse_trace_name in metadata takes priority over computed span_name."""
handler = _make_handler()
with patch(
"langfuse.langchain.CallbackHandler.propagate_attributes"
) as mock_pa:
mock_pa.return_value = MagicMock(
__enter__=MagicMock(return_value=None),
__exit__=MagicMock(return_value=False),
)
self._run_on_chain_start(
handler,
serialized={"name": "computed-name"},
metadata={"langfuse_trace_name": "override-name"},
parent_run_id=None,
)
_, kwargs = mock_pa.call_args
assert kwargs.get("trace_name") == "override-name"
def test_propagate_attributes_not_called_for_child_runs(self):
"""propagate_attributes must only be called at the root (parent_run_id=None)."""
handler = _make_handler()
root_run_id = _make_run_id()
handler._child_to_parent_run_id_map[root_run_id] = None
with patch(
"langfuse.langchain.CallbackHandler.propagate_attributes"
) as mock_pa:
mock_pa.return_value = MagicMock(
__enter__=MagicMock(return_value=None),
__exit__=MagicMock(return_value=False),
)
# Child run — parent_run_id is set
self._run_on_chain_start(
handler,
serialized={"name": "child-node"},
parent_run_id=root_run_id,
)
mock_pa.assert_not_called()
def test_empty_span_name_still_propagated(self):
"""Even when span_name resolves to '', it should still be forwarded."""
handler = _make_handler()
with patch(
"langfuse.langchain.CallbackHandler.propagate_attributes"
) as mock_pa:
mock_pa.return_value = MagicMock(
__enter__=MagicMock(return_value=None),
__exit__=MagicMock(return_value=False),
)
# Empty serialized — span_name will be '<unknown>'
self._run_on_chain_start(
handler,
serialized=None,
parent_run_id=None,
)
_, kwargs = mock_pa.call_args
# '<unknown>' is what get_langchain_run_name returns for None serialized
assert kwargs.get("trace_name") == "<unknown>"
# ---------------------------------------------------------------------------
# Tests: _strip_langfuse_keys_from_dict strips langfuse_trace_name
# ---------------------------------------------------------------------------
class TestStripLangfuseKeys:
def test_strips_langfuse_trace_name(self):
metadata = {
"langfuse_trace_name": "my-agent",
"other_key": "value",
}
result = _strip_langfuse_keys_from_dict(metadata, keep_langfuse_trace_attributes=False)
assert "langfuse_trace_name" not in result
assert result["other_key"] == "value"
def test_keeps_langfuse_trace_name_when_flag_set(self):
metadata = {
"langfuse_trace_name": "my-agent",
"other_key": "value",
}
result = _strip_langfuse_keys_from_dict(metadata, keep_langfuse_trace_attributes=True)
assert result["langfuse_trace_name"] == "my-agent"
def test_strips_all_trace_attribute_keys_together(self):
metadata = {
"langfuse_trace_name": "n",
"langfuse_session_id": "s",
"langfuse_user_id": "u",
"langfuse_tags": ["t"],
"keep_me": 1,
}
result = _strip_langfuse_keys_from_dict(metadata, keep_langfuse_trace_attributes=False)
for key in ("langfuse_trace_name", "langfuse_session_id", "langfuse_user_id", "langfuse_tags"):
assert key not in result
assert result["keep_me"] == 1