Runbook: checkpoint recovery / resume¶
Scope: recover and resume a distributed training job from its last good checkpoint after a crash, preemption, or hardware fault.
Run this to resume a distributed training job after it stopped: a crash, a preemption (spot/maintenance), or a node loss. Severity: every minute down burns GPU-hours; the goal is a correct resume (loss continues, no silent data repetition), not just a process that restarts. At scale, node failure is steady-state, so this is routine, not exceptional.
Reference templates on real APIs; pin versions and validate before production use.
This is the longform procedure for RB-7 in operational runbooks. The checkpoint strategy (sharded, async, tiered via DCP) is in storage and data; launchers, elastic rendezvous, and what a correct resume must restore are in distributed training; the failure math (checkpoint interval < MTBF) is in reliability and RAS.
Trigger¶
- Job crash: a rank threw, an OOM, an uncorrectable ECC killed a rank (reliability and RAS), or the launcher exited non-zero.
- Preemption: a spot/preemptible node was reclaimed, or a maintenance drain (the driver-upgrade runbook) evicted the job.
- Node loss: a GPU fell off the bus (XID 79), a node rebooted, or a fabric fault isolated ranks (reliability and RAS, the NCCL-hang runbook).
Pre-checks¶
- Classify the failure first: was it the infrastructure (node/GPU/fabric, drain the bad node before resuming, the GPU-fault runbook/the NCCL-hang runbook) or the job (code/OOM/divergence, where resuming onto the same bug just re-crashes)? Do not resume onto known-bad hardware.
- Capacity available: enough healthy GPUs to meet the world size, or the job is configured elastic (torchrun elastic resizes the world on failure, distributed training). If short, decide: wait for capacity, or resume elastic at a smaller world size.
- Storage reachable: the checkpoint tier (shared/parallel FS or object store, storage and data) is mounted and healthy on the surviving/replacement nodes.
Flow¶
stateDiagram-v2
[*] --> Classify
Classify --> DrainNode: infra fault
Classify --> Locate: job fault or preemption
DrainNode --> Locate: bad node out of pool
Locate --> Verify_shards: last COMPLETE checkpoint
Verify_shards --> Fallback: shard missing or corrupt
Verify_shards --> Resume: all shards present
Fallback --> Resume: previous checkpoint
Resume --> Watch_loss: weights, optimizer, RNG, data
Watch_loss --> Fallback: loss spike
Watch_loss --> [*]: loss continuous
Procedure¶
A correct resume restores weights + optimizer state + RNG + dataloader position, not just weights (distributed training). Restoring only weights silently repeats data and resets the optimizer, corrupting the run.
-
Locate the last COMPLETE sharded checkpoint and verify integrity: all shards present, not a half-written one from a checkpoint that was interrupted by the same failure (storage and data). With
torch.distributed.checkpoint(DCP) each rank writes its own shard, so a complete checkpoint has the full set plus its metadata:A directory missingls -1dt "$CKPT_DIR"/step_* | head -3 # newest checkpoints LATEST=$(ls -1dt "$CKPT_DIR"/step_* | head -1) ls "$LATEST" # expect .metadata + every rank shard # sanity: shard count matches the world size that wrote it; metadata present and non-empty.metadataor short on shards is incomplete: skip it and treat the previous one as latest (Rollback). -
Resume the job, restoring the full state from
$LATEST. Confirm the world size matches what wrote the checkpoint, or run elastic so DCP reshards on load (distributed training). DCP'sloadreads the sharded state back into model + optimizer; the training loop restores RNG and advances the dataloader to the saved position:If resuming at a different world size, only a resharding-aware checkpoint (DCP) loads cleanly; a flat# elastic launcher: world size may differ from the crash; rendezvous reforms the group torchrun --nnodes=1:<max> --nproc-per-node=8 \ --rdzv-backend=c10d --rdzv-endpoint=<head>:29400 \ train.py --resume-from "$LATEST" # train.py must: dcp.load(state) -> set RNG state -> fast-forward dataloader to saved steptorch.saveof rank-0 state will not. -
Confirm the loss continues: read the first logged steps after resume. The loss should pick up at roughly the pre-failure value and keep descending. A loss spike at resume is the signal of a bad restore (missing optimizer state, wrong RNG, or skipped/repeated data), so stop and fall back (distributed training).
Verification¶
- Loss curve continuous across the resume boundary: no step discontinuity, no spike; the post-resume loss tracks the pre-failure trajectory.
- Tokens-seen consistent: the global step / tokens-consumed counter resumes from the checkpoint's value, and the dataloader is at the matching position (no replayed or skipped shards, storage and data).
- Throughput back to baseline: MFU returns to the run's normal band; if it does not, suspect the replacement node's fabric/parallelism layout (performance tuning, the NCCL-hang runbook), not the resume.
- Checkpoint cadence still < MTBF: after a node loss the effective MTBF dropped; confirm the interval is still short enough that the next failure is cheap (reliability and RAS).
Rollback¶
If the latest checkpoint is corrupt or its resume produces a loss spike, fall back one checkpoint: resume from the previous complete step_* and accept the small amount of recomputed work (this is exactly why cadence is sized below MTBF, reliability and RAS):
PREV=$(ls -1dt "$CKPT_DIR"/step_* | sed -n '2p') # second-newest
ls "$PREV" # confirm complete before resuming
torchrun ... train.py --resume-from "$PREV"
If multiple recent checkpoints are corrupt, suspect the storage path (a write storm or FS fault corrupting shards, storage and data) and validate the filesystem before trusting any further checkpoint.
Related runbooks¶
- the GPU-fault runbook: GPU fault, drain, reset, RMA (drain the bad node before resuming onto it).
- the NCCL-hang runbook: NCCL hang / collective stall (a stalled job often ends here: drain the node, resume from checkpoint).
- the MFU-regression runbook: Training MFU regression (if throughput does not return to baseline after resume).
- operational runbooks: Operational runbooks index (RB-7).
References¶
- PyTorch Distributed Checkpoint (DCP) — sharded save/load, resharding: https://pytorch.org/docs/stable/distributed.checkpoint.html
- torchrun elastic (rendezvous, world-size change on failure): https://docs.pytorch.org/docs/stable/elastic/run.html
- PyTorch FSDP (state_dict, checkpoint integration): https://docs.pytorch.org/docs/stable/fsdp.html
- NVIDIA XID errors (classify the failure that stopped the job): https://docs.nvidia.com/deploy/xid-errors/index.html
Related: Storage · Training · Reliability · GPU Fault/RMA · NCCL Hang · MFU Regression · Operational Runbooks · Glossary