12 (Appendix). Quantisation and Quantisation-Aware Training#
⚡Compute Note: You can run this notebook on CPU.
This tutorial-style appendix introduces the motivations, numerics, and tooling behind model quantisation with a focus on transformers. We provide conceptual explanations and concise PyTorch code snippets so you can reason about precision trade-offs without running large-scale LLM training.
By the end of this appendix you should be able to:
Explain why quantisation is attractive for deployment and how it differs from pruning or distillation.
Describe the most common integer formats (e.g. int8, int4) and how scale/zero-point pairs represent floating-point ranges.
Compare post-training quantisation (PTQ) with quantisation-aware training (QAT).
Prototype a minimal PTQ and QAT workflow in PyTorch using toy data.
An interesting video on quantisation is by Julia Turc.
1. Why quantise large language models?#
Quantisation maps high-precision parameters and activations (usually 16-bit or 32-bit floats) to lower-precision integer representations. The key benefits are:
Latency: Integer matrix multiplications are faster on CPUs and accelerators that ship with dedicated int8/int4 instructions.
Memory footprint: Reducing precision from 16-bit to 8-bit halves the storage requirements for weights, activations, and gradients.
Bandwidth: Smaller tensors move more quickly across device memory hierarchies, improving throughput for autoregressive decoding.
The trade-off is reduced representational fidelity. Effective quantisation strategies aim to minimise task degradation by carefully choosing calibration data, numeric formats, and training procedures.
Quantisation versus other efficiency techniques#
Technique |
Core idea |
Typical gain |
Trade-offs |
|---|---|---|---|
Pruning |
Remove parameters entirely |
Smaller models, sparse compute |
Requires sparse kernels, may harm accuracy |
Distillation |
Train a smaller student on teacher outputs |
Compact model with similar behaviour |
Needs extra training, may miss rare behaviours |
Quantisation |
Lower the precision of weights/activations |
Faster + smaller with same architecture |
Numerical noise can degrade quality |
Quantisation composes well with pruning and distillation, but each solves a distinct optimisation problem.
2. Quantisation fundamentals#
A uniform affine quantiser represents a real value \(x\) with:
where \(s\) is the scale, \(z\) is the zero-point, and \(q_{\min}, q_{\max}\) bound the integer range (e.g. \([-128, 127]\) for signed int8).
To recover an approximate float, we dequantise via:
Two practical choices dominate transformer quantisation:
Per-tensor vs. per-channel scales: LayerNorm and attention projections often benefit from per-channel scaling to preserve dynamic ranges across attention heads.
Symmetric vs. asymmetric: Symmetric quantisers fix \(z = 0\) and work well when data are zero-centred (common for weights). Asymmetric quantisers add a zero-point to better capture skewed activation distributions.
Choosing bit-widths#
Bit-width |
Common usage |
Notes |
|---|---|---|
16-bit (FP16/BF16/FP8) |
Training and mixed-precision inference |
Floating formats preserve wide dynamic range. |
8-bit int |
Widely adopted for weights + activations |
Best hardware support across CPU/GPU/TPU. |
4-bit int |
Aggressive compression for decoder-only LLMs |
Needs more careful calibration, often weight-only. |
2-bit / Binary |
Specialised research |
Requires custom kernels; accuracy drops are larger. |
3. PyTorch quantisation toolchain overview#
PyTorch (since 2.0) exposes quantisation APIs under torch.ao.quantization. The typical workflow is:
Define a float32 model.
Prepare it with a quantisation configuration (qconfig) that describes observer types and backend kernels.
Calibrate or train to collect activation statistics.
Convert to a quantised model with integer kernels.
Depending on whether calibration happens after or during training we distinguish PTQ and QAT.
4. Post-training quantisation (PTQ)#
PTQ transforms a pre-trained float model to a quantised one using a small calibration dataset. No gradient updates are required, making PTQ attractive when the original training data are unavailable or expensive to reproduce.
Below we quantise a toy multilayer perceptron. Although tiny, the code mirrors what you would do with transformer blocks: prepare, calibrate, convert.
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
transform = transforms.Compose([
transforms.ToTensor()
])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
test_loader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)
import torch
import torch.nn as nn
import torch.quantization
import torch.optim as optim
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.quant = torch.quantization.QuantStub()
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.relu2 = nn.ReLU()
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(32 * 14 * 14, 128)
self.relu3 = nn.ReLU()
self.fc2 = nn.Linear(128, 10)
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.relu1(self.conv1(x))
x = self.pool(self.relu2(self.conv2(x)))
# SAFE FLATTEN (works for channels-last too)
x = torch.flatten(x, 1) # <-- changed from reshape
# or: x = x.contiguous().view(x.size(0), -1)
x = self.relu3(self.fc1(x))
x = self.fc2(x)
x = self.dequant(x)
return x
model_fp32 = SimpleCNN()
model_fp32.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_fp32.parameters(), lr=0.001)
print("Training (just a few epochs for demo)...")
for epoch in range(1): # Just 1 epoch for time; increase for better accuracy!
model_fp32.train()
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model_fp32(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
Training (just a few epochs for demo)...
Fusing Conv + ReLU (or Conv+BN+ReLU) is important as it combines them into one operation, improving both the accuracy and speed of quantized models. We do this before quantization.
We then set the quantization config. It tells PyTorch how to observe and quantize the model. ‘fbgemm’ is preferred for x86 CPUs. For ARM CPUs, use ‘qnnpack’.
We then calibrate. This step feeds real data through the model. Observers collect min/max values to determine how to map floats to int8 (scale/zero-point).
def fuse_model(model):
torch.quantization.fuse_modules(model,
[['conv1', 'relu1'], ['conv2', 'relu2'], ['fc1', 'relu3']], inplace=True)
fuse_model(model_fp32)
model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')
model_fp32.cpu() # Quantization is CPU-only in PyTorch
torch.quantization.prepare(model_fp32, inplace=True)
print("Calibrating...")
model_fp32.eval()
with torch.no_grad():
for images, _ in train_loader:
model_fp32(images)
break # In practical cases, use more calibration data; here just a few for demo
quantized_model = torch.quantization.convert(model_fp32, inplace=False)
quantized_model.eval()
Calibrating...
/tmp/ipykernel_316898/4183642918.py:9: DeprecationWarning: torch.ao.quantization is deprecated and will be removed in 2.10.
For migrations of users:
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e)
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e)
see https://github.com/pytorch/ao/issues/2259 for more details
torch.quantization.prepare(model_fp32, inplace=True)
/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/ao/quantization/observer.py:246: UserWarning: Please use quant_min and quant_max to specify the range for observers. reduce_range will be deprecated in a future release of PyTorch.
warnings.warn(
/tmp/ipykernel_316898/4183642918.py:18: DeprecationWarning: torch.ao.quantization is deprecated and will be removed in 2.10.
For migrations of users:
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e)
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e)
see https://github.com/pytorch/ao/issues/2259 for more details
quantized_model = torch.quantization.convert(model_fp32, inplace=False)
SimpleCNN(
(quant): Quantize(scale=tensor([0.0079]), zero_point=tensor([0]), dtype=torch.quint8)
(conv1): QuantizedConvReLU2d(1, 16, kernel_size=(3, 3), stride=(1, 1), scale=0.013325857929885387, zero_point=0, padding=(1, 1))
(relu1): Identity()
(conv2): QuantizedConvReLU2d(16, 32, kernel_size=(3, 3), stride=(1, 1), scale=0.02389773167669773, zero_point=0, padding=(1, 1))
(relu2): Identity()
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(fc1): QuantizedLinearReLU(in_features=6272, out_features=128, scale=0.18056835234165192, zero_point=0, qscheme=torch.per_channel_affine)
(relu3): Identity()
(fc2): QuantizedLinear(in_features=128, out_features=10, scale=0.2785327434539795, zero_point=60, qscheme=torch.per_channel_affine)
(dequant): DeQuantize()
)
Let’s see how this model performs.
import os
def evaluate(model, test_loader, device="cpu"):
model.eval()
correct = total = 0
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy:', 100.0 * correct / total, '%')
return correct / total
print("\nEvaluating original (float32) model:")
evaluate(model_fp32, test_loader)
print("\nEvaluating quantized model:")
evaluate(quantized_model, test_loader)
torch.save(model_fp32.state_dict(), "float_model.pth")
torch.save(quantized_model.state_dict(), "quantized_model.pth")
float_size = os.path.getsize("float_model.pth") / 1024
quant_size = os.path.getsize("quantized_model.pth") / 1024
os.remove("float_model.pth"); os.remove("quantized_model.pth")
print(f"\nModel size (float32): {float_size:.1f} KB")
print(f"Model size (quantized): {quant_size:.1f} KB")
Evaluating original (float32) model:
Accuracy: 97.84 %
Evaluating quantized model:
Accuracy: 97.78 %
Model size (float32): 3209.8 KB
Model size (quantized): 802.0 KB
Even without additional fine-tuning, int8 PTQ often introduces only mild accuracy degradation (only 0.1% here!) and significantly reduce model size (by 75% here!).
5. Quantisation-aware training (QAT)#
PTQ struggles when activation distributions shift significantly between calibration data and real workloads, or when you push to 4-bit precision. Quantisation-aware training mitigates this by simulating quantisation effects during fine-tuning.
The high-level idea:
Insert fake-quantisation modules that emulate quant/dequant in the forward pass.
Backpropagate through straight-through estimators so gradients flow to float weights.
After training, convert to real integer kernels.
Below we fine-tune our toy network for a few gradient steps with QAT. Instead of optimising a real task, we match a random target tensor just to illustrate the mechanics.
import torch, torch.nn as nn, torch.optim as optim
import torchvision, torchvision.transforms as T
from torch.utils.data import DataLoader
from torch.ao.quantization import get_default_qat_qconfig, QConfigMapping
from torch.ao.quantization.quantize_fx import fuse_fx, prepare_qat_fx, convert_fx
class SimpleCNNFX(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 16, 3, stride=1, padding=1)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(16, 32, 3, stride=1, padding=1)
self.relu2 = nn.ReLU()
self.pool = nn.MaxPool2d(2, 2) # -> 14x14
self.fc1 = nn.Linear(32 * 14 * 14, 128)
self.relu3 = nn.ReLU()
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.relu1(self.conv1(x))
x = self.pool(self.relu2(self.conv2(x)))
x = torch.flatten(x, 1) # safe flatten (channels-last OK)
x = self.relu3(self.fc1(x))
x = self.fc2(x)
return x
float_model = SimpleCNNFX().to(device)
# Example input for FX tracing
example_inputs = (torch.randn(1, 1, 28, 28),)
# Fuse known patterns: (conv, relu), (linear, relu)
# This modifies module structure to ConvReLU / LinearReLU where applicable
fused_float = float_model.cpu()
backend = "fbgemm" # x86 backend
qconfig = get_default_qat_qconfig(backend)
qconfig_mapping = QConfigMapping().set_global(qconfig)
prepared_qat = prepare_qat_fx(fused_float, qconfig_mapping, example_inputs)
prepared_qat.to(device)
prepared_qat.train()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(prepared_qat.parameters(), lr=1e-3)
num_epochs = 1
num_steps_freeze = 200 # after N steps disable observers & freeze BN (QAT best-practice)
step = 0
for epoch in range(num_epochs):
prepared_qat.train()
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad(set_to_none=True)
logits = prepared_qat(images)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
step += 1
# Best practice: after a warmup, disable observers + freeze BN stats
if step == num_steps_freeze:
def _toggle(module):
if hasattr(module, 'apply_disable_observer'):
module.apply(torch.ao.quantization.disable_observer)
if hasattr(module, 'apply_freeze_bn'):
module.apply(torch.ao.quantization.freeze_bn_stats)
_toggle(prepared_qat)
print(f"Epoch {epoch+1}/{num_epochs} done. Last batch loss: {loss.item():.4f}")
print("\nEval float (QAT-prepared) before convert:")
_ = evaluate(prepared_qat, test_loader, device=device)
prepared_qat.eval()
quantized_model = convert_fx(prepared_qat.cpu()) # convert on CPU
print(quantized_model)
print("\nQuantized int8 accuracy:")
_ = evaluate(quantized_model, test_loader, device='cpu')
/tmp/ipykernel_316898/2062197539.py:39: DeprecationWarning: torch.ao.quantization is deprecated and will be removed in 2.10.
For migrations of users:
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e)
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e)
see https://github.com/pytorch/ao/issues/2259 for more details
prepared_qat = prepare_qat_fx(fused_float, qconfig_mapping, example_inputs)
/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/ao/quantization/observer.py:246: UserWarning: Please use quant_min and quant_max to specify the range for observers. reduce_range will be deprecated in a future release of PyTorch.
warnings.warn(
Epoch 1/1 done. Last batch loss: 0.0231
Eval float (QAT-prepared) before convert:
Accuracy: 98.19 %
GraphModule(
(conv1): QuantizedConvReLU2d(1, 16, kernel_size=(3, 3), stride=(1, 1), scale=0.015383397229015827, zero_point=0, padding=(1, 1))
(conv2): QuantizedConvReLU2d(16, 32, kernel_size=(3, 3), stride=(1, 1), scale=0.020312752574682236, zero_point=0, padding=(1, 1))
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(fc1): QuantizedLinearReLU(in_features=6272, out_features=128, scale=0.34092840552330017, zero_point=0, qscheme=torch.per_channel_affine)
(fc2): QuantizedLinear(in_features=128, out_features=10, scale=0.26298820972442627, zero_point=112, qscheme=torch.per_channel_affine)
)
def forward(self, x):
conv1_input_scale_0 = self.conv1_input_scale_0
conv1_input_zero_point_0 = self.conv1_input_zero_point_0
quantize_per_tensor = torch.quantize_per_tensor(x, conv1_input_scale_0, conv1_input_zero_point_0, torch.quint8); x = conv1_input_scale_0 = conv1_input_zero_point_0 = None
conv1 = self.conv1(quantize_per_tensor); quantize_per_tensor = None
conv2 = self.conv2(conv1); conv1 = None
pool = self.pool(conv2); conv2 = None
flatten = torch.flatten(pool, 1); pool = None
fc1 = self.fc1(flatten); flatten = None
fc2 = self.fc2(fc1); fc1 = None
dequantize_6 = fc2.dequantize(); fc2 = None
return dequantize_6
# To see more debug info, please use `graph_module.print_readable()`
Quantized int8 accuracy:
/tmp/ipykernel_316898/2062197539.py:76: DeprecationWarning: torch.ao.quantization is deprecated and will be removed in 2.10.
For migrations of users:
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e)
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e)
see https://github.com/pytorch/ao/issues/2259 for more details
quantized_model = convert_fx(prepared_qat.cpu()) # convert on CPU
Accuracy: 98.25 %
In a real LLM fine-tuning session you would:
Start from a float checkpoint, insert QAT observers in linear/attention modules, and resume training with a small learning rate.
Use task-representative prompts so the model adapts to quantisation noise where it matters most (e.g. long-context decoding).
Optionally freeze embedding layers or layer norms if they destabilise under fake-quantisation.
Practical tips#
Gradient scaling: Lower precisions amplify rounding error, so gradient clipping and slightly higher weight decay can stabilise QAT.
Progressive bit-width schedules: Begin with int8 fake quantisers and anneal to int4 once the model adapts.
Layer-wise decisions: Some components (e.g. logits projection) may remain in higher precision to safeguard perplexity.
Evaluation: Always compare quantised and float baselines on downstream metrics (perplexity, accuracy) and monitor decoding latency to ensure the trade-off is worthwhile.
6. Beyond standard QAT#
Research is rapidly evolving:
GPTQ/AWQ: Optimise weight quantisation with second-order approximations or activation-aware scaling.
LLM.int8(): Hybrid scheme keeping outlier activations in FP16 while quantising the rest.
FP8 training/inference: Uses floating-point formats with learned scaling, now supported on some GPUs.
Structured sparsity + quantisation: Combining 2:4 sparsity with int8 kernels for maximal throughput.
These techniques build upon the PTQ/QAT primitives described above.
7. Key takeaways#
Quantisation trades numerical precision for latency and memory gains, making it crucial for serving LLMs on cost-sensitive hardware.
PTQ is simple and works well when you have decent calibration data and stick to 8-bit weights/activations.
QAT injects quantisation noise during training, enabling more aggressive bit-widths or challenging domains at the cost of additional fine-tuning.
PyTorch’s
torch.ao.quantizationmodule provides batteries-included utilities to prepare, calibrate, and convert models, so you can experiment without rewriting your architecture.