-
Notifications
You must be signed in to change notification settings - Fork 262
Expand file tree
/
Copy pathtest_extract_model.py
More file actions
158 lines (145 loc) · 4.88 KB
/
test_extract_model.py
File metadata and controls
158 lines (145 loc) · 4.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from typing import Any
from unittest.mock import MagicMock
import pytest
from langchain.schema.messages import HumanMessage
from langchain_anthropic import Anthropic, ChatAnthropic
from langchain_aws import BedrockLLM, ChatBedrock
from langchain_community.chat_models import (
ChatCohere,
ChatTongyi,
)
from langchain_community.chat_models.fake import FakeMessagesListChatModel
# from langchain_huggingface.llms import HuggingFacePipeline
from langchain_community.llms.textgen import TextGen
from langchain_core.load.dump import default
from langchain_google_vertexai import ChatVertexAI
from langchain_groq import ChatGroq
from langchain_mistralai.chat_models import ChatMistralAI
from langchain_ollama import ChatOllama, OllamaLLM
from langchain_openai import (
AzureChatOpenAI,
ChatOpenAI,
OpenAI,
)
from langfuse.langchain import CallbackHandler
from langfuse.langchain.utils import _extract_model_name
from tests.utils import get_api
@pytest.mark.parametrize(
"expected_model,model",
[
(
"mixtral-8x7b-32768",
ChatGroq(
temperature=0, model_name="mixtral-8x7b-32768", groq_api_key="something"
),
),
("llama3", OllamaLLM(model="llama3")),
("llama3", ChatOllama(model="llama3")),
(
None,
FakeMessagesListChatModel(responses=[HumanMessage("Hello, how are you?")]),
),
(
"mistralai",
ChatMistralAI(mistral_api_key="mistral_api_key", model="mistralai"),
),
(
"text-gen",
TextGen(model_url="some-url"),
), # local deployments, does not have a model name
("claude-2", ChatAnthropic(model_name="claude-2")),
(
"claude-3-sonnet-20240229",
ChatAnthropic(model="claude-3-sonnet-20240229"),
),
("claude-2", Anthropic()),
("claude-2", Anthropic()),
("command", ChatCohere(model="command", cohere_api_key="command")),
(None, ChatTongyi(dashscope_api_key="dash")),
(
"amazon.titan-tg1-large",
BedrockLLM(
model="amazon.titan-tg1-large",
region="us-east-1",
client=MagicMock(),
),
),
(
"anthropic.claude-3-sonnet-20240229-v1:0",
ChatBedrock(
model_id="anthropic.claude-3-sonnet-20240229-v1:0",
region_name="us-east-1",
client=MagicMock(),
),
),
(
"claude-1",
BedrockLLM(
model="claude-1",
region="us-east-1",
client=MagicMock(),
),
),
],
)
def test_models(expected_model: str, model: Any):
serialized = default(model)
model_name = _extract_model_name(serialized)
assert model_name == expected_model
# all models here need to be tested here because we take the model from the kwargs / invocation_params or we need to make an actual call for setup
@pytest.mark.skip("Flaky")
@pytest.mark.parametrize(
"expected_model,model",
[
("gpt-3.5-turbo-0125", ChatOpenAI()),
("gpt-3.5-turbo-instruct", OpenAI()),
(
"gpt-3.5-turbo",
AzureChatOpenAI(
openai_api_version="2023-05-15",
model="gpt-3.5-turbo",
azure_deployment="your-deployment-name",
azure_endpoint="https://your-endpoint-name.azurewebsites.net",
),
),
# (
# "gpt2",
# HuggingFacePipeline(
# model_id="gpt2",
# model_kwargs={
# "max_new_tokens": 512,
# "top_k": 30,
# "temperature": 0.1,
# "repetition_penalty": 1.03,
# },
# ),
# ),
(
"qwen-72b-chat",
ChatTongyi(model="qwen-72b-chat", dashscope_api_key="dashscope"),
),
(
"gemini",
ChatVertexAI(
model_name="gemini", credentials=MagicMock(), project="some-project"
),
),
],
)
def test_entire_llm_call(expected_model, model):
callback = CallbackHandler()
with callback.client.start_as_current_span(name="parent") as span:
trace_id = span.trace_id
try:
# LLM calls are failing, because of missing API keys etc.
# However, we are still able to extract the model names beforehand.
model.invoke("Hello, how are you?", config={"callbacks": [callback]})
except Exception as e:
print(e)
pass
callback.client.flush()
api = get_api()
trace = api.trace.get(trace_id)
assert len(trace.observations) == 2
generation = list(filter(lambda o: o.type == "GENERATION", trace.observations))[0]
assert generation.model == expected_model