Skip to content

Commit ad2e408

Browse files
committed
fix(langchain): scope langgraph resume context
1 parent 19a231e commit ad2e408

File tree

2 files changed

+230
-123
lines changed

2 files changed

+230
-123
lines changed

langfuse/langchain/CallbackHandler.py

Lines changed: 99 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,14 @@
8484

8585
LANGSMITH_TAG_HIDDEN: str = "langsmith:hidden"
8686
CONTROL_FLOW_EXCEPTION_TYPES: Set[Type[BaseException]] = set()
87+
LANGGRAPH_COMMAND_TYPE: Optional[Type[Any]] = None
8788

8889
try:
8990
from langgraph.errors import GraphBubbleUp
91+
from langgraph.types import Command as LangGraphCommand
9092

9193
CONTROL_FLOW_EXCEPTION_TYPES.add(GraphBubbleUp)
94+
LANGGRAPH_COMMAND_TYPE = LangGraphCommand
9295
except ImportError:
9396
pass
9497

@@ -136,8 +139,9 @@ def __init__(
136139
self._propagation_context_manager: Optional[_AgnosticContextManager] = None
137140
self._trace_context = trace_context
138141
# LangGraph resumes as a fresh root callback run after interrupting, so we keep
139-
# just enough trace context to stitch the resume back onto the original trace.
140-
self._resume_trace_context: Optional[TraceContext] = None
142+
# pending resume contexts keyed by thread/session instead of a single shared slot.
143+
self._resume_trace_context_by_key: Dict[str, TraceContext] = {}
144+
self._root_run_resume_key_map: Dict[UUID, str] = {}
141145
self._child_to_parent_run_id_map: Dict[UUID, Optional[UUID]] = {}
142146

143147
self.last_trace_id: Optional[str] = None
@@ -164,7 +168,35 @@ def on_llm_new_token(
164168

165169
self._updated_completion_start_time_memo.add(run_id)
166170

167-
def _consume_root_trace_context(self) -> Optional[TraceContext]:
171+
def _get_langgraph_resume_key(
172+
self, metadata: Optional[Dict[str, Any]]
173+
) -> Optional[str]:
174+
thread_id = metadata.get("thread_id") if metadata else None
175+
176+
if thread_id is None:
177+
return None
178+
179+
return str(thread_id)
180+
181+
def _set_root_run_resume_key(
182+
self, run_id: UUID, metadata: Optional[Dict[str, Any]]
183+
) -> None:
184+
resume_key = self._get_langgraph_resume_key(metadata)
185+
186+
if resume_key is not None:
187+
self._root_run_resume_key_map[run_id] = resume_key
188+
189+
def _pop_root_run_resume_key(self, run_id: UUID) -> Optional[str]:
190+
return self._root_run_resume_key_map.pop(run_id, None)
191+
192+
def _is_langgraph_resume(self, inputs: Any) -> bool:
193+
return LANGGRAPH_COMMAND_TYPE is not None and isinstance(
194+
inputs, LANGGRAPH_COMMAND_TYPE
195+
)
196+
197+
def _consume_root_trace_context(
198+
self, *, inputs: Any, metadata: Optional[Dict[str, Any]]
199+
) -> Optional[TraceContext]:
168200
if self._trace_context is not None:
169201
return self._trace_context
170202

@@ -175,19 +207,30 @@ def _consume_root_trace_context(self) -> Optional[TraceContext]:
175207
if current_span_context.is_valid:
176208
return None
177209

178-
trace_context = self._resume_trace_context
179-
self._resume_trace_context = None
210+
# Only explicit LangGraph resumes should consume pending trace linkage.
211+
if not self._is_langgraph_resume(inputs):
212+
return None
213+
214+
resume_key = self._get_langgraph_resume_key(metadata)
215+
if resume_key is None:
216+
return None
180217

181-
return trace_context
218+
return self._resume_trace_context_by_key.pop(resume_key, None)
182219

183-
def _clear_resume_trace_context(self) -> None:
184-
self._resume_trace_context = None
220+
def _clear_root_run_resume_key(self, run_id: UUID) -> None:
221+
# Keep the pending interrupt context until an explicit Command(resume=...)
222+
# arrives. A separate root run on the same thread_id is not a resume.
223+
self._pop_root_run_resume_key(run_id)
185224

186-
def _persist_resume_trace_context(self, observation: Any) -> None:
225+
def _persist_resume_trace_context(self, *, run_id: UUID, observation: Any) -> None:
187226
if self._trace_context is not None:
188227
return
189228

190-
self._resume_trace_context = {
229+
resume_key = self._pop_root_run_resume_key(run_id)
230+
if resume_key is None:
231+
return
232+
233+
self._resume_trace_context_by_key[resume_key] = {
191234
"trace_id": observation.trace_id,
192235
"parent_span_id": observation.id,
193236
}
@@ -309,12 +352,17 @@ def on_retriever_error(
309352
).end()
310353

311354
if parent_run_id is None and level == "DEFAULT":
312-
self._persist_resume_trace_context(observation)
355+
self._persist_resume_trace_context(
356+
run_id=run_id, observation=observation
357+
)
313358
elif parent_run_id is None:
314-
self._clear_resume_trace_context()
359+
self._clear_root_run_resume_key(run_id)
315360

316361
except Exception as e:
317362
langfuse_logger.exception(e)
363+
finally:
364+
if parent_run_id is None:
365+
self._reset()
318366

319367
def _parse_langfuse_trace_attributes(
320368
self, *, metadata: Optional[Dict[str, Any]], tags: Optional[List[str]]
@@ -383,7 +431,7 @@ def _get_langchain_observation_metadata(
383431
def on_chain_start(
384432
self,
385433
serialized: Optional[Dict[str, Any]],
386-
inputs: Dict[str, Any],
434+
inputs: Any,
387435
*,
388436
run_id: UUID,
389437
parent_run_id: Optional[UUID] = None,
@@ -392,6 +440,8 @@ def on_chain_start(
392440
**kwargs: Any,
393441
) -> Any:
394442
self._child_to_parent_run_id_map[run_id] = parent_run_id
443+
if parent_run_id is None:
444+
self._set_root_run_resume_key(run_id, metadata)
395445

396446
try:
397447
self._log_debug_event(
@@ -432,7 +482,9 @@ def on_chain_start(
432482
obs = self._get_parent_observation(parent_run_id)
433483
if isinstance(obs, Langfuse):
434484
span = obs.start_observation(
435-
trace_context=self._consume_root_trace_context(),
485+
trace_context=self._consume_root_trace_context(
486+
inputs=inputs, metadata=metadata
487+
),
436488
name=span_name,
437489
as_type=observation_type,
438490
metadata=span_metadata,
@@ -636,7 +688,7 @@ def on_chain_end(
636688
)
637689

638690
if parent_run_id is None:
639-
self._clear_resume_trace_context()
691+
self._clear_root_run_resume_key(run_id)
640692
self._exit_propagation_context()
641693

642694
span.end()
@@ -679,9 +731,11 @@ def on_chain_error(
679731

680732
if parent_run_id is None:
681733
if level == "DEFAULT":
682-
self._persist_resume_trace_context(observation)
734+
self._persist_resume_trace_context(
735+
run_id=run_id, observation=observation
736+
)
683737
else:
684-
self._clear_resume_trace_context()
738+
self._clear_root_run_resume_key(run_id)
685739
self._exit_propagation_context()
686740

687741
observation.end()
@@ -739,6 +793,8 @@ def on_llm_start(
739793
**kwargs: Any,
740794
) -> Any:
741795
self._child_to_parent_run_id_map[run_id] = parent_run_id
796+
if parent_run_id is None:
797+
self._set_root_run_resume_key(run_id, metadata)
742798

743799
try:
744800
self._log_debug_event(
@@ -794,7 +850,7 @@ def on_tool_start(
794850
parent_observation = self._get_parent_observation(parent_run_id)
795851
if isinstance(parent_observation, Langfuse):
796852
span = parent_observation.start_observation(
797-
trace_context=self._consume_root_trace_context(),
853+
trace_context=self._trace_context,
798854
name=self.get_langchain_run_name(serialized, **kwargs),
799855
as_type=observation_type,
800856
input=input_str,
@@ -846,7 +902,7 @@ def on_retriever_start(
846902
parent_observation = self._get_parent_observation(parent_run_id)
847903
if isinstance(parent_observation, Langfuse):
848904
span = parent_observation.start_observation(
849-
trace_context=self._consume_root_trace_context(),
905+
trace_context=self._trace_context,
850906
name=span_name,
851907
as_type=observation_type,
852908
metadata=span_metadata,
@@ -889,14 +945,17 @@ def on_retriever_end(
889945

890946
if observation is not None:
891947
if parent_run_id is None:
892-
self._clear_resume_trace_context()
948+
self._clear_root_run_resume_key(run_id)
893949
observation.update(
894950
output=documents,
895951
input=kwargs.get("inputs"),
896952
).end()
897953

898954
except Exception as e:
899955
langfuse_logger.exception(e)
956+
finally:
957+
if parent_run_id is None:
958+
self._reset()
900959

901960
def on_tool_end(
902961
self,
@@ -913,7 +972,7 @@ def on_tool_end(
913972

914973
if observation is not None:
915974
if parent_run_id is None:
916-
self._clear_resume_trace_context()
975+
self._clear_root_run_resume_key(run_id)
917976
observation.update(
918977
output=output,
919978
input=kwargs.get("inputs"),
@@ -947,12 +1006,17 @@ def on_tool_error(
9471006
).end()
9481007

9491008
if parent_run_id is None and level == "DEFAULT":
950-
self._persist_resume_trace_context(observation)
1009+
self._persist_resume_trace_context(
1010+
run_id=run_id, observation=observation
1011+
)
9511012
elif parent_run_id is None:
952-
self._clear_resume_trace_context()
1013+
self._clear_root_run_resume_key(run_id)
9531014

9541015
except Exception as e:
9551016
langfuse_logger.exception(e)
1017+
finally:
1018+
if parent_run_id is None:
1019+
self._reset()
9561020

9571021
def __on_llm_action(
9581022
self,
@@ -965,6 +1029,8 @@ def __on_llm_action(
9651029
**kwargs: Any,
9661030
) -> None:
9671031
self._child_to_parent_run_id_map[run_id] = parent_run_id
1032+
if parent_run_id is None:
1033+
self._set_root_run_resume_key(run_id, metadata)
9681034

9691035
try:
9701036
tools = kwargs.get("invocation_params", {}).get("tools", None)
@@ -1012,7 +1078,7 @@ def __on_llm_action(
10121078
parent_observation = self._get_parent_observation(parent_run_id)
10131079
if isinstance(parent_observation, Langfuse):
10141080
generation = parent_observation.start_observation(
1015-
trace_context=self._consume_root_trace_context(),
1081+
trace_context=self._trace_context,
10161082
as_type="generation",
10171083
**content,
10181084
) # type: ignore
@@ -1132,7 +1198,7 @@ def on_llm_end(
11321198
self._updated_completion_start_time_memo.discard(run_id)
11331199

11341200
if parent_run_id is None:
1135-
self._clear_resume_trace_context()
1201+
self._clear_root_run_resume_key(run_id)
11361202
self._reset()
11371203

11381204
def on_llm_error(
@@ -1161,15 +1227,21 @@ def on_llm_error(
11611227
).end()
11621228

11631229
if parent_run_id is None and level == "DEFAULT":
1164-
self._persist_resume_trace_context(generation)
1230+
self._persist_resume_trace_context(
1231+
run_id=run_id, observation=generation
1232+
)
11651233
elif parent_run_id is None:
1166-
self._clear_resume_trace_context()
1234+
self._clear_root_run_resume_key(run_id)
11671235

11681236
except Exception as e:
11691237
langfuse_logger.exception(e)
1238+
finally:
1239+
if parent_run_id is None:
1240+
self._reset()
11701241

11711242
def _reset(self) -> None:
11721243
self._child_to_parent_run_id_map = {}
1244+
self._root_run_resume_key_map = {}
11731245

11741246
def _exit_propagation_context(self) -> None:
11751247
manager = self._propagation_context_manager

0 commit comments

Comments
 (0)