Skip to content

Commit 12bc17b

Browse files
committed
fix(langchain): harden langgraph resume state
1 parent fccbffc commit 12bc17b

File tree

2 files changed

+314
-22
lines changed

2 files changed

+314
-22
lines changed

langfuse/langchain/CallbackHandler.py

Lines changed: 75 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -194,28 +194,41 @@ def _is_langgraph_resume(self, inputs: Any) -> bool:
194194
inputs, LANGGRAPH_COMMAND_TYPE
195195
)
196196

197-
def _consume_root_trace_context(
197+
def _take_root_trace_context(
198198
self, *, inputs: Any, metadata: Optional[Dict[str, Any]]
199-
) -> Optional[TraceContext]:
199+
) -> tuple[Optional[str], Optional[TraceContext]]:
200200
if self._trace_context is not None:
201-
return self._trace_context
201+
return None, self._trace_context
202202

203203
current_span_context = trace.get_current_span().get_span_context()
204204

205205
# Only reuse the pending resume context when this callback run has no active
206206
# parent span of its own. Nested callbacks should attach normally.
207207
if current_span_context.is_valid:
208-
return None
208+
return None, None
209209

210210
# Only explicit LangGraph resumes should consume pending trace linkage.
211211
if not self._is_langgraph_resume(inputs):
212-
return None
212+
return None, None
213213

214214
resume_key = self._get_langgraph_resume_key(metadata)
215215
if resume_key is None:
216-
return None
216+
return None, None
217+
218+
return resume_key, self._resume_trace_context_by_key.pop(resume_key, None)
219+
220+
def _restore_root_trace_context(
221+
self, *, resume_key: Optional[str], trace_context: Optional[TraceContext]
222+
) -> None:
223+
if self._trace_context is not None:
224+
return
217225

218-
return self._resume_trace_context_by_key.pop(resume_key, None)
226+
if resume_key is None or trace_context is None:
227+
return
228+
229+
# Span creation failed after we consumed the pending linkage, so put it
230+
# back and let the next retry resume the interrupted trace correctly.
231+
self._resume_trace_context_by_key.setdefault(resume_key, trace_context)
219232

220233
def _clear_root_run_resume_key(self, run_id: UUID) -> None:
221234
# Keep the pending interrupt context until an explicit Command(resume=...)
@@ -362,7 +375,7 @@ def on_retriever_error(
362375
langfuse_logger.exception(e)
363376
finally:
364377
if parent_run_id is None:
365-
self._reset()
378+
self._reset(run_id)
366379

367380
def _parse_langfuse_trace_attributes(
368381
self, *, metadata: Optional[Dict[str, Any]], tags: Optional[List[str]]
@@ -443,6 +456,10 @@ def on_chain_start(
443456
if parent_run_id is None:
444457
self._set_root_run_resume_key(run_id, metadata)
445458

459+
span = None
460+
resume_key = None
461+
trace_context = None
462+
446463
try:
447464
self._log_debug_event(
448465
"on_chain_start", run_id, parent_run_id, inputs=inputs
@@ -481,10 +498,11 @@ def on_chain_start(
481498

482499
obs = self._get_parent_observation(parent_run_id)
483500
if isinstance(obs, Langfuse):
501+
resume_key, trace_context = self._take_root_trace_context(
502+
inputs=inputs, metadata=metadata
503+
)
484504
span = obs.start_observation(
485-
trace_context=self._consume_root_trace_context(
486-
inputs=inputs, metadata=metadata
487-
),
505+
trace_context=trace_context,
488506
name=span_name,
489507
as_type=observation_type,
490508
metadata=span_metadata,
@@ -511,6 +529,13 @@ def on_chain_start(
511529
self.last_trace_id = self._runs[run_id].trace_id
512530

513531
except Exception as e:
532+
if span is None:
533+
self._restore_root_trace_context(
534+
resume_key=resume_key, trace_context=trace_context
535+
)
536+
if parent_run_id is None:
537+
self._clear_root_run_resume_key(run_id)
538+
self._exit_propagation_context()
514539
langfuse_logger.exception(e)
515540

516541
def _register_langfuse_prompt(
@@ -701,7 +726,7 @@ def on_chain_end(
701726
finally:
702727
if parent_run_id is None:
703728
self._exit_propagation_context()
704-
self._reset()
729+
self._reset(run_id)
705730

706731
def on_chain_error(
707732
self,
@@ -745,7 +770,7 @@ def on_chain_error(
745770
finally:
746771
if parent_run_id is None:
747772
self._exit_propagation_context()
748-
self._reset()
773+
self._reset(run_id)
749774

750775
def on_chat_model_start(
751776
self,
@@ -759,6 +784,8 @@ def on_chat_model_start(
759784
**kwargs: Any,
760785
) -> Any:
761786
self._child_to_parent_run_id_map[run_id] = parent_run_id
787+
if parent_run_id is None:
788+
self._set_root_run_resume_key(run_id, metadata)
762789

763790
try:
764791
self._log_debug_event(
@@ -824,6 +851,8 @@ def on_tool_start(
824851
**kwargs: Any,
825852
) -> Any:
826853
self._child_to_parent_run_id_map[run_id] = parent_run_id
854+
if parent_run_id is None:
855+
self._set_root_run_resume_key(run_id, metadata)
827856

828857
try:
829858
self._log_debug_event(
@@ -883,6 +912,8 @@ def on_retriever_start(
883912
**kwargs: Any,
884913
) -> Any:
885914
self._child_to_parent_run_id_map[run_id] = parent_run_id
915+
if parent_run_id is None:
916+
self._set_root_run_resume_key(run_id, metadata)
886917

887918
try:
888919
self._log_debug_event(
@@ -955,7 +986,7 @@ def on_retriever_end(
955986
langfuse_logger.exception(e)
956987
finally:
957988
if parent_run_id is None:
958-
self._reset()
989+
self._reset(run_id)
959990

960991
def on_tool_end(
961992
self,
@@ -980,6 +1011,9 @@ def on_tool_end(
9801011

9811012
except Exception as e:
9821013
langfuse_logger.exception(e)
1014+
finally:
1015+
if parent_run_id is None:
1016+
self._reset(run_id)
9831017

9841018
def on_tool_error(
9851019
self,
@@ -1016,7 +1050,7 @@ def on_tool_error(
10161050
langfuse_logger.exception(e)
10171051
finally:
10181052
if parent_run_id is None:
1019-
self._reset()
1053+
self._reset(run_id)
10201054

10211055
def __on_llm_action(
10221056
self,
@@ -1199,7 +1233,7 @@ def on_llm_end(
11991233

12001234
if parent_run_id is None:
12011235
self._clear_root_run_resume_key(run_id)
1202-
self._reset()
1236+
self._reset(run_id)
12031237

12041238
def on_llm_error(
12051239
self,
@@ -1237,11 +1271,32 @@ def on_llm_error(
12371271
langfuse_logger.exception(e)
12381272
finally:
12391273
if parent_run_id is None:
1240-
self._reset()
1274+
self._reset(run_id)
1275+
1276+
def _run_belongs_to_root(self, run_id: UUID, root_run_id: UUID) -> bool:
1277+
current_run_id: Optional[UUID] = run_id
1278+
visited: Set[UUID] = set()
1279+
1280+
while current_run_id is not None and current_run_id not in visited:
1281+
if current_run_id == root_run_id:
1282+
return True
1283+
1284+
visited.add(current_run_id)
1285+
current_run_id = self._child_to_parent_run_id_map.get(current_run_id)
1286+
1287+
return False
1288+
1289+
def _reset(self, root_run_id: UUID) -> None:
1290+
run_ids_to_clear = [
1291+
run_id
1292+
for run_id in self._child_to_parent_run_id_map
1293+
if self._run_belongs_to_root(run_id, root_run_id)
1294+
]
1295+
1296+
for run_id in run_ids_to_clear:
1297+
self._child_to_parent_run_id_map.pop(run_id, None)
12411298

1242-
def _reset(self) -> None:
1243-
self._child_to_parent_run_id_map = {}
1244-
self._root_run_resume_key_map = {}
1299+
self._root_run_resume_key_map.pop(root_run_id, None)
12451300

12461301
def _exit_propagation_context(self) -> None:
12471302
manager = self._propagation_context_manager

0 commit comments

Comments
 (0)