Rollout redundancy in RL: prompt deduplication and cascade attention¶
Scope: why group-sampling RL post-training wastes 80-95% of its tokens on repeated prompts, and the two mathematically-equivalent optimizations that recover it: prompt deduplication in the training forward/backward, and cascade attention over shared-prefix KV during rollout decode. This page covers what the redundancy is, why and when the two fixes pay off, how to apply them (with a runnable check that cascade attention equals full attention), and how they compose with the RL loop. It is the efficiency companion to async and disaggregated RL, the algorithm in GRPO, and the cache mechanics in KV cache management.
Generalizable techniques with real implementations (FlashInfer, vLLM, Arctic Platform); pin versions and verify numerical equivalence before production. The Python example is executed and asserted (numpy); framework flags are reference templates.
flowchart TB
P["Prompt (shared)"] -->|"sample G responses"| GRP["Group of G sequences: prompt + response_i"]
GRP --> TRAIN["Training forward/backward"]
GRP --> DEC["Rollout decode"]
TRAIN --> DEDUP["Prompt dedup: pack unique prompt once, reconstruct per-response logprobs"]
DEC --> CASC["Cascade attention: read shared-prefix KV once per group"]
DEDUP --> EQT["Equivalent gradients, less compute + memory"]
CASC --> EQI["Equivalent output, less KV bandwidth"]
What it is¶
Policy-gradient RL (PPO, GRPO) samples G responses per prompt to estimate an advantage, then trains on all G prompt+response pairs. The prompt is identical across the whole group, so a group-sampled batch is dominated by repeated prompt tokens. A 10K-token prompt with 10 responses of 1K tokens each is 110K total tokens but only 20K unique, so 82% of the tokens are duplicated prompt; in typical RL training, 80-95% of batch tokens are redundant prompt, and with attention's O(n^2) cost that shared prompt dominates the bill for long-context RL.1
That redundancy appears in two engines of the RL loop, each with its own exact fix:
- Training (the DeepSpeed/FSDP forward+backward that computes logprobs and gradients) recomputes the same prompt
Gtimes. Fix: prompt deduplication, pack each unique prompt once and reconstruct per-response logprobs. - Rollout decode (the vLLM/SGLang sampler) re-reads the shared prefix's KV cache once per request every step. Fix: cascade attention, read the shared prefix KV once per group.
Both are exact: they change the memory-access and compute pattern, not the math.
Why use it¶
- Long-context RL is prompt-bound. As prompts grow (repo-scale code, long documents, agent histories) the response is a small tail; nearly all compute and KV traffic is the shared prefix, so deduplicating it is the single largest lever for tokens/sec and memory.
- The two fixes are complementary. Training dedup removes recompute in the backward pass; cascade attention removes re-reads of KV bandwidth at decode. A full RL system wants both.
- Orthogonal to on-/off-policy design. These are pure efficiency transforms; they do not touch staleness, importance sampling, or reward, so they stack on top of async/disaggregated RL without changing its stability story.
- Free at inference. FlashInfer reports up to ~31x faster shared-prefix batch decode versus baseline PagedAttention for a long (32K) prompt at large batch (256), the regime long-context group-sampling RL lives in.2
When to use it (and when not)¶
- Use it for long shared prompts and large group size
G: long-context RL such as Text-to-SQL, long-context QA, and repo/agent tasks. Savings scale withprompt_len * G. - Marginal for short prompts, small
G, or response-dominated sequences: little shared prefix to amortize. - Requires exact sharing. Dedup and cascade grouping key on an identical prefix; per-sample system-prompt variance, non-deterministic templating, or interleaved tool outputs fragment the groups and erase the benefit, so normalize prompt construction first.
- Verify equivalence before trusting it. A dedup or cascade path that is not exactly equal corrupts logprobs and gradients subtly; gate it with the equivalence check below before a real run.
Architecture¶
Both transforms live inside the compute engines that async/disaggregated RL already defines (the trainer, the reference/log-prob engine, and the sampler), whether colocated or split into actor/learner pools. Prompt dedup patches the training and log-prob forward passes; cascade attention patches the sampler's decode attention and composes with prefix caching and continuous batching. Neither changes weight-sync, staleness, or reward.
How to use it¶
The inference-side fix is the more general one, and its correctness is checkable. Prefix caching already stores the shared prompt's KV once instead of G times, but standard PagedAttention still reads that copy once per request at decode. Cascade attention instead computes one partial attention over the shared prefix and one per-request partial over the unique suffix, then merges them with the online-softmax (log-sum-exp) reduction, reading the shared prefix once per group. This runnable check confirms the merge equals full dense attention to floating-point precision (asserted max error ~1e-16):
# cascade_check.py — validated: cascade (prefix partial + suffix partial, merged) == full attention.
import numpy as np
d, Lp, Ls = 64, 40, 8
rng = np.random.default_rng(0)
q = rng.standard_normal(d) / np.sqrt(d)
Kp, Vp = rng.standard_normal((Lp, d)), rng.standard_normal((Lp, d)) # shared prefix KV (read once per group)
Ks, Vs = rng.standard_normal((Ls, d)), rng.standard_normal((Ls, d)) # per-request unique suffix KV
def partial(K, V): # flash-style block state: (max score, denom, unnorm acc)
s = K @ q / np.sqrt(d); m = s.max(); e = np.exp(s - m)
return m, e.sum(), e @ V
def merge(a, b): # online-softmax combine of two partials
(m1, l1, o1), (m2, l2, o2) = a, b; m = max(m1, m2)
l = l1*np.exp(m1-m) + l2*np.exp(m2-m)
o = o1*np.exp(m1-m) + o2*np.exp(m2-m)
return o / l
def full(K, V): # reference: dense softmax attention
s = K @ q / np.sqrt(d); p = np.exp(s - s.max()); p /= p.sum(); return p @ V
cascade = merge(partial(Kp, Vp), partial(Ks, Vs))
ref = full(np.vstack([Kp, Ks]), np.vstack([Vp, Vs]))
assert np.allclose(cascade, ref, atol=1e-10) # exact, not approximate
print(f"max abs error: {np.abs(cascade - ref).max():.1e}") # -> ~1e-16
In production you do not write this kernel: enable it through the engine. vLLM V1 uses cascade attention automatically when a batch shares a single prefix; keep prefix caching on (--enable-prefix-caching, default in V1) so the shared prompt is stored once, and cascade attention removes the redundant re-reads on top.
How to develop with it¶
- Inference side. vLLM V1 triggers cascade attention only when all requests in a batch share one common prefix (a single tree); it does not yet handle different prefixes per batch. Real RL rollout batches hold many prompts at once (a forest of prefix trees), which is what Arctic Platform's Forest Cascade Attention generalizes by discovering multiple shared-prefix groups per decode batch.3 Check whether your engine's cascade path fires for multi-prompt batches, or the benefit is capped at single-tree.
- Training side. Prompt deduplication detects sequences that share a prompt, packs each unique prompt once, runs the model once, and scatters per-response logprobs/entropy back into the original order. Splitting attention into a prompt-to-prompt pass plus a response-to-full pass avoids recomputing the shared prompt attention
Gtimes. Arctic Platform's ZoRRo Train delivers this by patching the DeepSpeed training and log-prob engines, toggled per run, across dense and MoE models.1 It generalizes intra-batch deduplication (for example BatchLLM's shared-prefix grouping for offline serving) to the RL training pass.4
How to maintain it¶
Gate both transforms with the equivalence checks (the cascade check above for the sampler, and a gradient/logprob match against the naive path for training dedup) after any kernel or engine upgrade, since a changed attention path can break exact equivalence silently. Monitor the realized dedup ratio and cascade hit rate: prompts that look shared but differ by a token (a timestamp, a per-sample id, a tokenizer edge case) fragment into singleton groups and quietly erase the benefit, and stock single-tree cascade simply will not fire on a multi-prompt batch. Watch that measured savings track the theoretical prompt_len * G, and investigate when they do not.
How to run it in production¶
On the inference side, keep prefix caching on so the shared prompt is stored once, and run an engine whose cascade path fires for your batch shape (forest-aware for the many-prompt batches real RL produces, not just single-tree). On the training side, enable prompt dedup in the trainer and log-prob engines so the packed forward/backward is exact. The win compounds with batch size and prefix length and needs no RL-recipe re-tuning, because both transforms are mathematically equivalent to the naive path; neither adds a second model or changes the weight-sync story of async RL, whose sibling systems optimization is delta weight sync.
Failure modes¶
- Silent numerical drift. A dedup or cascade path that is not exactly equivalent corrupts logprobs and gradients; gate it with the equivalence check above before trusting a run.1
- Single-tree assumption. Applying stock single-prefix cascade attention to a multi-prompt RL batch either falls back to no speedup or mixes groups; use a forest-aware implementation for real batches.
- Broken grouping. Prompts that look shared but differ by a token (timestamp, per-sample ID, tokenizer edge case) split into singletons; monitor the realized dedup ratio, not just the config flag.
- Double-counting savings. Prefix caching and prompt dedup both target the shared prompt; measure end-to-end, because their savings overlap rather than add.
- Interaction with chunked prefill / speculative decode. Cascade attention changes the decode kernel path; validate it composes with the sampler's other features rather than assuming it.
References¶
- Arctic Platform (Snowflake AI Research), Arctic RL, ZoRRo Train, ZoRRo Inference: https://github.com/Snowflake-AI-Research/Arctic-Platform
- ArcticInference, Forest Cascade Attention (vLLM plugin): https://github.com/snowflakedb/ArcticInference
- FlashInfer, Cascade Inference: Memory-Bandwidth-Efficient Shared Prefix Batch Decoding: https://flashinfer.ai/2024/02/02/cascade-inference.html
- vLLM, cascade-attention optimization and single-tree/forest limitation (issue #14729): https://github.com/vllm-project/vllm/issues/14729
- vLLM, shared-prefix batching discussion (BatchLLM RFC, issue #12080): https://github.com/vllm-project/vllm/issues/12080
Related: Async and disaggregated RL · Delta weight sync · GRPO · PPO · KV cache management · Continuous batching internals · DeepSpeed ZeRO · RL libraries · verl · Glossary
-
Arctic Platform,
arctic_platform/rl/zorro_train/README.md: worked example (10K prompt x 10 responses = 110K total / 20K unique, 82% redundant), 80-95% redundant prompt tokens in typical RL, three-phase dedup (detect+pack, split prompt-to-prompt / response-to-full attention, transparent reconstruction of per-response logprobs/entropy in original order), gradients equivalent to baseline within numerical precision. ↩↩↩ -
FlashInfer, Cascade Inference (2024-02-02): decouple attention over a shared prefix from per-request suffixes, keep the shared KV in on-chip memory, merge partial states; up to ~31x speedup versus baseline PagedAttention for a 32,768-token shared prompt at batch size 256. ↩
-
Arctic Platform README, ZoRRo Inference: Forest Cascade Attention discovers groups of requests sharing a KV-cache prefix per decode batch, runs one grouped pass over shared prefix blocks plus per-request suffix passes, and reduces with weighting; reads each shared prefix block once per group instead of once per request, mathematically equivalent to standard attention. ↩
-
vLLM issue #12080 (BatchLLM RFC): intra-batch grouping of requests by shared prefix to compute the common prefix once in offline/batch scenarios, the offline-serving analogue of RL prompt deduplication. ↩