llm-pretrain-kernel
Language Modelslm-evaluation-harnessnanoGPTrigorous codebase
Description
LLM Pretraining: Custom GPU Kernel Optimization
Research Question
Write a custom GPU kernel (Triton or CUDA via PyTorch) to implement a fused MLP operation for GPT-2 pretraining. Your kernel should fuse multiple operations to reduce memory bandwidth and improve throughput while maintaining or improving model quality.
What You Can Modify
The fused_mlp_forward function (lines 34-48) in custom_pretrain.py:
- The MLP activation function (default: GELU via separate PyTorch ops)
- Kernel fusion strategy (fuse linear + activation, save intermediate values)
- Memory optimization (avoid materializing intermediate tensors)
- Custom autograd Functions for efficient backward pass
Note: The function signature fused_mlp_forward(x, w_fc, w_proj) must be preserved.
x: input tensor(B*T, n_embd)w_fc: first linear weight(4*n_embd, n_embd)w_proj: second linear weight(n_embd, 4*n_embd)- Returns: output tensor
(B*T, n_embd)
The MLP class calls this function and handles dropout separately.
Evaluation
- Metric: Validation loss (cross-entropy, lower is better) and training throughput (elapsed time, lower is better) -- kernel optimizations that also change the activation function may improve loss
- Model: GPT-2 Medium (24L/16H/1024D, ~355M params)
- Dataset: FineWeb 10B (GPT-2 tokenizer), ~7.1B tokens (D=20N Chinchilla-optimal)
- Training: 13535 iterations, BSZ=64, GA=8, 2-GPU DDP
- Hardware: H200 GPU with Triton support
Code
custom_pretrain.py
EditableRead-only
1"""Custom GPT-2 Pretraining Script2Based on Andrej Karpathy's nanoGPT, evaluated on FineWeb dataset.3"""45import math6import inspect7import os8import time9from contextlib import nullcontext10from dataclasses import dataclass1112import numpy as np13import torch14import torch.nn as nn15from torch.nn import functional as F
Additional context files (read-only):
nanoGPT/model.py
Results
| Model | Type | val loss gpt-345m ↓ | wikitext2 ppl gpt-345m ↓ | lambada ppl gpt-345m ↓ | elapsed gpt-345m ↓ | arc easy lm-eval-345m ↑ | hellaswag lm-eval-345m ↑ |
|---|---|---|---|---|---|---|---|
| relu_sq_torch | baseline | 2.274 | 42.810 | 67.050 | 21776.000 | 55.430 | 33.840 |
| triton_gelu | baseline | 2.287 | 43.940 | 68.200 | 20035.000 | 54.670 | 33.610 |
| triton_relu_sq_fused | baseline | 2.275 | 43.520 | 66.880 | 30344.000 | 54.800 | 33.770 |
| claude-opus-4.6 | vanilla | 2.301 | 44.680 | 68.670 | 20455.000 | 54.840 | 32.950 |
| deepseek-reasoner | vanilla | 2.357 | 51.840 | 78.990 | 21437.000 | 53.410 | 32.010 |
| gemini-3.1-pro-preview | vanilla | 2.300 | 44.820 | 68.470 | 19910.000 | 54.290 | 32.830 |
| gpt-5.4 | vanilla | 2.325 | 45.750 | 72.280 | 24492.000 | 55.300 | 32.830 |
| qwen3.6-plus | vanilla | 2.322 | 45.540 | 70.170 | 17943.000 | 54.340 | 32.640 |
| claude-opus-4.6 | agent | 2.301 | 44.680 | 68.670 | 20455.000 | 54.840 | 32.950 |
| deepseek-reasoner | agent | 2.357 | 51.840 | 78.990 | 21437.000 | 53.410 | 32.010 |
| gemini-3.1-pro-preview | agent | 2.300 | 44.820 | 68.470 | 19910.000 | 54.290 | 32.830 |
| gpt-5.4 | agent | 2.311 | 44.220 | 67.930 | 38487.000 | 56.190 | 33.150 |
| qwen3.6-plus | agent | 2.322 | 45.540 | 70.170 | 17943.000 | 54.340 | 32.640 |