Skip to content
Markdown

Runbook: checkpoint recovery / resume

Scope: recover and resume a distributed training job from its last good checkpoint after a crash, preemption, or hardware fault.

Run this to resume a distributed training job after it stopped: a crash, a preemption (spot/maintenance), or a node loss. Severity: every minute down burns GPU-hours; the goal is a correct resume (loss continues, no silent data repetition), not just a process that restarts. At scale, node failure is steady-state, so this is routine, not exceptional.

Reference templates on real APIs; pin versions and validate before production use.

This is the longform procedure for RB-7 in operational runbooks. The checkpoint strategy (sharded, async, tiered via DCP) is in storage and data; launchers, elastic rendezvous, and what a correct resume must restore are in distributed training; the failure math (checkpoint interval < MTBF) is in reliability and RAS.

Trigger

Pre-checks

  • Classify the failure first: was it the infrastructure (node/GPU/fabric, drain the bad node before resuming, the GPU-fault runbook/the NCCL-hang runbook) or the job (code/OOM/divergence, where resuming onto the same bug just re-crashes)? Do not resume onto known-bad hardware.
  • Capacity available: enough healthy GPUs to meet the world size, or the job is configured elastic (torchrun elastic resizes the world on failure, distributed training). If short, decide: wait for capacity, or resume elastic at a smaller world size.
  • Storage reachable: the checkpoint tier (shared/parallel FS or object store, storage and data) is mounted and healthy on the surviving/replacement nodes.

Flow

stateDiagram-v2
    [*] --> Classify
    Classify --> DrainNode: infra fault
    Classify --> Locate: job fault or preemption
    DrainNode --> Locate: bad node out of pool
    Locate --> Verify_shards: last COMPLETE checkpoint
    Verify_shards --> Fallback: shard missing or corrupt
    Verify_shards --> Resume: all shards present
    Fallback --> Resume: previous checkpoint
    Resume --> Watch_loss: weights, optimizer, RNG, data
    Watch_loss --> Fallback: loss spike
    Watch_loss --> [*]: loss continuous

Procedure

A correct resume restores weights + optimizer state + RNG + dataloader position, not just weights (distributed training). Restoring only weights silently repeats data and resets the optimizer, corrupting the run.

CKPT_DIR=/shared/ckpt/run-1234        # sharded DCP checkpoint root
  1. Locate the last COMPLETE sharded checkpoint and verify integrity: all shards present, not a half-written one from a checkpoint that was interrupted by the same failure (storage and data). With torch.distributed.checkpoint (DCP) each rank writes its own shard, so a complete checkpoint has the full set plus its metadata:

    ls -1dt "$CKPT_DIR"/step_* | head -3              # newest checkpoints
    LATEST=$(ls -1dt "$CKPT_DIR"/step_* | head -1)
    ls "$LATEST"                                       # expect .metadata + every rank shard
    # sanity: shard count matches the world size that wrote it; metadata present and non-empty
    
    A directory missing .metadata or short on shards is incomplete: skip it and treat the previous one as latest (Rollback).

  2. Resume the job, restoring the full state from $LATEST. Confirm the world size matches what wrote the checkpoint, or run elastic so DCP reshards on load (distributed training). DCP's load reads the sharded state back into model + optimizer; the training loop restores RNG and advances the dataloader to the saved position:

    # elastic launcher: world size may differ from the crash; rendezvous reforms the group
    torchrun --nnodes=1:<max> --nproc-per-node=8 \
      --rdzv-backend=c10d --rdzv-endpoint=<head>:29400 \
      train.py --resume-from "$LATEST"
    # train.py must: dcp.load(state) -> set RNG state -> fast-forward dataloader to saved step
    
    If resuming at a different world size, only a resharding-aware checkpoint (DCP) loads cleanly; a flat torch.save of rank-0 state will not.

  3. Confirm the loss continues: read the first logged steps after resume. The loss should pick up at roughly the pre-failure value and keep descending. A loss spike at resume is the signal of a bad restore (missing optimizer state, wrong RNG, or skipped/repeated data), so stop and fall back (distributed training).

Verification

  • Loss curve continuous across the resume boundary: no step discontinuity, no spike; the post-resume loss tracks the pre-failure trajectory.
  • Tokens-seen consistent: the global step / tokens-consumed counter resumes from the checkpoint's value, and the dataloader is at the matching position (no replayed or skipped shards, storage and data).
  • Throughput back to baseline: MFU returns to the run's normal band; if it does not, suspect the replacement node's fabric/parallelism layout (performance tuning, the NCCL-hang runbook), not the resume.
  • Checkpoint cadence still < MTBF: after a node loss the effective MTBF dropped; confirm the interval is still short enough that the next failure is cheap (reliability and RAS).

Rollback

If the latest checkpoint is corrupt or its resume produces a loss spike, fall back one checkpoint: resume from the previous complete step_* and accept the small amount of recomputed work (this is exactly why cadence is sized below MTBF, reliability and RAS):

PREV=$(ls -1dt "$CKPT_DIR"/step_* | sed -n '2p')     # second-newest
ls "$PREV"                                            # confirm complete before resuming
torchrun ... train.py --resume-from "$PREV"

If multiple recent checkpoints are corrupt, suspect the storage path (a write storm or FS fault corrupting shards, storage and data) and validate the filesystem before trusting any further checkpoint.

References

  • PyTorch Distributed Checkpoint (DCP) — sharded save/load, resharding: https://pytorch.org/docs/stable/distributed.checkpoint.html
  • torchrun elastic (rendezvous, world-size change on failure): https://docs.pytorch.org/docs/stable/elastic/run.html
  • PyTorch FSDP (state_dict, checkpoint integration): https://docs.pytorch.org/docs/stable/fsdp.html
  • NVIDIA XID errors (classify the failure that stopped the job): https://docs.nvidia.com/deploy/xid-errors/index.html

Related: Storage · Training · Reliability · GPU Fault/RMA · NCCL Hang · MFU Regression · Operational Runbooks · Glossary