Skip to content
Markdown

PyTorch/XLA and the XLA compiler

Scope: lazy-tensor tracing into an XLA/HLO graph, the XLA fusion/layout compiler, mark_step/torch_xla.sync() graph boundaries, per-shape-signature compilation caching and recompiles from dynamic shapes, and when XLA beats eager or Inductor.

What it is

PyTorch/XLA is an alternate compiler backend that maps PyTorch operations into XLA's graph IR and executes them on the target hardware's runtime. Where TorchInductor is the default backend for NVIDIA/AMD GPUs and CPUs, XLA targets Google Cloud TPUs and other accelerators that adopt XLA IR. The book is explicit that "XLA isn't commonly used with NVIDIA GPUs" but is "a powerful backend for non-NVIDIA hardware": Meta's MTIA inference ASIC uses XLA, and AWS Inferentia/Trainium run PyTorch through an XLA compiler in the AWS Neuron SDK.12

The execution model is lazy tensors. An XLA tensor records operations into an IR graph rather than executing them eagerly; "they record operations in a graph until the results are needed."8 This deferred trace is the unit the compiler optimizes: separate ops can be fused into a single optimized operation. The recorded IR is lowered to HLO (High Level Operations), XLA's machine-readable computation format, which the XLA compiler turns into hardware-specific code with its own fusion and layout-assignment passes.9

Two activation paths exist:

  • Lazy-tensor mode (move tensors/model .to(device)), where the framework auto-constructs graphs and you mark execution boundaries explicitly.8
  • torch.compile(..., backend="openxla"), which routes the TorchDynamo FX graph through the same Lazy Tensor machinery so the model is traced once at init instead of every step.31213

Why use it

The book frames the trade-off against Inductor directly. XLA "compiles whole programs ahead of time because XLA is designed to generate static graphs," and is "optimized for static shapes or bounded dynamic shapes."4 That ahead-of-time, whole-program posture is what lets XLA do aggressive cross-op fusion and layout selection for static graphs, but it is also why its dynamic-shape behavior is more brittle than Inductor's symbolic-shape approach.

The single load-bearing operational difference: "XLA will not incrementally compile mid-run." The graph is built statically ahead of time; a new shape triggers a fresh whole-graph compilation, "which is very expensive and impacts latency-sensitive workloads like inference."5 TorchInductor, by contrast, generalizes shapes after the first recompile and can emit one kernel covering a range of sequence lengths (see torch.compile: Graph Capture, Backends, and Recompiles). So with XLA, recompiles are coarser and more costly, and shape discipline is not optional.

When to use it (and when not)

Use XLA when:

  • The device is not supported by TorchInductor but does support XLA IR (TPU, MTIA, Inferentia/Trainium via Neuron). The book: "if you're running on a hardware device not currently supported by the TorchInductor backend, you can potentially use the XLA backend if the device supports XLA."7
  • Shapes are static or tightly bounded (fixed-size inference, padded/bucketed training) so the per-shape compile cost amortizes over many steps.

Do not reach for XLA when:

  • You are on NVIDIA GPUs running normal dynamic-shape PyTorch; TorchInductor is the default and is the better-trodden path. XLA "isn't commonly used with NVIDIA GPUs."7
  • Your workload is latency-sensitive with frequently varying shapes. New shapes force whole-program recompilation; the book recommends you "pad or use fixed-size buckets for your inputs" because XLA recompiles for new shapes rather than handling them symbolically.6

Many of the same compiler hygiene rules carry over from the Inductor path: minimize graph breaks, and you can still use data- and model-parallel distributed strategies.7

Architecture

Both entry paths converge on one pipeline. Eager PyTorch ops, or the FX graph that TorchDynamo hands over under torch.compile(..., backend="openxla"), feed the same LazyTensor machinery.123 The LazyTensor layer records operations into an IR graph and runs nothing until a barrier.8 At the barrier (torch_xla.sync() or xm.mark_step()), the trace is lowered to HLO, which the XLA compiler optimizes with fusion and layout-assignment passes into hardware code.9 The compiled executable is cached per unique input shape and signature: a matching signature is a cache hit, a new shape triggers a fresh whole-graph compilation.511 The runtime then executes on the target accelerator (TPU, MTIA, Trainium, Inferentia).12

flowchart LR
  EAGER["Eager PyTorch ops"] --> LT["LazyTensor IR graph (deferred trace)"]
  DYN["torch.compile backend=openxla (Dynamo FX graph)"] --> LT
  LT -->|"torch_xla.sync() / mark_step()"| HLO["HLO graph"]
  HLO --> XLA["XLA compiler (fusion + layout)"]
  XLA --> CACHE["Per-shape-signature executable cache"]
  CACHE --> DEV["TPU / MTIA / Trainium / Inferentia runtime"]
  DEV -.->|"new shape: whole-graph recompile"| XLA

The dashed edge is the cost center: anything that changes the HLO (a new shape, a different signature, non-identical model code) loops back through the compiler instead of hitting the cache.511

How to use it

Lazy-tensor mode and the graph boundary

In lazy mode you place tensors on the XLA device and mark where the accumulated graph should be cut and executed. The newer barrier API is torch_xla.sync(); xm.mark_step() is the historical equivalent.108 The barrier is what converts the traced IR to HLO and dispatches it.

Reference template (requires a torch_xla install and an XLA device, so it is not runnable in this environment). The numpy model below validates the core execution semantics it teaches.

# Reference template: needs torch_xla + an XLA device (TPU/Trainium/etc.).
import torch
import torch_xla
import torch_xla.core.xla_model as xm

device = torch_xla.device()          # was xm.xla_device()
model = MyModel().to(device)

for data, target in loader:
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()
    out = model(data)                # recorded into the lazy graph, not run yet
    loss = loss_fn(out, target)
    loss.backward()
    optimizer.step()
    torch_xla.sync()                 # cut graph -> lower to HLO -> compile/execute

torch_xla.sync() does not block further tracing while the device runs the graph; it only blocks access to the materialized tensor values.10 Avoid pulling tensor values to the host mid-step (printing a loss, .item(), host-side if): each materialization forces an implicit barrier and a graph cut. Reduce value access frequency, or place a deliberate sync() where the cut belongs.10

The numpy model below (runnable with the system python3 plus numpy, no torch) reproduces and asserts that core behavior: operations are recorded and nothing runs until a barrier, a host-value access forces an implicit barrier, and the deferred result is identical to eager evaluation.

import numpy as np

# Reference model of PyTorch/XLA lazy-tensor execution (numpy-only).
# Ops are RECORDED into a deferred graph; nothing runs until a barrier
# (torch_xla.sync / xm.mark_step) or a host-value access (.item()) cuts the
# graph and executes it. Validates deferral, barrier-triggered execution,
# host-access barriers, and exact equivalence to eager evaluation.

class LazyGraph:
    def __init__(self):
        self.barriers = 0      # times the graph was cut + executed
        self.evals = 0         # primitive op evaluations actually performed

class LazyTensor:
    def __init__(self, g, compute, inputs):
        self.g = g
        self._compute = compute       # fn(*materialized_inputs) -> ndarray
        self._inputs = inputs         # parent LazyTensors
        self._value = None            # None until materialized

    def __add__(self, o): return LazyTensor(self.g, np.add, [self, o])
    def __mul__(self, o): return LazyTensor(self.g, np.multiply, [self, o])

    def _run(self):
        if self._value is None:
            ins = [t._run() for t in self._inputs]
            self.g.evals += 1
            self._value = self._compute(*ins)
        return self._value

    def item(self):
        # Pulling a value to the host forces an implicit barrier + graph cut.
        return float(np.asarray(sync(self.g, self)).reshape(-1)[0])

def leaf(g, arr):
    t = LazyTensor(g, None, [])
    t._value = np.asarray(arr, dtype=np.float64)   # inputs already materialized
    return t

def sync(g, tensor):
    g.barriers += 1            # cut graph -> lower to HLO -> execute
    return tensor._run()

# --- happy path: deferral + equivalence to eager ---
g = LazyGraph()
a = leaf(g, [1.0, 2.0, 3.0])
b = leaf(g, [10.0, 20.0, 30.0])
expr = (a + b) * b                       # recorded, NOT executed
assert g.barriers == 0                   # nothing ran yet
assert expr._value is None               # not materialized
assert g.evals == 0                      # no primitive op evaluated

out = sync(g, expr)                       # barrier: execute the traced graph
ref = (np.array([1., 2., 3.]) + np.array([10., 20., 30.])) * np.array([10., 20., 30.])
assert np.array_equal(out, ref)          # deferred result == eager reference
assert g.barriers == 1
assert g.evals == 2                       # exactly one add + one mul, run once

# --- laziness holds for a long chain: zero barriers before sync ---
g2 = LazyGraph()
x = leaf(g2, np.ones(4))
acc = x
for _ in range(50):
    acc = acc + x
assert g2.barriers == 0 and acc._value is None    # 50 ops recorded, none run
assert np.array_equal(sync(g2, acc), np.ones(4) * 51.0)

# --- adversarial: host access (.item()) forces an implicit barrier ---
g3 = LazyGraph()
p = leaf(g3, [2.0])
q = leaf(g3, [5.0])
loss = p * q + p                          # recorded
assert g3.barriers == 0                   # not yet cut
val = loss.item()                         # host access -> implicit graph cut
assert val == 12.0                        # 2*5 + 2
assert g3.barriers == 1                   # the .item() caused a barrier

# --- edge: random-array equivalence vs eager reference (fuzz) ---
rng = np.random.default_rng(0)
for _ in range(100):
    g4 = LazyGraph()
    u, v = rng.standard_normal(8), rng.standard_normal(8)
    lu, lv = leaf(g4, u), leaf(g4, v)
    assert np.allclose(sync(g4, (lu + lv) * lu), (u + v) * u)

print("BLOCK1 OK: deferral, barrier count, host-access barrier, eager-equivalence")

How to integrate with it

torch.compile with the OpenXLA backend

backend="openxla" lets TorchDynamo hand its FX graph to PyTorch/XLA, which compiles it once and replays the compiled binary; the model is traced once at init rather than every iteration.1213 This is the recommended path for inference.

Reference template (requires torch_xla and torchvision, so it is not runnable here). The numpy model below validates the per-shape compilation-cache behavior it relies on.

# Reference template: needs torch_xla + torchvision + an XLA device.
import torch
import torchvision
import torch_xla.core.xla_model as xm

device = xm.xla_device()
resnet18 = torchvision.models.resnet18().to(device)
resnet18.eval()

# Backend string is supported by the PyTorch/XLA project and activates
# OpenXLA-based compilation.
dynamo_resnet18 = torch.compile(resnet18, backend="openxla")

for data, _ in loader:
    with torch.no_grad():
        output = dynamo_resnet18(data)   # executes cached compiled binary

Training under backend="openxla" exists but the upstream docs mark it experimental.12

Trace-once-then-replay only pays off if inputs keep hitting the same cached executable. The numpy model below (runnable with the system python3 plus numpy, no torch) reproduces and asserts the cache rule: one compile per distinct shape/signature, dynamic shapes force recompiles, and padding to a fixed bucket collapses everything to a single compile.

import numpy as np

# Reference model of XLA's per-shape-signature compilation cache (numpy-only).
# A new (shape, dtype) signature triggers an expensive whole-graph "compile";
# a seen signature reuses the cached executable. Validates: compile once per
# signature, dynamic shapes cause recompiles, padding/bucketing bounds them.

class XLACache:
    def __init__(self):
        self.cache = {}
        self.compiles = 0          # whole-graph compilations (expensive)

    def _signature(self, x):
        return (tuple(x.shape), x.dtype.str)   # shape + dtype

    def run(self, x):
        sig = self._signature(x)
        if sig not in self.cache:
            self.compiles += 1                 # recompile for a new signature
            self.cache[sig] = lambda a: (a * 2.0 + 1.0).sum()   # fused executable
        return self.cache[sig](x)

def reference(x):                              # slow eager reference
    return float((x * 2.0 + 1.0).sum())

# --- correctness: cached executable matches eager reference ---
c = XLACache()
x = np.arange(16, dtype=np.float64)
assert c.run(x) == reference(x)

# --- dynamic shapes: one compile per DISTINCT shape signature ---
c = XLACache()
lengths = [7, 7, 13, 7, 100, 13, 7]            # ragged sequence lengths
for n in lengths:
    seq = np.ones(n, dtype=np.float64)
    assert c.run(seq) == reference(seq)        # result always correct
assert c.compiles == len(set(lengths)) == 3    # 3 distinct shapes -> 3 compiles

# --- padding/bucketing: fixed shape collapses to a SINGLE compile ---
c = XLACache()
BUCKET = 128
for n in lengths:
    padded = np.zeros(BUCKET, dtype=np.float64)
    padded[:n] = 1.0
    c.run(padded)
assert c.compiles == 1                          # every input shares one signature

# --- adversarial edge: same shape, different dtype -> still a recompile ---
c = XLACache()
c.run(np.ones(8, dtype=np.float64))
c.run(np.ones(8, dtype=np.float64))             # cache hit, no recompile
assert c.compiles == 1
c.run(np.ones(8, dtype=np.float32))             # signature differs by dtype
assert c.compiles == 2                          # dtype is part of the signature

# --- adversarial: worst case = every input a unique shape (recompile storm) ---
c = XLACache()
for n in range(1, 51):                           # 50 distinct lengths
    c.run(np.ones(n, dtype=np.float64))
assert c.compiles == 50                          # pathological: compile every step

print("BLOCK2 OK: signature cache, dynamic recompiles, bucketing=1, dtype edge, storm")

How to run it in production

Inference is where the trade-offs bite hardest, so treat the compile cache as part of the serving contract:

  • Prefer the compiled path. Serve through torch.compile(..., backend="openxla") so the model is traced once at init and each request replays the compiled binary instead of re-tracing.1312
  • Pin input shapes. New shapes force whole-program recompilation, "which is very expensive and impacts latency-sensitive workloads like inference."5 Pad or use fixed-size buckets so requests share a small set of signatures.6
  • Warm up the cache before taking traffic. Performance improves after a few warm-up steps once the executable is cached, so run every expected shape once at startup and keep the hot path as cache hits.59
  • Keep the serving code path identical every step so the same HLO is produced and no recompile fires; "if the HLO changes between executions, a recompilation will still occur."911
  • Keep values on device on the hot path. Host-side access (.item(), logging a per-request scalar, host-side branching) forces implicit barriers and graph cuts, so minimize it inside the serving loop.10

How to maintain it

Compilation cache and recompiles

XLA caches each compiled executable per unique input shape and signature, so performance improves after a few warm-up steps, similar to Inductor.59 At the barrier, if the resulting HLO matches a previously seen graph, the cached compiled program is reused; "if the HLO changes between executions, a recompilation will still occur."11

Maintenance rules that keep the cache hot:

  • Run identical model code every step so the same graph is produced and compilation happens once per graph.9
  • Stabilize shapes. New shapes recompile the whole program. Pad to fixed sizes or bucket inputs to bound the number of distinct shape signatures.6
  • Watch for accidental dynamism: host-dependent control flow, varying sequence lengths, or per-iteration Python constants change the HLO and force recompiles, the same failure class flagged for the Inductor path in torch.compile: Graph Capture, Backends, and Recompiles.

Note a divergence worth flagging: the book says you can activate XLA via torch.compile(..., backend="openxla") "based on OpenXLA."3 Official PyTorch/XLA docs use the spelling backend='openxla' and describe the Dynamo to FX to LazyTensor bridge in more detail than the book; where they differ, prefer the official docs.12 Both agree on the backend string and the trace-once behavior.

How to scale it

Scaling XLA is mostly about spreading work across accelerators while keeping the number of distinct compilations bounded:

  • Distributed strategies. You can use distributed strategies with XLA, "such as data and model parallelism," and the same hygiene rule applies at scale: minimize graph breaks.7
  • Bucketing bounds compile count. The cache model above shows that total compilations equal the number of distinct shape signatures. A fixed set of buckets turns a per-step recompile risk (order of the number of steps) into a fixed, small compile budget (order of the number of buckets), which is what makes throughput predictable across a fleet.6
  • Non-NVIDIA hardware reach. XLA is the practical route to scaling on accelerators outside the TorchInductor path: Google Cloud TPUs, Meta's MTIA, and AWS Trainium/Inferentia via the Neuron SDK.12

Failure modes

Failure mode Trigger Symptom Mitigation
Recompile storm Input shapes vary every step (ragged sequence lengths) Repeated whole-graph compiles, latency spikes Pad or bucket to fixed sizes 6
Mid-run compile stall A never-before-seen shape appears at run time; "XLA will not incrementally compile mid-run" A single step blocks on a full whole-graph compile Warm up all expected shapes, keep shapes static or bounded 5
Host-access barrier thrash .item(), printing a loss, host-side if mid-step Implicit barriers cut the graph, extra syncs, throughput drop Reduce value access, place one deliberate sync() per step 10
HLO drift Model code differs between steps Recompilation even at the same input shape Run identical model code every step 911
Wrong backend for the device Reaching for XLA on NVIDIA GPUs Off the well-trodden path, weaker tooling Use TorchInductor on NVIDIA, XLA for non-NVIDIA hardware 7
Unvalidated training path Relying on backend="openxla" for training Experimental, potentially unstable behavior Treat training support as experimental, validate before production 12

References

  • Chris Fregly, AI Systems Performance Engineering (O'Reilly), Chapter 14: "PyTorch Compiler, OpenAI Triton, and XLA Backends", section "PyTorch XLA Backend".
  • PyTorch on XLA Devices (lazy tensors, xm.mark_step, device API): https://docs.pytorch.org/xla/release/r2.7/learn/pytorch-on-xla-devices.html
  • PyTorch/XLA Overview (LazyTensor tracing, HLO, compilation cache): https://docs.pytorch.org/xla/master/learn/xla-overview.html
  • PyTorch/XLA API guide (torch_xla.sync, torch_xla.device): https://docs.pytorch.org/xla/master/learn/api-guide.html
  • TorchDynamo (torch.compile) integration in PyTorch/XLA (backend="openxla"): https://docs.pytorch.org/xla/master/torch_compile.html

Related: torch.compile: Graph Capture, Backends, and Recompiles · PyTorch CUDA Caching Allocator Tuning · Activation Checkpointing and Memory Offloading · PyTorch Attention APIs: SDPA and FlexAttention · PyTorch Performance Regression Testing in CI · Kernel Fusion · OpenAI Triton: Authoring GPU Kernels in Python · CUDA Graphs: Capture, Replay, and Launch Overhead · Frameworks · Distributed Training Platform · Inference Serving and Optimization · Glossary


  1. Fregly, Ch. 14, "TorchInductor Backend Code Generation": XLA "is an alternate backend targeting non-CUDA hardware ... mainly used for Google Cloud TPUs through the OpenXLA project ... Meta's ... MTIA, uses XLA ... AWS's custom Inferentia and Trainium ... run PyTorch with an XLA compiler in its open source AWS Neuron SDK. NVIDIA GPUs typically use TorchInductor, however." 

  2. Fregly, Ch. 14, "PyTorch XLA Backend": "While TorchInductor is the PyTorch default backend for GPUs and CPUs, PyTorch XLA is a separate backend-compilation option that targets Google Cloud TPUs and other accelerators ... by mapping PyTorch operations into XLA's graph IR and executing them using the target hardware's runtime." 

  3. Fregly, Ch. 14, "PyTorch XLA Backend": "To activate the XLA backend, you can use torch.compile(..., backend=\"openxla\"), which activates PyTorch XLA based on OpenXLA. This backend string is supported by the PyTorch XLA project and activates OpenXLA-based compilation." 

  4. Fregly, Ch. 14, "PyTorch XLA Backend": "XLA captures the graph of computations. However, it compiles whole programs ahead of time because XLA is designed to generate static graphs. XLA is optimized for static shapes or bounded dynamic shapes." 

  5. Fregly, Ch. 14, "PyTorch XLA Backend": "The XLA compiler will cache each compiled graph per unique input shape and signature ... The major difference is that XLA will not incrementally compile mid-run. The graph is built statically ahead of time. And if a new shape is encountered, it will trigger a new whole-graph compilation, which is very expensive and impacts latency-sensitive workloads like inference." 

  6. Fregly, Ch. 14, "PyTorch XLA Backend": "New shapes trigger whole-program recompilation, which is expensive for latency-sensitive inference ... You may need to pad or use fixed-size buckets for your inputs. This is because XLA will recompile for new shapes rather than handling them symbolically." 

  7. Fregly, Ch. 14, "PyTorch XLA Backend": "if you're running on a hardware device not currently supported by the TorchInductor backend, you can potentially use the XLA backend if the device supports XLA. Many of the same principles apply, such as minimizing graph breaks. You can also use some distributed strategies with XLA, such as data and model parallelism. While XLA isn't commonly used with NVIDIA GPUs, it's a powerful backend for non-NVIDIA hardware." 

  8. PyTorch on XLA Devices: "XLA tensors are lazy. They record operations in a graph until the results are needed ... Calling xm.mark_step() at the end of each training iteration causes XLA to execute its current graph and update the model's parameters." https://docs.pytorch.org/xla/release/r2.7/learn/pytorch-on-xla-devices.html 

  9. PyTorch/XLA Overview: LazyTensor tracing records ops into an IR graph until a barrier; PyTorch/XLA converts the IR to HLO (High Level Operations) for the XLA compiler; compilation is cached and reused for previously seen graphs; "the same model code should be run for every step and compilation should only happen once for every graph." https://docs.pytorch.org/xla/master/learn/xla-overview.html 

  10. PyTorch/XLA API guide: torch_xla.sync(wait=False, reset_scope=True) "Launches all pending graph operations." Unlike a synchronous op it does not block further tracing while the device executes, but blocks access to the materializing tensor values. https://docs.pytorch.org/xla/master/learn/api-guide.html 

  11. PyTorch on XLA Devices, Compilation Caching: "if the HLO changes between executions, a recompilation will still occur." https://docs.pytorch.org/xla/release/r2.7/learn/pytorch-on-xla-devices.html 

  12. TorchDynamo (torch.compile) integration in PyTorch/XLA: "Support for PyTorch/XLA and Dynamo currently exists by adding the backend='openxla' argument to torch.compile." Dynamo provides a TorchFX graph that PyTorch/XLA compiles via existing Lazy Tensor technology; training support is experimental. https://docs.pytorch.org/xla/master/torch_compile.html 

  13. TorchDynamo integration: with torch.compile, "PyTorch/XLA only traces the resnet18 model once during the init time and executes the compiled binary every time dynamo_resnet18 is invoked, instead of tracing the model every time." https://docs.pytorch.org/xla/master/torch_compile.html