llm-pretrain-optimizer

Language Modelslm-evaluation-harnessnanoGPTrigorous codebase

Description

LLM Pretraining: Optimizer & Learning Rate Schedule Optimization

Research Question

Design an improved optimizer and/or learning rate schedule for GPT-2 language model pretraining. Your modifications should reduce validation loss compared to the standard AdamW with cosine annealing schedule.

What You Can Modify

Two regions in custom_pretrain.py:

  1. configure_optimizers method (lines 172-189): Optimizer creation and parameter grouping
  2. get_lr function (lines 192-201): Learning rate schedule

You can modify:

  • The optimization algorithm (default: AdamW with fused implementation)
  • Parameter grouping strategy (default: weight decay for 2D params, no decay for 1D)
  • Learning rate schedule shape (default: cosine with linear warmup)
  • Any optimizer hyperparameters

Note: The training loop calls get_lr(it, warmup_iters, lr_decay_iters, learning_rate, min_lr) — keep this signature compatible. The optimizer returned by configure_optimizers must support .zero_grad(), .step(), and .param_groups.

Evaluation

  • Metric: Validation loss (cross-entropy, lower is better), plus perplexity (WikiText-2, LAMBADA) and downstream accuracy (ARC-Easy, HellaSwag, PIQA, WinoGrande)
  • Model: GPT-2 Medium (24L/16H/1024D, ~355M params)
  • Dataset: FineWeb 10B (GPT-2 tokenizer), ~7.1B tokens (D=20N Chinchilla-optimal)
  • Training: 12030 iterations, BSZ=96, GA=6, 2-GPU DDP
  • Hardware: H200 GPU

Code

custom_pretrain.py
EditableRead-only
1"""Custom GPT-2 Pretraining Script
2Based on Andrej Karpathy's nanoGPT, evaluated on FineWeb dataset.
3"""
4
5import math
6import inspect
7import os
8import time
9from contextlib import nullcontext
10from dataclasses import dataclass
11
12import numpy as np
13import torch
14import torch.nn as nn
15from torch.nn import functional as F

Additional context files (read-only):

  • nanoGPT/model.py
  • nanoGPT/train.py

Results

ModelTypeval loss gpt-345m wikitext2 ppl gpt-345m lambada ppl gpt-345m arc easy lm-eval-345m hellaswag lm-eval-345m
adamw_nesterovbaseline2.32346.96071.82055.18032.750
lionbaseline2.20338.96060.05058.21035.640
muonbaseline2.20037.98060.08060.19036.850
claude-opus-4.6vanilla2.20037.63059.71060.14036.880
deepseek-reasonervanilla2.31045.12069.37053.62032.810
gemini-3.1-pro-previewvanilla2.22239.11062.34060.10035.410
gpt-5.4vanilla2.25542.42067.94057.91034.050
qwen3.6-plusvanilla6.9815585.8304826.54029.00025.310
claude-opus-4.6agent2.22139.78061.80058.71035.950
deepseek-reasoneragent2.31045.12069.37053.62032.810
gemini-3.1-pro-previewagent2.19838.20059.71059.47036.770
gpt-5.4agent2.24742.02064.97057.15033.980
qwen3.6-plusagent2.17337.14059.58059.64037.010

Agent Conversations