Skip to content
Markdown

Kernel fusion

Scope: fusing a chain of operations into one CUDA kernel to raise arithmetic intensity, eliminate HBM round-trips for intermediate tensors, and remove per-launch overhead, including when fusion pays off (memory-bound elementwise/reduction chains) and when it does not (register pressure, compute-bound GEMMs).

What it is

Kernel fusion combines multiple operations (sequential math, loop iterations, or independent ops on the same input) into a single kernel launch, so a value loaded from global memory feeds several computations before being written back. The intermediate results never leave on-chip storage (registers or shared memory), so they cost no HBM bandwidth. [Fregly, Ch. 9]

The canonical example: two elementwise ops.

// Naive: two kernels, two HBM round-trips.
// Kernel 1 reads x, writes y. Kernel 2 reads y, writes z.
y[i] = sinf(x[i]);
z[i] = sqrtf(y[i]);

// Fused: one kernel. x loaded once, sin and sqrt applied in registers,
// only z written. Intermediate y never touches global memory.
z[i] = sqrtf(sinf(x[i]));

In the naive form each element is touched twice and each kernel does one expensive ALU op per load/store: very low arithmetic intensity (FLOPs/byte). The fused form doubles FLOPs per element while roughly halving memory traffic, so the effective FLOPs/byte jumps and the kernel moves toward the compute roof on the Roofline model. [Fregly, Ch. 9]

Two structural flavors, both used by state-of-the-art kernels:

  • Vertical fusion combines sequential operations on the same data (e.g. sin then sqrt; a reduction with the ops before and after it; a matmul with its bias+activation epilogue). [Fregly, Ch. 9]1
  • Horizontal fusion combines independent operations that read the same input, loading it once (e.g. computing sin(x) and cos(x) in one kernel). [Fregly, Ch. 9]1

Architecture

The whole benefit comes from where the intermediate lives. Unfused, the intermediate is materialized in HBM (global memory) and read back over the same slow link. Fused, it stays in registers or shared memory, on-chip storage that is roughly two orders of magnitude faster and costs no HBM bandwidth. The diagram contrasts the two data paths for the sin then sqrt chain.

flowchart LR
  X["Input x (HBM)"]
  subgraph UNFUSED["Unfused: 2 kernels, ~4 FLOPs / 36 bytes"]
    K1["Kernel 1: y = sin(x)"]
    HBM["Intermediate y<br/>written to HBM, read back"]
    K2["Kernel 2: z = sqrt(y)"]
  end
  subgraph FUSED["Fused: 1 kernel, ~4 FLOPs / 12 bytes"]
    KF["Single kernel<br/>z = sqrt(sin(x))<br/>intermediate stays in registers"]
  end
  Z["Output z (HBM)"]
  X --> K1 --> HBM --> K2 --> Z
  X --> KF --> Z

The load-bearing structural choice is the memory tier the intermediate occupies. A fused kernel loads each input once from HBM, does all the chained math against on-chip storage, and writes each output once. The three tiers it plays across, slowest to fastest:

  • HBM (global memory): where inputs arrive and outputs leave. Every unfused kernel boundary adds one write plus one read here. This is the traffic fusion removes.
  • Shared memory: a per-block scratchpad, used to stage data read by multiple threads in a block (the reduction buffer sdata below) so redundant global loads disappear.
  • Registers: per-thread on-chip storage where the fused intermediate (sinf(x[i]), the running sum of squares) actually lives. Cheapest of all, but scarce, which is exactly what caps how much you can fuse (see Failure modes).

Why use it

Fusion attacks all three costs of a multi-kernel pipeline at once:

  1. HBM traffic. Every unfused boundary forces a write of the intermediate to global memory and a read back in the next kernel. Fusion keeps intermediates in registers/shared memory. For the L2-norm example below, fusion takes the data movement from ~4 FLOPs per 36 bytes (naive, with intermediate writes/reads) to ~4 FLOPs per 12 bytes (fused, two reads + one write). [Fregly, Ch. 9]
  2. Arithmetic intensity. More FLOPs per byte moves the kernel rightward on the Roofline model, out of the memory-bound regime where the GPU stalls waiting on data. Modern GPUs advance compute throughput faster than HBM bandwidth, so raising FLOPs/byte is increasingly the dominant lever. [Fregly, Ch. 9]
  3. Launch overhead and barriers. Collapsing N kernels into one removes N−1 launch latencies and the implicit barriers between them, and improves cache locality because data stays resident.

Fusion is one of the standard arithmetic-intensity techniques alongside shared-memory tiling, memory coalescing, and reduced precision on Tensor Cores.

The byte-and-FLOP accounting behind those three claims is small enough to check exactly. The block below models the sin then sqrt chain and the L2-norm, and asserts the invariant fusion must satisfy: strictly fewer bytes moved and strictly higher FLOPs/byte, including the boundary case of zero intermediates where fusion can only tie, never lose. It runs on numpy alone.

# numpy-only accounting of arithmetic intensity (FLOPs/byte) for the two
# examples this page teaches: the sin->sqrt chain and the L2-norm.
# Invariant: fusion must strictly lower bytes moved and raise FLOPs/byte
# whenever there is an intermediate, and never move more bytes than unfused.
import numpy as np

F32 = 4  # bytes per float32 element

def ai(flops, bytes_moved):
    return flops / bytes_moved

# sin(x) -> sqrt(x): unfused reads x, writes y, reads y, writes z (4 transfers);
# fused reads x, writes z (2), the intermediate y staying in registers.
sin_unfused_bytes = 4 * F32
sin_fused_bytes = 2 * F32
sin_flops = 2  # sin + sqrt
assert sin_fused_bytes < sin_unfused_bytes
assert ai(sin_flops, sin_fused_bytes) > ai(sin_flops, sin_unfused_bytes)
assert sin_unfused_bytes - sin_fused_bytes == 2 * F32  # exactly the y round-trip

# L2-norm per element: fused = read to accumulate + read to write + write = 3
# transfers = 12 B, ~3 FLOPs -> ~0.25 FLOPs/byte, matching the figure below.
l2_fused_bytes = 3 * F32
l2_fused_flops = 3
assert l2_fused_bytes == 12
assert abs(ai(l2_fused_flops, l2_fused_bytes) - 0.25) < 1e-9
# unfused pipeline materializes sq[B,H] and re-reads it: >= 5 transfers/elem.
l2_unfused_bytes = 5 * F32
assert l2_unfused_bytes > l2_fused_bytes
assert ai(l2_fused_flops, l2_fused_bytes) > ai(l2_fused_flops, l2_unfused_bytes)

# Edge / equivalence-to-reference case: for any monotone op chain the byte model
# must never claim fusion moves MORE bytes; with >=1 intermediate it is strictly
# cheaper, and at 0 intermediates (the boundary) it exactly ties.
for n_intermediates in range(0, 8):
    unfused = (2 + 2 * n_intermediates) * F32
    fused = 2 * F32
    assert fused <= unfused
    assert (fused < unfused) == (n_intermediates > 0)

print("OK: fusion lowers bytes and raises FLOPs/byte; L2 fused = 12 B/elem, ~0.25 FLOPs/B")

When to use it (and when not)

Fuse when:

  • The chain is memory-bound: elementwise/pointwise chains, normalizations (LayerNorm, RMSNorm, L2-norm), activations, bias-add, dropout, scale, anything whose intermediates are large relative to the math performed. These are exactly the chains TorchInductor targets first. [Fregly, Ch. 9]1
  • Data is read more than once by threads in the same block: stage it in shared memory to kill redundant global loads. [Fregly, Ch. 9]
  • You are amortizing launch overhead across many small ops.

Do not fuse (or profile carefully first) when:

  • Register pressure spills. Fusion and loop unrolling consume registers; too much fusion drops occupancy and, if registers spill to local memory, the added traffic can erase the win. Always profile fused kernels; if register usage is excessive and spilling, the fusion may be net-negative. [Fregly, Ch. 9]
  • The kernel is already compute-bound (e.g. a large dense GEMM saturating Tensor Cores). Fusion targets memory traffic; it does not help a kernel already at the compute roof. Use epilogue fusion to attach cheap ops to such a GEMM, not to merge two heavy GEMMs.
  • The op is nonstandard or a custom CUDA op the compiler cannot see: TorchInductor compiles and fuses most ATen ops but not arbitrary custom kernels, so manual fusion may be required. [Fregly, Ch. 9]

Verify the regime first with the Nsight profiling workflow: a kernel showing high "Memory Throttle" / memory-bound stalls on Speed of Light is a fusion candidate; one already at the compute roof is not. [Fregly, Ch. 9]

How to use it

In PyTorch the first move is torch.compile; its TorchInductor backend automatically fuses chains of elementwise (pointwise) operations and some reductions into generated Triton/C++ kernels, so you get fusion without writing a kernel. [Fregly, Ch. 9]1

The snippet below is a reference template: it needs torch (not installed in this environment), so it is not executed here. The core byte/FLOP math it relies on is validated by the numpy block in Why use it and the equivalence block in How to run it in production.

# Reference template (requires torch; not runnable here).
import torch

@torch.compile
def block(x, w, b):
    # bias-add -> GELU -> dropout: a pointwise chain Inductor fuses
    # into a single kernel after the matmul.
    return torch.nn.functional.dropout(
        torch.nn.functional.gelu(x @ w + b), p=0.1, training=True
    )

Inspect the generated kernels to confirm fusion happened:

TORCH_COMPILE_DEBUG=1 python block.py
# read output_code.py in the dumped debug dir to see the fused Triton kernel

How to integrate with it

Fusion integrates into an existing model best through the ops you already call, not by rewriting hot paths as custom kernels. Prefer fused library calls and high-level ops (torch.matmul, torch.nn.functional activations) over long sequences of small Python-level elementwise ops; the libraries call efficient kernels and the compiler fuses them with surrounding work. For attention, scaled_dot_product_attention dispatches to a fused FlashAttention/cuDNN/memory-efficient backend by shape and dtype; pin it explicitly when needed. [Fregly, Ch. 9]2

# Reference template (requires torch; not runnable here).
from torch.nn.attention import SDPBackend, sdpa_kernel

with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    out = torch.nn.functional.scaled_dot_product_attention(q, k, v)

See FlashAttention / MLA for the fused-attention case in depth.

How to run it in production

When the compiler cannot fuse a chain (a custom op it cannot see, or a shape it will not target), write one kernel by hand that loads once, computes in registers/shared memory, and writes once. Concrete example: L2-normalize each hidden-length row of x (shape [batch, hidden]), where norm_b = sqrt(Σ_i x[b,i]^2 + ε) and y[b,i] = x[b,i] / norm_b. A naive pipeline is square → reduce → divide across separate kernels with intermediate HBM writes; the fused kernel collapses them. [Fregly, Ch. 9]

__global__ void fusedL2Norm(const float* __restrict__ x,
                            float* __restrict__ y,
                            int hidden) {
  extern __shared__ float sdata[];        // reduction buffer
  const int batch = blockIdx.x;           // one block per batch row
  const int tid   = threadIdx.x;
  const float* batch_ptr = x + size_t(batch) * hidden;

  // 1) Accumulate sum of squares into shared memory.
  float local = 0.f;
  for (int i = tid; i < hidden; i += blockDim.x) {
    float v = batch_ptr[i];
    local = fmaf(v, v, local);            // v*v + local, single rounding
  }
  sdata[tid] = local;
  __syncthreads();

  // 2) Parallel reduction to sdata[0].
  for (int offset = blockDim.x >> 1; offset > 0; offset >>= 1) {
    if (tid < offset) sdata[tid] += sdata[tid + offset];
    __syncthreads();
  }

  // 3) Normalize: multiply by inverse (rsqrt) instead of dividing by sqrt.
  float inv = rsqrtf(sdata[0]);
  float* out_batch = y + size_t(batch) * hidden;
  for (int i = tid; i < hidden; i += blockDim.x) {
    out_batch[i] = batch_ptr[i] * inv;
  }
}

Key points, all load-bearing:

  • __restrict__ on the pointers tells the compiler x and y do not alias, enabling reuse of loaded values and reordering. 4
  • fmaf(v, v, local) is a fused multiply-add: v*v + local computed with a single rounding step, one instruction. 3
  • rsqrtf (reciprocal square root) plus a multiply replaces a divide. The intrinsic reciprocal-square-root runs on the SM's special-function units, so it is much higher throughput than a division-then-IEEE-sqrt stream, a standard hot-loop micro-optimization that trades a small accuracy cost. [Fregly, Ch. 9]3
  • Each thread walks its slice twice (accumulate, then write), so global traffic is ~12 bytes per element (two reads + one write). A conservative arithmetic intensity is ≈3 FLOPs / 12 bytes ≈ 0.25 FLOPs/byte. The pipeline stays memory-bound; the win is removing the intermediate HBM round-trips and launch barriers of the 3-kernel version, plus cache locality, not crossing into compute-bound. [Fregly, Ch. 9]

Launch with dynamic shared memory sized to the reduction buffer (blockDim.x * sizeof(float)):

int threads = 256;
size_t smem = threads * sizeof(float);
fusedL2Norm<<<batch, threads, smem>>>(x, y, hidden);

The CUDA kernel needs a GPU, so before shipping it the thing to lock down is its numerics: the fused single-pass form must equal the naive square → reduce → divide pipeline it replaces, and the rsqrtf-times-multiply must equal a divide-by-sqrt. The block below is that equivalence-to-reference test in numpy, and it includes the adversarial all-zero row where a missing ε floor would divide by zero and produce NaN/inf. It runs on numpy alone.

# numpy-only model of the fusedL2Norm kernel above. Proves the fused single
# pass equals the naive 3-kernel pipeline (square -> reduce -> divide), that
# rsqrt*x equals x/sqrt, and that the epsilon floor keeps an all-zero row finite.
import numpy as np

def l2norm_unfused(x, eps=1e-6):
    sq = x * x                          # kernel 1: square      -> [B,H]
    ssum = sq.sum(axis=1, keepdims=1)   # kernel 2: row reduce  -> [B,1]
    return x / np.sqrt(ssum + eps)      # kernel 3: divide      -> [B,H]

def l2norm_fused(x, eps=1e-6):
    ssum = np.einsum("bh,bh->b", x, x)[:, None]  # accumulate in "registers"
    inv = 1.0 / np.sqrt(ssum + eps)              # rsqrt: one reciprocal
    return x * inv                               # multiply, not divide

rng = np.random.default_rng(0)
x = rng.standard_normal((7, 512)).astype(np.float32)

# 1) Fused equals the unfused reference pipeline within float32 tolerance.
a, b = l2norm_fused(x), l2norm_unfused(x)
assert np.allclose(a, b, atol=1e-5), np.abs(a - b).max()

# 2) The divide -> rsqrt swap the kernel makes is numerically equivalent.
v = np.abs(rng.standard_normal(4096)).astype(np.float32) + 0.1
assert np.allclose(v * (1.0 / np.sqrt(v)), v / np.sqrt(v), atol=1e-5)

# 3) Adversarial: an all-zero row. Without the eps floor this is 0/0 -> NaN.
#    The floor must keep every output finite and send the zero row to zero.
xz = x.copy(); xz[3] = 0.0
out = l2norm_fused(xz)
assert np.isfinite(out).all(), "eps floor failed: non-finite output on zero row"
assert np.all(out[3] == 0.0)
raw = np.sqrt(np.einsum("bh,bh->b", xz, xz))  # the floor genuinely fires:
assert raw[3] == 0.0 and (raw[3] + 1e-6) > 0.0  # 0.0 without eps, > 0 with it

# 4) Non-zero rows come out unit-norm (the operation's contract), up to eps.
assert np.allclose(np.linalg.norm(a, axis=1), 1.0, atol=1e-3)

print("OK: fused==unfused, rsqrt==divide, eps floor finite on zero row, rows unit-norm")

How to maintain it

Fusion is a trade-off, so verify each fused kernel on real hardware rather than assuming a win:

ncu --set roofline --section SpeedOfLight \
    --metrics launch__registers_per_thread,smsp__sass_average_branch_targets_threads_uniform.pct \
    ./fused_app
  • Check launch__registers_per_thread: a jump after fusion that drops occupancy, or any local-memory spill, signals over-fusion; back off. [Fregly, Ch. 9]
  • On the Roofline / Speed-of-Light view, confirm the arithmetic-intensity point moved right and memory-bound stall reasons ("Memory Throttle", cache misses) dropped. [Fregly, Ch. 9]

See the Nsight profiling workflow for the full counter-driven loop and performance optimization for where fusion sits among the other levers.

How to scale it

Beyond one kernel, scaling fusion means not hand-writing the hardest cases. For library-grade vertical and horizontal fusion (epilogue fusion, GEMM+activation, Tensor Core paths), reach for NVIDIA CUTLASS or OpenAI Triton (the TorchInductor backend) rather than hand-writing MMA pipelines; these give you tuned, portable fused kernels across shapes and architectures. [Fregly, Ch. 9] At model scale the same principle drives horizontal fusion of many independent small ops that share an input, and epilogue fusion that attaches cheap pointwise work onto a GEMM already saturating the compute roof so the extra ops cost no additional HBM traffic.

Failure modes

Fusion is not free; the same mechanism that removes HBM traffic can backfire. Each row is a failure, the symptom you see in the profiler, and the fix.

  • Register-pressure spill. Fusion and loop unrolling consume registers; too much fusion drops occupancy and, once registers spill to local (HBM-backed) memory, the added traffic can erase the win and leave the kernel net-negative. Symptom: launch__registers_per_thread jumps and local-memory load/store counters appear. Fix: back off the fusion or split the kernel; always profile before shipping. [Fregly, Ch. 9]
  • Already compute-bound. A large dense GEMM saturating Tensor Cores is at the compute roof; fusion targets memory traffic and does nothing here. Symptom: Speed-of-Light shows high compute utilization, low memory-throttle. Fix: use epilogue fusion to attach cheap ops to the GEMM, do not merge two heavy GEMMs.
  • Op the compiler cannot see. TorchInductor compiles and fuses most ATen ops but not arbitrary custom CUDA kernels, so a custom op silently breaks the fusion chain around it. Symptom: the debug output_code.py shows separate kernels across the boundary. Fix: express the op in fusible primitives, or hand-write the fused kernel. [Fregly, Ch. 9]
  • Missing epsilon floor in a normalization. A fused norm that divides by sqrt(sum_of_squares) without an ε floor produces NaN/inf on an all-zero (or denormal) row. Symptom: NaNs propagate downstream. Fix: add the ε floor inside the reduction (validated in the production block above).

The register/occupancy cliff and the spill-erases-the-win failure are exact enough to model. The block below asserts occupancy is monotonically non-increasing in registers/thread, shows the concrete cliff (32 to 128 registers quarters resident warps), and models the spill point at which the fused kernel moves more bytes than the unfused one. It also confirms horizontal fusion (sin and cos from one load) is numerically equivalent to two passes. It runs on numpy alone.

# numpy-only model of the register-pressure occupancy cliff and the
# "spilling erases the win" failure mode, plus horizontal-fusion equivalence.
import numpy as np

def resident_warps(regs_per_thread, regfile=65536, warp=32, max_warps=64):
    # Warps resident on an SM = register file / (regs/thread * 32), capped.
    if regs_per_thread <= 0:
        return max_warps
    return min(max_warps, regfile // (regs_per_thread * warp))

# Occupancy is monotonically non-increasing as fusion raises regs/thread.
prev = resident_warps(1)
for r in range(1, 129):
    cur = resident_warps(r)
    assert cur <= prev, (r, cur, prev)
    prev = cur

# Concrete cliff: a light kernel at 32 regs is full; over-fused at 128 is 1/4.
assert resident_warps(32) == 64          # 65536/(32*32) = 64, hits the cap
assert resident_warps(128) == 16         # 65536/(128*32) = 16
assert resident_warps(128) < resident_warps(32)

# Spill failure: effective bytes = fused base + spilled-register HBM traffic.
# Enough spill makes the fused kernel move MORE bytes than the unfused pipeline.
def fused_effective_bytes(base=12, spill_regs=0, bytes_each=4):
    return base + spill_regs * bytes_each

unfused_bytes = 20
assert fused_effective_bytes(spill_regs=0) < unfused_bytes   # win, no spill
assert fused_effective_bytes(spill_regs=3) > unfused_bytes   # spill erases win
assert fused_effective_bytes(spill_regs=5) > fused_effective_bytes(spill_regs=2)

# Horizontal fusion: sin(x) and cos(x) from ONE load equal two separate passes,
# and read the input once instead of twice.
rng = np.random.default_rng(1)
x = rng.standard_normal(1000).astype(np.float32)
assert np.allclose(np.sin(x), np.sin(x.copy())) and np.allclose(np.cos(x), np.cos(x.copy()))
assert (1 + 2) * 4 < (2 + 2) * 4          # 1 read + 2 writes < 2 reads + 2 writes

print("OK: occupancy monotone in regs/thread, 128->16 warp cliff, spill erases win, horizontal fusion equivalent")

References

  • Chris Fregly, AI Systems Performance Engineering (O'Reilly), Ch. 9, Increasing CUDA Kernel Efficiency and Arithmetic Intensity ("Kernel Fusion", "PyTorch and Arithmetic Intensity"). Primary source for the fusion technique, the L2-norm fused-kernel example, the FLOPs/byte figures, and the vertical/horizontal fusion taxonomy.

Related: Roofline Model and Arithmetic Intensity · Shared Memory, Bank Conflicts, and Tiling · Memory Coalescing and Vectorized Access · CUDA Occupancy Tuning · Tensor Cores and Mixed Precision · FlashAttention and Multi-Head Latent Attention · Profiling GPUs: Nsight Systems and Nsight Compute · Performance Optimization and Tuning · Frameworks · Glossary


  1. PyTorch Blog, "Why Is PyTorch Compile So Fast: Kernel Fusion." Confirms TorchInductor pointwise/vertical/horizontal/epilogue fusion and the read/write-reduction rationale. https://pytorch.org/blog/why-is-pytorch-compile-so-fast-kernel-fusion/ 

  2. PyTorch docs, torch.nn.attention.sdpa_kernel. Backend selection (SDPBackend.FLASH_ATTENTION) for fused scaled dot-product attention. https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html 

  3. NVIDIA CUDA Math API Reference, Single Precision Intrinsics. fmaf (single-rounding multiply-add) and rsqrtf/intrinsic reciprocal-sqrt throughput on the SFUs. https://docs.nvidia.com/cuda/cuda-math-api/cuda_math_api/group__CUDA__MATH__INTRINSIC__SINGLE.html 

  4. NVIDIA CUDA C++ Programming Guide, __restrict__ pointer aliasing and the optimizations it enables. https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#restrict