llm-pretrain-linear-attention

Language Modelslm-evaluation-harnessnanoGPTrigorous codebase

Description

LLM Pretraining: Linear/Subquadratic Attention Mechanism

Research Question

Design a novel linear or subquadratic attention mechanism for GPT-2 language model pretraining that achieves competitive validation loss while replacing standard softmax attention. The mechanism should scale better than O(n^2) in sequence length.

What You Can Modify

Two editable regions in custom_pretrain.py:

  1. CausalSelfAttention class (lines 33-70): The attention mechanism itself, including:

    • The attention computation (replace softmax attention with linear/subquadratic alternatives)
    • Feature maps, gating mechanisms, decay factors
    • Query/Key/Value projections and transformations
    • Internal state management (recurrent states, convolutions, etc.)
  2. Block class (lines 88-100): The transformer block structure, including:

    • How attention and MLP sublayers are composed
    • Normalization placement (pre-norm, post-norm)
    • Residual connection patterns

Note: The flash-linear-attention (FLA) library is pre-installed and provides 27+ optimized linear attention implementations with Triton kernels. You can import from fla.layers (e.g., GatedLinearAttention, DeltaNet, MultiScaleRetention, LinearAttention, HGRN2, Mamba2, etc.) or implement your own mechanism from scratch.

Note: If your attention mechanism does not use learned absolute position embeddings, set self.use_pos_emb = False in __init__ — the model will then skip adding position embeddings in the forward pass.

Note: torch.compile is disabled for this task since FLA's Triton kernels are not compatible with it.

Evaluation

  • Metric: Validation loss (cross-entropy, lower is better), plus perplexity (WikiText-2, LAMBADA) and downstream accuracy (ARC-Easy, HellaSwag, PIQA, WinoGrande)
  • Model: GPT-2 Medium (24L/16H/1024D, ~355M params)
  • Dataset: FineWeb 10B (GPT-2 tokenizer), ~7.1B tokens (D=20N Chinchilla-optimal)
  • Training: 13535 iterations, BSZ=32, GA=16, 2-GPU DDP
  • Hardware: H200 GPU

Code

custom_pretrain.py
EditableRead-only
1"""Custom GPT-2 Pretraining Script
2Based on Andrej Karpathy's nanoGPT, evaluated on FineWeb dataset.
3# flash-linear-attention is available: from fla.layers import GatedLinearAttention, DeltaNet, MultiScaleRetention, etc.
4"""
5
6import math
7import inspect
8import os
9import time
10from contextlib import nullcontext
11from dataclasses import dataclass
12
13import numpy as np
14import torch
15import torch.nn as nn

Results

ModelTypeval loss gpt-345m wikitext2 ppl gpt-345m lambada ppl gpt-345m
deltanetbaseline2.34849.88070.480
glabaseline2.44864.32084.730
retnetbaseline2.47966.67082.360
claude-opus-4.6vanilla2.30947.46078.450
deepseek-reasonervanilla2.34454.15071.810
gemini-3.1-pro-previewvanilla2.20140.44060.100
gpt-5.4vanilla2.36957.23074.770
qwen3.6-plusvanilla2.26043.88063.700
claude-opus-4.6agent2.30947.46078.450
deepseek-reasoneragent2.28646.35065.230
gemini-3.1-pro-previewagent2.20140.44060.100
gpt-5.4agent2.36957.23074.770
qwen3.6-plusagent2.26043.88063.700

Agent Conversations