Autoregressive Embedding Strategy

Studies how token embeddings, position embeddings, and weight tying affect autoregressive language-model pretraining loss.

Language Modelslm-evaluation-harnessnanoGPT
llm-pretrain-embedding

Description

LLM Pretraining: Embedding Strategy Optimization

Research Question

Design an improved embedding strategy for GPT-style language model pretraining. The change should reduce validation loss compared to standard learned token + position embeddings with weight tying, while remaining a modular embedding-level intervention.

Background

The default scheme uses:

  • Learned token embedding (wte) of shape (vocab_size, n_embd).
  • Learned absolute position embedding (wpe) of shape (block_size, n_embd).
  • Tied weights between the input token embedding and the output lm_head projection (Press & Wolf, "Using the Output Embedding to Improve Language Models", 2016/2017, arXiv:1608.05859).

Common alternatives studied at this layer:

  • Untied input/output embeddings.
  • Hash-based / bigram / n-gram embeddings to inject sub-token co-occurrence statistics.
  • Value embeddings (popularized in the modded-nanogpt speedrun, originally inspired by Zhou et al., 2024): a separate embedding table whose output is added to the value projections inside attention layers — typically gated and inserted at a few specific layers.

What you can modify

The TokenEmbedding class in nanoGPT/custom_pretrain.py:

  • Token embedding representation (default: learned token + position embeddings).
  • Weight-tying strategy (default: input embedding shares weights with output lm_head).
  • Additional embedding sources (e.g., n-gram, hash-based).
  • Per-layer value embeddings injected via get_value_embed(layer_idx).

Interface

Your TokenEmbedding class must implement:

  • forward(idx) -> x — takes token indices (B, T), returns embeddings (B, T, n_embd).
  • get_lm_head_weight() -> Tensor — returns the weight tensor used for the output projection.
  • get_num_pos_params() -> int — returns the count of position parameters (excluded from the reported parameter count).
  • get_value_embed(layer_idx) -> Optional[Tensor] — optional per-layer value-embedding residual (B, T, n_embd) or None.

Reference baselines

  • untied — break weight tying between input embedding and lm_head.
  • bigram_hash — hash-based bigram embeddings additive to the token embedding.
  • value_embed — value-style per-layer embedding injection.

Fixed Pipeline

  • Model: GPT-2 Medium (24 layers, 16 heads, d=1024, ~355M params).
  • Dataset: FineWeb 10B (HuggingFace HuggingFaceFW/fineweb sample-10BT), GPT-2 tokenizer, ~7.1B training tokens.
  • Training: 12,030 iterations, micro-batch 96, gradient accumulation 6, 2-GPU DDP.
  • The corpus, tokenizer, training loop, optimizer, and unrelated transformer blocks are fixed.
  • The benchmark's parameter accounting excludes get_num_pos_params() from the reported count, so simply scaling capacity through positional parameters is not a valid escape.

Evaluation

  • Validation loss — cross-entropy on FineWeb (lower is better, primary).
  • Perplexity — WikiText-2, LAMBADA (lower is better).
  • Downstream accuracy — ARC-Easy, HellaSwag, PIQA, WinoGrande (higher is better).

Code

custom_pretrain.py
EditableRead-only
1"""Custom GPT-2 Pretraining Script
2Based on Andrej Karpathy's nanoGPT, evaluated on FineWeb dataset.
3"""
4
5import math
6import inspect
7import os
8import time
9from contextlib import nullcontext
10from dataclasses import dataclass
11
12import numpy as np
13import torch
14import torch.nn as nn
15from torch.nn import functional as F
model.py
EditableRead-only
1"""
2Full definition of a GPT Language Model, all of it in this single file.
3References:
41) the official GPT-2 TensorFlow implementation released by OpenAI:
5https://github.com/openai/gpt-2/blob/master/src/model.py
62) huggingface/transformers PyTorch implementation:
7https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
8"""
9
10import math
11import inspect
12from dataclasses import dataclass
13
14import torch
15import torch.nn as nn

Method Summary

Auto-summarized from each method's code by an LLM reviewer — not the model's original output. Browse via the picker below; the Code section is independent.
Baselines
Agents
GPT-5.4·Formulahigh

Dual-lexicon + depth-decayed n-gram

Per-token convex mix of two lexical embeddings (alt_wte vs wte) with a sigmoid gate, plus token-confidence-scaled bigram/trigram residuals injected at the input and three layers with decreasing scale.

tok(x)=Ewte(x)+σ(g(x))(Ealt(x)Ewte(x))ngram(x)=σ(c(x))[σ(m(x))T(h3(x))+(1σ(m(x)))B(h2(x))]h(0)=tok(x)+Ewpe(t)+s0ngram(x),Δ(i)=singram(x)\begin{array}{l}\mathrm{tok}(x) = E_{\text{wte}}(x) + \sigma(g(x))\,(E_{\text{alt}}(x) - E_{\text{wte}}(x))\\\mathrm{ngram}(x) = \sigma(c(x))\bigl[\sigma(m(x))\,T(h_3(x)) + (1-\sigma(m(x)))\,B(h_2(x))\bigr]\\h^{(0)} = \mathrm{tok}(x) + E_{\text{wpe}}(t) + s_0\,\mathrm{ngram}(x),\quad \Delta^{(\ell_i)} = s_i\,\mathrm{ngram}(x)\end{array}
Δ vs. baselineCombines and extends the untied + value_embed + bigram_hash baselines: token-wise convex mix of two embedding tables to give the lm_head a softly-untied weight, plus depth-decayed bigram and trigram residuals gated by per-token confidence and source-mixing scalars.
ngram_vocab_mult=2xinput_ngram_scale=0.10learnablelayer_scales=[0.050, 0.035, 0.020]learnableinject_layers={1, n_layer/2, n_layer-1}alt_wte_init_std=0.02learnablemix_gate_init=0learnablebigram_hash=65537·x + 8191·x_{t-1}trigram_hash=131071·x + 4099·x_{t-1} + 257·x_{t-2}Recovers Tied baseline at mix_gate=0 and ngram tables zero-init

Results