llm-pretrain-embedding
Language Modelslm-evaluation-harnessnanoGPTrigorous codebase
Description
LLM Pretraining: Embedding Strategy Optimization
Research Question
Design an improved embedding strategy for GPT-2 language model pretraining. Your modifications should reduce validation loss compared to standard token + position embeddings with weight tying.
What You Can Modify
The TokenEmbedding class (lines 116-140) in custom_pretrain.py:
- Token embedding representation (default: learned token + position embeddings)
- Weight tying strategy (default: input embedding = output lm_head weight)
- Additional embedding sources (e.g., n-gram embeddings, hash-based embeddings)
- Value embeddings that inject into transformer layers via
get_value_embed(layer_idx)
Interface: Your TokenEmbedding class must implement:
forward(idx) -> x: Takes token indices(B, T), returns embeddings(B, T, n_embd)get_lm_head_weight() -> weight: Returns the weight tensor for the output projectionget_num_pos_params() -> int: Returns count of position parameters (excluded from reported param count)get_value_embed(layer_idx) -> Optional[Tensor]: (Optional) Returns per-layer value embedding residual(B, T, n_embd)orNone
Evaluation
- Metric: Validation loss (cross-entropy, lower is better), plus perplexity (WikiText-2, LAMBADA) and downstream accuracy (ARC-Easy, HellaSwag, PIQA, WinoGrande)
- Model: GPT-2 Medium (24L/16H/1024D, ~355M params)
- Dataset: FineWeb 10B (GPT-2 tokenizer), ~7.1B tokens (D=20N Chinchilla-optimal)
- Training: 12030 iterations, BSZ=96, GA=6, 2-GPU DDP
- Hardware: H200 GPU
Code
custom_pretrain.py
EditableRead-only
1"""Custom GPT-2 Pretraining Script2Based on Andrej Karpathy's nanoGPT, evaluated on FineWeb dataset.3"""45import math6import inspect7import os8import time9from contextlib import nullcontext10from dataclasses import dataclass1112import numpy as np13import torch14import torch.nn as nn15from torch.nn import functional as F
Additional context files (read-only):
nanoGPT/model.py
Results
| Model | Type | val loss gpt-345m ↓ | wikitext2 ppl gpt-345m ↓ | lambada ppl gpt-345m ↓ | arc easy lm-eval-345m ↑ | hellaswag lm-eval-345m ↑ |
|---|---|---|---|---|---|---|
| bigram_hash | baseline | 2.288 | 44.970 | 71.430 | 56.400 | 33.710 |
| untied | baseline | 2.306 | 45.700 | 71.110 | 54.800 | 33.050 |
| value_embed | baseline | 2.277 | 43.870 | 68.330 | 54.880 | 33.670 |
| claude-opus-4.6 | vanilla | 2.279 | 43.890 | 68.900 | 56.020 | 33.460 |
| deepseek-reasoner | vanilla | 2.308 | 45.250 | 70.290 | 55.350 | 33.350 |
| gemini-3.1-pro-preview | vanilla | 2.256 | 43.160 | 67.500 | 56.360 | 34.340 |
| gpt-5.4 | vanilla | 2.299 | 45.310 | 69.340 | 55.090 | 33.300 |
| qwen3.6-plus | vanilla | 2.300 | 44.710 | 69.840 | 55.130 | 33.190 |
| claude-opus-4.6 | agent | 2.276 | 44.110 | 70.360 | 56.270 | 33.680 |
| deepseek-reasoner | agent | 2.308 | 45.250 | 70.290 | 55.350 | 33.350 |
| gemini-3.1-pro-preview | agent | 2.255 | 43.130 | 66.020 | 56.690 | 34.070 |
| gpt-5.4 | agent | 2.299 | 45.310 | 69.340 | 55.090 | 33.300 |
| qwen3.6-plus | agent | 2.286 | 45.050 | 70.260 | 56.270 | 33.560 |