Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions langfuse/langchain/CallbackHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,11 @@ def _parse_langfuse_trace_attributes(
):
attributes["user_id"] = metadata["langfuse_user_id"]

if "langfuse_trace_name" in metadata and isinstance(
metadata["langfuse_trace_name"], str
):
attributes["trace_name"] = metadata["langfuse_trace_name"]

if tags is not None or (
"langfuse_tags" in metadata and isinstance(metadata["langfuse_tags"], list)
):
Expand Down Expand Up @@ -369,6 +374,7 @@ def on_chain_start(
session_id=parsed_trace_attributes.get("session_id", None),
tags=parsed_trace_attributes.get("tags", None),
metadata=parsed_trace_attributes.get("metadata", None),
trace_name=parsed_trace_attributes.get("trace_name", None),
)

self._propagation_context_manager.__enter__()
Expand Down Expand Up @@ -1403,6 +1409,7 @@ def _strip_langfuse_keys_from_dict(
"langfuse_session_id",
"langfuse_user_id",
"langfuse_tags",
"langfuse_trace_name",
]

metadata_copy = metadata.copy()
Expand Down
83 changes: 83 additions & 0 deletions tests/unit/test_langchain.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from contextvars import copy_context
from unittest.mock import patch
from uuid import uuid4

import pytest
from langchain.messages import HumanMessage
Expand Down Expand Up @@ -166,3 +168,84 @@ def test_chat_model_error_marks_generation_error(langfuse_memory_client, get_spa
assert (
"boom" in span.attributes[LangfuseOtelSpanAttributes.OBSERVATION_STATUS_MESSAGE]
)


def test_root_chain_metadata_propagates_trace_name(
langfuse_memory_client, get_span, find_spans
):
response = ChatResult(
generations=[
ChatGeneration(
message=AIMessage(content="knock knock"),
text="knock knock",
)
],
llm_output={
"token_usage": {
"prompt_tokens": 4,
"completion_tokens": 2,
"total_tokens": 6,
},
"model_name": "gpt-4o-mini",
},
)

with patch.object(ChatOpenAI, "_generate", return_value=response):
handler = CallbackHandler()
prompt = ChatPromptTemplate.from_template("tell me a joke about {topic}")
chain = prompt | ChatOpenAI(api_key="test", temperature=0) | StrOutputParser()

result = chain.invoke(
{"topic": "otters"},
config={
"callbacks": [handler],
"metadata": {"langfuse_trace_name": "langchain-trace-name"},
},
)

assert result == "knock knock"

langfuse_memory_client.flush()
root_span = get_span("RunnableSequence")
generation_span = get_span("ChatOpenAI")

assert (
root_span.attributes[LangfuseOtelSpanAttributes.TRACE_NAME]
== "langchain-trace-name"
)
assert (
generation_span.attributes[LangfuseOtelSpanAttributes.TRACE_NAME]
== "langchain-trace-name"
)
assert (
f"{LangfuseOtelSpanAttributes.OBSERVATION_METADATA}.langfuse_trace_name"
not in root_span.attributes
)
assert len(find_spans("ChatOpenAI")) == 1


def test_root_chain_exports_when_end_runs_in_copied_context(
langfuse_memory_client, get_span
):
handler = CallbackHandler()
run_id = uuid4()

handler.on_chain_start(
{"id": ["RunnableSequence"]},
{"topic": "otters"},
run_id=run_id,
metadata={"langfuse_trace_name": "async-root-trace"},
)

copy_context().run(
handler.on_chain_end,
{"output": "knock knock"},
run_id=run_id,
)

langfuse_memory_client.flush()
root_span = get_span("RunnableSequence")

assert root_span.attributes[LangfuseOtelSpanAttributes.TRACE_NAME] == (
"async-root-trace"
)
Loading