llm-qat-algorithm
Description
LLM Quantization-Aware Training (QAT) Algorithm
Research Question
Design a quantization-aware training (QAT) algorithm that finetunes Mistral-7B-v0.1 with fake quantization to minimize perplexity degradation after INT4/INT3/INT2 weight quantization.
Background
Quantization-aware training (QAT) inserts fake quantization operations into the forward pass during training, so the model learns to be robust to quantization noise. Unlike post-training quantization (PTQ), which operates on frozen weights, QAT adjusts the weights through gradient descent while simulating the quantization noise. This allows the model to compensate for precision loss, especially at very low bit-widths (INT3, INT2) where PTQ methods degrade severely.
The core challenge is designing the fake quantization function: how to quantize and dequantize during forward (to simulate real quantization), and how to pass gradients through the non-differentiable rounding operation during backward. Common approaches include:
- Straight-Through Estimator (STE): Pass gradients through the rounding operation as-is (identity gradient). Simple but effective.
- Learned Step Size Quantization (LSQ): Learn the quantization step size (scale) as a trainable parameter, with a specially derived gradient for the scale.
- PACT: Learn the clipping range for activations to minimize quantization error.
- Differentiable soft quantization: Replace hard rounding with a smooth approximation that has well-defined gradients.
- Adaptive rounding (AdaRound-style): Learn per-weight rounding decisions (round up or down) instead of always rounding to nearest.
- Progressive quantization: Start with higher precision and gradually reduce bit-width during training.
The finetuning operates on a pretrained Mistral-7B-v0.1, using WikiText-2 training data for a short number of steps (controlled via CONFIG_OVERRIDES). After finetuning, real quantize-dequantize roundtrip is applied to all weights and the model is evaluated on WikiText-2 test perplexity.
What You Can Modify
The editable region in custom_qat.py (lines 27-185) contains:
fake_quantize_weight(weight, num_bits, group_size): The fake quantization function applied to weights during the forward pass of training. Must simulate quantization noise while allowing gradient flow.fake_quantize_activation(x, num_bits): Optional activation quantization during training (can be identity if not needed).quantize_dequantize_weight(weight, num_bits, group_size): The real quantize-dequantize used for final evaluation (applied after training). This should match real INT quantization.QATWrapper(nn.Module): A wrapper around nn.Linear that applies fake quantization during training. It wraps the original layer's weights and appliesfake_quantize_weightin forward. Can hold additional learnable parameters (e.g., learned scales, clipping ranges).prepare_qat_model(model, num_bits, group_size): Replaces nn.Linear layers with QATWrapper. Can initialize any extra learnable parameters.CONFIG_OVERRIDES: A dict to override training hyperparameters (learning_rate, num_steps, batch_size, gradient_accumulation_steps, max_grad_norm, warmup_steps, weight_decay).
You can implement any approach:
- STE variants: Different gradient estimators for the rounding function
- Learned scales/clipping: Per-channel or per-group learnable quantization parameters
- Soft quantization: Differentiable approximations to rounding
- Adaptive rounding: Learn rounding decisions for each weight
- Progressive/scheduled quantization: Vary quantization parameters during training
- Mixed-precision awareness: Different strategies for different bit-widths
- Loss augmentation: Add regularization terms to encourage quantization-friendly weights
Architecture
The task loads real Mistral-7B-v0.1 weights (7.24B parameters, pre-downloaded) and finetunes them with fake quantization inserted. The model runs on a single 80GB GPU using gradient checkpointing and small batch sizes to fit in memory.
Mistral-7B-v0.1 specs: 32 layers, 32 attention heads, 8 KV heads (GQA), 4096 hidden dimension, 14336 intermediate size, ~7.24B parameters.
The script (custom_qat.py):
- Loads Mistral-7B-v0.1 weights from
/data/mistral-7b-v01 - Evaluates the FP16 (unquantized) model as baseline
- Loads WikiText-2 training data for QAT finetuning
- Replaces all nn.Linear layers with QATWrapper (your fake quantization)
- Finetunes with AdamW for N steps (default 200) with gradient checkpointing
- After training: applies real
quantize_dequantize_weightto all layer weights - Evaluates the quantized model on WikiText-2 test perplexity
Interface
def fake_quantize_weight(weight, num_bits=4, group_size=128):
"""Fake quantization for training -- must allow gradient flow.
Args:
weight: float tensor of shape (out_features, in_features)
num_bits: target bit width (4, 3, or 2)
group_size: columns per quantization group (128)
Returns:
Fake-quantized weight tensor (same shape, still float, but simulates quantization)
"""
def fake_quantize_activation(x, num_bits=8):
"""Optional activation quantization during training.
Args:
x: activation tensor
num_bits: activation bit width
Returns:
Fake-quantized activation (or identity if not used)
"""
def quantize_dequantize_weight(weight, num_bits=4, group_size=128):
"""Real quantize-dequantize for evaluation after training.
Args:
weight: float tensor of shape (out_features, in_features)
num_bits: target bit width
group_size: columns per quantization group
Returns:
Quantized-then-dequantized weight tensor
"""
class QATWrapper(nn.Module):
"""Wraps an nn.Linear layer for quantization-aware training.
Holds a reference to the original layer. Applies fake_quantize_weight
to the weight during forward pass. Can hold extra learnable parameters.
"""
CONFIG_OVERRIDES = {}
# Allowed keys: learning_rate, num_steps, batch_size,
# gradient_accumulation_steps, max_grad_norm, warmup_steps, weight_decay
Constraints:
- The training loop is FIXED (not editable) -- you only design the quantization algorithm
- QATWrapper must wrap the original nn.Linear (use its weight, don't create new weights)
fake_quantize_weightmust allow gradients to flow (e.g., via STE or differentiable approximation)quantize_dequantize_weightmust implement real symmetric integer quantization (hard rounding)- The model must fit on a single 80GB GPU during training (use gradient checkpointing)
copy,math,torch,torch.nn,F,np,os,timeare available- Your algorithm must work for INT4, INT3, and INT2, and for group size 128
Evaluation
The algorithm is evaluated across three quantization settings of increasing difficulty:
- qat-7b-int4: W4 group-128 -- standard setting, QAT should easily beat PTQ here
- qat-7b-int3: W3 group-128 -- harder setting, QAT advantage becomes more significant
- qat-7b-int2: W2 group-128 -- extreme setting (only 4 discrete levels!), where QAT is essential
Primary metric: wikitext2_ppl -- WikiText-2 perplexity after real quantization (lower is better)
Secondary metrics: degradation -- perplexity increase over FP16 baseline; qat_time -- finetuning wall-clock time
Model: Mistral-7B-v0.1 (32 layers, GQA, ~7.24B params)
Weights: Real Mistral-7B-v0.1 from HuggingFace (pre-downloaded)
Training data: WikiText-2 train split (128 sequences, 2048 tokens each)
Evaluation data: WikiText-2 test split
Code
1"""Quantization-Aware Training (QAT) for LLMs -- finetune + quantize + evaluate.23This script loads a pretrained LLM (Mistral-7B-v0.1), finetunes it with fake4quantization inserted in the forward pass, then applies real quantize-dequantize5roundtrip and evaluates perplexity on WikiText-2.67The QAT algorithm is defined in the EDITABLE REGION below.8Everything else (model loading, training loop, evaluation) is fixed.9"""1011import argparse12import copy13import math14import os15import time
Results
No results available yet.