llm-pretrain-residual

Language Modelslm-evaluation-harnessnanoGPTrigorous codebase

Description

LLM Pretraining: Residual Connection Strategy

Research Question

Improve the residual connection strategy in a GPT-style language model. The current architecture uses standard Pre-LN residual connections (x + sublayer(x)) in each transformer block. Your task is to redesign how information flows through the residual stream across layers to achieve lower validation loss.

Background

Standard Residual Connections

The default GPT architecture uses simple additive residual connections in each block:

x = x + self.attn(self.ln_1(x))   # attention sublayer
x = x + self.mlp(self.ln_2(x))    # MLP sublayer

While effective, this fixed accumulation pattern may not be optimal for deep networks. The residual stream is the primary information highway through the model, and its design critically affects gradient flow, feature reuse, and training dynamics.

Research Directions

Several recent works have proposed improvements to residual connections:

  1. Per-layer residual scaling: Learnable scalars that modulate the residual stream at each layer (inspired by modded-nanogpt, ReZero, SkipInit).
  2. Initial embedding blending: Blending the initial token embedding back at each layer to preserve token identity (x0 residual connections).
  3. Hyper-Connections: Maintaining m parallel residual streams with learned transition matrices for richer information flow across layers (Zhu et al., 2025).
  4. Attention Residuals: Using softmax attention over all previous layer outputs to dynamically select which representations to combine (Kimi Team, 2026).

What You Can Modify

Block Class (lines 88-99)

The Block class defines per-block residual behavior. You can change how attention and MLP outputs are combined with the residual stream within each block.

Residual Stream Parameters (lines 128-130)

Add custom parameters to GPT.__init__ for your residual connection strategy (e.g., per-layer scalars, transition matrices, query vectors).

Block Loop in GPT.forward (lines 162-164)

The main loop that iterates through transformer blocks. You can modify how blocks are called and how their outputs are accumulated (e.g., multi-stream processing, attention over layer outputs).

Optimizer Configuration (lines 175-192)

The configure_optimizers method. If you add new parameters, you may want to assign them to appropriate optimizer groups with custom learning rates and weight decay.

Training Hyperparameters (line 251)

The CONFIG_OVERRIDES dictionary for adjusting learning rate, weight decay, etc.

Note: The CausalSelfAttention, MLP, LayerNorm, and GPTConfig classes are fixed. The Block.forward signature must accept x and return a tensor of the same shape. The GPT.forward must accept (idx, targets=None) and return (logits, loss).

Evaluation

  • Primary metric: Validation loss (val_loss, lower is better)
  • Secondary metrics: Perplexity on WikiText-2 and LAMBADA, plus downstream task accuracy (ARC-Easy, HellaSwag, PIQA, WinoGrande)
  • Model: GPT-2 Medium (24L/16H/1024D, ~355M params)
  • Dataset: FineWeb 10B (GPT-2 tokenizer), ~7.1B tokens (D=20N Chinchilla-optimal)
  • Training: 13535 iterations, BSZ=32, GA=16, 2-GPU DDP
  • Hardware: H200 GPU

Code

custom_pretrain.py
EditableRead-only
1"""Custom GPT-2 Pretraining Script
2Based on Andrej Karpathy's nanoGPT, evaluated on FineWeb dataset.
3"""
4
5import math
6import inspect
7import os
8import time
9from contextlib import nullcontext
10from dataclasses import dataclass
11
12import numpy as np
13import torch
14import torch.nn as nn
15from torch.nn import functional as F

Additional context files (read-only):

  • nanoGPT/model.py

Results

ModelTypeval loss gpt-345m wikitext2 ppl gpt-345m lambada ppl gpt-345m arc easy lm-eval-345m hellaswag lm-eval-345m
full_attnresbaseline2.25441.82064.32055.01034.050
learned_scalingbaseline2.26843.91068.76055.85033.900
proresbaseline2.27144.11067.21055.35033.910
vanillabaseline2.27644.28070.09054.12033.820
claude-opus-4.6vanilla2.31646.16071.04054.08032.480
deepseek-reasonervanilla2.27344.16068.38054.50033.280
gemini-3.1-pro-previewvanilla2.25741.09063.34054.97034.350
gpt-5.4vanilla3.585315.190267.53040.07026.220
qwen3.6-plusvanilla10.01645715.28026108.76026.05025.970
claude-opus-4.6agent2.31646.16071.04054.08032.480
deepseek-reasoneragent2.29543.60068.09054.92033.190
gemini-3.1-pro-previewagent2.25741.09063.34054.97034.350
gpt-5.4agent3.022110.870148.82040.36027.050
qwen3.6-plusagent10.01645715.28026108.76026.05025.970

Agent Conversations