llm-pretrain-attention

Language Modelslm-evaluation-harnessnanoGPTrigorous codebase

Description

LLM Pretraining: Attention Mechanism Optimization

Research Question

Design an improved self-attention mechanism for GPT-2 language model pretraining. Your modifications should reduce validation loss compared to the standard multi-head attention with learned absolute position embeddings.

What You Can Modify

The CausalSelfAttention class (lines 34-70 in custom_pretrain.py), including:

  • Position encoding scheme (the default uses learned absolute position embeddings via wpe)
  • Query/Key/Value computation and projection
  • Attention score computation and masking
  • Any attention-related hyperparameters

Note: If your attention mechanism implements its own position encoding (replacing the learned wpe), set self.use_pos_emb = False in __init__ — the model will then skip adding position embeddings in the forward pass.

Evaluation

  • Metric: Validation loss (cross-entropy, lower is better), plus perplexity (WikiText-2, LAMBADA) and downstream accuracy (ARC-Easy, HellaSwag, PIQA, WinoGrande)
  • Model: GPT-2 Medium (24L/16H/1024D, ~355M params)
  • Dataset: FineWeb 10B (GPT-2 tokenizer), ~7.1B tokens (D=20N Chinchilla-optimal)
  • Training: 13535 iterations, BSZ=64, GA=8, 2-GPU DDP
  • Hardware: H200 GPU

Code

custom_pretrain.py
EditableRead-only
1"""Custom GPT-2 Pretraining Script
2Based on Andrej Karpathy's nanoGPT, evaluated on FineWeb dataset.
3"""
4
5import math
6import inspect
7import os
8import time
9from contextlib import nullcontext
10from dataclasses import dataclass
11
12import numpy as np
13import torch
14import torch.nn as nn
15from torch.nn import functional as F

Additional context files (read-only):

  • nanoGPT/model.py

Results

ModelTypeval loss gpt-345m wikitext2 ppl gpt-345m lambada ppl gpt-345m arc easy lm-eval-345m hellaswag lm-eval-345m
qk_normbaseline2.28843.65069.99055.64033.410
ropebaseline2.25743.17065.81057.32034.480
rope_qk_normbaseline2.25943.44067.20057.83034.240
claude-opus-4.6vanilla2.24642.22066.13058.38034.600
deepseek-reasonervanilla2.22140.08061.81057.49035.370
gemini-3.1-pro-previewvanilla2.26043.06065.37056.57034.570
gpt-5.4vanilla---25.08025.040
qwen3.6-plusvanilla2.24642.57066.15056.69034.630
claude-opus-4.6agent2.24642.22066.13058.38034.600
deepseek-reasoneragent2.22140.08061.81057.49035.370
gemini-3.1-pro-previewagent2.25541.45064.80057.70034.550
gpt-5.4agent---25.08025.040
qwen3.6-plusagent2.24642.57066.15056.69034.630

Agent Conversations