Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 28 additions & 31 deletions judgearena/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from judgearena.cli_common import (
add_common_arguments,
parse_engine_kwargs,
resolve_generation_configs,
resolve_verbosity,
)
from judgearena.estimate_elo_ratings import CliEloArgs
Expand Down Expand Up @@ -172,6 +172,31 @@ def _resolve_model_a(args: argparse.Namespace) -> str | None:
return args.model_A


def _common_base_kwargs(args: argparse.Namespace) -> dict:
"""Build the kwargs shared by every CLI dataclass from a parsed Namespace.

Centralised here so that ``CliArgs`` and ``CliEloArgs`` see the same
per-role :class:`GenerationConfig` instances and the same shared
fields without each call site repeating the field-by-field forwarding.
"""
gen_configs = resolve_generation_configs(args)
return {
"judge_model": args.judge_model,
"n_instructions": args.n_instructions,
"provide_explanation": args.provide_explanation,
"swap_mode": args.swap_mode,
"ignore_cache": args.ignore_cache,
"truncate_all_input_chars": args.truncate_all_input_chars,
"result_folder": args.result_folder,
"verbosity": resolve_verbosity(args),
"log_file": args.log_file,
"no_log_file": args.no_log_file,
"gen_A": gen_configs["A"],
"gen_B": gen_configs["B"],
"gen_judge": gen_configs["judge"],
}


def _build_elo_args(
args: argparse.Namespace, arena: str, model_a: str | None
) -> CliEloArgs:
Expand All @@ -191,21 +216,7 @@ def _build_elo_args(
n_bootstraps=args.n_bootstraps,
seed=args.seed,
baseline_model=args.baseline_model,
judge_model=args.judge_model,
n_instructions=args.n_instructions,
provide_explanation=args.provide_explanation,
swap_mode=args.swap_mode,
ignore_cache=args.ignore_cache,
truncate_all_input_chars=args.truncate_all_input_chars,
max_out_tokens_models=args.max_out_tokens_models,
max_out_tokens_judge=args.max_out_tokens_judge,
max_model_len=args.max_model_len,
chat_template=args.chat_template,
result_folder=args.result_folder,
engine_kwargs=parse_engine_kwargs(args.engine_kwargs),
verbosity=resolve_verbosity(args),
log_file=args.log_file,
no_log_file=args.no_log_file,
**_common_base_kwargs(args),
)


Expand All @@ -219,21 +230,7 @@ def _build_generate_and_evaluate_args(
model_A=model_a,
model_B=args.model_B,
use_tqdm=args.use_tqdm,
judge_model=args.judge_model,
n_instructions=args.n_instructions,
provide_explanation=args.provide_explanation,
swap_mode=args.swap_mode,
ignore_cache=args.ignore_cache,
truncate_all_input_chars=args.truncate_all_input_chars,
max_out_tokens_models=args.max_out_tokens_models,
max_out_tokens_judge=args.max_out_tokens_judge,
max_model_len=args.max_model_len,
chat_template=args.chat_template,
result_folder=args.result_folder,
engine_kwargs=parse_engine_kwargs(args.engine_kwargs),
verbosity=resolve_verbosity(args),
log_file=args.log_file,
no_log_file=args.no_log_file,
**_common_base_kwargs(args),
)


Expand Down
Loading