-
Notifications
You must be signed in to change notification settings - Fork 262
Expand file tree
/
Copy pathextract_model.py
More file actions
162 lines (135 loc) · 6.07 KB
/
extract_model.py
File metadata and controls
162 lines (135 loc) · 6.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
"""@private"""
import re
from typing import Any, Dict, List, Literal, Optional
# NOTE ON DEPENDENCIES:
# - since Jan 2024, there is https://pypi.org/project/langchain-openai/ which is a separate package and imports openai models.
# Decided to not make this a dependency of langfuse as few people will have this. Need to match these models manually
# - langchain_community is loaded as a dependency of langchain, but extremely unreliable. Decided to not depend on it.
def _extract_model_name(
serialized: Optional[Dict[str, Any]],
**kwargs: Any,
):
"""Extracts the model name from the serialized or kwargs object. This is used to get the model names for Langfuse."""
# In this function we return on the first match, so the order of operations is important
# First, extract known models where we know the model name aka id
# Extract the model name from the provided path (aray) in the serialized or kwargs object
models_by_id = [
("ChatGoogleGenerativeAI", ["kwargs", "model"], "serialized"),
("ChatMistralAI", ["kwargs", "model"], "serialized"),
("ChatVertexAi", ["kwargs", "model_name"], "serialized"),
("ChatVertexAI", ["kwargs", "model_name"], "serialized"),
("OpenAI", ["invocation_params", "model_name"], "kwargs"),
("ChatOpenAI", ["invocation_params", "model_name"], "kwargs"),
("AzureChatOpenAI", ["invocation_params", "model"], "kwargs"),
("AzureChatOpenAI", ["invocation_params", "model_name"], "kwargs"),
("AzureChatOpenAI", ["invocation_params", "azure_deployment"], "kwargs"),
("HuggingFacePipeline", ["invocation_params", "model_id"], "kwargs"),
("BedrockChat", ["kwargs", "model_id"], "serialized"),
("Bedrock", ["kwargs", "model_id"], "serialized"),
("BedrockLLM", ["kwargs", "model_id"], "serialized"),
("ChatBedrock", ["kwargs", "model_id"], "serialized"),
("LlamaCpp", ["invocation_params", "model_path"], "kwargs"),
("WatsonxLLM", ["invocation_params", "model_id"], "kwargs"),
]
for model_name, keys, select_from in models_by_id:
model = _extract_model_by_path_for_id(
model_name, serialized, kwargs, keys, select_from
)
if model:
return model
# Second, we match AzureOpenAI as we need to extract the model name, fdeployment version and deployment name
if serialized.get("id")[-1] == "AzureOpenAI":
if kwargs.get("invocation_params").get("model"):
return kwargs.get("invocation_params").get("model")
if kwargs.get("invocation_params").get("model_name"):
return kwargs.get("invocation_params").get("model_name")
deployment_name = None
if serialized.get("kwargs").get("openai_api_version"):
deployment_name = serialized.get("kwargs").get("deployment_version")
deployment_version = None
if serialized.get("kwargs").get("deployment_name"):
deployment_name = serialized.get("kwargs").get("deployment_name")
return deployment_name + "-" + deployment_version
# Third, for some models, we are unable to extract the model by a path in an object. Langfuse provides us with a string representation of the model pbjects
# We use regex to extract the model from the repr string
models_by_pattern = [
("Anthropic", "model", "anthropic"),
("ChatAnthropic", "model", None),
("ChatTongyi", "model_name", None),
("ChatCohere", "model", None),
("Cohere", "model", None),
("HuggingFaceHub", "model", None),
("ChatAnyscale", "model_name", None),
("TextGen", "model", "text-gen"),
("Ollama", "model", None),
("OllamaLLM", "model", None),
("ChatOllama", "model", None),
("ChatFireworks", "model", None),
("ChatPerplexity", "model", None),
("VLLM", "model", None),
("Xinference", "model_uid", None),
("ChatOCIGenAI", "model_id", None),
("DeepInfra", "model_id", None),
]
for model_name, pattern, default in models_by_pattern:
model = _extract_model_from_repr_by_pattern(
model_name, serialized, pattern, default
)
if model:
return model
# Finally, we try to extract the most likely paths as a catch all
random_paths = [
["kwargs", "model_name"],
["kwargs", "model"],
["invocation_params", "model_name"],
["invocation_params", "model"],
]
for select in ["kwargs", "serialized"]:
for path in random_paths:
model = _extract_model_by_path(serialized, kwargs, path, select)
if model:
return model
return None
def _extract_model_from_repr_by_pattern(
id: str,
serialized: Optional[Dict[str, Any]],
pattern: str,
default: Optional[str] = None,
):
if serialized is None:
return None
if serialized.get("id")[-1] == id:
if serialized.get("repr"):
extracted = _extract_model_with_regex(pattern, serialized.get("repr"))
return extracted if extracted else default if default else None
return None
def _extract_model_with_regex(pattern: str, text: str):
match = re.search(rf"{pattern}='(.*?)'", text)
if match:
return match.group(1)
return None
def _extract_model_by_path_for_id(
id: str,
serialized: Optional[Dict[str, Any]],
kwargs: dict,
keys: List[str],
select_from: Literal["serialized", "kwargs"],
):
if serialized is None and select_from == "serialized":
return None
if serialized.get("id")[-1] == id:
return _extract_model_by_path(serialized, kwargs, keys, select_from)
def _extract_model_by_path(
serialized: Optional[Dict[str, Any]],
kwargs: dict,
keys: List[str],
select_from: Literal["serialized", "kwargs"],
):
if serialized is None and select_from == "serialized":
return None
current_obj = kwargs if select_from == "kwargs" else serialized
for key in keys:
current_obj = current_obj.get(key)
if not current_obj:
return None
return current_obj if current_obj else None