11from collections import OrderedDict
22from contextvars import Token
3+ from dataclasses import dataclass , field
34from typing import (
45 Any ,
56 Dict ,
103104 pass
104105
105106
107+ @dataclass
108+ class _RunState :
109+ parent_run_id : Optional [UUID ]
110+ root_run_id : UUID
111+
112+
113+ @dataclass
114+ class _RootRunState :
115+ run_ids : Set [UUID ] = field (default_factory = set )
116+ resume_key : Optional [str ] = None
117+ propagation_context_manager : Optional [_AgnosticContextManager ] = None
118+
119+
120+ class _PendingResumeTraceContextStore :
121+ def __init__ (self , max_size : int ) -> None :
122+ self ._max_size = max_size
123+ self ._contexts : OrderedDict [str , TraceContext ] = OrderedDict ()
124+
125+ def store (self , * , resume_key : str , trace_context : TraceContext ) -> None :
126+ self ._contexts [resume_key ] = trace_context
127+ self ._contexts .move_to_end (resume_key )
128+
129+ if len (self ._contexts ) > self ._max_size :
130+ self ._contexts .popitem (last = False )
131+
132+ def take (self , resume_key : str ) -> Optional [TraceContext ]:
133+ return self ._contexts .pop (resume_key , None )
134+
135+ def __contains__ (self , resume_key : str ) -> bool :
136+ return resume_key in self ._contexts
137+
138+ def __len__ (self ) -> int :
139+ return len (self ._contexts )
140+
141+ def keys (self ) -> List [str ]:
142+ return list (self ._contexts .keys ())
143+
144+
106145class LangchainCallbackHandler (LangchainBaseCallbackHandler ):
107146 def __init__ (
108147 self ,
@@ -143,15 +182,12 @@ def __init__(
143182 self ._context_tokens : Dict [UUID , Token ] = {}
144183 self ._prompt_to_parent_run_map : Dict [UUID , Any ] = {}
145184 self ._updated_completion_start_time_memo : Set [UUID ] = set ()
146- self ._propagation_context_manager : Optional [_AgnosticContextManager ] = None
147185 self ._trace_context = trace_context
148- # LangGraph resumes as a fresh root callback run after interrupting, so we keep
149- # pending resume contexts keyed by thread/session instead of a single shared slot.
150- self ._resume_trace_context_by_key : OrderedDict [str , TraceContext ] = (
151- OrderedDict ()
186+ self ._pending_resume_trace_contexts = _PendingResumeTraceContextStore (
187+ MAX_PENDING_RESUME_TRACE_CONTEXTS
152188 )
153- self ._root_run_resume_key_map : Dict [UUID , str ] = {}
154- self ._child_to_parent_run_id_map : Dict [UUID , Optional [ UUID ] ] = {}
189+ self ._run_states : Dict [UUID , _RunState ] = {}
190+ self ._root_run_states : Dict [UUID , _RootRunState ] = {}
155191
156192 self .last_trace_id : Optional [str ] = None
157193
@@ -187,16 +223,62 @@ def _get_langgraph_resume_key(
187223
188224 return str (thread_id )
189225
190- def _set_root_run_resume_key (
191- self , run_id : UUID , metadata : Optional [Dict [str , Any ]]
226+ def _track_run (
227+ self ,
228+ * ,
229+ run_id : UUID ,
230+ parent_run_id : Optional [UUID ],
231+ metadata : Optional [Dict [str , Any ]] = None ,
192232 ) -> None :
193- resume_key = self ._get_langgraph_resume_key (metadata )
233+ if run_id in self ._run_states :
234+ return
235+
236+ if parent_run_id is None :
237+ root_run_id = run_id
238+ self ._root_run_states [root_run_id ] = _RootRunState (
239+ run_ids = {run_id },
240+ resume_key = self ._get_langgraph_resume_key (metadata ),
241+ )
242+ else :
243+ parent_state = self ._run_states .get (parent_run_id )
244+ root_run_id = (
245+ parent_state .root_run_id if parent_state is not None else parent_run_id
246+ )
247+ root_run_state = self ._root_run_states .setdefault (
248+ root_run_id , _RootRunState ()
249+ )
250+ root_run_state .run_ids .add (run_id )
251+
252+ self ._run_states [run_id ] = _RunState (
253+ parent_run_id = parent_run_id ,
254+ root_run_id = root_run_id ,
255+ )
256+
257+ def _get_run_state (self , run_id : UUID ) -> Optional [_RunState ]:
258+ return self ._run_states .get (run_id )
194259
195- if resume_key is not None :
196- self ._root_run_resume_key_map [run_id ] = resume_key
260+ def _get_root_run_state (self , run_id : UUID ) -> Optional [_RootRunState ]:
261+ run_state = self ._get_run_state (run_id )
262+
263+ if run_state is None :
264+ return None
265+
266+ return self ._root_run_states .get (run_state .root_run_id )
197267
198268 def _pop_root_run_resume_key (self , run_id : UUID ) -> Optional [str ]:
199- return self ._root_run_resume_key_map .pop (run_id , None )
269+ root_run_state = self ._get_root_run_state (run_id )
270+
271+ if root_run_state is None :
272+ return None
273+
274+ resume_key = root_run_state .resume_key
275+ root_run_state .resume_key = None
276+
277+ return resume_key
278+
279+ def _get_parent_run_id (self , run_id : UUID ) -> Optional [UUID ]:
280+ run_state = self ._get_run_state (run_id )
281+ return run_state .parent_run_id if run_state is not None else None
200282
201283 def _is_langgraph_resume (self , inputs : Any ) -> bool :
202284 return (
@@ -208,11 +290,9 @@ def _is_langgraph_resume(self, inputs: Any) -> bool:
208290 def _store_resume_trace_context (
209291 self , * , resume_key : str , trace_context : TraceContext
210292 ) -> None :
211- self ._resume_trace_context_by_key [resume_key ] = trace_context
212- self ._resume_trace_context_by_key .move_to_end (resume_key )
213-
214- if len (self ._resume_trace_context_by_key ) > MAX_PENDING_RESUME_TRACE_CONTEXTS :
215- self ._resume_trace_context_by_key .popitem (last = False )
293+ self ._pending_resume_trace_contexts .store (
294+ resume_key = resume_key , trace_context = trace_context
295+ )
216296
217297 def _take_root_trace_context (
218298 self , * , inputs : Any , metadata : Optional [Dict [str , Any ]]
@@ -235,7 +315,7 @@ def _take_root_trace_context(
235315 if resume_key is None :
236316 return None , None
237317
238- return resume_key , self ._resume_trace_context_by_key . pop (resume_key , None )
318+ return resume_key , self ._pending_resume_trace_contexts . take (resume_key )
239319
240320 def _restore_root_trace_context (
241321 self , * , resume_key : Optional [str ], trace_context : Optional [TraceContext ]
@@ -477,9 +557,7 @@ def on_chain_start(
477557 metadata : Optional [Dict [str , Any ]] = None ,
478558 ** kwargs : Any ,
479559 ) -> Any :
480- self ._child_to_parent_run_id_map [run_id ] = parent_run_id
481- if parent_run_id is None :
482- self ._set_root_run_resume_key (run_id , metadata )
560+ self ._track_run (run_id = run_id , parent_run_id = parent_run_id , metadata = metadata )
483561
484562 span = None
485563 resume_key = None
@@ -511,15 +589,21 @@ def on_chain_start(
511589 metadata = metadata , tags = tags
512590 )
513591
514- self . _propagation_context_manager = propagate_attributes (
592+ propagation_context_manager = propagate_attributes (
515593 user_id = parsed_trace_attributes .get ("user_id" , None ),
516594 session_id = parsed_trace_attributes .get ("session_id" , None ),
517595 tags = parsed_trace_attributes .get ("tags" , None ),
518596 metadata = parsed_trace_attributes .get ("metadata" , None ),
519597 trace_name = parsed_trace_attributes .get ("trace_name" , None ),
520598 )
521599
522- self ._propagation_context_manager .__enter__ ()
600+ root_run_state = self ._get_root_run_state (run_id )
601+ if root_run_state is not None :
602+ root_run_state .propagation_context_manager = (
603+ propagation_context_manager
604+ )
605+
606+ propagation_context_manager .__enter__ ()
523607
524608 obs = self ._get_parent_observation (parent_run_id )
525609 if isinstance (obs , Langfuse ):
@@ -559,7 +643,7 @@ def on_chain_start(
559643 resume_key = resume_key , trace_context = trace_context
560644 )
561645 if parent_run_id is None :
562- self ._exit_propagation_context ()
646+ self ._exit_propagation_context (run_id )
563647 self ._reset (run_id )
564648 langfuse_logger .exception (e )
565649
@@ -665,7 +749,7 @@ def on_agent_action(
665749 ** kwargs : Any ,
666750 ) -> Any :
667751 """Run on agent action."""
668- self ._child_to_parent_run_id_map [ run_id ] = parent_run_id
752+ self ._track_run ( run_id = run_id , parent_run_id = parent_run_id )
669753
670754 try :
671755 self ._log_debug_event (
@@ -739,7 +823,7 @@ def on_chain_end(
739823
740824 if parent_run_id is None :
741825 self ._clear_root_run_resume_key (run_id )
742- self ._exit_propagation_context ()
826+ self ._exit_propagation_context (run_id )
743827
744828 span .end ()
745829
@@ -750,7 +834,7 @@ def on_chain_end(
750834
751835 finally :
752836 if parent_run_id is None :
753- self ._exit_propagation_context ()
837+ self ._exit_propagation_context (run_id )
754838 self ._reset (run_id )
755839
756840 def on_chain_error (
@@ -786,15 +870,15 @@ def on_chain_error(
786870 )
787871 else :
788872 self ._clear_root_run_resume_key (run_id )
789- self ._exit_propagation_context ()
873+ self ._exit_propagation_context (run_id )
790874
791875 observation .end ()
792876
793877 except Exception as e :
794878 langfuse_logger .exception (e )
795879 finally :
796880 if parent_run_id is None :
797- self ._exit_propagation_context ()
881+ self ._exit_propagation_context (run_id )
798882 self ._reset (run_id )
799883
800884 def on_chat_model_start (
@@ -808,9 +892,7 @@ def on_chat_model_start(
808892 metadata : Optional [Dict [str , Any ]] = None ,
809893 ** kwargs : Any ,
810894 ) -> Any :
811- self ._child_to_parent_run_id_map [run_id ] = parent_run_id
812- if parent_run_id is None :
813- self ._set_root_run_resume_key (run_id , metadata )
895+ self ._track_run (run_id = run_id , parent_run_id = parent_run_id , metadata = metadata )
814896
815897 try :
816898 self ._log_debug_event (
@@ -844,9 +926,7 @@ def on_llm_start(
844926 metadata : Optional [Dict [str , Any ]] = None ,
845927 ** kwargs : Any ,
846928 ) -> Any :
847- self ._child_to_parent_run_id_map [run_id ] = parent_run_id
848- if parent_run_id is None :
849- self ._set_root_run_resume_key (run_id , metadata )
929+ self ._track_run (run_id = run_id , parent_run_id = parent_run_id , metadata = metadata )
850930
851931 try :
852932 self ._log_debug_event (
@@ -875,9 +955,7 @@ def on_tool_start(
875955 metadata : Optional [Dict [str , Any ]] = None ,
876956 ** kwargs : Any ,
877957 ) -> Any :
878- self ._child_to_parent_run_id_map [run_id ] = parent_run_id
879- if parent_run_id is None :
880- self ._set_root_run_resume_key (run_id , metadata )
958+ self ._track_run (run_id = run_id , parent_run_id = parent_run_id , metadata = metadata )
881959
882960 try :
883961 self ._log_debug_event (
@@ -936,9 +1014,7 @@ def on_retriever_start(
9361014 metadata : Optional [Dict [str , Any ]] = None ,
9371015 ** kwargs : Any ,
9381016 ) -> Any :
939- self ._child_to_parent_run_id_map [run_id ] = parent_run_id
940- if parent_run_id is None :
941- self ._set_root_run_resume_key (run_id , metadata )
1017+ self ._track_run (run_id = run_id , parent_run_id = parent_run_id , metadata = metadata )
9421018
9431019 try :
9441020 self ._log_debug_event (
@@ -1087,9 +1163,7 @@ def __on_llm_action(
10871163 metadata : Optional [Dict [str , Any ]] = None ,
10881164 ** kwargs : Any ,
10891165 ) -> None :
1090- self ._child_to_parent_run_id_map [run_id ] = parent_run_id
1091- if parent_run_id is None :
1092- self ._set_root_run_resume_key (run_id , metadata )
1166+ self ._track_run (run_id = run_id , parent_run_id = parent_run_id , metadata = metadata )
10931167
10941168 try :
10951169 tools = kwargs .get ("invocation_params" , {}).get ("tools" , None )
@@ -1113,8 +1187,8 @@ def __on_llm_action(
11131187 self ._deregister_langfuse_prompt (current_parent_run_id )
11141188 break
11151189 else :
1116- current_parent_run_id = self ._child_to_parent_run_id_map . get (
1117- current_parent_run_id , None
1190+ current_parent_run_id = self ._get_parent_run_id (
1191+ current_parent_run_id
11181192 )
11191193
11201194 content = {
@@ -1298,38 +1372,30 @@ def on_llm_error(
12981372 if parent_run_id is None :
12991373 self ._reset (run_id )
13001374
1301- def _run_belongs_to_root (self , run_id : UUID , root_run_id : UUID ) -> bool :
1302- current_run_id : Optional [UUID ] = run_id
1303- visited : Set [UUID ] = set ()
1304-
1305- while current_run_id is not None and current_run_id not in visited :
1306- if current_run_id == root_run_id :
1307- return True
1308-
1309- visited .add (current_run_id )
1310- current_run_id = self ._child_to_parent_run_id_map .get (current_run_id )
1311-
1312- return False
1313-
13141375 def _reset (self , root_run_id : UUID ) -> None :
1315- run_ids_to_clear = [
1316- run_id
1317- for run_id in self ._child_to_parent_run_id_map
1318- if self ._run_belongs_to_root (run_id , root_run_id )
1319- ]
1376+ run_state = self ._get_run_state (root_run_id )
1377+ if run_state is None :
1378+ return
13201379
1321- for run_id in run_ids_to_clear :
1322- self ._child_to_parent_run_id_map .pop (run_id , None )
1380+ root_run_state = self ._root_run_states .pop (run_state .root_run_id , None )
1381+ if root_run_state is None :
1382+ self ._run_states .pop (root_run_id , None )
1383+ return
1384+
1385+ for run_id in root_run_state .run_ids :
1386+ self ._run_states .pop (run_id , None )
13231387
1324- self ._root_run_resume_key_map .pop (root_run_id , None )
1388+ def _exit_propagation_context (self , run_id : UUID ) -> None :
1389+ root_run_state = self ._get_root_run_state (run_id )
13251390
1326- def _exit_propagation_context ( self ) -> None :
1327- manager = self . _propagation_context_manager
1391+ if root_run_state is None :
1392+ return
13281393
1394+ manager = root_run_state .propagation_context_manager
13291395 if manager is None :
13301396 return
13311397
1332- self . _propagation_context_manager = None
1398+ root_run_state . propagation_context_manager = None
13331399 manager .__exit__ (None , None , None )
13341400
13351401 def __join_tags_and_metadata (
0 commit comments