PyTorch/XLA and the XLA compiler¶
Scope: lazy-tensor tracing into an XLA/HLO graph, the XLA fusion/layout compiler, mark_step/torch_xla.sync() graph boundaries, per-shape-signature compilation caching and recompiles from dynamic shapes, and when XLA beats eager or Inductor.
What it is¶
PyTorch/XLA is an alternate compiler backend that maps PyTorch operations into XLA's graph IR and executes them on the target hardware's runtime. Where TorchInductor is the default backend for NVIDIA/AMD GPUs and CPUs, XLA targets Google Cloud TPUs and other accelerators that adopt XLA IR. The book is explicit that "XLA isn't commonly used with NVIDIA GPUs" but is "a powerful backend for non-NVIDIA hardware": Meta's MTIA inference ASIC uses XLA, and AWS Inferentia/Trainium run PyTorch through an XLA compiler in the AWS Neuron SDK.12
The execution model is lazy tensors. An XLA tensor records operations into an IR graph rather than executing them eagerly; "they record operations in a graph until the results are needed."8 This deferred trace is the unit the compiler optimizes: separate ops can be fused into a single optimized operation. The recorded IR is lowered to HLO (High Level Operations), XLA's machine-readable computation format, which the XLA compiler turns into hardware-specific code with its own fusion and layout-assignment passes.9
Two activation paths exist:
- Lazy-tensor mode (move tensors/model
.to(device)), where the framework auto-constructs graphs and you mark execution boundaries explicitly.8 torch.compile(..., backend="openxla"), which routes the TorchDynamo FX graph through the same Lazy Tensor machinery so the model is traced once at init instead of every step.31213
Why use it¶
The book frames the trade-off against Inductor directly. XLA "compiles whole programs ahead of time because XLA is designed to generate static graphs," and is "optimized for static shapes or bounded dynamic shapes."4 That ahead-of-time, whole-program posture is what lets XLA do aggressive cross-op fusion and layout selection for static graphs, but it is also why its dynamic-shape behavior is more brittle than Inductor's symbolic-shape approach.
The single load-bearing operational difference: "XLA will not incrementally compile mid-run." The graph is built statically ahead of time; a new shape triggers a fresh whole-graph compilation, "which is very expensive and impacts latency-sensitive workloads like inference."5 TorchInductor, by contrast, generalizes shapes after the first recompile and can emit one kernel covering a range of sequence lengths (see torch.compile: Graph Capture, Backends, and Recompiles). So with XLA, recompiles are coarser and more costly, and shape discipline is not optional.
When to use it (and when not)¶
Use XLA when:
- The device is not supported by TorchInductor but does support XLA IR (TPU, MTIA, Inferentia/Trainium via Neuron). The book: "if you're running on a hardware device not currently supported by the TorchInductor backend, you can potentially use the XLA backend if the device supports XLA."7
- Shapes are static or tightly bounded (fixed-size inference, padded/bucketed training) so the per-shape compile cost amortizes over many steps.
Do not reach for XLA when:
- You are on NVIDIA GPUs running normal dynamic-shape PyTorch; TorchInductor is the default and is the better-trodden path. XLA "isn't commonly used with NVIDIA GPUs."7
- Your workload is latency-sensitive with frequently varying shapes. New shapes force whole-program recompilation; the book recommends you "pad or use fixed-size buckets for your inputs" because XLA recompiles for new shapes rather than handling them symbolically.6
Many of the same compiler hygiene rules carry over from the Inductor path: minimize graph breaks, and you can still use data- and model-parallel distributed strategies.7
Architecture¶
Both entry paths converge on one pipeline. Eager PyTorch ops, or the FX graph that TorchDynamo hands over under torch.compile(..., backend="openxla"), feed the same LazyTensor machinery.123 The LazyTensor layer records operations into an IR graph and runs nothing until a barrier.8 At the barrier (torch_xla.sync() or xm.mark_step()), the trace is lowered to HLO, which the XLA compiler optimizes with fusion and layout-assignment passes into hardware code.9 The compiled executable is cached per unique input shape and signature: a matching signature is a cache hit, a new shape triggers a fresh whole-graph compilation.511 The runtime then executes on the target accelerator (TPU, MTIA, Trainium, Inferentia).12
flowchart LR
EAGER["Eager PyTorch ops"] --> LT["LazyTensor IR graph (deferred trace)"]
DYN["torch.compile backend=openxla (Dynamo FX graph)"] --> LT
LT -->|"torch_xla.sync() / mark_step()"| HLO["HLO graph"]
HLO --> XLA["XLA compiler (fusion + layout)"]
XLA --> CACHE["Per-shape-signature executable cache"]
CACHE --> DEV["TPU / MTIA / Trainium / Inferentia runtime"]
DEV -.->|"new shape: whole-graph recompile"| XLA
The dashed edge is the cost center: anything that changes the HLO (a new shape, a different signature, non-identical model code) loops back through the compiler instead of hitting the cache.511
How to use it¶
Lazy-tensor mode and the graph boundary¶
In lazy mode you place tensors on the XLA device and mark where the accumulated graph should be cut and executed. The newer barrier API is torch_xla.sync(); xm.mark_step() is the historical equivalent.108 The barrier is what converts the traced IR to HLO and dispatches it.
Reference template (requires a torch_xla install and an XLA device, so it is not runnable in this environment). The numpy model below validates the core execution semantics it teaches.
# Reference template: needs torch_xla + an XLA device (TPU/Trainium/etc.).
import torch
import torch_xla
import torch_xla.core.xla_model as xm
device = torch_xla.device() # was xm.xla_device()
model = MyModel().to(device)
for data, target in loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
out = model(data) # recorded into the lazy graph, not run yet
loss = loss_fn(out, target)
loss.backward()
optimizer.step()
torch_xla.sync() # cut graph -> lower to HLO -> compile/execute
torch_xla.sync() does not block further tracing while the device runs the graph; it only blocks access to the materialized tensor values.10 Avoid pulling tensor values to the host mid-step (printing a loss, .item(), host-side if): each materialization forces an implicit barrier and a graph cut. Reduce value access frequency, or place a deliberate sync() where the cut belongs.10
The numpy model below (runnable with the system python3 plus numpy, no torch) reproduces and asserts that core behavior: operations are recorded and nothing runs until a barrier, a host-value access forces an implicit barrier, and the deferred result is identical to eager evaluation.
import numpy as np
# Reference model of PyTorch/XLA lazy-tensor execution (numpy-only).
# Ops are RECORDED into a deferred graph; nothing runs until a barrier
# (torch_xla.sync / xm.mark_step) or a host-value access (.item()) cuts the
# graph and executes it. Validates deferral, barrier-triggered execution,
# host-access barriers, and exact equivalence to eager evaluation.
class LazyGraph:
def __init__(self):
self.barriers = 0 # times the graph was cut + executed
self.evals = 0 # primitive op evaluations actually performed
class LazyTensor:
def __init__(self, g, compute, inputs):
self.g = g
self._compute = compute # fn(*materialized_inputs) -> ndarray
self._inputs = inputs # parent LazyTensors
self._value = None # None until materialized
def __add__(self, o): return LazyTensor(self.g, np.add, [self, o])
def __mul__(self, o): return LazyTensor(self.g, np.multiply, [self, o])
def _run(self):
if self._value is None:
ins = [t._run() for t in self._inputs]
self.g.evals += 1
self._value = self._compute(*ins)
return self._value
def item(self):
# Pulling a value to the host forces an implicit barrier + graph cut.
return float(np.asarray(sync(self.g, self)).reshape(-1)[0])
def leaf(g, arr):
t = LazyTensor(g, None, [])
t._value = np.asarray(arr, dtype=np.float64) # inputs already materialized
return t
def sync(g, tensor):
g.barriers += 1 # cut graph -> lower to HLO -> execute
return tensor._run()
# --- happy path: deferral + equivalence to eager ---
g = LazyGraph()
a = leaf(g, [1.0, 2.0, 3.0])
b = leaf(g, [10.0, 20.0, 30.0])
expr = (a + b) * b # recorded, NOT executed
assert g.barriers == 0 # nothing ran yet
assert expr._value is None # not materialized
assert g.evals == 0 # no primitive op evaluated
out = sync(g, expr) # barrier: execute the traced graph
ref = (np.array([1., 2., 3.]) + np.array([10., 20., 30.])) * np.array([10., 20., 30.])
assert np.array_equal(out, ref) # deferred result == eager reference
assert g.barriers == 1
assert g.evals == 2 # exactly one add + one mul, run once
# --- laziness holds for a long chain: zero barriers before sync ---
g2 = LazyGraph()
x = leaf(g2, np.ones(4))
acc = x
for _ in range(50):
acc = acc + x
assert g2.barriers == 0 and acc._value is None # 50 ops recorded, none run
assert np.array_equal(sync(g2, acc), np.ones(4) * 51.0)
# --- adversarial: host access (.item()) forces an implicit barrier ---
g3 = LazyGraph()
p = leaf(g3, [2.0])
q = leaf(g3, [5.0])
loss = p * q + p # recorded
assert g3.barriers == 0 # not yet cut
val = loss.item() # host access -> implicit graph cut
assert val == 12.0 # 2*5 + 2
assert g3.barriers == 1 # the .item() caused a barrier
# --- edge: random-array equivalence vs eager reference (fuzz) ---
rng = np.random.default_rng(0)
for _ in range(100):
g4 = LazyGraph()
u, v = rng.standard_normal(8), rng.standard_normal(8)
lu, lv = leaf(g4, u), leaf(g4, v)
assert np.allclose(sync(g4, (lu + lv) * lu), (u + v) * u)
print("BLOCK1 OK: deferral, barrier count, host-access barrier, eager-equivalence")
How to integrate with it¶
torch.compile with the OpenXLA backend¶
backend="openxla" lets TorchDynamo hand its FX graph to PyTorch/XLA, which compiles it once and replays the compiled binary; the model is traced once at init rather than every iteration.1213 This is the recommended path for inference.
Reference template (requires torch_xla and torchvision, so it is not runnable here). The numpy model below validates the per-shape compilation-cache behavior it relies on.
# Reference template: needs torch_xla + torchvision + an XLA device.
import torch
import torchvision
import torch_xla.core.xla_model as xm
device = xm.xla_device()
resnet18 = torchvision.models.resnet18().to(device)
resnet18.eval()
# Backend string is supported by the PyTorch/XLA project and activates
# OpenXLA-based compilation.
dynamo_resnet18 = torch.compile(resnet18, backend="openxla")
for data, _ in loader:
with torch.no_grad():
output = dynamo_resnet18(data) # executes cached compiled binary
Training under backend="openxla" exists but the upstream docs mark it experimental.12
Trace-once-then-replay only pays off if inputs keep hitting the same cached executable. The numpy model below (runnable with the system python3 plus numpy, no torch) reproduces and asserts the cache rule: one compile per distinct shape/signature, dynamic shapes force recompiles, and padding to a fixed bucket collapses everything to a single compile.
import numpy as np
# Reference model of XLA's per-shape-signature compilation cache (numpy-only).
# A new (shape, dtype) signature triggers an expensive whole-graph "compile";
# a seen signature reuses the cached executable. Validates: compile once per
# signature, dynamic shapes cause recompiles, padding/bucketing bounds them.
class XLACache:
def __init__(self):
self.cache = {}
self.compiles = 0 # whole-graph compilations (expensive)
def _signature(self, x):
return (tuple(x.shape), x.dtype.str) # shape + dtype
def run(self, x):
sig = self._signature(x)
if sig not in self.cache:
self.compiles += 1 # recompile for a new signature
self.cache[sig] = lambda a: (a * 2.0 + 1.0).sum() # fused executable
return self.cache[sig](x)
def reference(x): # slow eager reference
return float((x * 2.0 + 1.0).sum())
# --- correctness: cached executable matches eager reference ---
c = XLACache()
x = np.arange(16, dtype=np.float64)
assert c.run(x) == reference(x)
# --- dynamic shapes: one compile per DISTINCT shape signature ---
c = XLACache()
lengths = [7, 7, 13, 7, 100, 13, 7] # ragged sequence lengths
for n in lengths:
seq = np.ones(n, dtype=np.float64)
assert c.run(seq) == reference(seq) # result always correct
assert c.compiles == len(set(lengths)) == 3 # 3 distinct shapes -> 3 compiles
# --- padding/bucketing: fixed shape collapses to a SINGLE compile ---
c = XLACache()
BUCKET = 128
for n in lengths:
padded = np.zeros(BUCKET, dtype=np.float64)
padded[:n] = 1.0
c.run(padded)
assert c.compiles == 1 # every input shares one signature
# --- adversarial edge: same shape, different dtype -> still a recompile ---
c = XLACache()
c.run(np.ones(8, dtype=np.float64))
c.run(np.ones(8, dtype=np.float64)) # cache hit, no recompile
assert c.compiles == 1
c.run(np.ones(8, dtype=np.float32)) # signature differs by dtype
assert c.compiles == 2 # dtype is part of the signature
# --- adversarial: worst case = every input a unique shape (recompile storm) ---
c = XLACache()
for n in range(1, 51): # 50 distinct lengths
c.run(np.ones(n, dtype=np.float64))
assert c.compiles == 50 # pathological: compile every step
print("BLOCK2 OK: signature cache, dynamic recompiles, bucketing=1, dtype edge, storm")
How to run it in production¶
Inference is where the trade-offs bite hardest, so treat the compile cache as part of the serving contract:
- Prefer the compiled path. Serve through
torch.compile(..., backend="openxla")so the model is traced once at init and each request replays the compiled binary instead of re-tracing.1312 - Pin input shapes. New shapes force whole-program recompilation, "which is very expensive and impacts latency-sensitive workloads like inference."5 Pad or use fixed-size buckets so requests share a small set of signatures.6
- Warm up the cache before taking traffic. Performance improves after a few warm-up steps once the executable is cached, so run every expected shape once at startup and keep the hot path as cache hits.59
- Keep the serving code path identical every step so the same HLO is produced and no recompile fires; "if the HLO changes between executions, a recompilation will still occur."911
- Keep values on device on the hot path. Host-side access (
.item(), logging a per-request scalar, host-side branching) forces implicit barriers and graph cuts, so minimize it inside the serving loop.10
How to maintain it¶
Compilation cache and recompiles¶
XLA caches each compiled executable per unique input shape and signature, so performance improves after a few warm-up steps, similar to Inductor.59 At the barrier, if the resulting HLO matches a previously seen graph, the cached compiled program is reused; "if the HLO changes between executions, a recompilation will still occur."11
Maintenance rules that keep the cache hot:
- Run identical model code every step so the same graph is produced and compilation happens once per graph.9
- Stabilize shapes. New shapes recompile the whole program. Pad to fixed sizes or bucket inputs to bound the number of distinct shape signatures.6
- Watch for accidental dynamism: host-dependent control flow, varying sequence lengths, or per-iteration Python constants change the HLO and force recompiles, the same failure class flagged for the Inductor path in torch.compile: Graph Capture, Backends, and Recompiles.
Note a divergence worth flagging: the book says you can activate XLA via torch.compile(..., backend="openxla") "based on OpenXLA."3 Official PyTorch/XLA docs use the spelling backend='openxla' and describe the Dynamo to FX to LazyTensor bridge in more detail than the book; where they differ, prefer the official docs.12 Both agree on the backend string and the trace-once behavior.
How to scale it¶
Scaling XLA is mostly about spreading work across accelerators while keeping the number of distinct compilations bounded:
- Distributed strategies. You can use distributed strategies with XLA, "such as data and model parallelism," and the same hygiene rule applies at scale: minimize graph breaks.7
- Bucketing bounds compile count. The cache model above shows that total compilations equal the number of distinct shape signatures. A fixed set of buckets turns a per-step recompile risk (order of the number of steps) into a fixed, small compile budget (order of the number of buckets), which is what makes throughput predictable across a fleet.6
- Non-NVIDIA hardware reach. XLA is the practical route to scaling on accelerators outside the TorchInductor path: Google Cloud TPUs, Meta's MTIA, and AWS Trainium/Inferentia via the Neuron SDK.12
Failure modes¶
| Failure mode | Trigger | Symptom | Mitigation |
|---|---|---|---|
| Recompile storm | Input shapes vary every step (ragged sequence lengths) | Repeated whole-graph compiles, latency spikes | Pad or bucket to fixed sizes 6 |
| Mid-run compile stall | A never-before-seen shape appears at run time; "XLA will not incrementally compile mid-run" | A single step blocks on a full whole-graph compile | Warm up all expected shapes, keep shapes static or bounded 5 |
| Host-access barrier thrash | .item(), printing a loss, host-side if mid-step |
Implicit barriers cut the graph, extra syncs, throughput drop | Reduce value access, place one deliberate sync() per step 10 |
| HLO drift | Model code differs between steps | Recompilation even at the same input shape | Run identical model code every step 911 |
| Wrong backend for the device | Reaching for XLA on NVIDIA GPUs | Off the well-trodden path, weaker tooling | Use TorchInductor on NVIDIA, XLA for non-NVIDIA hardware 7 |
| Unvalidated training path | Relying on backend="openxla" for training |
Experimental, potentially unstable behavior | Treat training support as experimental, validate before production 12 |
References¶
- Chris Fregly, AI Systems Performance Engineering (O'Reilly), Chapter 14: "PyTorch Compiler, OpenAI Triton, and XLA Backends", section "PyTorch XLA Backend".
- PyTorch on XLA Devices (lazy tensors,
xm.mark_step, device API): https://docs.pytorch.org/xla/release/r2.7/learn/pytorch-on-xla-devices.html - PyTorch/XLA Overview (LazyTensor tracing, HLO, compilation cache): https://docs.pytorch.org/xla/master/learn/xla-overview.html
- PyTorch/XLA API guide (
torch_xla.sync,torch_xla.device): https://docs.pytorch.org/xla/master/learn/api-guide.html - TorchDynamo (
torch.compile) integration in PyTorch/XLA (backend="openxla"): https://docs.pytorch.org/xla/master/torch_compile.html
Related: torch.compile: Graph Capture, Backends, and Recompiles · PyTorch CUDA Caching Allocator Tuning · Activation Checkpointing and Memory Offloading · PyTorch Attention APIs: SDPA and FlexAttention · PyTorch Performance Regression Testing in CI · Kernel Fusion · OpenAI Triton: Authoring GPU Kernels in Python · CUDA Graphs: Capture, Replay, and Launch Overhead · Frameworks · Distributed Training Platform · Inference Serving and Optimization · Glossary
-
Fregly, Ch. 14, "TorchInductor Backend Code Generation": XLA "is an alternate backend targeting non-CUDA hardware ... mainly used for Google Cloud TPUs through the OpenXLA project ... Meta's ... MTIA, uses XLA ... AWS's custom Inferentia and Trainium ... run PyTorch with an XLA compiler in its open source AWS Neuron SDK. NVIDIA GPUs typically use TorchInductor, however." ↩↩↩
-
Fregly, Ch. 14, "PyTorch XLA Backend": "While TorchInductor is the PyTorch default backend for GPUs and CPUs, PyTorch XLA is a separate backend-compilation option that targets Google Cloud TPUs and other accelerators ... by mapping PyTorch operations into XLA's graph IR and executing them using the target hardware's runtime." ↩↩↩
-
Fregly, Ch. 14, "PyTorch XLA Backend": "To activate the XLA backend, you can use torch.compile(..., backend=\"openxla\"), which activates PyTorch XLA based on OpenXLA. This backend string is supported by the PyTorch XLA project and activates OpenXLA-based compilation." ↩↩↩
-
Fregly, Ch. 14, "PyTorch XLA Backend": "XLA captures the graph of computations. However, it compiles whole programs ahead of time because XLA is designed to generate static graphs. XLA is optimized for static shapes or bounded dynamic shapes." ↩
-
Fregly, Ch. 14, "PyTorch XLA Backend": "The XLA compiler will cache each compiled graph per unique input shape and signature ... The major difference is that XLA will not incrementally compile mid-run. The graph is built statically ahead of time. And if a new shape is encountered, it will trigger a new whole-graph compilation, which is very expensive and impacts latency-sensitive workloads like inference." ↩↩↩↩↩↩↩
-
Fregly, Ch. 14, "PyTorch XLA Backend": "New shapes trigger whole-program recompilation, which is expensive for latency-sensitive inference ... You may need to pad or use fixed-size buckets for your inputs. This is because XLA will recompile for new shapes rather than handling them symbolically." ↩↩↩↩↩
-
Fregly, Ch. 14, "PyTorch XLA Backend": "if you're running on a hardware device not currently supported by the TorchInductor backend, you can potentially use the XLA backend if the device supports XLA. Many of the same principles apply, such as minimizing graph breaks. You can also use some distributed strategies with XLA, such as data and model parallelism. While XLA isn't commonly used with NVIDIA GPUs, it's a powerful backend for non-NVIDIA hardware." ↩↩↩↩↩
-
PyTorch on XLA Devices: "XLA tensors are lazy. They record operations in a graph until the results are needed ... Calling
xm.mark_step()at the end of each training iteration causes XLA to execute its current graph and update the model's parameters." https://docs.pytorch.org/xla/release/r2.7/learn/pytorch-on-xla-devices.html ↩↩↩↩ -
PyTorch/XLA Overview: LazyTensor tracing records ops into an IR graph until a barrier; PyTorch/XLA converts the IR to HLO (High Level Operations) for the XLA compiler; compilation is cached and reused for previously seen graphs; "the same model code should be run for every step and compilation should only happen once for every graph." https://docs.pytorch.org/xla/master/learn/xla-overview.html ↩↩↩↩↩↩↩
-
PyTorch/XLA API guide:
torch_xla.sync(wait=False, reset_scope=True)"Launches all pending graph operations." Unlike a synchronous op it does not block further tracing while the device executes, but blocks access to the materializing tensor values. https://docs.pytorch.org/xla/master/learn/api-guide.html ↩↩↩↩↩ -
PyTorch on XLA Devices, Compilation Caching: "if the HLO changes between executions, a recompilation will still occur." https://docs.pytorch.org/xla/release/r2.7/learn/pytorch-on-xla-devices.html ↩↩↩↩↩
-
TorchDynamo (torch.compile) integration in PyTorch/XLA: "Support for PyTorch/XLA and Dynamo currently exists by adding the
backend='openxla'argument totorch.compile." Dynamo provides a TorchFX graph that PyTorch/XLA compiles via existing Lazy Tensor technology; training support is experimental. https://docs.pytorch.org/xla/master/torch_compile.html ↩↩↩↩↩↩↩ -
TorchDynamo integration: with
torch.compile, "PyTorch/XLA only traces the resnet18 model once during the init time and executes the compiled binary every timedynamo_resnet18is invoked, instead of tracing the model every time." https://docs.pytorch.org/xla/master/torch_compile.html ↩↩↩