Recipe: memory-efficient GRPO post-training¶
Scope: a standalone recipe to run an RL post-training (GRPO) workload on a single node by applying two quantization techniques to the RL loop (a 4-bit QLoRA base for the policy and reference, and FP8 rollouts for generation) so the policy, reference, and inference engine fit where full-precision GRPO would need several GPUs. The applied, copy-paste counterpart to the algorithm in GRPO, the reward in reward design, and the systems split in async RL systems; quantization itself is in quantization for inference.
Reference templates on real APIs; pin versions and validate before production use. The commands here are not hardware-tested.
flowchart LR
BASE["NF4 4-bit base (QLoRA)"] --> POL["Policy = base + LoRA adapter (BF16)"]
BASE --> REF["Reference = base, adapter disabled"]
POL --> ROLL["FP8 rollouts (vLLM)"]
ROLL --> RWD["Verifiable reward function"]
RWD --> ADV["Group-relative advantage"]
ADV --> UPD["LoRA update (clipped + KL)"]
REF -.->|"KL penalty"| UPD
UPD -->|"adapter weight sync"| ROLL
What it is¶
GRPO is rollout-dominated and memory-heavy: it holds a policy, a frozen reference (for the KL penalty), and a reward signal, and it generates a group of completions per prompt each step (GRPO). Full-precision, that is two model copies plus an inference engine, typically several GPUs for an 8B model. This recipe shrinks all three with quantization, the book's inference-efficiency techniques applied inside the RL loop (quantization for inference):
- QLoRA base: the base weights are stored once in 4-bit NF4; the policy is that base plus a small BF16 LoRA adapter, and the reference is the same base with the adapter disabled. One 4-bit copy is both, so there is no second full-precision model in memory.
- FP8 rollouts: the generation engine (vLLM) serves the policy in FP8 on Hopper/Blackwell, cutting the dominant rollout cost while the LoRA update stays in BF16 for stability.
The result is an RL post-training workload that fits a model whose full-precision GRPO run would not.
Why it matters¶
- One node instead of several. The 4-bit base plus adapters plus an FP8 engine is the difference between a single 8x GPU node and a multi-node RL job for mid-size models.
- Rollout cost is the bottleneck. Generation dominates GRPO wall-clock; FP8 rollouts attack exactly that (async RL systems).
- Stability is preserved. Training math (the LoRA gradient, the KL, the advantage) stays in BF16; only storage and generation are low-precision, the same precision split tensor cores already use (tensor cores and mixed precision).
When it is needed (and when not)¶
- Use it when you want to run GRPO on a verifiable-reward task (maths, code, format) but the full-precision policy + reference + rollout engine will not fit your node.
- Prefer full-precision GRPO (GRPO) when you have the GPUs, a full BF16 policy avoids the small quality cost of a 4-bit base and the train-inference gap of FP8 rollouts.
- Prefer DPO or rejection sampling when you have offline preference data or want SFT-level simplicity; neither runs an online rollout loop at all.
- Avoid if the reward is noisy or hackable; a bad reward trains a bad model fast regardless of precision (reward design).
How: implement, integrate, maintain¶
1. The verifiable reward¶
Keep it deterministic and cheap; the reward is the objective, so validate it on held-out cases first (reward design).
# reward.py — exact-match reward on a boxed final answer
import re
def reward_correct(prompts, completions, answer, **kw):
def final(s):
m = re.search(r"\\boxed\{([^}]*)\}", s)
return m.group(1).strip() if m else None
return [1.0 if final(c) == a else 0.0 for c, a in zip(completions, answer)]
2. The GRPO config (QLoRA base + FP8 rollouts)¶
TRL's GRPOTrainer takes a peft_config for the LoRA adapter and a 4-bit BitsAndBytesConfig for the NF4 base; use_vllm=True runs rollouts on vLLM, which serves the policy in FP8 when pointed at an FP8 engine. Pin versions. GRPO/quantization APIs move fast.
# train_grpo_qlora.py — TRL >= 0.16, peft, bitsandbytes
import torch
from datasets import load_dataset
from peft import LoraConfig
from transformers import BitsAndBytesConfig
from trl import GRPOConfig, GRPOTrainer
from reward import reward_correct
nf4 = BitsAndBytesConfig( # 4-bit NF4 base (QLoRA)
load_in_4bit=True, bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)
trainer = GRPOTrainer(
model="Qwen/Qwen3-8B",
reward_funcs=[reward_correct],
args=GRPOConfig(
num_generations=8, # group size G
beta=0.01, # KL toward the (adapter-disabled) reference
use_vllm=True, vllm_mode="server", # FP8 rollout engine, separate from the trainer
max_completion_length=2048, bf16=True, # update stays BF16
gradient_checkpointing=True),
peft_config=LoraConfig( # policy = NF4 base + BF16 LoRA; ref = adapter off
r=16, lora_alpha=32, task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]),
model_init_kwargs={"quantization_config": nf4, "torch_dtype": torch.bfloat16},
train_dataset=load_dataset("trl-lib/DeepMath-103K", split="train"))
trainer.train()
Serve the rollout engine in FP8 alongside the trainer (quantization for inference, async RL systems):
# FP8 rollout server on its own GPUs; the trainer points at it via vllm_mode="server".
CUDA_VISIBLE_DEVICES=4,5,6,7 trl vllm-serve --model Qwen/Qwen3-8B --quantization fp8
# trainer on the remaining GPUs:
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch train_grpo_qlora.py
For frontier scale, verl drives the same GRPO with LoRA and vLLM rollouts on Ray (algorithm.adv_estimator=grpo, actor_rollout_ref.model.lora_rank=16, actor_rollout_ref.rollout.name=vllm).
3. Apply and verify¶
Watch the three RL health metrics, not just loss (observability):
# reward should trend up; entropy should not collapse to ~0; KL should stay bounded.
# TRL logs reward, reward_std, frac_reward_zero_std, kl, completions/s.
tail -f trainer_log.jsonl | grep -oE '"(reward|kl|entropy|frac_reward_zero_std)":[0-9.]+'
rewardclimbing andfrac_reward_zero_stdlow (most groups have signal) means the loop is learning.klflat-and-bounded means the 4-bit policy is not drifting off the reference; if it climbs, raisebeta.- rollout throughput (completions/s) is the FP8 win. Compare against a BF16 rollout baseline.
4. Maintain¶
- Gate on a held-out eval the reward never saw before promoting the merged model (SRE and MLOps practices).
- Merge for serving: merge the LoRA adapter into the dequantised BF16 base (not the 4-bit base), then quantise the merged model for deployment (quantization for inference, serving open-weight models).
- Checkpoint the adapter each N steps; resume is cheap because only the adapter changes (checkpoint recovery).
Failure modes¶
- Entropy collapse: outputs go deterministic, reward plateaus. Add KL (
beta>0) or a DAPO-style fix (GRPO). - Train-inference gap widened by FP8: FP8 rollouts diverge slightly more from the BF16 trainer than BF16 rollouts; keep TRL's importance-sampling correction on and watch KL (async RL systems).
- 4-bit base too lossy: a heavily quantised base can cap final quality; if the eval plateaus low, try a BF16 base or 8-bit before blaming the reward.
- Merging into the 4-bit base: merging an adapter into the NF4 base corrupts weights; always merge into the dequantised base.
- Reward hacking: the policy games a loophole; test the reward adversarially first (reward design).
References¶
- DeepSeekMath (GRPO): https://arxiv.org/abs/2402.03300
- QLoRA (NF4, 4-bit fine-tuning): https://arxiv.org/abs/2305.14314
- TRL GRPO Trainer (peft / vLLM): https://huggingface.co/docs/trl/grpo_trainer
- vLLM quantization (FP8): https://docs.vllm.ai/en/latest/features/quantization/index.html
- verl GRPO: https://verl.readthedocs.io/en/latest/algo/grpo.html
Related: GRPO · Quantization for inference · Reward design · Async RL systems · SFT and LoRA · Fine-tuning and post-training · verl · Glossary