torch.compile: graph capture, backends, and recompiles¶
Scope: how torch.compile turns eager PyTorch into fused kernels, covering TorchDynamo bytecode capture, AOTAutograd, the TorchInductor (Triton) backend, the four compile modes, guard-driven recompiles and graph breaks, dynamic shapes, and how to diagnose recompiles with TORCH_LOGS, plus when compilation pays off and when it does not.
What it is¶
torch.compile(model) is a JIT compiler for PyTorch. It wraps a callable and, on first execution, traces the Python code into a graph, optimizes and fuses it, and generates target-specific kernels (Triton on NVIDIA GPUs). The result runs in place of the original eager code. Wrapping is the only required change, and you do not rewrite the model. (Fregly, Ch. 13)
The stack has three layers, each owning a distinct job (Fregly, Ch. 13):
- TorchDynamo is a CPython frame evaluation hook. It symbolically traces Python bytecode into an FX graph, installing guards (assumptions about tensor shape, dtype, device, stride, and Python globals). When it hits something it cannot trace, it inserts a graph break: it compiles the traceable region, runs the untraced part in eager Python, then resumes tracing. (Dynamo Overview)
- AOTAutograd traces the backward pass ahead of time, so both forward and backward are captured and fusible, not just forward.
- TorchInductor is the default backend. It lowers the captured graph to fused kernels, emitting Triton for GPU and C++/OpenMP for CPU. It performs pointwise/vertical/horizontal fusion plus prologue (e.g. bias-add) and epilogue (activation, dropout, residual) fusion around GEMMs, then autotunes kernel variants over tile sizes and access patterns and picks the fastest for the hardware. (TorchInductor / Inductor backend; Fregly, Ch. 13)
The default signature options that matter here come from the official API. This is a reference template (needs torch; not run here) (torch.compile):
torch.compile(
model,
mode=None, # None == "default"; also "reduce-overhead",
# "max-autotune", "max-autotune-no-cudagraphs"
fullgraph=False, # True => error on any graph break instead of falling back
dynamic=None, # None => auto-detect; True => assume dynamic shapes;
# False => assume static, specialize on first shape
backend="inductor",
)
fullgraph=True is a correctness/perf assertion: it forces a hard error on the first graph break rather than silently falling back to eager, which is how you guarantee a single captured graph. (torch.compile)
Why use it¶
Compilation removes two costs that eager mode pays every iteration: Python interpreter overhead and per-kernel launch latency. Fusing many small ops into fewer, larger kernels keeps intermediates in registers and shared memory instead of round-tripping through HBM, which also raises arithmetic intensity (see Kernel Fusion, Roofline Model and Arithmetic Intensity). (Fregly, Ch. 13)
That fusion claim is the page's central mechanism, and it is testable in isolation. The block below fuses a pointwise chain into one kernel, then asserts it equals a slow per-op reference, that an adversarial mis-fused variant does not, and that fusion cuts kernel launches and HBM traffic while raising arithmetic intensity. Validated with numpy 2.4.6 under python3:
# Fusion equivalence + memory-traffic / arithmetic-intensity model.
# Core claim (Fregly Ch. 13): fusing a chain of pointwise ops into ONE kernel
# keeps intermediates in registers instead of round-tripping through HBM, so it
# cuts kernel launches and HBM traffic and raises arithmetic intensity, WITHOUT
# changing the result. We assert equivalence to a slow per-op reference, an
# adversarial mis-fused variant that must NOT match, and the traffic/AI drop.
import numpy as np
rng = np.random.default_rng(0)
N = 100_000
x = rng.standard_normal(N).astype(np.float32)
a, b, s = np.float32(1.5), np.float32(0.25), np.float32(2.0)
class HBM:
"""Counts elements moved to/from HBM and kernel launches."""
def __init__(self):
self.elems = 0
self.launches = 0
def kernel(self, fn, *reads, out_len):
for r in reads:
self.elems += r.size # read inputs from HBM
y = fn()
self.elems += out_len # write output to HBM
self.launches += 1
return y
def unfused(x):
m = HBM()
t1 = m.kernel(lambda: x * a, x, out_len=N) # each op is its own kernel
t2 = m.kernel(lambda: t1 + b, t1, out_len=N) # and spills to HBM
t3 = m.kernel(lambda: np.maximum(t2, 0.0), t2, out_len=N)
y = m.kernel(lambda: t3 * s, t3, out_len=N)
return y, m
def fused(x):
m = HBM()
# one kernel: read x once, write y once; t1..t3 never leave registers
y = m.kernel(lambda: np.maximum(x * a + b, 0.0) * s, x, out_len=N)
return y, m
yu, mu = unfused(x)
yf, mf = fused(x)
# 1. Equivalence: identical ops in identical order => bit-identical result.
assert np.array_equal(yu, yf), "fused must equal unfused reference"
# 2. Adversarial mis-fusion (relu BEFORE bias) must NOT match on a sign-crossing
# input. x*a = -1.5 -> relu 0 -> +b = 0.5 (wrong); correct = relu(-1.25) = 0.
xc = np.array([-1.0], dtype=np.float32)
wrong = (np.maximum(xc * a, 0.0) + b) * s
right = np.maximum(xc * a + b, 0.0) * s
assert not np.allclose(wrong, right), "test must distinguish fusion order"
assert right[0] == 0.0 and wrong[0] == 0.5
# 3. Fusion cuts launches and HBM traffic and raises arithmetic intensity.
assert (mu.launches, mf.launches) == (4, 1)
assert mu.elems == 8 * N and mf.elems == 2 * N # 4*(r+w) vs (r x + w y)
flops = 4 * N # mul, add, max, mul per elem
ai_unfused = flops / (mu.elems * 4) # float32 = 4 bytes
ai_fused = flops / (mf.elems * 4)
assert ai_fused == 4 * ai_unfused # 0.5 vs 0.125 flops/byte
print(f"fusion: launches {mu.launches}->{mf.launches}, traffic "
f"{mu.elems//N}N->{mf.elems//N}N, AI {ai_unfused}->{ai_fused} flops/byte")
print("block1 OK: fusion equivalence + traffic/AI reduction, mis-fusion rejected")
The book's MoE example times one training iteration at roughly 248 ms eager vs about 173 ms compiled with max-autotune, about 30% faster. These numbers are illustrative of one model on one machine, not hardware verified here; the book states actual speedup varies with model structure, batch size, and dynamic shapes. (Fregly, Ch. 13)
The gain is structural, not universal:
- Sparse / many-small-kernel models (MoE) win big: hundreds of medium GEMMs plus token dispatch/combine and per-token activations fuse away, cutting launch count and Python overhead. (Fregly, Ch. 13)
- Dense models dominated by one massive GEMM see modest gains: a single large matmul already saturates Tensor Cores (Tensor Cores and Mixed Precision), leaving little to fuse or de-overhead. (Fregly, Ch. 13)
Compile time is paid once (seconds to minutes for large models) and amortized over a long run. TorchInductor caches compiled kernels so later runs skip the cost; torch.compiler.save_cache_artifacts() / load_cache_artifacts() persist Inductor outputs across runs or nodes (keep CUDA, PyTorch, and Triton versions compatible across nodes). (Fregly, Ch. 13)
When to use it (and when not)¶
Use torch.compile as the first-resort performance lever; it is low effort and usually "free." Reach for custom Triton/CUDA kernels only for hotspots Inductor handles poorly or cannot capture; even then you can register a custom op and let Inductor fuse the surrounding graph around it. (Fregly, Ch. 13)
Pick a mode by workload (torch.compile docs; Fregly, Ch. 13, Table 13-5):
| Mode | Compile time | Extra memory | What it does |
|---|---|---|---|
default |
low–medium | none | General fusion, basic autotuning; may use CUDA Graphs for stable segments. Best when compile time or memory is tight (large models). |
reduce-overhead |
medium | yes (workspace caching) | Uses CUDA Graphs to kill per-iteration launch overhead. Best for inference / small batches. Auto-skips graphs and falls back to eager if it detects dynamic shapes. |
max-autotune |
high (slow) | maybe (if graphs used) | Aggressive Triton autotuning over many tilings; enables CUDA Graphs on GPU. Compile once, run many. Can occasionally regress latency, so profile it. |
max-autotune-no-cudagraphs |
high | none | Same tuning as max-autotune but no CUDA Graph capture. For varying input shapes, memory-constrained cases, or debugging issues masked by graphs. |
Guidance from the book: start with default. Use reduce-overhead when Python overhead in a tight loop of small ops is the bottleneck; it removes nearly all launch overhead via CUDA Graphs, but only when per-iteration work is consistent (no dynamic shapes, no new allocations). For highly dynamic MoE token routing, prefer default or max-autotune-no-cudagraphs, then switch to max-autotune once shapes stabilize. (Fregly, Ch. 13)
Do not expect much when: the model is one big GEMM (already Tensor-Core bound), shapes change every iteration (recompile churn dominates), or the hot region is data-dependent control flow Dynamo cannot trace. CUDA Graph modes also allocate large static buffers, so avoid them under tight memory. (Fregly, Ch. 13)
Architecture¶
The three layers form a pipeline. TorchDynamo hooks CPython frame evaluation and rewrites bytecode into an FX graph, attaching guards on shape, dtype, device, stride, and Python globals. A region Dynamo cannot trace becomes a graph break: the traced part compiles, the untraceable part runs in eager Python, and tracing resumes after it. The captured graph passes to AOTAutograd, which traces the backward pass ahead of time so forward and backward are one joint graph. TorchInductor lowers that joint graph, fuses it, autotunes, and emits Triton kernels on NVIDIA GPUs or C++/OpenMP kernels on CPU. On a later call, if any guard fails (a new shape, dtype, device, stride, or changed global), Dynamo recompiles that frame.
flowchart TB
Eager["Eager PyTorch model<br/>torch.compile(model)"]
Dynamo["TorchDynamo<br/>(CPython frame hook)"]
Capture["Bytecode -> FX graph<br/>(install guards: shape/dtype/device)"]
Break["Graph break:<br/>untraceable region runs eager, then resumes"]
AOT["AOTAutograd<br/>(joint forward + backward graph)"]
Inductor["TorchInductor backend<br/>(lower + fuse + autotune)"]
Triton["Triton kernels (NVIDIA GPU)"]
Cpp["C++ / OpenMP kernels (CPU)"]
Recompile["Guard fails on new shape/dtype:<br/>Dynamo recompiles the frame"]
Eager --> Dynamo
Dynamo --> Capture
Capture -->|"cannot trace"| Break
Break --> Capture
Capture --> AOT
AOT --> Inductor
Inductor --> Triton
Inductor --> Cpp
Capture -.->|"later call"| Recompile
Recompile --> Dynamo
The dashed recompile edge is the load-bearing invariant: a compiled frame is valid only while its guards hold, so shape or dtype drift silently sends execution back through Dynamo. The guard-and-recompile block under How to maintain it makes that behaviour executable, and the graph-break edge is modelled under How to integrate it into a model.
How to use it¶
Basic use and mode selection need no model changes beyond the wrap. Reference template (needs torch; not run here):
import torch
model = torch.compile(model) # mode="default"
model = torch.compile(model, mode="reduce-overhead") # CUDA Graph Trees
model = torch.compile(model, mode="max-autotune") # longest compile, best runtime
Always warm up before timing: the first call(s) pay tracing, autotuning, and JIT compilation. Time steady state with CUDA events around a measured loop, not the first iteration. Reference template (needs torch + CUDA; not run here) (Fregly, Ch. 13):
torch.cuda.synchronize()
start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
for _ in range(warmup): # absorb compile + autotune
out = compiled_model(x); out.loss.backward()
optimizer.step(); optimizer.zero_grad(set_to_none=True)
torch.cuda.synchronize(); start.record()
for _ in range(iters):
out = compiled_model(x); out.loss.backward()
optimizer.step(); optimizer.zero_grad(set_to_none=True)
end.record(); torch.cuda.synchronize()
print(f"{start.elapsed_time(end)/iters:.3f} ms/iter")
Why the warm-up matters is a small, checkable claim: the one-time compile spike lands on the first iteration, so a mean that includes it overestimates steady-state cost. The block below models the iteration timeline and asserts that timing after warm-up recovers the true per-iteration cost, while unwarmed timing does not. Validated with numpy 2.4.6 under python3:
# Warm-up timing model: why the page insists you time STEADY STATE, not the first
# iteration. The first call to a compiled model pays tracing + autotuning + JIT
# compile (a one-time spike); measuring it inflates the per-iter estimate. We
# assert the steady-state mean recovers the true per-iter cost, that a naive mean
# over all iters overestimates it, and the adversarial case: timing with ZERO
# warm-up leaks the compile spike into the measured window (a wrong benchmark).
import numpy as np
def iter_times(n_warmup, n_iter, steady=1.0, compile_cost=50.0):
"""Per-iteration wall time. The compile spike lands on the very first call."""
return np.array([steady + (compile_cost if i == 0 else 0.0)
for i in range(n_warmup + n_iter)])
t = iter_times(n_warmup=3, n_iter=10)
naive = t.mean() # includes the compile spike -> wrong (too high)
steady = t[3:].mean() # measured AFTER warm-up -> true per-iter cost
assert steady == 1.0
assert naive > steady # spike inflates the naive mean
assert np.isclose(t[3:].sum(), 10 * 1.0) # compile cost is one-time only
# Adversarial: no warm-up leaks the spike into the timed window.
leaked = iter_times(n_warmup=0, n_iter=10).mean()
assert leaked > steady
assert np.isclose(leaked, (50.0 + 10 * 1.0) / 10) # = 6.0, six-fold overestimate
# Warming up longer never changes the steady estimate (spike already absorbed).
for w in (1, 5, 50):
tt = iter_times(n_warmup=w, n_iter=8)
assert tt[w:].mean() == 1.0, w
print(f"block4 OK: steady={steady:.3f} (true), naive={naive:.3f}, "
f"unwarmed={leaked:.3f} -> warm-up required before timing")
How to integrate it into a model¶
Graph breaks. A break splits the model into multiple compiled graphs with eager glue between them, eroding the fusion benefit. torch._dynamo.explain reports the graph count, break count, op count, and the reason for each break. Reference template (needs torch; not run here) (torch.compiler FAQ; Fregly, Ch. 13):
explanation = torch._dynamo.explain(model)(x) # call the returned fn with real inputs
print(explanation) # graph/break counts + break reasons
To forbid breaks entirely (and get a hard error pointing at the offending op), compile with fullgraph=True. (torch.compile)
The break-versus-graph accounting and the fullgraph=True contract are both testable without a GPU. The block below models untraceable ops as breaks, then asserts that a program starting and ending traceable yields breaks + 1 compiled subgraphs, that the eager glue is numerically identical to a pure-eager run, and that fullgraph=True hard-errors on the first break. Validated with numpy 2.4.6 under python3:
# Graph-break model: what torch._dynamo.explain reports and what fullgraph=True
# enforces. An untraceable op is a GRAPH BREAK: Dynamo compiles the traceable
# region, runs the untraceable op in eager Python, then resumes tracing. So the
# number of compiled subgraphs = (breaks + 1) when the program starts and ends
# traceable, fullgraph=True must HARD ERROR on the first break, and the eager
# glue must be numerically identical to a pure-eager run (a break costs fusion,
# never correctness). We assert all of these, including the fullgraph adversary.
import numpy as np
def run_compiled(ops, x, fullgraph=False):
"""ops: list of (fn, traceable). Returns (result, n_graphs, n_breaks)."""
graphs = breaks = 0
in_graph = False
y = x
for fn, traceable in ops:
if traceable:
if not in_graph:
graphs += 1 # open a new compiled subgraph
in_graph = True
y = fn(y) # (fused into the current subgraph)
else:
if fullgraph:
raise RuntimeError("graph break encountered with fullgraph=True")
breaks += 1
in_graph = False # close subgraph; run this op eager
y = fn(y)
return y, graphs, breaks
def eager(ops, x):
y = x
for fn, _ in ops:
y = fn(y)
return y
sq = (lambda v: v * v, True)
inc = (lambda v: v + 1.0, True)
brk = (lambda v: v - 3.0, False) # models an untraceable op (.item(), print, ...)
x = np.arange(5, dtype=np.float64)
# 1. Two breaks separating three traceable regions -> 3 graphs, 2 breaks.
ops = [sq, inc, brk, sq, brk, inc]
y, g, b = run_compiled(ops, x)
assert (g, b) == (3, 2), (g, b)
# 2. Eager glue is numerically identical to a pure-eager run.
assert np.array_equal(y, eager(ops, x))
# 3. No untraceable op -> a single fused graph, zero breaks.
_, g0, b0 = run_compiled([sq, inc, sq], x)
assert (g0, b0) == (1, 0)
# 4. Adversarial: fullgraph=True must raise on the first break.
raised = False
try:
run_compiled(ops, x, fullgraph=True)
except RuntimeError as exc:
raised = "fullgraph=True" in str(exc)
assert raised, "fullgraph=True must hard-error on a graph break"
# 5. Invariant graphs == breaks + 1 for a program that starts/ends traceable.
for probe in ([sq, brk, inc, brk, sq, brk, inc], [sq, brk, sq], [inc]):
_, gg, bb = run_compiled(probe, x)
assert gg == bb + 1, (probe, gg, bb)
print("block3 OK: 3 graphs/2 breaks, eager-equivalent glue, fullgraph raises, "
"graphs==breaks+1")
How to run it in production¶
CUDA Graphs interaction. reduce-overhead and on-GPU max-autotune capture CUDA Graphs automatically (no boilerplate), using CUDA Graph Trees to cache one static graph per input shape (see CUDA Graphs: Capture, Replay, and Launch Overhead). CUDA Graphs require static shapes and stable memory addresses; new shapes trigger fresh captures, so keep shapes consistent to maximize cache hits. Use max-autotune-no-cudagraphs when shapes vary or when graph capture masks a bug you are chasing. (Fregly, Ch. 13)
Persisting compiled artifacts. Compile time is paid once and TorchInductor caches kernels, but a fresh process or a new node starts cold. For long jobs and serving, persist the Inductor outputs with torch.compiler.save_cache_artifacts() and reload them with load_cache_artifacts() so later runs and other nodes skip compilation. Keep CUDA, PyTorch, and Triton versions compatible across nodes, or the cache will not load. Warm up once at startup (a throwaway forward/backward at the serving shapes) so the first real request does not pay tracing and autotuning. (Fregly, Ch. 13)
How to maintain it¶
Recompiles and guards. Each compiled frame is protected by guards. If a guard fails on a later call (a new shape, dtype, device, stride, or changed Python global), Dynamo recompiles that frame (up to torch._dynamo.config.recompile_limit, after which it falls back to eager). Unintended recompiles are a top cause of "compiled but slow." (torch.compiler FAQ, Dynamo Overview)
Diagnose with TORCH_LOGS. The env var is a comma-separated list of [+-]<component> entries; + lowers the level (more output). The recompiles artifact logs each recompilation with the guard that failed; recompiles_verbose prints all failing guard checks; graph_breaks and guards cover the rest. (torch._logging, TORCH_LOGS recipe)
TORCH_LOGS="recompiles" python train.py # why each recompile fired
TORCH_LOGS="recompiles_verbose,graph_breaks" python train.py
TORCH_LOGS="+dynamo,+inductor" python train.py # full verbose pipeline trace
The book uses TORCH_LOGS="+dynamo" / "+dynamo,+inductor" to surface where execution exits the compiled graph; the recompiles artifact above is the official, narrower lens for recompile reasons. (Fregly, Ch. 13; torch._logging)
Dynamic shapes. Modern PyTorch supports partial dynamic shapes via shape guards, which eliminates some unnecessary graph breaks; truly dynamic workloads may still fall back to eager (or to max-autotune-no-cudagraphs) for correctness. To pre-empt shape-driven recompiles, mark a dimension dynamic so Dynamo compiles a shape-generic graph instead of specializing. Reference template (needs torch; not run here):
torch._dynamo.mark_dynamic(x, 0) # dim 0 (e.g. batch / seq) may vary
model = torch.compile(model, dynamic=True)
mark_dynamic tells the compiler to expect varying sizes on that dim; dynamic=True assumes dynamic shapes globally, dynamic=False specializes on the first shape seen, dynamic=None auto-detects. (Fregly, Ch. 13; torch.compile)
The guard-and-recompile behaviour those knobs control is the model below. It reproduces a guard-keyed compile cache and asserts that distinct static shapes each recompile, that a dynamic dimension collapses them to a single compile, that the dtype guard still fires, that exceeding recompile_limit falls back to eager, and that recompiling never changes the result. Validated with numpy 2.4.6 under python3:
# Guard-driven recompile model: the core TorchDynamo mechanism the page teaches.
# Each compiled frame is protected by GUARDS on (per-dim shape, dtype, device,
# stride). A guard miss triggers a (re)compile keyed by that guard; past
# torch._dynamo.config.recompile_limit distinct guards the frame falls back to
# EAGER. Marking a dim dynamic drops it from the guard so varying that dim
# reuses ONE compiled graph. We assert every one of these behaviours plus the
# adversarial cases (limit->eager, non-dynamic dim still recompiles, dtype guard
# still fires, and recompiling never changes the result).
import numpy as np
class Frame:
def __init__(self, fn, recompile_limit=8, dynamic_dims=()):
self.fn = fn
self.limit = recompile_limit
self.dynamic = set(dynamic_dims)
self.cache = {} # guard_key -> compiled variant id
self.compiles = 0 # cold compile + recompiles
self.next_id = 0
def guard_key(self, x):
shape = tuple(None if i in self.dynamic else s
for i, s in enumerate(x.shape))
return (shape, x.dtype.str, "cuda", x.strides) # what Dynamo specializes on
def __call__(self, x):
key = self.guard_key(x)
if key in self.cache: # guard passes -> reuse variant
return self.fn(x), f"compiled#{self.cache[key]}"
if len(self.cache) >= self.limit: # over recompile_limit -> eager
return self.fn(x), "eager"
self.cache[key] = self.next_id # guard miss -> (re)compile
self.next_id += 1
self.compiles += 1
return self.fn(x), f"compiled#{self.cache[key]}"
fn = lambda x: x * 2.0 + 1.0 # the frame's traced body
# 1. Same shape repeated -> compiled ONCE, then reused (this is also why a
# repeated transformer block / nested_compile_region compiles once).
f = Frame(fn)
outs = [f(np.ones((4, 8), np.float32)) for _ in range(100)]
assert f.compiles == 1
assert all(tag == "compiled#0" for _, tag in outs)
# 2. K distinct static shapes -> K compiles (recompile churn).
f = Frame(fn)
for n in (1, 2, 4, 8, 16):
f(np.ones((n, 8), np.float32))
assert f.compiles == 5
# 3. recompile_limit: past the limit, new shapes fall back to EAGER, not compile.
f = Frame(fn, recompile_limit=8)
for n in range(1, 9): # fill cache with 8 variants
_, tag = f(np.ones((n, 8), np.float32))
assert tag.startswith("compiled")
_, tag = f(np.ones((99, 8), np.float32)) # 9th distinct shape
assert f.compiles == 8 and tag == "eager", (f.compiles, tag)
# 4. dynamic dim 0: varying batch collapses to ONE compile (mark_dynamic).
f = Frame(fn, dynamic_dims=(0,))
for n in (1, 2, 4, 8, 16, 999):
f(np.ones((n, 8), np.float32))
assert f.compiles == 1
# 5. Adversarial: a NON-dynamic dim still recompiles even with dim 0 dynamic.
f = Frame(fn, dynamic_dims=(0,))
for w in (8, 16, 32):
f(np.ones((4, w), np.float32))
assert f.compiles == 3
# 6. Adversarial: the dtype guard still fires under dynamic shapes.
f = Frame(fn, dynamic_dims=(0,))
f(np.ones((4, 8), np.float32))
f(np.ones((4, 8), np.float64)) # same shape, new dtype -> recompile
assert f.compiles == 2
# 7. Recompiling is a perf mechanism, not a correctness change: output == eager.
f = Frame(fn, dynamic_dims=(0,))
for n in (1, 5, 5, 7):
x = np.arange(n * 8, dtype=np.float32).reshape(n, 8)
y, _ = f(x)
assert np.array_equal(y, fn(x))
print("block2 OK: guards, per-shape recompiles, limit->eager, dynamic=1 compile, "
"dtype guard, correctness preserved")
Profiling and verification. Profile every mode for your specific model and GPU; there is no universal winner, and max-autotune can regress on some models (Fregly, Ch. 13). Verify the win in torch.profiler / Nsight Systems: a correctly compiled run shows fewer, longer GPU kernels and the trace marks "Compiled Function" regions plus any graph breaks. Watch memory when using CUDA-Graph modes (static buffers inflate the footprint). Allocator tuning (PYTORCH_CUDA_ALLOC_CONF, cudaMallocAsync) is covered in PyTorch CUDA Caching Allocator Tuning.
How to scale it¶
Regional compilation. For models built from many identical blocks (transformer/MoE layers), torch.compiler.nested_compile_region marks a repeated block so it is compiled once and reused across all occurrences, cutting cold-start compile time and recompile churn without losing fusion. Runtime throughput matches full-model compilation; the region transparently recompiles if input shape, dtype, device, stride, or globals change. Available since PyTorch 2.5. Reference template (needs torch; not run here) (Fregly, Ch. 13; nested_compile_region, regional compilation recipe):
@torch.compiler.nested_compile_region
class TransformerBlock(torch.nn.Module):
... # compiled once, reused per layer
Compile-once-reuse is the same guard-hit mechanism the guard model above validates: identical blocks share a guard key, so the second and later occurrences reuse the first compilation instead of paying it again (block2 case 1: 100 calls, one compile). For a model with L identical layers, that turns L compilations into one.
Across processes and nodes. Persist Inductor artifacts (the save_cache_artifacts / load_cache_artifacts pair from How to run it in production) so a scaled-out job compiles once and every worker loads the same cache, provided CUDA, PyTorch, and Triton versions match. Under CUDA-Graph modes, remember each captured shape pins its own static buffers, so wide shape coverage trades HBM for fewer recaptures; keep the shape set small at scale.
Failure modes¶
- Silent eager fallback past
recompile_limit. Once a frame exceedstorch._dynamo.config.recompile_limitdistinct guard sets, Dynamo stops compiling and runs it eager, so a "compiled" model quietly loses the speedup (reproduced in the block2 limit case). Cap shape variety or mark dimensions dynamic. (torch.compiler FAQ) - Recompile churn from drifting guards. A new shape, dtype, device, stride, or changed Python global fails a guard and recompiles the frame; if it happens every iteration, compile cost dominates (block2).
mark_dynamic/dynamic=Trueis the fix for shape drift;TORCH_LOGS="recompiles"names the failing guard. - Graph breaks erode fusion. An untraceable op splits the model into multiple graphs with eager glue between them, so the fusion win shrinks (block3). Find them with
torch._dynamo.explain; forbid them withfullgraph=True. fullgraph=Trueon untraceable control flow. Data-dependent Python control flow Dynamo cannot trace makesfullgraph=Truehard-error (block3). Either rewrite the region to be traceable or accept the break withoutfullgraph.max-autotuneregresses latency. Aggressive autotuning occasionally picks a slower configuration for a given model; it is not a guaranteed win. Profile it againstdefaultbefore shipping. (Fregly, Ch. 13)- CUDA-Graph modes exhaust memory.
reduce-overheadand on-GPUmax-autotuneallocate large static buffers per captured shape; under tight memory they OOM. Usemax-autotune-no-cudagraphs, or hold the shape set small. (Fregly, Ch. 13) - One big GEMM sees almost no gain. A dense model dominated by a single large matmul is already Tensor-Core bound, so there is little to fuse or de-overhead and the speedup is minimal. Compile is still cheap, but do not expect much. (Fregly, Ch. 13)
- Cache artifacts do not load across mismatched nodes.
save_cache_artifactsoutput is tied to the CUDA/PyTorch/Triton versions that produced it, so a version-skewed node silently recompiles instead of loading. Pin the toolchain across the fleet. (Fregly, Ch. 13)
Reference templates (the
torch.compile, timing,explain,mark_dynamic, andnested_compile_regionblocks) needtorchand are not run here; their APIs and numbers are grounded in the cited book chapter and official PyTorch docs. The four numpy blocks are executed and asserted underpython3(numpy 2.4.6). Benchmark on your target before relying on any latency figure.
References¶
- Chris Fregly, AI Systems Performance Engineering (O'Reilly), Ch. 13: "PyTorch Compiler (torch.compile)", "Compilation Modes and Trade-Offs", "Regional Compilation", "Profiling and Debugging Compiler Performance Issues", "CUDA Graph Trees". Primary source for the compiler-stack description, the eager-vs-compiled MoE timing (~248 ms vs ~173 ms), the mode trade-off table (Table 13-5), the
TORCH_LOGS/explain/mark_dynamicworkflow, and the sparse-vs-dense speedup guidance. - PyTorch,
torch.compileAPI: exact mode strings,fullgraph,dynamic, default backendinductor. - PyTorch, Dynamo Overview: bytecode capture, guards, graph breaks, recompile-on-guard-failure.
- PyTorch, torch.compiler (Inductor/backends): TorchInductor as default backend, Triton/C++ codegen.
- PyTorch, torch.compiler FAQ:
torch._dynamo.explain, recompiles,recompile_limit. - PyTorch, torch._logging and TORCH_LOGS recipe:
recompiles,recompiles_verbose,graph_breaks,guardsartifacts;[+-]<component>syntax. - PyTorch,
torch.compiler.nested_compile_regionand regional compilation recipe: repeated-block compile-once reuse (since 2.5).
Related: PyTorch CUDA Caching Allocator Tuning · PyTorch/XLA and the XLA Compiler · Activation Checkpointing and Memory Offloading · PyTorch Attention APIs: SDPA and FlexAttention · PyTorch Performance Regression Testing in CI · Kernel Fusion · OpenAI Triton: Authoring GPU Kernels in Python · CUDA Graphs: Capture, Replay, and Launch Overhead · Tensor Cores and Mixed Precision · Roofline Model and Arithmetic Intensity · Profiling GPUs: Nsight Systems and Nsight Compute · Frameworks · Performance Optimization and Tuning · Glossary