llm-kv-structural-reduction
ML SystemsnanoGPTrigorous codebase
Description
LLM Pretraining: KV-Structural Reduction
Research Question
Design a more KV-efficient causal attention structure for GPT-style pretraining, with the primary focus on the tradeoff between KV head sharing and latent KV compression.
- how much quality can be preserved by reducing the realized KV state
- whether grouped/shared KV heads or latent KV bottlenecks give the better quality-memory tradeoff under a fixed small-scale pretraining budget
What You Can Modify
One editable region in custom_pretrain.py:
- Attention-structure region (lines 35-154), including:
build_kv_heads(...): how many KV heads are materialized relative to query headscross_layer_share(...): optional structural sharing hook inside the attention stacklatent_kv_project(...): whether K/V are compressed into a lower-rank latent spaceCausalSelfAttention: how the above choices are instantiated inside the attention block, including the internal query/KV projection and attention mixing path
Intended Task Boundary
- This task studies KV-state reduction inside the attention block.
- The main comparison axes are:
- dense MHA vs grouped/shared KV heads
- grouped/shared KV heads vs latent KV compression
cross_layer_share(...)remains available as an auxiliary structural hook inside the same block.- The evaluator enforces the top-level boundary of this region with an AST validator:
only the allowed helper functions plus
CausalSelfAttentionmay appear in the editable span. That keeps edits inside the attention block, even though the internal contents ofCausalSelfAttentionremain flexible.
Evaluation
- Primary metric: validation loss (cross-entropy, lower is better)
- Secondary metrics:
kv_bytes_per_token(lower is better; evaluator-derived analytic KV footprint from the realized attention structure)head_sharing_ratio(higher means more aggressive KV head sharing; evaluator-derived from realized head topology)latent_rank_ratio(lower is better only for methods that truly use latent KV compression; evaluator-derived from latent projection dimensions)heldout_loss(lower is better; perplexity-style held-out evaluation on the packaged eval corpus, reported consistently for every visible benchmark)generation_toks_per_s(higher is better; lightweight autoregressive generation throughput on held-out tokens, reported consistently for every visible benchmark)
- Visible benchmark regimes:
train-ctx1024: 124M pretraining atblock_size=1024train-ctx256: 124M short-context pretraining atblock_size=256eval-wikitext2: 124M pretraining followed by held-out eval and generation-side diagnostics onWikiText-2eval-wikitext103: 124M pretraining followed by held-out eval and generation-side diagnostics onWikiText-103eval-lambada: 124M pretraining followed by held-out eval and generation-side diagnostics onLAMBADA
- Training data: ClimbMix tokenized training split
- Held-out eval data: packaged
evaldependency (WikiText-2,WikiText-103,LAMBADA) - Training schedule: same budget family as the existing
nanoGPTMLS-Bench pretraining tasks
Code
custom_pretrain.py
EditableRead-only
1"""Custom GPT-2 pretraining script for KV-structural reduction tasks.23Based on Andrej Karpathy's nanoGPT, with a narrow editable region for KV4structure changes such as grouped KV heads and latent KV compression.5"""67import ast8import inspect9import json10import math11import os12import time13from contextlib import nullcontext14from dataclasses import dataclass15
Results
| Model | Type | val loss train-ctx1024 ↓ | kv bytes per token train-ctx1024 ↓ | heldout loss train-ctx1024 ↓ | generation toks per s train-ctx1024 ↑ | val loss train-ctx256 ↓ | kv bytes per token train-ctx256 ↓ | heldout loss train-ctx256 ↓ | generation toks per s train-ctx256 ↑ | val loss eval-wikitext2 ↓ | kv bytes per token eval-wikitext2 ↓ | heldout loss eval-wikitext2 ↓ | generation toks per s eval-wikitext2 ↑ | val loss eval-wikitext103 ↓ | kv bytes per token eval-wikitext103 ↓ | heldout loss eval-wikitext103 ↓ | generation toks per s eval-wikitext103 ↑ | val loss eval-lambada ↓ | kv bytes per token eval-lambada ↓ | heldout loss eval-lambada ↓ | generation toks per s eval-lambada ↑ |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| gqa | baseline | 4.970 | 768.000 | 6.453 | 255.047 | 4.976 | 768.000 | 6.477 | 241.103 | 4.999 | 768.000 | 6.915 | 224.089 | 4.949 | 768.000 | 6.878 | 218.951 | 4.995 | 768.000 | 5.460 | 223.799 |
| mha | baseline | 5.048 | 3072.000 | 6.411 | 269.104 | 5.083 | 3072.000 | 6.556 | 280.265 | 5.011 | 3072.000 | 6.841 | 234.101 | 4.989 | 3072.000 | 6.798 | 239.934 | 4.996 | 3072.000 | 5.556 | 246.772 |
| mla | baseline | 5.054 | 640.000 | 6.639 | 98.563 | 5.152 | 640.000 | 6.650 | 93.963 | 5.054 | 640.000 | 7.231 | 89.563 | 5.054 | 640.000 | 7.231 | 93.176 | 5.054 | 640.000 | 5.609 | 95.500 |
| mqa | baseline | 5.039 | 256.000 | 6.425 | 209.550 | 5.074 | 256.000 | 6.466 | 224.450 | 5.040 | 256.000 | 6.951 | 191.742 | 5.105 | 256.000 | 6.989 | 198.993 | 5.049 | 256.000 | 5.437 | 200.152 |