FSDP (fully sharded data parallel)¶
Scope: PyTorch FSDP2 (fully_shard), sharding parameters, gradients, and optimizer state across ranks to train models too large for DDP, and how that scales across multiple nodes via HSDP.
Reference templates use real APIs (torch, torchrun, NCCL): pin versions and validate before production use. The
torchsnippets below need CUDA GPUs and are not executed here; each core algorithm they teach is mirrored by a runnable, asserted numpy block (run withpython3) that validates the underlying math.
What it is¶
FSDP is a data-parallel strategy that shards each layer's parameters, gradients, and optimizer state across the data-parallel ranks instead of replicating them (the DDP model in DDP). Each rank holds a 1/N slice. At runtime FSDP all-gathers the full parameters of a unit just before its forward/backward compute, runs the compute, reduce-scatters the gradients back to shards, and immediately frees the gathered parameters, so the full weight tensor only exists transiently, one unit at a time. Peak memory is dominated by the largest single unit plus the shard, not the whole model.
FSDP2 is the current API: the per-parameter-sharding rewrite exposed as torch.distributed.fsdp.fully_shard (functional, applied per module) replacing the legacy FullyShardedDataParallel wrapper class (FSDP1). FSDP2 uses DTensor-based sharding, cleaner mixed precision via MixedPrecisionPolicy, and composes with torch.compile and tensor parallel (tensor parallelism). This page covers FSDP2. It is one of the parallelism axes surveyed in distributed training.
Why use it¶
- Fit models too big for DDP. DDP needs a full model plus optimizer copy per GPU; for large models that OOMs. FSDP cuts per-GPU state by
~N(the shard count), so a model that does not fit under DDP fits under FSDP on the same hardware. - Optimizer-state dominates memory. Adam keeps ~2 extra fp32 states per parameter; sharding these (the ZeRO-3 idea, see DeepSpeed and ZeRO) is where most of the saving comes from.
- Native PyTorch, composable. No external runtime; composes with tensor and pipeline parallelism, activation checkpointing, and
torch.compile.
When to use it (and when not)¶
- Use DDP (DDP) when the model plus optimizer plus activations fit comfortably on one GPU: DDP has lower communication (one all-reduce per step vs all-gather plus reduce-scatter) and is simpler and faster in that regime.
- Use FSDP when the model does not fit under DDP but still fits when sharded across the data-parallel group (most 7B-70B dense fine-tunes; pretraining up to mid hundreds of billions with HSDP).
- Add tensor parallel (tensor parallelism) or pipeline parallel (pipeline parallelism) when even a single layer's activations or a transient all-gather is too large, or to keep the FSDP group small. Production frontier pretraining is usually FSDP/HSDP plus TP plus PP (3D/4D parallelism).
- ZeRO via DeepSpeed (DeepSpeed and ZeRO) is the conceptual sibling; choose by stack (native PyTorch goes to FSDP; DeepSpeed ecosystem / ZeRO-Offload goes to DeepSpeed).
Architecture¶
flowchart LR
subgraph Rank["One rank, per FSDP unit"]
SH["Param shard (1/N)"]
AG["All-gather full params"]
FC["Forward / backward compute"]
RS["Reduce-scatter grads"]
FR["Free full params"]
end
SH -->|"pre-compute"| AG
AG --> FC
FC --> RS
RS -->|"back to shard"| SH
RS --> FR
FR -.->|"next unit"| AG
How to use it¶
Apply fully_shard to each transformer block (a sharding unit), then to the root module. Wrapping per block lets FSDP overlap the next block's all-gather with the current block's compute.
# Reference template: torch + NCCL + CUDA GPUs (not executed here).
# pin torch (FSDP2 API): e.g. torch>=2.6; validate on your build.
import torch
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
mp = MixedPrecisionPolicy(
param_dtype=torch.bfloat16, # compute/all-gather in bf16
reduce_dtype=torch.float32, # grad reduce-scatter in fp32 (stability)
)
# model is an nn.Module already on meta or cuda; shard each block, then root.
for block in model.model.layers:
fully_shard(block, mp_policy=mp)
fully_shard(model, mp_policy=mp) # root last
# launch: torchrun --standalone --nproc_per_node=8 train.py
# Reference template: torch + NCCL + CUDA GPUs (not executed here).
# standard loop: FSDP is transparent to the training step.
opt = torch.optim.AdamW(model.parameters(), lr=2e-5, fused=True)
for batch in loader:
loss = model(**batch).loss
loss.backward() # reduce-scatter happens here
opt.step(); opt.zero_grad()
Validated (numpy, runnable). FSDP's whole lifecycle reduces to two operations on a partitioned tensor: all-gather reconstructs a unit's full parameter from its 1/N shards for compute, and the gradient path all-gathers over a reduce-scatter, which is exactly an all-reduce of the summed gradients. The block proves the shard/all-gather round trip is exact, proves all_gather(reduce_scatter(g)) == all_reduce(g), shows per-rank state is ceil(P/N) (the 1/N memory win), and catches a corrupted shard and a dropped rank.
import numpy as np
def make_shards(x, n):
# pad to a multiple of n, then split into n equal contiguous shards
L = x.shape[0]
pad = (-L) % n
xp = np.concatenate([x, np.zeros(pad, x.dtype)])
return xp.reshape(n, -1), L
def all_gather(shards, L):
return shards.reshape(-1)[:L] # concat shards, drop padding
def reduce_scatter(per_rank_grads, n):
total = per_rank_grads.sum(axis=0) # sum across ranks (the reduce)
return make_shards(total, n) # then scatter shards (rank r owns row r)
rng = np.random.default_rng(0)
N, L = 4, 10 # L % N != 0 exercises the padding boundary
P = rng.standard_normal(L)
# 1) shard -> all-gather reconstructs the parameter exactly (the transient full weight)
shards, L0 = make_shards(P, N)
assert shards.shape == (N, -(-L // N)) # ceil(L/N) per shard => 1/N state
assert np.array_equal(all_gather(shards, L0), P)
# 2) all-gather over reduce-scatter == all-reduce(sum): the FSDP gradient identity
grads = rng.standard_normal((N, L))
rs, Lr = reduce_scatter(grads, N)
assert np.allclose(all_gather(rs, Lr), grads.sum(axis=0))
# 3) adversarial: a corrupted shard is detected, not silently gathered
bad = shards.copy(); bad[2, 0] += 1.0
assert not np.array_equal(all_gather(bad, L0), P)
# 4) adversarial: dropping a rank's gradient breaks the reduce identity
missing = grads.copy(); missing[1] = 0.0
assert not np.allclose(all_gather(reduce_scatter(missing, N)[0], Lr), grads.sum(axis=0))
# 5) memory: per-rank sharded state is 1/N of the replicated parameter (plus padding)
per_rank = shards[0].size
assert per_rank == -(-L // N) # exactly ceil(L/N)
assert per_rank * N >= L and (per_rank - 1) * N < L
print(f"A OK: roundtrip exact; allgather-of-reducescatter==allreduce; "
f"corruption+drop caught; per-rank state {per_rank}/{L} ~ 1/{N}")
How to integrate with it¶
FSDP integrates at two levels: the per-wrap knobs that control sharding granularity, precision, and activation memory, and its composition with the other parallelism axes and higher-level trainers.
Wrapping policy and granularity¶
- Wrapping policy. Shard at transformer-block granularity (one
fully_shardper block). Too coarse (only root) gives one giant all-gather, no overlap, and high peak memory; too fine (pernn.Linear) gives many tiny collectives and is latency-bound. For HF models, the block list is typicallymodel.model.layers. reshard_after_forwardcontrols whether params are re-sharded after forward (defaultTrue, lowest memory) or kept for backward (False, less comm, more memory), a key throughput/memory knob.
Mixed precision¶
MixedPrecisionPolicy sets param_dtype=torch.bfloat16 (compute and all-gather run in bf16) and reduce_dtype=torch.float32 (the gradient reduce-scatter accumulates in fp32 for stability). Keeping the reduce in fp32 is the safe default; bf16 reduction can diverge on large runs.
Validated (numpy, runnable). Reducing gradients in bf16 loses accuracy because bf16 keeps only 7 mantissa bits, so accumulating many small contributions swamps the running sum. The block reduces 4096 small values sequentially in bf16 and in fp32 against a float64 reference: fp32 is orders of magnitude more accurate, and bf16 cannot even resolve 1 + 2^-10. This is why the page reduces in fp32.
import numpy as np
def to_bf16(x):
# round-to-nearest-even cast float32 -> bfloat16, kept as float32 values
x = np.asarray(x, np.float32)
u = x.view(np.uint32).astype(np.uint64)
bias = ((u >> 16) & 1) + 0x7FFF # round-to-nearest-even on 16 dropped bits
u = ((u + bias) & 0xFFFF0000).astype(np.uint32)
return u.view(np.float32)
def reduce_in(vals, dtype):
acc = np.float32(0.0)
for v in vals:
if dtype == "bf16":
acc = to_bf16(to_bf16(acc) + to_bf16(np.float32(v)))
else:
acc = np.float32(acc + np.float32(v))
return float(acc)
rng = np.random.default_rng(0)
vals = rng.standard_normal(4096).astype(np.float32) * 1e-2 # many small grad contribs
ref = float(np.sum(vals.astype(np.float64))) # high-precision ground truth
err_bf16 = abs(reduce_in(vals, "bf16") - ref)
err_fp32 = abs(reduce_in(vals, "fp32") - ref)
assert err_fp32 < err_bf16 # fp32 reduce is strictly more accurate
assert err_bf16 > 10 * err_fp32 # and materially so (page's fp32 default)
# adversarial: bf16's 7-bit mantissa cannot resolve 1 + 2^-10 (rounds back to 1.0)
assert to_bf16(np.float32(1.0 + 2 ** -10)) == np.float32(1.0)
assert to_bf16(np.float32(1.0 + 2 ** -6)) != np.float32(1.0) # but 2^-6 (>= ulp) is kept
print(f"B OK: err_fp32={err_fp32:.2e} << err_bf16={err_bf16:.2e} "
f"({err_bf16 / err_fp32:.1f}x); bf16 swamps 1+2^-10")
Activation checkpointing¶
Pair FSDP with checkpointing to trade compute for activation memory, usually the second-largest memory term after optimizer state. Checkpoint each block, then shard it.
# Reference template: torch + NCCL + CUDA GPUs (not executed here).
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper, apply_activation_checkpointing,
)
from torch.distributed.fsdp import CPUOffloadPolicy
# checkpoint each block, then shard it.
apply_activation_checkpointing(
model, checkpoint_wrapper_fn=checkpoint_wrapper,
check_fn=lambda m: m in set(model.model.layers),
)
# optional: offload params/grads to host (slower, frees VRAM)
fully_shard(model, offload_policy=CPUOffloadPolicy())
Validated (numpy, runnable). Activation checkpointing discards a block's intermediate activations after forward and recomputes them during backward. Correctness requires the recomputed activation to reproduce the stored one exactly. The block computes the same gradient both ways and asserts equality, shows the checkpoint path stores zero internal activations, and catches the bug where a skipped recompute (stale zeros) corrupts the gradient.
import numpy as np
def forward_store(x):
a = np.maximum(x, 0.0) # relu activation
loss = float((a * a).sum()) # b = a^2, loss = sum(b)
return loss, {"a": a} # activation cached for backward
def backward_from_cache(x, cache):
a = cache["a"]
return 2.0 * a * (x > 0) # dL/dx = 2a * relu'(x)
def backward_recompute(x):
a = np.maximum(x, 0.0) # activation RECOMPUTED in backward (checkpointing)
return 2.0 * a * (x > 0)
rng = np.random.default_rng(0)
x = rng.standard_normal(2048)
loss, cache = forward_store(x)
g_store = backward_from_cache(x, cache)
g_ckpt = backward_recompute(x)
# 1) recompute yields the exact same gradient as the stored-activation path
assert np.array_equal(g_store, g_ckpt)
# 2) closed-form cross-check of the gradient
assert np.allclose(g_store, 2.0 * np.maximum(x, 0.0) * (x > 0))
# 3) memory: checkpointing stores 0 internal activations vs 1 for the stored path
assert len({}) < len(cache)
# 4) adversarial: skipping the recompute (stale zeros) gives the WRONG gradient
g_wrong = 2.0 * np.zeros_like(x) * (x > 0)
assert not np.array_equal(g_wrong, g_store)
assert g_store[x > 0].any() # gradient is non-trivial where x > 0
print("C OK: recompute==store exactly; 0<1 saved tensors; stale-skip gradient caught")
Composition with tensor, pipeline, and compile¶
FSDP2 is DTensor-based, so it composes with the other axes rather than replacing them:
- Tensor parallel (tensor parallelism) and pipeline parallel (pipeline parallelism) stack under FSDP/HSDP to form 3D/4D parallelism when a single layer's activations or a transient all-gather is too large, or to keep the FSDP group small.
torch.compilecomposes withfully_shard; compile each block so the compiled graph and the FSDP collectives overlap.- Higher-level trainers. HF
Trainer/accelerate, PyTorch Lightning (strategy="fsdp"), and the RL/SFT stacks below own the wrapping for you: you select FSDP as the backend and they drive the collectives.
How to run it in production¶
Production FSDP is a launch-orchestration, fabric, and checkpointing problem: the training code is unchanged, and what matters is how ranks start, how the collectives reach the wire, and how the job survives a failed rank.
Launch orchestration (torchrun and Slurm)¶
Enlarge the world with torchrun; rendezvous (--rdzv_backend=c10d) coordinates ranks across hosts. Under Slurm, one torchrun per node, fanned out by srun.
# multi-node launch (run on every node; rdzv coordinates ranks).
torchrun \
--nnodes=4 --nproc_per_node=8 \
--rdzv_backend=c10d --rdzv_endpoint="$MASTER_ADDR:29500" \
--rdzv_id=fsdp-job train.py
#!/bin/bash
#SBATCH --nodes=4
#SBATCH --ntasks-per-node=1
#SBATCH --gpus-per-node=8
#SBATCH --exclusive
# one torchrun per node; srun fans it out. See cluster-slurm.md for the pattern.
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n1)
srun torchrun --nnodes=$SLURM_NNODES --nproc_per_node=8 \
--rdzv_backend=c10d --rdzv_endpoint="$MASTER_ADDR:29500" \
--rdzv_id=$SLURM_JOB_ID train.py
The Slurm/torchrun launch patterns live in Slurm; end-to-end recipes in distributed-training recipes.
Hardware and fabric¶
The collectives ride NCCL, so the fabric and its config dominate throughput:
- NVLink / NVSwitch carries the intra-node all-gather plus reduce-scatter, the hot path under HSDP; keep sharding inside the NVLink domain (performance tuning).
- InfiniBand / RoCE plus GPUDirect RDMA carries the inter-node gradient all-reduce. Set
NCCL_IB_HCAto the right HCAs andNCCL_NET_GDR_LEVEL=SYS; confirm GDR withNCCL_DEBUG=INFOprinting[GDRDMA](networking fabric). NCCL_NVLS_ENABLE=1turns on NVLink SHARP (in-switch reduction), accelerating reduce-scatter/all-reduce on NVSwitch hardware.- PCIe ACS must be off for P2P / GDR to engage, or NCCL silently falls back and throughput collapses.
- On Blackwell, bf16 params with fp32 reduce is the safe default; FP8 weights are an optimization layered on top, validated per recipe (the Blackwell platform).
Checkpointing and fault tolerance¶
- Sharded checkpoints. Use
torch.distributed.checkpoint(DCP) for sharded, resharding-tolerant save/load; it lets you restart on a different world size without gathering a full state dict (recovery in the checkpoint-recovery runbook). - Avoid full-state saves at scale. Gathering the full (unsharded) state on rank 0 for very large models can OOM host memory; prefer DCP sharded checkpoints.
- Elastic restarts.
torchrunis TorchElastic: with--rdzv_backend=c10dand--max-restarts=Nit re-forms the process group and resumes from the last DCP checkpoint after a rank dies. - Fail loud, not hung. Set
TORCH_NCCL_ASYNC_ERROR_HANDLING=1and pass atimeouttoinit_process_groupso a dead or straggling rank aborts the collective with an error instead of hanging the world (see the NCCL-hang runbook).
How to maintain it¶
An FSDP job is only as healthy as its slowest rank and its checkpoint format; maintenance is about keeping ranks identical, keeping one API path, and keeping the sharded state restorable over time.
- Pin one FSDP API. FSDP2
fully_shardand FSDP1FullyShardedDataParallelare different code paths; do not mix them in one model. New code should target FSDP2, since it is the current, maintained API. - Pin versions across every node. All ranks must run the same PyTorch, CUDA, and NCCL build; a mismatched torch or NCCL across hosts causes silent hangs or divergent numerics. Upgrade the whole job as one coordinated redeploy, never a rolling per-node bump, and re-validate the FSDP2 API on each new torch build.
- Keep checkpoints resharding-tolerant. DCP checkpoints load back on a different world size, so you can grow or shrink the shard/replicate mesh between runs; treat a full-state-dict checkpoint as a migration artifact only, not the steady-state format (the checkpoint-recovery runbook).
- Tune the memory/throughput knobs deliberately.
reshard_after_forwardand the block-wrapping granularity are the levers between peak memory and communication; change one at a time and measure, since the right point moves with model size and fabric. - Watch MFU and stragglers. The per-layer collectives are barriers, so the slowest rank sets step time; track model FLOPs utilization and per-step collective time and treat a regression as a fabric or straggler signal (the MFU-regression runbook).
How to scale it¶
FSDP scales across multiple nodes by enlarging the data-parallel world via torchrun. Naive multi-node FSDP all-gathers and reduce-scatters across all ranks every layer; that crosses the slow inter-node fabric per layer and bottlenecks on InfiniBand bandwidth. HSDP (Hybrid Sharded Data Parallel) fixes this with a 2-D DeviceMesh:
- Shard intra-node (inner mesh dim, e.g. 8 GPUs on NVLink/NVSwitch): the all-gather plus reduce-scatter happen over fast intra-node NVLink, in a fraction of the original world size.
- Replicate inter-node (outer mesh dim, across nodes over InfiniBand): only a DDP-style all-reduce of gradients crosses the slow fabric, once per step, exactly the DDP cost in DDP.
# Reference template: torch + NCCL + CUDA GPUs (not executed here).
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import fully_shard
# nodes = number of nodes; 8 GPUs/node. Outer = replicate, inner = shard.
nodes = 4
mesh = init_device_mesh(
"cuda", (nodes, 8), mesh_dim_names=("replicate", "shard"),
)
for block in model.model.layers:
fully_shard(block, mesh=mesh, mp_policy=mp)
fully_shard(model, mesh=mesh, mp_policy=mp)
# intra-node (NVLink): all-gather/reduce-scatter of the 8-way shard
# inter-node (IB): grad all-reduce across the 4 replicas, once per step
Pick the shard dimension to match the NVLink domain (8 for HGX, 4 for some boards); the replicate dimension then spans nodes. Launch these across nodes with the torchrun/Slurm patterns in production. Recipes in distributed-training recipes; cluster launchers in Slurm.
Validated (numpy, runnable). HSDP replaces the flat all-reduce over all ranks with reduce-scatter inside each node and a single inter-node all-reduce per shard; the result must equal the naive all-reduce over every rank. The block builds an (R nodes, S shard-ranks) gradient, proves the hierarchical reduction equals the flat one, and catches two failures: doing only the intra-node step, and dropping a node from the inter-node all-reduce.
import numpy as np
def make_shards(x, s):
L = x.shape[0]; pad = (-L) % s
return np.concatenate([x, np.zeros(pad, x.dtype)]).reshape(s, -1), L
def gather(shards, L):
return shards.reshape(-1)[:L]
def hsdp_all_reduce(g, R, S):
# g: (R nodes, S shard-ranks, L). Hierarchical: reduce-scatter intra-node,
# all-reduce inter-node per shard, all-gather. Returns the full global sum.
L = g.shape[-1]
node_shards = np.stack([make_shards(g[i].sum(axis=0), S)[0] for i in range(R)])
global_shard = node_shards.sum(axis=0) # inter-node all-reduce (crosses the fabric)
return gather(global_shard, L)
rng = np.random.default_rng(0)
R, S, L = 3, 4, 10 # L % S != 0 -> padding boundary
g = rng.standard_normal((R, S, L))
flat = g.reshape(R * S, L).sum(axis=0) # naive all-reduce over all R*S ranks
hier = hsdp_all_reduce(g, R, S)
assert np.allclose(hier, flat) # HSDP == flat all-reduce (equivalence)
assert hier.shape[0] == L # sharding padding is trimmed on gather
# adversarial 1: intra-node only (drop the inter-node all-reduce) != global
intra_only = gather(make_shards(g[0].sum(axis=0), S)[0], L)
assert not np.allclose(intra_only, flat)
# adversarial 2: a missing node changes the reduction (detects a dropped replica)
assert not np.allclose(hsdp_all_reduce(g[:R - 1], R - 1, S), flat)
print("D OK: HSDP hierarchical reduce == flat all-reduce; intra-only + dropped node caught")
Inference¶
FSDP is a training construct. For serving, use a dedicated inference engine (vLLM/SGLang) with tensor/pipeline parallel; see inference serving/disaggregated inference. The one exception is when a model is too large to fit for offline/batch evaluation or generation even on a multi-GPU node: sharded forward (FSDP or a sharded-inference path) can run a single very large model across ranks, gathering each layer transiently as in training. For latency-sensitive online serving this is the wrong tool: its per-layer all-gather adds latency; route to the serving pages.
Fine-tuning¶
FSDP is the default backend for memory-bound fine-tuning:
- Full-parameter fine-tuning of large models that will not fit under DDP: shard params plus optimizer state across the group (fine-tuning and post-training).
- FSDP plus LoRA / QLoRA: shard the frozen base weights (and quantized base for QLoRA) while training only small adapter tensors, cutting both base-weight and optimizer memory; this is how single-node fine-tunes of very large models fit (SFT and LoRA).
- RL post-training stacks (verl, SkyRL, NeMo-RL) use FSDP as a policy/training backend behind the rollout engine (RL libraries, GRPO).
Validated (numpy, runnable). The memory arithmetic behind the page's two central claims. First, FSDP's per-GPU state is exactly 1/N of DDP's replicated footprint, and collapses to the DDP footprint at N=1. Second, under QLoRA only the trainable adapters carry optimizer state, so a 10x larger frozen base leaves the optimizer footprint unchanged and QLoRA's optimizer state is over 100x smaller than full fine-tuning.
import numpy as np
# Mixed-precision Adam memory accounting (bytes/param), ZeRO-style:
P16, G16, M, V, MASTER = 2, 2, 4, 4, 4 # bf16 param+grad, fp32 m+v+master copy
PER_PARAM = P16 + G16 + M + V + MASTER # 16 bytes/param, fully replicated
OPT_ONLY = M + V + MASTER # 12 bytes/param of optimizer state
def ddp_per_gpu(P): return PER_PARAM * P # DDP replicates all
def fsdp_per_gpu(P, N, U): return PER_PARAM * P / N + P16 * U # shard + 1 transient unit
P, U = 7_000_000_000, 200_000_000 # ~7B params, largest block ~200M
for N in (2, 8, 64):
ddp, fsdp = ddp_per_gpu(P), fsdp_per_gpu(P, N, U)
assert fsdp < ddp # FSDP lighter for N>1
assert abs((PER_PARAM * P / N) / ddp - 1.0 / N) < 1e-12 # sharded state is exactly 1/N
# boundary: N=1 with no transient unit -> no benefit, exactly the DDP footprint
assert fsdp_per_gpu(P, 1, 0) == ddp_per_gpu(P)
# QLoRA: only trainable adapters carry optimizer state; the frozen 4-bit base carries none
def qlora_opt(n_base, n_adapter): return OPT_ONLY * n_adapter # base frozen -> contributes 0
P_adapter = 40_000_000
full_ft = OPT_ONLY * P # optimizer state if all P were trainable
qlora = qlora_opt(P, P_adapter) # only adapters trainable
assert qlora < full_ft / 100 # >100x smaller optimizer state
# adversarial: a 10x larger frozen base leaves QLoRA optimizer state unchanged
assert qlora_opt(7_000_000_000, P_adapter) == qlora_opt(70_000_000_000, P_adapter)
base_4bit = 0.5 * P # 4-bit base weight bytes (no grad/opt state)
assert base_4bit < OPT_ONLY * P # frozen base cheaper than its optimizer state
print(f"E OK: sharded=1/N exact; N=1 no benefit; QLoRA opt {full_ft / qlora:.0f}x smaller")
Cookbook (common use cases)¶
1. Single-node FSDP (8 GPUs, one box)
# Reference template: torch + NCCL + CUDA GPUs (not executed here).
# torchrun --standalone --nproc_per_node=8 train.py
import torch
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
mp = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32)
for block in model.model.layers:
fully_shard(block, mp_policy=mp)
fully_shard(model, mp_policy=mp) # 1-D mesh inferred from default group
2. Multi-node HSDP (shard intra-node, replicate inter-node)
# Reference template: torch + NCCL + CUDA GPUs (not executed here).
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import fully_shard
mesh = init_device_mesh("cuda", (NUM_NODES, 8),
mesh_dim_names=("replicate", "shard"))
for block in model.model.layers:
fully_shard(block, mesh=mesh, mp_policy=mp)
fully_shard(model, mesh=mesh, mp_policy=mp)
# launch on every node: torchrun --nnodes=$NUM_NODES --nproc_per_node=8 ...
3. FSDP plus QLoRA (fit a large model on one node)
# Reference template: torch + peft + bitsandbytes + CUDA GPUs (not executed here).
# base loaded 4-bit (bitsandbytes); train adapters only.
from peft import LoraConfig, get_peft_model
model = get_peft_model(model, LoraConfig(r=16, lora_alpha=32, target_modules="all-linear"))
for block in model.model.model.layers: # extra .model from the PEFT wrapper
fully_shard(block, mp_policy=mp)
fully_shard(model, mp_policy=mp)
# only adapter params carry optimizer state; base shards stay frozen.
Failure modes¶
Several of these are reproduced in miniature by the validated numpy blocks above: a corrupted shard and a dropped rank (in "How to use it"), a bf16 reduce that swamps small gradients (in "How to integrate with it"), and the intra-node-only / dropped-node HSDP reductions (in "How to scale it").
- Wrapping only the root gives a single huge all-gather, OOM, and no compute/comm overlap. Always shard per transformer block (the training-OOM runbook).
- Naive multi-node FSDP (1-D mesh across all ranks) does per-layer collectives across IB and is IB-bandwidth-bound. Use HSDP so only the grad all-reduce crosses nodes.
reduce_dtype=bf16can diverge on large runs; reduce gradients in fp32.- ACS on / wrong
NCCL_IB_HCAmakes NCCL fall back off GDR and throughput collapses; verify[GDRDMA](the NCCL-hang runbook). - Mismatched shard dim vs NVLink domain (e.g. shard=16 spanning two NVLink islands) makes the "intra-node" all-gather cross IB anyway. Match the shard dim to the NVLink group.
- Mixing FSDP1
FullyShardedDataParalleland FSDP2fully_shardAPIs in one model: pick one; they are different code paths. - Saving full (unsharded) state on rank 0 for very large models can OOM host memory; prefer DCP sharded checkpoints (the checkpoint-recovery runbook).
References¶
- PyTorch FSDP docs: https://docs.pytorch.org/docs/stable/fsdp.html
fully_shardAPI: https://docs.pytorch.org/docs/stable/distributed.fsdp.fully_shard.html- FSDP2 getting-started tutorial: https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html
- "PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel" (arXiv 2304.11277): https://arxiv.org/abs/2304.11277
- TorchTitan (HSDP/FSDP + TP/PP reference): https://github.com/pytorch/torchtitan
- NCCL env vars: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html
Related: Distributed training · DDP · DeepSpeed/ZeRO · Tensor parallel · Pipeline parallel · Training recipes · Slurm · SFT/LoRA · Performance · Glossary