llm-pretrain-loss
Language Modelslm-evaluation-harnessnanoGPTrigorous codebase
Description
LLM Pretraining: Loss Function Optimization
Research Question
Design an improved loss function for GPT-2 language model pretraining. Your modifications should reduce validation loss compared to standard cross-entropy.
What You Can Modify
The compute_loss function (lines 189-191) in custom_pretrain.py:
- Loss function formulation (default: standard cross-entropy)
- Logit processing (e.g., softcapping, temperature scaling)
- Regularization terms (e.g., z-loss, entropy penalties)
- Label distribution modifications (e.g., label smoothing)
Note: The function signature compute_loss(logits, targets) must be preserved. logits has shape (B, T, V) and targets has shape (B, T). The function is called inside the model's forward pass during training.
Evaluation
- Metric: Validation loss (cross-entropy, lower is better), plus perplexity (WikiText-2, LAMBADA) and downstream accuracy (ARC-Easy, HellaSwag, PIQA, WinoGrande)
- Model: GPT-2 Medium (24L/16H/1024D, ~355M params)
- Dataset: FineWeb 10B (GPT-2 tokenizer), ~7.1B tokens (D=20N Chinchilla-optimal)
- Training: 13535 iterations, BSZ=64, GA=8, 2-GPU DDP
- Hardware: H200 GPU
Code
custom_pretrain.py
EditableRead-only
1"""Custom GPT-2 Pretraining Script2Based on Andrej Karpathy's nanoGPT, evaluated on FineWeb dataset.3"""45import math6import inspect7import os8import time9from contextlib import nullcontext10from dataclasses import dataclass1112import numpy as np13import torch14import torch.nn as nn15from torch.nn import functional as F
Additional context files (read-only):
nanoGPT/model.py
Results
| Model | Type | val loss gpt-345m ↓ | wikitext2 ppl gpt-345m ↓ | lambada ppl gpt-345m ↓ | arc easy lm-eval-345m ↑ | hellaswag lm-eval-345m ↑ |
|---|---|---|---|---|---|---|
| label_smoothing | baseline | 2.338 | 47.130 | 71.800 | 54.040 | 33.630 |
| softcap_ce | baseline | 2.270 | 43.460 | 67.100 | 56.480 | 31.820 |
| z_loss | baseline | 2.293 | 44.090 | 68.250 | 54.550 | 33.850 |
| claude-opus-4.6 | vanilla | 2.338 | 47.910 | 72.300 | 53.200 | 33.200 |
| deepseek-reasoner | vanilla | 2.341 | 46.950 | 71.320 | 55.180 | 33.350 |
| gemini-3.1-pro-preview | vanilla | 2.293 | 45.130 | 69.780 | 56.140 | 33.480 |
| gpt-5.4 | vanilla | 2.338 | 47.200 | 71.060 | 55.180 | 33.200 |
| qwen3.6-plus | vanilla | 3.491 | 125.290 | 178.800 | 52.570 | 33.250 |
| claude-opus-4.6 | agent | 2.338 | 47.910 | 72.300 | 53.200 | 33.200 |
| deepseek-reasoner | agent | 2.340 | 47.530 | 72.240 | 54.340 | 33.300 |
| gemini-3.1-pro-preview | agent | 2.293 | 45.130 | 69.780 | 56.140 | 33.480 |
| gpt-5.4 | agent | 2.288 | 44.840 | 69.420 | 54.670 | 33.530 |
| qwen3.6-plus | agent | 2.301 | 45.420 | 69.020 | 53.750 | 33.100 |