@@ -135,6 +135,9 @@ def __init__(
135135 self ._updated_completion_start_time_memo : Set [UUID ] = set ()
136136 self ._propagation_context_manager : Optional [_AgnosticContextManager ] = None
137137 self ._trace_context = trace_context
138+ # LangGraph resumes as a fresh root callback run after interrupting, so we keep
139+ # just enough trace context to stitch the resume back onto the original trace.
140+ self ._resume_trace_context : Optional [TraceContext ] = None
138141 self ._child_to_parent_run_id_map : Dict [UUID , Optional [UUID ]] = {}
139142
140143 self .last_trace_id : Optional [str ] = None
@@ -161,6 +164,44 @@ def on_llm_new_token(
161164
162165 self ._updated_completion_start_time_memo .add (run_id )
163166
167+ def _consume_root_trace_context (self ) -> Optional [TraceContext ]:
168+ if self ._trace_context is not None :
169+ return self ._trace_context
170+
171+ current_span_context = trace .get_current_span ().get_span_context ()
172+
173+ # Only reuse the pending resume context when this callback run has no active
174+ # parent span of its own. Nested callbacks should attach normally.
175+ if current_span_context .is_valid :
176+ return None
177+
178+ trace_context = self ._resume_trace_context
179+ self ._resume_trace_context = None
180+
181+ return trace_context
182+
183+ def _clear_resume_trace_context (self ) -> None :
184+ self ._resume_trace_context = None
185+
186+ def _persist_resume_trace_context (self , observation : Any ) -> None :
187+ if self ._trace_context is not None :
188+ return
189+
190+ self ._resume_trace_context = {
191+ "trace_id" : observation .trace_id ,
192+ "parent_span_id" : observation .id ,
193+ }
194+
195+ def _get_error_level_and_status_message (
196+ self , error : BaseException
197+ ) -> tuple [Literal ["DEFAULT" , "ERROR" ], str ]:
198+ # LangGraph uses GraphBubbleUp subclasses for expected control flow such as
199+ # interrupts and handoffs, so they should stay visible without being errors.
200+ if any (isinstance (error , t ) for t in CONTROL_FLOW_EXCEPTION_TYPES ):
201+ return "DEFAULT" , str (error ) or type (error ).__name__
202+
203+ return "ERROR" , str (error )
204+
164205 def _get_observation_type_from_serialized (
165206 self , serialized : Optional [Dict [str , Any ]], callback_type : str , ** kwargs : Any
166207 ) -> Union [
@@ -256,13 +297,22 @@ def on_retriever_error(
256297 observation = self ._detach_observation (run_id )
257298
258299 if observation is not None :
300+ level , status_message = self ._get_error_level_and_status_message (error )
259301 observation .update (
260- level = "ERROR" ,
261- status_message = str (error ),
302+ level = cast (
303+ Optional [Literal ["DEBUG" , "DEFAULT" , "WARNING" , "ERROR" ]],
304+ level ,
305+ ),
306+ status_message = status_message ,
262307 input = kwargs .get ("inputs" ),
263308 cost_details = {"total" : 0 },
264309 ).end ()
265310
311+ if parent_run_id is None and level == "DEFAULT" :
312+ self ._persist_resume_trace_context (observation )
313+ elif parent_run_id is None :
314+ self ._clear_resume_trace_context ()
315+
266316 except Exception as e :
267317 langfuse_logger .exception (e )
268318
@@ -382,7 +432,7 @@ def on_chain_start(
382432 obs = self ._get_parent_observation (parent_run_id )
383433 if isinstance (obs , Langfuse ):
384434 span = obs .start_observation (
385- trace_context = self ._trace_context ,
435+ trace_context = self ._consume_root_trace_context () ,
386436 name = span_name ,
387437 as_type = observation_type ,
388438 metadata = span_metadata ,
@@ -586,6 +636,7 @@ def on_chain_end(
586636 )
587637
588638 if parent_run_id is None :
639+ self ._clear_resume_trace_context ()
589640 self ._exit_propagation_context ()
590641
591642 span .end ()
@@ -611,10 +662,7 @@ def on_chain_error(
611662 ) -> None :
612663 try :
613664 self ._log_debug_event ("on_chain_error" , run_id , parent_run_id , error = error )
614- if any (isinstance (error , t ) for t in CONTROL_FLOW_EXCEPTION_TYPES ):
615- level = None
616- else :
617- level = "ERROR"
665+ level , status_message = self ._get_error_level_and_status_message (error )
618666
619667 observation = self ._detach_observation (run_id )
620668
@@ -624,12 +672,16 @@ def on_chain_error(
624672 Optional [Literal ["DEBUG" , "DEFAULT" , "WARNING" , "ERROR" ]],
625673 level ,
626674 ),
627- status_message = str ( error ) if level else None ,
675+ status_message = status_message ,
628676 input = kwargs .get ("inputs" ),
629677 cost_details = {"total" : 0 },
630678 )
631679
632680 if parent_run_id is None :
681+ if level == "DEFAULT" :
682+ self ._persist_resume_trace_context (observation )
683+ else :
684+ self ._clear_resume_trace_context ()
633685 self ._exit_propagation_context ()
634686
635687 observation .end ()
@@ -739,13 +791,24 @@ def on_tool_start(
739791 serialized , "tool" , ** kwargs
740792 )
741793
742- span = self ._get_parent_observation (parent_run_id ).start_observation (
743- name = self .get_langchain_run_name (serialized , ** kwargs ),
744- as_type = observation_type ,
745- input = input_str ,
746- metadata = meta ,
747- level = "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None ,
748- )
794+ parent_observation = self ._get_parent_observation (parent_run_id )
795+ if isinstance (parent_observation , Langfuse ):
796+ span = parent_observation .start_observation (
797+ trace_context = self ._consume_root_trace_context (),
798+ name = self .get_langchain_run_name (serialized , ** kwargs ),
799+ as_type = observation_type ,
800+ input = input_str ,
801+ metadata = meta ,
802+ level = "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None ,
803+ )
804+ else :
805+ span = parent_observation .start_observation (
806+ name = self .get_langchain_run_name (serialized , ** kwargs ),
807+ as_type = observation_type ,
808+ input = input_str ,
809+ metadata = meta ,
810+ level = "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None ,
811+ )
749812
750813 self ._attach_observation (run_id , span )
751814
@@ -780,16 +843,30 @@ def on_retriever_start(
780843 observation_type = self ._get_observation_type_from_serialized (
781844 serialized , "retriever" , ** kwargs
782845 )
783- span = self ._get_parent_observation (parent_run_id ).start_observation (
784- name = span_name ,
785- as_type = observation_type ,
786- metadata = span_metadata ,
787- input = query ,
788- level = cast (
789- Optional [Literal ["DEBUG" , "DEFAULT" , "WARNING" , "ERROR" ]],
790- span_level ,
791- ),
792- )
846+ parent_observation = self ._get_parent_observation (parent_run_id )
847+ if isinstance (parent_observation , Langfuse ):
848+ span = parent_observation .start_observation (
849+ trace_context = self ._consume_root_trace_context (),
850+ name = span_name ,
851+ as_type = observation_type ,
852+ metadata = span_metadata ,
853+ input = query ,
854+ level = cast (
855+ Optional [Literal ["DEBUG" , "DEFAULT" , "WARNING" , "ERROR" ]],
856+ span_level ,
857+ ),
858+ )
859+ else :
860+ span = parent_observation .start_observation (
861+ name = span_name ,
862+ as_type = observation_type ,
863+ metadata = span_metadata ,
864+ input = query ,
865+ level = cast (
866+ Optional [Literal ["DEBUG" , "DEFAULT" , "WARNING" , "ERROR" ]],
867+ span_level ,
868+ ),
869+ )
793870
794871 self ._attach_observation (run_id , span )
795872
@@ -811,6 +888,8 @@ def on_retriever_end(
811888 observation = self ._detach_observation (run_id )
812889
813890 if observation is not None :
891+ if parent_run_id is None :
892+ self ._clear_resume_trace_context ()
814893 observation .update (
815894 output = documents ,
816895 input = kwargs .get ("inputs" ),
@@ -833,6 +912,8 @@ def on_tool_end(
833912 observation = self ._detach_observation (run_id )
834913
835914 if observation is not None :
915+ if parent_run_id is None :
916+ self ._clear_resume_trace_context ()
836917 observation .update (
837918 output = output ,
838919 input = kwargs .get ("inputs" ),
@@ -854,13 +935,22 @@ def on_tool_error(
854935 observation = self ._detach_observation (run_id )
855936
856937 if observation is not None :
938+ level , status_message = self ._get_error_level_and_status_message (error )
857939 observation .update (
858- status_message = str (error ),
859- level = "ERROR" ,
940+ status_message = status_message ,
941+ level = cast (
942+ Optional [Literal ["DEBUG" , "DEFAULT" , "WARNING" , "ERROR" ]],
943+ level ,
944+ ),
860945 input = kwargs .get ("inputs" ),
861946 cost_details = {"total" : 0 },
862947 ).end ()
863948
949+ if parent_run_id is None and level == "DEFAULT" :
950+ self ._persist_resume_trace_context (observation )
951+ elif parent_run_id is None :
952+ self ._clear_resume_trace_context ()
953+
864954 except Exception as e :
865955 langfuse_logger .exception (e )
866956
@@ -919,9 +1009,17 @@ def __on_llm_action(
9191009 "prompt" : registered_prompt ,
9201010 }
9211011
922- generation = self ._get_parent_observation (parent_run_id ).start_observation (
923- as_type = "generation" , ** content
924- ) # type: ignore
1012+ parent_observation = self ._get_parent_observation (parent_run_id )
1013+ if isinstance (parent_observation , Langfuse ):
1014+ generation = parent_observation .start_observation (
1015+ trace_context = self ._consume_root_trace_context (),
1016+ as_type = "generation" ,
1017+ ** content ,
1018+ ) # type: ignore
1019+ else :
1020+ generation = parent_observation .start_observation (
1021+ as_type = "generation" , ** content
1022+ ) # type: ignore
9251023 self ._attach_observation (run_id , generation )
9261024
9271025 self .last_trace_id = self ._runs [run_id ].trace_id
@@ -1034,6 +1132,7 @@ def on_llm_end(
10341132 self ._updated_completion_start_time_memo .discard (run_id )
10351133
10361134 if parent_run_id is None :
1135+ self ._clear_resume_trace_context ()
10371136 self ._reset ()
10381137
10391138 def on_llm_error (
@@ -1050,13 +1149,22 @@ def on_llm_error(
10501149 generation = self ._detach_observation (run_id )
10511150
10521151 if generation is not None :
1152+ level , status_message = self ._get_error_level_and_status_message (error )
10531153 generation .update (
1054- status_message = str (error ),
1055- level = "ERROR" ,
1154+ status_message = status_message ,
1155+ level = cast (
1156+ Optional [Literal ["DEBUG" , "DEFAULT" , "WARNING" , "ERROR" ]],
1157+ level ,
1158+ ),
10561159 input = kwargs .get ("inputs" ),
10571160 cost_details = {"total" : 0 },
10581161 ).end ()
10591162
1163+ if parent_run_id is None and level == "DEFAULT" :
1164+ self ._persist_resume_trace_context (generation )
1165+ elif parent_run_id is None :
1166+ self ._clear_resume_trace_context ()
1167+
10601168 except Exception as e :
10611169 langfuse_logger .exception (e )
10621170
0 commit comments