@@ -194,28 +194,41 @@ def _is_langgraph_resume(self, inputs: Any) -> bool:
194194 inputs , LANGGRAPH_COMMAND_TYPE
195195 )
196196
197- def _consume_root_trace_context (
197+ def _take_root_trace_context (
198198 self , * , inputs : Any , metadata : Optional [Dict [str , Any ]]
199- ) -> Optional [TraceContext ]:
199+ ) -> tuple [ Optional [str ], Optional [ TraceContext ] ]:
200200 if self ._trace_context is not None :
201- return self ._trace_context
201+ return None , self ._trace_context
202202
203203 current_span_context = trace .get_current_span ().get_span_context ()
204204
205205 # Only reuse the pending resume context when this callback run has no active
206206 # parent span of its own. Nested callbacks should attach normally.
207207 if current_span_context .is_valid :
208- return None
208+ return None , None
209209
210210 # Only explicit LangGraph resumes should consume pending trace linkage.
211211 if not self ._is_langgraph_resume (inputs ):
212- return None
212+ return None , None
213213
214214 resume_key = self ._get_langgraph_resume_key (metadata )
215215 if resume_key is None :
216- return None
216+ return None , None
217+
218+ return resume_key , self ._resume_trace_context_by_key .pop (resume_key , None )
219+
220+ def _restore_root_trace_context (
221+ self , * , resume_key : Optional [str ], trace_context : Optional [TraceContext ]
222+ ) -> None :
223+ if self ._trace_context is not None :
224+ return
217225
218- return self ._resume_trace_context_by_key .pop (resume_key , None )
226+ if resume_key is None or trace_context is None :
227+ return
228+
229+ # Span creation failed after we consumed the pending linkage, so put it
230+ # back and let the next retry resume the interrupted trace correctly.
231+ self ._resume_trace_context_by_key .setdefault (resume_key , trace_context )
219232
220233 def _clear_root_run_resume_key (self , run_id : UUID ) -> None :
221234 # Keep the pending interrupt context until an explicit Command(resume=...)
@@ -362,7 +375,7 @@ def on_retriever_error(
362375 langfuse_logger .exception (e )
363376 finally :
364377 if parent_run_id is None :
365- self ._reset ()
378+ self ._reset (run_id )
366379
367380 def _parse_langfuse_trace_attributes (
368381 self , * , metadata : Optional [Dict [str , Any ]], tags : Optional [List [str ]]
@@ -443,6 +456,10 @@ def on_chain_start(
443456 if parent_run_id is None :
444457 self ._set_root_run_resume_key (run_id , metadata )
445458
459+ span = None
460+ resume_key = None
461+ trace_context = None
462+
446463 try :
447464 self ._log_debug_event (
448465 "on_chain_start" , run_id , parent_run_id , inputs = inputs
@@ -481,10 +498,11 @@ def on_chain_start(
481498
482499 obs = self ._get_parent_observation (parent_run_id )
483500 if isinstance (obs , Langfuse ):
501+ resume_key , trace_context = self ._take_root_trace_context (
502+ inputs = inputs , metadata = metadata
503+ )
484504 span = obs .start_observation (
485- trace_context = self ._consume_root_trace_context (
486- inputs = inputs , metadata = metadata
487- ),
505+ trace_context = trace_context ,
488506 name = span_name ,
489507 as_type = observation_type ,
490508 metadata = span_metadata ,
@@ -511,6 +529,13 @@ def on_chain_start(
511529 self .last_trace_id = self ._runs [run_id ].trace_id
512530
513531 except Exception as e :
532+ if span is None :
533+ self ._restore_root_trace_context (
534+ resume_key = resume_key , trace_context = trace_context
535+ )
536+ if parent_run_id is None :
537+ self ._clear_root_run_resume_key (run_id )
538+ self ._exit_propagation_context ()
514539 langfuse_logger .exception (e )
515540
516541 def _register_langfuse_prompt (
@@ -701,7 +726,7 @@ def on_chain_end(
701726 finally :
702727 if parent_run_id is None :
703728 self ._exit_propagation_context ()
704- self ._reset ()
729+ self ._reset (run_id )
705730
706731 def on_chain_error (
707732 self ,
@@ -745,7 +770,7 @@ def on_chain_error(
745770 finally :
746771 if parent_run_id is None :
747772 self ._exit_propagation_context ()
748- self ._reset ()
773+ self ._reset (run_id )
749774
750775 def on_chat_model_start (
751776 self ,
@@ -759,6 +784,8 @@ def on_chat_model_start(
759784 ** kwargs : Any ,
760785 ) -> Any :
761786 self ._child_to_parent_run_id_map [run_id ] = parent_run_id
787+ if parent_run_id is None :
788+ self ._set_root_run_resume_key (run_id , metadata )
762789
763790 try :
764791 self ._log_debug_event (
@@ -824,6 +851,8 @@ def on_tool_start(
824851 ** kwargs : Any ,
825852 ) -> Any :
826853 self ._child_to_parent_run_id_map [run_id ] = parent_run_id
854+ if parent_run_id is None :
855+ self ._set_root_run_resume_key (run_id , metadata )
827856
828857 try :
829858 self ._log_debug_event (
@@ -883,6 +912,8 @@ def on_retriever_start(
883912 ** kwargs : Any ,
884913 ) -> Any :
885914 self ._child_to_parent_run_id_map [run_id ] = parent_run_id
915+ if parent_run_id is None :
916+ self ._set_root_run_resume_key (run_id , metadata )
886917
887918 try :
888919 self ._log_debug_event (
@@ -955,7 +986,7 @@ def on_retriever_end(
955986 langfuse_logger .exception (e )
956987 finally :
957988 if parent_run_id is None :
958- self ._reset ()
989+ self ._reset (run_id )
959990
960991 def on_tool_end (
961992 self ,
@@ -980,6 +1011,9 @@ def on_tool_end(
9801011
9811012 except Exception as e :
9821013 langfuse_logger .exception (e )
1014+ finally :
1015+ if parent_run_id is None :
1016+ self ._reset (run_id )
9831017
9841018 def on_tool_error (
9851019 self ,
@@ -1016,7 +1050,7 @@ def on_tool_error(
10161050 langfuse_logger .exception (e )
10171051 finally :
10181052 if parent_run_id is None :
1019- self ._reset ()
1053+ self ._reset (run_id )
10201054
10211055 def __on_llm_action (
10221056 self ,
@@ -1199,7 +1233,7 @@ def on_llm_end(
11991233
12001234 if parent_run_id is None :
12011235 self ._clear_root_run_resume_key (run_id )
1202- self ._reset ()
1236+ self ._reset (run_id )
12031237
12041238 def on_llm_error (
12051239 self ,
@@ -1237,11 +1271,32 @@ def on_llm_error(
12371271 langfuse_logger .exception (e )
12381272 finally :
12391273 if parent_run_id is None :
1240- self ._reset ()
1274+ self ._reset (run_id )
1275+
1276+ def _run_belongs_to_root (self , run_id : UUID , root_run_id : UUID ) -> bool :
1277+ current_run_id : Optional [UUID ] = run_id
1278+ visited : Set [UUID ] = set ()
1279+
1280+ while current_run_id is not None and current_run_id not in visited :
1281+ if current_run_id == root_run_id :
1282+ return True
1283+
1284+ visited .add (current_run_id )
1285+ current_run_id = self ._child_to_parent_run_id_map .get (current_run_id )
1286+
1287+ return False
1288+
1289+ def _reset (self , root_run_id : UUID ) -> None :
1290+ run_ids_to_clear = [
1291+ run_id
1292+ for run_id in self ._child_to_parent_run_id_map
1293+ if self ._run_belongs_to_root (run_id , root_run_id )
1294+ ]
1295+
1296+ for run_id in run_ids_to_clear :
1297+ self ._child_to_parent_run_id_map .pop (run_id , None )
12411298
1242- def _reset (self ) -> None :
1243- self ._child_to_parent_run_id_map = {}
1244- self ._root_run_resume_key_map = {}
1299+ self ._root_run_resume_key_map .pop (root_run_id , None )
12451300
12461301 def _exit_propagation_context (self ) -> None :
12471302 manager = self ._propagation_context_manager
0 commit comments