diff --git a/benchmarks/swebench/run_infer.py b/benchmarks/swebench/run_infer.py index 33f802373..a5836351a 100644 --- a/benchmarks/swebench/run_infer.py +++ b/benchmarks/swebench/run_infer.py @@ -26,6 +26,7 @@ from benchmarks.utils.console_logging import summarize_instance from benchmarks.utils.constants import EVAL_AGENT_SERVER_IMAGE from benchmarks.utils.conversation import build_event_persistence_callback +from benchmarks.utils.cost_cap import build_cost_cap_callback from benchmarks.utils.critics import create_critic from benchmarks.utils.dataset import get_dataset from benchmarks.utils.evaluation import Evaluation @@ -302,14 +303,30 @@ def evaluate_instance( instance_id=instance.id, attempt=self.current_attempt, ) + callbacks: list = [persist_callback] + + # Optional per-instance cost cap (defence-in-depth against runaway + # cost on hard instances). Constructed before the Conversation so it + # is part of the composed callback chain from the first event, then + # bound to the conversation right after construction so it can call + # ``conversation.pause()`` and read accumulated cost. + cost_cap_callback = None + if self.metadata.max_cost_per_instance is not None: + cost_cap_callback = build_cost_cap_callback( + max_cost_per_instance=self.metadata.max_cost_per_instance, + instance_id=instance.id, + ) + callbacks.append(cost_cap_callback) conversation = Conversation( agent=agent, workspace=workspace, - callbacks=[persist_callback], + callbacks=callbacks, max_iteration_per_run=self.metadata.max_iterations, delete_on_close=True, ) + if cost_cap_callback is not None: + cost_cap_callback.bind(conversation) logger.info("repo_path: %s", repo_path) source_repo_path = self.get_source_repo_path(instance) @@ -441,6 +458,7 @@ def main() -> None: enable_condenser=enable_condenser, condenser_max_size=args.condenser_max_size, condenser_keep_first=args.condenser_keep_first, + max_cost_per_instance=args.max_cost_per_instance, ) # Run orchestrator with a simple JSONL writer diff --git a/benchmarks/utils/args_parser.py b/benchmarks/utils/args_parser.py index 4df17499b..c97469571 100644 --- a/benchmarks/utils/args_parser.py +++ b/benchmarks/utils/args_parser.py @@ -129,6 +129,16 @@ def get_parser(add_llm_config: bool = True) -> argparse.ArgumentParser: type=int, help="Number of initial events to always keep when condensing", ) + parser.add_argument( + "--max-cost-per-instance", + type=float, + default=None, + help=( + "Optional per-instance accumulated cost cap in USD. When set " + "(must be > 0), the conversation is paused as soon as " + "accumulated_cost exceeds this value. Defaults to no cap." + ), + ) return parser diff --git a/benchmarks/utils/cost_cap.py b/benchmarks/utils/cost_cap.py new file mode 100644 index 000000000..455c4fde5 --- /dev/null +++ b/benchmarks/utils/cost_cap.py @@ -0,0 +1,121 @@ +"""Per-instance cost cap callback for benchmarks. + +Some evaluations have shown that a small minority of instances can consume +disproportionately large amounts of money (e.g. for the Gemini-3.5-Flash +SWE-bench run, 22 instances cost >$10 each and accounted for ~20% of the +total $1900+ spend). The dominant mechanism is a combination of: + +* large iteration counts (some instances ran 300+ iterations); +* the LLM-summarising condenser, which periodically rewrites the prompt + prefix and therefore invalidates the provider's prompt cache (cache-read + ratio dropped from ~45% on cheap instances to ~10% on the expensive + ones); and +* high reasoning effort, which makes every uncached call more costly. + +This module provides a small, optional defence-in-depth: a +:class:`Conversation` callback that pauses the conversation once the +accumulated per-instance cost exceeds a configured threshold. It does not +attempt to fix the root cause (which would require SDK-level changes to the +condenser or to enforce ``Metrics.max_budget_per_task``); it simply caps +the blast radius. + +When the cap is hit, the callback calls ``conversation.pause()`` which +takes effect at the next iteration boundary, mirroring the existing +behaviour of ``max_iteration_per_run``. Any patch produced up to that point +is still collected and submitted. +""" + +from __future__ import annotations + +from typing import Callable + +from openhands.sdk import Event, get_logger + + +logger = get_logger(__name__) + +ConversationCallback = Callable[[Event], None] + + +class CostCapCallback: + """Callback that pauses a conversation once accumulated cost exceeds a cap. + + Use :meth:`bind` to attach the callback to the conversation after the + conversation has been constructed; the callback can be passed to the + :class:`Conversation` constructor *before* binding so that the + callback is part of the composed callback chain from the very first + event. + """ + + def __init__(self, max_cost_per_instance: float, instance_id: str) -> None: + """ + Args: + max_cost_per_instance: Maximum allowed accumulated USD cost for + this instance. Must be strictly positive. + instance_id: Identifier used only for log messages. + + Raises: + ValueError: If ``max_cost_per_instance`` is not strictly positive. + """ + if max_cost_per_instance <= 0: + raise ValueError( + f"max_cost_per_instance must be > 0, got {max_cost_per_instance}" + ) + self.max_cost_per_instance = max_cost_per_instance + self.instance_id = instance_id + self._conversation = None # type: ignore[assignment] + self._triggered = False + + def bind(self, conversation: object) -> None: + """Attach the conversation whose cost will be monitored.""" + self._conversation = conversation + + def __call__(self, event: Event) -> None: # noqa: ARG002 - event is unused + if self._triggered or self._conversation is None: + return + try: + cost = self._conversation.conversation_stats.get_combined_metrics().accumulated_cost + except Exception as exc: + # Metrics access should never block the run. + logger.debug( + "cost_cap: failed to read accumulated_cost for %s: %s", + self.instance_id, + exc, + ) + return + + if cost >= self.max_cost_per_instance: + self._triggered = True + logger.warning( + "cost_cap: instance %s exceeded per-instance budget " + "(accumulated_cost=$%.4f >= cap=$%.4f); pausing conversation.", + self.instance_id, + cost, + self.max_cost_per_instance, + ) + try: + self._conversation.pause() + except Exception as exc: + # Defensive: if pause itself fails (e.g. remote conversation + # in an odd state), don't take the whole instance down. + logger.warning( + "cost_cap: pause() raised for %s: %s", + self.instance_id, + exc, + ) + + +def build_cost_cap_callback( + max_cost_per_instance: float, instance_id: str +) -> CostCapCallback: + """Convenience wrapper. See :class:`CostCapCallback`. + + Returns: + An unbound :class:`CostCapCallback`. The caller must call + ``bind(conversation)`` once the conversation has been created so + the callback knows which conversation to read cost from and pause. + """ + return CostCapCallback( + max_cost_per_instance=max_cost_per_instance, + instance_id=instance_id, + ) diff --git a/benchmarks/utils/models.py b/benchmarks/utils/models.py index 9dd471383..8190b2dd8 100644 --- a/benchmarks/utils/models.py +++ b/benchmarks/utils/models.py @@ -92,6 +92,18 @@ class EvalMetadata(BaseModel): ge=0, description="Number of initial events to always keep when condensing", ) + max_cost_per_instance: float | None = Field( + default=None, + gt=0, + description=( + "Optional per-instance accumulated cost cap in USD. When set, " + "the conversation is paused as soon as accumulated_cost exceeds " + "this value, mirroring the behaviour of max_iterations. None " + "(the default) disables the cap. This is a defence-in-depth " + "measure against runaway-cost instances and does not affect " + "behaviour when the cap is not reached." + ), + ) lmnr: LaminarEvalMetadata | None = Field( default=None, description="Laminar evaluation metadata", diff --git a/tests/test_cost_cap.py b/tests/test_cost_cap.py new file mode 100644 index 000000000..36f2423e5 --- /dev/null +++ b/tests/test_cost_cap.py @@ -0,0 +1,143 @@ +"""Tests for the per-instance cost cap callback.""" + +from __future__ import annotations + +import pytest + +from benchmarks.utils.cost_cap import CostCapCallback, build_cost_cap_callback + + +class _FakeMetrics: + def __init__(self, cost: float) -> None: + self.accumulated_cost = cost + + +class _FakeStats: + def __init__(self, cost: float) -> None: + self._cost = cost + + def get_combined_metrics(self) -> _FakeMetrics: + return _FakeMetrics(self._cost) + + +class _FakeConversation: + """Minimal stand-in for ``BaseConversation`` for unit testing. + + Exposes only the surface the cost-cap callback touches: + ``conversation_stats`` and ``pause()``. + """ + + def __init__(self, cost: float) -> None: + self._cost = cost + self.paused = False + self.pause_call_count = 0 + + @property + def conversation_stats(self) -> _FakeStats: + return _FakeStats(self._cost) + + def set_cost(self, cost: float) -> None: + self._cost = cost + + def pause(self) -> None: + self.paused = True + self.pause_call_count += 1 + + +def test_rejects_non_positive_cap(): + with pytest.raises(ValueError): + build_cost_cap_callback(max_cost_per_instance=0.0, instance_id="x") + with pytest.raises(ValueError): + build_cost_cap_callback(max_cost_per_instance=-1.0, instance_id="x") + + +def test_no_pause_below_cap(): + convo = _FakeConversation(cost=2.5) + cb = build_cost_cap_callback(max_cost_per_instance=10.0, instance_id="inst") + cb.bind(convo) + cb(event=object()) # event is unused + assert convo.paused is False + + +def test_pauses_when_cap_reached(): + convo = _FakeConversation(cost=10.0) + cb = build_cost_cap_callback(max_cost_per_instance=10.0, instance_id="inst") + cb.bind(convo) + cb(event=object()) + assert convo.paused is True + assert convo.pause_call_count == 1 + + +def test_pauses_when_cap_exceeded(): + convo = _FakeConversation(cost=0.0) + cb = build_cost_cap_callback(max_cost_per_instance=5.0, instance_id="inst") + cb.bind(convo) + cb(event=object()) + assert convo.paused is False + + convo.set_cost(7.5) + cb(event=object()) + assert convo.paused is True + + +def test_idempotent_once_triggered(): + """Once paused, the callback must not call pause() again on subsequent events.""" + convo = _FakeConversation(cost=20.0) + cb = build_cost_cap_callback(max_cost_per_instance=5.0, instance_id="inst") + cb.bind(convo) + for _ in range(5): + cb(event=object()) + assert convo.pause_call_count == 1 + + +def test_no_op_when_not_bound(): + """Calling the callback before bind() should be a safe no-op.""" + cb = build_cost_cap_callback(max_cost_per_instance=1.0, instance_id="inst") + # Should not raise. + cb(event=object()) + # Still works after binding. + convo = _FakeConversation(cost=2.0) + cb.bind(convo) + cb(event=object()) + assert convo.paused is True + + +def test_metrics_failure_does_not_crash(): + """If reading metrics raises, the callback must swallow the error.""" + + class _BrokenConversation(_FakeConversation): + @property + def conversation_stats(self): + raise RuntimeError("metrics unavailable") + + convo = _BrokenConversation(cost=0.0) + cb = build_cost_cap_callback(max_cost_per_instance=1.0, instance_id="inst") + cb.bind(convo) + cb(event=object()) # must not raise + assert convo.paused is False + + +def test_pause_failure_does_not_crash(): + """If pause() raises, the callback must swallow the error and stay triggered.""" + + class _BrokenPauseConversation(_FakeConversation): + def pause(self) -> None: + self.pause_call_count += 1 + raise RuntimeError("cannot pause") + + convo = _BrokenPauseConversation(cost=10.0) + cb = build_cost_cap_callback(max_cost_per_instance=1.0, instance_id="inst") + cb.bind(convo) + cb(event=object()) # must not raise + # Once it tried to pause once, it stays triggered and won't try again. + cb(event=object()) + assert convo.pause_call_count == 1 + + +def test_callback_class_directly(): + """CostCapCallback can be constructed and used directly.""" + convo = _FakeConversation(cost=15.0) + cb = CostCapCallback(max_cost_per_instance=10.0, instance_id="direct") + cb.bind(convo) + cb(event=object()) + assert convo.paused is True