llm-pretrain-kernel

Language Modelslm-evaluation-harnessnanoGPTrigorous codebase

Description

LLM Pretraining: Custom GPU Kernel Optimization

Research Question

Write a custom GPU kernel (Triton or CUDA via PyTorch) to implement a fused MLP operation for GPT-2 pretraining. Your kernel should fuse multiple operations to reduce memory bandwidth and improve throughput while maintaining or improving model quality.

What You Can Modify

The fused_mlp_forward function (lines 34-48) in custom_pretrain.py:

  • The MLP activation function (default: GELU via separate PyTorch ops)
  • Kernel fusion strategy (fuse linear + activation, save intermediate values)
  • Memory optimization (avoid materializing intermediate tensors)
  • Custom autograd Functions for efficient backward pass

Note: The function signature fused_mlp_forward(x, w_fc, w_proj) must be preserved.

  • x: input tensor (B*T, n_embd)
  • w_fc: first linear weight (4*n_embd, n_embd)
  • w_proj: second linear weight (n_embd, 4*n_embd)
  • Returns: output tensor (B*T, n_embd)

The MLP class calls this function and handles dropout separately.

Evaluation

  • Metric: Validation loss (cross-entropy, lower is better) and training throughput (elapsed time, lower is better) -- kernel optimizations that also change the activation function may improve loss
  • Model: GPT-2 Medium (24L/16H/1024D, ~355M params)
  • Dataset: FineWeb 10B (GPT-2 tokenizer), ~7.1B tokens (D=20N Chinchilla-optimal)
  • Training: 13535 iterations, BSZ=64, GA=8, 2-GPU DDP
  • Hardware: H200 GPU with Triton support

Code

custom_pretrain.py
EditableRead-only
1"""Custom GPT-2 Pretraining Script
2Based on Andrej Karpathy's nanoGPT, evaluated on FineWeb dataset.
3"""
4
5import math
6import inspect
7import os
8import time
9from contextlib import nullcontext
10from dataclasses import dataclass
11
12import numpy as np
13import torch
14import torch.nn as nn
15from torch.nn import functional as F

Additional context files (read-only):

  • nanoGPT/model.py

Results

ModelTypeval loss gpt-345m wikitext2 ppl gpt-345m lambada ppl gpt-345m elapsed gpt-345m arc easy lm-eval-345m hellaswag lm-eval-345m
relu_sq_torchbaseline2.27442.81067.05021776.00055.43033.840
triton_gelubaseline2.28743.94068.20020035.00054.67033.610
triton_relu_sq_fusedbaseline2.27543.52066.88030344.00054.80033.770
claude-opus-4.6vanilla2.30144.68068.67020455.00054.84032.950
deepseek-reasonervanilla2.35751.84078.99021437.00053.41032.010
gemini-3.1-pro-previewvanilla2.30044.82068.47019910.00054.29032.830
gpt-5.4vanilla2.32545.75072.28024492.00055.30032.830
qwen3.6-plusvanilla2.32245.54070.17017943.00054.34032.640
claude-opus-4.6agent2.30144.68068.67020455.00054.84032.950
deepseek-reasoneragent2.35751.84078.99021437.00053.41032.010
gemini-3.1-pro-previewagent2.30044.82068.47019910.00054.29032.830
gpt-5.4agent2.31144.22067.93038487.00056.19033.150
qwen3.6-plusagent2.32245.54070.17017943.00054.34032.640

Agent Conversations