Tensor cores and mixed precision¶
Scope: how Tensor Cores and reduced-precision formats (TF32, BF16/FP16, FP8, NVFP4, INT8) raise arithmetic intensity and throughput, why accumulation precision and 2:4 sparsity matter, and how to drive them from PyTorch and CUTLASS.
What it is¶
A Tensor Core is a fixed-function matrix-multiply-accumulate (MMA) unit on the SM. It consumes two low-precision operand tiles, multiplies them, and accumulates partial sums in a higher-precision accumulator. One MMA instruction issues many fused multiply-adds per cycle, so a Tensor Core kernel does far more FLOPs per issued instruction (and per byte loaded) than a scalar CUDA-core FMA loop.
The throughput lever is operand width. Each drop in precision halves (FP16 to FP8) or quarters (down to FP4) the bytes per element, which doubles or quadruples both the weights loaded per HBM transaction and the math the Tensor Cores can issue per cycle. Reduced precision is therefore the primary way to move a tensor-heavy kernel rightward on the roofline toward the compute roof.
The formats, by exponent/mantissa layout:
| Format | Exponent | Mantissa | Role |
|---|---|---|---|
| FP32 | 8 | 23 | Baseline; CUDA-core compute and safe accumulation |
| TF32 | 8 | 10 | FP32 range, FP16 precision; Tensor Core path for fp32 matmuls |
| BF16 | 8 | 7 | FP32 range, no loss scaling; preferred for training |
| FP16 | 5 | 10 | Half precision; needs loss scaling to avoid underflow |
| FP8 (E4M3) | 4 | 3 | Inference and FP8 training with per-tensor/per-block scaling |
| NVFP4 (E2M1) | 2 | 1 | Blackwell 4-bit with two-level micro-scaling |
| INT8 | — | — | Integer inference via DP4A and integer MMA |
TF32's layout (8-bit exponent like FP32, 10-bit mantissa like FP16) is confirmed by NVIDIA: it keeps FP32 range while running on Tensor Cores at higher throughput than FP32 on CUDA cores (NVIDIA TF32 blog). BF16 matches FP32's 8-bit exponent; FP16 uses a narrower 5-bit exponent (Ampere Tuning Guide).
Why use it¶
Compute throughput on each GPU generation is growing faster than HBM bandwidth, so kernels are increasingly memory-bound unless arithmetic intensity rises. Reduced-precision Tensor Cores attack both terms of FLOPs/byte at once: more FLOPs issued per cycle, fewer bytes moved per element.
Reported multipliers (book, Blackwell-class hardware, not independently benchmarked here):
- BF16/FP16: Tensor Cores sustain >90% of half-precision peak, roughly 4x the FP32 peak.
- FP8 with FP32/TF32 accumulation: 2-3x the BF16/FP16 TFLOPS, given acceptable quantization error.
- NVFP4 on B200: ~10 PFLOPS dense vs ~80 TFLOPS dense FP32 peak, about two orders of magnitude higher theoretical throughput per weight; B300 (Ultra) reports ~15 PFLOPS dense (50% more than B200).
- INT8: weight traffic drops 75% vs FP32 (1 byte vs 4); DP4A issues 4 INT8 multiply-accumulates per instruction vs 1 FP32 FMA.
In Nsight Compute, the shift is observable: as you move FP32 to TF32/BF16/FP16/FP8/FP4, Speed-of-Light memory-bound stall reasons ("Memory Throttle", cache misses) drop and Warp Stall metrics shift from memory-related stalls toward dependency/pipeline stalls, the signature of moving from memory-bound to compute-bound. See nsight profiling and observability.
When to use it (and when not)¶
Use it when:
- Large GEMMs or attention dominate runtime (LLM training and inference serving): these saturate Tensor Cores and amortize per-tile overhead.
- The kernel is memory-bound on the roofline and operands tolerate reduced precision.
- Batch/tile granularity is large enough (e.g. 128-256) to amortize format conversion, scaling, and sparse-index overhead. A batch of 1 in FP8 or 2:4 may see little benefit.
Pick the precision deliberately:
- BF16 for training: matches FP32 exponent range, so it rarely needs
GradScaler; preferred over FP16 on modern GPUs. - FP16 for training only with loss scaling (
GradScaler) to keep small gradients out of the 5-bit-exponent underflow region. - TF32 as a near-free upgrade for fp32 matmuls: same code, FP32 range preserved.
- FP8/NVFP4 when the model tolerates the precision drop after calibration; accuracy must be validated per model.
- INT8 for inference paths that tolerate quantization.
Skip or be cautious when:
- The op is accuracy-sensitive (layer norm, softmax, reductions). Keep these in FP32. AMP does this automatically.
- Accumulation in low precision would lose signal, so always accumulate in BF16/FP16/FP32, never in the operand format.
- 2:4 sparsity during training: gradients do not benefit, maintaining sparsity in updates is complex, and NVIDIA's 2:4 Sparse Tensor Core feature is primarily an inference feature. Training support is limited and framework-dependent; verify before relying on it.
Architecture¶
Two low-precision operand tiles enter the MMA unit, the multiply happens at operand precision, and partial sums land in a higher-precision accumulator before the output tile is written back. Narrower operands buy throughput (more FLOPs/cycle, fewer bytes/element); the accumulator buys back the numerical range the operands gave up.
flowchart LR
A["Operand A tile<br/>(FP16 / BF16 / FP8 / NVFP4)"] --> MMA
B["Operand B tile<br/>(FP16 / BF16 / FP8 / NVFP4)"] --> MMA
MMA["Tensor Core MMA<br/>(multiply at low precision)"] --> ACC["Accumulator (FP32 / BF16 / FP16)"]
ACC --> OUT["Output tile C"]
WIDTH["Narrower operands: more FLOPs/cycle, fewer bytes/element"] -.->|throughput| A
RANGE["Lower precision: narrower numerical range, needs scaling"] -.->|tradeoff| MMA
On Blackwell the operand/accumulator staging is hardware-managed: tcgen05.mma instructions move operands between shared memory and a dedicated ~256 KB per-SM Tensor Memory (TMEM) accumulator, fed by the Tensor Memory Accelerator (TMA) streaming tiles from HBM into shared memory. You never allocate TMEM directly; CUTLASS, cuBLAS, cuDNN, and Triton manage it.
Accumulation precision is the invariant that makes the rest usable. Low precision is for operands, not accumulators. The MMA executes e.g. FP8 x FP8 or FP16 x FP16, but partial sums accumulate at higher precision (BF16/FP16/FP32) in TMEM, kernel-dependent. The low-precision-to-accumulator conversions happen automatically inside the MMA path: the kernel reads FP4/FP8 inputs from HBM, Tensor Cores multiply at low precision, and the accumulator holds BF16/FP16/FP32. This is what makes aggressive quantization usable without destroying numerics. Keep accumulation in FP32 for stability whenever in doubt.
The block below makes that contract concrete: it multiplies FP16 operands but accumulates in FP32, and shows that accumulating in the operand format instead swamps the small partial sums (a long dot product where the FP16 running sum outgrows its own ULP).
# Runnable on system python3 (numpy). Core mechanism of a Tensor Core MMA: operands are LOW
# precision (FP16 here) but partial sums accumulate in a HIGHER-precision register (FP32),
# never in the operand format. This is what makes low-precision matmul numerically usable.
import numpy as np
def to_fp16(x):
return np.asarray(x, dtype=np.float32).astype(np.float16)
rng = np.random.default_rng(0)
N = 8192
a = to_fp16(rng.uniform(0.5, 1.5, N))
b = to_fp16(rng.uniform(0.5, 1.5, N))
ref = float(np.dot(a.astype(np.float64), b.astype(np.float64))) # trusted reference
acc_fp32 = np.float32(0.0) # accumulate FP16*FP16 products in FP32 (TMEM)
for i in range(N):
acc_fp32 = np.float32(acc_fp32 + np.float32(a[i]) * np.float32(b[i]))
acc_fp16 = np.float16(0.0) # accumulate in the operand format (wrong)
for i in range(N):
acc_fp16 = np.float16(acc_fp16 + (a[i] * b[i]))
err_fp32 = abs(float(acc_fp32) - ref)
err_fp16 = abs(float(acc_fp16) - ref)
assert err_fp32 / ref < 1e-4, err_fp32 / ref # FP32 accumulation tracks
assert err_fp16 / ref > 1e-2, err_fp16 / ref # FP16 accumulation swamps
assert err_fp16 > err_fp32 * 100, (err_fp16, err_fp32)
print(f"accumulation OK: fp32 rel_err={err_fp32/ref:.2e} fp16 rel_err={err_fp16/ref:.2e}")
# accumulation OK: fp32 rel_err=2.91e-06 fp16 rel_err=4.75e-01
How to use it¶
The accumulation contract from Architecture is the one invariant to preserve: operands go low, accumulators stay in FP32. The PyTorch knobs below switch on the low-precision operand paths while keeping that invariant, and keep accuracy-sensitive ops (layer norm, softmax, reductions) in FP32.
TF32 for FP32 matmuls¶
Enabling TF32 makes torch.matmul and torch.nn.Linear run as TF32 Tensor Core kernels instead of FP32 on CUDA cores:
# Reference template (requires torch + a CUDA GPU). The TF32 numerics are validated in numpy below.
import torch
# {'highest' | 'high' | 'medium'}; 'high'/'medium' select TF32, 'highest' keeps FP32
torch.set_float32_matmul_precision("high")
high/medium are equivalent to torch.backends.cuda.matmul.allow_tf32 = True; highest (default) keeps internal FP32 (torch.set_float32_matmul_precision).
TF32's contract, validated in numpy: it keeps FP32's 8-bit exponent (so a value that overflows FP16 stays finite) while truncating the mantissa to 10 bits (so sub-ULP increments round away). The adversarial case is the large value that overflows FP16 yet stays finite in TF32.
# Runnable on system python3 (numpy). TF32 = FP32's 8-bit exponent (full range) with the
# mantissa truncated to 10 bits (FP16-like precision). Modelled by rounding the FP32 bit
# pattern to keep sign + 8 exponent + 10 mantissa bits.
import numpy as np
def to_tf32(x):
u = np.asarray(x, dtype=np.float32).view(np.uint32).copy()
lsb = (u >> 13) & 1 # round-to-nearest-even on 13 dropped bits
u = (u + np.uint32(0x0FFF) + lsb) & np.uint32(0xFFFFE000)
return u.view(np.float32)
# 1) RANGE preserved: a value that overflows FP16 (max 65504) stays finite in TF32.
big = np.float32(1e30)
with np.errstate(over="ignore"):
assert np.isinf(np.float16(big)), "FP16 must overflow at 1e30"
tf = to_tf32(big)
assert np.isfinite(tf) and abs(float(tf) - 1e30) / 1e30 < 1e-3, "TF32 keeps FP32 range"
# 2) PRECISION ~10 mantissa bits: 1 + 2^-11 is below the TF32 step above 1.0 so it rounds
# back to 1.0, while 1 + 2^-9 (two ULPs) survives. FP32 keeps the finer increment.
one = np.float32(1.0)
assert to_tf32(one + np.float32(2.0**-11)) == one, "sub-ULP increment must round away"
assert to_tf32(one + np.float32(2.0**-9)) > one, "a >1-ULP increment must survive"
assert np.float32(1.0) + np.float32(2.0**-11) > one, "FP32 keeps the finer increment"
# 3) EQUIVALENCE: values already fitting in 10 mantissa bits are unchanged by TF32 rounding.
exact = np.array([0.0, 1.0, 1.5, 2.0, -4.0, 0.5, 256.0], dtype=np.float32)
assert np.array_equal(to_tf32(exact), exact), "10-bit-representable values are unchanged"
print("TF32 OK: range like FP32, precision ~10 mantissa bits")
# TF32 OK: range like FP32, precision ~10 mantissa bits
Automatic mixed precision (AMP)¶
AMP picks FP16 or BF16 per op and accumulates in FP32 for stability. BF16 is the safer default (FP32 exponent range, no GradScaler):
# Reference template (requires torch + a CUDA GPU). The BF16 vs FP16 range is validated in numpy below.
import torch
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
output = model(input)
loss = loss_fn(output, target)
loss.backward()
Why BF16 is the safer default, in numpy: it keeps FP32's 8-bit exponent range where FP16 overflows and underflows, trading mantissa bits for range. The edge cases are the large value FP16 sends to infinity and the small value FP16 flushes to zero, both of which BF16 keeps finite and nonzero.
# Runnable on system python3 (numpy). AMP picks BF16 or FP16 for the matmul and accumulates
# in FP32. BF16 is the safe default: it shares FP32's 8-bit exponent (same dynamic range),
# so it needs no loss scaling; FP16's 5-bit exponent overflows and underflows far sooner.
# numpy has float16 natively; BF16 is modelled by rounding FP32 to its top 16 bits.
import numpy as np
def to_bf16(x):
u = np.asarray(x, dtype=np.float32).view(np.uint32).copy()
lsb = (u >> 16) & 1 # round-to-nearest-even, keep sign+8exp+7mant
u = (u + np.uint32(0x7FFF) + lsb) & np.uint32(0xFFFF0000)
return u.view(np.float32)
# 1) RANGE: a large activation (1e5 > FP16 max 65504) overflows FP16 but is finite in BF16.
big = np.float32(1e5)
with np.errstate(over="ignore"):
assert np.isinf(np.float16(big)), "FP16 (5-bit exponent) overflows at 1e5"
assert np.isfinite(to_bf16(big)) and abs(float(to_bf16(big)) - 1e5) / 1e5 < 1e-2, "BF16 finite"
# 2) UNDERFLOW: a small gradient (1e-8) flushes to 0 in FP16 but stays nonzero in BF16.
small = np.float32(1e-8)
assert np.float16(small) == np.float16(0.0), "FP16 underflows 1e-8 to zero"
assert to_bf16(small) != 0.0, "BF16 keeps 1e-8 (FP32 exponent range)"
# 3) PRECISION trade: BF16 has 7 mantissa bits, coarser than FP16's 10. Near 1.0 the BF16
# step (2^-7) exceeds the FP16 step (2^-10): BF16 buys range by spending mantissa.
one = np.float32(1.0)
assert to_bf16(one + np.float32(2.0**-8)) == one, "sub-ULP rounds away in BF16 (7 mantissa bits)"
assert np.float16(1.0) + np.float16(2.0**-8) > np.float16(1.0), "FP16 (10 mantissa bits) keeps it"
print("BF16 vs FP16 OK: BF16 keeps FP32 range, FP16 keeps more mantissa")
# BF16 vs FP16 OK: BF16 keeps FP32 range, FP16 keeps more mantissa
FP16 training needs a gradient scaler to prevent underflow (BF16 does not):
# Reference template (requires torch + a CUDA GPU). The loss-scaling math is validated in numpy below.
import torch
scaler = torch.amp.GradScaler("cuda") # FP16 only; omit for BF16
for x, y in loader:
optimizer.zero_grad()
with torch.amp.autocast("cuda", dtype=torch.float16):
loss = loss_fn(model(x), y)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
torch.cuda.amp.autocast/GradScaler are deprecated; use torch.amp.* with an explicit "cuda" device (torch.amp). Under torch.compile, TorchInductor fuses the precision casts: GEMMs run FP16/TF32 on Tensor Cores, accumulation stays FP32, and sensitive kernels (layer norm, softmax) stay FP32.
Why FP16 needs the scaler, in numpy: a gradient below FP16's subnormal floor flushes to zero, and multiplying the loss by a scale then dividing it back round-trips the gradient into range. The adversarial case is a scale that is too small to lift the gradient, which a real scaler detects and backs off from.
# Runnable on system python3 (numpy). FP16 loss scaling: small gradients live below FP16's
# smallest subnormal (~6e-8) and flush to zero. Multiplying the loss by S shifts the backward
# pass up into representable range; dividing by S (unscaling) before the optimizer step
# recovers the true gradient. BF16's wider exponent needs none of this.
import numpy as np
def to_bf16(x):
u = np.asarray(x, np.float32).view(np.uint32).copy()
u = (u + np.uint32(0x7FFF) + ((u >> 16) & 1)) & np.uint32(0xFFFF0000)
return u.view(np.float32)
grad = np.float32(1e-9) # a real gradient below FP16's floor
assert np.float16(grad) == np.float16(0.0), "unscaled: FP16 flushes 1e-9 to zero (signal lost)"
S = np.float32(2.0**16) # loss scale
scaled = np.float16(grad * S) # backward runs in FP16 on the scaled loss
assert scaled != np.float16(0.0), "scaling lifts the gradient into FP16 range"
recovered = np.float32(scaled) / S # unscale before the optimizer step
assert abs(recovered - grad) / grad < 3e-3, ("unscale rel_err", abs(recovered - grad) / grad)
# ADVERSARIAL: too small a scale still underflows; a real scaler must detect this and back off.
assert np.float16(grad * np.float32(16.0)) == np.float16(0.0), "under-scaled gradient still lost"
# BF16 (8-bit exponent) represents 1e-9 directly, so it needs no loss scaling at all.
assert to_bf16(grad) != 0.0, "BF16 keeps 1e-9 without any loss scaling"
print(f"loss scaling OK: 1e-9 lost unscaled, recovered at S=2^16 (rel_err={abs(recovered-grad)/grad:.1e})")
# loss scaling OK: 1e-9 lost unscaled, recovered at S=2^16 (rel_err=4.4e-04)
How to integrate it¶
FP8 and NVFP4 via the Transformer Engine¶
The Transformer Engine combines Tensor Core low-precision hardware with a software runtime for scaling and casting. NVFP4 (E2M1, values up to +/-6) uses two-level micro-scaling: each 16-element microblock gets an FP8 E4M3 scale, plus a per-tensor FP32 scale to avoid overflow (Transformer Engine FP8/FP4 primer, NVFP4 docs). This is what lets 4-bit storage retain usable accuracy. Use library/TE paths rather than hand-rolling FP8/FP4 scaling.
# Reference template (requires torch + transformer_engine on a Hopper/Blackwell GPU).
# The two-level micro-scaling math is validated in numpy below.
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling, Format
fp8_recipe = DelayedScaling(fp8_format=Format.E4M3) # FP8 forward; NVFP4 recipes analogous
model = te.Linear(4096, 4096, bias=True)
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
out = model(inp) # low-precision operands, FP32 accumulation
The core of NVFP4's accuracy, in numpy: two-level micro-scaling maps each block into E2M1's usable range, where a single unscaled cast to 4-bit would saturate everything above +/-6. The adversarial case is two blocks whose magnitudes differ by orders of magnitude, which one global scale cannot serve.
# Runnable on system python3 (numpy). NVFP4 stores weights as 4-bit E2M1 (magnitudes on the
# grid {0,.5,1,1.5,2,3,4,6}, max +/-6) and recovers accuracy with two-level micro-scaling:
# a per-16-element block scale (FP8 E4M3) times a per-tensor FP32 scale. Without scaling,
# anything above 6 saturates; per-block scaling maps each block into E2M1's usable range.
import numpy as np
E2M1 = np.array([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]) # nonneg E2M1 magnitudes
def q_e2m1(x): # nearest-magnitude E2M1 quantizer
x = np.asarray(x, np.float32)
idx = np.abs(np.abs(x)[..., None] - E2M1).argmin(axis=-1)
return (np.sign(x) * E2M1[idx]).astype(np.float32)
def q_e4m3(s): # FP8 E4M3 scale: 3 mantissa bits, max 448
s = np.clip(np.asarray(s, np.float32), 2.0**-9, 448.0).view(np.uint32).copy()
s = (s + np.uint32(0x7FFFF) + ((s >> 20) & 1)) & np.uint32(0xFFF00000)
return s.view(np.float32)
rng = np.random.default_rng(7)
BLK = 16
# Adversarial: two blocks with wildly different magnitudes (~0.02 vs ~120), the case a single
# global scale cannot serve, which is the whole reason micro-scaling exists.
w = np.concatenate([rng.uniform(-0.02, 0.02, BLK),
rng.uniform(-120, 120, BLK)]).astype(np.float32)
# (1) NO scaling: cast straight to E2M1. The big block saturates at 6, error is enormous.
naive = q_e2m1(w)
rel_naive = np.linalg.norm(naive - w) / np.linalg.norm(w)
assert (np.abs(naive) <= 6.0).all() and np.abs(naive).max() == 6.0, "unscaled saturates at 6"
assert rel_naive > 0.5, rel_naive
# (2) TWO-LEVEL micro-scaling: per-tensor FP32 scale then per-block E4M3 scale into +/-6.
per_tensor = np.float32(np.abs(w).max() / 448.0) # keeps block scales in E4M3 range
wn = (w / per_tensor).reshape(-1, BLK)
blk_scale = q_e4m3(np.abs(wn).max(axis=1, keepdims=True) / 6.0)
deq = (q_e2m1(wn / blk_scale) * blk_scale * per_tensor).reshape(-1)
rel_mx = np.linalg.norm(deq - w) / np.linalg.norm(w)
# (3) Micro-scaling beats no-scaling by a wide margin; block scales are valid E4M3 values.
assert rel_mx < rel_naive / 5.0, (rel_mx, rel_naive)
assert rel_mx < 0.2, rel_mx
assert np.array_equal(q_e4m3(blk_scale), blk_scale), "block scales are exact E4M3 values"
print(f"NVFP4 micro-scaling OK: rel_err {rel_naive:.2f} (unscaled) -> {rel_mx:.3f} (two-level)")
# NVFP4 micro-scaling OK: rel_err 0.93 (unscaled) -> 0.073 (two-level)
CUTLASS and cuBLAS GEMMs¶
For a hand-tuned-equivalent GEMM, instantiate a CUTLASS template. It handles shared-memory tiling, async copies (cp.async/TMA), double buffering, and TMEM staging, typically within a few percent of a hand-written MMA kernel:
// Reference template (CUDA C++ / CUTLASS, requires nvcc + a Tensor Core GPU).
// The tiling-with-FP32-accumulation numerics are validated in the numpy block below.
#include <cutlass/numeric_types.h>
#include <cutlass/gemm/device/gemm.h>
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, // A (FP16)
cutlass::half_t, cutlass::layout::ColumnMajor, // B (FP16)
cutlass::half_t, cutlass::layout::RowMajor, // C / output (FP16)
float, // accumulator (FP32)
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm100>; // Blackwell B200
Gemm gemm_op;
cutlass::Status status = gemm_op(
{M, N, K}, // GEMM shape
1.0f, // alpha
A_d, lda,
B_d, ldb,
0.0f, // beta
C_d, ldc);
Tile shapes are chosen empirically (e.g. 128x128 or 256x128) to fit the 256 KB per-SM TMEM budget. A 256x512 FP16 tile (2 bytes/elem) or 256x256 FP32 tile (4 bytes/elem) maxes that budget. CUTLASS templates already cover FP4/FP8/FP16/TF32 operands and can fuse bias-add and activation. cuBLAS is built on CUTLASS, and PyTorch dispatches to these libraries, so most of this throughput is available without writing CUDA C++.
What the tiling guarantees numerically, in numpy: a tiled GEMM that accumulates each K-tile's partial product in an FP32 accumulator reproduces the reference matmul, including ragged boundary tiles when the shape is not a multiple of the tile. The corruption case is a kernel that drops the ragged tail K-tile, which is silently wrong.
# Runnable on system python3 (numpy). What CUTLASS does numerically: tile the GEMM over M/N/K
# and accumulate each K-tile's partial product into an FP32 accumulator. Tiling reorders the
# work but must reproduce the reference GEMM; dropping a ragged tail tile does not.
import numpy as np
def tiled_gemm(A, B, tile=64):
M, K = A.shape; N = B.shape[1]
C = np.zeros((M, N), dtype=np.float32) # FP32 accumulator (like TMEM)
for i in range(0, M, tile):
for j in range(0, N, tile):
acc = np.zeros((min(tile, M - i), min(tile, N - j)), dtype=np.float32)
for k in range(0, K, tile): # accumulate over K tiles, incl. ragged tail
acc += A[i:i+tile, k:k+tile].astype(np.float32) @ B[k:k+tile, j:j+tile].astype(np.float32)
C[i:i+acc.shape[0], j:j+acc.shape[1]] = acc
return C
rng = np.random.default_rng(1)
M, K, N = 130, 70, 96 # non-multiples of tile -> ragged tiles
A = rng.standard_normal((M, K)).astype(np.float32)
B = rng.standard_normal((K, N)).astype(np.float32)
ref = A.astype(np.float64) @ B.astype(np.float64) # trusted reference
# (1) EQUIVALENCE incl. ragged boundary tiles: tiled FP32 GEMM matches the reference.
C = tiled_gemm(A, B, tile=64)
assert np.allclose(C, ref, rtol=1e-4, atol=1e-4), np.abs(C - ref).max()
# (2) CORRUPTION detection: a kernel that skips the ragged tail K-tile is silently wrong.
def buggy_gemm(A, B, tile=64):
M, K = A.shape; N = B.shape[1]
C = np.zeros((M, N), np.float32)
for i in range(0, M, tile):
for j in range(0, N, tile):
acc = np.zeros((min(tile, M - i), min(tile, N - j)), np.float32)
for k in range(0, (K // tile) * tile, tile): # drops the ragged tail
acc += A[i:i+tile, k:k+tile].astype(np.float32) @ B[k:k+tile, j:j+tile].astype(np.float32)
C[i:i+acc.shape[0], j:j+acc.shape[1]] = acc
return C
assert not np.allclose(buggy_gemm(A, B), ref, rtol=1e-2), "dropping the tail K-tile must corrupt"
print(f"tiled GEMM OK: max_abs_err={np.abs(C-ref).max():.2e}, ragged tiles handled")
# tiled GEMM OK: max_abs_err=9.02e-06, ragged tiles handled
How to run it in production¶
- Validate accuracy per model after any FP8/FP4/INT8 quantization or 2:4 pruning, before promotion. The throughput numbers assume accuracy holds; gate the change on a held-out eval, not on the speedup alone.
- Confirm the precision actually engaged. Check the Nsight Compute roofline and Speed-of-Light view for the memory-bound to compute-bound shift, and verify Tensor Core utilization rose. A run can "speed up" for unrelated reasons while the low-precision path never dispatched. See nsight profiling and observability.
- For FP16 training paths, keep the
GradScalerin the loop and watch for the scaler backing off (repeated inf/NaN gradients); a scaler that never stabilizes signals an unscaled underflow or an operand that should be BF16.
How to maintain it¶
- Validate accuracy per model after FP8/FP4/INT8 quantization or 2:4 pruning. The throughput numbers assume accuracy holds.
- Confirm the precision actually engaged: check the Nsight Compute roofline and Speed-of-Light view for the memory-to-compute-bound shift, and verify Tensor Core utilization rose.
- Re-check library versions when changing GPU architecture (
arch::Sm100etc.); NVIDIA updates CUTLASS/cuBLAS for new FP8/FP4/TMEM features each generation. Always pull a fresh CUTLASS on a new architecture. - Keep operand width high enough to feed Tensor Cores: small batches/tiles under-amortize conversion and sparse-index overhead and erode the gains.
How to scale it¶
Throughput scales with operand width and with how well the kernel is fed. Two levers dominate: structured sparsity (halve the weight traffic and let Sparse Tensor Cores skip the zeros) and granularity (keep batches and tiles large enough to amortize conversion, scaling, and sparse-index overhead).
2:4 structured sparsity (inference)¶
2:4 means exactly two of every four consecutive weights are nonzero: prune 50%, halve weight traffic, and feed Sparse Tensor Cores that skip the zeros and do twice the work in the same cycle budget (up to ~2x dense, batch-size dependent). Apply after training, then calibrate/fine-tune to recover accuracy. Convert in PyTorch:
# Reference template (requires torch >= 2.1 + a GPU with Sparse Tensor Cores).
# The 2:4 structure and skip-zeros equivalence are validated in numpy below.
import torch
from torch.sparse import to_sparse_semi_structured
# weight must already follow a 2:4 mask (2 nonzeros per 4 contiguous elements)
sparse_w = to_sparse_semi_structured(dense_w)
y = torch.mm(x, sparse_w.t()) # dispatches a sparse GEMM on Sparse Tensor Cores
Pruning and format conversion run through cuSPARSELt and framework tooling; the Transformer Engine accelerates supported sparse executions but does not enforce sparsity (PyTorch semi-structured sparsity). Keep batches large (128-256) so sparse-index overhead is amortized.
The 2:4 contract, in numpy: exactly two of every four weights survive (the two largest by magnitude), and the sparse product equals the dense masked matmul while touching only the kept weights. The corruption cases are masks that keep three or one per group, and the boundary case is a tie where magnitudes are equal.
# Runnable on system python3 (numpy). 2:4 structured sparsity keeps exactly 2 of every 4
# contiguous weights (the 2 largest by magnitude). A Sparse Tensor Core skips the two zeros,
# so its result must equal the dense matmul of the masked weight, proven here against a
# gather-only reference that never touches the zeros. A mask that is not 2-per-4 is rejected.
import numpy as np
def mask_2to4(W):
G = W.reshape(W.shape[0], -1, 4)
keep = np.argsort(-np.abs(G), axis=2)[..., :2] # 2 largest magnitudes per group
m = np.zeros_like(G, dtype=bool)
np.put_along_axis(m, keep, True, axis=2)
return m.reshape(W.shape)
def is_2to4(mask):
return bool((mask.reshape(mask.shape[0], -1, 4).sum(2) == 2).all())
rng = np.random.default_rng(5)
R, C = 8, 64
W = rng.standard_normal((R, C)).astype(np.float32)
x = rng.standard_normal((32, C)).astype(np.float32)
mask = mask_2to4(W)
Wp = W * mask
# (1) STRUCTURE: exactly 2 nonzeros per group of 4.
assert is_2to4(mask), "2:4 mask must keep exactly 2 of every 4"
# (2) SKIP-ZEROS EQUIVALENCE: a gather multiplying ONLY the 2 kept weights per group equals
# the dense masked matmul, so skipping the two zeros changes nothing.
G = Wp.reshape(R, -1, 4); xg = x.reshape(x.shape[0], -1, 4); idx = mask.reshape(R, -1, 4)
y_gather = np.zeros((x.shape[0], R), np.float32)
for r in range(R):
for g in range(G.shape[1]):
kept = np.flatnonzero(idx[r, g]) # the 2 kept lanes
y_gather[:, r] += xg[:, g, kept] @ G[r, g, kept]
y_dense = x @ Wp.T
assert np.allclose(y_gather, y_dense, rtol=1e-4, atol=1e-4), np.abs(y_gather - y_dense).max()
# (3) CORRUPTION: masks keeping 3 or 1 per group are not valid 2:4 and must be rejected.
bad = mask.copy(); bv = bad.reshape(R, -1, 4); bv[0, 0] = [True, True, True, False]
assert not is_2to4(bad), "3-of-4 must be rejected"
bad2 = mask.copy(); bv2 = bad2.reshape(R, -1, 4); bv2[0, 0] = [True, False, False, False]
assert not is_2to4(bad2), "1-of-4 must be rejected"
# (4) BOUNDARY / ties: a group of equal magnitudes still yields exactly 2 kept.
tie = np.ones((1, 4), np.float32); tie[0, 1] = -1.0 # equal |.|, mixed signs
assert int(mask_2to4(tie).sum()) == 2, "ties must still keep exactly 2"
print("2:4 sparsity OK: 2-per-4 structure, skip-zeros equivalence, corruption + ties handled")
# 2:4 sparsity OK: 2-per-4 structure, skip-zeros equivalence, corruption + ties handled
Beyond a single GEMM, throughput scales by keeping the format conversion and scaling amortized: batch and tile large (128-256), keep operand width high enough to feed the Tensor Cores, and let cuBLAS/CUTLASS pick tile shapes that fill the 256 KB per-SM TMEM budget. Small batches under-amortize conversion and sparse-index overhead and erode the gains.
Failure modes¶
- Accumulating in the operand format. FP8/FP4/INT8 partial sums swamp once the running total outgrows its own ULP. Always accumulate in BF16/FP16/FP32 (FP32 when in doubt); never in the operand format (validated in the Architecture block).
- FP16 gradient underflow. Gradients below FP16's 5-bit-exponent subnormal floor flush to zero and the signal is lost. Use
GradScaler, or switch to BF16, whose 8-bit exponent avoids the underflow entirely. - Low-precision overflow without scaling. FP8 and NVFP4 have a narrow range (NVFP4 saturates above +/-6); casting without per-block/per-tensor micro-scaling saturates the large values. Use the Transformer Engine paths rather than hand-rolled FP8/FP4 scaling.
- Accuracy-sensitive ops in low precision. Layer norm, softmax, and reductions degrade in FP16/BF16. Keep them in FP32; AMP already does this automatically.
- Small batch/tile under-amortization. FP8 or 2:4 at batch 1 sees little benefit because format conversion, scaling, and sparse-index overhead dominate. Keep batch/tile granularity large (128-256).
- 2:4 sparsity in training. Gradients do not benefit, maintaining sparsity in updates is complex, and 2:4 is primarily an inference feature. Training support is limited and framework-dependent; verify before relying on it.
- Assuming throughput without confirming the path engaged. A kernel can appear faster while the low-precision Tensor Core path never dispatched. Verify in Nsight Compute that Tensor Core utilization rose and the kernel shifted memory-bound to compute-bound; the reported multipliers assume both the path and the accuracy hold.
- TF32 silently reducing precision.
high/mediumtruncate fp32 matmuls to a 10-bit mantissa, which can change results for precision-sensitive fp32 code. Keephighestwhere full FP32 is required.
References¶
- Chris Fregly, AI Systems Performance Engineering (O'Reilly), Chapter 9: Increasing CUDA Kernel Efficiency and Arithmetic Intensity (Mixed Precision and Utilizing Tensor Cores; Structured Sparsity; Transformer Engine and TMEM; CUTLASS).
- NVIDIA, Accelerating AI Training with TF32 Tensor Cores: TF32 8-bit exponent / 10-bit mantissa.
- NVIDIA, Ampere GPU Architecture Tuning Guide: BF16/FP16/TF32 Tensor Core formats.
- NVIDIA, Using FP8 and FP4 with Transformer Engine and NVFP4 format: E2M1, 16-element microblock E4M3 scale + per-tensor FP32 scale.
- PyTorch, torch.set_float32_matmul_precision:
highest/high/mediumand TF32 equivalence. - PyTorch, Automatic Mixed Precision (torch.amp):
autocast,GradScaler, BF16 vs FP16. - PyTorch, Accelerating BERT with semi-structured (2:4) sparsity:
to_sparse_semi_structured.
Related: Tensor Core Programming · Roofline Model and Arithmetic Intensity · Kernel Fusion · Shared Memory, Bank Conflicts, and Tiling · Memory Coalescing and Vectorized Access · CUDA Occupancy Tuning · GPU Memory Hierarchy · GPU Execution Model: SMs, Warps, and SIMT · NVIDIA Blackwell Datacenter Platform · NVIDIA Hopper Platform · Glossary