Skip to content
Markdown

Recipe: FSDP training in a single datacenter

Scope: a standalone recipe to run FSDP2 (fully_shard) training inside one NVLink/InfiniBand datacenter: the launcher and config (sharding granularity, mixed precision, activation checkpointing), the apply/verify loop (MFU and peak memory), and the failure modes specific to single-DC sharding.

Reference templates on real PyTorch FSDP2 APIs (torch>=2.6, validated against the 2.9 signature). Pin and validate against your installed build; lay sharding onto the actual topology. Not hardware-tested here.

What it is

FSDP (Fully Sharded Data Parallel) shards each transformer block's parameters, gradients, and optimizer state across the data-parallel ranks. Each rank holds a 1/N slice. At runtime FSDP all-gathers a block's full params just before its forward/backward, runs compute, reduce-scatters grads back to shards, and frees the gathered params. Peak memory is the largest single block plus the shard, not the whole model. This recipe is the single-DC half of distributed-training recipes and the runnable form of FSDP; for the geo-distributed / low-comms case use DiLoCo per distributed-training recipes.

Single-DC means every step's collectives ride a fast fabric:

  • Intra-node all-gather + reduce-scatter over NVLink/NVSwitch.
  • Inter-node (when >1 node) gradient all-reduce over InfiniBand/RoCE with GPUDirect RDMA, laid out via HSDP so only the lighter all-reduce crosses nodes, not the per-block collectives.

Why it matters

  • Fit models too big for DDP. DDP replicates the full model + Adam optimizer state (~2 fp32 states/param) per GPU and OOMs on large models. FSDP cuts per-GPU state by ~N, fitting most 7B-70B dense fine-tunes and pretraining into the hundreds of billions with HSDP.
  • Per-step comms are cheap in one DC. With NVLink intra-node and IB inter-node, all-gather/reduce-scatter overlap with compute and the run stays compute-bound, so MFU is high. This is the single-DC default; choose by the network between workers, not preference (distributed-training recipes).
  • Native PyTorch, composable with activation checkpointing, tensor/pipeline parallel, and torch.compile, with no external runtime.

When it is needed (and when not)

  • Use FSDP when the model does not fit under DDP but fits when sharded across the data-parallel group. Single-node (8 GPUs, one NVLink island) is the simplest case.
  • Use HSDP (shard intra-node, replicate inter-node) as soon as you cross node boundaries. Naive 1-D FSDP across all ranks all-gathers per block over IB and goes bandwidth-bound. See How below.
  • Use DDP (distributed training) when the model + optimizer + activations fit comfortably on one GPU: lower comms (one all-reduce/step), simpler, faster in that regime.
  • Add TP/PP when a single block's activations or its transient all-gather is too large, or to keep the FSDP group small; frontier pretraining is FSDP/HSDP + TP + PP.
  • Use DiLoCo / PRIME (DiLoCo), NOT this recipe, when workers sit across DCs on slow/heterogeneous links; FSDP's per-step comms cannot tolerate that fabric.
flowchart LR
  Q1{"Fits on 1 GPU<br/>under DDP?"} -->|"yes"| DDP["Use DDP"]
  Q1 -->|"no"| Q2{"All workers in<br/>one DC?"}
  Q2 -->|"no, slow links"| DILOCO["DiLoCo / PRIME<br/>(train-diloco)"]
  Q2 -->|"yes"| Q3{"More than<br/>one node?"}
  Q3 -->|"no"| FSDP1["FSDP, 1-D mesh<br/>(single node)"]
  Q3 -->|"yes"| HSDP["HSDP 2-D mesh<br/>shard intra-node,<br/>replicate inter-node"]

How: implement, integrate, maintain

1. Implement: shard, mixed precision, activation checkpointing

Apply fully_shard per transformer block (so the next block's all-gather overlaps the current block's compute), then to the root last. MixedPrecisionPolicy does bf16 compute with fp32 grad reduction. Activation checkpointing trades compute for the second-largest memory term.

# train_fsdp.py  -- torch>=2.6; validate fully_shard signature on your build.
from __future__ import annotations
import os, time, torch, torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    checkpoint_wrapper, apply_activation_checkpointing,
)

def build_mesh(gpus_per_node: int = 8) -> object | None:
    nodes = int(os.environ["WORLD_SIZE"]) // gpus_per_node
    if nodes <= 1:
        return None  # single node: 1-D mesh inferred from default group
    # HSDP: outer=replicate across nodes (IB all-reduce), inner=shard intra-node (NVLink)
    return init_device_mesh("cuda", (nodes, gpus_per_node),
                            mesh_dim_names=("replicate", "shard"))

def shard_model(model: torch.nn.Module, mesh: object | None) -> None:
    mp = MixedPrecisionPolicy(param_dtype=torch.bfloat16,   # compute/all-gather bf16
                              reduce_dtype=torch.float32)    # grad reduce in fp32 (stability)
    blocks = list(model.model.layers)                        # HF block list
    apply_activation_checkpointing(                          # checkpoint, then shard
        model, checkpoint_wrapper_fn=checkpoint_wrapper,
        check_fn=lambda m: m in set(blocks))
    kw = {"mp_policy": mp} | ({"mesh": mesh} if mesh is not None else {})
    for block in blocks:
        fully_shard(block, **kw)
    fully_shard(model, **kw)                                  # root last

2. Integrate: launch with torchrun (single node and HSDP multi-node)

torchrun sets RANK/WORLD_SIZE/LOCAL_RANK; the c10d rendezvous coordinates ranks. Single node uses --standalone; multi-node runs the same command on every node behind a shared MASTER_ADDR. For cluster fan-out use Slurm, gang scheduling via Volcano job on Kubernetes, or Ray.

# Single node, 8 GPUs on one NVLink island.
torchrun --standalone --nproc_per_node=8 train_fsdp.py
#!/bin/bash
#SBATCH --nodes=4
#SBATCH --ntasks-per-node=1
#SBATCH --gpus-per-node=8
#SBATCH --exclusive
# HSDP across 4 nodes; srun fans one torchrun per node. NCCL tuning for the IB path.
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n1)
export NCCL_NET_GDR_LEVEL=SYS        # enable GPUDirect RDMA
export NCCL_IB_HCA=mlx5              # set to your actual HCAs; comma-separate if multiple
export NCCL_NVLS_ENABLE=1           # NVLink SHARP (in-switch reduction) on NVSwitch
export NCCL_DEBUG=INFO              # confirm [GDRDMA] / [NVLS] in logs once, then drop
srun torchrun --nnodes=$SLURM_NNODES --nproc_per_node=8 \
  --rdzv_backend=c10d --rdzv_endpoint="$MASTER_ADDR:29500" \
  --rdzv_id="$SLURM_JOB_ID" train_fsdp.py

The training step is transparent to FSDP (reduce-scatter happens inside backward()):

def train_loop(model, loader, steps: int, gpus_per_node: int = 8) -> None:
    opt = torch.optim.AdamW(model.parameters(), lr=2e-5, fused=True)
    for batch in loader:
        loss = model(**batch).loss
        loss.backward()              # reduce-scatter / grad all-reduce here
        opt.step(); opt.zero_grad()

3. Verify: MFU and peak memory

Treat MFU (model FLOPs utilization) and peak reserved VRAM as the two acceptance gates. MFU is achieved FLOPs over peak hardware FLOPs; the dense-transformer estimate is 6 * N_params * tokens_per_step FLOPs per step (forward+backward), plus an attention term you can fold in or ignore for the leading-order number. Below the MFU floor, see the MFU regression runbook.

def report(model_params: int, tokens_per_step: int, dt_s: float,
           world_size: int, peak_tflops_per_gpu: float) -> dict[str, float]:
    # 6*N*tokens is the standard dense fwd+bwd FLOPs estimate (illustrative leading term).
    flops_step = 6.0 * model_params * tokens_per_step
    achieved = flops_step / dt_s / world_size / 1e12          # TFLOP/s per GPU
    mfu = achieved / peak_tflops_per_gpu
    peak_gib = torch.cuda.max_memory_reserved() / 1024**3
    return {"tflops_per_gpu": achieved, "mfu": mfu, "peak_reserved_gib": peak_gib}

# peak_tflops_per_gpu: use your GPU's published bf16 dense rate (vendor datasheet).
# Acceptance gate example (illustrative): mfu >= 0.40 and no OOM headroom warnings.

Cross-check fabric health before blaming the model: run nccl-tests fabric validation for busbw, gate nodes via GPU health gating, and watch the live signals in telemetry / monitoring / alerting against the SLO/SLI catalog. Smoke the platform first with smoke tests.

A Prometheus/PromQL alert on MFU regression (emit training_mfu from report() via your exporter):

# fire when MFU drops >15% below the 6h baseline for 30m (tune to your SLO)
(
  avg_over_time(training_mfu[30m])
  /
  avg_over_time(training_mfu[6h] offset 30m)
) < 0.85

4. Maintain: sharded checkpoints and gradient accumulation

Save/load with torch.distributed.checkpoint (DCP): sharded, resharding-tolerant, parallel across ranks. Avoid gathering a full state dict on rank 0 for large models, which OOMs host memory.

import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict

def save_ckpt(model, opt, path: str) -> None:
    msd, osd = get_state_dict(model, opt)
    dcp.save({"model": msd, "optim": osd}, checkpoint_id=path)

def load_ckpt(model, opt, path: str) -> None:
    msd, osd = get_state_dict(model, opt)          # templates for in-place load
    dcp.load({"model": msd, "optim": osd}, checkpoint_id=path)
    set_state_dict(model, opt, model_state_dict=msd, optim_state_dict=osd)

For gradient accumulation without per-microbatch comms, disable grad sync on all but the last microbatch via the FSDP2 method:

model.set_requires_gradient_sync(False)   # accumulate locally (no reduce-scatter)
# ... run K-1 microbatch backward() calls ...
model.set_requires_gradient_sync(True)    # last microbatch: reduce-scatter fires

Failure modes

  • Wrapping only the root -> one giant all-gather, OOM, zero compute/comm overlap. Always shard per block.
  • Naive 1-D FSDP across nodes -> per-block collectives cross IB, IB-bandwidth-bound, low MFU. Use HSDP so only the grad all-reduce crosses nodes.
  • Shard dim mismatched to the NVLink domain (e.g. shard=16 spanning two NVLink islands) -> the "intra-node" all-gather crosses IB anyway. Match the inner shard dim to the NVLink group (8 for HGX, 4 for some boards).
  • reduce_dtype=bf16 can diverge on long runs; reduce grads in fp32.
  • ACS on / wrong NCCL_IB_HCA -> NCCL silently drops off GDR, throughput collapses. Verify [GDRDMA] in NCCL_DEBUG=INFO. PCIe ACS must be off for P2P/GDR.
  • Mixing FSDP1 FullyShardedDataParallel and FSDP2 fully_shard in one model -> different code paths; pick one.
  • No comms/compute overlap -> GPUs idle on every all-gather; shows up as low MFU with healthy fabric.

References

  • PyTorch FSDP docs: https://docs.pytorch.org/docs/stable/fsdp.html
  • fully_shard API (2.9): https://docs.pytorch.org/docs/2.9/distributed.fsdp.fully_shard.html
  • FSDP2 getting-started tutorial: https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html
  • Distributed Checkpoint (DCP): https://docs.pytorch.org/docs/stable/distributed.checkpoint.html
  • "PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel" (arXiv 2304.11277): https://arxiv.org/abs/2304.11277
  • TorchTitan (HSDP/FSDP + TP/PP reference): https://github.com/pytorch/torchtitan
  • torchrun (elastic launch): https://docs.pytorch.org/docs/stable/elastic/run.html
  • NCCL environment variables: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html

Related: Distributed-training recipes · FSDP · DiLoCo · Slurm · Kubernetes · Volcano job · nccl-tests validation · MFU regression runbook · Telemetry · Glossary