Skip to content

Commit 9414be8

Browse files
committed
fix(langchain): harden langgraph resume detection
1 parent 12bc17b commit 9414be8

File tree

2 files changed

+185
-10
lines changed

2 files changed

+185
-10
lines changed

langfuse/langchain/CallbackHandler.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import OrderedDict
12
from contextvars import Token
23
from typing import (
34
Any,
@@ -85,12 +86,18 @@
8586
LANGSMITH_TAG_HIDDEN: str = "langsmith:hidden"
8687
CONTROL_FLOW_EXCEPTION_TYPES: Set[Type[BaseException]] = set()
8788
LANGGRAPH_COMMAND_TYPE: Optional[Type[Any]] = None
89+
MAX_PENDING_RESUME_TRACE_CONTEXTS = 1024
8890

8991
try:
9092
from langgraph.errors import GraphBubbleUp
91-
from langgraph.types import Command as LangGraphCommand
9293

9394
CONTROL_FLOW_EXCEPTION_TYPES.add(GraphBubbleUp)
95+
except ImportError:
96+
pass
97+
98+
try:
99+
from langgraph.types import Command as LangGraphCommand
100+
94101
LANGGRAPH_COMMAND_TYPE = LangGraphCommand
95102
except ImportError:
96103
pass
@@ -140,7 +147,9 @@ def __init__(
140147
self._trace_context = trace_context
141148
# LangGraph resumes as a fresh root callback run after interrupting, so we keep
142149
# pending resume contexts keyed by thread/session instead of a single shared slot.
143-
self._resume_trace_context_by_key: Dict[str, TraceContext] = {}
150+
self._resume_trace_context_by_key: OrderedDict[str, TraceContext] = (
151+
OrderedDict()
152+
)
144153
self._root_run_resume_key_map: Dict[UUID, str] = {}
145154
self._child_to_parent_run_id_map: Dict[UUID, Optional[UUID]] = {}
146155

@@ -190,10 +199,21 @@ def _pop_root_run_resume_key(self, run_id: UUID) -> Optional[str]:
190199
return self._root_run_resume_key_map.pop(run_id, None)
191200

192201
def _is_langgraph_resume(self, inputs: Any) -> bool:
193-
return LANGGRAPH_COMMAND_TYPE is not None and isinstance(
194-
inputs, LANGGRAPH_COMMAND_TYPE
202+
return (
203+
LANGGRAPH_COMMAND_TYPE is not None
204+
and isinstance(inputs, LANGGRAPH_COMMAND_TYPE)
205+
and getattr(inputs, "resume", None) is not None
195206
)
196207

208+
def _store_resume_trace_context(
209+
self, *, resume_key: str, trace_context: TraceContext
210+
) -> None:
211+
self._resume_trace_context_by_key[resume_key] = trace_context
212+
self._resume_trace_context_by_key.move_to_end(resume_key)
213+
214+
if len(self._resume_trace_context_by_key) > MAX_PENDING_RESUME_TRACE_CONTEXTS:
215+
self._resume_trace_context_by_key.popitem(last=False)
216+
197217
def _take_root_trace_context(
198218
self, *, inputs: Any, metadata: Optional[Dict[str, Any]]
199219
) -> tuple[Optional[str], Optional[TraceContext]]:
@@ -228,7 +248,9 @@ def _restore_root_trace_context(
228248

229249
# Span creation failed after we consumed the pending linkage, so put it
230250
# back and let the next retry resume the interrupted trace correctly.
231-
self._resume_trace_context_by_key.setdefault(resume_key, trace_context)
251+
self._store_resume_trace_context(
252+
resume_key=resume_key, trace_context=trace_context
253+
)
232254

233255
def _clear_root_run_resume_key(self, run_id: UUID) -> None:
234256
# Keep the pending interrupt context until an explicit Command(resume=...)
@@ -243,10 +265,13 @@ def _persist_resume_trace_context(self, *, run_id: UUID, observation: Any) -> No
243265
if resume_key is None:
244266
return
245267

246-
self._resume_trace_context_by_key[resume_key] = {
247-
"trace_id": observation.trace_id,
248-
"parent_span_id": observation.id,
249-
}
268+
self._store_resume_trace_context(
269+
resume_key=resume_key,
270+
trace_context={
271+
"trace_id": observation.trace_id,
272+
"parent_span_id": observation.id,
273+
},
274+
)
250275

251276
def _get_error_level_and_status_message(
252277
self, error: BaseException
@@ -534,8 +559,8 @@ def on_chain_start(
534559
resume_key=resume_key, trace_context=trace_context
535560
)
536561
if parent_run_id is None:
537-
self._clear_root_run_resume_key(run_id)
538562
self._exit_propagation_context()
563+
self._reset(run_id)
539564
langfuse_logger.exception(e)
540565

541566
def _register_langfuse_prompt(

tests/unit/test_langchain.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,7 @@ class DummyControlFlowError(RuntimeError):
494494

495495
assert "thread-1" in handler._resume_trace_context_by_key
496496
assert failed_resume_run_id not in handler._root_run_resume_key_map
497+
assert failed_resume_run_id not in handler._child_to_parent_run_id_map
497498
assert handler._propagation_context_manager is None
498499

499500
handler.on_chain_start(
@@ -525,6 +526,92 @@ class DummyControlFlowError(RuntimeError):
525526
otel_context.detach(context_token)
526527

527528

529+
def test_control_flow_resume_ignores_non_resume_commands(
530+
memory_exporter, langfuse_memory_client, monkeypatch
531+
):
532+
class DummyControlFlowError(RuntimeError):
533+
pass
534+
535+
Command = pytest.importorskip("langgraph.types").Command
536+
537+
context_token = otel_context.attach(otel_context.Context())
538+
monkeypatch.setattr(
539+
callback_handler_module,
540+
"CONTROL_FLOW_EXCEPTION_TYPES",
541+
{DummyControlFlowError},
542+
)
543+
544+
try:
545+
handler = CallbackHandler()
546+
547+
interrupt_run_id = uuid4()
548+
goto_run_id = uuid4()
549+
resume_run_id = uuid4()
550+
551+
handler.on_chain_start(
552+
{"name": "LangGraph"},
553+
{"messages": ["need approval"]},
554+
run_id=interrupt_run_id,
555+
metadata={"thread_id": "thread-1"},
556+
)
557+
handler.on_chain_error(
558+
DummyControlFlowError("graph interrupt"),
559+
run_id=interrupt_run_id,
560+
)
561+
562+
handler.on_chain_start(
563+
{"name": "LangGraph"},
564+
Command(goto="approval_node"),
565+
run_id=goto_run_id,
566+
metadata={"thread_id": "thread-1"},
567+
)
568+
handler.on_chain_end(
569+
{"messages": ["routed"]},
570+
run_id=goto_run_id,
571+
)
572+
573+
assert "thread-1" in handler._resume_trace_context_by_key
574+
575+
handler.on_chain_start(
576+
{"name": "LangGraph"},
577+
Command(resume={"approved": True}),
578+
run_id=resume_run_id,
579+
metadata={"thread_id": "thread-1"},
580+
)
581+
handler.on_chain_end(
582+
{"messages": ["approved"]},
583+
run_id=resume_run_id,
584+
)
585+
586+
handler._langfuse_client.flush()
587+
588+
root_spans = [
589+
span
590+
for span in memory_exporter.get_finished_spans()
591+
if span.name == "LangGraph"
592+
]
593+
594+
assert len(root_spans) == 3
595+
596+
spans_by_trace_id = {}
597+
for span in root_spans:
598+
spans_by_trace_id.setdefault(span.context.trace_id, []).append(span)
599+
600+
assert sorted(len(spans) for spans in spans_by_trace_id.values()) == [1, 2]
601+
602+
resumed_trace_spans = next(
603+
spans for spans in spans_by_trace_id.values() if len(spans) == 2
604+
)
605+
initial_span = next(span for span in resumed_trace_spans if span.parent is None)
606+
resumed_span = next(
607+
span for span in resumed_trace_spans if span.parent is not None
608+
)
609+
610+
assert resumed_span.parent.span_id == initial_span.context.span_id
611+
finally:
612+
otel_context.detach(context_token)
613+
614+
528615
def test_root_reset_preserves_other_inflight_resume_keys(
529616
memory_exporter, langfuse_memory_client, monkeypatch
530617
):
@@ -681,3 +768,66 @@ class DummyControlFlowError(RuntimeError):
681768
assert "retriever-thread" in handler._resume_trace_context_by_key
682769
assert retriever_run_id not in handler._root_run_resume_key_map
683770
assert retriever_run_id not in handler._child_to_parent_run_id_map
771+
772+
773+
def test_pending_resume_contexts_are_capped(langfuse_memory_client, monkeypatch):
774+
class DummyControlFlowError(RuntimeError):
775+
pass
776+
777+
monkeypatch.setattr(
778+
callback_handler_module,
779+
"CONTROL_FLOW_EXCEPTION_TYPES",
780+
{DummyControlFlowError},
781+
)
782+
monkeypatch.setattr(
783+
callback_handler_module,
784+
"MAX_PENDING_RESUME_TRACE_CONTEXTS",
785+
4,
786+
)
787+
788+
handler = CallbackHandler()
789+
790+
for index in range(5):
791+
run_id = uuid4()
792+
thread_id = f"thread-{index}"
793+
794+
handler.on_chain_start(
795+
{"name": "LangGraph"},
796+
{"messages": ["need approval"]},
797+
run_id=run_id,
798+
metadata={"thread_id": thread_id},
799+
)
800+
handler.on_chain_error(
801+
DummyControlFlowError(f"graph interrupt {index}"),
802+
run_id=run_id,
803+
)
804+
805+
assert len(handler._resume_trace_context_by_key) == 4
806+
assert list(handler._resume_trace_context_by_key) == [
807+
"thread-1",
808+
"thread-2",
809+
"thread-3",
810+
"thread-4",
811+
]
812+
813+
814+
def test_graphbubbleup_import_is_independent_from_command_import():
815+
real_import = __import__
816+
817+
def import_without_langgraph_command(
818+
name, globals=None, locals=None, fromlist=(), level=0
819+
):
820+
if name == "langgraph.types":
821+
raise ImportError("Command unavailable")
822+
823+
return real_import(name, globals, locals, fromlist, level)
824+
825+
with patch("builtins.__import__", side_effect=import_without_langgraph_command):
826+
reloaded_module = importlib.reload(callback_handler_module)
827+
assert reloaded_module.LANGGRAPH_COMMAND_TYPE is None
828+
assert any(
829+
exception_type.__name__ == "GraphBubbleUp"
830+
for exception_type in reloaded_module.CONTROL_FLOW_EXCEPTION_TYPES
831+
)
832+
833+
importlib.reload(callback_handler_module)

0 commit comments

Comments
 (0)