Model weight loading in inference engines¶
Scope: how an engine like SGLang or vLLM turns checkpoint files into a running model, and what you must implement when adding or debugging a model. This page covers what decides which module loads a checkpoint (the architectures field and the model registry), how PyTorch names every parameter so a weight key finds its tensor, how the safetensors index maps keys to shards, and why engines write custom load_weights (to skip, remap, transform, and fuse weights), with a runnable fused-QKV loader. It underpins delta weight sync (which reuses load_weights), multi-LoRA serving (which replaces loaded modules), and quantization (which repacks at load).
Mechanics track SGLang/vLLM internals and can shift across versions; verify class and file names on the pinned engine. The Python examples are executed and asserted (numpy and stdlib); framework snippets are reference templates.
flowchart LR
CFG["config.json: architectures"] --> REG["Model registry: EntryClass lookup"]
REG --> MOD["Instantiate module tree"]
SFT["safetensors + index.json weight_map"] --> KEYS["Tensor keys: dotted names"]
MOD --> LOAD["load_weights: match key to parameter"]
KEYS --> LOAD
LOAD --> CUSTOM["Custom loader: skip / remap / transform / fuse Q,K,V"]
CUSTOM --> READY["Weights resident, model ready"]
What it is¶
Loading a model is matching every tensor in the checkpoint to a parameter in an instantiated module tree. Four pieces make that work:
- The architecture selector. A model repo ships
config.jsonwhosearchitecturesfield (for exampleQwen3ForCausalLM) names the top-level class. The engine keeps a model registry that scans its model files, reads the entry class each declares (anEntryClass = ...line), and looks up the right one for that config. - Parameter names. PyTorch names each parameter by chaining the variable names from the top module down, separated by dots;
nn.ModuleListinserts the list index. Somodel.layers.0.self_attn.q_proj.weightis just the nested variable path. Renaming a module variable silently renames its weights, so the names are load-bearing, not cosmetic. - The shard map. For multi-file checkpoints,
model.safetensors.index.jsonholds aweight_mapfrom each tensor key to its shard file; each shard stores per-tensor dtype, shape, and byte offsets. Single-file checkpoints skip the index. - The loader. A default loader can match keys to parameters directly, but real models ship a custom
load_weightsto handle the cases the default cannot (below).
Why it matters¶
- It is the interface everything else reuses. Delta weight sync applies sparse updates through
load_weights; multi-LoRA discovers and replaces loaded modules; quantization repacks tensors during load. Understanding the loader is prerequisite to all three. - It is where new-model support lives. Adding a model to an engine is mostly writing its module tree with the right variable names and a
load_weightsthat reconciles the checkpoint's key layout with the engine's fused, parallel layers. - Silent mismatches are common. A wrong variable name, an unhandled fused layer, or a skipped tensor produces a model that loads but generates garbage, so knowing the matching rules is what makes the failure debuggable.
When to use it (and when not)¶
- You write or read this code when adding a new model, porting one across engines, or debugging weights that load without error but produce wrong output.
- The default path is enough when the checkpoint's key layout already matches the engine's module tree (no fused layers, no renamed variables, no on-the-fly transforms); then you do not touch
load_weights. - This is engine-side, not training-side. Sharded training checkpoints (Megatron/FSDP) have their own save/load; this page is about loading a Hugging Face-style checkpoint into an inference engine.
Architecture¶
The load pipeline is: read config.json, resolve the architecture through the registry, instantiate the module tree (which fixes every parameter name), read the safetensors index to know which shard holds each key, then run load_weights, which iterates the checkpoint tensors and, for each, either copies it into the matching parameter or routes it through a custom rule (skip, remap, transform, or a fused weight_loader). The result is every parameter populated and the model ready to run.
How to use it¶
The interesting case is a fused layer. For efficiency and simpler tensor parallelism, engines fuse Q, K, and V into one QKVParallelLinear parameter, but the checkpoint stores Q, K, V separately, so a custom weight_loader is called once per shard to write each into its slice of the fused tensor. This runnable model of that loader is executed and asserted, including an adversarial check that mislabeling a shard corrupts the result (shard identity is load-bearing):
# fused_qkv_load.py — validated model of a fused QKV weight_loader; numpy only.
import numpy as np
D, Hq, Hkv = 8, 4, 2 # GQA: fewer KV rows than query rows
rng = np.random.default_rng(0)
q = rng.standard_normal((Hq, D)).astype(np.float32) # Q, K, V are stored SEPARATELY on disk
k = rng.standard_normal((Hkv, D)).astype(np.float32)
v = rng.standard_normal((Hkv, D)).astype(np.float32)
fused = np.full((Hq + 2 * Hkv, D), np.nan, np.float32) # one fused QKVParallelLinear param
slots = {"q": (0, Hq), "k": (Hq, Hq + Hkv), "v": (Hq + Hkv, Hq + 2 * Hkv)}
def weight_loader(param, shard, shard_id): # called once per shard: "q", then "k", then "v"
s, e = slots[shard_id]
assert shard.shape[0] == e - s, "shard does not fit its slot"
param[s:e] = shard # write only this shard's rows
for shard, sid in [(q, "q"), (k, "k"), (v, "v")]:
weight_loader(fused, shard, sid)
assert np.array_equal(fused, np.vstack([q, k, v])) # fused == concat of q,k,v on the output dim
# adversarial: mislabel the k/v shards and the fused tensor is wrong -> shard identity matters
bad = np.full_like(fused, np.nan)
for shard, sid in [(q, "q"), (v, "k"), (k, "v")]: # v and k intentionally swapped
s, e = slots[sid]; bad[s:e] = shard
assert not np.array_equal(bad, np.vstack([q, k, v]))
print("fused QKV load OK; mislabeled shards detected")
In the engine, this is QKVParallelLinear.weight_loader(param, loaded_weight, loaded_shard_id), and the model's top-level load_weights routes the three checkpoint tensors to it. The same pattern fuses gate_proj/up_proj into gate_up_proj.
How to develop with it¶
First, get the names right. PyTorch derives every key by chaining variable names, so a faithful model of the rule (executed) shows why layers.0 appears and why renaming a variable breaks loading:
# naming.py — validated: dotted parameter names, ModuleList index disambiguates instances.
def named_parameters(module, prefix=""):
out = {}
for name, child in module.items():
key = prefix + name
if isinstance(child, list): # nn.ModuleList -> insert the index
for i, item in enumerate(child):
out.update(named_parameters(item, f"{key}.{i}."))
elif isinstance(child, dict): # submodule -> recurse
out.update(named_parameters(child, key + "."))
else:
out[key] = child # leaf parameter
return out
model = {"model": {"layers": [{"self_attn": {"q_proj": {"weight": 0}}},
{"self_attn": {"q_proj": {"weight": 0}}}]}}
keys = set(named_parameters(model))
assert "model.layers.0.self_attn.q_proj.weight" in keys
assert "model.layers.1.self_attn.q_proj.weight" in keys # per-instance keys via the list index
print("sample key:", sorted(keys)[0])
A custom load_weights exists for four recurring reasons:
- Skip parameters that are recomputed, not stored (for example
rotary_emb.inv_freqfrom the RoPE config). - Remap keys when the checkpoint's names differ from the module's (rename
?.conv.conv.weightto?.conv.conv_weight). - Transform a tensor on the way in (for example
squeezea(D, 1, K)conv weight to(D, K)). - Fuse separately-stored Q/K/V (or gate/up) into one parallel layer via the per-shard
weight_loaderabove.
Debugging tip: after a load, assert that the set of checkpoint keys and the set of module parameter names reconcile exactly; unexpected or missing keys are the usual cause of a model that loads but is wrong.
How to maintain it¶
Treat name reconciliation as a test, not an assumption. After any model-code or engine upgrade, re-run a load and assert no parameter is left uninitialized and no checkpoint tensor is dropped (a strict loader raises on either). Because a custom loader also handles quantization, tensor parallelism, and load formats, its branches grow over time, so keep a small fixture checkpoint and a load-time equivalence check (loaded tensor equals the on-disk tensor for a sample of keys, after any documented transform). When a checkpoint format or the engine's fused-layer set changes, the shard-to-slot mapping is the first thing to re-verify.
How to run it in production¶
Loading is where tensor and pipeline parallelism and quantization are applied, so the loader is on the startup critical path and touches every GPU. Practical levers: safetensors are memory-mapped so shards load without a full host copy; the loader shards each tensor to its rank as it reads, so a fused QKVParallelLinear is split across TP ranks during load rather than after; and quantized checkpoints are repacked into the runtime layout at load time (the reason a quantized weight in memory is often not byte-identical to the canonical form, which is what makes delta weight sync hard under NVFP4). Reuse of this native load_weights interface is exactly what lets delta sync and multi-LoRA inject their behavior without reimplementing per-model logic; keep the loader the single path so those features compose. Pin the engine build, because the model registry, fused-layer set, and loader branches are engine-version-specific.
Failure modes¶
- Renamed module variable. Changing a layer's variable name silently changes its parameter key, so the checkpoint tensor no longer matches; keep names aligned with the reference implementation.
- Unhandled fused layer. If Q/K/V are stored separately but the module fuses them and no
weight_loaderroutes the shards, the fused parameter is left partly unset; wire the per-shard loader. - Silently skipped tensor. A key the loader ignores by mistake leaves a parameter at its init value and the model generates garbage without erroring; assert full coverage.
- Wrong shard order. Writing a shard into the wrong slice (or mislabeling
q/k/v) corrupts attention; the adversarial check above is exactly this case. - Quantization or dtype mismatch. Loading a canonical tensor into a runtime-quantized parameter without the repack step produces wrong numerics; route quantized weights through the engine's post-load processing.
References¶
- Changyi, "How Does SGLang Actually Load Model Weights?": https://changyi.fun/posts/model-weight-loading/
- SGLang model definitions and registry (
python/sglang/srt/models/): https://github.com/sgl-project/sglang - Hugging Face safetensors (format, index,
weight_map): https://github.com/huggingface/safetensors - PyTorch
nn.Module.named_parameters(parameter naming): https://pytorch.org/docs/stable/generated/torch.nn.Module.html - Hugging Face Transformers (config
architectures): https://huggingface.co/docs/transformers/main/en/model_doc/auto
Related: Delta weight sync · Multi-LoRA serving · Quantization for inference · Serving open-weight models · Inference serving · SFT/LoRA · Glossary