Activation checkpointing and memory offloading¶
Scope: trade compute and bandwidth for HBM capacity. Recompute activations in the backward pass (activation/gradient checkpointing), and offload activations, parameters, gradients, and optimizer state to CPU/NVMe, with selective policies and FSDP/DeepSpeed integration.
What it is¶
Two distinct techniques that relieve GPU memory pressure by spending a different resource.
Activation (gradient) checkpointing. Instead of storing every intermediate activation from the forward pass for reuse in the backward pass, you store only the inputs to a checkpointed region and recompute the intermediate activations on the fly during backward, only when the gradient needs them. This cuts peak activation memory at the cost of an extra forward pass over the checkpointed region. PyTorch automates it with torch.utils.checkpoint: wrap a layer or sequence of layers and their forward activations are not retained.14
Memory offloading. Move tensors that do not need to be resident on the GPU (infrequently accessed parameters, gradients, optimizer state, or activations) to CPU DRAM or NVMe, and stream them back just-in-time. This trades interconnect bandwidth (and latency) for capacity. DeepSpeed ZeRO-Infinity (training) and ZeRO-Inference (inference) automate layer-by-layer prefetch from CPU or NVMe, overlapping transfers with compute.28
The two compose: checkpointing reduces the activation footprint you have to keep or offload; offloading handles parameters/optimizer state that checkpointing does not touch.
Architecture¶
Both techniques sit on the same axis: give up a cheap, abundant resource (compute FLOPS, interconnect bandwidth) to buy a scarce one (HBM capacity). The forward/backward data flow makes the trade concrete.
flowchart LR
FWD["Forward pass"] --> KEEP["Store checkpoint<br/>boundary inputs only"]
FWD --> DROP["Discard intermediate<br/>activations (saves HBM)"]
KEEP --> SAVE["Lower peak<br/>HBM usage"]
BWD["Backward pass<br/>(needs activations)"] --> RECOMP["Recompute discarded<br/>activations (extra forward FLOPS)"]
DROP -.-> RECOMP
BWD --> FETCH["Stream offloaded tensors<br/>back from CPU/NVMe"]
OFFLOAD["Offload params / grads /<br/>optimizer state to CPU / NVMe"] -.->|"PCIe / NVLink-C2C bandwidth"| FETCH
RECOMP --> COST["Extra cost:<br/>compute + interconnect"]
FETCH --> COST
In the forward pass the checkpointed region stores only its boundary inputs and drops the intermediate activations, which is where the HBM saving comes from. In the backward pass those dropped activations are recomputed on demand (one extra partial forward), and any offloaded parameters, gradients, or optimizer state are streamed back from CPU/NVMe over PCIe or NVLink-C2C. The cost of both paths is paid off the critical resource: extra compute for recompute, extra interconnect traffic for offload. Whether that cost is free or fatal depends entirely on overlap, which the timing model in How to run it in production makes precise.
Why use it¶
Activations dominate training memory for long sequences and deep stacks. For a large LLM it is frequently impossible to store all intermediate activations for backprop without an OOM. Without checkpointing the only levers are smaller batch size or fewer experts/layers; checkpointing lets you keep the larger batch and fit the model.1
The trade is favorable on current hardware because modern GPUs supply abundant FLOPS relative to HBM capacity, so there is compute headroom to absorb the extra recomputation. The book frames it directly: "you exchange some of the GPU's ample compute FLOPS to overcome its limited HBM capacity."1 This is the same memory-vs-compute axis discussed in Roofline Model and Arithmetic Intensity and GPU Memory Hierarchy. Recompute raises arithmetic intensity of the backward region while shrinking its working set.
Offloading extends the same logic past HBM entirely: trillion-parameter models can stage components on NVMe and swap into GPU memory just-in-time, provided transfers overlap compute so the training loop never stalls.2
The size of the win is not vague. For a chain of n uniform layers, checkpointing at segment boundaries turns peak activation memory from O(n) (store everything) into O(sqrt(n)) at the optimal segment size, in exchange for recomputing the interior layers once. The block below is a self-contained model of that trade, checked against a brute-force reference across every segment size and against the sqrt-n optimum found by search (numpy only, runnable as-is):
import numpy as np
def peak_activation_memory(n_layers, segment_size):
"""Peak activation memory (in per-layer units) when checkpointing a chain
of n_layers with checkpoints every segment_size layers. Peak is hit on the
first segment: all k saved boundaries stay resident, plus the (segment_size
- 1) interior activations recomputed for the segment in flight."""
if segment_size < 1 or segment_size > n_layers:
raise ValueError("segment_size must be in [1, n_layers]")
k = int(np.ceil(n_layers / segment_size)) # saved boundaries
return k + segment_size - 1
def recompute_layer_forwards(n_layers, segment_size):
"""Extra forward passes in backward: every non-boundary layer once."""
if segment_size < 1:
raise ValueError("segment_size must be >= 1")
return max(n_layers - int(np.ceil(n_layers / segment_size)), 0)
def brute_force_peak(n_layers, segment_size):
"""Slow independent reference: all boundary checkpoints stay live for the
whole backward pass; each recomputed segment adds its interior transiently."""
boundaries = list(range(0, n_layers, segment_size))
k = len(boundaries)
peak = 0
for start in reversed(boundaries):
interior = min(start + segment_size, n_layers) - start - 1
peak = max(peak, k + interior)
return peak
n = 100
seg = int(round(np.sqrt(n))) # sqrt rule -> 10
peak = peak_activation_memory(n, seg)
assert peak < n, f"checkpointing must beat the {n}-unit baseline, got {peak}"
assert peak == 19, f"expected 19 units at seg={seg}, got {peak}"
# Equivalence to the reference across ALL segment sizes.
for s in range(1, n + 1):
assert peak_activation_memory(n, s) == brute_force_peak(n, s), f"mismatch at seg={s}"
# The O(sqrt(n)) minimum: search agrees with the sqrt rule, bound holds.
peaks = {s: peak_activation_memory(n, s) for s in range(1, n + 1)}
best = min(peaks, key=peaks.get)
assert abs(best - seg) <= 1, f"optimum {best} must be near sqrt(n)={seg}"
assert peaks[best] <= 2 * np.sqrt(n) + 1, "peak at optimum must be O(sqrt(n))"
# Recompute is strictly less than one full extra pass over all layers.
assert recompute_layer_forwards(n, seg) == n - int(np.ceil(n / seg)) < n
# Adversarial boundaries: seg=1 stores all, recomputes none; seg=n saves nothing.
assert peak_activation_memory(n, 1) == n and recompute_layer_forwards(n, 1) == 0
assert peak_activation_memory(n, n) == n
for bad in (0, n + 1, -5):
try:
peak_activation_memory(n, bad)
raised = False
except ValueError:
raised = True
assert raised, f"segment_size={bad} must raise"
print(f"peak={peak} vs baseline {n} ({n / peak:.1f}x), "
f"recompute={recompute_layer_forwards(n, seg)} layer-forwards, optimum seg={best}")
Running it prints peak=19 vs baseline 100 (5.3x), recompute=90 layer-forwards, optimum seg=10: at 100 layers the sqrt-optimal segment cuts peak activation memory 5.3x for the price of recomputing 90 layer-forwards (less than one full extra pass). Real transformer layers are not perfectly uniform, so treat the sqrt point as the starting guess, then measure.
When to use it (and when not)¶
Use activation checkpointing when:
- Peak activation memory (not parameters) is the OOM cause: long context, deep transformer/MoE stacks, large microbatch.
- You are willing to pay roughly one extra forward per checkpointed region for materially lower peak memory.
Checkpoint selectively, not everything. A common strategy is to checkpoint only the transformer blocks (which hold the bulk of activations) and leave small layers (layer norms, embeddings) uncheckpointed. This captures most of the memory savings with minimal recompute overhead.1 Full recompute of cheap pointwise ops wastes compute for little gain; selective activation checkpointing (SAC) addresses exactly this (see HOW).5
Use offloading when:
- The model state (parameters + gradients + optimizer state) exceeds aggregate GPU memory even after sharding, or you want a larger batch than HBM allows.
- You have interconnect bandwidth to hide the transfers: NVLink-C2C on Grace-Hopper/Grace-Blackwell superchips, or PCIe/NVMe with enough headroom and overlap.2
Prefer not to offload when transfers cannot be hidden behind compute. Offloading onto a slow PCIe link without overlap converts a memory problem into a throughput collapse. The book notes NVIDIA Unified Memory and NVMe swap can paper over capacity but introduce unpredictable paging stalls; explicit, managed offload is usually preferable for predictable performance. NVMe should be a last resort via OS swapping, not a primary Unified Memory target.2
How to use it¶
Wrap the layer (or block) you want to rematerialize. The book wraps blocks with torch.utils.checkpoint.1 Official PyTorch docs require passing use_reentrant explicitly, and recommend use_reentrant=False (the non-reentrant variant records the autograd graph, stops recomputation once needed activations are available, and supports all backward APIs); use_reentrant=True always recomputes the full region and has known limitations. PyTorch 2.9 raises if the flag is omitted.4
Reference template (requires torch, not installed here). The numpy model in Why use it already validates the underlying memory/recompute trade this wrapper realizes.
# reference template - requires torch
import torch
from torch.utils.checkpoint import checkpoint
class Block(torch.nn.Module):
def forward(self, x):
# ... attention + MLP ...
return x
def forward_with_checkpointing(blocks, x):
for block in blocks:
# Recompute this block's activations in backward instead of storing them.
# use_reentrant=False is the recommended (non-reentrant) variant.
x = checkpoint(block, x, use_reentrant=False)
return x
Selective activation checkpointing (SAC)¶
All-or-nothing recompute over-pays for cheap ops. SAC (PyTorch 2.5 prototype) recomputes cheap ops while force-saving expensive ones (matmuls, attention, convolutions) via a policy function. This is not in the book; it is the current recommended way to tune the recompute-vs-memory point per op.5
Reference template (requires torch, not installed here):
# reference template - requires torch
from functools import partial
from torch.utils.checkpoint import (
checkpoint,
create_selective_checkpoint_contexts,
CheckpointPolicy,
)
compute_intensive = {torch.ops.aten.mm.default, torch.ops.aten._scaled_dot_product_flash_attention.default}
def policy_fn(ctx, op, *args, **kwargs):
# Save expensive ops; recompute the rest.
if op in compute_intensive:
return CheckpointPolicy.MUST_SAVE
return CheckpointPolicy.PREFER_RECOMPUTE
out = checkpoint(
fn,
*args,
use_reentrant=False,
context_fn=partial(create_selective_checkpoint_contexts, policy_fn),
)
Under torch.compile, an automatic alternative is the activation memory budget: torch._dynamo.config.activation_memory_budget in [0.0, 1.0] (0.0 = plain checkpointing behavior, 1.0 = default compile behavior). The compiler then applies a Pareto-optimal SAC policy for that budget. Experimental, compile-only.5 See torch.compile: Graph Capture, Backends, and Recompiles.
The "save expensive ops, recompute cheap ones" rule is not a heuristic hint: under a memory budget it is exactly a 0/1 knapsack (weight = stored activation bytes, value = recompute FLOPS avoided), and the Pareto-optimal split is found by dynamic programming. The numpy block below builds that optimum, checks it against brute force over 3000 randomized instances, and shows that a naive density-ordered greedy is strictly suboptimal on some of them, so the exact policy matters (numpy only, runnable as-is):
import itertools
import numpy as np
def sac_cost(save, mem, flops):
"""(peak stored memory, recompute FLOPS) for a boolean save decision per op."""
save = np.asarray(save, dtype=bool)
mem, flops = np.asarray(mem, float), np.asarray(flops, float)
return float(mem[save].sum()), float(flops[~save].sum())
def optimal_sac_policy(mem, flops, budget):
"""Which ops to force-SAVE under a memory budget to minimize recompute FLOPS.
Equivalent to maximizing recompute-avoided (value=flops) subject to stored
memory (weight=mem) <= budget: a 0/1 knapsack, solved exactly by DP.
Requires integer memory units (activation-tile counts)."""
mem = np.asarray(mem)
flops = np.asarray(flops, float)
if not np.allclose(mem, np.round(mem)):
raise ValueError("DP requires integer memory units")
mem = mem.astype(int)
cap, n = int(budget), len(mem)
dp = np.zeros(cap + 1)
keep = [[False] * (cap + 1) for _ in range(n)]
for i in range(n):
w, v = int(mem[i]), flops[i]
for c in range(cap, w - 1, -1):
if dp[c - w] + v > dp[c] + 1e-12:
dp[c], keep[i][c] = dp[c - w] + v, True
save = np.zeros(n, dtype=bool)
c = int(np.argmax(dp))
for i in range(n - 1, -1, -1):
if keep[i][c]:
save[i] = True
c -= int(mem[i])
return save
def greedy_sac_policy(mem, flops, budget):
"""Density-ordered greedy (save highest FLOPS-per-byte first). Fast but NOT
optimal for 0/1 knapsack; included only to show DP can strictly beat it."""
mem, flops = np.asarray(mem, float), np.asarray(flops, float)
save, used = np.zeros(len(mem), dtype=bool), 0.0
for i in np.argsort(-(flops / np.maximum(mem, 1e-12))):
if used + mem[i] <= budget + 1e-9:
save[i], used = True, used + mem[i]
return save
def brute_force(mem, flops, budget):
best = None
for bits in itertools.product([False, True], repeat=len(mem)):
peak, rc = sac_cost(bits, mem, flops)
if peak <= budget + 1e-9 and (best is None or rc < best[1]):
best = (np.array(bits), rc)
return best
# attn, mlp matmul (expensive), layernorm, gelu, dropout (cheap).
names = ["attn", "mlp_mm", "layernorm", "gelu", "dropout"]
mem = [4, 4, 1, 2, 2]
flops = [50, 40, 1, 2, 1]
budget = 9 # baseline full-save = 13
save = optimal_sac_policy(mem, flops, budget)
peak, rc = sac_cost(save, mem, flops)
assert peak <= budget, f"over budget: {peak} > {budget}"
assert abs(rc - brute_force(mem, flops, budget)[1]) < 1e-9, "DP must equal optimum"
saved = {names[i] for i in range(5) if save[i]}
recomputed = {names[i] for i in range(5) if not save[i]}
assert {"attn", "mlp_mm"} <= saved, f"must save the matmuls, got {saved}"
assert recomputed <= {"layernorm", "gelu", "dropout"}, f"recompute cheap ops: {recomputed}"
# Strictly inside the all-or-nothing frontier.
_, rc_all_recompute = sac_cost([False] * 5, mem, flops)
peak_all_save, _ = sac_cost([True] * 5, mem, flops)
assert rc < rc_all_recompute and peak < peak_all_save
# Looser budget never raises recompute (Pareto monotonicity); full budget -> 0.
prev = None
for b in [4, 6, 9, 11, 13]:
_, r = sac_cost(optimal_sac_policy(mem, flops, b), mem, flops)
assert prev is None or r <= prev + 1e-9, "recompute rose as budget grew"
prev = r
assert prev == 0.0
# Adversarial: starvation budget recomputes everything.
assert sac_cost(optimal_sac_policy(mem, flops, 0), mem, flops)[1] == rc_all_recompute
# Fuzz: DP == brute force everywhere, and strictly beats greedy on some instances.
rng = np.random.default_rng(1)
dp_beats_greedy = 0
for _ in range(3000):
n = int(rng.integers(3, 8))
m, f = rng.integers(1, 6, n), rng.integers(1, 60, n).astype(float)
b = int(rng.integers(0, int(m.sum()) + 1))
_, rc_dp = sac_cost(optimal_sac_policy(m, f, b), m, f)
_, rc_bf = sac_cost(brute_force(m, f, b)[0], m, f)
_, rc_gr = sac_cost(greedy_sac_policy(m, f, b), m, f)
assert abs(rc_dp - rc_bf) < 1e-9, "DP != brute-force optimum"
assert rc_dp <= rc_gr + 1e-9, "DP must never lose to greedy"
dp_beats_greedy += rc_dp < rc_gr - 1e-9
assert dp_beats_greedy > 0, "greedy must be strictly suboptimal somewhere"
print(f"save={sorted(saved)} recompute={sorted(recomputed)} "
f"peak={peak:.0f}<= {budget} rc={rc:.0f} vs {rc_all_recompute:.0f} all-recompute; "
f"DP beat greedy on {dp_beats_greedy}/3000")
Running it prints save=['attn', 'layernorm', 'mlp_mm'] recompute=['dropout', 'gelu'] peak=9<= 9 rc=3 vs 94 all-recompute; DP beat greedy on 404/3000: the optimum keeps the two matmuls (and, since it has spare budget, the tiny layernorm) and recomputes only the cheap pointwise ops, cutting recompute from 94 down to 3 units while staying at the 9-unit memory budget. This is the exact behavior the MUST_SAVE/PREFER_RECOMPUTE policy and the activation_memory_budget knob approximate at op granularity.
How to integrate: FSDP automatic checkpointing and CPU offload¶
FSDP (ZeRO Stage-3: shards parameters, gradients, optimizer state across GPUs) can apply activation checkpointing and offload parameters/gradients under the hood. The book wraps the model and sets the checkpointing policy plus CPU offload (FSDP1 API shown).3 Reference template (requires torch with a CUDA/NCCL build, not installed here):
# reference template - requires torch + CUDA/NCCL
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
CPUOffload,
ShardingStrategy,
BackwardPrefetch,
MixedPrecision,
)
dist.init_process_group("nccl")
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
model = MyModel().cuda()
fsdp_model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
use_orig_params=True, # no flat-param; better overlap, simpler state_dict
cpu_offload=CPUOffload(offload_params=True), # params + grads to CPU when idle
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
activation_checkpointing_policy={
nn.TransformerEncoderLayer,
nn.TransformerDecoderLayer,
nn.MultiheadAttention,
},
)
optimizer = torch.optim.AdamW(fsdp_model.parameters(), weight_decay=0.01)
activation_checkpointing_policy tells FSDP to rematerialize activations in those submodules, the same memory savings as manual checkpoint() wrappers without the boilerplate; FSDP also handles uneven per-GPU batch sizes (useful for MoE). CPUOffload(offload_params=True) moves parameters and gradients to CPU when not needed on-GPU, reducing peak GPU memory.3
Note on API surface: the book's FullyShardedDataParallel + CPUOffload is the FSDP1 interface. For composing checkpointing with FSDP, official PyTorch provides apply_activation_checkpointing with a check_fn (and checkpoint_wrapper / CheckpointImpl) from torch.distributed.algorithms._checkpoint.checkpoint_wrapper; naive torch.utils.checkpoint on submodules can conflict with FSDP sharding because recompute needs full parameters. Prefer the FSDP-integrated path.6 PyTorch also offers a newer fully-sharded API (torch.distributed.fsdp.fully_shard, "FSDP2"); check your PyTorch version for the supported surface.7
How to scale it¶
Checkpointing shrinks the per-GPU activation footprint; sharding and offload shrink the parameter, gradient, and optimizer-state footprint. As you add GPUs, the sharding strategy sets how much each GPU holds versus how much crosses the fabric (book):3
FULL_SHARD(ZeRO-3): smallest per-GPU footprint; best with a very fast multinode fabric.HYBRID_SHARD: shard within a node, replicate shards across nodes; lower cross-node traffic, higher per-node memory, often higher throughput on a decent-but-not-blistering fabric.SHARD_GRAD_OP(ZeRO-2): shard only gradients + optimizer state, full parameter copy per GPU; for slow networks or few GPUs.
For layers too large for one GPU, FSDP composes with tensor parallel (within node) and pipeline parallel (across nodes). See FSDP, Distributed Training Platform.
How to run it in production: offloading to CPU and NVMe (DeepSpeed)¶
For state that exceeds even sharded GPU memory, DeepSpeed ZeRO-Infinity offloads optimizer state and parameters to CPU or NVMe and streams them back. Configuration is JSON, not code (verify field names against your DeepSpeed version):9
{
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "nvme",
"nvme_path": "/local_nvme",
"pin_memory": true,
"buffer_count": 4
},
"offload_param": {
"device": "nvme",
"nvme_path": "/local_nvme",
"pin_memory": true,
"buffer_count": 5,
"buffer_size": 1e8,
"max_in_cpu": 1e9
}
},
"aio": {
"block_size": 1048576,
"queue_depth": 8,
"thread_count": 1,
"single_submit": false,
"overlap_events": true
}
}
device may be cpu or nvme; the aio block tunes asynchronous NVMe I/O. The critical correctness/perf rule from the book holds regardless of framework: pin host buffers and use asynchronous non-blocking DMA (.to(device, non_blocking=True) / cudaMemcpyAsync) so transfers overlap compute and the loop does not stall.29 See DeepSpeed and ZeRO.
Whether offload is free or ruinous reduces to one inequality: the transfer of the tensor you prefetch must finish inside the compute it overlaps, otherwise the step is transfer-bound and stalls (the same roofline logic as Roofline Model and Arithmetic Intensity). With perfect overlap the step costs max(compute_time, transfer_time). The block below makes that precise, cross-checks it against an independent event-driven simulator, tests the exact break-even bandwidth, and shows why the same 2 GiB shard is fine over NVLink-C2C but catastrophic over PCIe (numpy only, runnable as-is):
import numpy as np
GB = 1024**3
def transfer_time_s(num_bytes, bandwidth_bytes_per_s):
if bandwidth_bytes_per_s <= 0:
raise ValueError("bandwidth must be positive")
return num_bytes / bandwidth_bytes_per_s
def step_time_with_offload(compute_s, num_bytes, bandwidth_bytes_per_s):
"""Per-step wall time when prefetching concurrently with compute: with full
overlap the transfer either hides under compute (step = compute) or the
step is transfer-bound (step = transfer). Hence max(compute, transfer)."""
return max(compute_s, transfer_time_s(num_bytes, bandwidth_bytes_per_s))
def is_hidden(compute_s, num_bytes, bandwidth_bytes_per_s):
return transfer_time_s(num_bytes, bandwidth_bytes_per_s) <= compute_s
def simulate_step(compute_s, num_bytes, bw, dt=1e-4):
"""Independent reference: tick a clock; compute and transfer run in
parallel; the step ends only when BOTH complete."""
t = done_c = done_x = 0.0
while done_c < compute_s - 1e-15 or done_x < num_bytes - 1e-6:
done_c = min(compute_s, done_c + dt)
done_x = min(num_bytes, done_x + bw * dt)
t += dt
return t
bytes_shard = 2 * GB # a 2 GiB optimizer/param shard to prefetch
compute = 20e-3 # 20 ms of GPU compute per step
# PCIe Gen4 x16 (~25 GB/s): 2 GiB takes 80 ms -> cannot hide, loop STALLS.
pcie = 25 * GB
assert not is_hidden(compute, bytes_shard, pcie)
t_pcie = step_time_with_offload(compute, bytes_shard, pcie)
assert abs(t_pcie - transfer_time_s(bytes_shard, pcie)) < 1e-12 # transfer-bound
assert t_pcie / compute > 3.5, "PCIe step must be several x slower than compute"
# NVLink-C2C (~900 GB/s): 2 GiB takes ~2.2 ms -> fully hidden, compute-bound.
c2c = 900 * GB
assert is_hidden(compute, bytes_shard, c2c)
assert abs(step_time_with_offload(compute, bytes_shard, c2c) - compute) < 1e-12
# Knife-edge: the exact bandwidth where transfer == compute counts as hidden;
# one byte/s slower does not.
bw_break_even = bytes_shard / compute
assert is_hidden(compute, bytes_shard, bw_break_even)
assert not is_hidden(compute, bytes_shard, bw_break_even - 1.0)
# Equivalence to the independent simulator across regimes.
for bw in (pcie, c2c, bw_break_even):
fast = step_time_with_offload(compute, bytes_shard, bw)
slow = simulate_step(compute, bytes_shard, bw)
assert abs(fast - slow) / fast < 0.02, f"sim disagrees at {bw / GB:.0f} GB/s"
# Adversarial: non-positive bandwidth must raise, never return inf silently.
for bad in (0.0, -1.0):
try:
transfer_time_s(bytes_shard, bad)
raised = False
except ValueError:
raised = True
assert raised, f"bandwidth={bad} must raise"
print(f"PCIe: transfer={transfer_time_s(bytes_shard, pcie) * 1e3:.0f} ms -> "
f"STALL step={t_pcie * 1e3:.0f} ms ({t_pcie / compute:.1f}x); "
f"C2C: transfer={transfer_time_s(bytes_shard, c2c) * 1e3:.1f} ms -> hidden; "
f"break-even={bw_break_even / GB:.0f} GB/s")
Running it prints PCIe: transfer=80 ms -> STALL step=80 ms (4.0x); C2C: transfer=2.2 ms -> hidden; break-even=100 GB/s: the identical 2 GiB prefetch is invisible on a 900 GB/s NVLink-C2C link but quadruples step time on a 25 GB/s PCIe link, and the crossover for a 20 ms step sits at 100 GB/s. This is why the book insists on pinned memory plus async DMA and treats NVMe-via-Unified-Memory paging (which cannot guarantee overlap) as a last resort. GPUDirect Storage can page parameters directly from NVMe to GPU without CPU involvement; on superchips, NVLink-C2C makes CPU offload cheap enough that the book describes a CPU-GPU superchip offload system ("SuperOffload") overlapping speculative CPU optimizer updates with GPU backprop.2
How to maintain it¶
- Profile that recompute/transfer actually overlaps. Inspect the timeline in Nsight Systems or
torch.profiler; checkpointing should lower peak memory in the memory profile, and offload transfers should sit under compute, not block it. See Profiling GPUs: Nsight Systems and Nsight Compute. - Watch peak with
torch.cuda.memory_stats()/torch.cuda.max_memory_allocated()before and after enabling each technique to quantify the saving. - Apply one change at a time and measure; checkpointing, offload, and
torch.compile's own recompute interact. - Guard goodput: recompute and stalled transfers cost wall-clock. Track effective throughput, not just "it fits." See Goodput: Measuring Useful AI Throughput.
Failure modes¶
- Recompute cost swamps the saving: checkpointing cheap pointwise ops (layer norms, activations, dropout) buys almost no memory but pays a full recompute. Checkpoint transformer blocks, not everything; use SAC or the
activation_memory_budgetknob to force-save the matmuls and attention. The SAC block shows recompute falling from 94 to 3 units at the same memory budget.15 - Transfer-bound offload (throughput collapse): prefetches that cannot hide under compute make every step wait on the link. The production block shows the same 2 GiB shard going from invisible on 900 GB/s NVLink-C2C to a 4x step-time blowup on 25 GB/s PCIe. Check the tensor-bytes / bandwidth vs compute-time inequality before enabling offload, not after.2
- Synchronous or unpinned copies: forgetting
pin_memory/non_blocking=Trueserializes the copy against compute even when bandwidth was sufficient, silently converting a hideable transfer into a stall. Pin host buffers and use async DMA.29 - Unified Memory / NVMe-swap paging: leaning on NVIDIA Unified Memory or OS swap to "just fit" introduces unpredictable page-fault stalls with no overlap guarantee. Prefer explicit, managed offload; treat NVMe as a last resort.2
use_reentrant=Truelimitations: the reentrant variant always recomputes the whole region and does not support all backward APIs (for example, it breaks with some hooks and withgrad()on non-leaf inputs). Passuse_reentrant=False; PyTorch 2.9 raises if the flag is omitted at all.4- Naive checkpoint under FSDP sharding: wrapping submodules with plain
torch.utils.checkpointconflicts with FSDP, because recompute needs the full unsharded parameters that FSDP has released. Use the FSDP-integratedapply_activation_checkpointing/checkpoint_wrapperpath.6 - Wrong sharding strategy for the fabric:
FULL_SHARDon a slow interconnect drowns in all-gather traffic; a full parameter replica (SHARD_GRAD_OP) on too many GPUs wastes memory. Match the strategy to fabric speed and GPU count.3 - Optimizing "it fits" instead of goodput: recompute FLOPS and stalled transfers are real wall-clock cost. A configuration that fits but halves throughput can be worse than a smaller batch that does not. Track effective throughput. See Goodput: Measuring Useful AI Throughput.
References¶
- Chris Fregly, AI Systems Performance Engineering (O'Reilly), Chapter 13: "Activation Checkpointing for Memory Savings," "Offloading Parameters to CPU and NVMe," "SuperOffload," and "FSDP Automatic Checkpointing and Offloading."
- PyTorch,
torch.utils.checkpoint(use_reentrant semantics): https://docs.pytorch.org/docs/stable/checkpoint.html - PyTorch blog, "Current and New Activation Checkpointing Techniques in PyTorch" (SAC,
create_selective_checkpoint_contexts,CheckpointPolicy,activation_memory_budget): https://pytorch.org/blog/activation-checkpointing-techniques/ - PyTorch,
torch.distributed.fsdp(FSDP /fully_shard): https://docs.pytorch.org/docs/stable/fsdp.html - PyTorch, activation checkpointing with FSDP (
apply_activation_checkpointing,checkpoint_wrapper,CheckpointImpl): https://docs.pytorch.org/docs/stable/distributed.algorithms.join.html and the FSDP tutorials at https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html - DeepSpeed, ZeRO / ZeRO-Infinity tutorial: https://www.deepspeed.ai/tutorials/zero/
- DeepSpeed, Configuration JSON (
offload_optimizer,offload_param,aio): https://www.deepspeed.ai/docs/config-json/
Related: FSDP · DeepSpeed and ZeRO · torch.compile: Graph Capture, Backends, and Recompiles · PyTorch CUDA Caching Allocator Tuning · GPU Memory Hierarchy · Distributed Training Platform · Glossary
-
Fregly, Ch. 13, "Activation Checkpointing for Memory Savings": store only region inputs and recompute intermediate activations in backward via
torch.utils.checkpoint; checkpoint transformer blocks but not small layers (layer norms, embeddings); "you exchange some of the GPU's ample compute FLOPS to overcome its limited HBM capacity." ↩↩↩↩↩↩ -
Fregly, Ch. 13, "Offloading Parameters to CPU and NVMe" / "SuperOffload": offload infrequently used components to CPU/NVMe with pinned memory and async DMA; ZeRO-Infinity / ZeRO-Inference stream weights layer-by-layer; overlap transfers with compute; Unified Memory paging is unpredictable, NVMe is a last resort. ↩↩↩↩↩↩↩↩↩
-
Fregly, Ch. 13, "FSDP Automatic Checkpointing and Offloading": FSDP = ZeRO-3;
activation_checkpointing_policy,CPUOffload(offload_params=True),use_orig_params=True;FULL_SHARD/HYBRID_SHARD/SHARD_GRAD_OPtrade-offs; compose with TP and PP. ↩↩↩↩ -
PyTorch,
torch.utils.checkpointdocs:use_reentrantmust be passed explicitly (exception in 2.9 if omitted);use_reentrant=Falserecommended (records autograd graph, stops recompute once activations available, supports all backward APIs). https://docs.pytorch.org/docs/stable/checkpoint.html ↩↩↩ -
PyTorch blog, "Current and New Activation Checkpointing Techniques":
create_selective_checkpoint_contexts(policy_fn)+CheckpointPolicy.{MUST_SAVE,PREFER_RECOMPUTE}(2.5 prototype);torch._dynamo.config.activation_memory_budgetin [0,1] (2.4 experimental, compile-only). https://pytorch.org/blog/activation-checkpointing-techniques/ ↩↩↩↩ -
PyTorch, activation checkpointing with FSDP:
apply_activation_checkpointingwithcheck_fn,checkpoint_wrapper,CheckpointImplfromtorch.distributed.algorithms._checkpoint.checkpoint_wrapper; naive submoduletorch.utils.checkpointcan conflict with FSDP sharding. https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html ↩↩ -
PyTorch,
torch.distributed.fsdp: FSDP1 (FullyShardedDataParallel) and the newerfully_shard(FSDP2). https://docs.pytorch.org/docs/stable/fsdp.html ↩ -
DeepSpeed, ZeRO / ZeRO-Infinity tutorial. https://www.deepspeed.ai/tutorials/zero/ ↩
-
DeepSpeed, Configuration JSON:
offload_optimizer/offload_param(devicecpu|nvme,nvme_path,pin_memory,buffer_count,buffer_size,max_in_cpu) andaioblock. https://www.deepspeed.ai/docs/config-json/ ↩↩↩