From f5e5865b3df4da752fb92cd04d880b1b31ee2a1c Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Mon, 20 Apr 2026 14:32:19 +0200 Subject: [PATCH 1/6] fix(langchain): preserve langgraph control flow traces --- langfuse/langchain/CallbackHandler.py | 172 +++++++++++++++++++++----- tests/unit/test_langchain.py | 148 +++++++++++++++++++++- 2 files changed, 287 insertions(+), 33 deletions(-) diff --git a/langfuse/langchain/CallbackHandler.py b/langfuse/langchain/CallbackHandler.py index 80b7114e5..b4c3eccf7 100644 --- a/langfuse/langchain/CallbackHandler.py +++ b/langfuse/langchain/CallbackHandler.py @@ -135,6 +135,9 @@ def __init__( self._updated_completion_start_time_memo: Set[UUID] = set() self._propagation_context_manager: Optional[_AgnosticContextManager] = None self._trace_context = trace_context + # LangGraph resumes as a fresh root callback run after interrupting, so we keep + # just enough trace context to stitch the resume back onto the original trace. + self._resume_trace_context: Optional[TraceContext] = None self._child_to_parent_run_id_map: Dict[UUID, Optional[UUID]] = {} self.last_trace_id: Optional[str] = None @@ -161,6 +164,44 @@ def on_llm_new_token( self._updated_completion_start_time_memo.add(run_id) + def _consume_root_trace_context(self) -> Optional[TraceContext]: + if self._trace_context is not None: + return self._trace_context + + current_span_context = trace.get_current_span().get_span_context() + + # Only reuse the pending resume context when this callback run has no active + # parent span of its own. Nested callbacks should attach normally. + if current_span_context.is_valid: + return None + + trace_context = self._resume_trace_context + self._resume_trace_context = None + + return trace_context + + def _clear_resume_trace_context(self) -> None: + self._resume_trace_context = None + + def _persist_resume_trace_context(self, observation: Any) -> None: + if self._trace_context is not None: + return + + self._resume_trace_context = { + "trace_id": observation.trace_id, + "parent_span_id": observation.id, + } + + def _get_error_level_and_status_message( + self, error: BaseException + ) -> tuple[Literal["DEFAULT", "ERROR"], str]: + # LangGraph uses GraphBubbleUp subclasses for expected control flow such as + # interrupts and handoffs, so they should stay visible without being errors. + if any(isinstance(error, t) for t in CONTROL_FLOW_EXCEPTION_TYPES): + return "DEFAULT", str(error) or type(error).__name__ + + return "ERROR", str(error) + def _get_observation_type_from_serialized( self, serialized: Optional[Dict[str, Any]], callback_type: str, **kwargs: Any ) -> Union[ @@ -256,13 +297,22 @@ def on_retriever_error( observation = self._detach_observation(run_id) if observation is not None: + level, status_message = self._get_error_level_and_status_message(error) observation.update( - level="ERROR", - status_message=str(error), + level=cast( + Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], + level, + ), + status_message=status_message, input=kwargs.get("inputs"), cost_details={"total": 0}, ).end() + if parent_run_id is None and level == "DEFAULT": + self._persist_resume_trace_context(observation) + elif parent_run_id is None: + self._clear_resume_trace_context() + except Exception as e: langfuse_logger.exception(e) @@ -382,7 +432,7 @@ def on_chain_start( obs = self._get_parent_observation(parent_run_id) if isinstance(obs, Langfuse): span = obs.start_observation( - trace_context=self._trace_context, + trace_context=self._consume_root_trace_context(), name=span_name, as_type=observation_type, metadata=span_metadata, @@ -586,6 +636,7 @@ def on_chain_end( ) if parent_run_id is None: + self._clear_resume_trace_context() self._exit_propagation_context() span.end() @@ -611,10 +662,7 @@ def on_chain_error( ) -> None: try: self._log_debug_event("on_chain_error", run_id, parent_run_id, error=error) - if any(isinstance(error, t) for t in CONTROL_FLOW_EXCEPTION_TYPES): - level = None - else: - level = "ERROR" + level, status_message = self._get_error_level_and_status_message(error) observation = self._detach_observation(run_id) @@ -624,12 +672,16 @@ def on_chain_error( Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], level, ), - status_message=str(error) if level else None, + status_message=status_message, input=kwargs.get("inputs"), cost_details={"total": 0}, ) if parent_run_id is None: + if level == "DEFAULT": + self._persist_resume_trace_context(observation) + else: + self._clear_resume_trace_context() self._exit_propagation_context() observation.end() @@ -739,13 +791,24 @@ def on_tool_start( serialized, "tool", **kwargs ) - span = self._get_parent_observation(parent_run_id).start_observation( - name=self.get_langchain_run_name(serialized, **kwargs), - as_type=observation_type, - input=input_str, - metadata=meta, - level="DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None, - ) + parent_observation = self._get_parent_observation(parent_run_id) + if isinstance(parent_observation, Langfuse): + span = parent_observation.start_observation( + trace_context=self._consume_root_trace_context(), + name=self.get_langchain_run_name(serialized, **kwargs), + as_type=observation_type, + input=input_str, + metadata=meta, + level="DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None, + ) + else: + span = parent_observation.start_observation( + name=self.get_langchain_run_name(serialized, **kwargs), + as_type=observation_type, + input=input_str, + metadata=meta, + level="DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None, + ) self._attach_observation(run_id, span) @@ -780,16 +843,30 @@ def on_retriever_start( observation_type = self._get_observation_type_from_serialized( serialized, "retriever", **kwargs ) - span = self._get_parent_observation(parent_run_id).start_observation( - name=span_name, - as_type=observation_type, - metadata=span_metadata, - input=query, - level=cast( - Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], - span_level, - ), - ) + parent_observation = self._get_parent_observation(parent_run_id) + if isinstance(parent_observation, Langfuse): + span = parent_observation.start_observation( + trace_context=self._consume_root_trace_context(), + name=span_name, + as_type=observation_type, + metadata=span_metadata, + input=query, + level=cast( + Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], + span_level, + ), + ) + else: + span = parent_observation.start_observation( + name=span_name, + as_type=observation_type, + metadata=span_metadata, + input=query, + level=cast( + Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], + span_level, + ), + ) self._attach_observation(run_id, span) @@ -811,6 +888,8 @@ def on_retriever_end( observation = self._detach_observation(run_id) if observation is not None: + if parent_run_id is None: + self._clear_resume_trace_context() observation.update( output=documents, input=kwargs.get("inputs"), @@ -833,6 +912,8 @@ def on_tool_end( observation = self._detach_observation(run_id) if observation is not None: + if parent_run_id is None: + self._clear_resume_trace_context() observation.update( output=output, input=kwargs.get("inputs"), @@ -854,13 +935,22 @@ def on_tool_error( observation = self._detach_observation(run_id) if observation is not None: + level, status_message = self._get_error_level_and_status_message(error) observation.update( - status_message=str(error), - level="ERROR", + status_message=status_message, + level=cast( + Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], + level, + ), input=kwargs.get("inputs"), cost_details={"total": 0}, ).end() + if parent_run_id is None and level == "DEFAULT": + self._persist_resume_trace_context(observation) + elif parent_run_id is None: + self._clear_resume_trace_context() + except Exception as e: langfuse_logger.exception(e) @@ -919,9 +1009,17 @@ def __on_llm_action( "prompt": registered_prompt, } - generation = self._get_parent_observation(parent_run_id).start_observation( - as_type="generation", **content - ) # type: ignore + parent_observation = self._get_parent_observation(parent_run_id) + if isinstance(parent_observation, Langfuse): + generation = parent_observation.start_observation( + trace_context=self._consume_root_trace_context(), + as_type="generation", + **content, + ) # type: ignore + else: + generation = parent_observation.start_observation( + as_type="generation", **content + ) # type: ignore self._attach_observation(run_id, generation) self.last_trace_id = self._runs[run_id].trace_id @@ -1034,6 +1132,7 @@ def on_llm_end( self._updated_completion_start_time_memo.discard(run_id) if parent_run_id is None: + self._clear_resume_trace_context() self._reset() def on_llm_error( @@ -1050,13 +1149,22 @@ def on_llm_error( generation = self._detach_observation(run_id) if generation is not None: + level, status_message = self._get_error_level_and_status_message(error) generation.update( - status_message=str(error), - level="ERROR", + status_message=status_message, + level=cast( + Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], + level, + ), input=kwargs.get("inputs"), cost_details={"total": 0}, ).end() + if parent_run_id is None and level == "DEFAULT": + self._persist_resume_trace_context(generation) + elif parent_run_id is None: + self._clear_resume_trace_context() + except Exception as e: langfuse_logger.exception(e) diff --git a/tests/unit/test_langchain.py b/tests/unit/test_langchain.py index 5d8406e9c..e43579b28 100644 --- a/tests/unit/test_langchain.py +++ b/tests/unit/test_langchain.py @@ -12,6 +12,7 @@ from langfuse._client.attributes import LangfuseOtelSpanAttributes from langfuse.langchain import CallbackHandler +from langfuse.langchain.CallbackHandler import CONTROL_FLOW_EXCEPTION_TYPES def _assert_parent_child(parent_span, child_span) -> None: @@ -169,7 +170,6 @@ def test_chat_model_error_marks_generation_error(langfuse_memory_client, get_spa "boom" in span.attributes[LangfuseOtelSpanAttributes.OBSERVATION_STATUS_MESSAGE] ) - def test_root_chain_metadata_propagates_trace_name( langfuse_memory_client, get_span, find_spans ): @@ -249,3 +249,149 @@ def test_root_chain_exports_when_end_runs_in_copied_context( assert root_span.attributes[LangfuseOtelSpanAttributes.TRACE_NAME] == ( "async-root-trace" ) + + +def test_control_flow_errors_use_default_level_and_keep_status_message( + langfuse_memory_client, get_span +): + class DummyControlFlowError(RuntimeError): + pass + + original_control_flow_types = set(CONTROL_FLOW_EXCEPTION_TYPES) + CONTROL_FLOW_EXCEPTION_TYPES.clear() + CONTROL_FLOW_EXCEPTION_TYPES.add(DummyControlFlowError) + + try: + handler = CallbackHandler() + + tool_run_id = uuid4() + retriever_run_id = uuid4() + llm_run_id = uuid4() + chain_run_id = uuid4() + + handler.on_tool_start( + {"name": "human_approval"}, + "{}", + run_id=tool_run_id, + ) + handler.on_tool_error( + DummyControlFlowError("tool interrupt"), + run_id=tool_run_id, + ) + + handler.on_retriever_start( + {"name": "knowledge_base"}, + "approval policy", + run_id=retriever_run_id, + ) + handler.on_retriever_error( + DummyControlFlowError("retriever bubble-up"), + run_id=retriever_run_id, + ) + + handler.on_llm_start( + {"name": "TestLLM"}, + ["need approval"], + run_id=llm_run_id, + invocation_params={}, + ) + handler.on_llm_error( + DummyControlFlowError("llm bubble-up"), + run_id=llm_run_id, + ) + + handler.on_chain_start( + {"name": "LangGraph"}, + {"messages": ["need approval"]}, + run_id=chain_run_id, + ) + handler.on_chain_error( + DummyControlFlowError("graph interrupt"), + run_id=chain_run_id, + ) + + handler._langfuse_client.flush() + + for span_name, message in [ + ("human_approval", "tool interrupt"), + ("knowledge_base", "retriever bubble-up"), + ("TestLLM", "llm bubble-up"), + ("LangGraph", "graph interrupt"), + ]: + span = get_span(span_name) + assert ( + span.attributes[LangfuseOtelSpanAttributes.OBSERVATION_LEVEL] + == "DEFAULT" + ) + assert ( + span.attributes[LangfuseOtelSpanAttributes.OBSERVATION_STATUS_MESSAGE] + == message + ) + finally: + CONTROL_FLOW_EXCEPTION_TYPES.clear() + CONTROL_FLOW_EXCEPTION_TYPES.update(original_control_flow_types) + + +def test_control_flow_resume_reuses_trace_until_terminal_completion( + memory_exporter, langfuse_memory_client +): + class DummyControlFlowError(RuntimeError): + pass + + original_control_flow_types = set(CONTROL_FLOW_EXCEPTION_TYPES) + CONTROL_FLOW_EXCEPTION_TYPES.clear() + CONTROL_FLOW_EXCEPTION_TYPES.add(DummyControlFlowError) + + try: + handler = CallbackHandler() + + interrupted_run_id = uuid4() + resumed_run_id = uuid4() + fresh_run_id = uuid4() + + handler.on_chain_start( + {"name": "LangGraph"}, + {"messages": ["need approval"]}, + run_id=interrupted_run_id, + ) + handler.on_chain_error( + DummyControlFlowError("graph interrupt"), + run_id=interrupted_run_id, + ) + + handler.on_chain_start( + {"name": "LangGraph"}, + {"resume": True}, + run_id=resumed_run_id, + ) + handler.on_chain_end( + {"messages": ["approved"]}, + run_id=resumed_run_id, + ) + + handler.on_chain_start( + {"name": "LangGraph"}, + {"messages": ["fresh invocation"]}, + run_id=fresh_run_id, + ) + handler.on_chain_end( + {"messages": ["completed"]}, + run_id=fresh_run_id, + ) + + handler._langfuse_client.flush() + + root_spans = [ + span + for span in memory_exporter.get_finished_spans() + if span.name == "LangGraph" + ] + + assert len(root_spans) == 3 + assert root_spans[0].context.trace_id == root_spans[1].context.trace_id + assert root_spans[1].parent is not None + assert root_spans[1].parent.span_id == root_spans[0].context.span_id + assert root_spans[2].context.trace_id != root_spans[1].context.trace_id + finally: + CONTROL_FLOW_EXCEPTION_TYPES.clear() + CONTROL_FLOW_EXCEPTION_TYPES.update(original_control_flow_types) From 19a231e570dbda477c227ddcfc0ccd28329a57a3 Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Mon, 20 Apr 2026 14:41:55 +0200 Subject: [PATCH 2/6] test(langchain): isolate resume trace assertions --- tests/unit/test_langchain.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_langchain.py b/tests/unit/test_langchain.py index e43579b28..a30fbf136 100644 --- a/tests/unit/test_langchain.py +++ b/tests/unit/test_langchain.py @@ -9,6 +9,7 @@ from langchain_core.outputs import ChatGeneration, ChatResult, Generation, LLMResult from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI, OpenAI +from opentelemetry import context as otel_context from langfuse._client.attributes import LangfuseOtelSpanAttributes from langfuse.langchain import CallbackHandler @@ -338,6 +339,7 @@ def test_control_flow_resume_reuses_trace_until_terminal_completion( class DummyControlFlowError(RuntimeError): pass + context_token = otel_context.attach(otel_context.Context()) original_control_flow_types = set(CONTROL_FLOW_EXCEPTION_TYPES) CONTROL_FLOW_EXCEPTION_TYPES.clear() CONTROL_FLOW_EXCEPTION_TYPES.add(DummyControlFlowError) @@ -388,10 +390,22 @@ class DummyControlFlowError(RuntimeError): ] assert len(root_spans) == 3 - assert root_spans[0].context.trace_id == root_spans[1].context.trace_id - assert root_spans[1].parent is not None - assert root_spans[1].parent.span_id == root_spans[0].context.span_id - assert root_spans[2].context.trace_id != root_spans[1].context.trace_id + spans_by_trace_id = {} + for span in root_spans: + spans_by_trace_id.setdefault(span.context.trace_id, []).append(span) + + assert sorted(len(spans) for spans in spans_by_trace_id.values()) == [1, 2] + + resumed_trace_spans = next( + spans for spans in spans_by_trace_id.values() if len(spans) == 2 + ) + initial_span = next(span for span in resumed_trace_spans if span.parent is None) + resumed_span = next(span for span in resumed_trace_spans if span.parent is not None) + fresh_span = next(span for span in root_spans if span.context.trace_id != initial_span.context.trace_id) + + assert resumed_span.parent.span_id == initial_span.context.span_id + assert fresh_span.parent is None finally: CONTROL_FLOW_EXCEPTION_TYPES.clear() CONTROL_FLOW_EXCEPTION_TYPES.update(original_control_flow_types) + otel_context.detach(context_token) From ad2e408fdfdc11e87241817ac6b64940dfe08ab9 Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Mon, 20 Apr 2026 15:00:57 +0200 Subject: [PATCH 3/6] fix(langchain): scope langgraph resume context --- langfuse/langchain/CallbackHandler.py | 126 +++++++++++--- tests/unit/test_langchain.py | 227 +++++++++++++++----------- 2 files changed, 230 insertions(+), 123 deletions(-) diff --git a/langfuse/langchain/CallbackHandler.py b/langfuse/langchain/CallbackHandler.py index b4c3eccf7..986548844 100644 --- a/langfuse/langchain/CallbackHandler.py +++ b/langfuse/langchain/CallbackHandler.py @@ -84,11 +84,14 @@ LANGSMITH_TAG_HIDDEN: str = "langsmith:hidden" CONTROL_FLOW_EXCEPTION_TYPES: Set[Type[BaseException]] = set() +LANGGRAPH_COMMAND_TYPE: Optional[Type[Any]] = None try: from langgraph.errors import GraphBubbleUp + from langgraph.types import Command as LangGraphCommand CONTROL_FLOW_EXCEPTION_TYPES.add(GraphBubbleUp) + LANGGRAPH_COMMAND_TYPE = LangGraphCommand except ImportError: pass @@ -136,8 +139,9 @@ def __init__( self._propagation_context_manager: Optional[_AgnosticContextManager] = None self._trace_context = trace_context # LangGraph resumes as a fresh root callback run after interrupting, so we keep - # just enough trace context to stitch the resume back onto the original trace. - self._resume_trace_context: Optional[TraceContext] = None + # pending resume contexts keyed by thread/session instead of a single shared slot. + self._resume_trace_context_by_key: Dict[str, TraceContext] = {} + self._root_run_resume_key_map: Dict[UUID, str] = {} self._child_to_parent_run_id_map: Dict[UUID, Optional[UUID]] = {} self.last_trace_id: Optional[str] = None @@ -164,7 +168,35 @@ def on_llm_new_token( self._updated_completion_start_time_memo.add(run_id) - def _consume_root_trace_context(self) -> Optional[TraceContext]: + def _get_langgraph_resume_key( + self, metadata: Optional[Dict[str, Any]] + ) -> Optional[str]: + thread_id = metadata.get("thread_id") if metadata else None + + if thread_id is None: + return None + + return str(thread_id) + + def _set_root_run_resume_key( + self, run_id: UUID, metadata: Optional[Dict[str, Any]] + ) -> None: + resume_key = self._get_langgraph_resume_key(metadata) + + if resume_key is not None: + self._root_run_resume_key_map[run_id] = resume_key + + def _pop_root_run_resume_key(self, run_id: UUID) -> Optional[str]: + return self._root_run_resume_key_map.pop(run_id, None) + + def _is_langgraph_resume(self, inputs: Any) -> bool: + return LANGGRAPH_COMMAND_TYPE is not None and isinstance( + inputs, LANGGRAPH_COMMAND_TYPE + ) + + def _consume_root_trace_context( + self, *, inputs: Any, metadata: Optional[Dict[str, Any]] + ) -> Optional[TraceContext]: if self._trace_context is not None: return self._trace_context @@ -175,19 +207,30 @@ def _consume_root_trace_context(self) -> Optional[TraceContext]: if current_span_context.is_valid: return None - trace_context = self._resume_trace_context - self._resume_trace_context = None + # Only explicit LangGraph resumes should consume pending trace linkage. + if not self._is_langgraph_resume(inputs): + return None + + resume_key = self._get_langgraph_resume_key(metadata) + if resume_key is None: + return None - return trace_context + return self._resume_trace_context_by_key.pop(resume_key, None) - def _clear_resume_trace_context(self) -> None: - self._resume_trace_context = None + def _clear_root_run_resume_key(self, run_id: UUID) -> None: + # Keep the pending interrupt context until an explicit Command(resume=...) + # arrives. A separate root run on the same thread_id is not a resume. + self._pop_root_run_resume_key(run_id) - def _persist_resume_trace_context(self, observation: Any) -> None: + def _persist_resume_trace_context(self, *, run_id: UUID, observation: Any) -> None: if self._trace_context is not None: return - self._resume_trace_context = { + resume_key = self._pop_root_run_resume_key(run_id) + if resume_key is None: + return + + self._resume_trace_context_by_key[resume_key] = { "trace_id": observation.trace_id, "parent_span_id": observation.id, } @@ -309,12 +352,17 @@ def on_retriever_error( ).end() if parent_run_id is None and level == "DEFAULT": - self._persist_resume_trace_context(observation) + self._persist_resume_trace_context( + run_id=run_id, observation=observation + ) elif parent_run_id is None: - self._clear_resume_trace_context() + self._clear_root_run_resume_key(run_id) except Exception as e: langfuse_logger.exception(e) + finally: + if parent_run_id is None: + self._reset() def _parse_langfuse_trace_attributes( self, *, metadata: Optional[Dict[str, Any]], tags: Optional[List[str]] @@ -383,7 +431,7 @@ def _get_langchain_observation_metadata( def on_chain_start( self, serialized: Optional[Dict[str, Any]], - inputs: Dict[str, Any], + inputs: Any, *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -392,6 +440,8 @@ def on_chain_start( **kwargs: Any, ) -> Any: self._child_to_parent_run_id_map[run_id] = parent_run_id + if parent_run_id is None: + self._set_root_run_resume_key(run_id, metadata) try: self._log_debug_event( @@ -432,7 +482,9 @@ def on_chain_start( obs = self._get_parent_observation(parent_run_id) if isinstance(obs, Langfuse): span = obs.start_observation( - trace_context=self._consume_root_trace_context(), + trace_context=self._consume_root_trace_context( + inputs=inputs, metadata=metadata + ), name=span_name, as_type=observation_type, metadata=span_metadata, @@ -636,7 +688,7 @@ def on_chain_end( ) if parent_run_id is None: - self._clear_resume_trace_context() + self._clear_root_run_resume_key(run_id) self._exit_propagation_context() span.end() @@ -679,9 +731,11 @@ def on_chain_error( if parent_run_id is None: if level == "DEFAULT": - self._persist_resume_trace_context(observation) + self._persist_resume_trace_context( + run_id=run_id, observation=observation + ) else: - self._clear_resume_trace_context() + self._clear_root_run_resume_key(run_id) self._exit_propagation_context() observation.end() @@ -739,6 +793,8 @@ def on_llm_start( **kwargs: Any, ) -> Any: self._child_to_parent_run_id_map[run_id] = parent_run_id + if parent_run_id is None: + self._set_root_run_resume_key(run_id, metadata) try: self._log_debug_event( @@ -794,7 +850,7 @@ def on_tool_start( parent_observation = self._get_parent_observation(parent_run_id) if isinstance(parent_observation, Langfuse): span = parent_observation.start_observation( - trace_context=self._consume_root_trace_context(), + trace_context=self._trace_context, name=self.get_langchain_run_name(serialized, **kwargs), as_type=observation_type, input=input_str, @@ -846,7 +902,7 @@ def on_retriever_start( parent_observation = self._get_parent_observation(parent_run_id) if isinstance(parent_observation, Langfuse): span = parent_observation.start_observation( - trace_context=self._consume_root_trace_context(), + trace_context=self._trace_context, name=span_name, as_type=observation_type, metadata=span_metadata, @@ -889,7 +945,7 @@ def on_retriever_end( if observation is not None: if parent_run_id is None: - self._clear_resume_trace_context() + self._clear_root_run_resume_key(run_id) observation.update( output=documents, input=kwargs.get("inputs"), @@ -897,6 +953,9 @@ def on_retriever_end( except Exception as e: langfuse_logger.exception(e) + finally: + if parent_run_id is None: + self._reset() def on_tool_end( self, @@ -913,7 +972,7 @@ def on_tool_end( if observation is not None: if parent_run_id is None: - self._clear_resume_trace_context() + self._clear_root_run_resume_key(run_id) observation.update( output=output, input=kwargs.get("inputs"), @@ -947,12 +1006,17 @@ def on_tool_error( ).end() if parent_run_id is None and level == "DEFAULT": - self._persist_resume_trace_context(observation) + self._persist_resume_trace_context( + run_id=run_id, observation=observation + ) elif parent_run_id is None: - self._clear_resume_trace_context() + self._clear_root_run_resume_key(run_id) except Exception as e: langfuse_logger.exception(e) + finally: + if parent_run_id is None: + self._reset() def __on_llm_action( self, @@ -965,6 +1029,8 @@ def __on_llm_action( **kwargs: Any, ) -> None: self._child_to_parent_run_id_map[run_id] = parent_run_id + if parent_run_id is None: + self._set_root_run_resume_key(run_id, metadata) try: tools = kwargs.get("invocation_params", {}).get("tools", None) @@ -1012,7 +1078,7 @@ def __on_llm_action( parent_observation = self._get_parent_observation(parent_run_id) if isinstance(parent_observation, Langfuse): generation = parent_observation.start_observation( - trace_context=self._consume_root_trace_context(), + trace_context=self._trace_context, as_type="generation", **content, ) # type: ignore @@ -1132,7 +1198,7 @@ def on_llm_end( self._updated_completion_start_time_memo.discard(run_id) if parent_run_id is None: - self._clear_resume_trace_context() + self._clear_root_run_resume_key(run_id) self._reset() def on_llm_error( @@ -1161,15 +1227,21 @@ def on_llm_error( ).end() if parent_run_id is None and level == "DEFAULT": - self._persist_resume_trace_context(generation) + self._persist_resume_trace_context( + run_id=run_id, observation=generation + ) elif parent_run_id is None: - self._clear_resume_trace_context() + self._clear_root_run_resume_key(run_id) except Exception as e: langfuse_logger.exception(e) + finally: + if parent_run_id is None: + self._reset() def _reset(self) -> None: self._child_to_parent_run_id_map = {} + self._root_run_resume_key_map = {} def _exit_propagation_context(self) -> None: manager = self._propagation_context_manager diff --git a/tests/unit/test_langchain.py b/tests/unit/test_langchain.py index a30fbf136..9f1463ca7 100644 --- a/tests/unit/test_langchain.py +++ b/tests/unit/test_langchain.py @@ -1,3 +1,4 @@ +import importlib from contextvars import copy_context from unittest.mock import patch from uuid import uuid4 @@ -13,7 +14,8 @@ from langfuse._client.attributes import LangfuseOtelSpanAttributes from langfuse.langchain import CallbackHandler -from langfuse.langchain.CallbackHandler import CONTROL_FLOW_EXCEPTION_TYPES + +callback_handler_module = importlib.import_module("langfuse.langchain.CallbackHandler") def _assert_parent_child(parent_span, child_span) -> None: @@ -171,6 +173,7 @@ def test_chat_model_error_marks_generation_error(langfuse_memory_client, get_spa "boom" in span.attributes[LangfuseOtelSpanAttributes.OBSERVATION_STATUS_MESSAGE] ) + def test_root_chain_metadata_propagates_trace_name( langfuse_memory_client, get_span, find_spans ): @@ -253,132 +256,161 @@ def test_root_chain_exports_when_end_runs_in_copied_context( def test_control_flow_errors_use_default_level_and_keep_status_message( - langfuse_memory_client, get_span + langfuse_memory_client, get_span, monkeypatch ): class DummyControlFlowError(RuntimeError): pass - original_control_flow_types = set(CONTROL_FLOW_EXCEPTION_TYPES) - CONTROL_FLOW_EXCEPTION_TYPES.clear() - CONTROL_FLOW_EXCEPTION_TYPES.add(DummyControlFlowError) + monkeypatch.setattr( + callback_handler_module, + "CONTROL_FLOW_EXCEPTION_TYPES", + {DummyControlFlowError}, + ) - try: - handler = CallbackHandler() + handler = CallbackHandler() - tool_run_id = uuid4() - retriever_run_id = uuid4() - llm_run_id = uuid4() - chain_run_id = uuid4() + tool_run_id = uuid4() + retriever_run_id = uuid4() + llm_run_id = uuid4() + chain_run_id = uuid4() - handler.on_tool_start( - {"name": "human_approval"}, - "{}", - run_id=tool_run_id, - ) - handler.on_tool_error( - DummyControlFlowError("tool interrupt"), - run_id=tool_run_id, - ) + handler.on_tool_start( + {"name": "human_approval"}, + "{}", + run_id=tool_run_id, + ) + handler.on_tool_error( + DummyControlFlowError("tool interrupt"), + run_id=tool_run_id, + ) - handler.on_retriever_start( - {"name": "knowledge_base"}, - "approval policy", - run_id=retriever_run_id, - ) - handler.on_retriever_error( - DummyControlFlowError("retriever bubble-up"), - run_id=retriever_run_id, - ) + handler.on_retriever_start( + {"name": "knowledge_base"}, + "approval policy", + run_id=retriever_run_id, + ) + handler.on_retriever_error( + DummyControlFlowError("retriever bubble-up"), + run_id=retriever_run_id, + ) - handler.on_llm_start( - {"name": "TestLLM"}, - ["need approval"], - run_id=llm_run_id, - invocation_params={}, - ) - handler.on_llm_error( - DummyControlFlowError("llm bubble-up"), - run_id=llm_run_id, - ) + handler.on_llm_start( + {"name": "TestLLM"}, + ["need approval"], + run_id=llm_run_id, + invocation_params={}, + ) + handler.on_llm_error( + DummyControlFlowError("llm bubble-up"), + run_id=llm_run_id, + ) - handler.on_chain_start( - {"name": "LangGraph"}, - {"messages": ["need approval"]}, - run_id=chain_run_id, + handler.on_chain_start( + {"name": "LangGraph"}, + {"messages": ["need approval"]}, + run_id=chain_run_id, + ) + handler.on_chain_error( + DummyControlFlowError("graph interrupt"), + run_id=chain_run_id, + ) + + handler._langfuse_client.flush() + + for span_name, message in [ + ("human_approval", "tool interrupt"), + ("knowledge_base", "retriever bubble-up"), + ("TestLLM", "llm bubble-up"), + ("LangGraph", "graph interrupt"), + ]: + span = get_span(span_name) + assert ( + span.attributes[LangfuseOtelSpanAttributes.OBSERVATION_LEVEL] + == "DEFAULT" ) - handler.on_chain_error( - DummyControlFlowError("graph interrupt"), - run_id=chain_run_id, + assert ( + span.attributes[LangfuseOtelSpanAttributes.OBSERVATION_STATUS_MESSAGE] + == message ) - handler._langfuse_client.flush() - for span_name, message in [ - ("human_approval", "tool interrupt"), - ("knowledge_base", "retriever bubble-up"), - ("TestLLM", "llm bubble-up"), - ("LangGraph", "graph interrupt"), - ]: - span = get_span(span_name) - assert ( - span.attributes[LangfuseOtelSpanAttributes.OBSERVATION_LEVEL] - == "DEFAULT" - ) - assert ( - span.attributes[LangfuseOtelSpanAttributes.OBSERVATION_STATUS_MESSAGE] - == message - ) - finally: - CONTROL_FLOW_EXCEPTION_TYPES.clear() - CONTROL_FLOW_EXCEPTION_TYPES.update(original_control_flow_types) - - -def test_control_flow_resume_reuses_trace_until_terminal_completion( - memory_exporter, langfuse_memory_client +def test_control_flow_resume_uses_thread_keyed_explicit_resume_context( + memory_exporter, langfuse_memory_client, monkeypatch ): class DummyControlFlowError(RuntimeError): pass + Command = pytest.importorskip("langgraph.types").Command + context_token = otel_context.attach(otel_context.Context()) - original_control_flow_types = set(CONTROL_FLOW_EXCEPTION_TYPES) - CONTROL_FLOW_EXCEPTION_TYPES.clear() - CONTROL_FLOW_EXCEPTION_TYPES.add(DummyControlFlowError) + monkeypatch.setattr( + callback_handler_module, + "CONTROL_FLOW_EXCEPTION_TYPES", + {DummyControlFlowError}, + ) try: handler = CallbackHandler() - interrupted_run_id = uuid4() - resumed_run_id = uuid4() - fresh_run_id = uuid4() + thread_one_interrupt_run_id = uuid4() + thread_two_interrupt_run_id = uuid4() + thread_one_fresh_run_id = uuid4() + thread_two_resume_run_id = uuid4() + thread_one_resume_run_id = uuid4() handler.on_chain_start( {"name": "LangGraph"}, {"messages": ["need approval"]}, - run_id=interrupted_run_id, + run_id=thread_one_interrupt_run_id, + metadata={"thread_id": "thread-1"}, ) handler.on_chain_error( - DummyControlFlowError("graph interrupt"), - run_id=interrupted_run_id, + DummyControlFlowError("graph interrupt 1"), + run_id=thread_one_interrupt_run_id, ) handler.on_chain_start( {"name": "LangGraph"}, - {"resume": True}, - run_id=resumed_run_id, + {"messages": ["need approval"]}, + run_id=thread_two_interrupt_run_id, + metadata={"thread_id": "thread-2"}, ) - handler.on_chain_end( - {"messages": ["approved"]}, - run_id=resumed_run_id, + handler.on_chain_error( + DummyControlFlowError("graph interrupt 2"), + run_id=thread_two_interrupt_run_id, ) handler.on_chain_start( {"name": "LangGraph"}, {"messages": ["fresh invocation"]}, - run_id=fresh_run_id, + run_id=thread_one_fresh_run_id, + metadata={"thread_id": "thread-1"}, ) handler.on_chain_end( {"messages": ["completed"]}, - run_id=fresh_run_id, + run_id=thread_one_fresh_run_id, + ) + + handler.on_chain_start( + {"name": "LangGraph"}, + Command(resume={"approved": True}), + run_id=thread_two_resume_run_id, + metadata={"thread_id": "thread-2"}, + ) + handler.on_chain_end( + {"messages": ["approved"]}, + run_id=thread_two_resume_run_id, + ) + + handler.on_chain_start( + {"name": "LangGraph"}, + Command(resume={"approved": True}), + run_id=thread_one_resume_run_id, + metadata={"thread_id": "thread-1"}, + ) + handler.on_chain_end( + {"messages": ["approved"]}, + run_id=thread_one_resume_run_id, ) handler._langfuse_client.flush() @@ -389,23 +421,26 @@ class DummyControlFlowError(RuntimeError): if span.name == "LangGraph" ] - assert len(root_spans) == 3 + assert len(root_spans) == 5 spans_by_trace_id = {} for span in root_spans: spans_by_trace_id.setdefault(span.context.trace_id, []).append(span) - assert sorted(len(spans) for spans in spans_by_trace_id.values()) == [1, 2] + assert sorted(len(spans) for spans in spans_by_trace_id.values()) == [1, 2, 2] - resumed_trace_spans = next( + resumed_trace_spans = [ spans for spans in spans_by_trace_id.values() if len(spans) == 2 - ) - initial_span = next(span for span in resumed_trace_spans if span.parent is None) - resumed_span = next(span for span in resumed_trace_spans if span.parent is not None) - fresh_span = next(span for span in root_spans if span.context.trace_id != initial_span.context.trace_id) + ] + assert len(resumed_trace_spans) == 2 + + for spans in resumed_trace_spans: + initial_span = next(span for span in spans if span.parent is None) + resumed_span = next(span for span in spans if span.parent is not None) + assert resumed_span.parent.span_id == initial_span.context.span_id - assert resumed_span.parent.span_id == initial_span.context.span_id - assert fresh_span.parent is None + fresh_trace_spans = next( + spans for spans in spans_by_trace_id.values() if len(spans) == 1 + ) + assert fresh_trace_spans[0].parent is None finally: - CONTROL_FLOW_EXCEPTION_TYPES.clear() - CONTROL_FLOW_EXCEPTION_TYPES.update(original_control_flow_types) otel_context.detach(context_token) From 12bc17b58b32988f367c3a9bffb5a5e4a3ac608e Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Mon, 20 Apr 2026 16:01:04 +0200 Subject: [PATCH 4/6] fix(langchain): harden langgraph resume state --- langfuse/langchain/CallbackHandler.py | 95 +++++++--- tests/unit/test_langchain.py | 241 +++++++++++++++++++++++++- 2 files changed, 314 insertions(+), 22 deletions(-) diff --git a/langfuse/langchain/CallbackHandler.py b/langfuse/langchain/CallbackHandler.py index 986548844..f4f80fa18 100644 --- a/langfuse/langchain/CallbackHandler.py +++ b/langfuse/langchain/CallbackHandler.py @@ -194,28 +194,41 @@ def _is_langgraph_resume(self, inputs: Any) -> bool: inputs, LANGGRAPH_COMMAND_TYPE ) - def _consume_root_trace_context( + def _take_root_trace_context( self, *, inputs: Any, metadata: Optional[Dict[str, Any]] - ) -> Optional[TraceContext]: + ) -> tuple[Optional[str], Optional[TraceContext]]: if self._trace_context is not None: - return self._trace_context + return None, self._trace_context current_span_context = trace.get_current_span().get_span_context() # Only reuse the pending resume context when this callback run has no active # parent span of its own. Nested callbacks should attach normally. if current_span_context.is_valid: - return None + return None, None # Only explicit LangGraph resumes should consume pending trace linkage. if not self._is_langgraph_resume(inputs): - return None + return None, None resume_key = self._get_langgraph_resume_key(metadata) if resume_key is None: - return None + return None, None + + return resume_key, self._resume_trace_context_by_key.pop(resume_key, None) + + def _restore_root_trace_context( + self, *, resume_key: Optional[str], trace_context: Optional[TraceContext] + ) -> None: + if self._trace_context is not None: + return - return self._resume_trace_context_by_key.pop(resume_key, None) + if resume_key is None or trace_context is None: + return + + # Span creation failed after we consumed the pending linkage, so put it + # back and let the next retry resume the interrupted trace correctly. + self._resume_trace_context_by_key.setdefault(resume_key, trace_context) def _clear_root_run_resume_key(self, run_id: UUID) -> None: # Keep the pending interrupt context until an explicit Command(resume=...) @@ -362,7 +375,7 @@ def on_retriever_error( langfuse_logger.exception(e) finally: if parent_run_id is None: - self._reset() + self._reset(run_id) def _parse_langfuse_trace_attributes( self, *, metadata: Optional[Dict[str, Any]], tags: Optional[List[str]] @@ -443,6 +456,10 @@ def on_chain_start( if parent_run_id is None: self._set_root_run_resume_key(run_id, metadata) + span = None + resume_key = None + trace_context = None + try: self._log_debug_event( "on_chain_start", run_id, parent_run_id, inputs=inputs @@ -481,10 +498,11 @@ def on_chain_start( obs = self._get_parent_observation(parent_run_id) if isinstance(obs, Langfuse): + resume_key, trace_context = self._take_root_trace_context( + inputs=inputs, metadata=metadata + ) span = obs.start_observation( - trace_context=self._consume_root_trace_context( - inputs=inputs, metadata=metadata - ), + trace_context=trace_context, name=span_name, as_type=observation_type, metadata=span_metadata, @@ -511,6 +529,13 @@ def on_chain_start( self.last_trace_id = self._runs[run_id].trace_id except Exception as e: + if span is None: + self._restore_root_trace_context( + resume_key=resume_key, trace_context=trace_context + ) + if parent_run_id is None: + self._clear_root_run_resume_key(run_id) + self._exit_propagation_context() langfuse_logger.exception(e) def _register_langfuse_prompt( @@ -701,7 +726,7 @@ def on_chain_end( finally: if parent_run_id is None: self._exit_propagation_context() - self._reset() + self._reset(run_id) def on_chain_error( self, @@ -745,7 +770,7 @@ def on_chain_error( finally: if parent_run_id is None: self._exit_propagation_context() - self._reset() + self._reset(run_id) def on_chat_model_start( self, @@ -759,6 +784,8 @@ def on_chat_model_start( **kwargs: Any, ) -> Any: self._child_to_parent_run_id_map[run_id] = parent_run_id + if parent_run_id is None: + self._set_root_run_resume_key(run_id, metadata) try: self._log_debug_event( @@ -824,6 +851,8 @@ def on_tool_start( **kwargs: Any, ) -> Any: self._child_to_parent_run_id_map[run_id] = parent_run_id + if parent_run_id is None: + self._set_root_run_resume_key(run_id, metadata) try: self._log_debug_event( @@ -883,6 +912,8 @@ def on_retriever_start( **kwargs: Any, ) -> Any: self._child_to_parent_run_id_map[run_id] = parent_run_id + if parent_run_id is None: + self._set_root_run_resume_key(run_id, metadata) try: self._log_debug_event( @@ -955,7 +986,7 @@ def on_retriever_end( langfuse_logger.exception(e) finally: if parent_run_id is None: - self._reset() + self._reset(run_id) def on_tool_end( self, @@ -980,6 +1011,9 @@ def on_tool_end( except Exception as e: langfuse_logger.exception(e) + finally: + if parent_run_id is None: + self._reset(run_id) def on_tool_error( self, @@ -1016,7 +1050,7 @@ def on_tool_error( langfuse_logger.exception(e) finally: if parent_run_id is None: - self._reset() + self._reset(run_id) def __on_llm_action( self, @@ -1199,7 +1233,7 @@ def on_llm_end( if parent_run_id is None: self._clear_root_run_resume_key(run_id) - self._reset() + self._reset(run_id) def on_llm_error( self, @@ -1237,11 +1271,32 @@ def on_llm_error( langfuse_logger.exception(e) finally: if parent_run_id is None: - self._reset() + self._reset(run_id) + + def _run_belongs_to_root(self, run_id: UUID, root_run_id: UUID) -> bool: + current_run_id: Optional[UUID] = run_id + visited: Set[UUID] = set() + + while current_run_id is not None and current_run_id not in visited: + if current_run_id == root_run_id: + return True + + visited.add(current_run_id) + current_run_id = self._child_to_parent_run_id_map.get(current_run_id) + + return False + + def _reset(self, root_run_id: UUID) -> None: + run_ids_to_clear = [ + run_id + for run_id in self._child_to_parent_run_id_map + if self._run_belongs_to_root(run_id, root_run_id) + ] + + for run_id in run_ids_to_clear: + self._child_to_parent_run_id_map.pop(run_id, None) - def _reset(self) -> None: - self._child_to_parent_run_id_map = {} - self._root_run_resume_key_map = {} + self._root_run_resume_key_map.pop(root_run_id, None) def _exit_propagation_context(self) -> None: manager = self._propagation_context_manager diff --git a/tests/unit/test_langchain.py b/tests/unit/test_langchain.py index 9f1463ca7..45d9f30c5 100644 --- a/tests/unit/test_langchain.py +++ b/tests/unit/test_langchain.py @@ -325,8 +325,7 @@ class DummyControlFlowError(RuntimeError): ]: span = get_span(span_name) assert ( - span.attributes[LangfuseOtelSpanAttributes.OBSERVATION_LEVEL] - == "DEFAULT" + span.attributes[LangfuseOtelSpanAttributes.OBSERVATION_LEVEL] == "DEFAULT" ) assert ( span.attributes[LangfuseOtelSpanAttributes.OBSERVATION_STATUS_MESSAGE] @@ -444,3 +443,241 @@ class DummyControlFlowError(RuntimeError): assert fresh_trace_spans[0].parent is None finally: otel_context.detach(context_token) + + +def test_control_flow_resume_restores_context_after_failed_root_start( + memory_exporter, langfuse_memory_client, monkeypatch +): + class DummyControlFlowError(RuntimeError): + pass + + Command = pytest.importorskip("langgraph.types").Command + + context_token = otel_context.attach(otel_context.Context()) + monkeypatch.setattr( + callback_handler_module, + "CONTROL_FLOW_EXCEPTION_TYPES", + {DummyControlFlowError}, + ) + + try: + handler = CallbackHandler() + + interrupt_run_id = uuid4() + failed_resume_run_id = uuid4() + successful_resume_run_id = uuid4() + + handler.on_chain_start( + {"name": "LangGraph"}, + {"messages": ["need approval"]}, + run_id=interrupt_run_id, + metadata={"thread_id": "thread-1"}, + ) + handler.on_chain_error( + DummyControlFlowError("graph interrupt"), + run_id=interrupt_run_id, + ) + + assert "thread-1" in handler._resume_trace_context_by_key + + with patch.object( + handler._langfuse_client, + "start_observation", + side_effect=RuntimeError("trace create failed"), + ): + handler.on_chain_start( + {"name": "LangGraph"}, + Command(resume={"approved": True}), + run_id=failed_resume_run_id, + metadata={"thread_id": "thread-1"}, + ) + + assert "thread-1" in handler._resume_trace_context_by_key + assert failed_resume_run_id not in handler._root_run_resume_key_map + assert handler._propagation_context_manager is None + + handler.on_chain_start( + {"name": "LangGraph"}, + Command(resume={"approved": True}), + run_id=successful_resume_run_id, + metadata={"thread_id": "thread-1"}, + ) + handler.on_chain_end( + {"messages": ["approved"]}, + run_id=successful_resume_run_id, + ) + + handler._langfuse_client.flush() + + root_spans = [ + span + for span in memory_exporter.get_finished_spans() + if span.name == "LangGraph" + ] + + assert len(root_spans) == 2 + + initial_span = next(span for span in root_spans if span.parent is None) + resumed_span = next(span for span in root_spans if span.parent is not None) + + assert resumed_span.parent.span_id == initial_span.context.span_id + finally: + otel_context.detach(context_token) + + +def test_root_reset_preserves_other_inflight_resume_keys( + memory_exporter, langfuse_memory_client, monkeypatch +): + class DummyControlFlowError(RuntimeError): + pass + + Command = pytest.importorskip("langgraph.types").Command + + context_token = otel_context.attach(otel_context.Context()) + monkeypatch.setattr( + callback_handler_module, + "CONTROL_FLOW_EXCEPTION_TYPES", + {DummyControlFlowError}, + ) + + try: + handler = CallbackHandler() + root_one_context = copy_context() + root_two_context = copy_context() + + root_one_run_id = uuid4() + root_two_run_id = uuid4() + root_two_resume_run_id = uuid4() + + root_one_context.run( + handler.on_chain_start, + {"name": "LangGraph"}, + {"messages": ["completed"]}, + run_id=root_one_run_id, + metadata={"thread_id": "thread-1"}, + ) + root_two_context.run( + handler.on_chain_start, + {"name": "LangGraph"}, + {"messages": ["need approval"]}, + run_id=root_two_run_id, + metadata={"thread_id": "thread-2"}, + ) + + assert handler._root_run_resume_key_map[root_two_run_id] == "thread-2" + + root_one_context.run( + handler.on_chain_end, + {"messages": ["completed"]}, + run_id=root_one_run_id, + ) + + assert handler._root_run_resume_key_map[root_two_run_id] == "thread-2" + + root_two_context.run( + handler.on_chain_error, + DummyControlFlowError("graph interrupt"), + run_id=root_two_run_id, + ) + + assert "thread-2" in handler._resume_trace_context_by_key + + root_two_context.run( + handler.on_chain_start, + {"name": "LangGraph"}, + Command(resume={"approved": True}), + run_id=root_two_resume_run_id, + metadata={"thread_id": "thread-2"}, + ) + root_two_context.run( + handler.on_chain_end, + {"messages": ["approved"]}, + run_id=root_two_resume_run_id, + ) + + handler._langfuse_client.flush() + + root_spans = [ + span + for span in memory_exporter.get_finished_spans() + if span.name == "LangGraph" + ] + + assert len(root_spans) == 3 + + spans_by_trace_id = {} + for span in root_spans: + spans_by_trace_id.setdefault(span.context.trace_id, []).append(span) + + assert sorted(len(spans) for spans in spans_by_trace_id.values()) == [1, 2] + finally: + otel_context.detach(context_token) + + +def test_root_tool_and_retriever_runs_seed_resume_keys_and_cleanup( + langfuse_memory_client, monkeypatch +): + class DummyControlFlowError(RuntimeError): + pass + + monkeypatch.setattr( + callback_handler_module, + "CONTROL_FLOW_EXCEPTION_TYPES", + {DummyControlFlowError}, + ) + + handler = CallbackHandler() + + tool_error_run_id = uuid4() + tool_end_run_id = uuid4() + retriever_run_id = uuid4() + + handler.on_tool_start( + {"name": "human_approval"}, + "{}", + run_id=tool_error_run_id, + metadata={"thread_id": "tool-error-thread"}, + ) + assert handler._root_run_resume_key_map[tool_error_run_id] == "tool-error-thread" + + handler.on_tool_error( + DummyControlFlowError("tool interrupt"), + run_id=tool_error_run_id, + ) + + assert "tool-error-thread" in handler._resume_trace_context_by_key + assert tool_error_run_id not in handler._root_run_resume_key_map + assert tool_error_run_id not in handler._child_to_parent_run_id_map + + handler.on_tool_start( + {"name": "human_approval"}, + "{}", + run_id=tool_end_run_id, + metadata={"thread_id": "tool-end-thread"}, + ) + assert handler._root_run_resume_key_map[tool_end_run_id] == "tool-end-thread" + + handler.on_tool_end( + '{"approved": true}', + run_id=tool_end_run_id, + ) + + assert tool_end_run_id not in handler._root_run_resume_key_map + assert tool_end_run_id not in handler._child_to_parent_run_id_map + + handler.on_retriever_start( + {"name": "knowledge_base"}, + "approval policy", + run_id=retriever_run_id, + metadata={"thread_id": "retriever-thread"}, + ) + assert handler._root_run_resume_key_map[retriever_run_id] == "retriever-thread" + + handler.on_retriever_error( + DummyControlFlowError("retriever interrupt"), + run_id=retriever_run_id, + ) + + assert "retriever-thread" in handler._resume_trace_context_by_key + assert retriever_run_id not in handler._root_run_resume_key_map + assert retriever_run_id not in handler._child_to_parent_run_id_map From 9414be8c5bb34316272582f89623de7370015f8e Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Mon, 20 Apr 2026 17:01:41 +0200 Subject: [PATCH 5/6] fix(langchain): harden langgraph resume detection --- langfuse/langchain/CallbackHandler.py | 45 ++++++-- tests/unit/test_langchain.py | 150 ++++++++++++++++++++++++++ 2 files changed, 185 insertions(+), 10 deletions(-) diff --git a/langfuse/langchain/CallbackHandler.py b/langfuse/langchain/CallbackHandler.py index f4f80fa18..795a1807f 100644 --- a/langfuse/langchain/CallbackHandler.py +++ b/langfuse/langchain/CallbackHandler.py @@ -1,3 +1,4 @@ +from collections import OrderedDict from contextvars import Token from typing import ( Any, @@ -85,12 +86,18 @@ LANGSMITH_TAG_HIDDEN: str = "langsmith:hidden" CONTROL_FLOW_EXCEPTION_TYPES: Set[Type[BaseException]] = set() LANGGRAPH_COMMAND_TYPE: Optional[Type[Any]] = None +MAX_PENDING_RESUME_TRACE_CONTEXTS = 1024 try: from langgraph.errors import GraphBubbleUp - from langgraph.types import Command as LangGraphCommand CONTROL_FLOW_EXCEPTION_TYPES.add(GraphBubbleUp) +except ImportError: + pass + +try: + from langgraph.types import Command as LangGraphCommand + LANGGRAPH_COMMAND_TYPE = LangGraphCommand except ImportError: pass @@ -140,7 +147,9 @@ def __init__( self._trace_context = trace_context # LangGraph resumes as a fresh root callback run after interrupting, so we keep # pending resume contexts keyed by thread/session instead of a single shared slot. - self._resume_trace_context_by_key: Dict[str, TraceContext] = {} + self._resume_trace_context_by_key: OrderedDict[str, TraceContext] = ( + OrderedDict() + ) self._root_run_resume_key_map: Dict[UUID, str] = {} self._child_to_parent_run_id_map: Dict[UUID, Optional[UUID]] = {} @@ -190,10 +199,21 @@ def _pop_root_run_resume_key(self, run_id: UUID) -> Optional[str]: return self._root_run_resume_key_map.pop(run_id, None) def _is_langgraph_resume(self, inputs: Any) -> bool: - return LANGGRAPH_COMMAND_TYPE is not None and isinstance( - inputs, LANGGRAPH_COMMAND_TYPE + return ( + LANGGRAPH_COMMAND_TYPE is not None + and isinstance(inputs, LANGGRAPH_COMMAND_TYPE) + and getattr(inputs, "resume", None) is not None ) + def _store_resume_trace_context( + self, *, resume_key: str, trace_context: TraceContext + ) -> None: + self._resume_trace_context_by_key[resume_key] = trace_context + self._resume_trace_context_by_key.move_to_end(resume_key) + + if len(self._resume_trace_context_by_key) > MAX_PENDING_RESUME_TRACE_CONTEXTS: + self._resume_trace_context_by_key.popitem(last=False) + def _take_root_trace_context( self, *, inputs: Any, metadata: Optional[Dict[str, Any]] ) -> tuple[Optional[str], Optional[TraceContext]]: @@ -228,7 +248,9 @@ def _restore_root_trace_context( # Span creation failed after we consumed the pending linkage, so put it # back and let the next retry resume the interrupted trace correctly. - self._resume_trace_context_by_key.setdefault(resume_key, trace_context) + self._store_resume_trace_context( + resume_key=resume_key, trace_context=trace_context + ) def _clear_root_run_resume_key(self, run_id: UUID) -> None: # Keep the pending interrupt context until an explicit Command(resume=...) @@ -243,10 +265,13 @@ def _persist_resume_trace_context(self, *, run_id: UUID, observation: Any) -> No if resume_key is None: return - self._resume_trace_context_by_key[resume_key] = { - "trace_id": observation.trace_id, - "parent_span_id": observation.id, - } + self._store_resume_trace_context( + resume_key=resume_key, + trace_context={ + "trace_id": observation.trace_id, + "parent_span_id": observation.id, + }, + ) def _get_error_level_and_status_message( self, error: BaseException @@ -534,8 +559,8 @@ def on_chain_start( resume_key=resume_key, trace_context=trace_context ) if parent_run_id is None: - self._clear_root_run_resume_key(run_id) self._exit_propagation_context() + self._reset(run_id) langfuse_logger.exception(e) def _register_langfuse_prompt( diff --git a/tests/unit/test_langchain.py b/tests/unit/test_langchain.py index 45d9f30c5..628597537 100644 --- a/tests/unit/test_langchain.py +++ b/tests/unit/test_langchain.py @@ -494,6 +494,7 @@ class DummyControlFlowError(RuntimeError): assert "thread-1" in handler._resume_trace_context_by_key assert failed_resume_run_id not in handler._root_run_resume_key_map + assert failed_resume_run_id not in handler._child_to_parent_run_id_map assert handler._propagation_context_manager is None handler.on_chain_start( @@ -525,6 +526,92 @@ class DummyControlFlowError(RuntimeError): otel_context.detach(context_token) +def test_control_flow_resume_ignores_non_resume_commands( + memory_exporter, langfuse_memory_client, monkeypatch +): + class DummyControlFlowError(RuntimeError): + pass + + Command = pytest.importorskip("langgraph.types").Command + + context_token = otel_context.attach(otel_context.Context()) + monkeypatch.setattr( + callback_handler_module, + "CONTROL_FLOW_EXCEPTION_TYPES", + {DummyControlFlowError}, + ) + + try: + handler = CallbackHandler() + + interrupt_run_id = uuid4() + goto_run_id = uuid4() + resume_run_id = uuid4() + + handler.on_chain_start( + {"name": "LangGraph"}, + {"messages": ["need approval"]}, + run_id=interrupt_run_id, + metadata={"thread_id": "thread-1"}, + ) + handler.on_chain_error( + DummyControlFlowError("graph interrupt"), + run_id=interrupt_run_id, + ) + + handler.on_chain_start( + {"name": "LangGraph"}, + Command(goto="approval_node"), + run_id=goto_run_id, + metadata={"thread_id": "thread-1"}, + ) + handler.on_chain_end( + {"messages": ["routed"]}, + run_id=goto_run_id, + ) + + assert "thread-1" in handler._resume_trace_context_by_key + + handler.on_chain_start( + {"name": "LangGraph"}, + Command(resume={"approved": True}), + run_id=resume_run_id, + metadata={"thread_id": "thread-1"}, + ) + handler.on_chain_end( + {"messages": ["approved"]}, + run_id=resume_run_id, + ) + + handler._langfuse_client.flush() + + root_spans = [ + span + for span in memory_exporter.get_finished_spans() + if span.name == "LangGraph" + ] + + assert len(root_spans) == 3 + + spans_by_trace_id = {} + for span in root_spans: + spans_by_trace_id.setdefault(span.context.trace_id, []).append(span) + + assert sorted(len(spans) for spans in spans_by_trace_id.values()) == [1, 2] + + resumed_trace_spans = next( + spans for spans in spans_by_trace_id.values() if len(spans) == 2 + ) + initial_span = next(span for span in resumed_trace_spans if span.parent is None) + resumed_span = next( + span for span in resumed_trace_spans if span.parent is not None + ) + + assert resumed_span.parent.span_id == initial_span.context.span_id + finally: + otel_context.detach(context_token) + + def test_root_reset_preserves_other_inflight_resume_keys( memory_exporter, langfuse_memory_client, monkeypatch ): @@ -681,3 +768,66 @@ class DummyControlFlowError(RuntimeError): assert "retriever-thread" in handler._resume_trace_context_by_key assert retriever_run_id not in handler._root_run_resume_key_map assert retriever_run_id not in handler._child_to_parent_run_id_map + + +def test_pending_resume_contexts_are_capped(langfuse_memory_client, monkeypatch): + class DummyControlFlowError(RuntimeError): + pass + + monkeypatch.setattr( + callback_handler_module, + "CONTROL_FLOW_EXCEPTION_TYPES", + {DummyControlFlowError}, + ) + monkeypatch.setattr( + callback_handler_module, + "MAX_PENDING_RESUME_TRACE_CONTEXTS", + 4, + ) + + handler = CallbackHandler() + + for index in range(5): + run_id = uuid4() + thread_id = f"thread-{index}" + + handler.on_chain_start( + {"name": "LangGraph"}, + {"messages": ["need approval"]}, + run_id=run_id, + metadata={"thread_id": thread_id}, + ) + handler.on_chain_error( + DummyControlFlowError(f"graph interrupt {index}"), + run_id=run_id, + ) + + assert len(handler._resume_trace_context_by_key) == 4 + assert list(handler._resume_trace_context_by_key) == [ + "thread-1", + "thread-2", + "thread-3", + "thread-4", + ] + + +def test_graphbubbleup_import_is_independent_from_command_import(): + real_import = __import__ + + def import_without_langgraph_command( + name, globals=None, locals=None, fromlist=(), level=0 + ): + if name == "langgraph.types": + raise ImportError("Command unavailable") + + return real_import(name, globals, locals, fromlist, level) + + with patch("builtins.__import__", side_effect=import_without_langgraph_command): + reloaded_module = importlib.reload(callback_handler_module) + assert reloaded_module.LANGGRAPH_COMMAND_TYPE is None + assert any( + exception_type.__name__ == "GraphBubbleUp" + for exception_type in reloaded_module.CONTROL_FLOW_EXCEPTION_TYPES + ) + + importlib.reload(callback_handler_module) From 55bb373a3ca8f83275cfcf61e28da2b74ce1296b Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Mon, 20 Apr 2026 17:31:51 +0200 Subject: [PATCH 6/6] refactor(langchain): centralize root resume state --- langfuse/langchain/CallbackHandler.py | 208 +++++++++++++++++--------- tests/unit/test_langchain.py | 61 +++++--- 2 files changed, 176 insertions(+), 93 deletions(-) diff --git a/langfuse/langchain/CallbackHandler.py b/langfuse/langchain/CallbackHandler.py index 795a1807f..1349f6ae0 100644 --- a/langfuse/langchain/CallbackHandler.py +++ b/langfuse/langchain/CallbackHandler.py @@ -1,5 +1,6 @@ from collections import OrderedDict from contextvars import Token +from dataclasses import dataclass, field from typing import ( Any, Dict, @@ -103,6 +104,44 @@ pass +@dataclass +class _RunState: + parent_run_id: Optional[UUID] + root_run_id: UUID + + +@dataclass +class _RootRunState: + run_ids: Set[UUID] = field(default_factory=set) + resume_key: Optional[str] = None + propagation_context_manager: Optional[_AgnosticContextManager] = None + + +class _PendingResumeTraceContextStore: + def __init__(self, max_size: int) -> None: + self._max_size = max_size + self._contexts: OrderedDict[str, TraceContext] = OrderedDict() + + def store(self, *, resume_key: str, trace_context: TraceContext) -> None: + self._contexts[resume_key] = trace_context + self._contexts.move_to_end(resume_key) + + if len(self._contexts) > self._max_size: + self._contexts.popitem(last=False) + + def take(self, resume_key: str) -> Optional[TraceContext]: + return self._contexts.pop(resume_key, None) + + def __contains__(self, resume_key: str) -> bool: + return resume_key in self._contexts + + def __len__(self) -> int: + return len(self._contexts) + + def keys(self) -> List[str]: + return list(self._contexts.keys()) + + class LangchainCallbackHandler(LangchainBaseCallbackHandler): def __init__( self, @@ -143,15 +182,12 @@ def __init__( self._context_tokens: Dict[UUID, Token] = {} self._prompt_to_parent_run_map: Dict[UUID, Any] = {} self._updated_completion_start_time_memo: Set[UUID] = set() - self._propagation_context_manager: Optional[_AgnosticContextManager] = None self._trace_context = trace_context - # LangGraph resumes as a fresh root callback run after interrupting, so we keep - # pending resume contexts keyed by thread/session instead of a single shared slot. - self._resume_trace_context_by_key: OrderedDict[str, TraceContext] = ( - OrderedDict() + self._pending_resume_trace_contexts = _PendingResumeTraceContextStore( + MAX_PENDING_RESUME_TRACE_CONTEXTS ) - self._root_run_resume_key_map: Dict[UUID, str] = {} - self._child_to_parent_run_id_map: Dict[UUID, Optional[UUID]] = {} + self._run_states: Dict[UUID, _RunState] = {} + self._root_run_states: Dict[UUID, _RootRunState] = {} self.last_trace_id: Optional[str] = None @@ -187,16 +223,62 @@ def _get_langgraph_resume_key( return str(thread_id) - def _set_root_run_resume_key( - self, run_id: UUID, metadata: Optional[Dict[str, Any]] + def _track_run( + self, + *, + run_id: UUID, + parent_run_id: Optional[UUID], + metadata: Optional[Dict[str, Any]] = None, ) -> None: - resume_key = self._get_langgraph_resume_key(metadata) + if run_id in self._run_states: + return + + if parent_run_id is None: + root_run_id = run_id + self._root_run_states[root_run_id] = _RootRunState( + run_ids={run_id}, + resume_key=self._get_langgraph_resume_key(metadata), + ) + else: + parent_state = self._run_states.get(parent_run_id) + root_run_id = ( + parent_state.root_run_id if parent_state is not None else parent_run_id + ) + root_run_state = self._root_run_states.setdefault( + root_run_id, _RootRunState() + ) + root_run_state.run_ids.add(run_id) + + self._run_states[run_id] = _RunState( + parent_run_id=parent_run_id, + root_run_id=root_run_id, + ) + + def _get_run_state(self, run_id: UUID) -> Optional[_RunState]: + return self._run_states.get(run_id) - if resume_key is not None: - self._root_run_resume_key_map[run_id] = resume_key + def _get_root_run_state(self, run_id: UUID) -> Optional[_RootRunState]: + run_state = self._get_run_state(run_id) + + if run_state is None: + return None + + return self._root_run_states.get(run_state.root_run_id) def _pop_root_run_resume_key(self, run_id: UUID) -> Optional[str]: - return self._root_run_resume_key_map.pop(run_id, None) + root_run_state = self._get_root_run_state(run_id) + + if root_run_state is None: + return None + + resume_key = root_run_state.resume_key + root_run_state.resume_key = None + + return resume_key + + def _get_parent_run_id(self, run_id: UUID) -> Optional[UUID]: + run_state = self._get_run_state(run_id) + return run_state.parent_run_id if run_state is not None else None def _is_langgraph_resume(self, inputs: Any) -> bool: return ( @@ -208,11 +290,9 @@ def _is_langgraph_resume(self, inputs: Any) -> bool: def _store_resume_trace_context( self, *, resume_key: str, trace_context: TraceContext ) -> None: - self._resume_trace_context_by_key[resume_key] = trace_context - self._resume_trace_context_by_key.move_to_end(resume_key) - - if len(self._resume_trace_context_by_key) > MAX_PENDING_RESUME_TRACE_CONTEXTS: - self._resume_trace_context_by_key.popitem(last=False) + self._pending_resume_trace_contexts.store( + resume_key=resume_key, trace_context=trace_context + ) def _take_root_trace_context( self, *, inputs: Any, metadata: Optional[Dict[str, Any]] @@ -235,7 +315,7 @@ def _take_root_trace_context( if resume_key is None: return None, None - return resume_key, self._resume_trace_context_by_key.pop(resume_key, None) + return resume_key, self._pending_resume_trace_contexts.take(resume_key) def _restore_root_trace_context( self, *, resume_key: Optional[str], trace_context: Optional[TraceContext] @@ -477,9 +557,7 @@ def on_chain_start( metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: - self._child_to_parent_run_id_map[run_id] = parent_run_id - if parent_run_id is None: - self._set_root_run_resume_key(run_id, metadata) + self._track_run(run_id=run_id, parent_run_id=parent_run_id, metadata=metadata) span = None resume_key = None @@ -511,7 +589,7 @@ def on_chain_start( metadata=metadata, tags=tags ) - self._propagation_context_manager = propagate_attributes( + propagation_context_manager = propagate_attributes( user_id=parsed_trace_attributes.get("user_id", None), session_id=parsed_trace_attributes.get("session_id", None), tags=parsed_trace_attributes.get("tags", None), @@ -519,7 +597,13 @@ def on_chain_start( trace_name=parsed_trace_attributes.get("trace_name", None), ) - self._propagation_context_manager.__enter__() + root_run_state = self._get_root_run_state(run_id) + if root_run_state is not None: + root_run_state.propagation_context_manager = ( + propagation_context_manager + ) + + propagation_context_manager.__enter__() obs = self._get_parent_observation(parent_run_id) if isinstance(obs, Langfuse): @@ -559,7 +643,7 @@ def on_chain_start( resume_key=resume_key, trace_context=trace_context ) if parent_run_id is None: - self._exit_propagation_context() + self._exit_propagation_context(run_id) self._reset(run_id) langfuse_logger.exception(e) @@ -665,7 +749,7 @@ def on_agent_action( **kwargs: Any, ) -> Any: """Run on agent action.""" - self._child_to_parent_run_id_map[run_id] = parent_run_id + self._track_run(run_id=run_id, parent_run_id=parent_run_id) try: self._log_debug_event( @@ -739,7 +823,7 @@ def on_chain_end( if parent_run_id is None: self._clear_root_run_resume_key(run_id) - self._exit_propagation_context() + self._exit_propagation_context(run_id) span.end() @@ -750,7 +834,7 @@ def on_chain_end( finally: if parent_run_id is None: - self._exit_propagation_context() + self._exit_propagation_context(run_id) self._reset(run_id) def on_chain_error( @@ -786,7 +870,7 @@ def on_chain_error( ) else: self._clear_root_run_resume_key(run_id) - self._exit_propagation_context() + self._exit_propagation_context(run_id) observation.end() @@ -794,7 +878,7 @@ def on_chain_error( langfuse_logger.exception(e) finally: if parent_run_id is None: - self._exit_propagation_context() + self._exit_propagation_context(run_id) self._reset(run_id) def on_chat_model_start( @@ -808,9 +892,7 @@ def on_chat_model_start( metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: - self._child_to_parent_run_id_map[run_id] = parent_run_id - if parent_run_id is None: - self._set_root_run_resume_key(run_id, metadata) + self._track_run(run_id=run_id, parent_run_id=parent_run_id, metadata=metadata) try: self._log_debug_event( @@ -844,9 +926,7 @@ def on_llm_start( metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: - self._child_to_parent_run_id_map[run_id] = parent_run_id - if parent_run_id is None: - self._set_root_run_resume_key(run_id, metadata) + self._track_run(run_id=run_id, parent_run_id=parent_run_id, metadata=metadata) try: self._log_debug_event( @@ -875,9 +955,7 @@ def on_tool_start( metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: - self._child_to_parent_run_id_map[run_id] = parent_run_id - if parent_run_id is None: - self._set_root_run_resume_key(run_id, metadata) + self._track_run(run_id=run_id, parent_run_id=parent_run_id, metadata=metadata) try: self._log_debug_event( @@ -936,9 +1014,7 @@ def on_retriever_start( metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: - self._child_to_parent_run_id_map[run_id] = parent_run_id - if parent_run_id is None: - self._set_root_run_resume_key(run_id, metadata) + self._track_run(run_id=run_id, parent_run_id=parent_run_id, metadata=metadata) try: self._log_debug_event( @@ -1087,9 +1163,7 @@ def __on_llm_action( metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: - self._child_to_parent_run_id_map[run_id] = parent_run_id - if parent_run_id is None: - self._set_root_run_resume_key(run_id, metadata) + self._track_run(run_id=run_id, parent_run_id=parent_run_id, metadata=metadata) try: tools = kwargs.get("invocation_params", {}).get("tools", None) @@ -1113,8 +1187,8 @@ def __on_llm_action( self._deregister_langfuse_prompt(current_parent_run_id) break else: - current_parent_run_id = self._child_to_parent_run_id_map.get( - current_parent_run_id, None + current_parent_run_id = self._get_parent_run_id( + current_parent_run_id ) content = { @@ -1298,38 +1372,30 @@ def on_llm_error( if parent_run_id is None: self._reset(run_id) - def _run_belongs_to_root(self, run_id: UUID, root_run_id: UUID) -> bool: - current_run_id: Optional[UUID] = run_id - visited: Set[UUID] = set() - - while current_run_id is not None and current_run_id not in visited: - if current_run_id == root_run_id: - return True - - visited.add(current_run_id) - current_run_id = self._child_to_parent_run_id_map.get(current_run_id) - - return False - def _reset(self, root_run_id: UUID) -> None: - run_ids_to_clear = [ - run_id - for run_id in self._child_to_parent_run_id_map - if self._run_belongs_to_root(run_id, root_run_id) - ] + run_state = self._get_run_state(root_run_id) + if run_state is None: + return - for run_id in run_ids_to_clear: - self._child_to_parent_run_id_map.pop(run_id, None) + root_run_state = self._root_run_states.pop(run_state.root_run_id, None) + if root_run_state is None: + self._run_states.pop(root_run_id, None) + return + + for run_id in root_run_state.run_ids: + self._run_states.pop(run_id, None) - self._root_run_resume_key_map.pop(root_run_id, None) + def _exit_propagation_context(self, run_id: UUID) -> None: + root_run_state = self._get_root_run_state(run_id) - def _exit_propagation_context(self) -> None: - manager = self._propagation_context_manager + if root_run_state is None: + return + manager = root_run_state.propagation_context_manager if manager is None: return - self._propagation_context_manager = None + root_run_state.propagation_context_manager = None manager.__exit__(None, None, None) def __join_tags_and_metadata( diff --git a/tests/unit/test_langchain.py b/tests/unit/test_langchain.py index 628597537..27298342c 100644 --- a/tests/unit/test_langchain.py +++ b/tests/unit/test_langchain.py @@ -23,6 +23,23 @@ def _assert_parent_child(parent_span, child_span) -> None: assert child_span.parent.span_id == parent_span.context.span_id +def _has_pending_resume_context(handler, resume_key: str) -> bool: + return resume_key in handler._pending_resume_trace_contexts + + +def _pending_resume_context_keys(handler) -> list[str]: + return handler._pending_resume_trace_contexts.keys() + + +def _get_root_resume_key(handler, root_run_id): + root_run_state = handler._root_run_states.get(root_run_id) + return None if root_run_state is None else root_run_state.resume_key + + +def _has_run_state(handler, run_id) -> bool: + return run_id in handler._run_states + + def test_chat_model_callback_exports_generation_span( langfuse_memory_client, get_span, json_attr ): @@ -478,7 +495,7 @@ class DummyControlFlowError(RuntimeError): run_id=interrupt_run_id, ) - assert "thread-1" in handler._resume_trace_context_by_key + assert _has_pending_resume_context(handler, "thread-1") with patch.object( handler._langfuse_client, @@ -492,10 +509,10 @@ class DummyControlFlowError(RuntimeError): metadata={"thread_id": "thread-1"}, ) - assert "thread-1" in handler._resume_trace_context_by_key - assert failed_resume_run_id not in handler._root_run_resume_key_map - assert failed_resume_run_id not in handler._child_to_parent_run_id_map - assert handler._propagation_context_manager is None + assert _has_pending_resume_context(handler, "thread-1") + assert _get_root_resume_key(handler, failed_resume_run_id) is None + assert not _has_run_state(handler, failed_resume_run_id) + assert failed_resume_run_id not in handler._root_run_states handler.on_chain_start( {"name": "LangGraph"}, @@ -570,7 +587,7 @@ class DummyControlFlowError(RuntimeError): run_id=goto_run_id, ) - assert "thread-1" in handler._resume_trace_context_by_key + assert _has_pending_resume_context(handler, "thread-1") handler.on_chain_start( {"name": "LangGraph"}, @@ -651,7 +668,7 @@ class DummyControlFlowError(RuntimeError): metadata={"thread_id": "thread-2"}, ) - assert handler._root_run_resume_key_map[root_two_run_id] == "thread-2" + assert _get_root_resume_key(handler, root_two_run_id) == "thread-2" root_one_context.run( handler.on_chain_end, @@ -659,7 +676,7 @@ class DummyControlFlowError(RuntimeError): run_id=root_one_run_id, ) - assert handler._root_run_resume_key_map[root_two_run_id] == "thread-2" + assert _get_root_resume_key(handler, root_two_run_id) == "thread-2" root_two_context.run( handler.on_chain_error, @@ -667,7 +684,7 @@ class DummyControlFlowError(RuntimeError): run_id=root_two_run_id, ) - assert "thread-2" in handler._resume_trace_context_by_key + assert _has_pending_resume_context(handler, "thread-2") root_two_context.run( handler.on_chain_start, @@ -725,16 +742,16 @@ class DummyControlFlowError(RuntimeError): run_id=tool_error_run_id, metadata={"thread_id": "tool-error-thread"}, ) - assert handler._root_run_resume_key_map[tool_error_run_id] == "tool-error-thread" + assert _get_root_resume_key(handler, tool_error_run_id) == "tool-error-thread" handler.on_tool_error( DummyControlFlowError("tool interrupt"), run_id=tool_error_run_id, ) - assert "tool-error-thread" in handler._resume_trace_context_by_key - assert tool_error_run_id not in handler._root_run_resume_key_map - assert tool_error_run_id not in handler._child_to_parent_run_id_map + assert _has_pending_resume_context(handler, "tool-error-thread") + assert _get_root_resume_key(handler, tool_error_run_id) is None + assert not _has_run_state(handler, tool_error_run_id) handler.on_tool_start( {"name": "human_approval"}, @@ -742,15 +759,15 @@ class DummyControlFlowError(RuntimeError): run_id=tool_end_run_id, metadata={"thread_id": "tool-end-thread"}, ) - assert handler._root_run_resume_key_map[tool_end_run_id] == "tool-end-thread" + assert _get_root_resume_key(handler, tool_end_run_id) == "tool-end-thread" handler.on_tool_end( '{"approved": true}', run_id=tool_end_run_id, ) - assert tool_end_run_id not in handler._root_run_resume_key_map - assert tool_end_run_id not in handler._child_to_parent_run_id_map + assert _get_root_resume_key(handler, tool_end_run_id) is None + assert not _has_run_state(handler, tool_end_run_id) handler.on_retriever_start( {"name": "knowledge_base"}, @@ -758,16 +775,16 @@ class DummyControlFlowError(RuntimeError): run_id=retriever_run_id, metadata={"thread_id": "retriever-thread"}, ) - assert handler._root_run_resume_key_map[retriever_run_id] == "retriever-thread" + assert _get_root_resume_key(handler, retriever_run_id) == "retriever-thread" handler.on_retriever_error( DummyControlFlowError("retriever interrupt"), run_id=retriever_run_id, ) - assert "retriever-thread" in handler._resume_trace_context_by_key - assert retriever_run_id not in handler._root_run_resume_key_map - assert retriever_run_id not in handler._child_to_parent_run_id_map + assert _has_pending_resume_context(handler, "retriever-thread") + assert _get_root_resume_key(handler, retriever_run_id) is None + assert not _has_run_state(handler, retriever_run_id) def test_pending_resume_contexts_are_capped(langfuse_memory_client, monkeypatch): @@ -802,8 +819,8 @@ class DummyControlFlowError(RuntimeError): run_id=run_id, ) - assert len(handler._resume_trace_context_by_key) == 4 - assert list(handler._resume_trace_context_by_key) == [ + assert len(handler._pending_resume_trace_contexts) == 4 + assert _pending_resume_context_keys(handler) == [ "thread-1", "thread-2", "thread-3",