llm-qat-algorithm

Language Modelsgptqrigorous codebase

Description

LLM Quantization-Aware Training (QAT) Algorithm

Research Question

Design a quantization-aware training (QAT) algorithm that finetunes Mistral-7B-v0.1 with fake quantization to minimize perplexity degradation after INT4/INT3/INT2 weight quantization.

Background

Quantization-aware training (QAT) inserts fake quantization operations into the forward pass during training, so the model learns to be robust to quantization noise. Unlike post-training quantization (PTQ), which operates on frozen weights, QAT adjusts the weights through gradient descent while simulating the quantization noise. This allows the model to compensate for precision loss, especially at very low bit-widths (INT3, INT2) where PTQ methods degrade severely.

The core challenge is designing the fake quantization function: how to quantize and dequantize during forward (to simulate real quantization), and how to pass gradients through the non-differentiable rounding operation during backward. Common approaches include:

  • Straight-Through Estimator (STE): Pass gradients through the rounding operation as-is (identity gradient). Simple but effective.
  • Learned Step Size Quantization (LSQ): Learn the quantization step size (scale) as a trainable parameter, with a specially derived gradient for the scale.
  • PACT: Learn the clipping range for activations to minimize quantization error.
  • Differentiable soft quantization: Replace hard rounding with a smooth approximation that has well-defined gradients.
  • Adaptive rounding (AdaRound-style): Learn per-weight rounding decisions (round up or down) instead of always rounding to nearest.
  • Progressive quantization: Start with higher precision and gradually reduce bit-width during training.

The finetuning operates on a pretrained Mistral-7B-v0.1, using WikiText-2 training data for a short number of steps (controlled via CONFIG_OVERRIDES). After finetuning, real quantize-dequantize roundtrip is applied to all weights and the model is evaluated on WikiText-2 test perplexity.

What You Can Modify

The editable region in custom_qat.py (lines 27-185) contains:

  • fake_quantize_weight(weight, num_bits, group_size): The fake quantization function applied to weights during the forward pass of training. Must simulate quantization noise while allowing gradient flow.
  • fake_quantize_activation(x, num_bits): Optional activation quantization during training (can be identity if not needed).
  • quantize_dequantize_weight(weight, num_bits, group_size): The real quantize-dequantize used for final evaluation (applied after training). This should match real INT quantization.
  • QATWrapper(nn.Module): A wrapper around nn.Linear that applies fake quantization during training. It wraps the original layer's weights and applies fake_quantize_weight in forward. Can hold additional learnable parameters (e.g., learned scales, clipping ranges).
  • prepare_qat_model(model, num_bits, group_size): Replaces nn.Linear layers with QATWrapper. Can initialize any extra learnable parameters.
  • CONFIG_OVERRIDES: A dict to override training hyperparameters (learning_rate, num_steps, batch_size, gradient_accumulation_steps, max_grad_norm, warmup_steps, weight_decay).

You can implement any approach:

  • STE variants: Different gradient estimators for the rounding function
  • Learned scales/clipping: Per-channel or per-group learnable quantization parameters
  • Soft quantization: Differentiable approximations to rounding
  • Adaptive rounding: Learn rounding decisions for each weight
  • Progressive/scheduled quantization: Vary quantization parameters during training
  • Mixed-precision awareness: Different strategies for different bit-widths
  • Loss augmentation: Add regularization terms to encourage quantization-friendly weights

Architecture

The task loads real Mistral-7B-v0.1 weights (7.24B parameters, pre-downloaded) and finetunes them with fake quantization inserted. The model runs on a single 80GB GPU using gradient checkpointing and small batch sizes to fit in memory.

Mistral-7B-v0.1 specs: 32 layers, 32 attention heads, 8 KV heads (GQA), 4096 hidden dimension, 14336 intermediate size, ~7.24B parameters.

The script (custom_qat.py):

  1. Loads Mistral-7B-v0.1 weights from /data/mistral-7b-v01
  2. Evaluates the FP16 (unquantized) model as baseline
  3. Loads WikiText-2 training data for QAT finetuning
  4. Replaces all nn.Linear layers with QATWrapper (your fake quantization)
  5. Finetunes with AdamW for N steps (default 200) with gradient checkpointing
  6. After training: applies real quantize_dequantize_weight to all layer weights
  7. Evaluates the quantized model on WikiText-2 test perplexity

Interface

def fake_quantize_weight(weight, num_bits=4, group_size=128):
    """Fake quantization for training -- must allow gradient flow.
    
    Args:
        weight: float tensor of shape (out_features, in_features)
        num_bits: target bit width (4, 3, or 2)
        group_size: columns per quantization group (128)
    Returns:
        Fake-quantized weight tensor (same shape, still float, but simulates quantization)
    """

def fake_quantize_activation(x, num_bits=8):
    """Optional activation quantization during training.
    
    Args:
        x: activation tensor
        num_bits: activation bit width
    Returns:
        Fake-quantized activation (or identity if not used)
    """

def quantize_dequantize_weight(weight, num_bits=4, group_size=128):
    """Real quantize-dequantize for evaluation after training.
    
    Args:
        weight: float tensor of shape (out_features, in_features)
        num_bits: target bit width
        group_size: columns per quantization group
    Returns:
        Quantized-then-dequantized weight tensor
    """

class QATWrapper(nn.Module):
    """Wraps an nn.Linear layer for quantization-aware training.
    
    Holds a reference to the original layer. Applies fake_quantize_weight
    to the weight during forward pass. Can hold extra learnable parameters.
    """

CONFIG_OVERRIDES = {}
# Allowed keys: learning_rate, num_steps, batch_size,
# gradient_accumulation_steps, max_grad_norm, warmup_steps, weight_decay

Constraints:

  • The training loop is FIXED (not editable) -- you only design the quantization algorithm
  • QATWrapper must wrap the original nn.Linear (use its weight, don't create new weights)
  • fake_quantize_weight must allow gradients to flow (e.g., via STE or differentiable approximation)
  • quantize_dequantize_weight must implement real symmetric integer quantization (hard rounding)
  • The model must fit on a single 80GB GPU during training (use gradient checkpointing)
  • copy, math, torch, torch.nn, F, np, os, time are available
  • Your algorithm must work for INT4, INT3, and INT2, and for group size 128

Evaluation

The algorithm is evaluated across three quantization settings of increasing difficulty:

  • qat-7b-int4: W4 group-128 -- standard setting, QAT should easily beat PTQ here
  • qat-7b-int3: W3 group-128 -- harder setting, QAT advantage becomes more significant
  • qat-7b-int2: W2 group-128 -- extreme setting (only 4 discrete levels!), where QAT is essential

Primary metric: wikitext2_ppl -- WikiText-2 perplexity after real quantization (lower is better) Secondary metrics: degradation -- perplexity increase over FP16 baseline; qat_time -- finetuning wall-clock time Model: Mistral-7B-v0.1 (32 layers, GQA, ~7.24B params) Weights: Real Mistral-7B-v0.1 from HuggingFace (pre-downloaded) Training data: WikiText-2 train split (128 sequences, 2048 tokens each) Evaluation data: WikiText-2 test split

Code

custom_qat.py
EditableRead-only
1"""Quantization-Aware Training (QAT) for LLMs -- finetune + quantize + evaluate.
2
3This script loads a pretrained LLM (Mistral-7B-v0.1), finetunes it with fake
4quantization inserted in the forward pass, then applies real quantize-dequantize
5roundtrip and evaluates perplexity on WikiText-2.
6
7The QAT algorithm is defined in the EDITABLE REGION below.
8Everything else (model loading, training loop, evaluation) is fixed.
9"""
10
11import argparse
12import copy
13import math
14import os
15import time

Results

No results available yet.