diff --git a/judgearena/cli.py b/judgearena/cli.py index eb94c83..e8ffed8 100644 --- a/judgearena/cli.py +++ b/judgearena/cli.py @@ -12,7 +12,7 @@ from judgearena.cli_common import ( add_common_arguments, - parse_engine_kwargs, + resolve_generation_configs, resolve_verbosity, ) from judgearena.estimate_elo_ratings import CliEloArgs @@ -172,6 +172,31 @@ def _resolve_model_a(args: argparse.Namespace) -> str | None: return args.model_A +def _common_base_kwargs(args: argparse.Namespace) -> dict: + """Build the kwargs shared by every CLI dataclass from a parsed Namespace. + + Centralised here so that ``CliArgs`` and ``CliEloArgs`` see the same + per-role :class:`GenerationConfig` instances and the same shared + fields without each call site repeating the field-by-field forwarding. + """ + gen_configs = resolve_generation_configs(args) + return { + "judge_model": args.judge_model, + "n_instructions": args.n_instructions, + "provide_explanation": args.provide_explanation, + "swap_mode": args.swap_mode, + "ignore_cache": args.ignore_cache, + "truncate_all_input_chars": args.truncate_all_input_chars, + "result_folder": args.result_folder, + "verbosity": resolve_verbosity(args), + "log_file": args.log_file, + "no_log_file": args.no_log_file, + "gen_A": gen_configs["A"], + "gen_B": gen_configs["B"], + "gen_judge": gen_configs["judge"], + } + + def _build_elo_args( args: argparse.Namespace, arena: str, model_a: str | None ) -> CliEloArgs: @@ -191,21 +216,7 @@ def _build_elo_args( n_bootstraps=args.n_bootstraps, seed=args.seed, baseline_model=args.baseline_model, - judge_model=args.judge_model, - n_instructions=args.n_instructions, - provide_explanation=args.provide_explanation, - swap_mode=args.swap_mode, - ignore_cache=args.ignore_cache, - truncate_all_input_chars=args.truncate_all_input_chars, - max_out_tokens_models=args.max_out_tokens_models, - max_out_tokens_judge=args.max_out_tokens_judge, - max_model_len=args.max_model_len, - chat_template=args.chat_template, - result_folder=args.result_folder, - engine_kwargs=parse_engine_kwargs(args.engine_kwargs), - verbosity=resolve_verbosity(args), - log_file=args.log_file, - no_log_file=args.no_log_file, + **_common_base_kwargs(args), ) @@ -219,21 +230,7 @@ def _build_generate_and_evaluate_args( model_A=model_a, model_B=args.model_B, use_tqdm=args.use_tqdm, - judge_model=args.judge_model, - n_instructions=args.n_instructions, - provide_explanation=args.provide_explanation, - swap_mode=args.swap_mode, - ignore_cache=args.ignore_cache, - truncate_all_input_chars=args.truncate_all_input_chars, - max_out_tokens_models=args.max_out_tokens_models, - max_out_tokens_judge=args.max_out_tokens_judge, - max_model_len=args.max_model_len, - chat_template=args.chat_template, - result_folder=args.result_folder, - engine_kwargs=parse_engine_kwargs(args.engine_kwargs), - verbosity=resolve_verbosity(args), - log_file=args.log_file, - no_log_file=args.no_log_file, + **_common_base_kwargs(args), ) diff --git a/judgearena/cli_common.py b/judgearena/cli_common.py index 58ce78b..2874ee5 100644 --- a/judgearena/cli_common.py +++ b/judgearena/cli_common.py @@ -3,14 +3,46 @@ Houses the base dataclass fields and argparse definitions that are common to both ``judgearena`` (generate_and_evaluate) and ``judgearena-elo`` (estimate_elo_ratings) CLI tools. + +The CLI exposes per-role generation parameters (``_A``, ``_B``, +``_judge``) so that every knob that can affect the generated text - +temperature, top_p, top_k, seed, max_tokens, max_model_len, chat_template, +engine_kwargs - is recorded explicitly in the run. Older flags like +``--max_out_tokens_models`` keep working as deprecated fan-out aliases. """ from __future__ import annotations import argparse import json +import warnings from dataclasses import dataclass, field +ROLE_NAMES: tuple[str, ...] = ("A", "B", "judge") +"""The three roles a JudgeArena run distinguishes for generation settings. + +``A`` and ``B`` are the two battle models; ``judge`` is the LLM judge. +""" + + +@dataclass(frozen=True) +class GenerationConfig: + """Sampling and inference configuration applied to a single role. + + Every field is optional (``None``/empty defaults mean "let the backend + pick"). The cache layer hashes this dataclass, so any change to a + field invalidates the cached completions for that role. + """ + + temperature: float | None = None + top_p: float | None = None + top_k: int | None = None + seed: int | None = None + max_tokens: int = 32768 + max_model_len: int | None = None + chat_template: str | None = None + engine_kwargs: dict = field(default_factory=dict) + @dataclass class BaseCliArgs: @@ -19,20 +51,22 @@ class BaseCliArgs: judge_model: str n_instructions: int | None = None - provide_explanation: bool = False swap_mode: str = "fixed" ignore_cache: bool = False truncate_all_input_chars: int = 8192 - max_out_tokens_models: int = 32768 - max_out_tokens_judge: int = 32768 - max_model_len: int | None = None - chat_template: str | None = None result_folder: str = "results" - engine_kwargs: dict = field(default_factory=dict) verbosity: int = 0 log_file: str | None = None no_log_file: bool = False + # Per-role generation configuration. Built by ``resolve_generation_configs`` + # from the CLI flags or set programmatically by API users. + gen_A: GenerationConfig = field(default_factory=GenerationConfig) + gen_B: GenerationConfig = field(default_factory=GenerationConfig) + gen_judge: GenerationConfig = field(default_factory=GenerationConfig) + + provide_explanation: bool = False + def __post_init__(self): supported_modes = ["fixed", "both"] assert self.swap_mode in supported_modes, ( @@ -40,6 +74,120 @@ def __post_init__(self): ) +# --------------------------------------------------------------------------- +# CLI plumbing +# --------------------------------------------------------------------------- + + +def _add_per_role_generation_arguments(parser: argparse.ArgumentParser) -> None: + """Register the 8 sampling flags x 3 roles (24 flags total).""" + for role in ROLE_NAMES: + suffix = role + parser.add_argument( + f"--temperature_{suffix}", + type=float, + default=None, + help=f"Sampling temperature for the {role!r} role.", + ) + parser.add_argument( + f"--top_p_{suffix}", + type=float, + default=None, + help=f"Nucleus-sampling top-p for the {role!r} role.", + ) + parser.add_argument( + f"--top_k_{suffix}", + type=int, + default=None, + help=f"Top-k sampling for the {role!r} role.", + ) + parser.add_argument( + f"--seed_{suffix}", + type=int, + default=None, + help=( + f"Random seed forwarded to the {role!r} role's backend. " + "Hosted providers honour this on a best-effort basis." + ), + ) + parser.add_argument( + f"--max_out_tokens_{suffix}", + type=int, + default=None, + help=( + f"Generation token budget for the {role!r} role. " + "For VLLM, keep this <= --max_model_len_* (if provided)." + ), + ) + parser.add_argument( + f"--max_model_len_{suffix}", + type=int, + default=None, + help=( + f"Optional total context window for the {role!r} role's " + "vLLM model (prompt + generation)." + ), + ) + parser.add_argument( + f"--chat_template_{suffix}", + type=str, + default=None, + help=( + f"Jinja2 chat template string used for the {role!r} role " + "instead of the model tokenizer's template." + ), + ) + parser.add_argument( + f"--engine_kwargs_{suffix}", + type=str, + default=None, + help=( + f"JSON dict of engine-specific kwargs forwarded to the " + f"{role!r} role's backend." + ), + ) + + +def _add_deprecated_aliases(parser: argparse.ArgumentParser) -> None: + """Register the legacy non-role-aware flags that fan out to A/B/(judge).""" + parser.add_argument( + "--max_out_tokens_models", + type=int, + default=None, + help=( + "[DEPRECATED] Use --max_out_tokens_A and --max_out_tokens_B. " + "Sets both A and B when neither role-specific flag is provided." + ), + ) + parser.add_argument( + "--max_model_len", + type=int, + default=None, + help=( + "[DEPRECATED] Use --max_model_len_{A,B,judge}. Fans out to all " + "three roles when no role-specific flag is provided." + ), + ) + parser.add_argument( + "--chat_template", + type=str, + default=None, + help=( + "[DEPRECATED] Use --chat_template_{A,B,judge}. Fans out to all " + "three roles when no role-specific flag is provided." + ), + ) + parser.add_argument( + "--engine_kwargs", + type=str, + default=None, + help=( + "[DEPRECATED] Use --engine_kwargs_{A,B,judge}. Fans out to all " + "three roles when no role-specific flag is provided." + ), + ) + + def add_common_arguments(parser: argparse.ArgumentParser) -> None: """Register the CLI flags shared by all judgearena entrypoints.""" parser.add_argument( @@ -106,60 +254,10 @@ def add_common_arguments(parser: argparse.ArgumentParser) -> None: "completion before judge evaluation." ), ) - parser.add_argument( - "--max_out_tokens_models", - type=int, - required=False, - default=32768, - help=( - "Generation token budget for each model A/B response. For VLLM, " - "keep this <= --max_model_len (if provided)." - ), - ) - parser.add_argument( - "--max_out_tokens_judge", - type=int, - required=False, - default=32768, - help=( - "Generation token budget for the judge response (reasoning + scores). " - "For VLLM, keep this <= --max_model_len (if provided)." - ), - ) - parser.add_argument( - "--max_model_len", - type=int, - required=False, - default=None, - help=( - "Optional total context window for VLLM models (prompt + generation). " - "This is independent from --max_out_tokens_models/--max_out_tokens_judge, " - "which only cap generated tokens. This is useful on smaller GPUs to " - "avoid OOM." - ), - ) - parser.add_argument( - "--chat_template", - type=str, - required=False, - default=None, - help=( - "Jinja2 chat template string to use instead of the model's tokenizer " - "template. If not provided, ChatML is used as fallback for models " - "without a chat template." - ), - ) - parser.add_argument( - "--engine_kwargs", - type=str, - required=False, - default="{}", - help=( - "JSON dict of engine-specific kwargs forwarded to the underlying " - "engine. Example for vLLM: " - '\'{"tensor_parallel_size": 2, "gpu_memory_utilization": 0.9}\'.' - ), - ) + + _add_per_role_generation_arguments(parser) + _add_deprecated_aliases(parser) + parser.add_argument( "-v", "--verbose", @@ -194,14 +292,19 @@ def add_common_arguments(parser: argparse.ArgumentParser) -> None: ) -def parse_engine_kwargs(raw: str) -> dict: - """Parse and validate a JSON string into an engine-kwargs dict.""" +def parse_engine_kwargs(raw: str | None) -> dict: + """Parse and validate a JSON string into an engine-kwargs dict. + + ``None`` and empty strings both resolve to ``{}``. + """ + if raw is None or raw == "": + return {} try: - engine_kwargs = json.loads(raw) if raw else {} + engine_kwargs = json.loads(raw) if not isinstance(engine_kwargs, dict): raise ValueError("engine_kwargs must be a JSON object") except Exception as e: - raise SystemExit(f"Failed to parse --engine_kwargs: {e}") from e + raise SystemExit(f"Failed to parse engine_kwargs: {e}") from e return engine_kwargs @@ -213,3 +316,127 @@ def resolve_verbosity(args: argparse.Namespace) -> int: if getattr(args, "quiet", False): return -1 return getattr(args, "verbose", 0) + + +# --------------------------------------------------------------------------- +# Resolver: argparse Namespace -> per-role GenerationConfig dict +# --------------------------------------------------------------------------- + + +_LEGACY_DEFAULT_MAX_TOKENS = 32768 + + +def _warn_deprecated_alias(flag: str, replacements: tuple[str, ...]) -> None: + """Emit a single ``DeprecationWarning`` pointing users at the new flags.""" + warnings.warn( + f"{flag} is deprecated; use {', '.join(replacements)} instead.", + DeprecationWarning, + stacklevel=3, + ) + + +def resolve_generation_configs( + args: argparse.Namespace, +) -> dict[str, GenerationConfig]: + """Build ``{A, B, judge}`` :class:`GenerationConfig` from a parsed Namespace. + + Per-role flags take precedence; deprecated aliases fan out as fallbacks + and emit a ``DeprecationWarning``. + """ + legacy_max_tokens_models = getattr(args, "max_out_tokens_models", None) + legacy_max_model_len = getattr(args, "max_model_len", None) + legacy_chat_template = getattr(args, "chat_template", None) + legacy_engine_kwargs_raw = getattr(args, "engine_kwargs", None) + legacy_engine_kwargs = ( + parse_engine_kwargs(legacy_engine_kwargs_raw) + if legacy_engine_kwargs_raw is not None + else None + ) + + if legacy_max_tokens_models is not None: + _warn_deprecated_alias( + "--max_out_tokens_models", + ("--max_out_tokens_A", "--max_out_tokens_B"), + ) + if legacy_max_model_len is not None: + _warn_deprecated_alias( + "--max_model_len", + ("--max_model_len_A", "--max_model_len_B", "--max_model_len_judge"), + ) + if legacy_chat_template is not None: + _warn_deprecated_alias( + "--chat_template", + ("--chat_template_A", "--chat_template_B", "--chat_template_judge"), + ) + if legacy_engine_kwargs_raw is not None: + _warn_deprecated_alias( + "--engine_kwargs", + ("--engine_kwargs_A", "--engine_kwargs_B", "--engine_kwargs_judge"), + ) + + configs: dict[str, GenerationConfig] = {} + for role in ROLE_NAMES: + explicit_max_tokens = getattr(args, f"max_out_tokens_{role}", None) + if explicit_max_tokens is not None: + max_tokens = explicit_max_tokens + elif role in ("A", "B") and legacy_max_tokens_models is not None: + max_tokens = legacy_max_tokens_models + else: + max_tokens = _LEGACY_DEFAULT_MAX_TOKENS + + explicit_max_model_len = getattr(args, f"max_model_len_{role}", None) + max_model_len = ( + explicit_max_model_len + if explicit_max_model_len is not None + else legacy_max_model_len + ) + + explicit_chat_template = getattr(args, f"chat_template_{role}", None) + chat_template = ( + explicit_chat_template + if explicit_chat_template is not None + else legacy_chat_template + ) + + explicit_engine_kwargs_raw = getattr(args, f"engine_kwargs_{role}", None) + if explicit_engine_kwargs_raw is not None: + engine_kwargs = parse_engine_kwargs(explicit_engine_kwargs_raw) + elif legacy_engine_kwargs is not None: + engine_kwargs = dict(legacy_engine_kwargs) + else: + engine_kwargs = {} + + configs[role] = GenerationConfig( + temperature=getattr(args, f"temperature_{role}", None), + top_p=getattr(args, f"top_p_{role}", None), + top_k=getattr(args, f"top_k_{role}", None), + seed=getattr(args, f"seed_{role}", None), + max_tokens=max_tokens, + max_model_len=max_model_len, + chat_template=chat_template, + engine_kwargs=engine_kwargs, + ) + return configs + + +def gen_config_to_invoke_kwargs(gen: GenerationConfig) -> dict: + """Flatten a :class:`GenerationConfig` to the kwargs accepted by ``make_model``. + + Backends ignore unknown kwargs (vLLM-only fields are stripped inside + :func:`judgearena.utils.make_model` for hosted providers). + """ + kwargs: dict[str, object] = {"max_tokens": gen.max_tokens} + if gen.temperature is not None: + kwargs["temperature"] = gen.temperature + if gen.top_p is not None: + kwargs["top_p"] = gen.top_p + if gen.top_k is not None: + kwargs["top_k"] = gen.top_k + if gen.seed is not None: + kwargs["seed"] = gen.seed + if gen.max_model_len is not None: + kwargs["max_model_len"] = gen.max_model_len + if gen.chat_template is not None: + kwargs["chat_template"] = gen.chat_template + kwargs.update(gen.engine_kwargs) + return kwargs diff --git a/judgearena/estimate_elo_ratings.py b/judgearena/estimate_elo_ratings.py index 51ba6e2..9cc2ed9 100644 --- a/judgearena/estimate_elo_ratings.py +++ b/judgearena/estimate_elo_ratings.py @@ -7,7 +7,7 @@ from sklearn.linear_model import LogisticRegression from judgearena.arenas_utils import _extract_instruction_text, load_arena_dataframe -from judgearena.cli_common import BaseCliArgs +from judgearena.cli_common import BaseCliArgs, gen_config_to_invoke_kwargs from judgearena.evaluate import judge_and_parse_prefs from judgearena.generate import generate_instructions from judgearena.log import get_logger @@ -188,34 +188,36 @@ def main(args: CliEloArgs) -> dict: # Step 2: Generate completions for the model under evaluation logger.info("Step 2: Generating completions with %s", args.model) - # Only pass extra engine kwargs that are not None - extra_kwargs = dict(args.engine_kwargs) - if args.max_model_len is not None: - extra_kwargs["max_model_len"] = args.max_model_len - if args.chat_template is not None: - extra_kwargs["chat_template"] = args.chat_template + # The ``A`` generation config drives the model under evaluation; the + # ``judge`` config drives the LLM judge that scores it. Both are + # resolved from per-role CLI flags by ``cli.resolve_generation_configs``. + gen_a = args.gen_A + gen_judge = args.gen_judge + + invoke_kwargs_a = gen_config_to_invoke_kwargs(gen_a) + invoke_kwargs_a.pop("max_tokens", None) use_tqdm = False gen_fun = partial( generate_instructions, truncate_input_chars=args.truncate_all_input_chars, - max_tokens=args.max_out_tokens_models, + max_tokens=gen_a.max_tokens, use_tqdm=use_tqdm, - **extra_kwargs, + **invoke_kwargs_a, ) def replace_slash(s: str) -> str: return s.replace("/", "_") languages_str = "-".join(sorted(args.languages)) if args.languages else "all" - extra_kwargs_str = ( - "_".join(f"{k}={v}" for k, v in sorted(extra_kwargs.items())) - if extra_kwargs - else "" + extra_kwargs_str = "_".join( + f"{k}={v}" + for k, v in sorted(gen_config_to_invoke_kwargs(gen_a).items()) + if k != "max_tokens" ) cache_suffix = ( f"{args.arena}_{replace_slash(args.model)}_" f"{args.n_instructions}_{args.n_instructions_per_language}_" - f"{languages_str}_{args.truncate_all_input_chars}_{args.max_out_tokens_models}" + f"{languages_str}_{args.truncate_all_input_chars}_{gen_a.max_tokens}" + (f"_{extra_kwargs_str}" if extra_kwargs_str else "") ) if len(cache_suffix) > 100: @@ -268,17 +270,12 @@ def replace_slash(s: str) -> str: for i in range(n) ] - judge_extra_kwargs = {} - if args.max_model_len is not None: - judge_extra_kwargs["max_model_len"] = args.max_model_len - if args.chat_template is not None: - judge_extra_kwargs["chat_template"] = args.chat_template + judge_invoke_kwargs = gen_config_to_invoke_kwargs(gen_judge) def run_judge() -> pd.DataFrame: judge_chat_model = make_model( model=args.judge_model, - max_tokens=args.max_out_tokens_judge, - **judge_extra_kwargs, + **judge_invoke_kwargs, ) annotations, _, prefs = judge_and_parse_prefs( judge_chat_model=judge_chat_model, diff --git a/judgearena/generate_and_evaluate.py b/judgearena/generate_and_evaluate.py index 2919280..71a2e73 100644 --- a/judgearena/generate_and_evaluate.py +++ b/judgearena/generate_and_evaluate.py @@ -11,7 +11,11 @@ import pandas as pd -from judgearena.cli_common import BaseCliArgs +from judgearena.cli_common import ( + BaseCliArgs, + GenerationConfig, + gen_config_to_invoke_kwargs, +) from judgearena.evaluate import judge_and_parse_prefs, resolve_judge_prompts from judgearena.generate import generate_base, generate_instructions from judgearena.instruction_dataset import load_instructions @@ -183,28 +187,37 @@ def main(args: CliArgs): args.model_B, ) - # TODO currently we just support base models for fluency, we could also support instruction-tuned models - gen_fun = ( - partial( - generate_base, - truncate_input_chars=args.truncate_all_input_chars, - max_tokens=args.max_out_tokens_models, - max_model_len=args.max_model_len, - chat_template=args.chat_template, - use_tqdm=args.use_tqdm, - **args.engine_kwargs, - ) - if is_fluency_task - else partial( + # Per-role generation configs (resolved by the CLI dispatcher). Each + # config carries every knob that can affect the generated text - + # temperature, top_p, top_k, seed, max_tokens, max_model_len, + # chat_template, engine_kwargs - so model A, model B and the judge + # can be configured independently. + gen_a: GenerationConfig = args.gen_A + gen_b: GenerationConfig = args.gen_B + gen_judge: GenerationConfig = args.gen_judge + + def _build_gen_fn(gen: GenerationConfig): + invoke_kwargs = gen_config_to_invoke_kwargs(gen) + # ``max_tokens`` is passed as a dedicated arg below. + invoke_kwargs.pop("max_tokens", None) + if is_fluency_task: + # TODO currently we just support base models for fluency, we + # could also support instruction-tuned models. + return partial( + generate_base, + truncate_input_chars=args.truncate_all_input_chars, + max_tokens=gen.max_tokens, + use_tqdm=args.use_tqdm, + **invoke_kwargs, + ) + return partial( generate_instructions, truncate_input_chars=args.truncate_all_input_chars, - max_tokens=args.max_out_tokens_models, - max_model_len=args.max_model_len, - chat_template=args.chat_template, + max_tokens=gen.max_tokens, use_tqdm=args.use_tqdm, - **args.engine_kwargs, + **invoke_kwargs, ) - ) + dataset_completions_A = try_load_dataset_completions( args.task, args.model_A, n_instructions ) @@ -213,8 +226,9 @@ def main(args: CliArgs): :, "completion" ] else: + gen_fn_a = _build_gen_fn(gen_a) completions_A = cache_function_dataframe( - lambda: gen_fun( + lambda: gen_fn_a( instructions=instructions, model=args.model_A, use_tqdm=args.use_tqdm, @@ -232,8 +246,9 @@ def main(args: CliArgs): :, "completion" ] else: + gen_fn_b = _build_gen_fn(gen_b) completions_B = cache_function_dataframe( - lambda: gen_fun( + lambda: gen_fn_b( instructions=instructions, model=args.model_B, use_tqdm=args.use_tqdm, @@ -249,10 +264,7 @@ def main(args: CliArgs): judge_chat_model = make_model( model=args.judge_model, - max_tokens=args.max_out_tokens_judge, - max_model_len=args.max_model_len, - chat_template=args.chat_template, - **args.engine_kwargs, + **gen_config_to_invoke_kwargs(gen_judge), ) # save argument for results analysis diff --git a/judgearena/mt_bench/mt_bench_utils.py b/judgearena/mt_bench/mt_bench_utils.py index b28f859..9d173bb 100644 --- a/judgearena/mt_bench/mt_bench_utils.py +++ b/judgearena/mt_bench/mt_bench_utils.py @@ -15,6 +15,7 @@ import pandas as pd +from judgearena.cli_common import GenerationConfig, gen_config_to_invoke_kwargs from judgearena.eval_utils import _compute_grouped_stats, print_results from judgearena.generate import generate_multiturn from judgearena.instruction_dataset import load_instructions @@ -39,26 +40,34 @@ def _generate_mt_bench_completions( ) -> tuple[pd.DataFrame, pd.DataFrame]: cache_prefix = "mt-bench" - def _run_generation(model_name: str) -> pd.DataFrame: + def _run_generation(model_name: str, gen: GenerationConfig) -> pd.DataFrame: + invoke_kwargs = gen_config_to_invoke_kwargs(gen) + invoke_kwargs.pop("max_tokens", None) + # MT-Bench's category-aware temperatures only kick in when the user + # has *not* explicitly pinned a per-role temperature; otherwise the + # CLI override should win for reproducibility. + if gen.temperature is None: + temperature_config = FASTCHAT_TEMPERATURE_CONFIG + else: + temperature_config = None return generate_multiturn( questions=questions_df, model=model_name, truncate_input_chars=args.truncate_all_input_chars, - max_tokens=args.max_out_tokens_models, + max_tokens=gen.max_tokens, use_tqdm=args.use_tqdm, - max_model_len=args.max_model_len, - chat_template=args.chat_template, - temperature_config=FASTCHAT_TEMPERATURE_CONFIG, + temperature_config=temperature_config, + **invoke_kwargs, ) completions_a = cache_function_dataframe( - lambda: _run_generation(args.model_A), + lambda: _run_generation(args.model_A, args.gen_A), ignore_cache=ignore_cache, cache_name=f"{cache_prefix}_{args.model_A}_{args.n_instructions}", ).set_index("instruction_index") completions_b = cache_function_dataframe( - lambda: _run_generation(args.model_B), + lambda: _run_generation(args.model_B, args.gen_B), ignore_cache=ignore_cache, cache_name=f"{cache_prefix}_{args.model_B}_{args.n_instructions}", ).set_index("instruction_index") @@ -160,12 +169,14 @@ def run_mt_bench( questions_df=questions_df, ignore_cache=ignore_cache, ) + # MT-Bench historically forced judge temperature to 0 so its + # category-driven verdicts stay deterministic; we honour any explicit + # CLI override but fall back to 0 otherwise. + judge_invoke_kwargs = gen_config_to_invoke_kwargs(args.gen_judge) + judge_invoke_kwargs.setdefault("temperature", 0.0) judge_chat_model = make_model( model=args.judge_model, - max_tokens=args.max_out_tokens_judge, - temperature=0.0, - max_model_len=args.max_model_len, - chat_template=args.chat_template, + **judge_invoke_kwargs, ) return _run_mt_bench_fastchat( args=args, diff --git a/judgearena/utils.py b/judgearena/utils.py index 993ef01..a4006ee 100644 --- a/judgearena/utils.py +++ b/judgearena/utils.py @@ -127,7 +127,43 @@ def safe_text(value: object, truncate_chars: int | None) -> str: return truncate(str(value), max_len=truncate_chars) -def do_inference(chat_model, inputs, use_tqdm: bool = False): +def _extract_response_metadata(message) -> dict | None: + """Pull provider response metadata (e.g. ``system_fingerprint``) off an AIMessage. + + ChatVLLM and DummyModel return plain strings (no metadata). LangChain + chat models return ``AIMessage`` whose ``response_metadata`` attribute + contains the upstream provider's bookkeeping; we keep the keys that + matter for reproducibility. + """ + response_metadata = getattr(message, "response_metadata", None) + if not isinstance(response_metadata, dict) or not response_metadata: + return None + keys_of_interest = ( + "system_fingerprint", + "model_name", + "model", + "id", + "finish_reason", + ) + captured = { + k: response_metadata[k] for k in keys_of_interest if k in response_metadata + } + return captured or None + + +def do_inference( + chat_model, + inputs, + use_tqdm: bool = False, + out_metadata: list[dict | None] | None = None, +): + """Run a batch of LLM calls, with retries and optional response-metadata capture. + + When ``out_metadata`` is provided, one entry per input is appended; + each entry is either a dict of provider fingerprint fields (model, + ``system_fingerprint``, etc.) or ``None`` for backends that don't + expose any. + """ # Retries on rate-limit/server errors with exponential backoff. # Async path retries individual calls; batch path splits into 4^attempt chunks on failure. invoke_kwargs = { @@ -199,6 +235,10 @@ def batch_with_retry(batch_inputs, max_retries=5, base_delay=1.0): res = batch_with_retry(inputs) + if out_metadata is not None: + for raw in res: + out_metadata.append(_extract_response_metadata(raw)) + # Not sure why the API of Langchain returns sometime a string and sometimes an AIMessage object # is it because of using Chat and barebones models? # when using OpenAI, the output is AIMessage not a string... @@ -207,8 +247,17 @@ def batch_with_retry(batch_inputs, max_retries=5, base_delay=1.0): class DummyModel: - def __init__(self, name: str): + """In-process stub backend used in tests and offline smoke runs. + + The constructor accepts the same keyword arguments that hosted + backends do (``temperature``, ``top_p``, ``seed``, ...) and stores + them under :attr:`init_kwargs` so tests can assert that the + per-role generation configs reach the model layer. + """ + + def __init__(self, name: str, **init_kwargs): self.name = name + self.init_kwargs = dict(init_kwargs) self.message = "/".join(name.split("/")[1:]) def batch(self, inputs, **invoke_kwargs) -> list[str]: @@ -241,6 +290,10 @@ def __init__( model: str, max_tokens: int = 8192, chat_template: str | None = None, + temperature: float | None = None, + top_p: float | None = None, + top_k: int | None = None, + seed: int | None = None, **vllm_kwargs, ): from vllm import LLM, SamplingParams @@ -273,12 +326,24 @@ def __init__( stacklevel=2, ) + if seed is not None: + # vLLM honours the seed argument both at engine-init time (for + # tensor parallel determinism) and at sampling time. + vllm_kwargs.setdefault("seed", seed) + self.llm = LLM(model=model, trust_remote_code=True, **vllm_kwargs) - self.sampling_params = SamplingParams( - max_tokens=max_tokens, - temperature=0.6, - top_p=0.95, - ) + # Keep historical defaults when the caller did not specify a value; + # forward explicit values straight through so they are reproducible. + self._effective_sampling_kwargs: dict = { + "max_tokens": max_tokens, + "temperature": 0.6 if temperature is None else float(temperature), + "top_p": 0.95 if top_p is None else float(top_p), + } + if top_k is not None: + self._effective_sampling_kwargs["top_k"] = int(top_k) + if seed is not None: + self._effective_sampling_kwargs["seed"] = int(seed) + self.sampling_params = SamplingParams(**self._effective_sampling_kwargs) # Resolve chat template: # 1. Explicit override always wins → use chat() with that template @@ -388,6 +453,20 @@ async def ainvoke(self, input_item, **invoke_kwargs): None, lambda: self.invoke(input_item, **invoke_kwargs) ) + def set_temperature(self, temperature: float) -> None: + """Mutate the active SamplingParams to use ``temperature``. + + Used by MT-Bench's category-aware temperature switching so we + don't have to reload the vLLM engine between categories. + """ + from vllm import SamplingParams + + self._effective_sampling_kwargs["temperature"] = float(temperature) + self.sampling_params = SamplingParams(**self._effective_sampling_kwargs) + + +_VLLM_ONLY_KWARGS = ("max_model_len", "chat_template") + def make_model(model: str, max_tokens: int | None = 8192, **engine_kwargs): """Instantiate a model wrapper from a provider/model-name string. @@ -397,6 +476,9 @@ def make_model(model: str, max_tokens: int | None = 8192, **engine_kwargs): ``VLLM/meta-llama/Llama-3.3-70B-Instruct``. max_tokens: Maximum tokens the model may generate. **engine_kwargs: Engine-specific options forwarded to the model wrapper. + Common keys honoured across backends: ``temperature``, ``top_p``, + ``top_k``, ``seed``. vLLM-only keys (``max_model_len``, + ``chat_template``) are stripped before reaching hosted providers. """ # Avoid mutating the original engine_kwargs dictionary # NOTE: this is a shallow copy since we are not modifying any @@ -406,17 +488,34 @@ def make_model(model: str, max_tokens: int | None = 8192, **engine_kwargs): # Dedicated arguments like max_tokens always win over engine_kwargs. engine_kwargs["max_tokens"] = max_tokens or 8192 + # Pluck out cross-backend sampling controls so we can route them + # through the provider-appropriate constructor argument. + temperature = engine_kwargs.pop("temperature", None) + top_p = engine_kwargs.pop("top_p", None) + top_k = engine_kwargs.pop("top_k", None) + seed = engine_kwargs.pop("seed", None) + model_provider = model.split("/")[0] # vLLM-engine-only kwargs must not leak to remote-API providers # (OpenRouter, OpenAI, Together): langchain-openai forwards unknown # kwargs via model_kwargs into chat.completions.create, which rejects them. if model_provider != "VLLM": - engine_kwargs.pop("max_model_len", None) - engine_kwargs.pop("chat_template", None) + for key in _VLLM_ONLY_KWARGS: + engine_kwargs.pop(key, None) if model_provider == "Dummy": - return DummyModel(model) + # Forward sampling kwargs so tests can assert they reached the model. + dummy_kwargs = {k: v for k, v in engine_kwargs.items() if v is not None} + if temperature is not None: + dummy_kwargs["temperature"] = temperature + if top_p is not None: + dummy_kwargs["top_p"] = top_p + if top_k is not None: + dummy_kwargs["top_k"] = top_k + if seed is not None: + dummy_kwargs["seed"] = seed + return DummyModel(model, **dummy_kwargs) model_name = "/".join(model.split("/")[1:]) logger.info("Loading %s(model=%s)", model_provider, model_name) @@ -425,6 +524,14 @@ def make_model(model: str, max_tokens: int | None = 8192, **engine_kwargs): if model_provider == "VLLM": engine_kwargs = {k: v for k, v in engine_kwargs.items() if v is not None} engine_kwargs["chat_template"] = engine_kwargs.get("chat_template", None) + if temperature is not None: + engine_kwargs["temperature"] = temperature + if top_p is not None: + engine_kwargs["top_p"] = top_p + if top_k is not None: + engine_kwargs["top_k"] = top_k + if seed is not None: + engine_kwargs["seed"] = seed return ChatVLLM( model=model_name, @@ -433,11 +540,23 @@ def make_model(model: str, max_tokens: int | None = 8192, **engine_kwargs): if model_provider == "OpenRouter": # Special case we need to override API url and key + openai_kwargs = dict(engine_kwargs) + if temperature is not None: + openai_kwargs["temperature"] = temperature + if top_p is not None: + openai_kwargs["top_p"] = top_p + if seed is not None: + openai_kwargs["seed"] = seed + # ``top_k`` isn't a first-class OpenAI parameter; tunnel it through + # ``model_kwargs`` so providers that recognise it (vLLM/OpenRouter + # via Anthropic-compatible models) can still pick it up. + if top_k is not None: + openai_kwargs.setdefault("model_kwargs", {})["top_k"] = top_k return ChatOpenAI( api_key=os.getenv("OPENROUTER_API_KEY"), base_url="https://openrouter.ai/api/v1", model=model_name, - **engine_kwargs, + **openai_kwargs, ) else: model_classes = [ @@ -446,8 +565,24 @@ def make_model(model: str, max_tokens: int | None = 8192, **engine_kwargs): ] if model_provider == "LlamaCpp": engine_kwargs["model_path"] = model_name + if temperature is not None: + engine_kwargs["temperature"] = temperature + if top_p is not None: + engine_kwargs["top_p"] = top_p + if top_k is not None: + engine_kwargs["top_k"] = top_k + if seed is not None: + engine_kwargs["seed"] = seed else: engine_kwargs["model"] = model_name + if temperature is not None: + engine_kwargs["temperature"] = temperature + if top_p is not None: + engine_kwargs["top_p"] = top_p + if seed is not None: + engine_kwargs["seed"] = seed + if top_k is not None: + engine_kwargs.setdefault("model_kwargs", {})["top_k"] = top_k try: from langchain_together.llms import Together diff --git a/tests/test_cli.py b/tests/test_cli.py index 30be4fa..7b9b07e 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -312,7 +312,56 @@ def test_elo_forwards_optional_flags(capture_mains): assert elo_args.baseline_model == "gpt-4o" -def test_engine_kwargs_parsed_as_json(capture_mains): +def test_engine_kwargs_fans_out_to_all_roles_as_deprecation(capture_mains): + """Deprecated --engine_kwargs fans out to every role's GenerationConfig.""" + with pytest.warns(DeprecationWarning, match="--engine_kwargs is deprecated"): + cli_module.cli( + [ + "--task", + "alpaca-eval", + "--model_A", + "Dummy/A", + "--model_B", + "Dummy/B", + "--judge", + "Dummy/J", + "--engine_kwargs", + '{"tensor_parallel_size": 4}', + ] + ) + ge_args: CliArgs = capture_mains["args"] + assert ge_args.gen_A.engine_kwargs == {"tensor_parallel_size": 4} + assert ge_args.gen_B.engine_kwargs == {"tensor_parallel_size": 4} + assert ge_args.gen_judge.engine_kwargs == {"tensor_parallel_size": 4} + + +def test_per_role_engine_kwargs_override_legacy_flag(capture_mains): + """A per-role --engine_kwargs_A overrides the deprecated fan-out alias.""" + with pytest.warns(DeprecationWarning): + cli_module.cli( + [ + "--task", + "alpaca-eval", + "--model_A", + "Dummy/A", + "--model_B", + "Dummy/B", + "--judge", + "Dummy/J", + "--engine_kwargs", + '{"shared": 1}', + "--engine_kwargs_A", + '{"only_A": true}', + ] + ) + ge_args: CliArgs = capture_mains["args"] + assert ge_args.gen_A.engine_kwargs == {"only_A": True} + assert ge_args.gen_B.engine_kwargs == {"shared": 1} + assert ge_args.gen_judge.engine_kwargs == {"shared": 1} + + +def test_per_role_temperature_and_seed_flags(capture_mains): + """Per-role sampling flags land on the right GenerationConfig.""" cli_module.cli( [ "--task", @@ -323,9 +372,48 @@ def test_engine_kwargs_parsed_as_json(capture_mains): "Dummy/B", "--judge", "Dummy/J", - "--engine_kwargs", - '{"tensor_parallel_size": 4}', + "--temperature_A", + "0.0", + "--temperature_B", + "0.7", + "--temperature_judge", + "0.0", + "--seed_A", + "42", + "--seed_judge", + "13", ] ) ge_args: CliArgs = capture_mains["args"] - assert ge_args.engine_kwargs == {"tensor_parallel_size": 4} + assert ge_args.gen_A.temperature == 0.0 + assert ge_args.gen_B.temperature == 0.7 + assert ge_args.gen_judge.temperature == 0.0 + assert ge_args.gen_A.seed == 42 + assert ge_args.gen_judge.seed == 13 + assert ge_args.gen_B.seed is None + + +def test_max_out_tokens_models_fans_out_to_a_and_b(capture_mains): + """Deprecated --max_out_tokens_models populates A and B but not judge.""" + with pytest.warns( + DeprecationWarning, match="--max_out_tokens_models is deprecated" + ): + cli_module.cli( + [ + "--task", + "alpaca-eval", + "--model_A", + "Dummy/A", + "--model_B", + "Dummy/B", + "--judge", + "Dummy/J", + "--max_out_tokens_models", + "256", + ] + ) + ge_args: CliArgs = capture_mains["args"] + assert ge_args.gen_A.max_tokens == 256 + assert ge_args.gen_B.max_tokens == 256 + # judge keeps the historical default when --max_out_tokens_judge is unset + assert ge_args.gen_judge.max_tokens == 32768 diff --git a/tests/test_seed_plumbing.py b/tests/test_seed_plumbing.py new file mode 100644 index 0000000..c0e69f8 --- /dev/null +++ b/tests/test_seed_plumbing.py @@ -0,0 +1,96 @@ +"""Verify per-role sampling parameters reach the underlying backend. + +Uses the ``Dummy`` provider (which records every kwarg the constructor sees +on :attr:`DummyModel.init_kwargs`) so we can assert that ``--seed_A``, +``--temperature_A`` etc. are actually forwarded to the model layer rather +than dropped on the floor by the CLI dispatcher. +""" + +from __future__ import annotations + +from judgearena.cli_common import GenerationConfig, gen_config_to_invoke_kwargs +from judgearena.utils import make_model + + +def test_make_model_dummy_captures_temperature_and_seed(): + """Dummy backend records the constructor kwargs so tests can verify them.""" + model = make_model("Dummy/foo", max_tokens=64, temperature=0.3, seed=42) + assert model.init_kwargs.get("temperature") == 0.3 + assert model.init_kwargs.get("seed") == 42 + + +def test_make_model_dummy_forwards_top_p_and_top_k(): + model = make_model("Dummy/foo", max_tokens=64, top_p=0.95, top_k=50) + assert model.init_kwargs.get("top_p") == 0.95 + assert model.init_kwargs.get("top_k") == 50 + + +def test_make_model_dummy_forwards_engine_kwargs(): + """Arbitrary engine-specific kwargs flow through to DummyModel.init_kwargs.""" + model = make_model("Dummy/foo", max_tokens=64, my_extra_flag="hello") + assert model.init_kwargs.get("my_extra_flag") == "hello" + + +def test_gen_config_to_invoke_kwargs_skips_none(): + """Unset fields (None) on GenerationConfig are not forwarded.""" + cfg = GenerationConfig(temperature=None, seed=None, top_p=None, top_k=None) + kwargs = gen_config_to_invoke_kwargs(cfg) + # Only the always-present max_tokens plus whatever was explicitly set. + assert kwargs == {"max_tokens": cfg.max_tokens} + + +def test_gen_config_to_invoke_kwargs_includes_set_fields(): + cfg = GenerationConfig( + temperature=0.0, + top_p=0.95, + top_k=50, + seed=7, + max_tokens=128, + max_model_len=8192, + chat_template="", + engine_kwargs={"tensor_parallel_size": 2}, + ) + kwargs = gen_config_to_invoke_kwargs(cfg) + assert kwargs["temperature"] == 0.0 + assert kwargs["top_p"] == 0.95 + assert kwargs["top_k"] == 50 + assert kwargs["seed"] == 7 + assert kwargs["max_tokens"] == 128 + assert kwargs["max_model_len"] == 8192 + assert kwargs["chat_template"] == "" + assert kwargs["tensor_parallel_size"] == 2 + + +def test_per_role_seed_only_reaches_correct_model(monkeypatch): + """A per-role --seed_A flag must not bleed into model B / judge.""" + from judgearena import cli as cli_module + + captured_args: dict = {} + + def fake_main_ge(args) -> None: + captured_args["args"] = args + + monkeypatch.setattr(cli_module, "main_generate_and_evaluate", fake_main_ge) + + cli_module.cli( + [ + "--task", + "alpaca-eval", + "--model_A", + "Dummy/A", + "--model_B", + "Dummy/B", + "--judge", + "Dummy/J", + "--seed_A", + "11", + "--temperature_judge", + "0.0", + ] + ) + args = captured_args["args"] + assert args.gen_A.seed == 11 + assert args.gen_B.seed is None + assert args.gen_judge.seed is None + assert args.gen_judge.temperature == 0.0 + assert args.gen_A.temperature is None