Skip to content

Commit c700431

Browse files
authored
fix(langchain): allow prompt linking with langchain v1 create_agent (#1481)
* fix(langchain): allow prompt linking with langchain v1 create_agent * push
1 parent 32dddd5 commit c700431

1 file changed

Lines changed: 46 additions & 10 deletions

File tree

langfuse/langchain/CallbackHandler.py

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@
2828
LangfuseSpan,
2929
LangfuseTool,
3030
)
31-
from langfuse.types import TraceContext
3231
from langfuse._utils import _get_timestamp
3332
from langfuse.langchain.utils import _extract_model_name
3433
from langfuse.logger import langfuse_logger
34+
from langfuse.types import TraceContext
3535

3636
try:
3737
import langchain
@@ -132,6 +132,7 @@ def __init__(
132132
LangfuseRetriever,
133133
],
134134
] = {}
135+
self._child_to_parent_run_id_map: Dict[UUID, Optional[UUID]] = {}
135136
self.context_tokens: Dict[UUID, Token] = {}
136137
self.prompt_to_parent_run_map: Dict[UUID, Any] = {}
137138
self.updated_completion_start_time_memo: Set[UUID] = set()
@@ -302,6 +303,8 @@ def on_chain_start(
302303
metadata: Optional[Dict[str, Any]] = None,
303304
**kwargs: Any,
304305
) -> Any:
306+
self._child_to_parent_run_id_map[run_id] = parent_run_id
307+
305308
try:
306309
self._log_debug_event(
307310
"on_chain_start", run_id, parent_run_id, inputs=inputs
@@ -480,6 +483,8 @@ def on_agent_action(
480483
**kwargs: Any,
481484
) -> Any:
482485
"""Run on agent action."""
486+
self._child_to_parent_run_id_map[run_id] = parent_run_id
487+
483488
try:
484489
self._log_debug_event(
485490
"on_agent_action", run_id, parent_run_id, action=action
@@ -560,6 +565,10 @@ def on_chain_end(
560565
except Exception as e:
561566
langfuse_logger.exception(e)
562567

568+
finally:
569+
if parent_run_id is None:
570+
self._reset()
571+
563572
def on_chain_error(
564573
self,
565574
error: BaseException,
@@ -603,6 +612,8 @@ def on_chat_model_start(
603612
metadata: Optional[Dict[str, Any]] = None,
604613
**kwargs: Any,
605614
) -> Any:
615+
self._child_to_parent_run_id_map[run_id] = parent_run_id
616+
606617
try:
607618
self._log_debug_event(
608619
"on_chat_model_start", run_id, parent_run_id, messages=messages
@@ -635,6 +646,8 @@ def on_llm_start(
635646
metadata: Optional[Dict[str, Any]] = None,
636647
**kwargs: Any,
637648
) -> Any:
649+
self._child_to_parent_run_id_map[run_id] = parent_run_id
650+
638651
try:
639652
self._log_debug_event(
640653
"on_llm_start", run_id, parent_run_id, prompts=prompts
@@ -662,6 +675,8 @@ def on_tool_start(
662675
metadata: Optional[Dict[str, Any]] = None,
663676
**kwargs: Any,
664677
) -> Any:
678+
self._child_to_parent_run_id_map[run_id] = parent_run_id
679+
665680
try:
666681
self._log_debug_event(
667682
"on_tool_start", run_id, parent_run_id, input_str=input_str
@@ -704,6 +719,8 @@ def on_retriever_start(
704719
metadata: Optional[Dict[str, Any]] = None,
705720
**kwargs: Any,
706721
) -> Any:
722+
self._child_to_parent_run_id_map[run_id] = parent_run_id
723+
707724
try:
708725
self._log_debug_event(
709726
"on_retriever_start", run_id, parent_run_id, query=query
@@ -809,6 +826,8 @@ def __on_llm_action(
809826
metadata: Optional[Dict[str, Any]] = None,
810827
**kwargs: Any,
811828
) -> None:
829+
self._child_to_parent_run_id_map[run_id] = parent_run_id
830+
812831
try:
813832
tools = kwargs.get("invocation_params", {}).get("tools", None)
814833
if tools and isinstance(tools, list):
@@ -817,14 +836,23 @@ def __on_llm_action(
817836
model_name = self._parse_model_and_log_errors(
818837
serialized=serialized, metadata=metadata, kwargs=kwargs
819838
)
820-
registered_prompt = (
821-
self.prompt_to_parent_run_map.get(parent_run_id)
822-
if parent_run_id is not None
823-
else None
824-
)
825839

826-
if registered_prompt:
827-
self._deregister_langfuse_prompt(parent_run_id)
840+
registered_prompt = None
841+
current_parent_run_id = parent_run_id
842+
843+
# Check all parents for registered prompt
844+
while current_parent_run_id is not None:
845+
registered_prompt = self.prompt_to_parent_run_map.get(
846+
current_parent_run_id
847+
)
848+
849+
if registered_prompt:
850+
self._deregister_langfuse_prompt(current_parent_run_id)
851+
break
852+
else:
853+
current_parent_run_id = self._child_to_parent_run_id_map.get(
854+
current_parent_run_id, None
855+
)
828856

829857
content = {
830858
"name": self.get_langchain_run_name(serialized, **kwargs),
@@ -956,6 +984,9 @@ def on_llm_end(
956984
finally:
957985
self.updated_completion_start_time_memo.discard(run_id)
958986

987+
if parent_run_id is None:
988+
self._reset()
989+
959990
def on_llm_error(
960991
self,
961992
error: BaseException,
@@ -980,6 +1011,9 @@ def on_llm_error(
9801011
except Exception as e:
9811012
langfuse_logger.exception(e)
9821013

1014+
def _reset(self) -> None:
1015+
self._child_to_parent_run_id_map = {}
1016+
9831017
def __join_tags_and_metadata(
9841018
self,
9851019
tags: Optional[List[str]] = None,
@@ -1047,7 +1081,7 @@ def _log_debug_event(
10471081
**kwargs: Any,
10481082
) -> None:
10491083
langfuse_logger.debug(
1050-
f"Event: {event_name}, run_id: {str(run_id)[:5]}, parent_run_id: {str(parent_run_id)[:5]}"
1084+
f"Event: {event_name}, run_id: {run_id}, parent_run_id: {parent_run_id}"
10511085
)
10521086

10531087

@@ -1210,7 +1244,9 @@ def _parse_usage_model(usage: Union[pydantic.BaseModel, dict]) -> Any:
12101244
usage_model["input"] = max(0, usage_model["input"] - value)
12111245

12121246
if f"input_modality_{item['modality']}" in usage_model:
1213-
usage_model[f"input_modality_{item['modality']}"] = max(0, usage_model[f"input_modality_{item['modality']}"] - value)
1247+
usage_model[f"input_modality_{item['modality']}"] = max(
1248+
0, usage_model[f"input_modality_{item['modality']}"] - value
1249+
)
12141250

12151251
usage_model = {k: v for k, v in usage_model.items() if isinstance(v, int)}
12161252

0 commit comments

Comments
 (0)