Skip to content

Commit e0e81a6

Browse files
committed
fix(langchain): preserve langgraph control flow traces
1 parent cfbe7a3 commit e0e81a6

2 files changed

Lines changed: 288 additions & 32 deletions

File tree

langfuse/langchain/CallbackHandler.py

Lines changed: 140 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ def __init__(
135135
self._updated_completion_start_time_memo: Set[UUID] = set()
136136
self._propagation_context_manager: Optional[_AgnosticContextManager] = None
137137
self._trace_context = trace_context
138+
# LangGraph resumes as a fresh root callback run after interrupting, so we keep
139+
# just enough trace context to stitch the resume back onto the original trace.
140+
self._resume_trace_context: Optional[TraceContext] = None
138141
self._child_to_parent_run_id_map: Dict[UUID, Optional[UUID]] = {}
139142

140143
self.last_trace_id: Optional[str] = None
@@ -161,6 +164,44 @@ def on_llm_new_token(
161164

162165
self._updated_completion_start_time_memo.add(run_id)
163166

167+
def _consume_root_trace_context(self) -> Optional[TraceContext]:
168+
if self._trace_context is not None:
169+
return self._trace_context
170+
171+
current_span_context = trace.get_current_span().get_span_context()
172+
173+
# Only reuse the pending resume context when this callback run has no active
174+
# parent span of its own. Nested callbacks should attach normally.
175+
if current_span_context.is_valid:
176+
return None
177+
178+
trace_context = self._resume_trace_context
179+
self._resume_trace_context = None
180+
181+
return trace_context
182+
183+
def _clear_resume_trace_context(self) -> None:
184+
self._resume_trace_context = None
185+
186+
def _persist_resume_trace_context(self, observation: Any) -> None:
187+
if self._trace_context is not None:
188+
return
189+
190+
self._resume_trace_context = {
191+
"trace_id": observation.trace_id,
192+
"parent_span_id": observation.id,
193+
}
194+
195+
def _get_error_level_and_status_message(
196+
self, error: BaseException
197+
) -> tuple[Literal["DEFAULT", "ERROR"], str]:
198+
# LangGraph uses GraphBubbleUp subclasses for expected control flow such as
199+
# interrupts and handoffs, so they should stay visible without being errors.
200+
if any(isinstance(error, t) for t in CONTROL_FLOW_EXCEPTION_TYPES):
201+
return "DEFAULT", str(error) or type(error).__name__
202+
203+
return "ERROR", str(error)
204+
164205
def _get_observation_type_from_serialized(
165206
self, serialized: Optional[Dict[str, Any]], callback_type: str, **kwargs: Any
166207
) -> Union[
@@ -256,13 +297,22 @@ def on_retriever_error(
256297
observation = self._detach_observation(run_id)
257298

258299
if observation is not None:
300+
level, status_message = self._get_error_level_and_status_message(error)
259301
observation.update(
260-
level="ERROR",
261-
status_message=str(error),
302+
level=cast(
303+
Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
304+
level,
305+
),
306+
status_message=status_message,
262307
input=kwargs.get("inputs"),
263308
cost_details={"total": 0},
264309
).end()
265310

311+
if parent_run_id is None and level == "DEFAULT":
312+
self._persist_resume_trace_context(observation)
313+
elif parent_run_id is None:
314+
self._clear_resume_trace_context()
315+
266316
except Exception as e:
267317
langfuse_logger.exception(e)
268318

@@ -376,7 +426,7 @@ def on_chain_start(
376426
obs = self._get_parent_observation(parent_run_id)
377427
if isinstance(obs, Langfuse):
378428
span = obs.start_observation(
379-
trace_context=self._trace_context,
429+
trace_context=self._consume_root_trace_context(),
380430
name=span_name,
381431
as_type=observation_type,
382432
metadata=span_metadata,
@@ -580,6 +630,7 @@ def on_chain_end(
580630
)
581631

582632
if parent_run_id is None:
633+
self._clear_resume_trace_context()
583634
self._exit_propagation_context()
584635

585636
span.end()
@@ -605,10 +656,7 @@ def on_chain_error(
605656
) -> None:
606657
try:
607658
self._log_debug_event("on_chain_error", run_id, parent_run_id, error=error)
608-
if any(isinstance(error, t) for t in CONTROL_FLOW_EXCEPTION_TYPES):
609-
level = None
610-
else:
611-
level = "ERROR"
659+
level, status_message = self._get_error_level_and_status_message(error)
612660

613661
observation = self._detach_observation(run_id)
614662

@@ -618,12 +666,16 @@ def on_chain_error(
618666
Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
619667
level,
620668
),
621-
status_message=str(error) if level else None,
669+
status_message=status_message,
622670
input=kwargs.get("inputs"),
623671
cost_details={"total": 0},
624672
)
625673

626674
if parent_run_id is None:
675+
if level == "DEFAULT":
676+
self._persist_resume_trace_context(observation)
677+
else:
678+
self._clear_resume_trace_context()
627679
self._exit_propagation_context()
628680

629681
observation.end()
@@ -733,13 +785,24 @@ def on_tool_start(
733785
serialized, "tool", **kwargs
734786
)
735787

736-
span = self._get_parent_observation(parent_run_id).start_observation(
737-
name=self.get_langchain_run_name(serialized, **kwargs),
738-
as_type=observation_type,
739-
input=input_str,
740-
metadata=meta,
741-
level="DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None,
742-
)
788+
parent_observation = self._get_parent_observation(parent_run_id)
789+
if isinstance(parent_observation, Langfuse):
790+
span = parent_observation.start_observation(
791+
trace_context=self._consume_root_trace_context(),
792+
name=self.get_langchain_run_name(serialized, **kwargs),
793+
as_type=observation_type,
794+
input=input_str,
795+
metadata=meta,
796+
level="DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None,
797+
)
798+
else:
799+
span = parent_observation.start_observation(
800+
name=self.get_langchain_run_name(serialized, **kwargs),
801+
as_type=observation_type,
802+
input=input_str,
803+
metadata=meta,
804+
level="DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None,
805+
)
743806

744807
self._attach_observation(run_id, span)
745808

@@ -774,16 +837,30 @@ def on_retriever_start(
774837
observation_type = self._get_observation_type_from_serialized(
775838
serialized, "retriever", **kwargs
776839
)
777-
span = self._get_parent_observation(parent_run_id).start_observation(
778-
name=span_name,
779-
as_type=observation_type,
780-
metadata=span_metadata,
781-
input=query,
782-
level=cast(
783-
Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
784-
span_level,
785-
),
786-
)
840+
parent_observation = self._get_parent_observation(parent_run_id)
841+
if isinstance(parent_observation, Langfuse):
842+
span = parent_observation.start_observation(
843+
trace_context=self._consume_root_trace_context(),
844+
name=span_name,
845+
as_type=observation_type,
846+
metadata=span_metadata,
847+
input=query,
848+
level=cast(
849+
Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
850+
span_level,
851+
),
852+
)
853+
else:
854+
span = parent_observation.start_observation(
855+
name=span_name,
856+
as_type=observation_type,
857+
metadata=span_metadata,
858+
input=query,
859+
level=cast(
860+
Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
861+
span_level,
862+
),
863+
)
787864

788865
self._attach_observation(run_id, span)
789866

@@ -805,6 +882,8 @@ def on_retriever_end(
805882
observation = self._detach_observation(run_id)
806883

807884
if observation is not None:
885+
if parent_run_id is None:
886+
self._clear_resume_trace_context()
808887
observation.update(
809888
output=documents,
810889
input=kwargs.get("inputs"),
@@ -827,6 +906,8 @@ def on_tool_end(
827906
observation = self._detach_observation(run_id)
828907

829908
if observation is not None:
909+
if parent_run_id is None:
910+
self._clear_resume_trace_context()
830911
observation.update(
831912
output=output,
832913
input=kwargs.get("inputs"),
@@ -848,13 +929,22 @@ def on_tool_error(
848929
observation = self._detach_observation(run_id)
849930

850931
if observation is not None:
932+
level, status_message = self._get_error_level_and_status_message(error)
851933
observation.update(
852-
status_message=str(error),
853-
level="ERROR",
934+
status_message=status_message,
935+
level=cast(
936+
Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
937+
level,
938+
),
854939
input=kwargs.get("inputs"),
855940
cost_details={"total": 0},
856941
).end()
857942

943+
if parent_run_id is None and level == "DEFAULT":
944+
self._persist_resume_trace_context(observation)
945+
elif parent_run_id is None:
946+
self._clear_resume_trace_context()
947+
858948
except Exception as e:
859949
langfuse_logger.exception(e)
860950

@@ -913,9 +1003,17 @@ def __on_llm_action(
9131003
"prompt": registered_prompt,
9141004
}
9151005

916-
generation = self._get_parent_observation(parent_run_id).start_observation(
917-
as_type="generation", **content
918-
) # type: ignore
1006+
parent_observation = self._get_parent_observation(parent_run_id)
1007+
if isinstance(parent_observation, Langfuse):
1008+
generation = parent_observation.start_observation(
1009+
trace_context=self._consume_root_trace_context(),
1010+
as_type="generation",
1011+
**content,
1012+
) # type: ignore
1013+
else:
1014+
generation = parent_observation.start_observation(
1015+
as_type="generation", **content
1016+
) # type: ignore
9191017
self._attach_observation(run_id, generation)
9201018

9211019
self.last_trace_id = self._runs[run_id].trace_id
@@ -1028,6 +1126,7 @@ def on_llm_end(
10281126
self._updated_completion_start_time_memo.discard(run_id)
10291127

10301128
if parent_run_id is None:
1129+
self._clear_resume_trace_context()
10311130
self._reset()
10321131

10331132
def on_llm_error(
@@ -1044,13 +1143,22 @@ def on_llm_error(
10441143
generation = self._detach_observation(run_id)
10451144

10461145
if generation is not None:
1146+
level, status_message = self._get_error_level_and_status_message(error)
10471147
generation.update(
1048-
status_message=str(error),
1049-
level="ERROR",
1148+
status_message=status_message,
1149+
level=cast(
1150+
Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
1151+
level,
1152+
),
10501153
input=kwargs.get("inputs"),
10511154
cost_details={"total": 0},
10521155
).end()
10531156

1157+
if parent_run_id is None and level == "DEFAULT":
1158+
self._persist_resume_trace_context(generation)
1159+
elif parent_run_id is None:
1160+
self._clear_resume_trace_context()
1161+
10541162
except Exception as e:
10551163
langfuse_logger.exception(e)
10561164

0 commit comments

Comments
 (0)