Skip to content
Markdown

Runbook: RL checkpoint/resume validation

Scope: certify that a resumed RL post-training job is statistically the same experiment it was before the interruption. This is the validity layer that runs after a resume mechanically succeeds; locating a complete checkpoint and restarting the job is the checkpoint-recovery runbook. It matters most for a managed RL training API, where preemption is routine and an invalid resume silently changes a customer's result with no error anywhere.

Run this when designing the checkpoint schema for an RL training service, after any framework or engine upgrade that touches saved state, or whenever a resumed run's reward/KL/loss curves diverge from the pre-crash trajectory. Severity: correctness of a paid customer artifact; a run that resumes wrong burns the remaining GPU budget on a different experiment than the one the customer started.

Reference templates on real APIs; pin versions and validate before production use.

An RL trainer carries much more resumable state than a supervised run: beyond weights and optimizer moments there are rollout buffers with a generation-time policy version, a pinned reward model or verifier, a frozen KL reference, and running reward/advantage normalizers (PPO, the GRPO recipe). Losing any of these does not crash the resume; it forks the experiment. The disaggregated-trainer background is in async RL systems and delta weight sync; run metadata and artifact pinning live in experiment tracking and the model registry.

Trigger

  • A resumed run's curves break trajectory: reward, KL, or loss shows a step discontinuity or a transient at the resume boundary that the pre-crash curve did not predict (learning-curve extrapolation gives the band-fitting tools).
  • Designing or changing the checkpoint schema for a training service: new field, new framework version, new engine pairing between trainer and rollout workers.
  • A framework upgrade changed what state_dict contains (optimizer, scheduler, RNG handling), or the rollout engine changed sampling behavior.
  • Two "identical" resumed runs diverge more than seed-level noise explains.

Pre-checks

  • The mechanical resume already works: a complete checkpoint was located, all shards present, and the job restarts; that path is the checkpoint-recovery runbook. This runbook assumes it and asks whether the result is valid.
  • The pre-crash metric log survived. The replay gate (step 3) needs the recorded reward/KL/loss series for the steps between the checkpoint and the crash; confirm the metrics store retained them (experiment tracking).
  • Artifact pins are recorded per run: reward-model version, verifier version, KL-reference weights hash, dataset revision. If the run record does not pin them, fix that before debugging any single resume.
  • Decide the reproducibility tier the run class needs (step 5): bitwise for debugging runs, statistical for customer runs. The checks below differ in cost between the two.

Flow

flowchart TB
    A["Resume requested"] --> B{"Checkpoint schema version matches?"}
    B -->|"no"| R1["Reject: never load silently"]
    B -->|"yes"| C{"Sentinel digests match?"}
    C -->|"no"| R2["Corrupt: fall back one checkpoint"]
    C -->|"yes"| D["Restore full state inventory"]
    D --> E["Replay overlap window vs recorded metrics"]
    E -->|"within tolerance band"| F["Resume certified: continue past crash point"]
    E -->|"diverges"| G["Inventory gap: find the state item that was lost"]
    G --> R2
    F --> H["Log pins + schema version to the run audit record"]

Procedure

  1. Inventory the state the checkpoint must carry, and what losing each item costs. Review the trainer's save path against this list; anything absent is a latent resume fork:

    State item If lost on resume
    Model weights Not a resume at all; always present
    Optimizer moments (Adam m/v) Silent cold-restart transient: the first steps take distorted update sizes, curves dip or spike, then recover into a different trajectory
    LR scheduler step Wrong learning rate from step one; warmup replays or decay skips
    Per-rank RNG streams (python/numpy/torch, CPU and CUDA) Different rollouts and dropout; statistically acceptable for RL, but bit-reproducibility is gone (see step 5)
    Prompt-sampler / dataloader position Prompts repeat or skip; effective epoch count silently changes
    Rollout buffer + the policy version that generated it Replayed samples are off-policy relative to the resumed weights; see step 4
    Reward-model / verifier version pin A different judge scores the rest of the run: a different experiment
    KL-reference weights The KL penalty is measured against a different anchor; divergence budget silently moves
    Running normalizers (reward/advantage statistics) Early post-resume advantages are mis-scaled; a shorter transient than lost optimizer state, same silent shape
  2. Engineer the write side so an invalid checkpoint cannot look valid. Three properties, all cheap at save time (reference template):

    # Reference template: write-side validity. Atomic rename, versioned schema,
    # sentinel digests of a few tensors recorded at save time.
    state = {
        "schema_version": SCHEMA_VERSION,          # bump on any layout change
        "model": model.state_dict(),
        "optim": optimizer.state_dict(),
        "sched": scheduler.state_dict(),
        "rng": {"torch": torch.get_rng_state(), "cuda": torch.cuda.get_rng_state_all()},
        "sampler_pos": sampler.state_dict(),
        "pins": {"reward_model": rm_version, "kl_ref": ref_hash, "dataset": ds_rev},
        "sentinels": {name: float(p.detach().float().abs().sum())
                      for name, p in list(model.named_parameters())[:4]},
    }
    torch.save(state, tmp_path)                    # write to a temp name
    os.replace(tmp_path, final_path)               # atomic publish, never half-visible
    

    On load, reject a schema-version mismatch outright and re-compute the sentinel sums before trusting anything else. Sharded saves get the same properties from torch.distributed.checkpoint plus its metadata file; the completeness check is in the checkpoint-recovery runbook.

  3. Gate every resume with the replay-overlap check. The service keeps the metric log up to the crash; the checkpoint sits some steps earlier. Resume from the checkpoint, replay the overlap window, and compare against the recorded curve before letting the run continue past the crash point. With the full inventory restored and per-step RNG streams, the overlap replays exactly; a lost state item shows up as divergence inside the window. The gate logic is executed and asserted here (simulated Adam-like dynamics; it validates the check, not a framework):

    # resume_check.py - validated: the replay-overlap gate that certifies a resume as
    # statistically valid. Simulated Adam-like dynamics on a quadratic, seeded per-step
    # RNG streams; it validates the CHECK, not any training framework.
    from __future__ import annotations
    
    import numpy as np
    
    SCHEMA = 2
    
    
    def step_noise(step: int, seed: int = 7) -> float:
        """Per-step RNG stream keyed by step index, so a replay sees identical noise."""
        return float(np.random.default_rng(seed * 100_003 + step).normal(0.0, 0.05))
    
    
    def train(w: float, m: float, v: float, start: int, steps: int,
              lr: float = 0.2) -> tuple[list[float], float, float, float]:
        """Adam-like updates on loss = w^2 with per-step noise; returns the metric series."""
        losses: list[float] = []
        for t in range(start, start + steps):
            g = 2.0 * w + step_noise(t)
            m = 0.9 * m + 0.1 * g
            v = 0.999 * v + 0.001 * g * g
            w -= lr * m / (np.sqrt(v) + 1e-8)
            losses.append(w * w)
        return losses, w, m, v
    
    
    def checkpoint(w: float, m: float, v: float, step: int) -> dict:
        """Versioned snapshot with a sentinel digest computed at save time."""
        return {"schema": SCHEMA, "w": w, "m": m, "v": v, "step": step,
                "digest": round(w + m + v, 12)}
    
    
    def load(ck: dict) -> tuple[float, float, float, int]:
        """Fail fast on schema drift or a corrupted sentinel; never load silently."""
        if ck.get("schema") != SCHEMA:
            raise ValueError(f"incompatible checkpoint schema: {ck.get('schema')}")
        assert round(ck["w"] + ck["m"] + ck["v"], 12) == ck["digest"], "sentinel digest mismatch"
        return ck["w"], ck["m"], ck["v"], ck["step"]
    
    
    def replay_valid(recorded: list[float], replayed: list[float], tol: float) -> bool:
        """The resume gate: the replayed overlap window must track the recorded curve."""
        assert len(recorded) == len(replayed) and recorded
        return max(abs(a - b) for a, b in zip(recorded, replayed)) <= tol
    
    
    # Reference run: steps 0..49, checkpoint at step 40, crash at step 50. The metric
    # log for steps 40..49 survives the crash; the checkpoint is the resume point.
    losses, w, m, v = train(1.0, 0.0, 0.0, start=0, steps=40)
    ck = checkpoint(w, m, v, 40)
    tail, *_ = train(w, m, v, start=40, steps=10)
    recorded = losses + tail                                # metric log up to the crash
    
    # 1) Full-state resume: the replayed overlap window matches the log exactly.
    w2, m2, v2, s2 = load(ck)
    replay_full, *_ = train(w2, m2, v2, start=s2, steps=10)
    assert replay_valid(recorded[40:50], replay_full, tol=1e-12)
    
    # 2) Resume that silently dropped optimizer moments: warm Adam becomes a cold
    # restart, the overlap window diverges, and the gate rejects the resume.
    replay_cold, *_ = train(w2, 0.0, 0.0, start=s2, steps=10)
    assert not replay_valid(recorded[40:50], replay_cold, tol=1e-3)
    dev = max(abs(a - b) for a, b in zip(recorded[40:50], replay_cold))
    
    # 3) Schema drift is rejected at load, never absorbed.
    try:
        load({**ck, "schema": 1})
        raise SystemExit("schema mismatch must be rejected")
    except ValueError:
        pass
    
    print("full-state replay: max deviation 0.0 over the 10-step overlap window")
    print(f"cold-optimizer replay: max deviation {dev:.4f} > tol 1e-3, resume rejected")
    print("all resume-validation assertions passed")
    

    Output of the run: full-state replay deviates by exactly 0.0 over the overlap window; the cold-optimizer replay deviates by 0.2960 against a 1e-3 tolerance and is rejected; the schema mismatch raises. In production the tolerance is a configurable band, not a constant: for bitwise-tier runs it is zero on deterministic kernels, and for statistical-tier runs a band fitted from the run's own step-to-step metric noise is a reasonable starting point (rule of thumb, not a measured universal).

  4. Apply the RL staleness rules before continuing past the crash point.

    • Rollout buffer: samples in flight were generated by the checkpoint's policy version or earlier. If the trainer is strictly on-policy, discard anything generated after the checkpoint was written (it belongs to steps the resume will redo); an async trainer that already tolerates bounded lag can keep samples within its staleness budget and importance-correct as it normally does (async RL systems).
    • Reward model and verifier: reload exactly the pinned version. A "latest" tag here is how a resumed run silently becomes a new experiment (reward-model training, experiment tracking).
    • KL reference: confirm the reference weights hash matches the pin; a re-materialized reference from a different base snapshot moves the divergence anchor.
    • Rollout engine weights: after resume, push the restored trainer weights to the rollout fleet before generating; a fleet still serving pre-crash weights re-creates the staleness the checkpoint just resolved (delta weight sync).
  5. Choose the reproducibility tier per run class, and price it.

    • Statistically valid (the product default): full state inventory plus the replay gate on metric curves. Costs almost nothing at runtime and survives world-size changes and nondeterministic kernels.
    • Bitwise reproducible (debugging and dispute resolution): additionally requires deterministic kernels, fixed world size and batching order, and per-rank RNG restoration; expect a throughput penalty from deterministic algorithms and document it as a debug mode, not the default (PyTorch's reproducibility notes cover the switches).

Verification

  • Overlap window certified: the replayed steps between checkpoint and crash track the recorded series within the configured band, and the certification result (window, tolerance, deviation) is written to the run's audit record.
  • Curves continuous past the crash point: reward, KL, and entropy resume their pre-crash trajectory with no unexplained transient over the next observation window (learning-curve extrapolation).
  • Pins logged: schema version, reward-model version, KL-reference hash, and dataset revision recorded for the resumed segment; the run record shows one experiment, not two.
  • Rollout fleet synchronized: rollout workers confirm the post-resume weight version before the first new generation (delta weight sync).

Rollback

  • Gate failure: fall back one checkpoint and re-run the gate there (the checkpoint-recovery runbook owns the fallback mechanics). Repeated failures across checkpoints mean the save path is dropping a state item; fix the writer, not the resume.
  • No checkpoint passes: mark the run non-resumable and invoke the service policy for restart or credit; a managed API must define this outcome explicitly rather than resuming invalid state (SLOs for the training platform).
  • Schema migration gone wrong: keep the old loader available behind the version check so pre-migration checkpoints remain loadable; never mutate old checkpoints in place.

References

  • PyTorch Distributed Checkpoint (sharded save/load, resharding): https://pytorch.org/docs/stable/distributed.checkpoint.html
  • PyTorch reproducibility notes (RNG state, deterministic algorithms and their cost): https://pytorch.org/docs/stable/notes/randomness.html
  • torchrun elastic (resume at a different world size): https://docs.pytorch.org/docs/stable/elastic/run.html
  • Mohan et al., CheckFreq: Frequent, Fine-Grained DNN Checkpointing (FAST '21; low-overhead checkpoint cadence and resume correctness): https://www.usenix.org/conference/fast21/presentation/mohan

Related: Checkpoint recovery · Async RL systems · Delta weight sync · GRPO post-training recipe · PPO · Reward model training · Experiment tracking and model registry · Learning-curve extrapolation · SLOs: training platform · Operational runbooks