Skip to content

Commit f5e5865

Browse files
committed
fix(langchain): preserve langgraph control flow traces
1 parent 264b94d commit f5e5865

2 files changed

Lines changed: 287 additions & 33 deletions

File tree

langfuse/langchain/CallbackHandler.py

Lines changed: 140 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ def __init__(
135135
self._updated_completion_start_time_memo: Set[UUID] = set()
136136
self._propagation_context_manager: Optional[_AgnosticContextManager] = None
137137
self._trace_context = trace_context
138+
# LangGraph resumes as a fresh root callback run after interrupting, so we keep
139+
# just enough trace context to stitch the resume back onto the original trace.
140+
self._resume_trace_context: Optional[TraceContext] = None
138141
self._child_to_parent_run_id_map: Dict[UUID, Optional[UUID]] = {}
139142

140143
self.last_trace_id: Optional[str] = None
@@ -161,6 +164,44 @@ def on_llm_new_token(
161164

162165
self._updated_completion_start_time_memo.add(run_id)
163166

167+
def _consume_root_trace_context(self) -> Optional[TraceContext]:
168+
if self._trace_context is not None:
169+
return self._trace_context
170+
171+
current_span_context = trace.get_current_span().get_span_context()
172+
173+
# Only reuse the pending resume context when this callback run has no active
174+
# parent span of its own. Nested callbacks should attach normally.
175+
if current_span_context.is_valid:
176+
return None
177+
178+
trace_context = self._resume_trace_context
179+
self._resume_trace_context = None
180+
181+
return trace_context
182+
183+
def _clear_resume_trace_context(self) -> None:
184+
self._resume_trace_context = None
185+
186+
def _persist_resume_trace_context(self, observation: Any) -> None:
187+
if self._trace_context is not None:
188+
return
189+
190+
self._resume_trace_context = {
191+
"trace_id": observation.trace_id,
192+
"parent_span_id": observation.id,
193+
}
194+
195+
def _get_error_level_and_status_message(
196+
self, error: BaseException
197+
) -> tuple[Literal["DEFAULT", "ERROR"], str]:
198+
# LangGraph uses GraphBubbleUp subclasses for expected control flow such as
199+
# interrupts and handoffs, so they should stay visible without being errors.
200+
if any(isinstance(error, t) for t in CONTROL_FLOW_EXCEPTION_TYPES):
201+
return "DEFAULT", str(error) or type(error).__name__
202+
203+
return "ERROR", str(error)
204+
164205
def _get_observation_type_from_serialized(
165206
self, serialized: Optional[Dict[str, Any]], callback_type: str, **kwargs: Any
166207
) -> Union[
@@ -256,13 +297,22 @@ def on_retriever_error(
256297
observation = self._detach_observation(run_id)
257298

258299
if observation is not None:
300+
level, status_message = self._get_error_level_and_status_message(error)
259301
observation.update(
260-
level="ERROR",
261-
status_message=str(error),
302+
level=cast(
303+
Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
304+
level,
305+
),
306+
status_message=status_message,
262307
input=kwargs.get("inputs"),
263308
cost_details={"total": 0},
264309
).end()
265310

311+
if parent_run_id is None and level == "DEFAULT":
312+
self._persist_resume_trace_context(observation)
313+
elif parent_run_id is None:
314+
self._clear_resume_trace_context()
315+
266316
except Exception as e:
267317
langfuse_logger.exception(e)
268318

@@ -382,7 +432,7 @@ def on_chain_start(
382432
obs = self._get_parent_observation(parent_run_id)
383433
if isinstance(obs, Langfuse):
384434
span = obs.start_observation(
385-
trace_context=self._trace_context,
435+
trace_context=self._consume_root_trace_context(),
386436
name=span_name,
387437
as_type=observation_type,
388438
metadata=span_metadata,
@@ -586,6 +636,7 @@ def on_chain_end(
586636
)
587637

588638
if parent_run_id is None:
639+
self._clear_resume_trace_context()
589640
self._exit_propagation_context()
590641

591642
span.end()
@@ -611,10 +662,7 @@ def on_chain_error(
611662
) -> None:
612663
try:
613664
self._log_debug_event("on_chain_error", run_id, parent_run_id, error=error)
614-
if any(isinstance(error, t) for t in CONTROL_FLOW_EXCEPTION_TYPES):
615-
level = None
616-
else:
617-
level = "ERROR"
665+
level, status_message = self._get_error_level_and_status_message(error)
618666

619667
observation = self._detach_observation(run_id)
620668

@@ -624,12 +672,16 @@ def on_chain_error(
624672
Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
625673
level,
626674
),
627-
status_message=str(error) if level else None,
675+
status_message=status_message,
628676
input=kwargs.get("inputs"),
629677
cost_details={"total": 0},
630678
)
631679

632680
if parent_run_id is None:
681+
if level == "DEFAULT":
682+
self._persist_resume_trace_context(observation)
683+
else:
684+
self._clear_resume_trace_context()
633685
self._exit_propagation_context()
634686

635687
observation.end()
@@ -739,13 +791,24 @@ def on_tool_start(
739791
serialized, "tool", **kwargs
740792
)
741793

742-
span = self._get_parent_observation(parent_run_id).start_observation(
743-
name=self.get_langchain_run_name(serialized, **kwargs),
744-
as_type=observation_type,
745-
input=input_str,
746-
metadata=meta,
747-
level="DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None,
748-
)
794+
parent_observation = self._get_parent_observation(parent_run_id)
795+
if isinstance(parent_observation, Langfuse):
796+
span = parent_observation.start_observation(
797+
trace_context=self._consume_root_trace_context(),
798+
name=self.get_langchain_run_name(serialized, **kwargs),
799+
as_type=observation_type,
800+
input=input_str,
801+
metadata=meta,
802+
level="DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None,
803+
)
804+
else:
805+
span = parent_observation.start_observation(
806+
name=self.get_langchain_run_name(serialized, **kwargs),
807+
as_type=observation_type,
808+
input=input_str,
809+
metadata=meta,
810+
level="DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None,
811+
)
749812

750813
self._attach_observation(run_id, span)
751814

@@ -780,16 +843,30 @@ def on_retriever_start(
780843
observation_type = self._get_observation_type_from_serialized(
781844
serialized, "retriever", **kwargs
782845
)
783-
span = self._get_parent_observation(parent_run_id).start_observation(
784-
name=span_name,
785-
as_type=observation_type,
786-
metadata=span_metadata,
787-
input=query,
788-
level=cast(
789-
Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
790-
span_level,
791-
),
792-
)
846+
parent_observation = self._get_parent_observation(parent_run_id)
847+
if isinstance(parent_observation, Langfuse):
848+
span = parent_observation.start_observation(
849+
trace_context=self._consume_root_trace_context(),
850+
name=span_name,
851+
as_type=observation_type,
852+
metadata=span_metadata,
853+
input=query,
854+
level=cast(
855+
Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
856+
span_level,
857+
),
858+
)
859+
else:
860+
span = parent_observation.start_observation(
861+
name=span_name,
862+
as_type=observation_type,
863+
metadata=span_metadata,
864+
input=query,
865+
level=cast(
866+
Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
867+
span_level,
868+
),
869+
)
793870

794871
self._attach_observation(run_id, span)
795872

@@ -811,6 +888,8 @@ def on_retriever_end(
811888
observation = self._detach_observation(run_id)
812889

813890
if observation is not None:
891+
if parent_run_id is None:
892+
self._clear_resume_trace_context()
814893
observation.update(
815894
output=documents,
816895
input=kwargs.get("inputs"),
@@ -833,6 +912,8 @@ def on_tool_end(
833912
observation = self._detach_observation(run_id)
834913

835914
if observation is not None:
915+
if parent_run_id is None:
916+
self._clear_resume_trace_context()
836917
observation.update(
837918
output=output,
838919
input=kwargs.get("inputs"),
@@ -854,13 +935,22 @@ def on_tool_error(
854935
observation = self._detach_observation(run_id)
855936

856937
if observation is not None:
938+
level, status_message = self._get_error_level_and_status_message(error)
857939
observation.update(
858-
status_message=str(error),
859-
level="ERROR",
940+
status_message=status_message,
941+
level=cast(
942+
Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
943+
level,
944+
),
860945
input=kwargs.get("inputs"),
861946
cost_details={"total": 0},
862947
).end()
863948

949+
if parent_run_id is None and level == "DEFAULT":
950+
self._persist_resume_trace_context(observation)
951+
elif parent_run_id is None:
952+
self._clear_resume_trace_context()
953+
864954
except Exception as e:
865955
langfuse_logger.exception(e)
866956

@@ -919,9 +1009,17 @@ def __on_llm_action(
9191009
"prompt": registered_prompt,
9201010
}
9211011

922-
generation = self._get_parent_observation(parent_run_id).start_observation(
923-
as_type="generation", **content
924-
) # type: ignore
1012+
parent_observation = self._get_parent_observation(parent_run_id)
1013+
if isinstance(parent_observation, Langfuse):
1014+
generation = parent_observation.start_observation(
1015+
trace_context=self._consume_root_trace_context(),
1016+
as_type="generation",
1017+
**content,
1018+
) # type: ignore
1019+
else:
1020+
generation = parent_observation.start_observation(
1021+
as_type="generation", **content
1022+
) # type: ignore
9251023
self._attach_observation(run_id, generation)
9261024

9271025
self.last_trace_id = self._runs[run_id].trace_id
@@ -1034,6 +1132,7 @@ def on_llm_end(
10341132
self._updated_completion_start_time_memo.discard(run_id)
10351133

10361134
if parent_run_id is None:
1135+
self._clear_resume_trace_context()
10371136
self._reset()
10381137

10391138
def on_llm_error(
@@ -1050,13 +1149,22 @@ def on_llm_error(
10501149
generation = self._detach_observation(run_id)
10511150

10521151
if generation is not None:
1152+
level, status_message = self._get_error_level_and_status_message(error)
10531153
generation.update(
1054-
status_message=str(error),
1055-
level="ERROR",
1154+
status_message=status_message,
1155+
level=cast(
1156+
Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
1157+
level,
1158+
),
10561159
input=kwargs.get("inputs"),
10571160
cost_details={"total": 0},
10581161
).end()
10591162

1163+
if parent_run_id is None and level == "DEFAULT":
1164+
self._persist_resume_trace_context(generation)
1165+
elif parent_run_id is None:
1166+
self._clear_resume_trace_context()
1167+
10601168
except Exception as e:
10611169
langfuse_logger.exception(e)
10621170

0 commit comments

Comments
 (0)