diff --git a/langfuse/langchain/CallbackHandler.py b/langfuse/langchain/CallbackHandler.py index 80b7114e5..1349f6ae0 100644 --- a/langfuse/langchain/CallbackHandler.py +++ b/langfuse/langchain/CallbackHandler.py @@ -1,4 +1,6 @@ +from collections import OrderedDict from contextvars import Token +from dataclasses import dataclass, field from typing import ( Any, Dict, @@ -84,6 +86,8 @@ LANGSMITH_TAG_HIDDEN: str = "langsmith:hidden" CONTROL_FLOW_EXCEPTION_TYPES: Set[Type[BaseException]] = set() +LANGGRAPH_COMMAND_TYPE: Optional[Type[Any]] = None +MAX_PENDING_RESUME_TRACE_CONTEXTS = 1024 try: from langgraph.errors import GraphBubbleUp @@ -92,6 +96,51 @@ except ImportError: pass +try: + from langgraph.types import Command as LangGraphCommand + + LANGGRAPH_COMMAND_TYPE = LangGraphCommand +except ImportError: + pass + + +@dataclass +class _RunState: + parent_run_id: Optional[UUID] + root_run_id: UUID + + +@dataclass +class _RootRunState: + run_ids: Set[UUID] = field(default_factory=set) + resume_key: Optional[str] = None + propagation_context_manager: Optional[_AgnosticContextManager] = None + + +class _PendingResumeTraceContextStore: + def __init__(self, max_size: int) -> None: + self._max_size = max_size + self._contexts: OrderedDict[str, TraceContext] = OrderedDict() + + def store(self, *, resume_key: str, trace_context: TraceContext) -> None: + self._contexts[resume_key] = trace_context + self._contexts.move_to_end(resume_key) + + if len(self._contexts) > self._max_size: + self._contexts.popitem(last=False) + + def take(self, resume_key: str) -> Optional[TraceContext]: + return self._contexts.pop(resume_key, None) + + def __contains__(self, resume_key: str) -> bool: + return resume_key in self._contexts + + def __len__(self) -> int: + return len(self._contexts) + + def keys(self) -> List[str]: + return list(self._contexts.keys()) + class LangchainCallbackHandler(LangchainBaseCallbackHandler): def __init__( @@ -133,9 +182,12 @@ def __init__( self._context_tokens: Dict[UUID, Token] = {} self._prompt_to_parent_run_map: Dict[UUID, Any] = {} self._updated_completion_start_time_memo: Set[UUID] = set() - self._propagation_context_manager: Optional[_AgnosticContextManager] = None self._trace_context = trace_context - self._child_to_parent_run_id_map: Dict[UUID, Optional[UUID]] = {} + self._pending_resume_trace_contexts = _PendingResumeTraceContextStore( + MAX_PENDING_RESUME_TRACE_CONTEXTS + ) + self._run_states: Dict[UUID, _RunState] = {} + self._root_run_states: Dict[UUID, _RootRunState] = {} self.last_trace_id: Optional[str] = None @@ -161,6 +213,156 @@ def on_llm_new_token( self._updated_completion_start_time_memo.add(run_id) + def _get_langgraph_resume_key( + self, metadata: Optional[Dict[str, Any]] + ) -> Optional[str]: + thread_id = metadata.get("thread_id") if metadata else None + + if thread_id is None: + return None + + return str(thread_id) + + def _track_run( + self, + *, + run_id: UUID, + parent_run_id: Optional[UUID], + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + if run_id in self._run_states: + return + + if parent_run_id is None: + root_run_id = run_id + self._root_run_states[root_run_id] = _RootRunState( + run_ids={run_id}, + resume_key=self._get_langgraph_resume_key(metadata), + ) + else: + parent_state = self._run_states.get(parent_run_id) + root_run_id = ( + parent_state.root_run_id if parent_state is not None else parent_run_id + ) + root_run_state = self._root_run_states.setdefault( + root_run_id, _RootRunState() + ) + root_run_state.run_ids.add(run_id) + + self._run_states[run_id] = _RunState( + parent_run_id=parent_run_id, + root_run_id=root_run_id, + ) + + def _get_run_state(self, run_id: UUID) -> Optional[_RunState]: + return self._run_states.get(run_id) + + def _get_root_run_state(self, run_id: UUID) -> Optional[_RootRunState]: + run_state = self._get_run_state(run_id) + + if run_state is None: + return None + + return self._root_run_states.get(run_state.root_run_id) + + def _pop_root_run_resume_key(self, run_id: UUID) -> Optional[str]: + root_run_state = self._get_root_run_state(run_id) + + if root_run_state is None: + return None + + resume_key = root_run_state.resume_key + root_run_state.resume_key = None + + return resume_key + + def _get_parent_run_id(self, run_id: UUID) -> Optional[UUID]: + run_state = self._get_run_state(run_id) + return run_state.parent_run_id if run_state is not None else None + + def _is_langgraph_resume(self, inputs: Any) -> bool: + return ( + LANGGRAPH_COMMAND_TYPE is not None + and isinstance(inputs, LANGGRAPH_COMMAND_TYPE) + and getattr(inputs, "resume", None) is not None + ) + + def _store_resume_trace_context( + self, *, resume_key: str, trace_context: TraceContext + ) -> None: + self._pending_resume_trace_contexts.store( + resume_key=resume_key, trace_context=trace_context + ) + + def _take_root_trace_context( + self, *, inputs: Any, metadata: Optional[Dict[str, Any]] + ) -> tuple[Optional[str], Optional[TraceContext]]: + if self._trace_context is not None: + return None, self._trace_context + + current_span_context = trace.get_current_span().get_span_context() + + # Only reuse the pending resume context when this callback run has no active + # parent span of its own. Nested callbacks should attach normally. + if current_span_context.is_valid: + return None, None + + # Only explicit LangGraph resumes should consume pending trace linkage. + if not self._is_langgraph_resume(inputs): + return None, None + + resume_key = self._get_langgraph_resume_key(metadata) + if resume_key is None: + return None, None + + return resume_key, self._pending_resume_trace_contexts.take(resume_key) + + def _restore_root_trace_context( + self, *, resume_key: Optional[str], trace_context: Optional[TraceContext] + ) -> None: + if self._trace_context is not None: + return + + if resume_key is None or trace_context is None: + return + + # Span creation failed after we consumed the pending linkage, so put it + # back and let the next retry resume the interrupted trace correctly. + self._store_resume_trace_context( + resume_key=resume_key, trace_context=trace_context + ) + + def _clear_root_run_resume_key(self, run_id: UUID) -> None: + # Keep the pending interrupt context until an explicit Command(resume=...) + # arrives. A separate root run on the same thread_id is not a resume. + self._pop_root_run_resume_key(run_id) + + def _persist_resume_trace_context(self, *, run_id: UUID, observation: Any) -> None: + if self._trace_context is not None: + return + + resume_key = self._pop_root_run_resume_key(run_id) + if resume_key is None: + return + + self._store_resume_trace_context( + resume_key=resume_key, + trace_context={ + "trace_id": observation.trace_id, + "parent_span_id": observation.id, + }, + ) + + def _get_error_level_and_status_message( + self, error: BaseException + ) -> tuple[Literal["DEFAULT", "ERROR"], str]: + # LangGraph uses GraphBubbleUp subclasses for expected control flow such as + # interrupts and handoffs, so they should stay visible without being errors. + if any(isinstance(error, t) for t in CONTROL_FLOW_EXCEPTION_TYPES): + return "DEFAULT", str(error) or type(error).__name__ + + return "ERROR", str(error) + def _get_observation_type_from_serialized( self, serialized: Optional[Dict[str, Any]], callback_type: str, **kwargs: Any ) -> Union[ @@ -256,15 +458,29 @@ def on_retriever_error( observation = self._detach_observation(run_id) if observation is not None: + level, status_message = self._get_error_level_and_status_message(error) observation.update( - level="ERROR", - status_message=str(error), + level=cast( + Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], + level, + ), + status_message=status_message, input=kwargs.get("inputs"), cost_details={"total": 0}, ).end() + if parent_run_id is None and level == "DEFAULT": + self._persist_resume_trace_context( + run_id=run_id, observation=observation + ) + elif parent_run_id is None: + self._clear_root_run_resume_key(run_id) + except Exception as e: langfuse_logger.exception(e) + finally: + if parent_run_id is None: + self._reset(run_id) def _parse_langfuse_trace_attributes( self, *, metadata: Optional[Dict[str, Any]], tags: Optional[List[str]] @@ -333,7 +549,7 @@ def _get_langchain_observation_metadata( def on_chain_start( self, serialized: Optional[Dict[str, Any]], - inputs: Dict[str, Any], + inputs: Any, *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -341,7 +557,11 @@ def on_chain_start( metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: - self._child_to_parent_run_id_map[run_id] = parent_run_id + self._track_run(run_id=run_id, parent_run_id=parent_run_id, metadata=metadata) + + span = None + resume_key = None + trace_context = None try: self._log_debug_event( @@ -369,7 +589,7 @@ def on_chain_start( metadata=metadata, tags=tags ) - self._propagation_context_manager = propagate_attributes( + propagation_context_manager = propagate_attributes( user_id=parsed_trace_attributes.get("user_id", None), session_id=parsed_trace_attributes.get("session_id", None), tags=parsed_trace_attributes.get("tags", None), @@ -377,12 +597,21 @@ def on_chain_start( trace_name=parsed_trace_attributes.get("trace_name", None), ) - self._propagation_context_manager.__enter__() + root_run_state = self._get_root_run_state(run_id) + if root_run_state is not None: + root_run_state.propagation_context_manager = ( + propagation_context_manager + ) + + propagation_context_manager.__enter__() obs = self._get_parent_observation(parent_run_id) if isinstance(obs, Langfuse): + resume_key, trace_context = self._take_root_trace_context( + inputs=inputs, metadata=metadata + ) span = obs.start_observation( - trace_context=self._trace_context, + trace_context=trace_context, name=span_name, as_type=observation_type, metadata=span_metadata, @@ -409,6 +638,13 @@ def on_chain_start( self.last_trace_id = self._runs[run_id].trace_id except Exception as e: + if span is None: + self._restore_root_trace_context( + resume_key=resume_key, trace_context=trace_context + ) + if parent_run_id is None: + self._exit_propagation_context(run_id) + self._reset(run_id) langfuse_logger.exception(e) def _register_langfuse_prompt( @@ -513,7 +749,7 @@ def on_agent_action( **kwargs: Any, ) -> Any: """Run on agent action.""" - self._child_to_parent_run_id_map[run_id] = parent_run_id + self._track_run(run_id=run_id, parent_run_id=parent_run_id) try: self._log_debug_event( @@ -586,7 +822,8 @@ def on_chain_end( ) if parent_run_id is None: - self._exit_propagation_context() + self._clear_root_run_resume_key(run_id) + self._exit_propagation_context(run_id) span.end() @@ -597,8 +834,8 @@ def on_chain_end( finally: if parent_run_id is None: - self._exit_propagation_context() - self._reset() + self._exit_propagation_context(run_id) + self._reset(run_id) def on_chain_error( self, @@ -611,10 +848,7 @@ def on_chain_error( ) -> None: try: self._log_debug_event("on_chain_error", run_id, parent_run_id, error=error) - if any(isinstance(error, t) for t in CONTROL_FLOW_EXCEPTION_TYPES): - level = None - else: - level = "ERROR" + level, status_message = self._get_error_level_and_status_message(error) observation = self._detach_observation(run_id) @@ -624,13 +858,19 @@ def on_chain_error( Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], level, ), - status_message=str(error) if level else None, + status_message=status_message, input=kwargs.get("inputs"), cost_details={"total": 0}, ) if parent_run_id is None: - self._exit_propagation_context() + if level == "DEFAULT": + self._persist_resume_trace_context( + run_id=run_id, observation=observation + ) + else: + self._clear_root_run_resume_key(run_id) + self._exit_propagation_context(run_id) observation.end() @@ -638,8 +878,8 @@ def on_chain_error( langfuse_logger.exception(e) finally: if parent_run_id is None: - self._exit_propagation_context() - self._reset() + self._exit_propagation_context(run_id) + self._reset(run_id) def on_chat_model_start( self, @@ -652,7 +892,7 @@ def on_chat_model_start( metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: - self._child_to_parent_run_id_map[run_id] = parent_run_id + self._track_run(run_id=run_id, parent_run_id=parent_run_id, metadata=metadata) try: self._log_debug_event( @@ -686,7 +926,7 @@ def on_llm_start( metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: - self._child_to_parent_run_id_map[run_id] = parent_run_id + self._track_run(run_id=run_id, parent_run_id=parent_run_id, metadata=metadata) try: self._log_debug_event( @@ -715,7 +955,7 @@ def on_tool_start( metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: - self._child_to_parent_run_id_map[run_id] = parent_run_id + self._track_run(run_id=run_id, parent_run_id=parent_run_id, metadata=metadata) try: self._log_debug_event( @@ -739,13 +979,24 @@ def on_tool_start( serialized, "tool", **kwargs ) - span = self._get_parent_observation(parent_run_id).start_observation( - name=self.get_langchain_run_name(serialized, **kwargs), - as_type=observation_type, - input=input_str, - metadata=meta, - level="DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None, - ) + parent_observation = self._get_parent_observation(parent_run_id) + if isinstance(parent_observation, Langfuse): + span = parent_observation.start_observation( + trace_context=self._trace_context, + name=self.get_langchain_run_name(serialized, **kwargs), + as_type=observation_type, + input=input_str, + metadata=meta, + level="DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None, + ) + else: + span = parent_observation.start_observation( + name=self.get_langchain_run_name(serialized, **kwargs), + as_type=observation_type, + input=input_str, + metadata=meta, + level="DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None, + ) self._attach_observation(run_id, span) @@ -763,7 +1014,7 @@ def on_retriever_start( metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: - self._child_to_parent_run_id_map[run_id] = parent_run_id + self._track_run(run_id=run_id, parent_run_id=parent_run_id, metadata=metadata) try: self._log_debug_event( @@ -780,16 +1031,30 @@ def on_retriever_start( observation_type = self._get_observation_type_from_serialized( serialized, "retriever", **kwargs ) - span = self._get_parent_observation(parent_run_id).start_observation( - name=span_name, - as_type=observation_type, - metadata=span_metadata, - input=query, - level=cast( - Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], - span_level, - ), - ) + parent_observation = self._get_parent_observation(parent_run_id) + if isinstance(parent_observation, Langfuse): + span = parent_observation.start_observation( + trace_context=self._trace_context, + name=span_name, + as_type=observation_type, + metadata=span_metadata, + input=query, + level=cast( + Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], + span_level, + ), + ) + else: + span = parent_observation.start_observation( + name=span_name, + as_type=observation_type, + metadata=span_metadata, + input=query, + level=cast( + Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], + span_level, + ), + ) self._attach_observation(run_id, span) @@ -811,6 +1076,8 @@ def on_retriever_end( observation = self._detach_observation(run_id) if observation is not None: + if parent_run_id is None: + self._clear_root_run_resume_key(run_id) observation.update( output=documents, input=kwargs.get("inputs"), @@ -818,6 +1085,9 @@ def on_retriever_end( except Exception as e: langfuse_logger.exception(e) + finally: + if parent_run_id is None: + self._reset(run_id) def on_tool_end( self, @@ -833,6 +1103,8 @@ def on_tool_end( observation = self._detach_observation(run_id) if observation is not None: + if parent_run_id is None: + self._clear_root_run_resume_key(run_id) observation.update( output=output, input=kwargs.get("inputs"), @@ -840,6 +1112,9 @@ def on_tool_end( except Exception as e: langfuse_logger.exception(e) + finally: + if parent_run_id is None: + self._reset(run_id) def on_tool_error( self, @@ -854,15 +1129,29 @@ def on_tool_error( observation = self._detach_observation(run_id) if observation is not None: + level, status_message = self._get_error_level_and_status_message(error) observation.update( - status_message=str(error), - level="ERROR", + status_message=status_message, + level=cast( + Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], + level, + ), input=kwargs.get("inputs"), cost_details={"total": 0}, ).end() + if parent_run_id is None and level == "DEFAULT": + self._persist_resume_trace_context( + run_id=run_id, observation=observation + ) + elif parent_run_id is None: + self._clear_root_run_resume_key(run_id) + except Exception as e: langfuse_logger.exception(e) + finally: + if parent_run_id is None: + self._reset(run_id) def __on_llm_action( self, @@ -874,7 +1163,7 @@ def __on_llm_action( metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: - self._child_to_parent_run_id_map[run_id] = parent_run_id + self._track_run(run_id=run_id, parent_run_id=parent_run_id, metadata=metadata) try: tools = kwargs.get("invocation_params", {}).get("tools", None) @@ -898,8 +1187,8 @@ def __on_llm_action( self._deregister_langfuse_prompt(current_parent_run_id) break else: - current_parent_run_id = self._child_to_parent_run_id_map.get( - current_parent_run_id, None + current_parent_run_id = self._get_parent_run_id( + current_parent_run_id ) content = { @@ -919,9 +1208,17 @@ def __on_llm_action( "prompt": registered_prompt, } - generation = self._get_parent_observation(parent_run_id).start_observation( - as_type="generation", **content - ) # type: ignore + parent_observation = self._get_parent_observation(parent_run_id) + if isinstance(parent_observation, Langfuse): + generation = parent_observation.start_observation( + trace_context=self._trace_context, + as_type="generation", + **content, + ) # type: ignore + else: + generation = parent_observation.start_observation( + as_type="generation", **content + ) # type: ignore self._attach_observation(run_id, generation) self.last_trace_id = self._runs[run_id].trace_id @@ -1034,7 +1331,8 @@ def on_llm_end( self._updated_completion_start_time_memo.discard(run_id) if parent_run_id is None: - self._reset() + self._clear_root_run_resume_key(run_id) + self._reset(run_id) def on_llm_error( self, @@ -1050,26 +1348,54 @@ def on_llm_error( generation = self._detach_observation(run_id) if generation is not None: + level, status_message = self._get_error_level_and_status_message(error) generation.update( - status_message=str(error), - level="ERROR", + status_message=status_message, + level=cast( + Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], + level, + ), input=kwargs.get("inputs"), cost_details={"total": 0}, ).end() + if parent_run_id is None and level == "DEFAULT": + self._persist_resume_trace_context( + run_id=run_id, observation=generation + ) + elif parent_run_id is None: + self._clear_root_run_resume_key(run_id) + except Exception as e: langfuse_logger.exception(e) + finally: + if parent_run_id is None: + self._reset(run_id) - def _reset(self) -> None: - self._child_to_parent_run_id_map = {} + def _reset(self, root_run_id: UUID) -> None: + run_state = self._get_run_state(root_run_id) + if run_state is None: + return - def _exit_propagation_context(self) -> None: - manager = self._propagation_context_manager + root_run_state = self._root_run_states.pop(run_state.root_run_id, None) + if root_run_state is None: + self._run_states.pop(root_run_id, None) + return + + for run_id in root_run_state.run_ids: + self._run_states.pop(run_id, None) + + def _exit_propagation_context(self, run_id: UUID) -> None: + root_run_state = self._get_root_run_state(run_id) + + if root_run_state is None: + return + manager = root_run_state.propagation_context_manager if manager is None: return - self._propagation_context_manager = None + root_run_state.propagation_context_manager = None manager.__exit__(None, None, None) def __join_tags_and_metadata( diff --git a/tests/unit/test_langchain.py b/tests/unit/test_langchain.py index 5d8406e9c..27298342c 100644 --- a/tests/unit/test_langchain.py +++ b/tests/unit/test_langchain.py @@ -1,3 +1,4 @@ +import importlib from contextvars import copy_context from unittest.mock import patch from uuid import uuid4 @@ -9,16 +10,36 @@ from langchain_core.outputs import ChatGeneration, ChatResult, Generation, LLMResult from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI, OpenAI +from opentelemetry import context as otel_context from langfuse._client.attributes import LangfuseOtelSpanAttributes from langfuse.langchain import CallbackHandler +callback_handler_module = importlib.import_module("langfuse.langchain.CallbackHandler") + def _assert_parent_child(parent_span, child_span) -> None: assert child_span.parent is not None assert child_span.parent.span_id == parent_span.context.span_id +def _has_pending_resume_context(handler, resume_key: str) -> bool: + return resume_key in handler._pending_resume_trace_contexts + + +def _pending_resume_context_keys(handler) -> list[str]: + return handler._pending_resume_trace_contexts.keys() + + +def _get_root_resume_key(handler, root_run_id): + root_run_state = handler._root_run_states.get(root_run_id) + return None if root_run_state is None else root_run_state.resume_key + + +def _has_run_state(handler, run_id) -> bool: + return run_id in handler._run_states + + def test_chat_model_callback_exports_generation_span( langfuse_memory_client, get_span, json_attr ): @@ -249,3 +270,581 @@ def test_root_chain_exports_when_end_runs_in_copied_context( assert root_span.attributes[LangfuseOtelSpanAttributes.TRACE_NAME] == ( "async-root-trace" ) + + +def test_control_flow_errors_use_default_level_and_keep_status_message( + langfuse_memory_client, get_span, monkeypatch +): + class DummyControlFlowError(RuntimeError): + pass + + monkeypatch.setattr( + callback_handler_module, + "CONTROL_FLOW_EXCEPTION_TYPES", + {DummyControlFlowError}, + ) + + handler = CallbackHandler() + + tool_run_id = uuid4() + retriever_run_id = uuid4() + llm_run_id = uuid4() + chain_run_id = uuid4() + + handler.on_tool_start( + {"name": "human_approval"}, + "{}", + run_id=tool_run_id, + ) + handler.on_tool_error( + DummyControlFlowError("tool interrupt"), + run_id=tool_run_id, + ) + + handler.on_retriever_start( + {"name": "knowledge_base"}, + "approval policy", + run_id=retriever_run_id, + ) + handler.on_retriever_error( + DummyControlFlowError("retriever bubble-up"), + run_id=retriever_run_id, + ) + + handler.on_llm_start( + {"name": "TestLLM"}, + ["need approval"], + run_id=llm_run_id, + invocation_params={}, + ) + handler.on_llm_error( + DummyControlFlowError("llm bubble-up"), + run_id=llm_run_id, + ) + + handler.on_chain_start( + {"name": "LangGraph"}, + {"messages": ["need approval"]}, + run_id=chain_run_id, + ) + handler.on_chain_error( + DummyControlFlowError("graph interrupt"), + run_id=chain_run_id, + ) + + handler._langfuse_client.flush() + + for span_name, message in [ + ("human_approval", "tool interrupt"), + ("knowledge_base", "retriever bubble-up"), + ("TestLLM", "llm bubble-up"), + ("LangGraph", "graph interrupt"), + ]: + span = get_span(span_name) + assert ( + span.attributes[LangfuseOtelSpanAttributes.OBSERVATION_LEVEL] == "DEFAULT" + ) + assert ( + span.attributes[LangfuseOtelSpanAttributes.OBSERVATION_STATUS_MESSAGE] + == message + ) + + +def test_control_flow_resume_uses_thread_keyed_explicit_resume_context( + memory_exporter, langfuse_memory_client, monkeypatch +): + class DummyControlFlowError(RuntimeError): + pass + + Command = pytest.importorskip("langgraph.types").Command + + context_token = otel_context.attach(otel_context.Context()) + monkeypatch.setattr( + callback_handler_module, + "CONTROL_FLOW_EXCEPTION_TYPES", + {DummyControlFlowError}, + ) + + try: + handler = CallbackHandler() + + thread_one_interrupt_run_id = uuid4() + thread_two_interrupt_run_id = uuid4() + thread_one_fresh_run_id = uuid4() + thread_two_resume_run_id = uuid4() + thread_one_resume_run_id = uuid4() + + handler.on_chain_start( + {"name": "LangGraph"}, + {"messages": ["need approval"]}, + run_id=thread_one_interrupt_run_id, + metadata={"thread_id": "thread-1"}, + ) + handler.on_chain_error( + DummyControlFlowError("graph interrupt 1"), + run_id=thread_one_interrupt_run_id, + ) + + handler.on_chain_start( + {"name": "LangGraph"}, + {"messages": ["need approval"]}, + run_id=thread_two_interrupt_run_id, + metadata={"thread_id": "thread-2"}, + ) + handler.on_chain_error( + DummyControlFlowError("graph interrupt 2"), + run_id=thread_two_interrupt_run_id, + ) + + handler.on_chain_start( + {"name": "LangGraph"}, + {"messages": ["fresh invocation"]}, + run_id=thread_one_fresh_run_id, + metadata={"thread_id": "thread-1"}, + ) + handler.on_chain_end( + {"messages": ["completed"]}, + run_id=thread_one_fresh_run_id, + ) + + handler.on_chain_start( + {"name": "LangGraph"}, + Command(resume={"approved": True}), + run_id=thread_two_resume_run_id, + metadata={"thread_id": "thread-2"}, + ) + handler.on_chain_end( + {"messages": ["approved"]}, + run_id=thread_two_resume_run_id, + ) + + handler.on_chain_start( + {"name": "LangGraph"}, + Command(resume={"approved": True}), + run_id=thread_one_resume_run_id, + metadata={"thread_id": "thread-1"}, + ) + handler.on_chain_end( + {"messages": ["approved"]}, + run_id=thread_one_resume_run_id, + ) + + handler._langfuse_client.flush() + + root_spans = [ + span + for span in memory_exporter.get_finished_spans() + if span.name == "LangGraph" + ] + + assert len(root_spans) == 5 + spans_by_trace_id = {} + for span in root_spans: + spans_by_trace_id.setdefault(span.context.trace_id, []).append(span) + + assert sorted(len(spans) for spans in spans_by_trace_id.values()) == [1, 2, 2] + + resumed_trace_spans = [ + spans for spans in spans_by_trace_id.values() if len(spans) == 2 + ] + assert len(resumed_trace_spans) == 2 + + for spans in resumed_trace_spans: + initial_span = next(span for span in spans if span.parent is None) + resumed_span = next(span for span in spans if span.parent is not None) + assert resumed_span.parent.span_id == initial_span.context.span_id + + fresh_trace_spans = next( + spans for spans in spans_by_trace_id.values() if len(spans) == 1 + ) + assert fresh_trace_spans[0].parent is None + finally: + otel_context.detach(context_token) + + +def test_control_flow_resume_restores_context_after_failed_root_start( + memory_exporter, langfuse_memory_client, monkeypatch +): + class DummyControlFlowError(RuntimeError): + pass + + Command = pytest.importorskip("langgraph.types").Command + + context_token = otel_context.attach(otel_context.Context()) + monkeypatch.setattr( + callback_handler_module, + "CONTROL_FLOW_EXCEPTION_TYPES", + {DummyControlFlowError}, + ) + + try: + handler = CallbackHandler() + + interrupt_run_id = uuid4() + failed_resume_run_id = uuid4() + successful_resume_run_id = uuid4() + + handler.on_chain_start( + {"name": "LangGraph"}, + {"messages": ["need approval"]}, + run_id=interrupt_run_id, + metadata={"thread_id": "thread-1"}, + ) + handler.on_chain_error( + DummyControlFlowError("graph interrupt"), + run_id=interrupt_run_id, + ) + + assert _has_pending_resume_context(handler, "thread-1") + + with patch.object( + handler._langfuse_client, + "start_observation", + side_effect=RuntimeError("trace create failed"), + ): + handler.on_chain_start( + {"name": "LangGraph"}, + Command(resume={"approved": True}), + run_id=failed_resume_run_id, + metadata={"thread_id": "thread-1"}, + ) + + assert _has_pending_resume_context(handler, "thread-1") + assert _get_root_resume_key(handler, failed_resume_run_id) is None + assert not _has_run_state(handler, failed_resume_run_id) + assert failed_resume_run_id not in handler._root_run_states + + handler.on_chain_start( + {"name": "LangGraph"}, + Command(resume={"approved": True}), + run_id=successful_resume_run_id, + metadata={"thread_id": "thread-1"}, + ) + handler.on_chain_end( + {"messages": ["approved"]}, + run_id=successful_resume_run_id, + ) + + handler._langfuse_client.flush() + + root_spans = [ + span + for span in memory_exporter.get_finished_spans() + if span.name == "LangGraph" + ] + + assert len(root_spans) == 2 + + initial_span = next(span for span in root_spans if span.parent is None) + resumed_span = next(span for span in root_spans if span.parent is not None) + + assert resumed_span.parent.span_id == initial_span.context.span_id + finally: + otel_context.detach(context_token) + + +def test_control_flow_resume_ignores_non_resume_commands( + memory_exporter, langfuse_memory_client, monkeypatch +): + class DummyControlFlowError(RuntimeError): + pass + + Command = pytest.importorskip("langgraph.types").Command + + context_token = otel_context.attach(otel_context.Context()) + monkeypatch.setattr( + callback_handler_module, + "CONTROL_FLOW_EXCEPTION_TYPES", + {DummyControlFlowError}, + ) + + try: + handler = CallbackHandler() + + interrupt_run_id = uuid4() + goto_run_id = uuid4() + resume_run_id = uuid4() + + handler.on_chain_start( + {"name": "LangGraph"}, + {"messages": ["need approval"]}, + run_id=interrupt_run_id, + metadata={"thread_id": "thread-1"}, + ) + handler.on_chain_error( + DummyControlFlowError("graph interrupt"), + run_id=interrupt_run_id, + ) + + handler.on_chain_start( + {"name": "LangGraph"}, + Command(goto="approval_node"), + run_id=goto_run_id, + metadata={"thread_id": "thread-1"}, + ) + handler.on_chain_end( + {"messages": ["routed"]}, + run_id=goto_run_id, + ) + + assert _has_pending_resume_context(handler, "thread-1") + + handler.on_chain_start( + {"name": "LangGraph"}, + Command(resume={"approved": True}), + run_id=resume_run_id, + metadata={"thread_id": "thread-1"}, + ) + handler.on_chain_end( + {"messages": ["approved"]}, + run_id=resume_run_id, + ) + + handler._langfuse_client.flush() + + root_spans = [ + span + for span in memory_exporter.get_finished_spans() + if span.name == "LangGraph" + ] + + assert len(root_spans) == 3 + + spans_by_trace_id = {} + for span in root_spans: + spans_by_trace_id.setdefault(span.context.trace_id, []).append(span) + + assert sorted(len(spans) for spans in spans_by_trace_id.values()) == [1, 2] + + resumed_trace_spans = next( + spans for spans in spans_by_trace_id.values() if len(spans) == 2 + ) + initial_span = next(span for span in resumed_trace_spans if span.parent is None) + resumed_span = next( + span for span in resumed_trace_spans if span.parent is not None + ) + + assert resumed_span.parent.span_id == initial_span.context.span_id + finally: + otel_context.detach(context_token) + + +def test_root_reset_preserves_other_inflight_resume_keys( + memory_exporter, langfuse_memory_client, monkeypatch +): + class DummyControlFlowError(RuntimeError): + pass + + Command = pytest.importorskip("langgraph.types").Command + + context_token = otel_context.attach(otel_context.Context()) + monkeypatch.setattr( + callback_handler_module, + "CONTROL_FLOW_EXCEPTION_TYPES", + {DummyControlFlowError}, + ) + + try: + handler = CallbackHandler() + root_one_context = copy_context() + root_two_context = copy_context() + + root_one_run_id = uuid4() + root_two_run_id = uuid4() + root_two_resume_run_id = uuid4() + + root_one_context.run( + handler.on_chain_start, + {"name": "LangGraph"}, + {"messages": ["completed"]}, + run_id=root_one_run_id, + metadata={"thread_id": "thread-1"}, + ) + root_two_context.run( + handler.on_chain_start, + {"name": "LangGraph"}, + {"messages": ["need approval"]}, + run_id=root_two_run_id, + metadata={"thread_id": "thread-2"}, + ) + + assert _get_root_resume_key(handler, root_two_run_id) == "thread-2" + + root_one_context.run( + handler.on_chain_end, + {"messages": ["completed"]}, + run_id=root_one_run_id, + ) + + assert _get_root_resume_key(handler, root_two_run_id) == "thread-2" + + root_two_context.run( + handler.on_chain_error, + DummyControlFlowError("graph interrupt"), + run_id=root_two_run_id, + ) + + assert _has_pending_resume_context(handler, "thread-2") + + root_two_context.run( + handler.on_chain_start, + {"name": "LangGraph"}, + Command(resume={"approved": True}), + run_id=root_two_resume_run_id, + metadata={"thread_id": "thread-2"}, + ) + root_two_context.run( + handler.on_chain_end, + {"messages": ["approved"]}, + run_id=root_two_resume_run_id, + ) + + handler._langfuse_client.flush() + + root_spans = [ + span + for span in memory_exporter.get_finished_spans() + if span.name == "LangGraph" + ] + + assert len(root_spans) == 3 + + spans_by_trace_id = {} + for span in root_spans: + spans_by_trace_id.setdefault(span.context.trace_id, []).append(span) + + assert sorted(len(spans) for spans in spans_by_trace_id.values()) == [1, 2] + finally: + otel_context.detach(context_token) + + +def test_root_tool_and_retriever_runs_seed_resume_keys_and_cleanup( + langfuse_memory_client, monkeypatch +): + class DummyControlFlowError(RuntimeError): + pass + + monkeypatch.setattr( + callback_handler_module, + "CONTROL_FLOW_EXCEPTION_TYPES", + {DummyControlFlowError}, + ) + + handler = CallbackHandler() + + tool_error_run_id = uuid4() + tool_end_run_id = uuid4() + retriever_run_id = uuid4() + + handler.on_tool_start( + {"name": "human_approval"}, + "{}", + run_id=tool_error_run_id, + metadata={"thread_id": "tool-error-thread"}, + ) + assert _get_root_resume_key(handler, tool_error_run_id) == "tool-error-thread" + + handler.on_tool_error( + DummyControlFlowError("tool interrupt"), + run_id=tool_error_run_id, + ) + + assert _has_pending_resume_context(handler, "tool-error-thread") + assert _get_root_resume_key(handler, tool_error_run_id) is None + assert not _has_run_state(handler, tool_error_run_id) + + handler.on_tool_start( + {"name": "human_approval"}, + "{}", + run_id=tool_end_run_id, + metadata={"thread_id": "tool-end-thread"}, + ) + assert _get_root_resume_key(handler, tool_end_run_id) == "tool-end-thread" + + handler.on_tool_end( + '{"approved": true}', + run_id=tool_end_run_id, + ) + + assert _get_root_resume_key(handler, tool_end_run_id) is None + assert not _has_run_state(handler, tool_end_run_id) + + handler.on_retriever_start( + {"name": "knowledge_base"}, + "approval policy", + run_id=retriever_run_id, + metadata={"thread_id": "retriever-thread"}, + ) + assert _get_root_resume_key(handler, retriever_run_id) == "retriever-thread" + + handler.on_retriever_error( + DummyControlFlowError("retriever interrupt"), + run_id=retriever_run_id, + ) + + assert _has_pending_resume_context(handler, "retriever-thread") + assert _get_root_resume_key(handler, retriever_run_id) is None + assert not _has_run_state(handler, retriever_run_id) + + +def test_pending_resume_contexts_are_capped(langfuse_memory_client, monkeypatch): + class DummyControlFlowError(RuntimeError): + pass + + monkeypatch.setattr( + callback_handler_module, + "CONTROL_FLOW_EXCEPTION_TYPES", + {DummyControlFlowError}, + ) + monkeypatch.setattr( + callback_handler_module, + "MAX_PENDING_RESUME_TRACE_CONTEXTS", + 4, + ) + + handler = CallbackHandler() + + for index in range(5): + run_id = uuid4() + thread_id = f"thread-{index}" + + handler.on_chain_start( + {"name": "LangGraph"}, + {"messages": ["need approval"]}, + run_id=run_id, + metadata={"thread_id": thread_id}, + ) + handler.on_chain_error( + DummyControlFlowError(f"graph interrupt {index}"), + run_id=run_id, + ) + + assert len(handler._pending_resume_trace_contexts) == 4 + assert _pending_resume_context_keys(handler) == [ + "thread-1", + "thread-2", + "thread-3", + "thread-4", + ] + + +def test_graphbubbleup_import_is_independent_from_command_import(): + real_import = __import__ + + def import_without_langgraph_command( + name, globals=None, locals=None, fromlist=(), level=0 + ): + if name == "langgraph.types": + raise ImportError("Command unavailable") + + return real_import(name, globals, locals, fromlist, level) + + with patch("builtins.__import__", side_effect=import_without_langgraph_command): + reloaded_module = importlib.reload(callback_handler_module) + assert reloaded_module.LANGGRAPH_COMMAND_TYPE is None + assert any( + exception_type.__name__ == "GraphBubbleUp" + for exception_type in reloaded_module.CONTROL_FLOW_EXCEPTION_TYPES + ) + + importlib.reload(callback_handler_module)