@@ -135,6 +135,9 @@ def __init__(
135135 self ._updated_completion_start_time_memo : Set [UUID ] = set ()
136136 self ._propagation_context_manager : Optional [_AgnosticContextManager ] = None
137137 self ._trace_context = trace_context
138+ # LangGraph resumes as a fresh root callback run after interrupting, so we keep
139+ # just enough trace context to stitch the resume back onto the original trace.
140+ self ._resume_trace_context : Optional [TraceContext ] = None
138141 self ._child_to_parent_run_id_map : Dict [UUID , Optional [UUID ]] = {}
139142
140143 self .last_trace_id : Optional [str ] = None
@@ -161,6 +164,44 @@ def on_llm_new_token(
161164
162165 self ._updated_completion_start_time_memo .add (run_id )
163166
167+ def _consume_root_trace_context (self ) -> Optional [TraceContext ]:
168+ if self ._trace_context is not None :
169+ return self ._trace_context
170+
171+ current_span_context = trace .get_current_span ().get_span_context ()
172+
173+ # Only reuse the pending resume context when this callback run has no active
174+ # parent span of its own. Nested callbacks should attach normally.
175+ if current_span_context .is_valid :
176+ return None
177+
178+ trace_context = self ._resume_trace_context
179+ self ._resume_trace_context = None
180+
181+ return trace_context
182+
183+ def _clear_resume_trace_context (self ) -> None :
184+ self ._resume_trace_context = None
185+
186+ def _persist_resume_trace_context (self , observation : Any ) -> None :
187+ if self ._trace_context is not None :
188+ return
189+
190+ self ._resume_trace_context = {
191+ "trace_id" : observation .trace_id ,
192+ "parent_span_id" : observation .id ,
193+ }
194+
195+ def _get_error_level_and_status_message (
196+ self , error : BaseException
197+ ) -> tuple [Literal ["DEFAULT" , "ERROR" ], str ]:
198+ # LangGraph uses GraphBubbleUp subclasses for expected control flow such as
199+ # interrupts and handoffs, so they should stay visible without being errors.
200+ if any (isinstance (error , t ) for t in CONTROL_FLOW_EXCEPTION_TYPES ):
201+ return "DEFAULT" , str (error ) or type (error ).__name__
202+
203+ return "ERROR" , str (error )
204+
164205 def _get_observation_type_from_serialized (
165206 self , serialized : Optional [Dict [str , Any ]], callback_type : str , ** kwargs : Any
166207 ) -> Union [
@@ -256,13 +297,22 @@ def on_retriever_error(
256297 observation = self ._detach_observation (run_id )
257298
258299 if observation is not None :
300+ level , status_message = self ._get_error_level_and_status_message (error )
259301 observation .update (
260- level = "ERROR" ,
261- status_message = str (error ),
302+ level = cast (
303+ Optional [Literal ["DEBUG" , "DEFAULT" , "WARNING" , "ERROR" ]],
304+ level ,
305+ ),
306+ status_message = status_message ,
262307 input = kwargs .get ("inputs" ),
263308 cost_details = {"total" : 0 },
264309 ).end ()
265310
311+ if parent_run_id is None and level == "DEFAULT" :
312+ self ._persist_resume_trace_context (observation )
313+ elif parent_run_id is None :
314+ self ._clear_resume_trace_context ()
315+
266316 except Exception as e :
267317 langfuse_logger .exception (e )
268318
@@ -376,7 +426,7 @@ def on_chain_start(
376426 obs = self ._get_parent_observation (parent_run_id )
377427 if isinstance (obs , Langfuse ):
378428 span = obs .start_observation (
379- trace_context = self ._trace_context ,
429+ trace_context = self ._consume_root_trace_context () ,
380430 name = span_name ,
381431 as_type = observation_type ,
382432 metadata = span_metadata ,
@@ -580,6 +630,7 @@ def on_chain_end(
580630 )
581631
582632 if parent_run_id is None :
633+ self ._clear_resume_trace_context ()
583634 self ._exit_propagation_context ()
584635
585636 span .end ()
@@ -605,10 +656,7 @@ def on_chain_error(
605656 ) -> None :
606657 try :
607658 self ._log_debug_event ("on_chain_error" , run_id , parent_run_id , error = error )
608- if any (isinstance (error , t ) for t in CONTROL_FLOW_EXCEPTION_TYPES ):
609- level = None
610- else :
611- level = "ERROR"
659+ level , status_message = self ._get_error_level_and_status_message (error )
612660
613661 observation = self ._detach_observation (run_id )
614662
@@ -618,12 +666,16 @@ def on_chain_error(
618666 Optional [Literal ["DEBUG" , "DEFAULT" , "WARNING" , "ERROR" ]],
619667 level ,
620668 ),
621- status_message = str ( error ) if level else None ,
669+ status_message = status_message ,
622670 input = kwargs .get ("inputs" ),
623671 cost_details = {"total" : 0 },
624672 )
625673
626674 if parent_run_id is None :
675+ if level == "DEFAULT" :
676+ self ._persist_resume_trace_context (observation )
677+ else :
678+ self ._clear_resume_trace_context ()
627679 self ._exit_propagation_context ()
628680
629681 observation .end ()
@@ -733,13 +785,24 @@ def on_tool_start(
733785 serialized , "tool" , ** kwargs
734786 )
735787
736- span = self ._get_parent_observation (parent_run_id ).start_observation (
737- name = self .get_langchain_run_name (serialized , ** kwargs ),
738- as_type = observation_type ,
739- input = input_str ,
740- metadata = meta ,
741- level = "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None ,
742- )
788+ parent_observation = self ._get_parent_observation (parent_run_id )
789+ if isinstance (parent_observation , Langfuse ):
790+ span = parent_observation .start_observation (
791+ trace_context = self ._consume_root_trace_context (),
792+ name = self .get_langchain_run_name (serialized , ** kwargs ),
793+ as_type = observation_type ,
794+ input = input_str ,
795+ metadata = meta ,
796+ level = "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None ,
797+ )
798+ else :
799+ span = parent_observation .start_observation (
800+ name = self .get_langchain_run_name (serialized , ** kwargs ),
801+ as_type = observation_type ,
802+ input = input_str ,
803+ metadata = meta ,
804+ level = "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None ,
805+ )
743806
744807 self ._attach_observation (run_id , span )
745808
@@ -774,16 +837,30 @@ def on_retriever_start(
774837 observation_type = self ._get_observation_type_from_serialized (
775838 serialized , "retriever" , ** kwargs
776839 )
777- span = self ._get_parent_observation (parent_run_id ).start_observation (
778- name = span_name ,
779- as_type = observation_type ,
780- metadata = span_metadata ,
781- input = query ,
782- level = cast (
783- Optional [Literal ["DEBUG" , "DEFAULT" , "WARNING" , "ERROR" ]],
784- span_level ,
785- ),
786- )
840+ parent_observation = self ._get_parent_observation (parent_run_id )
841+ if isinstance (parent_observation , Langfuse ):
842+ span = parent_observation .start_observation (
843+ trace_context = self ._consume_root_trace_context (),
844+ name = span_name ,
845+ as_type = observation_type ,
846+ metadata = span_metadata ,
847+ input = query ,
848+ level = cast (
849+ Optional [Literal ["DEBUG" , "DEFAULT" , "WARNING" , "ERROR" ]],
850+ span_level ,
851+ ),
852+ )
853+ else :
854+ span = parent_observation .start_observation (
855+ name = span_name ,
856+ as_type = observation_type ,
857+ metadata = span_metadata ,
858+ input = query ,
859+ level = cast (
860+ Optional [Literal ["DEBUG" , "DEFAULT" , "WARNING" , "ERROR" ]],
861+ span_level ,
862+ ),
863+ )
787864
788865 self ._attach_observation (run_id , span )
789866
@@ -805,6 +882,8 @@ def on_retriever_end(
805882 observation = self ._detach_observation (run_id )
806883
807884 if observation is not None :
885+ if parent_run_id is None :
886+ self ._clear_resume_trace_context ()
808887 observation .update (
809888 output = documents ,
810889 input = kwargs .get ("inputs" ),
@@ -827,6 +906,8 @@ def on_tool_end(
827906 observation = self ._detach_observation (run_id )
828907
829908 if observation is not None :
909+ if parent_run_id is None :
910+ self ._clear_resume_trace_context ()
830911 observation .update (
831912 output = output ,
832913 input = kwargs .get ("inputs" ),
@@ -848,13 +929,22 @@ def on_tool_error(
848929 observation = self ._detach_observation (run_id )
849930
850931 if observation is not None :
932+ level , status_message = self ._get_error_level_and_status_message (error )
851933 observation .update (
852- status_message = str (error ),
853- level = "ERROR" ,
934+ status_message = status_message ,
935+ level = cast (
936+ Optional [Literal ["DEBUG" , "DEFAULT" , "WARNING" , "ERROR" ]],
937+ level ,
938+ ),
854939 input = kwargs .get ("inputs" ),
855940 cost_details = {"total" : 0 },
856941 ).end ()
857942
943+ if parent_run_id is None and level == "DEFAULT" :
944+ self ._persist_resume_trace_context (observation )
945+ elif parent_run_id is None :
946+ self ._clear_resume_trace_context ()
947+
858948 except Exception as e :
859949 langfuse_logger .exception (e )
860950
@@ -913,9 +1003,17 @@ def __on_llm_action(
9131003 "prompt" : registered_prompt ,
9141004 }
9151005
916- generation = self ._get_parent_observation (parent_run_id ).start_observation (
917- as_type = "generation" , ** content
918- ) # type: ignore
1006+ parent_observation = self ._get_parent_observation (parent_run_id )
1007+ if isinstance (parent_observation , Langfuse ):
1008+ generation = parent_observation .start_observation (
1009+ trace_context = self ._consume_root_trace_context (),
1010+ as_type = "generation" ,
1011+ ** content ,
1012+ ) # type: ignore
1013+ else :
1014+ generation = parent_observation .start_observation (
1015+ as_type = "generation" , ** content
1016+ ) # type: ignore
9191017 self ._attach_observation (run_id , generation )
9201018
9211019 self .last_trace_id = self ._runs [run_id ].trace_id
@@ -1028,6 +1126,7 @@ def on_llm_end(
10281126 self ._updated_completion_start_time_memo .discard (run_id )
10291127
10301128 if parent_run_id is None :
1129+ self ._clear_resume_trace_context ()
10311130 self ._reset ()
10321131
10331132 def on_llm_error (
@@ -1044,13 +1143,22 @@ def on_llm_error(
10441143 generation = self ._detach_observation (run_id )
10451144
10461145 if generation is not None :
1146+ level , status_message = self ._get_error_level_and_status_message (error )
10471147 generation .update (
1048- status_message = str (error ),
1049- level = "ERROR" ,
1148+ status_message = status_message ,
1149+ level = cast (
1150+ Optional [Literal ["DEBUG" , "DEFAULT" , "WARNING" , "ERROR" ]],
1151+ level ,
1152+ ),
10501153 input = kwargs .get ("inputs" ),
10511154 cost_details = {"total" : 0 },
10521155 ).end ()
10531156
1157+ if parent_run_id is None and level == "DEFAULT" :
1158+ self ._persist_resume_trace_context (generation )
1159+ elif parent_run_id is None :
1160+ self ._clear_resume_trace_context ()
1161+
10541162 except Exception as e :
10551163 langfuse_logger .exception (e )
10561164
0 commit comments