FlashAttention and multi-head latent attention¶
Scope: IO-aware exact attention (FlashAttention tiling, online softmax, versions 2 and 3) and DeepSeek's Multi-Head Latent Attention (MLA), which compresses the KV cache via low-rank joint projection.
What it is¶
Transformer attention computes softmax(QKᵀ / √d) V. The naive implementation materializes the full N × N score matrix S = QKᵀ (and the probabilities P = softmax(S)) in HBM, costing O(N²) memory and O(N²) extra HBM reads/writes. For long sequences this makes attention memory-bandwidth bound, not compute bound.
FlashAttention is an IO-aware, exact (no approximation) reformulation. It tiles Q, K, and V into blocks, streams them from HBM into on-chip SRAM, and computes the output one K/V block at a time using online softmax, a running rescaling of the partial output and the running max/sum so the final result is numerically identical to standard softmax, without ever writing S or P to HBM. Memory becomes linear in sequence length, and HBM traffic drops from Ω(Nd + N²) to roughly O(N²d²M⁻¹), where d is head dimension and M is SRAM size (FlashAttention paper). The book frames this as a canonical case of "mechanical sympathy": FlashAttention "tiles GPU computations, which minimizes the number of reads and writes issued to the GPU's memory" and yields a "2×–4× speedup in training and inference for long sequences while also reducing the overall memory footprint" (Fregly, Ch. 1).
Multi-Head Latent Attention (MLA), introduced by DeepSeek in DeepSeek-V2 and used in DeepSeek-V3/R1, attacks a different bottleneck: the size of the KV cache during decode. Instead of caching full per-head keys and values, MLA caches a single low-rank latent vector per token and reconstructs head-specific K/V on the fly via learned up-projection matrices (DeepSeek-V2, arXiv:2405.04434). The book describes MLA as restructuring "the attention computations to better utilize NVIDIA's memory hierarchy and dedicated GPU Tensor Cores," achieving "higher throughput at a fraction of the cost, surpassing even FlashAttention's performance on those same H800 systems" (Fregly, Ch. 1).
FlashAttention and MLA are complementary: FlashAttention is an exact kernel-level rewrite of the same math (it changes how attention is computed); MLA is a model-architecture change (it changes what is stored and projected). The two coexist. MLA's compute can itself run through a tiled, FlashAttention-style kernel (DeepSeek's decode kernel is FlashMLA).
Why use it¶
- Attention used to dominate runtime. Replacing default attention with FlashAttention "reduces what used to be a major bottleneck (attention) down to a fraction of overall runtime" and "became the default in many libraries almost overnight" (Fregly, Ch. 1). PyTorch's
scaled_dot_product_attentionnow dispatches to a FlashAttention backend automatically for eligible shapes/dtypes. - Long context is gated by memory. Linear-memory attention is what makes 100k+ token contexts feasible without quadratic activation blow-up. See roofline / arithmetic intensity for why reducing HBM traffic moves a kernel from memory-bound toward compute-bound.
- KV cache dominates decode memory. The book works a concrete example: a 70B model (80 layers, 32 heads, head dim 128, hidden 4096) needs ~1.31 MB of FP16 KV per token, ~328 GB for a 250k-token context (Fregly, Ch. 18). MLA's low-rank compression cuts the cached bytes per token sharply, which directly raises decode batch size and goodput. See goodput for AI systems.
- Hardware-software codesign. FA3 exploits Hopper-specific units (WGMMA Tensor Cores, TMA, FP8); MLA was tuned to extract throughput from bandwidth-constrained H800 GPUs. Both are examples of the codesign theme throughout the book.
When to use it (and when not)¶
Use FlashAttention when:
- Sequences are long (hundreds to hundreds of thousands of tokens) and attention is a measurable fraction of step time (verify with Nsight profiling).
- You want exact attention (FlashAttention is exact; it is not an approximation like sparse/linear attention).
- Head dimension and dtype are supported by the installed backend (FA2/FA3 support specific head dims and fp16/bf16/fp8).
Skip or deprioritize FlashAttention when:
- Sequences are very short: the
N²term is small and the tiling overhead and launch cost may not pay off; the fused SDPA math/efficient backends may already suffice. - Your bottleneck is elsewhere (MLP/GEMM, communication, data loading). Profile first.
Use MLA when:
- You control the model architecture (MLA is trained in, not a drop-in kernel swap for an existing MHA/GQA model) and decode is KV-cache / memory-bandwidth bound (long contexts, large batches, high concurrency).
MLA does not help if:
- You cannot retrain/convert the model. For an existing pretrained MHA model, prefer grouped-query attention, MQA, or FP8/INT8 KV-cache quantization to shrink the cache without architectural change.
Architecture¶
Two independent levers on the same attention. FlashAttention keeps the math bit-exact but changes the dataflow: nothing of size N × N ever touches HBM, so long-context attention becomes linear in memory and bandwidth-cheap. MLA keeps the dataflow familiar but changes what is stored: one low-rank latent per token instead of full per-head K/V, so decode stops being KV-cache bound. They stack, and MLA's decode math itself runs through a FlashAttention-style tiled kernel (FlashMLA).
flowchart LR
Q["Query block Q (HBM)"] --> SRAM["On-chip SRAM tile"]
K["Key blocks K_j (HBM, streamed)"] --> SRAM
V["Value blocks V_j (HBM, streamed)"] --> SRAM
SRAM --> Score["S_j = Q @ K_j.T * scale (in SRAM)"]
Score --> Online["Online softmax: update running max m & sum l"]
Online --> Acc["Rescale & accumulate O = O * alpha + p @ V_j"]
Acc -->|"loop next K/V block"| SRAM
Acc --> Out["Normalize O = O / l, write output to HBM"]
Note["No full N x N matrix S or P ever materialized in HBM"] -.-> Online
MLA["MLA: cache one low-rank KV latent per token (complementary)"] -.-> K
The SRAM/HBM split and the block streaming build directly on shared-memory tiling and the GPU memory hierarchy; the FP8/Tensor-Core paths in FA3 build on tensor cores and mixed precision.
FlashAttention versions¶
| Version | Target GPU | Key techniques | Reported result |
|---|---|---|---|
| FlashAttention (v1) | Ampere+ | Tiling, online softmax, recompute in backward | 2–4× wall-clock speedup vs standard attention (paper) |
| FlashAttention-2 | Ampere/Hopper | Better work partitioning, fewer non-matmul FLOPs, parallelize over seqlen | ~2× over FA1; ~35% of H100 peak FP16 FLOPs (FA2, arXiv:2307.08691; FA3 blog) |
| FlashAttention-3 | Hopper (H100) | WGMMA Tensor Cores, TMA async copy, warp-specialization, FP8 | FP16 up to ~740 TFLOPS (~75% of H100 peak), ~1.5–2.0× over FA2; FP8 close to ~1.2 PFLOPS (FA3 blog) |
The 2–4× figure in the table aligns with the book's stated "2×–4× speedup" for FlashAttention on long sequences (Fregly, Ch. 1). FA2/FA3 numbers are from official sources, not hardware-tested here.
MLA: low-rank KV compression¶
MLA replaces the per-head K/V cache with one shared latent vector per token. For input h_t, a down-projection produces a compressed latent c_t^{KV} (cached), and up-projection matrices reconstruct per-head keys and values; queries are similarly compressed. To stay compatible with rotary position embeddings, MLA carries a small decoupled RoPE key alongside the compressed latent (DeepSeek-V2, arXiv:2405.04434).
KV cache stored per token, in number of elements (DeepSeek-V2 paper, Table 1):
| Attention | KV cache per token |
|---|---|
| MHA | 2 · n_h · d_h |
GQA (n_g groups) |
2 · n_g · d_h |
| MQA | 2 · d_h |
| MLA | d_c + d_h^R (= 9/2 · d_h in DeepSeek-V2) |
DeepSeek-V2 sets the compressed latent dimension d_c = 4 · d_h and the decoupled RoPE key dimension d_h^R = d_h / 2, so the cache is 4.5 · d_h elements per token, far below MHA's 2 · n_h · d_h for the model's many heads, while the paper reports MLA matching or beating MHA quality (DeepSeek-V2, arXiv:2405.04434). The book notes the headline effect: MLA "exploit[s] the constrained H800 GPUs' architecture, achieving higher throughput at a fraction of the cost" (Fregly, Ch. 1).
The online softmax core (how the exactness holds)¶
Standard softmax over a row needs the global max and sum before normalizing. FlashAttention processes K/V in blocks and keeps a running max m and running sum l, rescaling the accumulated output O whenever a new block raises the max. The result equals dense softmax attention exactly, and the full score matrix S and probabilities P are never written to HBM. The block below runs the loop over K/V blocks j, checks it against a dense reference, checks that any tiling gives the identical answer, and checks the max-subtraction stability that keeps huge logits from overflowing (numpy only, runnable):
import numpy as np
def flash_attention_online(q, K, V, scale, block=4):
"""Online-softmax attention over K/V blocks (the FlashAttention core loop).
Numerically exact vs dense softmax attention. Never materializes full scores."""
n, d = K.shape
m, l = -np.inf, 0.0 # running max, running denominator
O = np.zeros(d, dtype=np.float64) # running output accumulator
for j in range(0, n, block):
Kj, Vj = K[j:j + block], V[j:j + block]
S_j = (q @ Kj.T) * scale # block scores, computed "in SRAM"
m_new = max(m, float(S_j.max()))
p = np.exp(S_j - m_new) # rescale current block to new max
alpha = np.exp(m - m_new) # correction factor for prior state
l = l * alpha + float(p.sum())
O = O * alpha + p @ Vj # rescale prior output, add this block
m = m_new
return O / l # final normalization
def dense_attention(q, K, V, scale):
"""Reference: standard softmax attention, full score row materialized."""
s = (q @ K.T) * scale
p = np.exp(s - s.max())
return (p / p.sum()) @ V
rng = np.random.default_rng(0)
n, d = 37, 16 # non-multiple of block on purpose
q = rng.standard_normal(d)
K = rng.standard_normal((n, d))
V = rng.standard_normal((n, d))
scale = 1.0 / np.sqrt(d)
# 1. Happy path: online == dense reference (exactness, not approximation).
out = flash_attention_online(q, K, V, scale, block=4)
ref = dense_attention(q, K, V, scale)
assert np.allclose(out, ref, atol=1e-12), f"online != dense: max|d|={np.abs(out-ref).max()}"
# 2. Block-size invariance: any tiling gives the identical result.
for b in (1, 3, 8, n, n + 5):
assert np.allclose(flash_attention_online(q, K, V, scale, b), ref, atol=1e-12), f"block {b} diverged"
# 3. Adversarial overflow: huge logits overflow a naive exp(S) but online softmax
# (max-subtraction) stays finite and correct. This is the numerical-stability guarantee.
big = 400.0 # exp(400) == inf in float64
Kbig = K.copy(); Kbig[0] *= big / scale # force one score ~ +400 before scaling
with np.errstate(over="ignore"): # the overflow is the point we assert on
naive = np.exp((q @ Kbig.T) * scale) # naive path overflows to +inf
assert not np.isfinite(naive).all(), "expected naive exp to overflow for the adversarial case"
out_big = flash_attention_online(q, Kbig, V, scale, block=5)
assert np.isfinite(out_big).all(), "online softmax must stay finite under overflow"
ref_big = dense_attention(q, Kbig, V, scale) # dense with its own max-subtraction
assert np.allclose(out_big, ref_big, atol=1e-10), "online != dense on adversarial logits"
print("online softmax: exactness, block-invariance, and overflow-stability all PASS")
Run: python3 prints online softmax: exactness, block-invariance, and overflow-stability all PASS. The overflow case is the load-bearing one: it proves the running-max subtraction, not luck, is what keeps FlashAttention finite on peaked logits.
How to use it¶
The most portable way to get FlashAttention is torch.nn.functional.scaled_dot_product_attention, which auto-dispatches to a Flash backend when shapes/dtypes qualify. To force the backend, use the sdpa_kernel context manager (PyTorch docs). Reference template (needs torch + CUDA; not run here):
# Reference template: requires torch + a CUDA GPU (not installed in this doc's test env).
import torch
import torch.nn.functional as F
from torch.nn.attention import sdpa_kernel, SDPBackend
# q, k, v: [batch, heads, seq, head_dim], fp16/bf16 on CUDA
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
Valid SDPBackend members are FLASH_ATTENTION, EFFICIENT_ATTENTION, CUDNN_ATTENTION, and MATH. You can pass a priority-ordered list and let PyTorch fall back:
# Reference template (torch): try cuDNN attention first, then FlashAttention, in that order.
with sdpa_kernel([SDPBackend.CUDNN_ATTENTION, SDPBackend.FLASH_ATTENTION], set_priority=True):
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
On exit, the context manager restores the previous backend flags (all backends re-enabled). Confirm the Flash kernel actually ran with Nsight Systems; a backend may silently fall back to the math kernel if a constraint (head dim, dtype, masking) is unmet.
All four backends compute the same scaled dot-product attention math; they differ only in dataflow and precision, so the value contract (softmax-normalized weights, causal masking that never leaks the future) is identical across them. The numpy block below validates that contract, including the causal edge case where query 0 sees exactly one key and the adversarial check that corrupting future values leaves earlier rows untouched (runnable):
import numpy as np
def sdpa(q, k, v, scale, causal=False):
"""Scaled dot-product attention, the math every SDPA / Flash backend computes.
q,k,v: [heads, seq, head_dim]. Returns [heads, seq, head_dim]."""
scores = np.einsum("hqd,hkd->hqk", q, k) * scale
if causal: # position i attends to <= i only
seq = scores.shape[-1]
mask = np.triu(np.ones((seq, seq), bool), k=1)
scores = np.where(mask, -np.inf, scores)
scores = scores - scores.max(-1, keepdims=True) # stable softmax
p = np.exp(scores)
p = p / p.sum(-1, keepdims=True)
return np.einsum("hqk,hkd->hqd", p, v)
rng = np.random.default_rng(1)
H, S, D = 3, 6, 8
q = rng.standard_normal((H, S, D))
k = rng.standard_normal((H, S, D))
v = rng.standard_normal((H, S, D))
scale = 1.0 / np.sqrt(D)
# 1. Rows are probabilities: each attention row sums to 1.
out = sdpa(q, k, v, scale)
raw = np.einsum("hqd,hkd->hqk", q, k) * scale
p_rows = np.exp(raw - raw.max(-1, keepdims=True))
p_rows = p_rows / p_rows.sum(-1, keepdims=True)
assert np.allclose(p_rows.sum(-1), 1.0), "attention weights must sum to 1 per row"
assert out.shape == (H, S, D)
# 2. Causal edge case: query 0 attends ONLY to key 0, so its output must equal
# v[:, 0] exactly (single unmasked key => weight 1.0).
out_c = sdpa(q, k, v, scale, causal=True)
assert np.allclose(out_c[:, 0, :], v[:, 0, :], atol=1e-12), "row 0 causal output must equal v[0]"
# 3. Adversarial no-future-leak: corrupt every value after position 0. Row 0 must be
# unchanged (mask blocks the future); the last row must change (it sees the future).
v_corrupt = v.copy(); v_corrupt[:, 1:, :] += 1000.0
out_c2 = sdpa(q, k, v_corrupt, scale, causal=True)
assert np.allclose(out_c2[:, 0, :], out_c[:, 0, :], atol=1e-12), "causal mask leaked future info into row 0"
assert not np.allclose(out_c2[:, -1, :], out_c[:, -1, :]), "last row should see corrupted future keys"
print("sdpa math: softmax-normalization, causal row-0 identity, and no-future-leak all PASS")
Run: python3 prints sdpa math: softmax-normalization, causal row-0 identity, and no-future-leak all PASS.
How to integrate it¶
FlashAttention rarely enters your code as a raw kernel. It arrives through one of three integration surfaces:
- PyTorch SDPA (above): the default path for training and custom models. Eligible shapes/dtypes dispatch to the Flash backend automatically;
sdpa_kernelpins it explicitly. - Serving engines. vLLM and SGLang select a Flash/PagedAttention kernel internally per model and hardware. You do not call it; you pick the engine and confirm the kernel via profiling. See inference serving and serving OSS models.
- MLA models are trained in, not flagged on. MLA is a model-architecture decision, not a runtime switch. To use it, serve a model designed with MLA (DeepSeek-V2/V3/R1). Both vLLM and SGLang provide first-class support for DeepSeek MLA models (Fregly, Ch. 18).
The cost that makes MLA worth integrating is KV-cache bytes per token. The numpy block below reproduces DeepSeek-V2 Table 1, asserts the paper's 4.5 · d_h identity, quantifies the reduction versus MHA at the model's real 128-head configuration, and pins the crossover boundary: the fixed MLA latent only wins once heads are plentiful, so the "far below MHA" claim is stated with its precondition (runnable):
import numpy as np
def kv_cache_elems(kind, n_h, d_h, n_g=1, d_c=None, d_hR=None):
"""KV-cache elements per token per layer (DeepSeek-V2 Table 1)."""
if kind == "MHA": return 2 * n_h * d_h
if kind == "GQA": return 2 * n_g * d_h
if kind == "MQA": return 2 * d_h
if kind == "MLA": return d_c + d_hR # one latent + decoupled RoPE key
raise ValueError(kind)
d_h = 128 # head dim
n_h = 128 # DeepSeek-V2: 128 attention heads
d_c = 4 * d_h # compressed latent dim = 4 * d_h
d_hR = d_h // 2 # decoupled RoPE key dim = d_h / 2
mla = kv_cache_elems("MLA", n_h, d_h, d_c=d_c, d_hR=d_hR)
mha = kv_cache_elems("MHA", n_h, d_h)
# 1. Paper identity: MLA cache == 4.5 * d_h elements per token.
assert mla == d_c + d_hR == 4 * d_h + d_h // 2, "MLA cache formula mismatch"
assert mla == int(4.5 * d_h), f"MLA should be 4.5*d_h, got {mla}"
# 2. Reduction vs MHA at the real head count (the whole point of MLA).
assert mla < mha, "MLA must be smaller than MHA"
ratio = mha / mla
assert abs(ratio - (2 * n_h) / 4.5) < 1e-9
assert ratio > 50, f"expected >50x MHA/MLA reduction at n_h={n_h}, got {ratio:.1f}"
# 3. Adversarial / boundary: with a single head MLA's fixed 4.5*d_h latent is LARGER
# than MHA's 2*d_h. MLA only wins when heads are many; assert the crossover so the
# "far below MHA" claim carries its precondition instead of reading as universal.
mha_1head = kv_cache_elems("MHA", 1, d_h)
assert mla > mha_1head, "at 1 head MLA latent should exceed MHA (crossover boundary)"
crossover = mla / (2 * d_h) # heads at which MHA overtakes MLA
assert 2 <= crossover <= 3, f"crossover near ~2.25 heads, got {crossover}"
# Concrete FP16 byte figures for the page (2 bytes/elem), per token per layer.
mha_bytes, mla_bytes = mha * 2, mla * 2
assert (mha_bytes, mla_bytes) == (65536, 1152)
print(f"MLA cache = {mla} elems (4.5*d_h); MHA = {mha}; reduction {ratio:.1f}x; "
f"crossover ~{crossover:.2f} heads -- all PASS")
Run: python3 prints MLA cache = 576 elems (4.5*d_h); MHA = 32768; reduction 56.9x; crossover ~2.25 heads -- all PASS. At d_h = 128 and 128 heads that is 1152 bytes/token/layer for MLA against 65536 for MHA in FP16, a ~57× cut, and the crossover assertion documents that MLA is the wrong choice for few-head models.
How to run it in production¶
Prefer serving through vLLM or SGLang, which integrate FlashAttention, FlashMLA, and PagedAttention rather than calling any kernel directly. For DeepSeek MLA decode specifically:
DeepSeek open-sourced FlashMLA (CUDA C++) during Open-Source Week, February 2025 (Fregly, Ch. 1). It is to decode what FlashAttention is to prefill: an IO-aware fused kernel for the single-token decode step. The book states FlashMLA "increases arithmetic intensity by fusing multiple attention operations into one ... process[ing] multiple heads and multiple time steps in one fused kernel launch," which keeps the math units busy despite small decode batch sizes, and it "pages KV cache and allocates the cache in fixed-size blocks (pages) so that contiguous memory accesses can happen for active sequences," reducing cache misses and DRAM traffic (Fregly, Ch. 18).
# Reference template: DeepSeek's open-source FlashMLA decode kernel (Hopper).
git clone https://github.com/deepseek-ai/FlashMLA
cd FlashMLA
python setup.py install
In practice, prefer serving through vLLM or SGLang, which integrate FlashMLA and PagedAttention rather than calling the kernel directly. See serving OSS models and inference serving.
How to scale it¶
- Longer context. FlashAttention's linear memory is the enabler: the
N²score matrix never lands in HBM, so context grows without quadratic activation blow-up. Bandwidth, not capacity, becomes the limit (roofline / arithmetic intensity). - Higher decode concurrency. MLA's
4.5 · d_hcache per token (versus MHA's2 · n_h · d_h) is what raises the batch size a fixed HBM budget can hold, and larger decode batches raise arithmetic intensity and goodput (goodput for AI systems). - Fuse harder at decode. Stanford's ThunderMLA builds on FlashMLA as a fused decode "megakernel" and reports "20–35% faster decode throughput compared to FlashMLA across different workloads" (Fregly, Ch. 18).
- Exploit the newest silicon. FA3 exploits Hopper WGMMA Tensor Cores, TMA async copy, warp-specialization, and FP8 to reach ~740 TFLOPS FP16 (~75% of H100 peak) and close to ~1.2 PFLOPS in FP8 (FA3 blog); scaling throughput per GPU is a matter of matching kernel version to hardware generation.
How to maintain it¶
- Verify the backend. Profile with Nsight; confirm the Flash/FlashMLA kernel ran and that attention is no longer the dominant kernel. A silent fallback to the math SDPA backend erases the win.
- Match dtype/head-dim constraints. FA2/FA3 and the Flash SDPA backend support specific head dimensions and dtypes; an unsupported config falls back. FA3's FP8 path targets Hopper specifically.
- Do not expect MLA to retrofit. For existing MHA checkpoints you cannot retrain, use GQA/MQA or FP8/INT8 KV quantization to shrink the cache instead. Both are model/serving-level decisions, validated against accuracy on your max context.
Failure modes¶
- Silent math fallback. The most common failure is invisible:
scaled_dot_product_attention(or a serving engine) quietly drops to theMATHbackend when a constraint is unmet, so you keep the quadratic HBM traffic while believing FlashAttention ran. Confirm the actual kernel in Nsight Systems; do not trust the API call alone. - Unsupported head dim / dtype. FA2/FA3 accept specific head dimensions and only fp16/bf16 (and FP8 on Hopper for FA3). An unsupported head dim, an fp32 tensor, or an exotic mask forces the fallback above. Match the config to the installed backend before profiling.
- Wrong hardware for the kernel version. FA3 and its FP8 path are Hopper-specific (WGMMA, TMA). Running the FA3 path on Ampere silently loses the async-copy and Tensor-Core gains, or fails to build.
- Numerical overflow if online softmax is reimplemented naively. A hand-rolled attention kernel that computes
exp(S)before subtracting the running max overflows on peaked logits (the adversarial case in the online-softmax block above overflows a naiveexp). Always carry the running max and thealpharescale. - Expecting MLA as a runtime flag. MLA is trained into the architecture. Pointing a serving engine at an MHA/GQA checkpoint will not "enable MLA"; there is no latent projection to cache. Use GQA/MQA or KV-cache quantization for models you cannot retrain.
- MLA on few-head models. MLA's cache is a fixed
4.5 · d_hregardless of head count, so below the ~2.25-head crossover (validated above) it is larger than MHA. MLA pays off only for the many-head models it was designed for. - FlashMLA outside its niche. FlashMLA is a decode kernel for MLA models on Hopper. It is not a general prefill attention kernel and not a substitute for FlashAttention in training; prefer vLLM/SGLang integration over calling it directly.
References¶
- Chris Fregly, AI Systems Performance Engineering (O'Reilly). Ch. 1 (FlashAttention 2×–4×, MLA, FlashMLA open-source); Ch. 18 (FlashMLA decode kernel, ThunderMLA 20–35%, KV cache sizing, PagedAttention).
- FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness: https://openreview.net/pdf?id=H4DqfPSibmx
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning, arXiv:2307.08691: https://arxiv.org/pdf/2307.08691
- FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision (Tri Dao, 2024): https://tridao.me/blog/2024/flash3/
- DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model, arXiv:2405.04434 (MLA, low-rank KV joint compression, Table 1): https://arxiv.org/pdf/2405.04434
- PyTorch
torch.nn.attention.sdpa_kernel/SDPBackend: https://docs.pytorch.org/docs/2.12/generated/torch.nn.attention.sdpa_kernel.html - DeepSeek FlashMLA (open source): https://github.com/deepseek-ai/FlashMLA
Related: Roofline Model and Arithmetic Intensity · Shared Memory, Bank Conflicts, and Tiling · GPU Memory Hierarchy · Tensor Cores and Mixed Precision · Inference Serving and Optimization · Serving OSS Models · Nsight Profiling Workflow · Goodput for AI Systems · Glossary