llm-pretrain-linear-attention
Description
LLM Pretraining: Linear/Subquadratic Attention Mechanism
Research Question
Design a novel linear or subquadratic attention mechanism for GPT-2 language model pretraining that achieves competitive validation loss while replacing standard softmax attention. The mechanism should scale better than O(n^2) in sequence length.
What You Can Modify
Two editable regions in custom_pretrain.py:
-
CausalSelfAttentionclass (lines 33-70): The attention mechanism itself, including:- The attention computation (replace softmax attention with linear/subquadratic alternatives)
- Feature maps, gating mechanisms, decay factors
- Query/Key/Value projections and transformations
- Internal state management (recurrent states, convolutions, etc.)
-
Blockclass (lines 88-100): The transformer block structure, including:- How attention and MLP sublayers are composed
- Normalization placement (pre-norm, post-norm)
- Residual connection patterns
Note: The flash-linear-attention (FLA) library is pre-installed and provides 27+ optimized linear attention implementations with Triton kernels. You can import from fla.layers (e.g., GatedLinearAttention, DeltaNet, MultiScaleRetention, LinearAttention, HGRN2, Mamba2, etc.) or implement your own mechanism from scratch.
Note: If your attention mechanism does not use learned absolute position embeddings, set self.use_pos_emb = False in __init__ — the model will then skip adding position embeddings in the forward pass.
Note: torch.compile is disabled for this task since FLA's Triton kernels are not compatible with it.
Evaluation
- Metric: Validation loss (cross-entropy, lower is better), plus perplexity (WikiText-2, LAMBADA) and downstream accuracy (ARC-Easy, HellaSwag, PIQA, WinoGrande)
- Model: GPT-2 Medium (24L/16H/1024D, ~355M params)
- Dataset: FineWeb 10B (GPT-2 tokenizer), ~7.1B tokens (D=20N Chinchilla-optimal)
- Training: 13535 iterations, BSZ=32, GA=16, 2-GPU DDP
- Hardware: H200 GPU
Code
1"""Custom GPT-2 Pretraining Script2Based on Andrej Karpathy's nanoGPT, evaluated on FineWeb dataset.3# flash-linear-attention is available: from fla.layers import GatedLinearAttention, DeltaNet, MultiScaleRetention, etc.4"""56import math7import inspect8import os9import time10from contextlib import nullcontext11from dataclasses import dataclass1213import numpy as np14import torch15import torch.nn as nn
Results
| Model | Type | val loss gpt-345m ↓ | wikitext2 ppl gpt-345m ↓ | lambada ppl gpt-345m ↓ |
|---|---|---|---|---|
| deltanet | baseline | 2.348 | 49.880 | 70.480 |
| gla | baseline | 2.448 | 64.320 | 84.730 |
| retnet | baseline | 2.479 | 66.670 | 82.360 |
| claude-opus-4.6 | vanilla | 2.309 | 47.460 | 78.450 |
| deepseek-reasoner | vanilla | 2.344 | 54.150 | 71.810 |
| gemini-3.1-pro-preview | vanilla | 2.201 | 40.440 | 60.100 |
| gpt-5.4 | vanilla | 2.369 | 57.230 | 74.770 |
| qwen3.6-plus | vanilla | 2.260 | 43.880 | 63.700 |
| claude-opus-4.6 | agent | 2.309 | 47.460 | 78.450 |
| deepseek-reasoner | agent | 2.286 | 46.350 | 65.230 |
| gemini-3.1-pro-preview | agent | 2.201 | 40.440 | 60.100 |
| gpt-5.4 | agent | 2.369 | 57.230 | 74.770 |
| qwen3.6-plus | agent | 2.260 | 43.880 | 63.700 |