Skip to content
Markdown

PyTorch attention APIs: SDPA and FlexAttention

Scope: scaled_dot_product_attention backend selection and forcing, FlexAttention for custom masks/biases compiled to fused kernels, and how both map onto the FlashAttention kernel family.

What it is

PyTorch exposes two first-class attention entry points.

torch.nn.functional.scaled_dot_product_attention (SDPA) is a single fused operator covering the canonical softmax attention. It dispatches at runtime to one of several CUDA backends and "attempts to automatically select the most optimal implementation based on the inputs" (PyTorch SDPA docs). The candidate backends are FlashAttention-2, memory-efficient attention (the xFormers implementation), a C++ math fallback, and cuDNN attention (same). The book frames SDPA the same way: it "automatically uses the fastest available attention kernel for the given hardware (e.g., FlashAttention)" and "if it's not supported, it will fall back to the standard attention implementation" (Fregly, Ch. 13).

Signature (PyTorch SDPA docs):

# API signature (torch reference, not executed here)
torch.nn.functional.scaled_dot_product_attention(
    query, key, value,
    attn_mask=None,
    dropout_p=0.0,
    is_causal=False,
    scale=None,
    enable_gqa=False,
) -> torch.Tensor

FlexAttention (torch.nn.attention.flex_attention.flex_attention) is a programmable attention operator. You supply a Python score_mod (transforms the pre-softmax score per element) and/or a mask_mod (decides which positions are attended), and "Leveraging torch.compile, we automatically lower your function into a single fused FlexAttention kernel" (PyTorch FlexAttention blog). The book classifies it as "a compiler-based approach for custom sparsity patterns in attention" that "can be substantially faster for specific sparse attention patterns (e.g., block-sparse or sliding-window attention) by generating optimized kernels," to be used "for special cases that scaled_dot_product_attention does not support" (Fregly, Ch. 13).

Why use it

Naive attention materializes the full S = QK^T score matrix in HBM, making memory traffic scale with sequence length squared. The FlashAttention algorithm avoids that materialization by tiling and online-softmax; SDPA gives you that kernel through a stable functional API without hand-writing it. The payoff is both lower latency and lower peak memory: the FlexAttention lowering "doesn't materialize any extra memory and has performance competitive with handwritten ones" (PyTorch FlexAttention blog).

The two-API split matters because the FlashAttention fast path is narrow. It accepts the standard causal/dense pattern but not arbitrary additive biases or irregular sparsity. Historically, every new attention variant (ALiBi, sliding window, document masking, PrefixLM) required a bespoke CUDA kernel. FlexAttention collapses that into a few lines of Python that compile down to a fused kernel, and BlockMask lets the kernel skip fully-masked blocks so block-sparse patterns get real speedups rather than just correctness (PyTorch FlexAttention blog).

When to use it (and when not)

Use plain SDPA when:

  • Attention is dense or simple-causal, dtype is fp16/bf16, and head dims fit the FlashAttention path. This is the "no-hassle speedup" case (Fregly, Ch. 13).
  • You want a kernel that works in eager mode without torch.compile.

Force a specific SDPA backend when:

  • You need deterministic kernel choice for benchmarking or numerical reproducibility (the auto-selector can change across shapes/versions).
  • A backend has a correctness or shape limitation on your inputs and you must steer around it.

Use FlexAttention when:

  • The pattern is unsupported by SDPA: arbitrary score biases (ALiBi, soft-capping), sliding-window, document/block-diagonal masks, PrefixLM. Per the book, this is exactly "for special cases that scaled_dot_product_attention does not support" (Fregly, Ch. 13).
  • Sparsity is exploitable at block granularity: build a BlockMask so the kernel skips empty blocks.

Prefer not to reach for FlexAttention when a dense SDPA call already hits FlashAttention; you add a torch.compile dependency and compile latency for no kernel-level gain. The book's general guidance applies: try the high-level fused path first, and reserve custom-kernel routes for genuine gaps: "even for specialized attention patterns, PyTorch provides the FlexAttention API (prefill) and FlexDecoding API (decode), which are the preferred ways to implement custom attention kernels in PyTorch" (Fregly, Ch. 13).

For the autoregressive decode phase, the book points to FlexDecoding rather than FlexAttention: it "optimizes the decoding or text generation phases," "integrates with torch.compile and dynamic cache layouts," and "does not change training-time attention semantics" (Fregly, Ch. 13). Scaling a single long sequence across GPUs is a separate lever (context parallelism), covered under How to scale it.

Architecture

SDPA and FlexAttention are two front doors onto the same FlashAttention-style kernel family. SDPA dispatches a fixed menu of backends; FlexAttention compiles your Python score_mod/mask_mod into a fresh fused kernel.

flowchart TD
    A["Attention call"] --> B{"Pattern supported by SDPA fast path?"}
    B -->|"yes (dense / causal)"| C["F.scaled_dot_product_attention"]
    B -->|"no (custom mask / bias / sparsity)"| D["flex_attention + torch.compile"]
    C --> E{"sdpa_kernel backend selection"}
    E --> F["FLASH_ATTENTION (FlashAttention-2)"]
    E --> G["CUDNN_ATTENTION"]
    E --> H["EFFICIENT_ATTENTION (xFormers mem-efficient)"]
    E --> I["MATH (C++ fallback)"]
    D --> J["fused FlashAttention-style kernel (Triton via TorchInductor)"]
    J --> K["BlockMask skips fully-masked blocks"]

SDPA's FLASH_ATTENTION backend is FlashAttention-2 (PyTorch SDPA docs); EFFICIENT_ATTENTION is the xFormers memory-efficient kernel; MATH is the unfused reference. FlexAttention does not call those backends; it generates its own fused FlashAttention-style kernel through torch.compile/TorchInductor (which emits Triton), producing a kernel that "doesn't materialize any extra memory" (PyTorch FlexAttention blog). For the algorithm and Multi-head Latent Attention variants behind these kernels see FlashAttention and Multi-Head Latent Attention; for the Triton codegen path see OpenAI Triton: Authoring GPU Kernels in Python.

Underneath every one of these kernels is the trick the previous section named: tiling plus online softmax. The kernel streams the key/value blocks and carries a running row max and denominator, rescaling the partial result as it goes, so the full score matrix is never held in HBM. The numpy block below reproduces that recurrence, checks it against the naive full-materialization reference across tile sizes, and then shows why the running max is not optional: without it, large logits overflow exp.

import numpy as np

def naive_attention(q, k, v, scale):                    # materializes full S x S scores
    s = (q @ k.T) * scale
    p = np.exp(s - s.max(axis=-1, keepdims=True))
    return (p / p.sum(axis=-1, keepdims=True)) @ v

def flash_tiled(q, k, v, scale, block):
    # FlashAttention core trick: stream KV in tiles, keep a running max m and
    # denominator l, never materialize the full score matrix. Rescale on the fly.
    S, D = q.shape[0], v.shape[1]
    out = np.zeros((S, D)); m = np.full((S, 1), -np.inf); l = np.zeros((S, 1))
    for j in range(0, k.shape[0], block):
        s = (q @ k[j:j + block].T) * scale              # only a [S, block] tile
        m_new = np.maximum(m, s.max(axis=-1, keepdims=True))
        p = np.exp(s - m_new)
        alpha = np.exp(m - m_new)                        # correct the prior partial sum
        l = alpha * l + p.sum(axis=-1, keepdims=True)
        out = alpha * out + p @ v[j:j + block]
        m = m_new
    return out / l

rng = np.random.default_rng(1)
S, D = 7, 4
q, k, v = (rng.standard_normal((S, D)) for _ in range(3))
scale = 1.0 / np.sqrt(D)

# 1. equivalence to the slow full-materialization reference, across every tile size
ref = naive_attention(q, k, v, scale)
for block in (1, 2, 3, 7):
    assert np.allclose(flash_tiled(q, k, v, scale, block), ref, atol=1e-12)

# 2. adversarial numerical stability: scale logits up so a naive exp WITHOUT the
#    max subtraction overflows to +inf, while the max-tracking tiled path stays
#    finite AND correct. This is why online softmax exists.
q_big = q * 400.0
with np.errstate(over="ignore"):
    unstable = np.exp((q_big @ k.T) * scale)            # no max-subtraction
assert np.isinf(unstable).any()                          # overflow really happens
flash_big = flash_tiled(q_big, k, v, scale, block=2)
assert np.isfinite(flash_big).all()
assert np.allclose(flash_big, naive_attention(q_big, k, v, scale), atol=1e-9)

print("online softmax == naive (tiles 1,2,3,7); overflow-safe: all asserts passed")

How to use it

The default call is one line: hand SDPA your query, key, and value tensors and let it pick a backend. Blocks below marked "reference template" require torch and a CUDA GPU (not installed in this docs environment); each core algorithm they teach is validated underneath in a runnable numpy block.

SDPA, default (auto-selected backend)

# reference template: requires torch + a CUDA GPU
import torch
import torch.nn.functional as F

# q, k, v: [batch, heads, seq, head_dim], bf16 on CUDA
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)

is_causal=True applies the upper-triangular causal mask without you constructing an attn_mask tensor; it is the cheap path for decoder self-attention.

The core math SDPA fuses (scaled scores, causal masking, softmax, value mix) validated in numpy, including the default 1/sqrt(head_dim) scale, the is_causal shortcut equivalence, and the fully-masked-row failure:

import numpy as np

def softmax(x, axis=-1):
    x = x - np.max(x, axis=axis, keepdims=True)
    e = np.exp(x)
    return e / np.sum(e, axis=axis, keepdims=True)

def sdpa(q, k, v, is_causal=False, attn_mask=None, scale=None):
    d = q.shape[-1]
    if scale is None:
        scale = 1.0 / np.sqrt(d)                       # SDPA default scale
    scores = (q @ np.swapaxes(k, -1, -2)) * scale
    if is_causal:
        sq, sk = q.shape[-2], k.shape[-2]
        keep = np.tril(np.ones((sq, sk), dtype=bool))
        scores = np.where(keep, scores, -np.inf)
    if attn_mask is not None:
        scores = scores + attn_mask
    return softmax(scores, axis=-1) @ v

rng = np.random.default_rng(0)
S, D = 6, 8
q, k, v = (rng.standard_normal((S, D)) for _ in range(3))

# 1. omitting scale must equal passing the documented default 1/sqrt(head_dim)
assert np.allclose(sdpa(q, k, v), sdpa(q, k, v, scale=1.0 / np.sqrt(D)))

# 2. is_causal=True equals supplying the equivalent additive -inf mask (reference)
causal_bias = np.where(np.tril(np.ones((S, S), bool)), 0.0, -np.inf)
assert np.allclose(sdpa(q, k, v, is_causal=True), sdpa(q, k, v, attn_mask=causal_bias))

# 3. under causal masking, query 0 sees only key 0, so out row 0 == value row 0
assert np.allclose(sdpa(q, k, v, is_causal=True)[0], v[0])

# 4. attention weights are a probability distribution (each row sums to 1)
w = softmax((q @ k.T) / np.sqrt(D), axis=-1)
assert np.allclose(w.sum(axis=-1), 1.0)

# 5. adversarial: a query row masked to attend to NOTHING yields NaN, not a
#    silent wrong answer. This is the fully-masked-row footgun (see Failure modes).
bad = causal_bias.copy(); bad[2, :] = -np.inf
with np.errstate(invalid="ignore"):                    # the all -inf row is expected
    out_bad = sdpa(q, k, v, attn_mask=bad)
assert np.isnan(out_bad[2]).all() and not np.isnan(out_bad[0]).any()

print("SDPA core math: all asserts passed")

How to integrate custom masks and biases

FlexAttention is the integration point for attention patterns SDPA cannot express. score_mod receives the scalar score plus index tensors and returns the modified score; mask_mod returns a boolean keep/drop decision (PyTorch FlexAttention docs):

# API signatures (torch reference, not executed here)
def score_mod(score, batch, head, q_idx, kv_idx): ...   # -> Tensor
def mask_mod(batch, head, q_idx, kv_idx): ...            # -> Tensor (bool)

Custom mask (mask_mod) compiled to a BlockMask

Sliding-window causal attention, expressed as a mask_mod compiled into a sparse BlockMask:

# reference template: requires torch + a CUDA GPU
import torch
from torch.nn.attention.flex_attention import flex_attention, create_block_mask

WINDOW = 1024

def sliding_window_causal(b, h, q_idx, kv_idx):
    causal = q_idx >= kv_idx
    windowed = q_idx - kv_idx <= WINDOW
    return causal & windowed

# B/H set to None broadcasts the mask across batch and heads.
block_mask = create_block_mask(
    sliding_window_causal, B=None, H=None, Q_LEN=8192, KV_LEN=8192
)

# torch.compile lowers score_mod/mask_mod into one fused FlashAttention-style kernel.
flex = torch.compile(flex_attention)
out = flex(q, k, v, block_mask=block_mask)

The mask_mod predicate validated in numpy: causal exclusion of the future, the exact window boundary (a key WINDOW back is kept, WINDOW + 1 back is dropped), and equivalence to a slow scalar reference.

import numpy as np

WINDOW = 4

def sliding_window_causal(q_idx, kv_idx):               # the page's mask_mod, vectorized
    causal = q_idx >= kv_idx
    windowed = q_idx - kv_idx <= WINDOW
    return causal & windowed

S = 12
q_idx = np.arange(S)[:, None]
kv_idx = np.arange(S)[None, :]
mask = sliding_window_causal(q_idx, kv_idx)             # [S, S] bool

# 1. causal: no query may attend to a future key
future = kv_idx > q_idx
assert not mask[future].any()

# 2. boundary values: exactly WINDOW back is kept, WINDOW+1 back is dropped
i = 10
assert mask[i, i - WINDOW] and not mask[i, i - WINDOW - 1]

# 3. each row keeps at most WINDOW + 1 keys (the window plus the diagonal)
assert mask.sum(axis=1).max() == WINDOW + 1

# 4. equivalence to a slow scalar double-loop reference
ref = np.array([[(i >= j) and (i - j <= WINDOW) for j in range(S)] for i in range(S)])
assert np.array_equal(mask, ref)

print("sliding-window causal mask_mod: all asserts passed")

The sparsity a BlockMask exploits, validated by computing attention twice: dense with the full mask, and block-sparse skipping any fully-masked tile. The two agree exactly, and the sparse path touches fewer tiles.

import numpy as np

def softmax(x):
    x = x - x.max(axis=-1, keepdims=True)
    e = np.exp(x); return e / e.sum(axis=-1, keepdims=True)

WINDOW, BLOCK = 4, 4

def mask_mod(q_idx, kv_idx):
    return (q_idx >= kv_idx) & (q_idx - kv_idx <= WINDOW)

S = 16
q_idx = np.arange(S)[:, None]; kv_idx = np.arange(S)[None, :]
full_mask = mask_mod(q_idx, kv_idx)                     # [S, S]

# BlockMask idea: a BLOCK x BLOCK tile is computed only if it holds ANY kept element.
nqb = nkb = S // BLOCK
active = np.zeros((nqb, nkb), bool)
for bi in range(nqb):
    for bj in range(nkb):
        active[bi, bj] = full_mask[bi*BLOCK:(bi+1)*BLOCK, bj*BLOCK:(bj+1)*BLOCK].any()

# 1. sparsity is real and lands where expected for a sliding window
assert active.sum() < nqb * nkb
assert not active[0, 1]                                  # upper-right tile: fully masked
assert active[3, 3] and active[3, 2]                     # diagonal band: active

# 2. equivalence: dense masked attention == block-sparse attention that SKIPS
#    fully-masked tiles. Same numbers, fewer tiles touched.
rng = np.random.default_rng(2)
D = 4
q, k, v = (rng.standard_normal((S, D)) for _ in range(3))
scale = 1.0 / np.sqrt(D)

dense = softmax(np.where(full_mask, (q @ k.T) * scale, -np.inf)) @ v

sparse_scores = np.full((S, S), -np.inf); computed = 0
for bi in range(nqb):
    for bj in range(nkb):
        if not active[bi, bj]:
            continue                                     # skip the fully-masked tile
        computed += 1
        rs, cs = bi*BLOCK, bj*BLOCK
        tile = (q[rs:rs+BLOCK] @ k[cs:cs+BLOCK].T) * scale
        m = full_mask[rs:rs+BLOCK, cs:cs+BLOCK]
        sparse_scores[rs:rs+BLOCK, cs:cs+BLOCK] = np.where(m, tile, -np.inf)
sparse = softmax(sparse_scores) @ v

assert computed < nqb * nkb                              # tiles were actually skipped
assert np.allclose(dense, sparse)                        # identical result

print(f"BlockMask sparsity: computed {computed}/{nqb*nkb} tiles, dense==sparse: passed")

Custom bias (score_mod): ALiBi

An additive bias (e.g. ALiBi) is a score_mod instead:

# reference template: requires torch + a CUDA GPU (flex defined in the block above)
def alibi(score, b, h, q_idx, kv_idx):
    return score + (q_idx - kv_idx) * alibi_slopes[h]

out = flex(q, k, v, score_mod=alibi)

The additive-bias math validated in numpy: zero bias on the diagonal, a monotone penalty as the key recedes, the nearer key outweighing the farther one after softmax, and slope 0 collapsing to uniform causal attention.

import numpy as np

def softmax(x):
    x = x - x.max(axis=-1, keepdims=True)
    e = np.exp(x); return e / e.sum(axis=-1, keepdims=True)

# ALiBi as a score_mod: add (q_idx - kv_idx) * slope to the score, then softmax.
S = 8
slope = -0.25                                           # ALiBi slopes are negative
q_idx = np.arange(S)[:, None]; kv_idx = np.arange(S)[None, :]
bias = (q_idx - kv_idx) * slope                         # [S, S]

# 1. a token's bias to itself is zero (diagonal)
assert np.allclose(np.diag(bias), 0.0)

# 2. along a causal row, penalty grows (bias falls) monotonically as the key recedes
row = 7
back = bias[row, : row + 1][::-1]                        # diagonal -> oldest key
assert np.all(np.diff(back) <= 0)

# 3. effect on attention: with equal raw scores, the nearest key wins and a nearer
#    key always outweighs a farther one
causal = q_idx >= kv_idx
w = softmax(np.where(causal, bias, -np.inf))
assert np.argmax(w[row, : row + 1]) == row
assert w[row, row] > w[row, 0]

# 4. slope 0 collapses ALiBi to plain uniform causal attention (reference equivalence)
w0 = softmax(np.where(causal, (q_idx - kv_idx) * 0.0, -np.inf))
assert np.allclose(w0[row, : row + 1], 1.0 / (row + 1))

print("ALiBi score_mod: all asserts passed")

flex_attention signature (PyTorch FlexAttention docs):

# API signature (torch reference, not executed here)
flex_attention(query, key, value,
               score_mod=None, block_mask=None, scale=None,
               enable_gqa=False, return_lse=False,
               kernel_options=None, *, return_aux=None)

create_block_mask(mask_mod, B, H, Q_LEN, KV_LEN, device=None, BLOCK_SIZE=128, _compile=False) produces a BlockMask, a block-level sparse format "somewhat of a cross in-between BCSR and a non-sparse format" (PyTorch FlexAttention docs). FlexAttention "can then use BlockMask to take advantage of the sparsity" by skipping fully-masked blocks (PyTorch FlexAttention blog). Wrap flex_attention (and ideally create_block_mask) in torch.compile: this is the lowering step that produces the fused kernel; eager FlexAttention does not give you the fused-kernel performance.

How to run it in production

In production you usually want deterministic kernel selection: a fixed backend so a PyTorch upgrade or a shape change cannot silently move your numerics or latency, and hard errors that surface a capability gap early rather than a quiet fallback to a slow kernel.

Forcing or restricting the backend

The current API is the torch.nn.attention.sdpa_kernel context manager driven by the SDPBackend enum, whose members are MATH, FLASH_ATTENTION, EFFICIENT_ATTENTION, and CUDNN_ATTENTION (PyTorch sdpa_kernel docs).

# reference template: requires torch + a CUDA GPU (backend-control API, no new math)
import torch.nn.functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel

# Force FlashAttention only; raises if inputs are unsupported.
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    out = F.scaled_dot_product_attention(q, k, v, is_causal=True)

# Supply a priority-ordered list: try cuDNN, then Flash.
with sdpa_kernel([SDPBackend.CUDNN_ATTENTION, SDPBackend.FLASH_ATTENTION],
                 set_priority=True):
    out = F.scaled_dot_product_attention(q, k, v, is_causal=True)

On exit the manager restores the prior backend flags (PyTorch sdpa_kernel docs). Note: the older torch.backends.cuda.sdp_kernel(enable_flash=..., enable_math=..., enable_mem_efficient=...) context manager still exists but is the legacy form; prefer torch.nn.attention.sdpa_kernel with the enum on current PyTorch.

If the listed backend cannot serve the given shapes/dtype/mask, the call errors instead of silently falling back, which is the point when you are pinning a backend for benchmarking. In a serving loop where you would rather degrade than crash, pass a priority-ordered list (as above) so a supported backend takes over.

How to maintain it

  • Pin a backend with sdpa_kernel in regression CI so a PyTorch upgrade that changes the auto-selector cannot silently shift your numerics or latency. Wire this into PyTorch Performance Regression Testing in CI.
  • FlexAttention recompiles when score_mod/mask_mod capture new shapes or Python globals; keep them pure and shape-stable. A BlockMask can be rebuilt for changed sparsity without recompiling the kernel (PyTorch FlexAttention blog).
  • Verify the realized backend in a profiler trace (the kernel name reveals Flash vs. memory-efficient vs. math). See Profiling GPUs: Nsight Systems and Nsight Compute. The API surface (flex_attention, SDPBackend) is marked prototype/beta and can change across releases; pin your PyTorch version.

How to scale it

Longer sequences and multi-GPU scaling do not come from picking a different attention entry point; they come from three levers around it:

  • Block sparsity for long context: a BlockMask lets FlexAttention skip fully-masked blocks, so a sliding-window or block-diagonal pattern turns quadratic work into work proportional to the kept blocks. The BlockMask numpy block above computes only 7 of 16 tiles for a window mask while matching the dense result exactly.
  • Decode-phase scaling: for the autoregressive decode phase the book points to FlexDecoding, which "optimizes the decoding or text generation phases" and "integrates with torch.compile and dynamic cache layouts" (Fregly, Ch. 13). See When to use it.
  • Scaling a single long sequence across GPUs is context parallelism (context_parallel()), not an attention-API choice. See Distributed Training Platform.

Failure modes

Failure mode Cause Mitigation
Silent backend drift SDPA's auto-selector can pick a different backend across shapes or PyTorch versions, shifting numerics and latency Pin a backend with sdpa_kernel and assert it in regression CI (see PyTorch Performance Regression Testing in CI)
Forced backend raises sdpa_kernel(FLASH_ATTENTION) errors instead of falling back when inputs are unsupported Intended when benchmarking; in serving pass a priority-ordered list with set_priority=True so a supported backend takes over
Falls off the Flash path FlashAttention needs fp16/bf16 and head dims it supports; fp32 or an unsupported head dim quietly routes to a slower backend Use bf16/fp16 and confirm the realized kernel in a profiler trace (see Profiling GPUs: Nsight Systems and Nsight Compute)
No speedup from FlexAttention Eager FlexAttention does not emit the fused kernel Wrap flex_attention (and ideally create_block_mask) in torch.compile
Recompilation storms score_mod/mask_mod capturing new shapes or Python globals retrigger torch.compile Keep the mods pure and shape-stable; rebuild a BlockMask for changed sparsity without recompiling the kernel
Fully-masked query row yields NaN A query whose mask_mod rejects every key has no valid target; softmax over an all -inf row divides by a zero normalizer Guarantee every query keeps at least one key (for example the diagonal). The SDPA numpy block above asserts this NaN appears rather than a silent wrong number
Prototype/beta API churn flex_attention and SDPBackend are marked prototype/beta and can change across releases Pin your PyTorch version

References

Related: torch.compile: Graph Capture, Backends, and Recompiles · PyTorch CUDA Caching Allocator Tuning · PyTorch/XLA and the XLA Compiler · Activation Checkpointing and Memory Offloading · PyTorch Performance Regression Testing in CI · FlashAttention and Multi-Head Latent Attention · OpenAI Triton: Authoring GPU Kernels in Python · Tensor Cores and Mixed Precision · Profiling GPUs: Nsight Systems and Nsight Compute · Inference Serving and Optimization · Glossary