Skip to content

Commit 189a6a8

Browse files
authored
feat(llama-index): add LlamaIndexInstrumentor (#946)
1 parent 8cbd293 commit 189a6a8

10 files changed

Lines changed: 1245 additions & 746 deletions

langfuse/llama_index/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
from .llama_index import LlamaIndexCallbackHandler
2-
from .span_handler import LlamaIndexSpanHandler
2+
from ._instrumentor import LlamaIndexInstrumentor
33

4-
__all__ = ["LlamaIndexCallbackHandler", "LlamaIndexSpanHandler"]
4+
__all__ = [
5+
"LlamaIndexCallbackHandler",
6+
"LlamaIndexInstrumentor",
7+
]

langfuse/llama_index/_context.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
from contextvars import ContextVar
2+
from typing import Optional, Any, List, Dict
3+
from ._types import InstrumentorContextData
4+
5+
6+
default_context: InstrumentorContextData = {
7+
"trace_id": None,
8+
"parent_observation_id": None,
9+
"update_parent": True,
10+
"trace_name": None,
11+
"root_llama_index_span_id": None,
12+
"is_user_managed_trace": None,
13+
"user_id": None,
14+
"session_id": None,
15+
"version": None,
16+
"release": None,
17+
"metadata": None,
18+
"tags": None,
19+
"public": None,
20+
}
21+
22+
langfuse_instrumentor_context: ContextVar[InstrumentorContextData] = ContextVar(
23+
"langfuse_instrumentor_context",
24+
default={**default_context},
25+
# The spread operator (**) is used here to create a new dictionary
26+
# that is a shallow copy of default_trace_attributes.
27+
# This ensures that each ContextVar instance gets its own copy of the default attributes,
28+
# preventing accidental shared state between different contexts.
29+
# If we didn't use the spread operator, all contexts would reference the same dictionary,
30+
# which could lead to unexpected behavior if the dictionary is modified.
31+
)
32+
33+
34+
class InstrumentorContext:
35+
@staticmethod
36+
def _get_context():
37+
return langfuse_instrumentor_context.get()
38+
39+
@property
40+
def trace_id(self) -> Optional[str]:
41+
return self._get_context()["trace_id"]
42+
43+
@property
44+
def parent_observation_id(self) -> Optional[str]:
45+
return self._get_context()["parent_observation_id"]
46+
47+
@property
48+
def root_llama_index_span_id(self) -> Optional[str]:
49+
return self._get_context()["root_llama_index_span_id"]
50+
51+
@property
52+
def is_user_managed_trace(self) -> Optional[bool]:
53+
return self._get_context()["is_user_managed_trace"]
54+
55+
@property
56+
def update_parent(self) -> Optional[bool]:
57+
return self._get_context()["update_parent"]
58+
59+
@property
60+
def trace_name(self) -> Optional[str]:
61+
return self._get_context()["trace_name"]
62+
63+
@property
64+
def trace_data(self):
65+
return {
66+
"user_id": self._get_context()["user_id"],
67+
"session_id": self._get_context()["session_id"],
68+
"version": self._get_context()["version"],
69+
"release": self._get_context()["release"],
70+
"metadata": self._get_context()["metadata"],
71+
"tags": self._get_context()["tags"],
72+
"public": self._get_context()["public"],
73+
}
74+
75+
@staticmethod
76+
def reset():
77+
langfuse_instrumentor_context.set({**default_context})
78+
79+
def reset_trace_id(self):
80+
previous_context = self._get_context()
81+
82+
langfuse_instrumentor_context.set(
83+
{**previous_context, "trace_id": None, "root_llama_index_span_id": None}
84+
)
85+
86+
@staticmethod
87+
def update(
88+
*,
89+
trace_id: Optional[str] = None,
90+
parent_observation_id: Optional[str] = None,
91+
update_parent: Optional[bool] = None,
92+
root_llama_index_span_id: Optional[str] = None,
93+
is_user_managed_trace: Optional[bool] = None,
94+
trace_name: Optional[str] = None,
95+
user_id: Optional[str] = None,
96+
session_id: Optional[str] = None,
97+
version: Optional[str] = None,
98+
release: Optional[str] = None,
99+
metadata: Optional[Dict[str, Any]] = None,
100+
tags: Optional[List[str]] = None,
101+
public: Optional[bool] = None,
102+
):
103+
updates = {}
104+
105+
if trace_id is not None:
106+
updates["trace_id"] = trace_id
107+
if parent_observation_id is not None:
108+
updates["parent_observation_id"] = parent_observation_id
109+
if update_parent is not None:
110+
updates["update_parent"] = update_parent
111+
if trace_name is not None:
112+
updates["trace_name"] = trace_name
113+
if root_llama_index_span_id is not None:
114+
updates["root_llama_index_span_id"] = root_llama_index_span_id
115+
if is_user_managed_trace is not None:
116+
updates["is_user_managed_trace"] = is_user_managed_trace
117+
if user_id is not None:
118+
updates["user_id"] = user_id
119+
if session_id is not None:
120+
updates["session_id"] = session_id
121+
if version is not None:
122+
updates["version"] = version
123+
if release is not None:
124+
updates["release"] = release
125+
if metadata is not None:
126+
updates["metadata"] = metadata
127+
if tags is not None:
128+
updates["tags"] = tags
129+
if public is not None:
130+
updates["public"] = public
131+
132+
langfuse_instrumentor_context.get().update(updates)
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
from typing import Optional, Any, Union, Dict, Mapping
2+
3+
from langfuse.client import (
4+
Langfuse,
5+
)
6+
from langfuse.model import ModelUsage
7+
8+
9+
try:
10+
from llama_index.core.base.llms.types import (
11+
ChatResponse,
12+
CompletionResponse,
13+
)
14+
from llama_index.core.instrumentation.events import BaseEvent
15+
from llama_index.core.instrumentation.events.embedding import (
16+
EmbeddingStartEvent,
17+
EmbeddingEndEvent,
18+
)
19+
from llama_index.core.instrumentation.event_handlers import BaseEventHandler
20+
from llama_index.core.instrumentation.events.llm import (
21+
LLMCompletionEndEvent,
22+
LLMCompletionStartEvent,
23+
LLMChatEndEvent,
24+
LLMChatStartEvent,
25+
)
26+
from llama_index.core.utilities.token_counting import TokenCounter
27+
28+
except ImportError:
29+
raise ModuleNotFoundError(
30+
"Please install llama-index to use the Langfuse llama-index integration: 'pip install llama-index'"
31+
)
32+
33+
from logging import getLogger
34+
35+
logger = getLogger(__name__)
36+
37+
38+
class LlamaIndexEventHandler(BaseEventHandler, extra="allow"):
39+
def __init__(
40+
self,
41+
*,
42+
langfuse_client: Langfuse,
43+
observation_updates: Dict[str, Dict[str, Any]],
44+
):
45+
super().__init__()
46+
47+
self._langfuse = langfuse_client
48+
self._observation_updates = observation_updates
49+
self._token_counter = TokenCounter()
50+
51+
@classmethod
52+
def class_name(cls) -> str:
53+
"""Class name."""
54+
return "LlamaIndexEventHandler"
55+
56+
def handle(self, event: BaseEvent) -> None:
57+
logger.debug(f"Event {type(event).__name__} received: {event}")
58+
59+
if isinstance(
60+
event, (LLMCompletionStartEvent, LLMChatStartEvent, EmbeddingStartEvent)
61+
):
62+
self.update_generation_from_start_event(event)
63+
elif isinstance(
64+
event, (LLMCompletionEndEvent, LLMChatEndEvent, EmbeddingEndEvent)
65+
):
66+
self.update_generation_from_end_event(event)
67+
68+
def update_generation_from_start_event(
69+
self,
70+
event: Union[LLMCompletionStartEvent, LLMChatStartEvent, EmbeddingStartEvent],
71+
) -> None:
72+
if event.span_id is None:
73+
logger.warning("Span ID is not set")
74+
return
75+
76+
model_data = event.model_dict
77+
model = model_data.pop("model", None) or model_data.pop("model_name", None)
78+
traced_model_data = {
79+
k: str(v)
80+
for k, v in model_data.items()
81+
if v is not None
82+
and k
83+
in [
84+
"max_tokens",
85+
"max_retries",
86+
"temperature",
87+
"timeout",
88+
"strict",
89+
"top_logprobs",
90+
"logprobs",
91+
"embed_batch_size",
92+
]
93+
}
94+
95+
self._update_observation_updates(
96+
event.span_id, model=model, model_parameters=traced_model_data
97+
)
98+
99+
def update_generation_from_end_event(
100+
self, event: Union[LLMCompletionEndEvent, LLMChatEndEvent, EmbeddingEndEvent]
101+
) -> None:
102+
if event.span_id is None:
103+
logger.warning("Span ID is not set")
104+
return
105+
106+
usage = None
107+
108+
if isinstance(event, (LLMCompletionEndEvent, LLMChatEndEvent)):
109+
usage = self._parse_token_usage(event.response) if event.response else None
110+
111+
if isinstance(event, EmbeddingEndEvent):
112+
token_count = sum(
113+
self._token_counter.get_string_tokens(chunk) for chunk in event.chunks
114+
)
115+
116+
usage = {
117+
"input": 0,
118+
"output": 0,
119+
"total": token_count or None,
120+
}
121+
122+
self._update_observation_updates(event.span_id, usage=usage)
123+
124+
def _update_observation_updates(self, id_: str, **kwargs) -> None:
125+
if id_ not in self._observation_updates:
126+
return
127+
128+
self._observation_updates[id_].update(kwargs)
129+
130+
def _parse_token_usage(
131+
self, response: Union[ChatResponse, CompletionResponse]
132+
) -> Optional[ModelUsage]:
133+
if (
134+
(raw := getattr(response, "raw", None))
135+
and hasattr(raw, "get")
136+
and (usage := raw.get("usage"))
137+
):
138+
return _parse_usage_from_mapping(usage)
139+
140+
if additional_kwargs := getattr(response, "additional_kwargs", None):
141+
return _parse_usage_from_mapping(additional_kwargs)
142+
143+
144+
def _parse_usage_from_mapping(
145+
usage: Union[object, Mapping[str, Any]],
146+
) -> ModelUsage:
147+
if isinstance(usage, Mapping):
148+
return _get_token_counts_from_mapping(usage)
149+
150+
return _parse_usage_from_object(usage)
151+
152+
153+
def _parse_usage_from_object(usage: object) -> ModelUsage:
154+
model_usage: ModelUsage = {
155+
"unit": None,
156+
"input": None,
157+
"output": None,
158+
"total": None,
159+
"input_cost": None,
160+
"output_cost": None,
161+
"total_cost": None,
162+
}
163+
164+
if (prompt_tokens := getattr(usage, "prompt_tokens", None)) is not None:
165+
model_usage["input"] = prompt_tokens
166+
if (completion_tokens := getattr(usage, "completion_tokens", None)) is not None:
167+
model_usage["output"] = completion_tokens
168+
if (total_tokens := getattr(usage, "total_tokens", None)) is not None:
169+
model_usage["total"] = total_tokens
170+
171+
return model_usage
172+
173+
174+
def _get_token_counts_from_mapping(
175+
usage_mapping: Mapping[str, Any],
176+
) -> ModelUsage:
177+
model_usage: ModelUsage = {
178+
"unit": None,
179+
"input": None,
180+
"output": None,
181+
"total": None,
182+
"input_cost": None,
183+
"output_cost": None,
184+
"total_cost": None,
185+
}
186+
if (prompt_tokens := usage_mapping.get("prompt_tokens")) is not None:
187+
model_usage["input"] = prompt_tokens
188+
if (completion_tokens := usage_mapping.get("completion_tokens")) is not None:
189+
model_usage["output"] = completion_tokens
190+
if (total_tokens := usage_mapping.get("total_tokens")) is not None:
191+
model_usage["total"] = total_tokens
192+
193+
return model_usage

0 commit comments

Comments
 (0)