Skip to content
Markdown

OpenAI Triton: authoring GPU kernels in Python

Scope: Triton's block-based programming model for writing fused GPU kernels in Python (the language torch.compile emits), its autotuner, and when a Triton kernel beats eager PyTorch or a chain of composed CUDA libraries.

What it is

Triton is an open-source Python DSL and compiler for GPU kernels. You write a kernel as a @triton.jit-decorated Python function; the compiler lowers it through Triton IR to LLVM IR and finally to PTX/SASS for NVIDIA targets. The book lists Triton alongside cuBLASLt, cuDNN, and CUTLASS as the kernel libraries that drive cp.async/TMA transfers into shared memory and feed the Tensor Cores (Fregly, Ch. 9). Unlike those C++ libraries, Triton's surface is Python, and its unit of work is a block (tile), not a thread.

The defining abstraction: a Triton program instance operates on a block of elements, addressed by a program id, and the compiler handles the intra-block thread mapping, memory coalescing, and shared-memory staging for you. The canonical vector-add kernel shows the whole model, verbatim from the official tutorial (Triton vector-add tutorial). It needs torch and triton (neither installed in this environment), so it is a reference template, not executed here; the numpy block just below validates the block/tile addressing and the mask guard it depends on.

# Reference template (requires torch + triton; not runnable here).
import torch
import triton
import triton.language as tl

@triton.jit
def add_kernel(x_ptr, y_ptr, output_ptr, n_elements,
               BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements           # guard the ragged tail
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    tl.store(output_ptr + offsets, output, mask=mask)

def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    output = torch.empty_like(x)
    n_elements = output.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
    return output

Key elements, each from the official API: tl.program_id(axis=0) returns this instance's block index; tl.arange(0, BLOCK_SIZE) builds the per-block offset vector; mask = offsets < n_elements and the mask= argument to tl.load/tl.store guard out-of-bounds lanes; triton.cdiv(n, b) is ceiling division for the launch grid; and kernel[grid](...) is the launch, where indexing the kernel with the grid yields a callable (Triton vector-add tutorial, triton Python API).

The block/tile model above is small enough to check exactly in numpy: a 1-D grid of program instances, each owning one BLOCK_SIZE tile, with mask = offsets < n_elements on the ragged final tile. The block below proves the tiled+masked launch equals a plain elementwise add, that the grid is ceil(n / BLOCK), and (the adversarial case) that dropping the mask lets the last block read past the end. It runs on numpy alone.

# numpy-only model of the vector-add kernel's block/tile programming model:
# one program instance per BLOCK_SIZE tile, masked to n_elements on the tail.
# Proves tiled+masked add == elementwise add and that the mask is what stops the
# final block from reading/writing out of bounds.
import numpy as np


def cdiv(n, b):
    return -(-n // b)                     # triton.cdiv: ceiling division


def add_tiled(x, y, BLOCK_SIZE):
    n = x.shape[0]
    out = np.empty_like(x)
    grid = cdiv(n, BLOCK_SIZE)
    touched = np.zeros(n, dtype=bool)
    for pid in range(grid):               # one instance per program_id
        offs = pid * BLOCK_SIZE + np.arange(BLOCK_SIZE)
        mask = offs < n                   # guard the ragged tail
        valid = offs[mask]
        out[valid] = x[valid] + y[valid]
        touched[valid] = True
    return out, touched, grid


rng = np.random.default_rng(0)

# 1) Ragged case: n not a multiple of BLOCK_SIZE. Tiled+masked == plain add,
#    every element written exactly once, grid == ceil(n / BLOCK).
n, B = 10_000, 1024
x = rng.standard_normal(n).astype(np.float32)
y = rng.standard_normal(n).astype(np.float32)
out, touched, grid = add_tiled(x, y, B)
assert grid == cdiv(n, B) == 10           # ceil(10000/1024) = 10 blocks
assert np.array_equal(out, x + y)         # equals the reference elementwise add
assert touched.all()                      # no element skipped
assert grid * B >= n and (grid - 1) * B < n

# 2) Exact-multiple boundary: the tail block is full, mask is all-true.
n2, B2 = 4096, 1024
x2, y2 = rng.standard_normal(n2).astype(np.float32), rng.standard_normal(n2).astype(np.float32)
out2, touched2, grid2 = add_tiled(x2, y2, B2)
assert grid2 == 4 and touched2.all()
assert np.array_equal(out2, x2 + y2)

# 3) Adversarial: without the mask the last block strides past the end. Model the
#    out-of-bounds lanes as NaN padding; a correct masked read must exclude them.
n3, B3 = 1000, 256                        # grid = 4 covers 1024 > 1000
pad = cdiv(n3, B3) * B3
xb = np.concatenate([rng.standard_normal(n3), np.full(pad - n3, np.nan)])
last_pid = cdiv(n3, B3) - 1
offs = last_pid * B3 + np.arange(B3)
mask = offs < n3
unmasked = xb[offs]                       # would pull in the NaN tail
masked = xb[offs[mask]]                   # must not
assert np.isnan(unmasked).any(), "tail must contain OOB (NaN) lanes to guard"
assert np.isfinite(masked).all(), "mask failed: OOB NaN lanes leaked into the tile"
assert mask.sum() == n3 - last_pid * B3   # exactly the valid remainder

print("OK: tiled+masked add == elementwise add; grid=ceil(n/B); mask suppresses OOB tail")

The reason this matters for an infra engineer: Triton is the language behind torch.compile. PyTorch's TorchInductor backend decomposes eager PyTorch and re-assembles it into "a high percentage of Triton kernels with PyTorch connecting code" (PyTorch: Accelerating Triton Dequantization). When you read the generated code under TORCH_LOGS=output_code, you are reading Triton. Understanding the model is how you debug a torch.compile regression down to the kernel.

Why use it

The win is the same arithmetic-intensity win as Kernel Fusion, reached from Python instead of C++. A composed sequence of PyTorch ops launches one kernel per op, and each kernel reads its inputs from HBM and writes its result back to HBM. For a memory-bound chain (elementwise, normalization, softmax, attention), that traffic, not the math, sets the runtime. See Roofline Model and Arithmetic Intensity.

Triton's fused-softmax tutorial quantifies it. A naive PyTorch softmax over an M x N matrix reads 5MN + 2M elements and writes 3MN + 2M (each op round-trips through DRAM); a single fused kernel "reads X once and does all the necessary computations on-chip," reading and writing MN, a theoretical ~4x reduction in memory traffic (Triton fused-softmax tutorial). The block-level model is what makes the fusion natural: the whole row lives in registers/SRAM across the max, subtract, exp, sum, and divide, so the intermediates never touch HBM.

Against composed CUDA libraries (cuBLAS + a separate elementwise kernel), Triton's advantage is the same fusion plus author productivity: you express a custom fused tail in Python instead of writing a CUTLASS epilogue. The book frames CUTLASS and Triton as the two "fuse it yourself" paths from opposite ends: CUTLASS for control and the performance ceiling, Triton for iteration speed (Fregly, Ch. 9).

The numbers above are the tutorial's analytical model, not a benchmark measured on this KB's hardware. Treat the ~4x as the traffic-reduction ceiling, not a guaranteed wall-clock speedup; the realized speedup depends on the GPU's HBM bandwidth and the kernel's occupancy.

Both claims (the traffic model and the on-chip stable softmax the fusion computes) are exact enough to check in numpy. The block below asserts the fused kernel reads/writes exactly MN while the naive form reads 5MN + 2M and writes 3MN + 2M, that the total-traffic ratio approaches the 4x ceiling, and (the adversarial case) that the max-subtract softmax the fusion keeps in SRAM stays finite on large logits where a non-stable exp overflows to inf. It runs on numpy alone.

# numpy-only model of the fused-softmax tutorial: (1) the memory-traffic model
# naive 5MN+2M read / 3MN+2M write vs fused MN/MN (~4x ceiling); (2) the
# numerically-stable softmax (subtract row max) that the fusion computes on-chip.
# Adversarial: large logits where non-stable exp overflows but stable stays finite.
import numpy as np


def naive_traffic(M, N):
    return 5 * M * N + 2 * M, 3 * M * N + 2 * M     # reads, writes


def fused_traffic(M, N):
    return M * N, M * N                             # read X once, write Y once


def softmax_stable(x):
    z = x - x.max(axis=1, keepdims=True)            # max, subtract (on-chip)
    e = np.exp(z)
    return e / e.sum(axis=1, keepdims=True)


def softmax_naive(x):
    e = np.exp(x)                                   # no max-subtract: overflows
    return e / e.sum(axis=1, keepdims=True)


M, N = 1823, 781                          # tutorial's non-round benchmark shape

# 1) Traffic model: fused reads/writes exactly MN; naive matches 5MN+2M / 3MN+2M.
rn, wn = naive_traffic(M, N)
rf, wf = fused_traffic(M, N)
assert rn == 5 * M * N + 2 * M and wn == 3 * M * N + 2 * M
assert rf == M * N and wf == M * N
ratio = (rn + wn) / (rf + wf)             # -> (5+3)/(1+1) = 4 as N grows
assert 3.9 < ratio < 4.01, ratio
assert abs((5 + 3) / (1 + 1) - 4.0) < 1e-12       # the asymptotic ceiling is 4x

# 2) Equivalence to reference: stable softmax rows are valid distributions and
#    softmax is shift-invariant (why subtracting the row max is free).
rng = np.random.default_rng(0)
x = rng.standard_normal((M, N)).astype(np.float32)
p = softmax_stable(x)
assert np.allclose(p.sum(axis=1), 1.0, atol=1e-5)
assert (p >= 0).all() and (p <= 1).all()
assert np.allclose(softmax_stable(x), softmax_stable(x + 12.5), atol=1e-5)

# 3) Adversarial: huge logits. Naive exp overflows -> inf/NaN; the max-subtract
#    form the fusion computes stays finite and still sums to 1.
big = np.array([[1000.0, 1001.0, 1002.0]], dtype=np.float32)
with np.errstate(over="ignore", invalid="ignore"):
    naive = softmax_naive(big)            # deliberately overflows: that is the point
stable = softmax_stable(big)
assert not np.isfinite(naive).all(), "expected naive exp to overflow on 1e3 logits"
assert np.isfinite(stable).all(), "stable softmax must stay finite"
assert np.allclose(stable.sum(axis=1), 1.0, atol=1e-6)
assert stable.argmax() == big.argmax()

# 4) Edge: a row of all-equal logits -> uniform 1/N.
uni = softmax_stable(np.zeros((1, 64), dtype=np.float32))
assert np.allclose(uni, 1.0 / 64, atol=1e-6)

print("OK: fused traffic MN/MN, naive 5MN+2M/3MN+2M, ~4x ceiling; "
      "stable softmax finite where naive overflows; uniform on equal logits")

When to use it (and when not)

Reach for library kernels first. For a standard GEMM use cuBLAS/cuBLASLt; for convolution and standard attention/normalization use cuDNN. These are heavily tuned and are the PyTorch default. A hand-written Triton matmul will, for regular shapes, usually tie or lose to cuBLAS at much higher engineering cost.

Reach for Triton when:

  • You have a memory-bound op chain (fused elementwise + reduction, RMSNorm, softmax, fused attention, dequantize-then-matmul, MoE routing) where killing DRAM round-trips via fusion is the win, and you want it in Python.
  • You need a kernel torch.compile does not already generate well: a custom fused op you can drop in as a torch.library custom op, or a shape where Inductor's autotuner underperforms.
  • You want fast iteration over a custom kernel: change the Python, re-autotune, re-profile, in seconds, versus a CUTLASS template recompile.
  • You are debugging a torch.compile regression and need to read, modify, or pin the generated Triton.

Reach for CUTLASS or inline PTX/SASS (Inline PTX and SASS-Level Tuning) instead when you need the last few percent of Tensor Core throughput, fine-grained control of the pipeline (explicit warp specialization, thread-block clusters, DSMEM, TMEM), or features Triton does not expose for your target architecture. Triton trades some peak performance and low-level control for productivity.

Do not reach for Triton to beat cuBLAS on a one-off regular-shaped matmul, and do not hand-write a Triton kernel for an op torch.compile already fuses. Inspect the generated code first.

Architecture

Triton's architecture is a lowering pipeline plus one load-bearing abstraction. The abstraction is the block (tile): the unit of work is not a thread but a block of elements addressed by a program id, and the compiler owns everything below that line. From the @triton.jit Python source, the compiler produces Triton IR, lowers that to LLVM IR, and emits PTX/SASS for the NVIDIA target; along the way it performs the thread-to-element mapping, memory coalescing, shared-memory staging, and software pipelining that a CUDA author writes by hand. That is the trade: you keep the block-level algorithm in Python and give up the per-thread control that CUTLASS and inline PTX retain, which is also where their performance ceiling comes from.

flowchart TB
    K["@triton.jit kernel (Python)<br/>unit of work = block/tile"]
    TIR["Triton IR"]
    LIR["LLVM IR"]
    PTX["PTX / SASS"]
    GPU["NVIDIA GPU<br/>Tensor Cores, shared memory"]
    AUTO["Compiler handles automatically:<br/>thread mapping, coalescing,<br/>shared-memory staging, pipelining"]
    CUDA["Hand-written CUDA / CUTLASS (C++)<br/>manual control, performance ceiling"]

    K --> TIR --> LIR --> PTX --> GPU
    AUTO -.-> TIR
    CUDA -->|"fuse it yourself: manual coalescing & SMEM"| GPU

The two boxes that feed the GPU are the two "fuse it yourself" paths the book contrasts (Fregly, Ch. 9): the Triton path (top) auto-generates the coalescing and SMEM staging, while the CUDA/CUTLASS path (bottom) makes you write them for maximum control. Because both terminate in PTX/SASS on the same hardware, the same Nsight profiling workflow reads either one, which is what lets you debug a torch.compile regression down to the SASS regardless of which path emitted it.

How to use it

The production pattern is an autotuned block-GEMM. The matmul tutorial shows it: tl.dot for the Tensor Core MMA, FP32 accumulation, L2-aware block ordering, and mask/other guards on the K-loop boundary. Signatures and lines are verbatim from the official tutorial (Triton matmul tutorial). It needs torch and triton, so it is a reference template, not executed here; the numpy block just below validates the blocked-GEMM math, the other=0.0 K-tail guard, and the GROUP_SIZE_M reordering it relies on.

# Reference template (requires torch + triton; not runnable here).
import triton
import triton.language as tl

@triton.autotune(
    configs=[
        triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256,
                       "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8},
                      num_stages=3, num_warps=8),
        triton.Config({"BLOCK_SIZE_M": 64,  "BLOCK_SIZE_N": 64,
                       "BLOCK_SIZE_K": 32,  "GROUP_SIZE_M": 8},
                      num_stages=5, num_warps=2),
    ],
    key=["M", "N", "K"],     # re-tune when these change
)
@triton.jit
def matmul_kernel(
    a_ptr, b_ptr, c_ptr,
    M, N, K,
    stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    # ... grouped pid -> (pid_m, pid_n) ordering for L2 reuse (GROUP_SIZE_M) ...

    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
        accumulator = tl.dot(a, b, accumulator)   # Tensor Core MMA
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk
    c = accumulator.to(tl.float16)
    tl.store(c_ptrs, c, mask=c_mask)

Load-bearing points, each verified against the docs:

  • tl.dot(a, b, accumulator) issues the Tensor Core MMA and accumulates in place; the accumulator is initialized FP32 (tl.zeros(..., dtype=tl.float32)) for numerical fidelity even with FP16 operands, mirroring the accumulate-high-precision pattern of Tensor Cores and Mixed Precision (Triton matmul tutorial).
  • mask=... , other=0.0 on each tl.load zero-fills the ragged K tail so partial K-blocks do not corrupt the accumulator.
  • GROUP_SIZE_M reorders program ids into super-blocks so neighboring CTAs reuse the same A/B tiles in L2, an L2-locality optimization, not a correctness requirement.

Those three choices are exactly checkable in numpy. The block below proves a blocked GEMM that tiles K and accumulates in FP32 equals a reference matmul (including a ragged K where the zero-fill is required), demonstrates the corruption an unmasked tail causes, shows GROUP_SIZE_M is a covering permutation of the output tiles (so it cannot change the result), and (the adversarial precision case) that FP16 accumulation drifts measurably versus FP32. It runs on numpy alone.

# numpy-only model of the matmul tutorial's three load-bearing choices:
# (1) blocked GEMM tiling K with FP32 accumulation == reference matmul;
# (2) mask=..., other=0.0 zero-fills the ragged K tail (required for correctness);
# (3) GROUP_SIZE_M reorders program ids for L2 reuse but is a pure permutation.
# Adversarial: FP16 accumulation drifts vs FP32 (why tl.zeros is float32).
import numpy as np


def cdiv(n, b):
    return -(-n // b)


def matmul_blocked(a, b, BK, acc_dtype=np.float32):
    M, K = a.shape
    K2, N = b.shape
    assert K == K2
    c = np.zeros((M, N), dtype=np.float32)
    for k0 in range(0, K, BK):
        kk = min(k0 + BK, K) - k0
        a_tile = np.zeros((M, BK), dtype=acc_dtype)   # other=0.0 zero-fill
        b_tile = np.zeros((BK, N), dtype=acc_dtype)
        a_tile[:, :kk] = a[:, k0:k0 + kk]
        b_tile[:kk, :] = b[k0:k0 + kk, :]
        c += (a_tile @ b_tile).astype(np.float32)
    return c


def grouped_pids(M_tiles, N_tiles, GROUP_SIZE_M):
    order = []
    num_pid_in_group = GROUP_SIZE_M * N_tiles
    for pid in range(M_tiles * N_tiles):
        group_id = pid // num_pid_in_group
        first_pid_m = group_id * GROUP_SIZE_M
        group_rows = min(M_tiles - first_pid_m, GROUP_SIZE_M)
        pid_m = first_pid_m + ((pid % num_pid_in_group) % group_rows)
        pid_n = (pid % num_pid_in_group) // group_rows
        order.append((pid_m, pid_n))
    return order


rng = np.random.default_rng(0)

# 1) Blocked FP32-accumulate GEMM == reference matmul, K a multiple of BK.
M, K, N, BK = 128, 256, 96, 64
a = rng.standard_normal((M, K)).astype(np.float16)
b = rng.standard_normal((K, N)).astype(np.float16)
ref = a.astype(np.float32) @ b.astype(np.float32)
got = matmul_blocked(a, b, BK)
assert np.allclose(got, ref, atol=1e-2, rtol=1e-2), np.abs(got - ref).max()

# 2) Adversarial ragged K (250 = 3*64 + 58): other=0.0 must match the reference,
#    and injecting garbage into the tail lanes (an unmasked read) must corrupt it.
Kr, BKr = 250, 64
ar = rng.standard_normal((64, Kr)).astype(np.float16)
br = rng.standard_normal((Kr, 48)).astype(np.float16)
ref_r = ar.astype(np.float32) @ br.astype(np.float32)
assert np.allclose(matmul_blocked(ar, br, BKr), ref_r, atol=1e-2, rtol=1e-2)
kk = Kr % BKr
corrupt = ref_r + np.full((64, BKr - kk), 7.0) @ np.full((BKr - kk, 48), 7.0)
assert not np.allclose(corrupt, ref_r, atol=1e-2), "unmasked tail must corrupt result"

# 3) GROUP_SIZE_M ordering is a permutation: covers every output tile once.
M_tiles, N_tiles = cdiv(M, 32), cdiv(N, 32)
for G in (1, 4, 8):
    order = grouped_pids(M_tiles, N_tiles, G)
    assert set(order) == {(m, n) for m in range(M_tiles) for n in range(N_tiles)}
    assert len(order) == len(set(order)) == M_tiles * N_tiles

# 4) Adversarial precision: FP16 accumulation drifts vs FP32 (why tl.zeros=float32).
err32 = np.abs(matmul_blocked(a, b, BK, np.float32) - ref).max()
err16 = np.abs(matmul_blocked(a, b, BK, np.float16) - ref).max()
assert err32 < err16, (err32, err16)

print("OK: blocked FP32-accum GEMM == reference (K ragged too); other=0.0 required; "
      "GROUP_SIZE_M is a covering permutation; FP16 accum drifts vs FP32")

How to autotune it

@triton.autotune benchmarks a list of triton.Config objects and caches the winner per distinct key tuple. The official signatures:

# triton.autotune(configs, key, prune_configs_by=None, reset_to_zero=None,
#                 restore_value=None, pre_hook=None, post_hook=None,
#                 warmup=None, rep=None, use_cuda_graph=False,
#                 do_bench=None, cache_results=False)
#
# triton.Config(kwargs, num_warps=4, num_stages=3, num_ctas=1,
#               maxnreg=None, pre_hook=None, ir_override=None)

Source: triton.autotune, triton.Config. The two knobs that matter most:

  • num_warps: threads per program instance, in warps (default 4). More warps spread a block's work across more threads (the softmax tutorial sets num_warps=8 to distribute wide rows); too many cuts occupancy via register pressure.
  • num_stages: software-pipeline depth (default 3). Deeper pipelines overlap more cp.async loads with compute but cost more shared memory. The softmax tutorial picks num_stages = 4 when shared memory exceeds ~200 KB, else 2, i.e. it is bounded by the SM's SMEM budget (Triton fused-softmax tutorial).

key=["M", "N", "K"] means the autotuner re-runs only when those argument values change; keep the key to the shape dimensions that actually shift the optimal config, or autotuning thrashes. The first call for a new key pays the benchmarking cost, so warm the cache before measuring steady-state latency.

Triton does not free you from CUDA Occupancy Tuning; it just moves the levers to num_warps/num_stages/BLOCK_SIZE. The softmax tutorial computes occupancy explicitly as the min of two limits, registers (NUM_REGS // (n_regs * WARP_SIZE * num_warps)) and shared memory (SIZE_SMEM // size_smem), after pre-compiling the kernel to read back its actual n_regs and size_smem (Triton fused-softmax tutorial). Larger blocks and deeper pipelines raise per-instance register/SMEM use and can lower the number of resident program instances; autotuning searches this tradeoff, but you must supply candidate configs that span it.

Both the occupancy formula and the cache semantics are exact enough to check in numpy. The block below asserts occupancy is min(register limit, smem limit) with the binding limit switching between the two, that it is monotonically non-increasing as num_warps and num_stages grow, and that the autotuner benchmarks once per key, cache-hits on a repeat key, and re-benchmarks on a new key (the thrash you cause with over-fine keys). It runs on numpy alone.

# numpy-only model of (1) the softmax tutorial's occupancy computation:
# resident instances = min(reg_limit, smem_limit), non-increasing in num_warps
# and num_stages; (2) @triton.autotune caching: bench once per key, reuse on
# repeat, re-bench on a new key (why key granularity matters).
import numpy as np

WARP_SIZE = 32
NUM_REGS = 65536          # 64K 32-bit registers per SM (Ampere-class)
SIZE_SMEM = 228 * 1024    # ~228 KB shared memory per SM (Hopper-class budget)


def occupancy(n_regs, num_warps, size_smem):
    reg_limit = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
    smem_limit = SIZE_SMEM // size_smem
    return min(reg_limit, smem_limit)     # the tutorial's two-limit min


# 1) Occupancy is min(reg, smem) and the binding limit switches.
reg_bound = occupancy(n_regs=100, num_warps=8, size_smem=2048)
assert reg_bound == NUM_REGS // (100 * 32 * 8)      # register-limited
assert reg_bound < SIZE_SMEM // 2048
smem_bound = occupancy(n_regs=32, num_warps=4, size_smem=200 * 1024)
assert smem_bound == SIZE_SMEM // (200 * 1024)      # smem-limited
assert smem_bound < NUM_REGS // (32 * 32 * 4)

# 2) Occupancy non-increasing as num_warps rises (register-pressure tradeoff).
prev = occupancy(48, 1, 4096)
for w in range(1, 33):
    cur = occupancy(48, w, 4096)
    assert cur <= prev, (w, cur, prev)
    prev = cur

# 3) Deeper pipelines (num_stages) cost more SMEM: occupancy non-increasing.
prev = occupancy(32, 4, 1 * 32 * 1024)
for s in range(1, 8):
    cur = occupancy(32, 4, s * 32 * 1024)
    assert cur <= prev, (s, cur, prev)
    prev = cur


class AutotuneCache:                      # emulates @triton.autotune(key=[...])
    def __init__(self, configs, bench):
        self.configs, self.bench = configs, bench
        self.cache, self.bench_calls = {}, 0

    def run(self, key):
        if key not in self.cache:
            best, best_t = None, float("inf")
            for cfg in self.configs:
                self.bench_calls += 1     # each candidate benchmarked once
                t = self.bench(cfg, key)
                if t < best_t:
                    best, best_t = cfg, t
            self.cache[key] = best
        return self.cache[key]


configs = ["BM128_BN256", "BM64_BN64", "BM32_BN32"]
base = {"BM128_BN256": 1.0, "BM64_BN64": 1.5, "BM32_BN32": 2.0}
def bench(cfg, key):
    M = key[0]
    return base[cfg] * (1.0 + 4096 / max(M, 1))     # deterministic, key-dependent

at = AutotuneCache(configs, bench)
k_big = (4096, 4096, 4096)
w1 = at.run(k_big)
assert at.bench_calls == len(configs)               # first key benches every config
# 4) Repeat SAME key: served from cache, zero extra benchmarking.
assert at.run(k_big) == w1 and at.bench_calls == len(configs), "same key must cache-hit"
# 5) NEW key: pays the benchmarking cost again (over-fine keys thrash this).
at.run((512, 512, 512))
assert at.bench_calls == 2 * len(configs), "new key must re-benchmark"
assert w1 == "BM128_BN256" and bench(w1, k_big) == min(bench(c, k_big) for c in configs)

print("OK: occupancy = min(reg,smem) limits, non-increasing in warps/stages; "
      "autotune benches once per key, cache-hits on repeat, re-benches on new key")

How to integrate with it

  • As a custom op: wrap the kernel behind a torch.library custom op so torch.compile treats it as an opaque, composable node rather than re-tracing into it.
  • Inspect what torch.compile already emits before writing your own: TORCH_LOGS="output_code" dumps the generated Triton. Often Inductor has already fused your chain. Confirm before hand-rolling (PyTorch: Accelerating Triton Dequantization).

Because Triton is the TorchInductor backend, integration is usually a question of whether to let the compiler emit the kernel or to drop in your own as a custom op it will not re-trace. Read the generated code first (TORCH_LOGS=output_code), and only hand-author the node the compiler misses.

How to run it in production

  • Profile, do not guess: confirm the fused kernel actually moved from memory-bound toward compute-bound and that Tensor Core utilization rose. Triton lowers to SASS, so the same Nsight workflow applies. See Profiling GPUs: Nsight Systems and Nsight Compute and GPU Diagnostics and Validation. Validate tl.dot engaged the Tensor Cores and that occupancy matches the autotuner's pick.
  • Warm the autotune cache before you measure or serve. The first call for each new key pays the full benchmarking cost; a cold cache in the serving path shows up as a latency spike on the first request of a new shape.
  • Numerical care: keep reductions and accumulators in tl.float32 even for FP16/BF16 inputs (as the matmul accumulator does); low-precision accumulation drifts. The FP16-vs-FP32 accumulation drift is the assertion in the matmul block above.

How to maintain it

  • Pin Triton/PyTorch versions: the autotuner cache and generated SASS depend on both the Triton compiler and the target architecture (SM80 Ampere, SM90 Hopper, SM100 Blackwell). Re-autotune when moving GPUs; a config tuned for one architecture's SMEM/register budget is not optimal on another.
  • Re-autotune on any shape or hardware change that shifts the optimal config. The key=["M", "N", "K"] on the kernel is the contract for when the cache is allowed to reuse a config; if you change the kernel body, the operand dtype, or the GPU, invalidate the cache and re-benchmark.
  • Keep the config list spanning the occupancy tradeoff. Autotuning only searches the configs you supply, so the candidate list must include both register-heavy/deep-pipeline and lighter options; otherwise the "winner" is just the best of a bad set. The occupancy block above shows why the register and SMEM limits pull in opposite directions.

How to scale it

At model scale the leverage is not hand-writing more kernels, it is letting torch.compile emit Triton for the whole graph and hand-authoring only the few nodes Inductor misses. TorchInductor already re-assembles eager PyTorch into "a high percentage of Triton kernels with PyTorch connecting code" (PyTorch: Accelerating Triton Dequantization), so scaling a workload means widening what the compiler fuses (larger fused regions, more of the graph under compile) and pinning the handful of custom Triton ops behind torch.library boundaries so they compose. Across a fleet of mixed GPUs, scaling also means re-autotuning per architecture: the same source recompiles and re-benchmarks for SM80/SM90/SM100, and the autotune cache is per-architecture, so a config tuned on one is not carried to another.

Failure modes

Triton moves the low-level levers into Python, but the failures are the same ones the block model and autotuner can hide. Each row is a failure, the symptom, and the fix; the ones that are exact are asserted in the numpy blocks above.

  • Dropped or wrong mask on the ragged tail. Omitting mask=offsets < n_elements (vector-add) or mask=..., other=0.0 (matmul K-loop) lets the final block read past the end or fold garbage into the accumulator. Symptom: wrong results only for sizes that are not a multiple of BLOCK_SIZE, or NaNs from reading uninitialized memory. Fix: mask every boundary load/store and zero-fill with other=0.0. The mask block (vector-add) and the corrupted-tail assertion (matmul) above are this failure.
  • Low-precision accumulation drift. Initializing the accumulator in FP16/BF16 instead of tl.float32 makes a long K reduction drift from the reference. Symptom: matmul error grows with K and only in reduced precision. Fix: tl.zeros(..., dtype=tl.float32) and cast the result out at the end. The FP16-vs-FP32 assertion above quantifies the drift.
  • Softmax / reduction overflow. A reduction that skips the max-subtract overflows exp on large logits and produces inf/NaN, even though the block model would otherwise keep the row on-chip. Symptom: NaNs downstream of an attention or softmax kernel. Fix: subtract the row max before exp (shift-invariant, so free). The overflow assertion in the softmax block above is this case.
  • Autotune cache thrash. Keying the autotuner on a value that changes every call (or leaving the cache cold in the serving path) re-benchmarks on every request. Symptom: a latency spike on the first call of each new shape, or continuous re-tuning. Fix: key only on the shape dimensions that move the optimal config, and warm the cache before serving. The bench-once-per-key / re-bench-on-new-key assertions above model this.
  • Occupancy cliff from over-large blocks or deep pipelines. Raising BLOCK_SIZE, num_warps, or num_stages past the SM's register or SMEM budget drops the number of resident program instances and can slow the kernel despite doing "more per instance". Symptom: a bigger config benchmarks slower; Nsight shows low occupancy bound by registers or shared memory. Fix: supply configs that span the tradeoff and let the autotuner pick; read back n_regs/size_smem. The min(reg, smem) monotonicity assertions above show the cliff.
  • Beating cuBLAS on a regular matmul. Hand-writing a Triton GEMM for a standard shape usually ties or loses to cuBLAS at much higher engineering cost, and hand-rolling an op torch.compile already fuses duplicates work. Symptom: weeks spent to match, not beat, the library. Fix: use library kernels for regular shapes and inspect the generated Triton (TORCH_LOGS=output_code) before hand-authoring.

References

  • Chris Fregly, AI Systems Performance Engineering (O'Reilly), Chapter 9, Increasing CUDA Kernel Efficiency and Arithmetic Intensity. Lists OpenAI Triton alongside cuBLASLt, cuDNN, CUTLASS as the kernels driving cp.async/TMA into shared memory, and frames Triton vs CUTLASS as the two "fuse it yourself" paths.
  • Triton, Vector Add tutorial. @triton.jit, tl.program_id, tl.arange, tl.load/tl.store with mask=, triton.cdiv, grid lambda and kernel[grid](...) launch.
  • Triton, Fused Softmax tutorial. The 5MN+2M read / 3MN+2M write naive-vs-fused traffic model and ~4x ceiling, num_warps/num_stages choice, and the explicit register/SMEM occupancy computation.
  • Triton, Matrix Multiplication tutorial. @triton.autotune configs/key, matmul_kernel signature, FP32 tl.zeros accumulator, tl.dot, GROUP_SIZE_M L2 ordering, and K-loop mask/other guards.
  • Triton, Python API index, triton.autotune, triton.Config. Verbatim autotune and Config constructor signatures.
  • PyTorch, Accelerating Triton Dequantization Kernels for GPTQ. torch.compile/TorchInductor re-assembles eager PyTorch into Triton kernels; reading generated code down to SASS.

Related: PyTorch Custom CUDA Extensions · Kernel Fusion · CUTLASS: Templated GEMM and Kernel Building Blocks · Tensor Cores and Mixed Precision · CUDA Occupancy Tuning · Memory Coalescing and Vectorized Access · Shared Memory, Bank Conflicts, and Tiling · Roofline Model and Arithmetic Intensity · Profiling GPUs: Nsight Systems and Nsight Compute · Frameworks · Inline PTX and SASS-Level Tuning · Glossary