llm-pretrain-mlp

Language Modelslm-evaluation-harnessnanoGPTrigorous codebase

Description

LLM Pretraining: Feed-Forward Network Optimization

Research Question

Design an improved feed-forward network (MLP) for GPT-2 language model pretraining. Your modifications should reduce validation loss compared to the standard GELU MLP.

What You Can Modify

The MLP class (lines 73-86 in custom_pretrain.py), including:

  • Activation function (default: GELU)
  • Network architecture (default: two linear layers with 4x expansion)
  • Gating mechanisms
  • Hidden dimension sizing

Constraint: The MLP must accept input of shape (B, T, n_embd) and return output of the same shape.

Evaluation

  • Metric: Validation loss (cross-entropy, lower is better), plus perplexity (WikiText-2, LAMBADA) and downstream accuracy (ARC-Easy, HellaSwag, PIQA, WinoGrande)
  • Model: GPT-2 Medium (24L/16H/1024D, ~355M params)
  • Dataset: FineWeb 10B (GPT-2 tokenizer), ~7.1B tokens (D=20N Chinchilla-optimal)
  • Training: 12030 iterations, BSZ=96, GA=6, 2-GPU DDP
  • Hardware: H200 GPU

Code

custom_pretrain.py
EditableRead-only
1"""Custom GPT-2 Pretraining Script
2Based on Andrej Karpathy's nanoGPT, evaluated on FineWeb dataset.
3"""
4
5import math
6import inspect
7import os
8import time
9from contextlib import nullcontext
10from dataclasses import dataclass
11
12import numpy as np
13import torch
14import torch.nn as nn
15from torch.nn import functional as F

Additional context files (read-only):

  • nanoGPT/model.py

Results

ModelTypeval loss gpt-345m wikitext2 ppl gpt-345m lambada ppl gpt-345m arc easy lm-eval-345m hellaswag lm-eval-345m
geglubaseline2.29544.13068.73054.88032.900
relu_squaredbaseline2.28343.33066.56055.26033.860
swiglubaseline2.29244.33066.81054.71033.400
claude-opus-4.6vanilla2.30344.11071.72054.76032.670
deepseek-reasonervanilla2.31344.30068.15052.65033.320
gemini-3.1-pro-previewvanilla2.28644.76069.10055.77033.610
gpt-5.4vanilla2.28443.23067.22052.86033.190
qwen3.6-plusvanilla2.30043.71066.46054.42033.340
claude-opus-4.6agent2.29943.97068.07054.12033.620
deepseek-reasoneragent2.21438.92061.74057.37035.210
gemini-3.1-pro-previewagent2.29243.35066.23054.42033.340
gpt-5.4agent2.32145.33070.55054.67032.840
qwen3.6-plusagent2.30043.71066.46054.42033.340

Agent Conversations