8484
8585LANGSMITH_TAG_HIDDEN : str = "langsmith:hidden"
8686CONTROL_FLOW_EXCEPTION_TYPES : Set [Type [BaseException ]] = set ()
87+ LANGGRAPH_COMMAND_TYPE : Optional [Type [Any ]] = None
8788
8889try :
8990 from langgraph .errors import GraphBubbleUp
91+ from langgraph .types import Command as LangGraphCommand
9092
9193 CONTROL_FLOW_EXCEPTION_TYPES .add (GraphBubbleUp )
94+ LANGGRAPH_COMMAND_TYPE = LangGraphCommand
9295except ImportError :
9396 pass
9497
@@ -136,8 +139,9 @@ def __init__(
136139 self ._propagation_context_manager : Optional [_AgnosticContextManager ] = None
137140 self ._trace_context = trace_context
138141 # LangGraph resumes as a fresh root callback run after interrupting, so we keep
139- # just enough trace context to stitch the resume back onto the original trace.
140- self ._resume_trace_context : Optional [TraceContext ] = None
142+ # pending resume contexts keyed by thread/session instead of a single shared slot.
143+ self ._resume_trace_context_by_key : Dict [str , TraceContext ] = {}
144+ self ._root_run_resume_key_map : Dict [UUID , str ] = {}
141145 self ._child_to_parent_run_id_map : Dict [UUID , Optional [UUID ]] = {}
142146
143147 self .last_trace_id : Optional [str ] = None
@@ -164,7 +168,35 @@ def on_llm_new_token(
164168
165169 self ._updated_completion_start_time_memo .add (run_id )
166170
167- def _consume_root_trace_context (self ) -> Optional [TraceContext ]:
171+ def _get_langgraph_resume_key (
172+ self , metadata : Optional [Dict [str , Any ]]
173+ ) -> Optional [str ]:
174+ thread_id = metadata .get ("thread_id" ) if metadata else None
175+
176+ if thread_id is None :
177+ return None
178+
179+ return str (thread_id )
180+
181+ def _set_root_run_resume_key (
182+ self , run_id : UUID , metadata : Optional [Dict [str , Any ]]
183+ ) -> None :
184+ resume_key = self ._get_langgraph_resume_key (metadata )
185+
186+ if resume_key is not None :
187+ self ._root_run_resume_key_map [run_id ] = resume_key
188+
189+ def _pop_root_run_resume_key (self , run_id : UUID ) -> Optional [str ]:
190+ return self ._root_run_resume_key_map .pop (run_id , None )
191+
192+ def _is_langgraph_resume (self , inputs : Any ) -> bool :
193+ return LANGGRAPH_COMMAND_TYPE is not None and isinstance (
194+ inputs , LANGGRAPH_COMMAND_TYPE
195+ )
196+
197+ def _consume_root_trace_context (
198+ self , * , inputs : Any , metadata : Optional [Dict [str , Any ]]
199+ ) -> Optional [TraceContext ]:
168200 if self ._trace_context is not None :
169201 return self ._trace_context
170202
@@ -175,19 +207,30 @@ def _consume_root_trace_context(self) -> Optional[TraceContext]:
175207 if current_span_context .is_valid :
176208 return None
177209
178- trace_context = self ._resume_trace_context
179- self ._resume_trace_context = None
210+ # Only explicit LangGraph resumes should consume pending trace linkage.
211+ if not self ._is_langgraph_resume (inputs ):
212+ return None
213+
214+ resume_key = self ._get_langgraph_resume_key (metadata )
215+ if resume_key is None :
216+ return None
180217
181- return trace_context
218+ return self . _resume_trace_context_by_key . pop ( resume_key , None )
182219
183- def _clear_resume_trace_context (self ) -> None :
184- self ._resume_trace_context = None
220+ def _clear_root_run_resume_key (self , run_id : UUID ) -> None :
221+ # Keep the pending interrupt context until an explicit Command(resume=...)
222+ # arrives. A separate root run on the same thread_id is not a resume.
223+ self ._pop_root_run_resume_key (run_id )
185224
186- def _persist_resume_trace_context (self , observation : Any ) -> None :
225+ def _persist_resume_trace_context (self , * , run_id : UUID , observation : Any ) -> None :
187226 if self ._trace_context is not None :
188227 return
189228
190- self ._resume_trace_context = {
229+ resume_key = self ._pop_root_run_resume_key (run_id )
230+ if resume_key is None :
231+ return
232+
233+ self ._resume_trace_context_by_key [resume_key ] = {
191234 "trace_id" : observation .trace_id ,
192235 "parent_span_id" : observation .id ,
193236 }
@@ -309,12 +352,17 @@ def on_retriever_error(
309352 ).end ()
310353
311354 if parent_run_id is None and level == "DEFAULT" :
312- self ._persist_resume_trace_context (observation )
355+ self ._persist_resume_trace_context (
356+ run_id = run_id , observation = observation
357+ )
313358 elif parent_run_id is None :
314- self ._clear_resume_trace_context ( )
359+ self ._clear_root_run_resume_key ( run_id )
315360
316361 except Exception as e :
317362 langfuse_logger .exception (e )
363+ finally :
364+ if parent_run_id is None :
365+ self ._reset ()
318366
319367 def _parse_langfuse_trace_attributes (
320368 self , * , metadata : Optional [Dict [str , Any ]], tags : Optional [List [str ]]
@@ -383,7 +431,7 @@ def _get_langchain_observation_metadata(
383431 def on_chain_start (
384432 self ,
385433 serialized : Optional [Dict [str , Any ]],
386- inputs : Dict [ str , Any ] ,
434+ inputs : Any ,
387435 * ,
388436 run_id : UUID ,
389437 parent_run_id : Optional [UUID ] = None ,
@@ -392,6 +440,8 @@ def on_chain_start(
392440 ** kwargs : Any ,
393441 ) -> Any :
394442 self ._child_to_parent_run_id_map [run_id ] = parent_run_id
443+ if parent_run_id is None :
444+ self ._set_root_run_resume_key (run_id , metadata )
395445
396446 try :
397447 self ._log_debug_event (
@@ -432,7 +482,9 @@ def on_chain_start(
432482 obs = self ._get_parent_observation (parent_run_id )
433483 if isinstance (obs , Langfuse ):
434484 span = obs .start_observation (
435- trace_context = self ._consume_root_trace_context (),
485+ trace_context = self ._consume_root_trace_context (
486+ inputs = inputs , metadata = metadata
487+ ),
436488 name = span_name ,
437489 as_type = observation_type ,
438490 metadata = span_metadata ,
@@ -636,7 +688,7 @@ def on_chain_end(
636688 )
637689
638690 if parent_run_id is None :
639- self ._clear_resume_trace_context ( )
691+ self ._clear_root_run_resume_key ( run_id )
640692 self ._exit_propagation_context ()
641693
642694 span .end ()
@@ -679,9 +731,11 @@ def on_chain_error(
679731
680732 if parent_run_id is None :
681733 if level == "DEFAULT" :
682- self ._persist_resume_trace_context (observation )
734+ self ._persist_resume_trace_context (
735+ run_id = run_id , observation = observation
736+ )
683737 else :
684- self ._clear_resume_trace_context ( )
738+ self ._clear_root_run_resume_key ( run_id )
685739 self ._exit_propagation_context ()
686740
687741 observation .end ()
@@ -739,6 +793,8 @@ def on_llm_start(
739793 ** kwargs : Any ,
740794 ) -> Any :
741795 self ._child_to_parent_run_id_map [run_id ] = parent_run_id
796+ if parent_run_id is None :
797+ self ._set_root_run_resume_key (run_id , metadata )
742798
743799 try :
744800 self ._log_debug_event (
@@ -794,7 +850,7 @@ def on_tool_start(
794850 parent_observation = self ._get_parent_observation (parent_run_id )
795851 if isinstance (parent_observation , Langfuse ):
796852 span = parent_observation .start_observation (
797- trace_context = self ._consume_root_trace_context () ,
853+ trace_context = self ._trace_context ,
798854 name = self .get_langchain_run_name (serialized , ** kwargs ),
799855 as_type = observation_type ,
800856 input = input_str ,
@@ -846,7 +902,7 @@ def on_retriever_start(
846902 parent_observation = self ._get_parent_observation (parent_run_id )
847903 if isinstance (parent_observation , Langfuse ):
848904 span = parent_observation .start_observation (
849- trace_context = self ._consume_root_trace_context () ,
905+ trace_context = self ._trace_context ,
850906 name = span_name ,
851907 as_type = observation_type ,
852908 metadata = span_metadata ,
@@ -889,14 +945,17 @@ def on_retriever_end(
889945
890946 if observation is not None :
891947 if parent_run_id is None :
892- self ._clear_resume_trace_context ( )
948+ self ._clear_root_run_resume_key ( run_id )
893949 observation .update (
894950 output = documents ,
895951 input = kwargs .get ("inputs" ),
896952 ).end ()
897953
898954 except Exception as e :
899955 langfuse_logger .exception (e )
956+ finally :
957+ if parent_run_id is None :
958+ self ._reset ()
900959
901960 def on_tool_end (
902961 self ,
@@ -913,7 +972,7 @@ def on_tool_end(
913972
914973 if observation is not None :
915974 if parent_run_id is None :
916- self ._clear_resume_trace_context ( )
975+ self ._clear_root_run_resume_key ( run_id )
917976 observation .update (
918977 output = output ,
919978 input = kwargs .get ("inputs" ),
@@ -947,12 +1006,17 @@ def on_tool_error(
9471006 ).end ()
9481007
9491008 if parent_run_id is None and level == "DEFAULT" :
950- self ._persist_resume_trace_context (observation )
1009+ self ._persist_resume_trace_context (
1010+ run_id = run_id , observation = observation
1011+ )
9511012 elif parent_run_id is None :
952- self ._clear_resume_trace_context ( )
1013+ self ._clear_root_run_resume_key ( run_id )
9531014
9541015 except Exception as e :
9551016 langfuse_logger .exception (e )
1017+ finally :
1018+ if parent_run_id is None :
1019+ self ._reset ()
9561020
9571021 def __on_llm_action (
9581022 self ,
@@ -965,6 +1029,8 @@ def __on_llm_action(
9651029 ** kwargs : Any ,
9661030 ) -> None :
9671031 self ._child_to_parent_run_id_map [run_id ] = parent_run_id
1032+ if parent_run_id is None :
1033+ self ._set_root_run_resume_key (run_id , metadata )
9681034
9691035 try :
9701036 tools = kwargs .get ("invocation_params" , {}).get ("tools" , None )
@@ -1012,7 +1078,7 @@ def __on_llm_action(
10121078 parent_observation = self ._get_parent_observation (parent_run_id )
10131079 if isinstance (parent_observation , Langfuse ):
10141080 generation = parent_observation .start_observation (
1015- trace_context = self ._consume_root_trace_context () ,
1081+ trace_context = self ._trace_context ,
10161082 as_type = "generation" ,
10171083 ** content ,
10181084 ) # type: ignore
@@ -1132,7 +1198,7 @@ def on_llm_end(
11321198 self ._updated_completion_start_time_memo .discard (run_id )
11331199
11341200 if parent_run_id is None :
1135- self ._clear_resume_trace_context ( )
1201+ self ._clear_root_run_resume_key ( run_id )
11361202 self ._reset ()
11371203
11381204 def on_llm_error (
@@ -1161,15 +1227,21 @@ def on_llm_error(
11611227 ).end ()
11621228
11631229 if parent_run_id is None and level == "DEFAULT" :
1164- self ._persist_resume_trace_context (generation )
1230+ self ._persist_resume_trace_context (
1231+ run_id = run_id , observation = generation
1232+ )
11651233 elif parent_run_id is None :
1166- self ._clear_resume_trace_context ( )
1234+ self ._clear_root_run_resume_key ( run_id )
11671235
11681236 except Exception as e :
11691237 langfuse_logger .exception (e )
1238+ finally :
1239+ if parent_run_id is None :
1240+ self ._reset ()
11701241
11711242 def _reset (self ) -> None :
11721243 self ._child_to_parent_run_id_map = {}
1244+ self ._root_run_resume_key_map = {}
11731245
11741246 def _exit_propagation_context (self ) -> None :
11751247 manager = self ._propagation_context_manager
0 commit comments