Tensor parallelism (TP)¶
Scope: Megatron-style intra-layer model parallelism (splitting a single layer's matmuls across GPUs) as one axis of the parallelism stack in distributed training.
Code convention: blocks that need torch, Megatron-Core, or vLLM are labelled reference templates (run them on a multi-GPU node with versions pinned). Each is paired with a small numpy-only check that proves the core identity it relies on and runs anywhere. Validate loss parity before production use.
What it is¶
Tensor parallelism (TP) splits the weights of an individual layer across GPUs so each holds a slice of the matmul, with a collective to recombine the result. Unlike data parallelism / ZeRO (DeepSpeed and ZeRO) and FSDP (FSDP), which shard across the batch or across layers, TP shards within a layer ("intra-layer model parallelism").
The Megatron formulation pairs two patterns to minimise communication:
- Column-parallel (
ColumnParallelLinear): forY = XA, weightAis split along its output columnsA = [A_1 ... A_p]. Each rank computesY_i = X A_i; output stays sharded (no collective on the forward). - Row-parallel (
RowParallelLinear): weight is split along input rows; each rank consumes a shard of the input and produces a partial sum, combined by an all-reduce.
Composing column-then-row (e.g. the MLP up/down or attention qkv/out projections) means one all-reduce per pair of linears in each of the forward and backward passes, the core efficiency of Megatron TP. Sequence parallelism extends this to shard the LayerNorm/dropout regions along the sequence dimension.
Why use it¶
- A single layer is too big: when one layer's parameters + activations exceed a GPU, no amount of data/ZeRO sharding helps. ZeRO shards across layers; TP is what fits the layer itself.
- Latency: for inference, TP splits one forward across GPUs so a single request's compute is parallelised, lowering per-token latency for large models (inference serving).
- Composability: TP is the innermost axis of "3D" parallelism (TP x PP pipeline parallelism x DP/ZeRO), the standard recipe for the largest models.
When to use it (and when not)¶
- Use TP for very large models where a layer does not fit on one GPU, or to cut single-request inference latency.
- KEEP TP WITHIN A NODE / NVLink domain. TP issues an all-reduce on every transformer block in both forward and backward, so it is extremely communication-heavy and latency-bound. Spanning TP across InfiniBand (inter-node) collapses throughput; set TP degree <= GPUs per NVLink/NVSwitch domain.
- Do not over-shard: small TP (2-8) within a node is typical. Prefer data parallelism / ZeRO (DeepSpeed and ZeRO) and pipeline parallelism (pipeline parallelism) for axes that must cross nodes, where comms are less frequent.
- When the model fits per GPU, skip TP; the extra collectives only add overhead.
Architecture¶
flowchart LR
X["Input X (replicated)"]
subgraph TPG["TP group (within NVLink domain)"]
C0["GPU 0: col-parallel A_0 -> Y_0"]
C1["GPU 1: col-parallel A_1 -> Y_1"]
R0["GPU 0: row-parallel B_0 (partial)"]
R1["GPU 1: row-parallel B_1 (partial)"]
end
AR["all-reduce (NVLink)"]
OUT["Output Z (replicated)"]
X --> C0
X --> C1
C0 --> R0
C1 --> R1
R0 --> AR
R1 --> AR
AR --> OUT
The column-parallel stage keeps its output sharded (no collective), the row-parallel stage produces per-rank partial sums, and a single all-reduce over NVLink reconstructs the replicated output. The two numpy checks below prove exactly these two identities.
How to use it¶
Two mainstream APIs: NVIDIA Megatron-Core (production training stack) and native PyTorch DTensor (torch.distributed.tensor.parallel). The PyTorch path expresses a per-module plan over a device mesh:
# Reference template (needs torch >= 2.x on a multi-GPU node).
# torch DTensor TP: shard one transformer block across a TP mesh
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import (
parallelize_module, ColwiseParallel, RowwiseParallel,
)
tp_mesh = init_device_mesh("cuda", (8,), mesh_dim_names=("tp",))
plan = {
"attn.wqkv": ColwiseParallel(), # column-parallel projections
"attn.wo": RowwiseParallel(), # row-parallel + all-reduce
"mlp.w1": ColwiseParallel(),
"mlp.w2": RowwiseParallel(),
}
parallelize_module(block, tp_mesh, plan)
ColwiseParallel splits a weight along its output columns and keeps the output sharded with no forward collective. This numpy-only check proves that identity against a dense reference, and that a dead (zeroed) shard is detectable:
import numpy as np
# Column-parallel linear: Y = X @ A, weight A split along its OUTPUT columns.
# Each rank computes Y_i = X @ A_i; concatenation reconstructs Y (no collective).
rng = np.random.default_rng(0)
d_model, d_ff, tp = 16, 24, 4
X = rng.standard_normal((5, d_model))
A = rng.standard_normal((d_model, d_ff))
Y_ref = X @ A # single-GPU reference
A_shards = np.split(A, tp, axis=1) # split columns across tp ranks
Y_shards = [X @ Ai for Ai in A_shards] # per-rank matmul, no comm
Y_tp = np.concatenate(Y_shards, axis=1) # gather (or keep sharded)
assert Y_tp.shape == Y_ref.shape
assert np.allclose(Y_tp, Y_ref, atol=1e-12), "column-parallel must equal dense matmul"
# Adversarial: a dead/zeroed rank (corrupted shard) must be detectable.
Y_shards[2][:] = 0.0
Y_bad = np.concatenate(Y_shards, axis=1)
assert not np.allclose(Y_bad, Y_ref), "a zeroed shard must break equivalence"
print("validator #1 (column-parallel): OK")
How to integrate it¶
The model author's job is to mark which linears are column- vs row-parallel and place the all-reduce at the boundary; the framework inserts the collectives. With Megatron-Core, swap nn.Linear for ColumnParallelLinear/RowParallelLinear and pass the TP/PP/DP degrees:
# Reference template (needs Megatron-Core + torch on a multi-GPU node).
from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear
# MLP: column-parallel up-proj (no comm), row-parallel down-proj (all-reduce)
self.up = ColumnParallelLinear(d_model, d_ff, gather_output=False, bias=True)
self.down = RowParallelLinear(d_ff, d_model, input_is_parallel=True, bias=True)
Develop on a single node first (TP only), validate loss matches a non-TP baseline on a tiny model, then add PP/DP. Watch for: weight init must be sharding-aware (Megatron seeds per-rank), and checkpoint save/load must consolidate or shard consistently with the TP degree.
RowParallelLinear consumes a sharded activation and emits a partial sum per rank that an all-reduce combines. This numpy-only check proves the all-reduce is exactly a sum, and that dropping a rank corrupts the output:
import numpy as np
# Row-parallel linear: Y = H @ B, weight B split along its INPUT rows and the
# activation H already sharded along the same dim. Each rank makes a partial
# product; the all-reduce SUMS them: sum_i (H_i @ B_i) == H @ B.
rng = np.random.default_rng(1)
m, d_ff, d_model, tp = 5, 24, 16, 4
H = rng.standard_normal((m, d_ff))
B = rng.standard_normal((d_ff, d_model))
Y_ref = H @ B
H_shards = np.split(H, tp, axis=1) # activation sharded along d_ff
B_shards = np.split(B, tp, axis=0) # weight sharded along input rows
partials = [Hi @ Bi for Hi, Bi in zip(H_shards, B_shards)]
Y_allreduce = sum(partials) # all-reduce == elementwise SUM
assert np.allclose(Y_allreduce, Y_ref, atol=1e-12), "row-parallel all-reduce must sum to dense matmul"
# Adversarial: dropping one rank from the all-reduce must corrupt the result.
Y_missing = sum(partials[:-1])
assert not np.allclose(Y_missing, Y_ref), "a missing rank in the all-reduce must be detectable"
print("validator #2 (row-parallel all-reduce): OK")
How to run it in production¶
Serving inference¶
TP is heavily used for serving large models: it splits one forward pass across GPUs to fit the model and lower latency. vLLM and SGLang expose it directly:
vllm serve <org>/<model> --tensor-parallel-size 8 # TP across 8 GPUs in a node
python -m sglang.launch_server --model-path <org>/<model> --tp 8
Keep --tensor-parallel-size within the NVLink domain; combine with pipeline-parallel or data-parallel replicas to scale out. Disaggregated prefill/decode and multi-node serving build on top of this. See inference serving, serving open-weight models, and disaggregated inference.
Hardware and fabric¶
TP communication is heavy and latency-bound: an all-reduce on every block, every step. The fabric is the deciding factor:
- NVLink/NVSwitch is essential; the all-reduce must stay on intra-node high-bandwidth, low-latency links. Never span TP across InfiniBand/RoCE: inter-node latency dominates and throughput collapses (networking fabric, performance tuning).
- NVLink SHARP (NVLS) offloads the all-reduce into the NVSwitch fabric; enable via
NCCL_NVLS_ENABLE=1where supported. In-network reduction (SHARP) on the IB side helps the DP all-reduce, not the TP one. - NCCL:
NCCL_DEBUG=INFOshould show NVLink (NVL) transport for the TP group; verify no PCIe/TCP fallback. Topology/affinity (Topology Manager, NUMA) must co-locate the TP GPUs. - Blackwell: FP8/NVFP4 tensor cores and the larger NVLink domain (NVL72) widen the practical TP degree and the size of layer that fits (the Blackwell platform).
How to maintain it¶
- Checkpoints encode the TP degree. A checkpoint saved at TP=N stores that shard layout. Loading into a job with a different degree needs a resharding/conversion step: Megatron ships checkpoint converters, and
torch.distributed.checkpoint(DCP) writes a degree-agnostic format you reshard on load. Never load a TP-N checkpoint into a TP-M run without converting; it fails or silently corrupts (see Failure modes below). - Pin framework and NCCL versions. TP sharding conventions, fused-kernel behaviour, and NCCL transport/collective selection shift across releases. Pin Megatron-Core / PyTorch / NCCL and re-run the parity check on every bump.
- Keep a loss-parity gate. Promote the single-node, tiny-model loss-vs-baseline check from integration into CI; re-run it whenever the TP degree, framework version, or fabric changes. Per-rank seeding must stay consistent so a reshard reproduces the same loss.
- Watch the transport. Monitor that NCCL keeps the TP group on NVLink (
NCCL_DEBUG=INFOshowsNVL) and track per-step all-reduce time; alert on any PCIe/TCP fallback or creeping collective latency, which silently caps throughput (performance tuning, networking fabric).
How to scale it¶
TP does not scale across nodes well; its scaling rule is TP degree = GPUs per node (the NVLink domain). To go beyond one node, compose TP with other axes:
- TP x PP (pipeline parallelism): pipeline stages span nodes (rare, large-payload point-to-point), TP stays intra-node.
- TP x DP/ZeRO (DeepSpeed and ZeRO): replicate the TP group across nodes for data parallelism.
# 3D layout: TP=8 (intra-node), PP=2, DP across the rest (Megatron-style)
torchrun --nnodes=4 --nproc_per_node=8 --rdzv_backend=c10d --rdzv_endpoint=$MASTER:29500 \
pretrain_gpt.py --tensor-model-parallel-size 8 --pipeline-model-parallel-size 2
Rank ordering matters: place the TP group on co-located NVLink-connected GPUs, then PP, then DP outermost (distributed training, distributed-training recipes).
Fine-tuning¶
For very large models, TP is part of the fine-tuning topology exactly as in pretraining: shard each layer within a node, compose with ZeRO/PP for the rest. Most RL/post-training stacks rely on a Megatron or FSDP backend that already implements TP (verl and NeMo-RL expose Megatron with TP; slime uses Megatron; RL libraries). Parameter-efficient fine-tuning (SFT and LoRA) usually avoids TP since LoRA-sized models fit per GPU; reach for TP only when the base model layer does not fit (fine-tuning and post-training).
Worked examples¶
1. Megatron-Core MLP block (column -> row)
# Reference template (needs Megatron-Core + torch on a multi-GPU node).
from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear
import torch.nn.functional as F
class ParallelMLP(torch.nn.Module):
def __init__(self, d_model, d_ff):
super().__init__()
self.up = ColumnParallelLinear(d_model, d_ff, gather_output=False, bias=True)
self.down = RowParallelLinear(d_ff, d_model, input_is_parallel=True, bias=True)
def forward(self, x):
h, _ = self.up(x) # sharded along d_ff, no collective
h = F.gelu(h)
out, _ = self.down(h) # all-reduce here
return out
The whole block is one column-parallel up-proj, an elementwise GeLU, and one row-parallel down-proj with a single all-reduce. Because GeLU is elementwise, the activation stays sharded across the pair, so the sharded pipeline must reproduce the dense MLP exactly. This numpy-only check is the loss-parity assertion the integration workflow relies on, plus an adversarial rank-misalignment case:
import numpy as np
def gelu(x): # elementwise, so split-invariant
return 0.5 * x * (1.0 + np.tanh(np.sqrt(2.0/np.pi) * (x + 0.044715 * x**3)))
# Megatron MLP: up-proj column-parallel (shard d_ff, no comm), GeLU, down-proj
# row-parallel (all-reduce). One all-reduce reconstructs the dense output.
rng = np.random.default_rng(2)
b, d_model, d_ff, tp = 3, 16, 32, 4
X = rng.standard_normal((b, d_model))
W1 = rng.standard_normal((d_model, d_ff)) # up-proj
W2 = rng.standard_normal((d_ff, d_model)) # down-proj
Y_ref = gelu(X @ W1) @ W2 # dense single-GPU MLP
W1s = np.split(W1, tp, axis=1) # column-parallel up-proj
W2s = np.split(W2, tp, axis=0) # row-parallel down-proj
partials = [gelu(X @ w1) @ w2 for w1, w2 in zip(W1s, W2s)]
Y_tp = sum(partials) # single all-reduce
assert np.allclose(Y_tp, Y_ref, atol=1e-9), "TP MLP (col->row) must equal dense MLP"
# Adversarial: misaligning the column-shard of W1 with the row-shard of W2
# (a rank-ordering bug between the paired linears) must be detectable.
Y_bug = sum(gelu(X @ w1) @ w2 for w1, w2 in zip(W1s, W2s[::-1]))
assert not np.allclose(Y_bug, Y_ref), "col/row shard misalignment must be detectable"
print("validator #3 (col->row MLP): OK")
2. vLLM TP serving across a node
# 8-GPU NVLink node; one logical replica, model split 8 ways
vllm serve <org>/<model-large> \
--tensor-parallel-size 8 \
--gpu-memory-utilization 0.92 \
--max-model-len 32768
# scale OUT with --pipeline-parallel-size or multiple replicas, NOT TP>GPUs/node
3. 3D layout (TP intra-node, PP + DP across nodes)
# 4 nodes x 8 GPUs: TP=8, PP=2 -> 16 GPUs per model replica, DP=2 replicas
torchrun --nnodes=4 --nproc_per_node=8 \
--rdzv_backend=c10d --rdzv_endpoint=$MASTER_ADDR:29500 \
pretrain.py \
--tensor-model-parallel-size 8 \
--pipeline-model-parallel-size 2 \
--sequence-parallel
Failure modes¶
- TP spanning nodes (TP > GPUs/node) -> all-reduce over IB, catastrophic slowdown. Hard rule: keep TP inside the NVLink domain.
- PCIe/TCP fallback for the TP group -> NCCL not using NVLink; check
NCCL_DEBUG=INFOshowsNVL, ACS off, GPUs co-located. - Checkpoint TP-degree mismatch -> load fails or silently corrupts; resharding tools must convert between TP degrees.
- Non-sharding-aware weight init -> divergent loss vs single-GPU baseline; use the framework's per-rank seeding.
- Forgetting sequence parallelism -> LayerNorm/dropout activations replicated, wasting memory the TP split was meant to save.
- Over-sharding small models -> collective overhead exceeds compute saved; only shard layers that do not fit.
References¶
- Megatron-LM (Shoeybi et al., 2019): https://arxiv.org/abs/1909.08053 · Megatron repo: https://github.com/NVIDIA/Megatron-LM
- Megatron-Core tensor-parallel API: https://docs.nvidia.com/megatron-core/developer-guide/latest/apidocs/core/core.tensor_parallel.layers.html
- PyTorch Tensor Parallel (DTensor): https://docs.pytorch.org/docs/stable/distributed.tensor.parallel.html · TP tutorial: https://docs.pytorch.org/tutorials/intermediate/TP_tutorial.html
- Sequence parallelism / Megatron (Korthikanti et al., 2022): https://arxiv.org/abs/2205.05198
- vLLM distributed serving: https://docs.vllm.ai/en/latest/ · SGLang: https://docs.sglang.ai/
Related: Distributed Training · Pipeline Parallel · FSDP · DeepSpeed/ZeRO · Serving OSS · Networking Fabric · Perf Optimization · Glossary