llm-pretrain-embedding

Language Modelslm-evaluation-harnessnanoGPTrigorous codebase

Description

LLM Pretraining: Embedding Strategy Optimization

Research Question

Design an improved embedding strategy for GPT-2 language model pretraining. Your modifications should reduce validation loss compared to standard token + position embeddings with weight tying.

What You Can Modify

The TokenEmbedding class (lines 116-140) in custom_pretrain.py:

  • Token embedding representation (default: learned token + position embeddings)
  • Weight tying strategy (default: input embedding = output lm_head weight)
  • Additional embedding sources (e.g., n-gram embeddings, hash-based embeddings)
  • Value embeddings that inject into transformer layers via get_value_embed(layer_idx)

Interface: Your TokenEmbedding class must implement:

  • forward(idx) -> x: Takes token indices (B, T), returns embeddings (B, T, n_embd)
  • get_lm_head_weight() -> weight: Returns the weight tensor for the output projection
  • get_num_pos_params() -> int: Returns count of position parameters (excluded from reported param count)
  • get_value_embed(layer_idx) -> Optional[Tensor]: (Optional) Returns per-layer value embedding residual (B, T, n_embd) or None

Evaluation

  • Metric: Validation loss (cross-entropy, lower is better), plus perplexity (WikiText-2, LAMBADA) and downstream accuracy (ARC-Easy, HellaSwag, PIQA, WinoGrande)
  • Model: GPT-2 Medium (24L/16H/1024D, ~355M params)
  • Dataset: FineWeb 10B (GPT-2 tokenizer), ~7.1B tokens (D=20N Chinchilla-optimal)
  • Training: 12030 iterations, BSZ=96, GA=6, 2-GPU DDP
  • Hardware: H200 GPU

Code

custom_pretrain.py
EditableRead-only
1"""Custom GPT-2 Pretraining Script
2Based on Andrej Karpathy's nanoGPT, evaluated on FineWeb dataset.
3"""
4
5import math
6import inspect
7import os
8import time
9from contextlib import nullcontext
10from dataclasses import dataclass
11
12import numpy as np
13import torch
14import torch.nn as nn
15from torch.nn import functional as F

Additional context files (read-only):

  • nanoGPT/model.py

Results

ModelTypeval loss gpt-345m wikitext2 ppl gpt-345m lambada ppl gpt-345m arc easy lm-eval-345m hellaswag lm-eval-345m
bigram_hashbaseline2.28844.97071.43056.40033.710
untiedbaseline2.30645.70071.11054.80033.050
value_embedbaseline2.27743.87068.33054.88033.670
claude-opus-4.6vanilla2.27943.89068.90056.02033.460
deepseek-reasonervanilla2.30845.25070.29055.35033.350
gemini-3.1-pro-previewvanilla2.25643.16067.50056.36034.340
gpt-5.4vanilla2.29945.31069.34055.09033.300
qwen3.6-plusvanilla2.30044.71069.84055.13033.190
claude-opus-4.6agent2.27644.11070.36056.27033.680
deepseek-reasoneragent2.30845.25070.29055.35033.350
gemini-3.1-pro-previewagent2.25543.13066.02056.69034.070
gpt-5.4agent2.29945.31069.34055.09033.300
qwen3.6-plusagent2.28645.05070.26056.27033.560

Agent Conversations