-
Notifications
You must be signed in to change notification settings - Fork 5
feat: add support for the original mt-bench #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 25 commits
ba4220d
d2a5a42
a828adb
0dcebf9
d60073b
38f63ee
6f5e0fc
42ff2ae
8fcb032
df958af
35856f2
6a11182
fecd3ed
0b4eaec
29340b0
2c294f1
4be61bf
51d2597
8dee7b2
fdc9410
48c5373
648a9be
14f747e
e67ea79
4089be8
03f5cce
8ffe3a6
b877f11
c2056b5
41cd15d
0ca66c5
a295305
0fb9700
e5670ea
6dd78fd
0094eea
f522e5b
6a851c3
caaa079
2e8e04e
5a314a7
8c91606
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,16 +4,10 @@ | |
| from openjury.utils import ( | ||
| do_inference, | ||
| make_model, | ||
| truncate, | ||
| ) | ||
|
|
||
|
|
||
| def truncate(s: str, max_len: int | None = None): | ||
| if max_len is not None: | ||
| return s[:max_len] | ||
| else: | ||
| return s | ||
|
|
||
|
|
||
| def generate_instructions( | ||
| instructions: pd.Series, | ||
| model: str, | ||
|
|
@@ -57,6 +51,92 @@ def generate_instructions( | |
| return df_outputs | ||
|
|
||
|
|
||
| def generate_multiturn( | ||
| questions: pd.DataFrame, | ||
| model: str, | ||
| truncate_input_chars: int | None = 8192, | ||
| max_tokens: int | None = 8192, | ||
| use_tqdm: bool = True, | ||
| **model_kwargs, | ||
| ) -> pd.DataFrame: | ||
| """Generate two-turn completions for MT-Bench style questions. | ||
|
|
||
| Generates turn 1 answers first, then uses them as conversation context | ||
| to generate turn 2 answers. | ||
|
|
||
| Args: | ||
| questions: DataFrame with columns turn_1, turn_2, and index instruction_index. | ||
| model: Model specification string (e.g. "VLLM/model-name"). | ||
| **model_kwargs: Provider-specific options forwarded to make_model | ||
| (e.g. max_model_len, chat_template for VLLM). | ||
| Returns: | ||
| DataFrame with columns: instruction_index, completion_turn_1, completion_turn_2 | ||
| """ | ||
| chat_model = make_model(model, max_tokens=max_tokens, **model_kwargs) | ||
|
|
||
| system_prompt = "You are a helpful assistant." | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can use a better
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, we have a naive default also in general (it is not blocking for this PR as we can change/improve it later).
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added the |
||
| turn1_template = ChatPromptTemplate.from_messages( | ||
| [("system", system_prompt), ("user", "{user_prompt}")] | ||
| ) | ||
|
|
||
| turn1_inputs = turn1_template.batch( | ||
| [ | ||
| {"user_prompt": truncate(row["turn_1"], max_len=truncate_input_chars)} | ||
| for _, row in questions.iterrows() | ||
| ] | ||
| ) | ||
|
|
||
| print(f"Generating turn 1 completions ({len(turn1_inputs)} questions).") | ||
| completions_turn_1 = do_inference( | ||
| chat_model=chat_model, | ||
| inputs=turn1_inputs, | ||
| use_tqdm=use_tqdm, | ||
| ) | ||
|
|
||
| turn2_inputs = [] | ||
| for (_, row), t1_answer in zip(questions.iterrows(), completions_turn_1): | ||
| if row["turn_2"] is None: | ||
| turn2_inputs.append( | ||
| turn1_template.invoke( | ||
| {"user_prompt": "No follow-up question."} | ||
| ) | ||
| ) | ||
| else: | ||
| multi_turn_template = ChatPromptTemplate.from_messages( | ||
| [ | ||
| ("system", system_prompt), | ||
| ("user", "{turn_1}"), | ||
| ("assistant", "{turn_1_answer}"), | ||
| ("user", "{turn_2}"), | ||
| ] | ||
| ) | ||
| turn2_inputs.append( | ||
| multi_turn_template.invoke( | ||
| { | ||
| "turn_1": truncate(row["turn_1"], max_len=truncate_input_chars), | ||
| "turn_1_answer": truncate(str(t1_answer), max_len=truncate_input_chars), | ||
| "turn_2": truncate(row["turn_2"], max_len=truncate_input_chars), | ||
| } | ||
| ) | ||
| ) | ||
|
|
||
| print(f"Generating turn 2 completions ({len(turn2_inputs)} questions).") | ||
| completions_turn_2 = do_inference( | ||
| chat_model=chat_model, | ||
| inputs=turn2_inputs, | ||
| use_tqdm=use_tqdm, | ||
| ) | ||
|
|
||
| df_outputs = pd.DataFrame( | ||
| data={ | ||
| "instruction_index": questions.index.tolist(), | ||
| "completion_turn_1": completions_turn_1, | ||
| "completion_turn_2": completions_turn_2, | ||
| }, | ||
| ) | ||
| return df_outputs | ||
|
|
||
|
|
||
| def generate_base( | ||
| instructions: pd.Series, | ||
| model: str, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💪