Skip to content

Commit 55bb373

Browse files
committed
refactor(langchain): centralize root resume state
1 parent 9414be8 commit 55bb373

File tree

2 files changed

+176
-93
lines changed

2 files changed

+176
-93
lines changed

langfuse/langchain/CallbackHandler.py

Lines changed: 137 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections import OrderedDict
22
from contextvars import Token
3+
from dataclasses import dataclass, field
34
from typing import (
45
Any,
56
Dict,
@@ -103,6 +104,44 @@
103104
pass
104105

105106

107+
@dataclass
108+
class _RunState:
109+
parent_run_id: Optional[UUID]
110+
root_run_id: UUID
111+
112+
113+
@dataclass
114+
class _RootRunState:
115+
run_ids: Set[UUID] = field(default_factory=set)
116+
resume_key: Optional[str] = None
117+
propagation_context_manager: Optional[_AgnosticContextManager] = None
118+
119+
120+
class _PendingResumeTraceContextStore:
121+
def __init__(self, max_size: int) -> None:
122+
self._max_size = max_size
123+
self._contexts: OrderedDict[str, TraceContext] = OrderedDict()
124+
125+
def store(self, *, resume_key: str, trace_context: TraceContext) -> None:
126+
self._contexts[resume_key] = trace_context
127+
self._contexts.move_to_end(resume_key)
128+
129+
if len(self._contexts) > self._max_size:
130+
self._contexts.popitem(last=False)
131+
132+
def take(self, resume_key: str) -> Optional[TraceContext]:
133+
return self._contexts.pop(resume_key, None)
134+
135+
def __contains__(self, resume_key: str) -> bool:
136+
return resume_key in self._contexts
137+
138+
def __len__(self) -> int:
139+
return len(self._contexts)
140+
141+
def keys(self) -> List[str]:
142+
return list(self._contexts.keys())
143+
144+
106145
class LangchainCallbackHandler(LangchainBaseCallbackHandler):
107146
def __init__(
108147
self,
@@ -143,15 +182,12 @@ def __init__(
143182
self._context_tokens: Dict[UUID, Token] = {}
144183
self._prompt_to_parent_run_map: Dict[UUID, Any] = {}
145184
self._updated_completion_start_time_memo: Set[UUID] = set()
146-
self._propagation_context_manager: Optional[_AgnosticContextManager] = None
147185
self._trace_context = trace_context
148-
# LangGraph resumes as a fresh root callback run after interrupting, so we keep
149-
# pending resume contexts keyed by thread/session instead of a single shared slot.
150-
self._resume_trace_context_by_key: OrderedDict[str, TraceContext] = (
151-
OrderedDict()
186+
self._pending_resume_trace_contexts = _PendingResumeTraceContextStore(
187+
MAX_PENDING_RESUME_TRACE_CONTEXTS
152188
)
153-
self._root_run_resume_key_map: Dict[UUID, str] = {}
154-
self._child_to_parent_run_id_map: Dict[UUID, Optional[UUID]] = {}
189+
self._run_states: Dict[UUID, _RunState] = {}
190+
self._root_run_states: Dict[UUID, _RootRunState] = {}
155191

156192
self.last_trace_id: Optional[str] = None
157193

@@ -187,16 +223,62 @@ def _get_langgraph_resume_key(
187223

188224
return str(thread_id)
189225

190-
def _set_root_run_resume_key(
191-
self, run_id: UUID, metadata: Optional[Dict[str, Any]]
226+
def _track_run(
227+
self,
228+
*,
229+
run_id: UUID,
230+
parent_run_id: Optional[UUID],
231+
metadata: Optional[Dict[str, Any]] = None,
192232
) -> None:
193-
resume_key = self._get_langgraph_resume_key(metadata)
233+
if run_id in self._run_states:
234+
return
235+
236+
if parent_run_id is None:
237+
root_run_id = run_id
238+
self._root_run_states[root_run_id] = _RootRunState(
239+
run_ids={run_id},
240+
resume_key=self._get_langgraph_resume_key(metadata),
241+
)
242+
else:
243+
parent_state = self._run_states.get(parent_run_id)
244+
root_run_id = (
245+
parent_state.root_run_id if parent_state is not None else parent_run_id
246+
)
247+
root_run_state = self._root_run_states.setdefault(
248+
root_run_id, _RootRunState()
249+
)
250+
root_run_state.run_ids.add(run_id)
251+
252+
self._run_states[run_id] = _RunState(
253+
parent_run_id=parent_run_id,
254+
root_run_id=root_run_id,
255+
)
256+
257+
def _get_run_state(self, run_id: UUID) -> Optional[_RunState]:
258+
return self._run_states.get(run_id)
194259

195-
if resume_key is not None:
196-
self._root_run_resume_key_map[run_id] = resume_key
260+
def _get_root_run_state(self, run_id: UUID) -> Optional[_RootRunState]:
261+
run_state = self._get_run_state(run_id)
262+
263+
if run_state is None:
264+
return None
265+
266+
return self._root_run_states.get(run_state.root_run_id)
197267

198268
def _pop_root_run_resume_key(self, run_id: UUID) -> Optional[str]:
199-
return self._root_run_resume_key_map.pop(run_id, None)
269+
root_run_state = self._get_root_run_state(run_id)
270+
271+
if root_run_state is None:
272+
return None
273+
274+
resume_key = root_run_state.resume_key
275+
root_run_state.resume_key = None
276+
277+
return resume_key
278+
279+
def _get_parent_run_id(self, run_id: UUID) -> Optional[UUID]:
280+
run_state = self._get_run_state(run_id)
281+
return run_state.parent_run_id if run_state is not None else None
200282

201283
def _is_langgraph_resume(self, inputs: Any) -> bool:
202284
return (
@@ -208,11 +290,9 @@ def _is_langgraph_resume(self, inputs: Any) -> bool:
208290
def _store_resume_trace_context(
209291
self, *, resume_key: str, trace_context: TraceContext
210292
) -> None:
211-
self._resume_trace_context_by_key[resume_key] = trace_context
212-
self._resume_trace_context_by_key.move_to_end(resume_key)
213-
214-
if len(self._resume_trace_context_by_key) > MAX_PENDING_RESUME_TRACE_CONTEXTS:
215-
self._resume_trace_context_by_key.popitem(last=False)
293+
self._pending_resume_trace_contexts.store(
294+
resume_key=resume_key, trace_context=trace_context
295+
)
216296

217297
def _take_root_trace_context(
218298
self, *, inputs: Any, metadata: Optional[Dict[str, Any]]
@@ -235,7 +315,7 @@ def _take_root_trace_context(
235315
if resume_key is None:
236316
return None, None
237317

238-
return resume_key, self._resume_trace_context_by_key.pop(resume_key, None)
318+
return resume_key, self._pending_resume_trace_contexts.take(resume_key)
239319

240320
def _restore_root_trace_context(
241321
self, *, resume_key: Optional[str], trace_context: Optional[TraceContext]
@@ -477,9 +557,7 @@ def on_chain_start(
477557
metadata: Optional[Dict[str, Any]] = None,
478558
**kwargs: Any,
479559
) -> Any:
480-
self._child_to_parent_run_id_map[run_id] = parent_run_id
481-
if parent_run_id is None:
482-
self._set_root_run_resume_key(run_id, metadata)
560+
self._track_run(run_id=run_id, parent_run_id=parent_run_id, metadata=metadata)
483561

484562
span = None
485563
resume_key = None
@@ -511,15 +589,21 @@ def on_chain_start(
511589
metadata=metadata, tags=tags
512590
)
513591

514-
self._propagation_context_manager = propagate_attributes(
592+
propagation_context_manager = propagate_attributes(
515593
user_id=parsed_trace_attributes.get("user_id", None),
516594
session_id=parsed_trace_attributes.get("session_id", None),
517595
tags=parsed_trace_attributes.get("tags", None),
518596
metadata=parsed_trace_attributes.get("metadata", None),
519597
trace_name=parsed_trace_attributes.get("trace_name", None),
520598
)
521599

522-
self._propagation_context_manager.__enter__()
600+
root_run_state = self._get_root_run_state(run_id)
601+
if root_run_state is not None:
602+
root_run_state.propagation_context_manager = (
603+
propagation_context_manager
604+
)
605+
606+
propagation_context_manager.__enter__()
523607

524608
obs = self._get_parent_observation(parent_run_id)
525609
if isinstance(obs, Langfuse):
@@ -559,7 +643,7 @@ def on_chain_start(
559643
resume_key=resume_key, trace_context=trace_context
560644
)
561645
if parent_run_id is None:
562-
self._exit_propagation_context()
646+
self._exit_propagation_context(run_id)
563647
self._reset(run_id)
564648
langfuse_logger.exception(e)
565649

@@ -665,7 +749,7 @@ def on_agent_action(
665749
**kwargs: Any,
666750
) -> Any:
667751
"""Run on agent action."""
668-
self._child_to_parent_run_id_map[run_id] = parent_run_id
752+
self._track_run(run_id=run_id, parent_run_id=parent_run_id)
669753

670754
try:
671755
self._log_debug_event(
@@ -739,7 +823,7 @@ def on_chain_end(
739823

740824
if parent_run_id is None:
741825
self._clear_root_run_resume_key(run_id)
742-
self._exit_propagation_context()
826+
self._exit_propagation_context(run_id)
743827

744828
span.end()
745829

@@ -750,7 +834,7 @@ def on_chain_end(
750834

751835
finally:
752836
if parent_run_id is None:
753-
self._exit_propagation_context()
837+
self._exit_propagation_context(run_id)
754838
self._reset(run_id)
755839

756840
def on_chain_error(
@@ -786,15 +870,15 @@ def on_chain_error(
786870
)
787871
else:
788872
self._clear_root_run_resume_key(run_id)
789-
self._exit_propagation_context()
873+
self._exit_propagation_context(run_id)
790874

791875
observation.end()
792876

793877
except Exception as e:
794878
langfuse_logger.exception(e)
795879
finally:
796880
if parent_run_id is None:
797-
self._exit_propagation_context()
881+
self._exit_propagation_context(run_id)
798882
self._reset(run_id)
799883

800884
def on_chat_model_start(
@@ -808,9 +892,7 @@ def on_chat_model_start(
808892
metadata: Optional[Dict[str, Any]] = None,
809893
**kwargs: Any,
810894
) -> Any:
811-
self._child_to_parent_run_id_map[run_id] = parent_run_id
812-
if parent_run_id is None:
813-
self._set_root_run_resume_key(run_id, metadata)
895+
self._track_run(run_id=run_id, parent_run_id=parent_run_id, metadata=metadata)
814896

815897
try:
816898
self._log_debug_event(
@@ -844,9 +926,7 @@ def on_llm_start(
844926
metadata: Optional[Dict[str, Any]] = None,
845927
**kwargs: Any,
846928
) -> Any:
847-
self._child_to_parent_run_id_map[run_id] = parent_run_id
848-
if parent_run_id is None:
849-
self._set_root_run_resume_key(run_id, metadata)
929+
self._track_run(run_id=run_id, parent_run_id=parent_run_id, metadata=metadata)
850930

851931
try:
852932
self._log_debug_event(
@@ -875,9 +955,7 @@ def on_tool_start(
875955
metadata: Optional[Dict[str, Any]] = None,
876956
**kwargs: Any,
877957
) -> Any:
878-
self._child_to_parent_run_id_map[run_id] = parent_run_id
879-
if parent_run_id is None:
880-
self._set_root_run_resume_key(run_id, metadata)
958+
self._track_run(run_id=run_id, parent_run_id=parent_run_id, metadata=metadata)
881959

882960
try:
883961
self._log_debug_event(
@@ -936,9 +1014,7 @@ def on_retriever_start(
9361014
metadata: Optional[Dict[str, Any]] = None,
9371015
**kwargs: Any,
9381016
) -> Any:
939-
self._child_to_parent_run_id_map[run_id] = parent_run_id
940-
if parent_run_id is None:
941-
self._set_root_run_resume_key(run_id, metadata)
1017+
self._track_run(run_id=run_id, parent_run_id=parent_run_id, metadata=metadata)
9421018

9431019
try:
9441020
self._log_debug_event(
@@ -1087,9 +1163,7 @@ def __on_llm_action(
10871163
metadata: Optional[Dict[str, Any]] = None,
10881164
**kwargs: Any,
10891165
) -> None:
1090-
self._child_to_parent_run_id_map[run_id] = parent_run_id
1091-
if parent_run_id is None:
1092-
self._set_root_run_resume_key(run_id, metadata)
1166+
self._track_run(run_id=run_id, parent_run_id=parent_run_id, metadata=metadata)
10931167

10941168
try:
10951169
tools = kwargs.get("invocation_params", {}).get("tools", None)
@@ -1113,8 +1187,8 @@ def __on_llm_action(
11131187
self._deregister_langfuse_prompt(current_parent_run_id)
11141188
break
11151189
else:
1116-
current_parent_run_id = self._child_to_parent_run_id_map.get(
1117-
current_parent_run_id, None
1190+
current_parent_run_id = self._get_parent_run_id(
1191+
current_parent_run_id
11181192
)
11191193

11201194
content = {
@@ -1298,38 +1372,30 @@ def on_llm_error(
12981372
if parent_run_id is None:
12991373
self._reset(run_id)
13001374

1301-
def _run_belongs_to_root(self, run_id: UUID, root_run_id: UUID) -> bool:
1302-
current_run_id: Optional[UUID] = run_id
1303-
visited: Set[UUID] = set()
1304-
1305-
while current_run_id is not None and current_run_id not in visited:
1306-
if current_run_id == root_run_id:
1307-
return True
1308-
1309-
visited.add(current_run_id)
1310-
current_run_id = self._child_to_parent_run_id_map.get(current_run_id)
1311-
1312-
return False
1313-
13141375
def _reset(self, root_run_id: UUID) -> None:
1315-
run_ids_to_clear = [
1316-
run_id
1317-
for run_id in self._child_to_parent_run_id_map
1318-
if self._run_belongs_to_root(run_id, root_run_id)
1319-
]
1376+
run_state = self._get_run_state(root_run_id)
1377+
if run_state is None:
1378+
return
13201379

1321-
for run_id in run_ids_to_clear:
1322-
self._child_to_parent_run_id_map.pop(run_id, None)
1380+
root_run_state = self._root_run_states.pop(run_state.root_run_id, None)
1381+
if root_run_state is None:
1382+
self._run_states.pop(root_run_id, None)
1383+
return
1384+
1385+
for run_id in root_run_state.run_ids:
1386+
self._run_states.pop(run_id, None)
13231387

1324-
self._root_run_resume_key_map.pop(root_run_id, None)
1388+
def _exit_propagation_context(self, run_id: UUID) -> None:
1389+
root_run_state = self._get_root_run_state(run_id)
13251390

1326-
def _exit_propagation_context(self) -> None:
1327-
manager = self._propagation_context_manager
1391+
if root_run_state is None:
1392+
return
13281393

1394+
manager = root_run_state.propagation_context_manager
13291395
if manager is None:
13301396
return
13311397

1332-
self._propagation_context_manager = None
1398+
root_run_state.propagation_context_manager = None
13331399
manager.__exit__(None, None, None)
13341400

13351401
def __join_tags_and_metadata(

0 commit comments

Comments
 (0)