diff --git a/judgearena/cli.py b/judgearena/cli.py index eb94c83..d226c69 100644 --- a/judgearena/cli.py +++ b/judgearena/cli.py @@ -193,6 +193,9 @@ def _build_elo_args( baseline_model=args.baseline_model, judge_model=args.judge_model, n_instructions=args.n_instructions, + judge_prompt_preset=args.judge_prompt_preset, + judge_system_prompt_file=args.judge_system_prompt_file, + judge_user_prompt_file=args.judge_user_prompt_file, provide_explanation=args.provide_explanation, swap_mode=args.swap_mode, ignore_cache=args.ignore_cache, @@ -221,6 +224,9 @@ def _build_generate_and_evaluate_args( use_tqdm=args.use_tqdm, judge_model=args.judge_model, n_instructions=args.n_instructions, + judge_prompt_preset=args.judge_prompt_preset, + judge_system_prompt_file=args.judge_system_prompt_file, + judge_user_prompt_file=args.judge_user_prompt_file, provide_explanation=args.provide_explanation, swap_mode=args.swap_mode, ignore_cache=args.ignore_cache, diff --git a/judgearena/cli_common.py b/judgearena/cli_common.py index 58ce78b..38ebf7c 100644 --- a/judgearena/cli_common.py +++ b/judgearena/cli_common.py @@ -19,6 +19,14 @@ class BaseCliArgs: judge_model: str n_instructions: int | None = None + # Judge-prompt selection (see ``judgearena.prompts.registry``). + # ``judge_prompt_preset`` picks a named preset; the ``_file`` overrides + # take a path on disk and win over the preset. ``provide_explanation`` + # is kept for backward compatibility and is equivalent to setting the + # preset to ``default_with_explanation``. + judge_prompt_preset: str | None = None + judge_system_prompt_file: str | None = None + judge_user_prompt_file: str | None = None provide_explanation: bool = False swap_mode: str = "fixed" ignore_cache: bool = False @@ -59,10 +67,46 @@ def add_common_arguments(parser: argparse.ArgumentParser) -> None: type=int, required=False, ) + parser.add_argument( + "--judge_prompt_preset", + type=str, + required=False, + default=None, + help=( + "Name of a judge-prompt preset registered in " + "``judgearena.prompts.registry`` (e.g. ``default``, " + "``default_with_explanation``, ``fluency``, " + "``fastchat-pairwise``). When omitted, the per-task default " + "is used." + ), + ) + parser.add_argument( + "--judge_system_prompt_file", + type=str, + required=False, + default=None, + help=( + "Path to a custom judge system prompt; takes precedence over " + "--judge_prompt_preset. Must be combined with " + "--judge_user_prompt_file." + ), + ) + parser.add_argument( + "--judge_user_prompt_file", + type=str, + required=False, + default=None, + help=( + "Path to a custom judge user-prompt template; takes precedence " + "over --judge_prompt_preset. Must be combined with " + "--judge_system_prompt_file." + ), + ) parser.add_argument( "--provide_explanation", action="store_true", help=( + "Equivalent to --judge_prompt_preset default_with_explanation. " "If specified, judge will provide explanation before making a " "judgement. Does not necessarily improve the accuracy of the judge " "but enables some result interpretation." diff --git a/judgearena/evaluate.py b/judgearena/evaluate.py index 7eb8599..1c5ba60 100644 --- a/judgearena/evaluate.py +++ b/judgearena/evaluate.py @@ -15,6 +15,10 @@ is_arena_hard_dataset, ) from judgearena.log import get_logger +from judgearena.prompts.registry import ( + ResolvedJudgePrompt, + resolve_judge_prompt, +) from judgearena.repro import _to_jsonable, write_run_metadata from judgearena.utils import ( compute_pref_summary, @@ -59,57 +63,79 @@ def get_regexp_match(self, s: str, regex: str, group_index: int = 1): return float(m.group(group_index).strip(" ")) -_COMPLETION_LABEL_SINGLE = "Answer" -_COMPLETION_LABEL_MULTI_TURN = "Conversation with User" -_EXPLANATION_SUFFIX = ", first starts with an explanation of your judgement" -_SCORE_FENCE = "\n```" - - def load_judge_system_and_user_prompt( provide_explanation: bool = True, multi_turn: bool = False, ) -> tuple[str, str]: - prompts_dir = Path(__file__).parent / "prompts" - system_prompt = (prompts_dir / "system-prompt.txt").read_text() + """Load the bundled default judge prompts (back-compat shim). - prompt_filename = ( - "prompt-with-explanation.txt" if provide_explanation else "prompt.txt" - ) - user_prompt_template = (prompts_dir / prompt_filename).read_text() - user_prompt_template = user_prompt_template.replace( - "{completion_label}", - _COMPLETION_LABEL_MULTI_TURN if multi_turn else _COMPLETION_LABEL_SINGLE, - ) - user_prompt_template = user_prompt_template.replace( - "{explanation_suffix}", - _EXPLANATION_SUFFIX if provide_explanation else _SCORE_FENCE, + Prefer :func:`judgearena.prompts.registry.resolve_judge_prompt` for new + code; this function delegates to the registry but returns the + ``(system, user_template)`` tuple expected by older callers. + """ + resolved = resolve_judge_prompt( + preset=("default_with_explanation" if provide_explanation else "default"), + multi_turn=multi_turn, ) - - return system_prompt, user_prompt_template + return resolved.system_text, resolved.user_template_text def resolve_judge_prompts( *, - provide_explanation: bool, + provide_explanation: bool = False, multi_turn: bool = False, system_prompt: str | None = None, user_prompt_template: str | None = None, + task: str | None = None, + preset: str | None = None, + system_file: str | None = None, + user_file: str | None = None, ) -> tuple[str, str]: - default_system_prompt, default_user_prompt_template = ( - load_judge_system_and_user_prompt( - provide_explanation=provide_explanation, multi_turn=multi_turn - ) + """Resolve the judge ``(system_prompt, user_prompt_template)`` for a run. + + Direct ``system_prompt`` / ``user_prompt_template`` overrides win. + Otherwise the registry is consulted with ``task`` / ``preset`` / + ``system_file`` / ``user_file``. Legacy callers that pass nothing + end up with the ``default`` preset (or ``default_with_explanation`` + when ``provide_explanation=True``) for backward compatibility. + """ + if system_prompt is not None and user_prompt_template is not None: + return system_prompt, user_prompt_template + + resolved = resolve_judge_prompt( + task=task, + preset=preset, + system_file=system_file, + user_file=user_file, + multi_turn=multi_turn, + provide_explanation=provide_explanation, ) return ( - system_prompt if system_prompt is not None else default_system_prompt, + system_prompt if system_prompt is not None else resolved.system_text, ( user_prompt_template if user_prompt_template is not None - else default_user_prompt_template + else resolved.user_template_text ), ) +def resolve_run_judge_prompt(task: str, cli_args) -> ResolvedJudgePrompt: + """Resolve the judge prompt for a run from the CLI args dataclass. + + Accepts a :class:`judgearena.cli_common.BaseCliArgs` instance (or any + object exposing the same attributes) and returns the full + :class:`ResolvedJudgePrompt`, including hashes/paths for metadata. + """ + return resolve_judge_prompt( + task=task, + preset=getattr(cli_args, "judge_prompt_preset", None), + system_file=getattr(cli_args, "judge_system_prompt_file", None), + user_file=getattr(cli_args, "judge_user_prompt_file", None), + provide_explanation=getattr(cli_args, "provide_explanation", False), + ) + + def evaluate_completions( dataset: str = "alpaca-eval", judge_chat_model: LLM = None, diff --git a/judgearena/prompts/__init__.py b/judgearena/prompts/__init__.py index 18afe87..0c8f05a 100644 --- a/judgearena/prompts/__init__.py +++ b/judgearena/prompts/__init__.py @@ -1 +1,22 @@ -"""Prompt templates bundled with JudgeArena.""" +"""Prompt templates bundled with JudgeArena. + +The :mod:`judgearena.prompts.registry` submodule exposes the named presets +used by the judge plus a per-task default mapping; see ``--judge_prompt_preset`` +on the CLI. +""" + +from judgearena.prompts.registry import ( + PRESETS, + TASK_DEFAULT_PRESET, + ResolvedJudgePrompt, + default_preset_for_task, + resolve_judge_prompt, +) + +__all__ = [ + "PRESETS", + "TASK_DEFAULT_PRESET", + "ResolvedJudgePrompt", + "default_preset_for_task", + "resolve_judge_prompt", +] diff --git a/judgearena/prompts/registry.py b/judgearena/prompts/registry.py new file mode 100644 index 0000000..c98ef5d --- /dev/null +++ b/judgearena/prompts/registry.py @@ -0,0 +1,216 @@ +"""Judge prompt presets and per-task default mapping. + +JudgeArena ships a small registry of named prompt presets so that every +benchmark gets a sensible default that is *also* recorded by hash in the +run metadata. Users can either pick a preset by name with +``--judge_prompt_preset NAME`` or supply a custom ``(system, user)`` pair +with ``--judge_system_prompt_file`` / ``--judge_user_prompt_file``. + +The MT-Bench pipeline keeps its own category-aware prompt selection (see +:mod:`judgearena.mt_bench.fastchat_compat`); the registry just records the +preset name ``fastchat-pairwise`` so the metadata bundle still answers +"which judge prompt was used here?". +""" + +from __future__ import annotations + +from dataclasses import dataclass +from importlib.resources import files +from pathlib import Path + +PROMPTS_PACKAGE = "judgearena.prompts" + +FLUENCY_SYSTEM = ( + "You are a highly efficient assistant, who evaluates and selects the best " + "large language model based on the quality of completion of a sentence. " + "You will see a sentence to be completed and two completions from " + "Assistant A and Assistant B and will have to decide which one is best. " + "Make sure to not over-confidently prefer one assistant or the other and " + "also make sure to not bias your preference based on the ordering or on " + "the length of the answers." +) + + +@dataclass(frozen=True) +class _PresetSpec: + """Internal description of a prompt preset. + + Either ``system_file`` or ``inline_system`` is set (file wins when both). + ``user_file`` is required for non-delegated presets. ``delegated`` + presets do not produce ``(system, user)`` strings; callers must use + their own prompt-selection machinery. + """ + + system_file: str | None = None + inline_system: str | None = None + user_file: str | None = None + delegated: bool = False + + +PRESETS: dict[str, _PresetSpec] = { + "default": _PresetSpec(system_file="system-prompt.txt", user_file="prompt.txt"), + "default_with_explanation": _PresetSpec( + system_file="system-prompt.txt", user_file="prompt-with-explanation.txt" + ), + "fluency": _PresetSpec(inline_system=FLUENCY_SYSTEM, user_file="prompt.txt"), + "fastchat-pairwise": _PresetSpec(delegated=True), +} + + +# Per-task default preset. Tasks not listed fall back through prefix rules +# in ``default_preset_for_task`` (m-arena-hard*, fluency-*) and ultimately +# to ``default``. +TASK_DEFAULT_PRESET: dict[str, str] = { + "alpaca-eval": "default", + "arena-hard-v0.1": "default", + "arena-hard-v2.0": "default", + "m-arena-hard": "default", + "mt-bench": "fastchat-pairwise", +} + + +@dataclass(frozen=True) +class ResolvedJudgePrompt: + """The judge prompt that will actually be used for a run. + + ``delegated=True`` signals that the calling pipeline (currently only + MT-Bench) selects its templates per item; ``system_text`` and + ``user_template_text`` are then unused but the ``name`` is still + recorded in the run metadata. + """ + + name: str + system_text: str + user_template_text: str + delegated: bool = False + source: str = "preset" # "preset" or "file" + system_path: str | None = None + user_path: str | None = None + + +_COMPLETION_LABEL_SINGLE = "Answer" +_COMPLETION_LABEL_MULTI_TURN = "Conversation with User" +_EXPLANATION_SUFFIX = ", first starts with an explanation of your judgement" +_SCORE_FENCE = "\n```" + + +def default_preset_for_task(task: str | None) -> str: + """Return the preset name that should be used for ``task`` by default.""" + if not task: + return "default" + if task in TASK_DEFAULT_PRESET: + return TASK_DEFAULT_PRESET[task] + if task.startswith("m-arena-hard"): + return "default" + if task.startswith("fluency-") or task.startswith("fluency"): + return "fluency" + return "default" + + +def _load_packaged_text(filename: str) -> str: + return files(PROMPTS_PACKAGE).joinpath(filename).read_text(encoding="utf-8") + + +def _materialize_user_template( + text: str, *, multi_turn: bool, with_explanation: bool +) -> str: + """Apply the ``{completion_label}`` and ``{explanation_suffix}`` substitutions. + + These placeholders exist in the bundled ``prompt.txt`` but not in + ``prompt-with-explanation.txt``; both spellings are handled idempotently + so callers don't need to know which preset they got. + """ + text = text.replace( + "{completion_label}", + _COMPLETION_LABEL_MULTI_TURN if multi_turn else _COMPLETION_LABEL_SINGLE, + ) + text = text.replace( + "{explanation_suffix}", + _EXPLANATION_SUFFIX if with_explanation else _SCORE_FENCE, + ) + return text + + +def resolve_judge_prompt( + *, + task: str | None = None, + preset: str | None = None, + system_file: str | Path | None = None, + user_file: str | Path | None = None, + multi_turn: bool = False, + provide_explanation: bool = False, +) -> ResolvedJudgePrompt: + """Resolve the prompt that should be used for this run. + + Resolution order: + + 1. If both ``system_file`` and ``user_file`` are given, they win. + 2. Else if ``preset`` is given, it is used. + 3. Else if ``provide_explanation=True``, ``default_with_explanation`` is used + (legacy alias kept for backward compatibility). + 4. Else the per-task default preset is selected (see + :data:`TASK_DEFAULT_PRESET`). + """ + if (system_file is None) != (user_file is None): + raise ValueError( + "Both --judge_system_prompt_file and --judge_user_prompt_file " + "must be provided together." + ) + + if system_file is not None: + sys_path = Path(system_file) + usr_path = Path(user_file) # type: ignore[arg-type] + sys_text = sys_path.read_text(encoding="utf-8") + usr_text = _materialize_user_template( + usr_path.read_text(encoding="utf-8"), + multi_turn=multi_turn, + with_explanation=provide_explanation, + ) + return ResolvedJudgePrompt( + name=f"file:{sys_path.name}+{usr_path.name}", + system_text=sys_text, + user_template_text=usr_text, + source="file", + system_path=str(sys_path), + user_path=str(usr_path), + ) + + if preset is None: + if provide_explanation: + preset = "default_with_explanation" + else: + preset = default_preset_for_task(task) + + if preset not in PRESETS: + raise KeyError( + f"Unknown judge prompt preset {preset!r}. Available: {sorted(PRESETS)}" + ) + spec = PRESETS[preset] + if spec.delegated: + return ResolvedJudgePrompt( + name=preset, + system_text="", + user_template_text="", + delegated=True, + source="preset", + ) + + sys_text = ( + spec.inline_system + if spec.inline_system is not None + else _load_packaged_text(spec.system_file) # type: ignore[arg-type] + ) + user_text = _load_packaged_text(spec.user_file) # type: ignore[arg-type] + user_text = _materialize_user_template( + user_text, + multi_turn=multi_turn, + with_explanation=(preset == "default_with_explanation"), + ) + return ResolvedJudgePrompt( + name=preset, + system_text=sys_text, + user_template_text=user_text, + source="preset", + system_path=spec.system_file, + user_path=spec.user_file, + ) diff --git a/tests/test_cli.py b/tests/test_cli.py index 30be4fa..a5b336c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -329,3 +329,24 @@ def test_engine_kwargs_parsed_as_json(capture_mains): ) ge_args: CliArgs = capture_mains["args"] assert ge_args.engine_kwargs == {"tensor_parallel_size": 4} + + +def test_judge_prompt_preset_flag_is_forwarded(capture_mains): + cli_module.cli( + [ + "--task", + "alpaca-eval", + "--model_A", + "Dummy/A", + "--model_B", + "Dummy/B", + "--judge", + "Dummy/J", + "--judge_prompt_preset", + "default_with_explanation", + ] + ) + ge_args: CliArgs = capture_mains["args"] + assert ge_args.judge_prompt_preset == "default_with_explanation" + assert ge_args.judge_system_prompt_file is None + assert ge_args.judge_user_prompt_file is None diff --git a/tests/test_prompt_registry.py b/tests/test_prompt_registry.py new file mode 100644 index 0000000..559fb0c --- /dev/null +++ b/tests/test_prompt_registry.py @@ -0,0 +1,133 @@ +"""Judge prompt registry resolution tests.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import pytest + +from judgearena.prompts.registry import ( + PRESETS, + TASK_DEFAULT_PRESET, + default_preset_for_task, + resolve_judge_prompt, +) + + +def test_default_preset_for_task_known_keys(): + for task, preset in TASK_DEFAULT_PRESET.items(): + assert default_preset_for_task(task) == preset + + +def test_default_preset_for_fluency_prefix(): + """Any task starting with ``fluency-`` selects the fluency preset by default.""" + assert default_preset_for_task("fluency-french") == "fluency" + assert default_preset_for_task("fluency-spanish") == "fluency" + + +def test_default_preset_for_m_arena_hard_prefix(): + """Per-language m-arena-hard tasks fall back to the default preset.""" + assert default_preset_for_task("m-arena-hard-EU") == "default" + assert default_preset_for_task("m-arena-hard-fr") == "default" + + +def test_default_preset_for_unknown_task(): + """Anything not in TASK_DEFAULT_PRESET defaults to ``default``.""" + assert default_preset_for_task("brand-new-bench") == "default" + assert default_preset_for_task(None) == "default" + + +def test_alpaca_eval_picks_default_preset(): + resolved = resolve_judge_prompt(task="alpaca-eval") + assert resolved.name == "default" + assert resolved.system_text # bundled system prompt + assert "{user_prompt}" in resolved.user_template_text + + +def test_mt_bench_picks_fastchat_pairwise_and_is_delegated(): + resolved = resolve_judge_prompt(task="mt-bench") + assert resolved.name == "fastchat-pairwise" + assert resolved.delegated is True + # Delegated presets do not produce concrete prompt strings; the caller + # uses its own prompt-selection machinery (FastChat-compatible). + assert resolved.system_text == "" + assert resolved.user_template_text == "" + + +def test_explicit_preset_wins_over_task_default(): + resolved = resolve_judge_prompt( + task="alpaca-eval", preset="default_with_explanation" + ) + assert resolved.name == "default_with_explanation" + + +def test_provide_explanation_legacy_alias_picks_explanation_preset(): + resolved = resolve_judge_prompt(task="alpaca-eval", provide_explanation=True) + assert resolved.name == "default_with_explanation" + + +def test_unknown_preset_raises(): + with pytest.raises(KeyError, match="Unknown judge prompt preset"): + resolve_judge_prompt(task="alpaca-eval", preset="does-not-exist") + + +def test_file_overrides_must_come_in_pair(tmp_path): + sys_file = tmp_path / "sys.txt" + sys_file.write_text("My system prompt", encoding="utf-8") + with pytest.raises(ValueError, match="must be provided together"): + resolve_judge_prompt(task="alpaca-eval", system_file=str(sys_file)) + + +def test_file_overrides_take_precedence_over_preset(tmp_path): + sys_file = tmp_path / "sys.txt" + usr_file = tmp_path / "usr.txt" + sys_file.write_text("My system prompt", encoding="utf-8") + usr_file.write_text("My user prompt {completion_label}", encoding="utf-8") + resolved = resolve_judge_prompt( + task="alpaca-eval", + preset="default_with_explanation", # should be ignored + system_file=str(sys_file), + user_file=str(usr_file), + ) + assert resolved.system_text == "My system prompt" + assert "Answer" in resolved.user_template_text # placeholder substituted + assert resolved.name.startswith("file:") + assert resolved.source == "file" + + +def test_resolve_run_judge_prompt_reads_from_cli_args(): + """``resolve_run_judge_prompt`` plucks the right knobs off a BaseCliArgs-shaped object.""" + from judgearena.evaluate import resolve_run_judge_prompt + + @dataclass + class FakeArgs: + judge_prompt_preset: str | None = None + judge_system_prompt_file: str | None = None + judge_user_prompt_file: str | None = None + provide_explanation: bool = False + + resolved = resolve_run_judge_prompt("alpaca-eval", FakeArgs()) + assert resolved.name == "default" + + resolved_explain = resolve_run_judge_prompt( + "alpaca-eval", + FakeArgs(judge_prompt_preset="default_with_explanation"), + ) + assert resolved_explain.name == "default_with_explanation" + + resolved_legacy = resolve_run_judge_prompt( + "alpaca-eval", FakeArgs(provide_explanation=True) + ) + assert resolved_legacy.name == "default_with_explanation" + + +def test_every_preset_round_trips_or_is_delegated(): + """Every entry in PRESETS resolves cleanly.""" + for name, spec in PRESETS.items(): + resolved = resolve_judge_prompt(preset=name) + assert resolved.name == name + if spec.delegated: + assert resolved.delegated is True + else: + assert resolved.system_text # non-empty + assert resolved.user_template_text # non-empty