Skip to content
Markdown

Learning-curve extrapolation and early stopping

Scope: how to stop training trials that are already losing, so a search campaign spends its GPU-hours on runs that can still win. It covers forecasting a trial's final metric from the partial curve, the multi-fidelity bandit family (Successive Halving, Hyperband, ASHA, BOHB) that allocates budget in rounds and cuts the losers, and the cooperative-cancellation mechanics that make a stop signal safe across a remote container. The throughput lever under autonomous experimentation loops; it turns raw GPU-hours into goodput and is the classical core of the samplers behind Ray Tune and Optuna.

The numpy blocks below (the power-law forecaster, the probabilistic stop rule, Successive Halving, the ASHA promotion rule, the cooperative-cancellation state machine, and the Hyperband bracket schedule) are executed and asserted, including adversarial cases. The Optuna and Ray Tune snippets are reference templates: pin versions and validate before production use.

What it is

Two families of early-stopping decide when a training trial has lost and cut it before it burns its whole budget. Learning-curve extrapolation fits a model to the partial loss (or metric) curve and forecasts the final value at the full budget T_max, terminating any trial predicted to fall short of the current best (the idea Domhan, Springenberg, and Hutter introduced for deep nets).1 Multi-fidelity bandits skip the forecast entirely: they run many configurations at a small budget, keep the top fraction, and promote survivors to larger budgets, so a loser never trains to completion. The two compose (a bandit allocates rungs, a curve forecast decides promotions within a rung), and both sit on the same plumbing: trials that report progress per step and a cooperative way to stop them.

The two design decisions that dominate every implementation are what forecasts a trial's fate (a parametric curve fit, or its rank against peers at a shared budget) and how a trial actually stops (it exits itself at a safe point; never kill -9).

Why use it

  • It is the single biggest throughput lever in a search campaign. Most proposed configurations lose, and finishing a losing trial to its budget is wasted compute. Cutting the losers early is what turns raw GPU-hours into goodput.
  • It automates the judgment a human already makes. A person tuning a model glances at the loss curve after a few hundred steps and kills a run going nowhere. Early stopping encodes that reflex so a campaign runs unattended.
  • It has documented, large wins. Domhan et al. report a substantial speedup of deep-net hyperparameter optimization from curve extrapolation;1 Hyperband reports speedups of five-fold to over an order of magnitude versus black-box Bayesian optimization on deep-learning and kernel tasks.2
  • It scales to real clusters. ASHA promotes asynchronously and scales near-linearly to hundreds of workers, removing the synchronization barrier that stalls naive Successive Halving.3
  • It composes with a smart proposer. BOHB pairs model-based proposals with bandit early stopping for both strong anytime performance and fast convergence, the natural pairing of the proposer and the early-stopper.4

When to use it (and when not)

Match the stopper to the curve you actually have:

  • Curve extrapolation when trials emit a smooth, saturating curve (most supervised training) and you can afford a few hundred steps of signal before deciding. A power-law or exponential fit forecasts the plateau and cancels runs that will not reach the incumbent.1
  • Successive Halving when you can run many configurations cheaply and rank them at a shared small budget. It is the base primitive: most configs die cheap, only survivors get expensive budget.
  • Hyperband when you cannot guess the right count-vs-budget balance up front. It hedges by running Successive Halving across several brackets, so no single guess sinks the search.2
  • ASHA on a multi-worker cluster, where a synchronization barrier would idle workers. Promote as soon as enough configs exist above a rung rather than waiting for the rung to finish.3
  • BOHB when a model-based proposer is available and you want it to also decide how long each config runs.4

When not to reach for it, or a narrower variant:

  • Do not extrapolate a strongly non-monotonic curve. RL and fine-tuning curves plateau then jump; a saturating fit reads the plateau as "doomed" right before the jump and over-cancels. Widen the margin, lengthen the floor, or prefer bandit rungs with generous budgets for these workloads.
  • Do not decide from a handful of noisy points. Early points are dominated by initialization noise; a stop on two or three measurements cancels good runs. Enforce a min-steps and min-reports floor.
  • Do not stop on a bare point estimate. A slow-starter whose mean is below the incumbent but whose posterior is wide still has real upside. Stop only when the forecast is confidently below the best.
  • Do not import a default margin. The safe stopping margin is workload-specific; measure the false-cancel rate on a held-out set of full runs and tune to it.

Architecture

Every trial reports (step, value) per step. A stopper waits until the trial is past a min-steps and min-reports floor, then decides its fate: a curve forecaster fits the partial curve and asks whether the predicted final can beat the incumbent, or a bandit ranks the trial against peers at the same rung. On a "stop" verdict the controller writes a control signal; the trial polls it on its next progress report and exits itself at a safe point, recorded as cancelled (not failed) so the campaign's failure-rate guard is not tripped.

flowchart LR
  RUN["Trial training"] --> REP["Report metric per step<br/>(step, value)"]
  REP --> MIN{"Past min-steps<br/>and min-reports?"}
  MIN -->|"no"| RUN
  MIN -->|"yes"| FC["Fit curve, forecast final at T_max"]
  FC --> CMP{"Forecast can beat<br/>incumbent best?"}
  CMP -->|"yes"| RUN
  CMP -->|"no"| STOP["Signal stop"]
  STOP --> COOP["Trial polls signal,<br/>exits at safe point"]
  COOP --> CANCELLED["Status: cancelled (not failed)"]

The two families slot into the same skeleton. Curve extrapolation replaces the FC/CMP block with a parametric fit and a probabilistic beat test; a bandit replaces it with a rank-at-rung comparison. Both then feed the same cooperative-cancellation tail (STOP to CANCELLED), which is what makes a stop signal safe across a remote GPU container (GPU consumption models).

How to forecast the final metric from the partial curve

The direct approach mimics the human: from the first portion of a learning curve, predict where it lands at the full budget T_max, and terminate runs predicted to fall short of the current best. Domhan et al. fit a weighted ensemble of parametric curve families (power law, exponential, and related saturating forms) to the observed curve with a probabilistic model, then extrapolate and stop unpromising runs, reporting a substantial speedup of deep-net hyperparameter optimization in their experiments.1 A power-law fit (y ~= a - b*t^(-c)) is the simplest usable instance; the ensemble matters because no single family fits every curve.

The block below fits y ~= a - b*t^(-c) with numpy only (a grid over the exponent c, linear least squares for (a, b) at each c, pick the lowest residual). It asserts recovery of a known curve, the flat-curve boundary, and, as the adversarial case, that a saturating fit badly under-predicts a plateau-then-jump RL curve (the exact failure the "when not" section warns about).

# Power-law learning-curve fit + forecast, executed and asserted (numpy-only).
import numpy as np

def fit_power_law(steps, values, c_grid=None):
    """Fit y ~= a - b*t**(-c). For each c on a grid, (a, b) is linear least
    squares on features [1, t**(-c)]; pick the c with lowest SSE."""
    t = np.asarray(steps, dtype=float)
    y = np.asarray(values, dtype=float)
    if c_grid is None:
        c_grid = np.linspace(0.05, 3.0, 120)
    best = None
    for c in c_grid:
        X = np.column_stack([np.ones_like(t), t ** (-c)])  # a*1 + (-b)*t^-c
        coef, *_ = np.linalg.lstsq(X, y, rcond=None)
        a, b = coef[0], -coef[1]
        resid = y - (a - b * t ** (-c))
        sse = float(resid @ resid)
        if best is None or sse < best[0]:
            best = (sse, a, b, c)
    _, a, b, c = best
    return a, b, c

def forecast(steps, values, t_max):
    a, b, c = fit_power_law(steps, values)
    return a - b * t_max ** (-c)

# Happy path: recover a known power law from its first half under small noise.
rng = np.random.default_rng(0)
a_true, b_true, c_true = 0.90, 0.6, 0.7
t = np.arange(1, 41, dtype=float)
y = a_true - b_true * t ** (-c_true) + rng.normal(0, 0.002, size=t.size)
yhat = forecast(t[:20], y[:20], t_max=200.0)
true_final = a_true - b_true * 200.0 ** (-c_true)
assert abs(yhat - true_final) < 0.02, (yhat, true_final)
assert yhat > y[19], (yhat, y[19])            # still rising toward the plateau

# Edge: a perfectly flat curve forecasts its own plateau.
assert abs(forecast(np.arange(1, 11), np.full(10, 0.5), t_max=1000.0) - 0.5) < 1e-6

# Adversarial: a plateau-then-jump (RL-style) curve fools a saturating fit.
# Seeing only the plateau, the forecast lands far BELOW the true final value.
jump_y = np.concatenate([np.full(15, 0.40), np.full(5, 0.85)])
yhat_jump = forecast(np.arange(1, 16), jump_y[:15], t_max=20.0)
assert yhat_jump < 0.60, yhat_jump                       # believes it saturates low
assert jump_y[-1] - yhat_jump > 0.20, (jump_y[-1], yhat_jump)  # under by > 0.2
print("forecast OK: happy=%.4f true=%.4f jump=%.4f (true 0.85)"
      % (yhat, true_final, yhat_jump))

The decision is probabilistic, not a point estimate: stop only when the predicted probability of beating the incumbent (with a margin) is low, so a slow-starter that would overtake late is not killed prematurely. The next block models the final metric as a Gaussian posterior and stops only when P(final > incumbent + margin) falls below a threshold. It cross-checks the closed-form Gaussian tail against Monte-Carlo, and as the adversarial case shows a wide-posterior slow-starter surviving even though naive mu < incumbent logic would kill it.

# Probabilistic early-stop decision, executed and asserted (numpy-only).
import math
import numpy as np

def prob_beats(mu, sigma, incumbent, margin=0.0):
    """P(final > incumbent + margin) for Normal(mu, sigma), via erf (no scipy)."""
    if sigma <= 0:
        return 1.0 if mu > incumbent + margin else 0.0
    z = (mu - (incumbent + margin)) / (sigma * math.sqrt(2.0))
    return 0.5 * math.erfc(-z)                 # 1 - CDF(incumbent + margin)

def should_stop(mu, sigma, incumbent, margin=0.0, floor_ok=True, p_thresh=0.05):
    if not floor_ok:                           # min-steps/min-reports floor unmet
        return False
    return prob_beats(mu, sigma, incumbent, margin) < p_thresh

# Closed-form tail matches a Monte-Carlo simulation.
rng = np.random.default_rng(1)
mu, sigma, inc = 0.80, 0.05, 0.83
mc = float(np.mean(rng.normal(mu, sigma, size=400_000) > inc))
assert abs(mc - prob_beats(mu, sigma, inc)) < 2e-3, (mc, prob_beats(mu, sigma, inc))

assert should_stop(mu=0.50, sigma=0.01, incumbent=0.80) is True    # clear loser
assert should_stop(mu=0.90, sigma=0.02, incumbent=0.80) is False   # clear winner

# Adversarial slow-starter: mean below incumbent but WIDE posterior => survives.
# Point-estimate logic (mu < inc) would wrongly kill it; the distribution spares it.
assert (0.78 < 0.80) is True
assert should_stop(mu=0.78, sigma=0.10, incumbent=0.80) is False, "killed a slow-starter"

# Floor guard spares even an obvious loser while the floor is unmet.
assert should_stop(mu=0.10, sigma=0.001, incumbent=0.80, floor_ok=False) is False
# A stricter margin stops a merely marginal winner.
assert should_stop(mu=0.805, sigma=0.005, incumbent=0.80, margin=0.05) is True
print("decision OK: mc=%.4f closed_form=%.4f" % (mc, prob_beats(mu, sigma, inc)))

How to allocate budget with multi-fidelity bandits

The bandit family treats budget (steps, epochs, data) as the resource to allocate and never trains a loser to completion:

  • Successive Halving. Run n configurations at a small budget, keep the top 1/eta, multiply the budget by eta, repeat. Most configs die cheap; only survivors get expensive budget. The catch is the n-vs-budget tradeoff: many configs at tiny budget risks cutting a slow-starter; few at large budget wastes compute on losers.
  • Hyperband hedges that tradeoff by running Successive Halving in several brackets with different starting n and aggressiveness, so no single guess about the budget-vs-count balance sinks the search. It reports speedups of over an order of magnitude versus black-box Bayesian optimization on deep-learning and kernel tasks.2
  • ASHA makes it asynchronous and parallel: instead of waiting for a whole rung to finish before promoting, promote a configuration as soon as enough configs exist above it at that rung. That removes the synchronization barrier and scales near-linearly to hundreds of workers, the version you want on a real cluster.3
  • BOHB combines model-based proposals (Bayesian optimization picks the configs) with Hyperband early stopping (the bandit decides how long each runs), getting both good anytime performance and fast convergence.4 It is the natural pairing of the proposer and the early-stopper.

Curve extrapolation and bandits compose: use the bandit to allocate rungs and a curve forecast to decide promotions within a rung.

The block below implements Successive Halving. It asserts the geometry (budgets multiply by eta, survivor counts divide by eta), that total work is a fraction of running everything to full budget, and, as the adversarial correctness check, that the true best is selected far above chance across 400 noisy seeds (guarding against a stopper that just returns a fixed or random survivor). It also rejects eta <= 1, which would eliminate nobody.

# Successive Halving, executed and asserted (numpy-only). Higher score is better.
import math
import numpy as np

def successive_halving(scores_at_budget, n, eta, r_min, r_max):
    """scores_at_budget(config_id, budget) -> float. Returns (best_id, rungs)
    where rungs is a list of (budget, ranked_survivors)."""
    assert eta > 1
    survivors = list(range(n))
    budget = r_min
    rungs = []
    while True:
        ranked = sorted(survivors, key=lambda c: scores_at_budget(c, budget), reverse=True)
        rungs.append((budget, list(ranked)))
        keep = max(1, math.floor(len(ranked) / eta))
        survivors = ranked[:keep]
        if len(survivors) == 1 or budget >= r_max:
            break
        budget = min(budget * eta, r_max)
    best = max(survivors, key=lambda c: scores_at_budget(c, r_max))
    return best, rungs

# Ground truth: each config has a final quality; its observed score rises toward
# that quality with shrinking noise, so early ranking is noisy but the top
# config is eventually identifiable.
N, R = 27, 27.0
rng = np.random.default_rng(2)
final_quality = rng.uniform(0.0, 1.0, size=N)

def score(cid, budget, r_max=R):
    frac = budget / r_max
    return final_quality[cid] * (1 - 0.5 * (1 - frac)) + rng.normal(0, 0.02 * (1 - frac))

best, rungs = successive_halving(score, n=N, eta=3, r_min=1.0, r_max=R)
budgets = [b for b, _ in rungs]
counts = [len(s) for _, s in rungs]
assert budgets == [1.0, 3.0, 9.0], budgets            # multiply by eta=3
assert counts == [27, 9, 3], counts                   # divide by eta=3
work = sum(b * len(s) for b, s in rungs)
assert work < 0.5 * (N * R), (work, N * R)            # far below brute force

# Adversarial correctness: across many seeds the true best is picked far above
# the 1/N random-guess rate, so the routine is really ranking, not returning noise.
hits = 0
for s in range(400):
    r = np.random.default_rng(1000 + s)
    fq = r.uniform(0, 1, size=N); bt = int(np.argmax(fq))
    def sc(cid, budget, r_max=R, fq=fq, r=r):
        return fq[cid] * (1 - 0.5 * (1 - budget / r_max)) + r.normal(0, 0.02 * (1 - budget / r_max))
    b, _ = successive_halving(sc, n=N, eta=3, r_min=1.0, r_max=R)
    hits += (b == bt)
rate = hits / 400
assert rate > 0.60, rate                              # vs 1/27 ~= 0.037 chance

# Edge: eta must be > 1, else nothing is eliminated.
raised = False
try:
    successive_halving(score, n=9, eta=1, r_min=1.0, r_max=9.0)
except AssertionError:
    raised = True
assert raised, "eta<=1 must be rejected"
print("SHA OK: budgets=%s counts=%s work=%.0f/%.0f best_hit_rate=%.2f"
      % (budgets, counts, work, N * R, rate))

Hyperband wraps Successive Halving in s_max + 1 brackets, each trading count against starting budget. The block below reproduces Li et al. Algorithm 1 and asserts the schedule against Algorithm 1's bracket-size formula for R = 81, eta = 3: exactly five brackets, the most aggressive starting n = 81 at resource 1, the least aggressive running five configs straight to full budget with no halving. The paper's displayed Table 1 differs for the middle brackets, showing starting n = 27, 9, 6 for s = 3, 2, 1 where the formula gives 34, 15, 8, a known inconsistency in the paper.

# Hyperband bracket schedule, executed and asserted (numpy-only). Li et al. Alg 1.
import math

def hyperband_brackets(R, eta):
    """Return (brackets, s_max, B). Each bracket is (s, n, r, rungs) where rungs
    is the (n_i, r_i) schedule Successive Halving runs inside that bracket."""
    s_max = int(math.floor(math.log(R, eta)))
    B = (s_max + 1) * R
    brackets = []
    for s in range(s_max, -1, -1):
        n = int(math.ceil((B / R) * (eta ** s) / (s + 1)))
        r = R * eta ** (-s)
        rungs = [(int(math.floor(n * eta ** (-i))), r * eta ** i) for i in range(s + 1)]
        brackets.append((s, n, r, rungs))
    return brackets, s_max, B

brackets, s_max, B = hyperband_brackets(R=81, eta=3)
assert s_max == 4 and len(brackets) == 5, (s_max, len(brackets))   # log_3(81)=4
assert [b[0] for b in brackets] == [4, 3, 2, 1, 0]                 # s descending
assert B == 405, B                                                # (s_max+1)*R

# Exact match to the paper for the most aggressive bracket (s=4).
assert brackets[0][3] == [(81, 1.0), (27, 3.0), (9, 9.0), (3, 27.0), (1, 81.0)]
# Least aggressive (s=0): 5 configs run straight to full R, no halving.
assert brackets[-1][3] == [(5, 81)], brackets[-1][3]
assert {b[0]: b[1] for b in brackets} == {4: 81, 3: 34, 2: 15, 1: 8, 0: 5}

# Inside every bracket, resource multiplies by eta and count divides by eta.
for s, n, r, rungs in brackets:
    for i in range(1, len(rungs)):
        assert math.isclose(rungs[i][1], rungs[i - 1][1] * 3)
        assert rungs[i][0] == math.floor(rungs[i - 1][0] / 3)
print("hyperband OK: s_max=%d brackets=%d n_by_s=%s"
      % (s_max, len(brackets), {b[0]: b[1] for b in brackets}))

How to promote asynchronously (ASHA) on a cluster

On a real cluster the synchronization barrier is the enemy: waiting for a whole rung to finish before promoting anyone idles workers. ASHA promotes a configuration from rung k to k+1 as soon as it is in the top 1/eta of configs already seen at rung k, no barrier.3 The block below models that promotion decision and asserts three things: with only three configs seen ASHA can already promote the single best (where synchronous Successive Halving could not), the promoted set is monotone as weaker newcomers arrive, and, as the equivalence-to-reference adversarial check, once a rung has fully reported ASHA's promoted set equals what synchronous Successive Halving keeps.

# ASHA async promotion rule, executed and asserted (numpy-only).
import math
from collections import defaultdict

class AshaLadder:
    def __init__(self, eta, max_rung):
        self.eta, self.max_rung = eta, max_rung
        self.rung_scores = defaultdict(dict)   # rung -> {config_id: score}

    def observe(self, rung, config_id, score):
        self.rung_scores[rung][config_id] = score

    def promotable(self, rung):
        """Configs at `rung` in the top 1/eta seen so far, not yet promoted.
        Decided on whoever has reported: no synchronization barrier."""
        if rung >= self.max_rung:
            return []
        scores = self.rung_scores[rung]
        k = math.floor(len(scores) / self.eta)
        if k < 1:
            return []                          # not enough seen to promote anyone
        ranked = sorted(scores, key=lambda c: scores[c], reverse=True)
        already = set(self.rung_scores[rung + 1])
        return [c for c in ranked[:k] if c not in already]

def sha_keep_reference(scores, eta):           # independent slow reference
    k = max(0, math.floor(len(scores) / eta))
    return set(sorted(scores, key=lambda c: scores[c], reverse=True)[:k])

# Barrier-freedom: with 3 configs seen, promote the single best (floor(3/3)=1).
lad = AshaLadder(eta=3, max_rung=2)
for cid, sc in [(0, 0.2), (1, 0.9), (2, 0.5)]:
    lad.observe(0, cid, sc)
assert lad.promotable(0) == [1], lad.promotable(0)

# Monotonicity: a weak newcomer does not evict an already-promoted config.
lad.observe(1, 1, 0.95)                        # config 1 now lives at rung 1
lad.observe(0, 3, 0.1)                         # a weak newcomer at rung 0
assert lad.promotable(0) == [], lad.promotable(0)

# Equivalence to reference: once the rung has fully reported, ASHA's promoted
# set equals synchronous Successive Halving's top floor(n/eta).
lad2 = AshaLadder(eta=3, max_rung=2)
order = {i: 1.0 - i * 0.1 for i in range(9)}   # config 0 best ... 8 worst
for cid, sc in order.items():
    lad2.observe(0, cid, sc)
prom = set(lad2.promotable(0))
assert prom == {0, 1, 2} == sha_keep_reference(order, 3), (prom,)
lad2.observe(2, 0, 0.5)
assert lad2.promotable(2) == []                # top rung never promotes
print("ASHA OK: best_of_3=[1] top3_of_9=%s == sha_ref" % sorted(prom))

In a production HPO framework you do not hand-write the ladder: Ray Tune ships ASHA (ASHAScheduler), Hyperband, and PBT. The scheduler snippet is a reference template (Ray is not vendored here; pin the version). The core promotion math it relies on is validated by the numpy block above.

# Ray Tune ASHA scheduler (reference template; pin ray[tune], validate before prod).
# The CORE async-promotion rule this relies on is validated by the numpy block above.
from ray import tune
from ray.tune.schedulers import ASHAScheduler

scheduler = ASHAScheduler(
    metric="val_acc", mode="max",
    max_t=81,            # max resource per trial (== R)
    grace_period=1,      # min-steps floor: never stop before this rung
    reduction_factor=3,  # eta
)

def trainable(config):
    for step in range(1, 82):
        acc = train_one_step(config)          # your training step
        tune.report(val_acc=acc)              # report (step, value); scheduler prunes

tune.run(trainable, num_samples=200, scheduler=scheduler,
         config={"lr": tune.loguniform(1e-5, 1e-1)})

How to execute and cancel a trial cooperatively

Whatever decides the stop, how the trial stops matters. A hard kill loses partial results and can leave a remote GPU container in a dirty state. The clean contract is cooperative: the trial periodically reports progress and checks whether it should stop, and the trial itself exits at a safe point. This is the Optuna pruning primitive (trial.report(value, step) then if trial.should_prune(): raise TrialPruned()), where the user's trial code raises rather than the framework killing it.5 Generalized to a per-trial cloud container (GPU consumption models), the controller writes a control signal, the trial polls it on each progress report, and on "stop" exits with a distinct code so the run is recorded as cancelled, not failed.

That cancelled-vs-failed distinction is load-bearing: a search loop usually has a failure-rate stop guard, and an early-out on purpose must not trip it. Cancellations are successful economies; only genuine errors are failures. The block below models the state machine and asserts the cooperative stop lands at the safe point (not the end), a None metric is FAILED (not CANCELLED) even mid-run, and, as the adversarial arithmetic, that nine cancels plus one real failure yield a 1/10 failure rate rather than the 10/10 a naive guard would compute and halt on.

# Cooperative-cancellation state machine, executed and asserted (numpy-only).
from enum import Enum

class Status(str, Enum):
    RUNNING = "running"
    COMPLETED = "completed"
    CANCELLED = "cancelled"    # deliberate early-out, NOT a failure
    FAILED = "failed"          # genuine error

CANCELLED_EXIT = 42            # distinct exit code the controller reads as cancelled

def run_trial(curve, control_signal, poll_every=1):
    """curve: iterable of (step, value). control_signal(step)->bool asks the
    controller 'should I stop?'. Returns (Status, exit_code, last_step)."""
    last_step = -1
    try:
        for i, (step, value) in enumerate(curve):
            last_step = step
            if value is None:
                raise ValueError("nan/none metric")     # real error -> FAILED
            if i % poll_every == 0 and control_signal(step):
                return Status.CANCELLED, CANCELLED_EXIT, last_step   # safe exit
        return Status.COMPLETED, 0, last_step
    except Exception:
        return Status.FAILED, 1, last_step

def counts_toward_failure_guard(status):
    """The failure-rate stop guard must count ONLY genuine errors."""
    return status is Status.FAILED

full = [(s, 0.5) for s in range(10)]
st, code, last = run_trial(full, lambda step: step >= 5)      # controller stops at 5
assert (st, code, last) == (Status.CANCELLED, CANCELLED_EXIT, 5)   # safe point, not 9
st2, code2, last2 = run_trial(full, lambda step: False)      # no signal -> completes
assert (st2, code2, last2) == (Status.COMPLETED, 0, 9)

# A real error mid-run is FAILED, not CANCELLED.
bad = [(0, 0.5), (1, None), (2, 0.5)]
assert run_trial(bad, lambda step: False)[:2] == (Status.FAILED, 1)

# Adversarial arithmetic: 9 cancels + 1 real failure => 1/10 failure rate,
# NOT 10/10. A guard that counted cancellations as failures would wrongly halt.
statuses = [Status.CANCELLED] * 9 + [Status.FAILED]
fail_rate = sum(counts_toward_failure_guard(s) for s in statuses) / len(statuses)
naive_rate = sum(s is not Status.COMPLETED for s in statuses) / len(statuses)
assert abs(fail_rate - 0.1) < 1e-9 and naive_rate == 1.0, (fail_rate, naive_rate)
print("coop OK: stop_at=%d fail_rate=%.2f naive_wrong=%.2f" % (last, fail_rate, naive_rate))

In a real HPO framework you do not hand-write the poll loop: Optuna implements the pruning contract. The snippet is a reference template (Optuna is not vendored here; pin the version). The core cooperative-exit and cancelled-vs-failed math it relies on is validated by the numpy block above.

# Optuna pruning contract (reference template; pin optuna, validate before prod).
# The CORE cooperative-exit math this relies on is validated by the numpy block above.
import optuna

def objective(trial):
    lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
    for step in range(100):
        acc = train_one_step(lr)              # your training step
        trial.report(acc, step)               # report (value, step)
        if trial.should_prune():              # cooperative: the trial decides
            raise optuna.TrialPruned()        # trial raises; framework does not kill
    return acc

study = optuna.create_study(direction="maximize",
                            pruner=optuna.pruners.MedianPruner(n_warmup_steps=5))
study.optimize(objective, n_trials=200)       # pruned trials are recorded, not failures

How to run it in production and tune the guard

Aggressive early stopping trades a small risk of killing a late-bloomer for a large throughput gain; tune the guard to keep that risk low:

  • Min-steps floor. Never decide before the curve has enough signal; early points are dominated by initialization noise.1
  • Min-reports floor. Require several progress reports before the first decision, so a single unlucky measurement cannot cancel a trial.
  • Confidence margin. Stop only when the forecast is confidently below the incumbent, not on a bare point estimate.
  • Poll interval. Re-evaluate periodically (not every step) to bound overhead.
  • Non-monotonic curves. RL and fine-tuning curves plateau then jump; a purely monotone extrapolator will over-cancel. Widen the margin or lengthen the floor for these workloads.
  • Keep cancellations out of the failure-rate guard. A deliberate early-out must not trip the campaign's failure-rate stop guard; separate the statuses (see the cooperative-cancellation block).

How to maintain and scale it

  • Validate the forecaster before trusting it. Which parametric curve family fits a given workload is empirical; fit it against completed runs on the target task and measure the false-cancel rate before letting it cancel live ones.1
  • Re-tune the margin per workload. The safe stopping margin is workload-specific: measure the false-cancel rate on a held-out set of full runs and tune to it, rather than importing a default.
  • Prefer asynchronous promotion at scale. On multi-worker clusters use ASHA-style async promotion to avoid a synchronization barrier; it scales near-linearly to hundreds of workers.3
  • Pair a model-based proposer with bandit early stopping. When both are available, BOHB-style pairing gives strong anytime performance and fast convergence.4
  • Compose the two families. Use bandit rungs to allocate budget and a curve forecast to decide promotions within a rung, so each stopper covers the other's blind spot.

Don't-miss checklist

  • Make every trial report (step, value) so a stopper has a curve to read.
  • Set a min-steps and min-reports floor before any cancel decision.
  • Stop on a probabilistic margin, not a point estimate, to spare slow-starters.
  • Prefer asynchronous promotion (ASHA-style) on multi-worker clusters to avoid a sync barrier.
  • Pair a model-based proposer with bandit early stopping (BOHB-style) when both are available.
  • Cancel cooperatively (the trial exits at a safe point) and record it as cancelled, not failed.
  • Keep cancellations out of the failure-rate stop guard.
  • Widen margins for non-monotonic (RL/fine-tuning) curves.

Failure modes

  • Killing a late-bloomer. Too tight a margin or too low a min-steps floor cancels a trial that would have won; loosen both.
  • Hard kill dirt. kill -9 on a remote container leaks GPU state and loses partial results; use the cooperative exit.
  • Cancelled counted as failed. Deliberate early-outs trip the failure-rate guard and stall the campaign; separate the statuses.
  • Deciding too early. Extrapolating from a few noisy points cancels good runs; enforce the floors.
  • Monotone assumption on RL curves. A plateau reads as "doomed" right before the jump; be conservative on non-monotonic workloads.
  • Overhead from over-polling. Forecasting every step costs more than it saves; bound the poll interval.

Open questions & validation

  • Which parametric curve family fits a given workload is empirical; validate the forecaster against completed runs on the target task before trusting it to cancel live ones.1
  • The safe stopping margin is workload-specific: measure the false-cancel rate on a held-out set of full runs and tune to it, rather than importing a default.
  • For strongly non-monotonic RL training, forecast-based stopping is less reliable than for supervised curves; bandit rungs with generous budgets may be the safer economy.

References

  • Domhan, Springenberg, Hutter, "Speeding Up Automatic Hyperparameter Optimization of Deep Neural Networks by Extrapolation of Learning Curves," IJCAI 2015: https://www.ijcai.org/Abstract/15/487
  • Li et al., "Hyperband: A Novel Bandit-Based Approach to Hyperparameter Optimization," JMLR 2018, arXiv:1603.06560: https://arxiv.org/abs/1603.06560
  • Li et al., "A System for Massively Parallel Hyperparameter Tuning" (ASHA), MLSys 2020, arXiv:1810.05934: https://arxiv.org/abs/1810.05934
  • Falkner, Klein, Hutter, "BOHB: Robust and Efficient Hyperparameter Optimization at Scale," ICML 2018, arXiv:1807.01774: https://arxiv.org/abs/1807.01774 · code: https://github.com/automl/HpBandSter
  • Akiba et al., "Optuna: A Next-generation Hyperparameter Optimization Framework" (pruning contract), KDD 2019, arXiv:1907.10902: https://arxiv.org/abs/1907.10902
  • Ray Tune schedulers (ASHA/Hyperband/PBT implementations): https://docs.ray.io/en/latest/tune/api/schedulers.html

Related: Autonomous experimentation loops · Evaluation integrity & anti-gaming · Goodput: useful throughput · GPU consumption models · SLOs: training platform · GRPO post-training recipe · Glossary


  1. Domhan, Springenberg, Hutter fit a weighted ensemble of parametric curve models to the initial part of a learning curve with a probabilistic model, extrapolate to the full budget, and terminate runs predicted to underperform the current best, reporting a substantial speedup of deep-net hyperparameter optimization. IJCAI 2015. 

  2. Li et al. cast hyperparameter optimization as a pure-exploration bandit over resource allocation; Hyperband runs Successive Halving across brackets with different budget/count tradeoffs and reports speedups of five-fold to over an order of magnitude versus Bayesian optimization on deep-learning and kernel problems. JMLR 2018, arXiv:1603.06560. 

  3. Li et al., ASHA promotes configurations asynchronously as soon as enough exist at a rung instead of waiting for the rung to complete, removing the synchronization bottleneck and scaling near-linearly to hundreds of workers. MLSys 2020, arXiv:1810.05934. 

  4. Falkner, Klein, Hutter, BOHB combines Bayesian-optimization proposals with Hyperband's bandit-based early stopping for strong anytime performance and fast convergence across a range of problems. ICML 2018, arXiv:1807.01774. 

  5. Akiba et al., Optuna's pruning contract is cooperative: the trial calls report(intermediate_value, step) and raises TrialPruned() when should_prune() is true, so the trial code stops itself rather than being killed by the framework. KDD 2019, arXiv:1907.10902.