Skip to content
Markdown

Distributed training platform

Scope: the frameworks and mechanics of training across many GPUs: launchers, parallelism strategies, the libraries, numerics, checkpoint/resume, fault tolerance, and the one efficiency number that matters (MFU). The MLOps core for the training side, and the closest neighbour to GPU performance and health.

Overview

Above one node, training is an exercise in splitting a model and its data across GPUs without letting communication or imbalance swamp compute. The strategy you pick (which dimensions to parallelise) determines the collective pattern, which determines the fabric requirement (networking fabric) and the achievable efficiency. The skill is laying parallelism onto the topology so the GPUs stay compute-bound, and measuring that with MFU rather than "the GPUs look busy".

Architecture: parallelism onto topology (TP intra-node, PP/FSDP inter-node)

flowchart TB
  subgraph N1["Node 1 — NVLink domain (TP)"]
    A1["GPU0"] --- A2["GPU1"] --- A3["GPU2"] --- A4["GPU3"]
  end
  subgraph N2["Node 2 — NVLink domain (TP)"]
    B1["GPU0"] --- B2["GPU1"] --- B3["GPU2"] --- B4["GPU3"]
  end
  N1 ==>|"PP / FSDP over InfiniBand"| N2

Core knowledge

Launchers and process model

  • torchrun (elastic, c10d/etcd rendezvous), srun under Slurm (provisioning and scheduling), or mpirun. Each rank gets RANK, LOCAL_RANK, WORLD_SIZE, MASTER_ADDR. Elastic rendezvous lets the world size change on failure (reliability and RAS).

Parallelism dimensions (they compose)

  • Data parallel (DDP): replicate the model, shard the data, all-reduce gradients. Simplest; memory-bound by model size.
  • FSDP / ZeRO: shard parameters, gradients, and optimizer state across ranks to fit large models. Comms become all-gather (params) + reduce-scatter (grads). ZeRO stages 1/2/3 (DeepSpeed); FSDP2 is the PyTorch-native form.
  • Tensor parallel (TP): split individual layers across GPUs (intra-layer). Heavy, latency-sensitive comms, so keep inside the NVLink domain (intra-node, or intra-rack on NVL72).
  • Pipeline parallel (PP): split the layer stack into stages across nodes; micro-batches fill the pipeline; mind the bubble.
  • Expert parallel (EP): distribute MoE experts; all-to-all routing.
  • Sequence / context parallel: shard the sequence dimension for long context.
  • Typical composition at scale: TP within a node, PP across nodes, FSDP/DP outermost, laid onto the topology so the heaviest comms ride the fastest links (networking fabric, performance tuning).

Frameworks

  • PyTorch (DDP, FSDP2), DeepSpeed (ZeRO, offload), Megatron-Core (the reference TP/PP/EP/SP engine for LLM scale), NVIDIA NeMo (end-to-end, built on Megatron), TorchTitan (PyTorch-native large-scale reference), Ray Train (orchestration), Hugging Face Accelerate/Trainer. Transformer Engine provides FP8 layers (performance tuning).

Numerics

  • BF16 is the default training precision; FP8 (Hopper/Blackwell, via Transformer Engine) for more throughput once numerics are validated. Gradient accumulation simulates a larger global batch; activation/gradient checkpointing recomputes activations to trade compute for memory.

Checkpoint, resume, fault tolerance

  • Sharded async checkpoints via DCP (storage and data). A correct resume restores optimizer state, RNG, and dataloader position, not just weights.
  • Elastic / fault-tolerant training: torchrun elastic resizes on failure; torchft and redundant replicas keep large jobs alive; automatic restart from the last checkpoint. At scale, node failure is frequent enough to design around from day one (reliability and RAS).
  • Straggler detection: one slow rank gates every collective, so the whole job runs at the slowest GPU.

The efficiency metric: MFU

  • MFU (Model FLOPs Utilization) = achieved model FLOPs / hardware peak FLOPs. HFU additionally counts recompute. For large LLM pretraining, 35-50% MFU is healthy; below ~35% signals a bottleneck: comms, dataloader (storage and data), imbalance, or a bad parallelism layout. Optimise toward MFU, not raw "utilisation" (observability, performance tuning).

Low-communication / geo-distributed (context)

  • DiLoCo and related methods reduce communication frequency to tolerate WAN latency between data centres, the inverse of the single-DC, high-bandwidth, TP-heavy regime. Knowing both, and when each applies, is the distinction to hold cleanly.

Don't-miss checklist

  • Lay parallelism onto topology: TP inside NVLink, PP/DP across the slower fabric (networking fabric).
  • Track MFU; investigate anything under ~35-40%.
  • Sharded async checkpoints; test that resume restores optimizer, RNG, and data position.
  • Design for node failure and stragglers from the start at scale (reliability and RAS).
  • BF16 by default; FP8 only with Transformer Engine and validated numerics.

Failure modes

  • TP spanning nodes over IB instead of NVLink: comms-bound, MFU collapses.
  • Global batch mis-sized, learning rate not scaled: divergence or wasted compute.
  • Resume that loses dataloader position or RNG: silent data repetition, non-reproducibility.
  • A single straggler gating all-reduce: whole job at the slowest rank's pace.
  • OOM at scale from missing activation checkpointing or wrong FSDP wrapping.

Open questions & validation

  • Validate a Megatron-Core / NeMo TP+PP+EP layout against the topology, not just FSDP; confirm comms land on the intended links (performance tuning).
  • FP8 training numerics on Blackwell via Transformer Engine: where it holds and where it diverges.
  • A measured MFU number on a real run, and the bottleneck hunt that follows.

References

  • PyTorch FSDP: https://docs.pytorch.org/docs/stable/fsdp.html
  • Megatron-LM / Megatron-Core: https://github.com/NVIDIA/Megatron-LM
  • DeepSpeed: https://www.deepspeed.ai/ · NeMo: https://docs.nvidia.com/nemo-framework/user-guide/latest/index.html
  • TorchTitan: https://github.com/pytorch/torchtitan
  • Transformer Engine (FP8): https://docs.nvidia.com/deeplearning/transformer-engine/index.html
  • torchrun elastic: https://docs.pytorch.org/docs/stable/elastic/run.html

Related: Fabric · Provisioning · Performance · Storage · Reliability · Training as a Platform Service · Optimization · Muon / DMuon · Glossary