Skip to content
Markdown

Recipe: memory-efficient GRPO post-training

Scope: a standalone recipe to run an RL post-training (GRPO) workload on a single node by applying two quantization techniques to the RL loop (a 4-bit QLoRA base for the policy and reference, and FP8 rollouts for generation) so the policy, reference, and inference engine fit where full-precision GRPO would need several GPUs. The applied, copy-paste counterpart to the algorithm in GRPO, the reward in reward design, and the systems split in async RL systems; quantization itself is in quantization for inference.

Reference templates on real APIs; pin versions and validate before production use. The commands here are not hardware-tested.

flowchart LR
  BASE["NF4 4-bit base (QLoRA)"] --> POL["Policy = base + LoRA adapter (BF16)"]
  BASE --> REF["Reference = base, adapter disabled"]
  POL --> ROLL["FP8 rollouts (vLLM)"]
  ROLL --> RWD["Verifiable reward function"]
  RWD --> ADV["Group-relative advantage"]
  ADV --> UPD["LoRA update (clipped + KL)"]
  REF -.->|"KL penalty"| UPD
  UPD -->|"adapter weight sync"| ROLL

What it is

GRPO is rollout-dominated and memory-heavy: it holds a policy, a frozen reference (for the KL penalty), and a reward signal, and it generates a group of completions per prompt each step (GRPO). Full-precision, that is two model copies plus an inference engine, typically several GPUs for an 8B model. This recipe shrinks all three with quantization, the book's inference-efficiency techniques applied inside the RL loop (quantization for inference):

  • QLoRA base: the base weights are stored once in 4-bit NF4; the policy is that base plus a small BF16 LoRA adapter, and the reference is the same base with the adapter disabled. One 4-bit copy is both, so there is no second full-precision model in memory.
  • FP8 rollouts: the generation engine (vLLM) serves the policy in FP8 on Hopper/Blackwell, cutting the dominant rollout cost while the LoRA update stays in BF16 for stability.

The result is an RL post-training workload that fits a model whose full-precision GRPO run would not.

Why it matters

  • One node instead of several. The 4-bit base plus adapters plus an FP8 engine is the difference between a single 8x GPU node and a multi-node RL job for mid-size models.
  • Rollout cost is the bottleneck. Generation dominates GRPO wall-clock; FP8 rollouts attack exactly that (async RL systems).
  • Stability is preserved. Training math (the LoRA gradient, the KL, the advantage) stays in BF16; only storage and generation are low-precision, the same precision split tensor cores already use (tensor cores and mixed precision).

When it is needed (and when not)

  • Use it when you want to run GRPO on a verifiable-reward task (maths, code, format) but the full-precision policy + reference + rollout engine will not fit your node.
  • Prefer full-precision GRPO (GRPO) when you have the GPUs, a full BF16 policy avoids the small quality cost of a 4-bit base and the train-inference gap of FP8 rollouts.
  • Prefer DPO or rejection sampling when you have offline preference data or want SFT-level simplicity; neither runs an online rollout loop at all.
  • Avoid if the reward is noisy or hackable; a bad reward trains a bad model fast regardless of precision (reward design).

How: implement, integrate, maintain

1. The verifiable reward

Keep it deterministic and cheap; the reward is the objective, so validate it on held-out cases first (reward design).

# reward.py — exact-match reward on a boxed final answer
import re

def reward_correct(prompts, completions, answer, **kw):
    def final(s):
        m = re.search(r"\\boxed\{([^}]*)\}", s)
        return m.group(1).strip() if m else None
    return [1.0 if final(c) == a else 0.0 for c, a in zip(completions, answer)]

2. The GRPO config (QLoRA base + FP8 rollouts)

TRL's GRPOTrainer takes a peft_config for the LoRA adapter and a 4-bit BitsAndBytesConfig for the NF4 base; use_vllm=True runs rollouts on vLLM, which serves the policy in FP8 when pointed at an FP8 engine. Pin versions. GRPO/quantization APIs move fast.

# train_grpo_qlora.py  — TRL >= 0.16, peft, bitsandbytes
import torch
from datasets import load_dataset
from peft import LoraConfig
from transformers import BitsAndBytesConfig
from trl import GRPOConfig, GRPOTrainer
from reward import reward_correct

nf4 = BitsAndBytesConfig(                       # 4-bit NF4 base (QLoRA)
    load_in_4bit=True, bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)

trainer = GRPOTrainer(
    model="Qwen/Qwen3-8B",
    reward_funcs=[reward_correct],
    args=GRPOConfig(
        num_generations=8,                      # group size G
        beta=0.01,                              # KL toward the (adapter-disabled) reference
        use_vllm=True, vllm_mode="server",      # FP8 rollout engine, separate from the trainer
        max_completion_length=2048, bf16=True,  # update stays BF16
        gradient_checkpointing=True),
    peft_config=LoraConfig(                      # policy = NF4 base + BF16 LoRA; ref = adapter off
        r=16, lora_alpha=32, task_type="CAUSAL_LM",
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]),
    model_init_kwargs={"quantization_config": nf4, "torch_dtype": torch.bfloat16},
    train_dataset=load_dataset("trl-lib/DeepMath-103K", split="train"))
trainer.train()

Serve the rollout engine in FP8 alongside the trainer (quantization for inference, async RL systems):

# FP8 rollout server on its own GPUs; the trainer points at it via vllm_mode="server".
CUDA_VISIBLE_DEVICES=4,5,6,7 trl vllm-serve --model Qwen/Qwen3-8B --quantization fp8
# trainer on the remaining GPUs:
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch train_grpo_qlora.py

For frontier scale, verl drives the same GRPO with LoRA and vLLM rollouts on Ray (algorithm.adv_estimator=grpo, actor_rollout_ref.model.lora_rank=16, actor_rollout_ref.rollout.name=vllm).

3. Apply and verify

Watch the three RL health metrics, not just loss (observability):

# reward should trend up; entropy should not collapse to ~0; KL should stay bounded.
# TRL logs reward, reward_std, frac_reward_zero_std, kl, completions/s.
tail -f trainer_log.jsonl | grep -oE '"(reward|kl|entropy|frac_reward_zero_std)":[0-9.]+'
  • reward climbing and frac_reward_zero_std low (most groups have signal) means the loop is learning.
  • kl flat-and-bounded means the 4-bit policy is not drifting off the reference; if it climbs, raise beta.
  • rollout throughput (completions/s) is the FP8 win. Compare against a BF16 rollout baseline.

4. Maintain

Failure modes

  • Entropy collapse: outputs go deterministic, reward plateaus. Add KL (beta>0) or a DAPO-style fix (GRPO).
  • Train-inference gap widened by FP8: FP8 rollouts diverge slightly more from the BF16 trainer than BF16 rollouts; keep TRL's importance-sampling correction on and watch KL (async RL systems).
  • 4-bit base too lossy: a heavily quantised base can cap final quality; if the eval plateaus low, try a BF16 base or 8-bit before blaming the reward.
  • Merging into the 4-bit base: merging an adapter into the NF4 base corrupts weights; always merge into the dequantised base.
  • Reward hacking: the policy games a loophole; test the reward adversarially first (reward design).

References

  • DeepSeekMath (GRPO): https://arxiv.org/abs/2402.03300
  • QLoRA (NF4, 4-bit fine-tuning): https://arxiv.org/abs/2305.14314
  • TRL GRPO Trainer (peft / vLLM): https://huggingface.co/docs/trl/grpo_trainer
  • vLLM quantization (FP8): https://docs.vllm.ai/en/latest/features/quantization/index.html
  • verl GRPO: https://verl.readthedocs.io/en/latest/algo/grpo.html

Related: GRPO · Quantization for inference · Reward design · Async RL systems · SFT and LoRA · Fine-tuning and post-training · verl · Glossary