Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 140 additions & 32 deletions langfuse/langchain/CallbackHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@
self._updated_completion_start_time_memo: Set[UUID] = set()
self._propagation_context_manager: Optional[_AgnosticContextManager] = None
self._trace_context = trace_context
# LangGraph resumes as a fresh root callback run after interrupting, so we keep
# just enough trace context to stitch the resume back onto the original trace.
self._resume_trace_context: Optional[TraceContext] = None
self._child_to_parent_run_id_map: Dict[UUID, Optional[UUID]] = {}

self.last_trace_id: Optional[str] = None
Comment thread
claude[bot] marked this conversation as resolved.
Expand All @@ -161,6 +164,44 @@

self._updated_completion_start_time_memo.add(run_id)

def _consume_root_trace_context(self) -> Optional[TraceContext]:
if self._trace_context is not None:
return self._trace_context

current_span_context = trace.get_current_span().get_span_context()

# Only reuse the pending resume context when this callback run has no active
# parent span of its own. Nested callbacks should attach normally.
if current_span_context.is_valid:
return None

trace_context = self._resume_trace_context
self._resume_trace_context = None

Comment thread
hassiebp marked this conversation as resolved.
Outdated
return trace_context

def _clear_resume_trace_context(self) -> None:
self._resume_trace_context = None

def _persist_resume_trace_context(self, observation: Any) -> None:
if self._trace_context is not None:
return

self._resume_trace_context = {
"trace_id": observation.trace_id,
"parent_span_id": observation.id,
}
Comment thread
hassiebp marked this conversation as resolved.
Outdated

def _get_error_level_and_status_message(
self, error: BaseException
) -> tuple[Literal["DEFAULT", "ERROR"], str]:
# LangGraph uses GraphBubbleUp subclasses for expected control flow such as
# interrupts and handoffs, so they should stay visible without being errors.
if any(isinstance(error, t) for t in CONTROL_FLOW_EXCEPTION_TYPES):
return "DEFAULT", str(error) or type(error).__name__

return "ERROR", str(error)

def _get_observation_type_from_serialized(
self, serialized: Optional[Dict[str, Any]], callback_type: str, **kwargs: Any
) -> Union[
Expand Down Expand Up @@ -256,13 +297,22 @@
observation = self._detach_observation(run_id)

if observation is not None:
level, status_message = self._get_error_level_and_status_message(error)
observation.update(
level="ERROR",
status_message=str(error),
level=cast(
Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
level,
),
status_message=status_message,
input=kwargs.get("inputs"),
cost_details={"total": 0},
).end()

if parent_run_id is None and level == "DEFAULT":
self._persist_resume_trace_context(observation)
elif parent_run_id is None:
self._clear_resume_trace_context()

Comment on lines 460 to +478
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 In on_chain_error, _persist_resume_trace_context is correctly called before observation.end() (line 868), so the resume key is saved even if end() throws. In the three non-chain handlers (on_retriever_error, on_tool_error, on_llm_error), observation.update().end() is chained and _persist_resume_trace_context is called afterward; if end() throws, the except block is entered without saving the resume key, and the subsequent finally: _reset(run_id) destroys the entire _RootRunState including its resume_key, losing it permanently. Fix by moving _persist_resume_trace_context before the .end() call in these three handlers, mirroring on_chain_error.

Extended reasoning...

What the bug is and how it manifests

In on_chain_error, observation.update() and observation.end() are separate calls, with _persist_resume_trace_context invoked between them (line 868 before line 875). Even if observation.end() raises, the resume key has already been transferred into _pending_resume_trace_contexts. The three non-chain handlers (on_retriever_error, on_tool_error, on_llm_error) use observation.update(...).end() as a single chained expression, then call _persist_resume_trace_context afterward. If the chained .end() raises, Python immediately jumps to the except block and _persist_resume_trace_context is never reached.

The specific code path that triggers it

Three conditions must hold simultaneously: (1) the error handler is invoked for a root-level run (parent_run_id is None); (2) the exception is a LangGraph control-flow type so level == "DEFAULT"; (3) observation.update().end() raises internally. When all three hold, the resume key stored in root_run_state.resume_key is never popped into _pending_resume_trace_contexts.

Why existing code does not prevent it

The finally: if parent_run_id is None: self._reset(run_id) block then executes unconditionally. _reset calls self._root_run_states.pop(run_state.root_run_id, None), which removes the entire _RootRunState object. Since _persist_resume_trace_context was never called, root_run_state.resume_key still holds the thread's resume key — but the object is now discarded. The key never reaches _pending_resume_trace_contexts.

Step-by-step proof

  1. on_tool_start(run_id=A, parent_run_id=None, metadata={"thread_id": "t1"}) creates _RootRunState(resume_key="t1") in _root_run_states[A].
  2. Tool raises GraphBubbleUp; on_tool_error fires with run_id=A, parent_run_id=None.
  3. _get_error_level_and_status_message returns ("DEFAULT", ...); code enters the if observation is not None branch.
  4. observation.update(...).end() raises (SDK internal error, OTel resource exhaustion, etc.).
  5. except Exception catches; execution jumps to finally. _persist_resume_trace_context at line 473 was never reached; _pending_resume_trace_contexts has no entry for "t1".
  6. finally: _reset(A) pops _root_run_states[A] — the _RootRunState with resume_key="t1" is discarded.
  7. Next Command(resume=...) invocation on thread "t1": _take_root_trace_context returns None. A disconnected trace is created instead of stitching to the interrupted one.

How to fix

Split observation.update().end() into separate calls and move _persist_resume_trace_context / _clear_root_run_resume_key before the .end() call in all three non-chain handlers, mirroring on_chain_error. The observation.trace_id and observation.id fields are assigned at span creation time, not at end time, so calling persist before end() is safe.

except Exception as e:
langfuse_logger.exception(e)

Expand Down Expand Up @@ -382,7 +432,7 @@
obs = self._get_parent_observation(parent_run_id)
if isinstance(obs, Langfuse):
span = obs.start_observation(
trace_context=self._trace_context,
trace_context=self._consume_root_trace_context(),
name=span_name,
as_type=observation_type,
metadata=span_metadata,
Expand Down Expand Up @@ -586,6 +636,7 @@
)

if parent_run_id is None:
self._clear_resume_trace_context()
self._exit_propagation_context()

span.end()
Expand All @@ -611,10 +662,7 @@
) -> None:
try:
self._log_debug_event("on_chain_error", run_id, parent_run_id, error=error)
if any(isinstance(error, t) for t in CONTROL_FLOW_EXCEPTION_TYPES):
level = None
else:
level = "ERROR"
level, status_message = self._get_error_level_and_status_message(error)

observation = self._detach_observation(run_id)

Expand All @@ -624,12 +672,16 @@
Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
level,
),
status_message=str(error) if level else None,
status_message=status_message,
input=kwargs.get("inputs"),
cost_details={"total": 0},
)

if parent_run_id is None:
if level == "DEFAULT":
self._persist_resume_trace_context(observation)
else:
self._clear_resume_trace_context()
self._exit_propagation_context()

observation.end()
Expand Down Expand Up @@ -739,13 +791,24 @@
serialized, "tool", **kwargs
)

span = self._get_parent_observation(parent_run_id).start_observation(
name=self.get_langchain_run_name(serialized, **kwargs),
as_type=observation_type,
input=input_str,
metadata=meta,
level="DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None,
)
parent_observation = self._get_parent_observation(parent_run_id)
if isinstance(parent_observation, Langfuse):
span = parent_observation.start_observation(
trace_context=self._consume_root_trace_context(),
name=self.get_langchain_run_name(serialized, **kwargs),
as_type=observation_type,
input=input_str,
metadata=meta,
level="DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None,
)
else:
span = parent_observation.start_observation(
name=self.get_langchain_run_name(serialized, **kwargs),
as_type=observation_type,
input=input_str,
metadata=meta,
level="DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None,
)

self._attach_observation(run_id, span)

Comment thread
claude[bot] marked this conversation as resolved.
Expand Down Expand Up @@ -780,16 +843,30 @@
observation_type = self._get_observation_type_from_serialized(
serialized, "retriever", **kwargs
)
span = self._get_parent_observation(parent_run_id).start_observation(
name=span_name,
as_type=observation_type,
metadata=span_metadata,
input=query,
level=cast(
Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
span_level,
),
)
parent_observation = self._get_parent_observation(parent_run_id)
if isinstance(parent_observation, Langfuse):
span = parent_observation.start_observation(
trace_context=self._consume_root_trace_context(),
name=span_name,
as_type=observation_type,
metadata=span_metadata,
input=query,
level=cast(
Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
span_level,
),
)
else:
span = parent_observation.start_observation(
name=span_name,
as_type=observation_type,
metadata=span_metadata,
input=query,
level=cast(
Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
span_level,
),
)

self._attach_observation(run_id, span)

Expand All @@ -811,6 +888,8 @@
observation = self._detach_observation(run_id)

if observation is not None:
if parent_run_id is None:
self._clear_resume_trace_context()
observation.update(
output=documents,
input=kwargs.get("inputs"),
Expand All @@ -833,6 +912,8 @@
observation = self._detach_observation(run_id)

if observation is not None:
if parent_run_id is None:
self._clear_resume_trace_context()
observation.update(
output=output,
input=kwargs.get("inputs"),
Comment thread
claude[bot] marked this conversation as resolved.
Expand All @@ -854,15 +935,24 @@
observation = self._detach_observation(run_id)

if observation is not None:
level, status_message = self._get_error_level_and_status_message(error)
observation.update(
status_message=str(error),
level="ERROR",
status_message=status_message,
level=cast(
Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
level,
),
input=kwargs.get("inputs"),
cost_details={"total": 0},
).end()

if parent_run_id is None and level == "DEFAULT":
self._persist_resume_trace_context(observation)
elif parent_run_id is None:
self._clear_resume_trace_context()

except Exception as e:
langfuse_logger.exception(e)

Check warning on line 955 in langfuse/langchain/CallbackHandler.py

View check run for this annotation

Claude / Claude Code Review

Missing _reset() in root-level control-flow error paths for tool/retriever/LLM handlers

The three newly-modified error handlers — `on_tool_error` (line 949), `on_retriever_error` (line 309), and `on_llm_error` (line 1161) — call `_persist_resume_trace_context` for root-level control-flow exceptions but never call `_reset()` afterward, unlike `on_chain_error` which correctly includes `_reset()` in its `finally` block. This is a pre-existing gap that the PR makes slightly worse by adding a new side effect (`_persist_resume_trace_context`) without matching cleanup, meaning stale `_chi
Comment thread
claude[bot] marked this conversation as resolved.

def __on_llm_action(
self,
Expand Down Expand Up @@ -919,9 +1009,17 @@
"prompt": registered_prompt,
}

generation = self._get_parent_observation(parent_run_id).start_observation(
as_type="generation", **content
) # type: ignore
parent_observation = self._get_parent_observation(parent_run_id)
if isinstance(parent_observation, Langfuse):
generation = parent_observation.start_observation(
trace_context=self._consume_root_trace_context(),
as_type="generation",
**content,
) # type: ignore
else:
generation = parent_observation.start_observation(
as_type="generation", **content
) # type: ignore
self._attach_observation(run_id, generation)

self.last_trace_id = self._runs[run_id].trace_id
Expand Down Expand Up @@ -1034,6 +1132,7 @@
self._updated_completion_start_time_memo.discard(run_id)

if parent_run_id is None:
self._clear_resume_trace_context()
self._reset()
Comment thread
hassiebp marked this conversation as resolved.
Outdated

def on_llm_error(
Expand All @@ -1050,13 +1149,22 @@
generation = self._detach_observation(run_id)

if generation is not None:
level, status_message = self._get_error_level_and_status_message(error)
generation.update(
status_message=str(error),
level="ERROR",
status_message=status_message,
level=cast(
Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
level,
),
input=kwargs.get("inputs"),
cost_details={"total": 0},
).end()

if parent_run_id is None and level == "DEFAULT":
self._persist_resume_trace_context(generation)
elif parent_run_id is None:
self._clear_resume_trace_context()

except Exception as e:
langfuse_logger.exception(e)

Expand Down
Loading
Loading