llm-kv-structural-reduction

ML SystemsnanoGPTrigorous codebase

Description

LLM Pretraining: KV-Structural Reduction

Research Question

Design a more KV-efficient causal attention structure for GPT-style pretraining, with the primary focus on the tradeoff between KV head sharing and latent KV compression.

  • how much quality can be preserved by reducing the realized KV state
  • whether grouped/shared KV heads or latent KV bottlenecks give the better quality-memory tradeoff under a fixed small-scale pretraining budget

What You Can Modify

One editable region in custom_pretrain.py:

  1. Attention-structure region (lines 35-154), including:
    • build_kv_heads(...): how many KV heads are materialized relative to query heads
    • cross_layer_share(...): optional structural sharing hook inside the attention stack
    • latent_kv_project(...): whether K/V are compressed into a lower-rank latent space
    • CausalSelfAttention: how the above choices are instantiated inside the attention block, including the internal query/KV projection and attention mixing path

Intended Task Boundary

  • This task studies KV-state reduction inside the attention block.
  • The main comparison axes are:
    • dense MHA vs grouped/shared KV heads
    • grouped/shared KV heads vs latent KV compression
  • cross_layer_share(...) remains available as an auxiliary structural hook inside the same block.
  • The evaluator enforces the top-level boundary of this region with an AST validator: only the allowed helper functions plus CausalSelfAttention may appear in the editable span. That keeps edits inside the attention block, even though the internal contents of CausalSelfAttention remain flexible.

Evaluation

  • Primary metric: validation loss (cross-entropy, lower is better)
  • Secondary metrics:
    • kv_bytes_per_token (lower is better; evaluator-derived analytic KV footprint from the realized attention structure)
    • head_sharing_ratio (higher means more aggressive KV head sharing; evaluator-derived from realized head topology)
    • latent_rank_ratio (lower is better only for methods that truly use latent KV compression; evaluator-derived from latent projection dimensions)
    • heldout_loss (lower is better; perplexity-style held-out evaluation on the packaged eval corpus, reported consistently for every visible benchmark)
    • generation_toks_per_s (higher is better; lightweight autoregressive generation throughput on held-out tokens, reported consistently for every visible benchmark)
  • Visible benchmark regimes:
    • train-ctx1024: 124M pretraining at block_size=1024
    • train-ctx256: 124M short-context pretraining at block_size=256
    • eval-wikitext2: 124M pretraining followed by held-out eval and generation-side diagnostics on WikiText-2
    • eval-wikitext103: 124M pretraining followed by held-out eval and generation-side diagnostics on WikiText-103
    • eval-lambada: 124M pretraining followed by held-out eval and generation-side diagnostics on LAMBADA
  • Training data: ClimbMix tokenized training split
  • Held-out eval data: packaged eval dependency (WikiText-2, WikiText-103, LAMBADA)
  • Training schedule: same budget family as the existing nanoGPT MLS-Bench pretraining tasks

Code

custom_pretrain.py
EditableRead-only
1"""Custom GPT-2 pretraining script for KV-structural reduction tasks.
2
3Based on Andrej Karpathy's nanoGPT, with a narrow editable region for KV
4structure changes such as grouped KV heads and latent KV compression.
5"""
6
7import ast
8import inspect
9import json
10import math
11import os
12import time
13from contextlib import nullcontext
14from dataclasses import dataclass
15

Results

ModelTypeval loss train-ctx1024 kv bytes per token train-ctx1024 heldout loss train-ctx1024 generation toks per s train-ctx1024 val loss train-ctx256 kv bytes per token train-ctx256 heldout loss train-ctx256 generation toks per s train-ctx256 val loss eval-wikitext2 kv bytes per token eval-wikitext2 heldout loss eval-wikitext2 generation toks per s eval-wikitext2 val loss eval-wikitext103 kv bytes per token eval-wikitext103 heldout loss eval-wikitext103 generation toks per s eval-wikitext103 val loss eval-lambada kv bytes per token eval-lambada heldout loss eval-lambada generation toks per s eval-lambada
gqabaseline4.970768.0006.453255.0474.976768.0006.477241.1034.999768.0006.915224.0894.949768.0006.878218.9514.995768.0005.460223.799
mhabaseline5.0483072.0006.411269.1045.0833072.0006.556280.2655.0113072.0006.841234.1014.9893072.0006.798239.9344.9963072.0005.556246.772
mlabaseline5.054640.0006.63998.5635.152640.0006.65093.9635.054640.0007.23189.5635.054640.0007.23193.1765.054640.0005.60995.500
mqabaseline5.039256.0006.425209.5505.074256.0006.466224.4505.040256.0006.951191.7425.105256.0006.989198.9935.049256.0005.437200.152