Agent Conversation: deepseek-reasoner
Task: llm-pretrain-linear-attention
LLM Pretraining: Linear/Subquadratic Attention Mechanism
Research Question
Design a novel linear or subquadratic attention mechanism for GPT-2 language model pretraining that achieves competitive validation loss while replacing standard softmax attention. The mechanism should scale better than O(n^2) in sequence length.
What You Can Modify
Two editable regions in custom_pretrain.py:
-
CausalSelfAttentionclass (lines 33-70): The attention mechanism itself, including:- The attention computation (replace softmax attention with linear/subquadratic alternatives)
- Feature maps, gating mechanisms, decay factors
- Query/Key/Value projections and transformations
- Internal state management (recurrent states, convolutions, etc.)
-
Blockclass (lines 88-100): The transformer block structure, including:- How attention and MLP sublayers are composed
- Normalization placement (pre-norm, post-norm)
- Residual connection patterns
Note: The flash-linear-attention (FLA) library is pre-installed and provides 27+ optimized linear attention implementations with Triton kernels. You can import from fla.layers (e.g., GatedLinearAttention, DeltaNet, MultiScaleRetention, LinearAttention, HGRN2, Mamba2, etc.) or implement your own mechanism from scratch.
Note: If your attention mechanism does not use learned absolute position embeddings, set self.use_pos_emb = False in __init__ — the model will then skip adding position embeddings in the forward pass.
Note: torch.compile is disabled for this task since FLA's Triton kernels are not compatible with it.
Evaluation
- Metric: Validation loss (cross-entropy, lower is better), plus perplexity (WikiText-2, LAMBADA) and downstream accuracy (ARC-Easy, HellaSwag, PIQA, WinoGrande)
- Model: GPT-2 Medium (24L/16H/1024D, ~355M params)
- Dataset: FineWeb 10B (GPT-2 tokenizer), ~7.1B tokens (D=20N Chinchilla-optimal)
- Training: 13535 iterations, BSZ=32, GA=16, 2-GPU DDP
- Hardware: H200 GPU
nanoGPT/custom_pretrain.py [EDITABLE — lines 33–70, lines 88–100, lines 246–248 only]
1: """Custom GPT-2 Pretraining Script
2: Based on Andrej Karpathy's nanoGPT, evaluated on FineWeb dataset.
3: # flash-linear-attention is available: from fla.layers import GatedLinearAttention, DeltaNet, MultiScaleRetention, etc.
4: """
5:
6: import math
7: import inspect
8: import os
9: import time
10: from contextlib import nullcontext
11: from dataclasses import dataclass
12:
13: import numpy as np
14: import torch
15: import torch.nn as nn
16: from torch.nn import functional as F
17:
18: # ============================================================================
19: # Model Components
20: # ============================================================================
21:
22: # ── Normalization ─Tool Input
1{
2 "op": "replace",
3 "filename": "nanoGPT/custom_pretrain.py",
4 "start_line": 33,
5 "end_line": 70,
6 "content": "# ── Self-Attention ─────────────────────────────────────────────────────────\nclass CausalSelfAttention(nn.Module):\n def __init__(self, config):\n super().__init__()\n from fla.layers import GatedLinearAttention\n hidden_size = config.n_embd\n num_heads = config.n_head\n # Gated Linear Attention base\n self.gla = GatedLinearAttention(\n mode='chunk',\n hidden_size=hidden_size,\n num_heads=num_heads,\n expand_k=0.5,\n expand_v=1.0,\n use_output_gate=True,\n gate_fn='swish',\n )\n # Short convolution branch (local pattern)\n self.conv = nn.Conv1d(\n in_channels=hidden_size,\n out_channels=hidden_size,\n kernel_size=4,\n groups=hidden_size, # depthwise\n padding=3, # causal padding: (kernel_size - 1)\n )\n # Learnable per-head decay factor (initialized as zeros -> no decay initially)\n self.decay = nn.Parameter(torch.zeros(num_heads))\n # Gating network: computes blend weight between linear attention and conv\n self.gate_proj = nn.Linear(hidden_size, hidden_size)\n self.gate_out = nn.Linear(hidden_size, hidden_size)\n # Output projection (standard)\n self.c_proj = nn.Linear(hidden_size, hidden_size, bias=config.bias)\n self.resid_dropout = nn.Dropout(config.dropout)\n self.n_head = num_heads\n self.n_embd = hidden_size\n self.dropout = config.dropout\n # We'll use relative positional encoding inside GLA, so disable learned absolute pos emb\n self.use_pos_emb = False\n\n @torch.compiler.disable\n def _gla_forward(self, x):\n # GLA returns (output, attn, lse)\n o, _, _ = self.gla(x)\n return o\n\n def forward(self, x):\n B, T, C = x.size()\n # 1. Gated Linear Attention with per-head decay modulation\n # Apply decay factor to the hidden state? GLA internal recurrence may not expose decay.\n # We'll instead modulate the input to GLA? Actually GLA uses chunk-wise recurrence.\n # We'll skip decay modulation for now (can be added later).\n gla_out = self._gla_forward(x)\n \n # 2. Depthwise convolution branch (causal)\n # Convert to (B, C, T) for conv1d\n conv_in = x.transpose(1, 2) # (B, C, T)\n # Causal padding: we padded with kernel_size-1 on the left, but conv1d padding adds both sides.\n # Use manual padding? Use Conv1d with padding='same'? Not causal.\n # We'll implement causal conv manually via F.conv1d with shifted output.\n # Simpler: use the conv layer with padding=3 and then discard first 3 timesteps?\n # Actually we want causal: output at t depends only on t-3, t-2, t-1, t.\n # With padding=3, conv output includes future information (since padding both sides).\n # Let's implement custom causal conv.\n # Use F.conv1d with groups=C, weight from self.conv, bias=None, padding=0, then shift.\n weight = self.conv.weight # (C, 1, kernel_size)\n bias = self.conv.bias\n # Apply conv\n conv_out = F.conv1d(conv_in, weight, bias, groups=C, padding=0)\n # conv_out shape (B, C, T - kernel_size + 1)\n # Pad left with zeros to match length T (causal: first kernel_size-1 outputs zero)\n pad_left = 3 # kernel_size - 1\n conv_out = F.pad(conv_out, (pad_left, 0)) # pad left only\n conv_out = conv_out.transpose(1, 2) # (B, T, C)\n \n # 3. Gating network\n gate_input = x # original input\n gate = torch.sigmoid(self.gate_proj(gate_input))\n # Blend\n blended = gate * gla_out + (1 - gate) * conv_out\n \n # 4. Output projection\n out = self.gate_out(blended)\n out = self.resid_dropout(self.c_proj(out))\n return out"
7}OK: Replaced lines 33..70 with 84 line(s) in nanoGPT/custom_pretrain.py. Editable range: 33–116, 134–146, 292–294.
[Current file: nanoGPT/custom_pretrain.py | editable: 33–116, 134–146, 292–294 | total: 484 lines]
33 # ── Self-Attention ─────────────────────────────────────────────────────────
34 class CausalSelfAttention(nn.Module):
35 def __init__(self, config):
... (78 more lines) ...
114 out = self.gate_out(blended)
115 out = self.resid_dropout(self.c_proj(out))
116 return out
...
134 # ── Transformer Block ─────────────────────────────────────────────────────
135 class Block(nn.Module):
136 def __init__(self, config):
... (7 more lines) ...
144 x = x + self.attn(self.ln_1(x))
145 x = x + self.mlp(self.ln_2(x))
146 return x
...
292 # CONFIG_OVERRIDES: override training hyperparameters for your method.
293 # Allowed keys: learning_rate, weight_dec
Tool Input
1{
2 "op": "replace",
3 "filename": "nanoGPT/custom_pretrain.py",
4 "start_line": 33,
5 "end_line": 116,
6 "content": "# ── Self-Attention ─────────────────────────────────────────────────────────\nclass CausalSelfAttention(nn.Module):\n def __init__(self, config):\n super().__init__()\n from fla.layers import GatedLinearAttention\n hidden_size = config.n_embd\n num_heads = config.n_head\n # Gated Linear Attention base\n self.gla = GatedLinearAttention(\n mode='chunk',\n hidden_size=hidden_size,\n num_heads=num_heads,\n expand_k=0.5,\n expand_v=1.0,\n use_output_gate=True,\n gate_fn='swish',\n )\n # Short depthwise causal convolution branch (kernel size 4)\n self.conv = nn.Conv1d(\n in_channels=hidden_size,\n out_channels=hidden_size,\n kernel_size=4,\n groups=hidden_size, # depthwise\n bias=False,\n )\n # Learnable per-head decay factor (initialized as zeros)\n self.decay = nn.Parameter(torch.zeros(num_heads))\n # Gating network: computes blend weight between linear attention and conv\n self.gate_proj = nn.Linear(hidden_size, hidden_size)\n self.gate_out = nn.Linear(hidden_size, hidden_size)\n # Output projection (standard)\n self.c_proj = nn.Linear(hidden_size, hidden_size, bias=config.bias)\n self.resid_dropout = nn.Dropout(config.dropout)\n self.n_head = num_heads\n self.n_embd = hidden_size\n self.dropout = config.dropout\n # We'll use relative positional encoding inside GLA, so disable learned absolute pos emb\n self.use_pos_emb = False\n\n @torch.compiler.disable\n def _gla_forward(self, x):\n # GLA returns (output, attn, lse)\n o, _, _ = self.gla(x)\n return o\n\n def forward(self, x):\n B, T, C = x.size()\n # 1. Gated Linear Attention\n gla_out = self._gla_forward(x)\n \n # 2. Depthwise causal convolution branch\n # Convert to (B, C, T) for conv1d\n conv_in = x.transpose(1, 2) # (B, C, T)\n # Causal padding: pad left with kernel_size-1 zeros\n conv_in_padded = F.pad(conv_in, (3, 0)) # (left, right)\n conv_out = self.conv(conv_in_padded) # (B, C, T)\n conv_out = conv_out.transpose(1, 2) # (B, T, C)\n \n # 3. Gating network (per token, per channel)\n gate = torch.sigmoid(self.gate_proj(x)) # (B, T, C)\n # Blend linear attention and convolution outputs\n blended = gate * gla_out + (1 - gate) * conv_out\n \n # 4. Output projection\n out = self.gate_out(blended)\n out = self.resid_dropout(self.c_proj(out))\n return out"
7}OK: Replaced lines 33..116 with 67 line(s) in nanoGPT/custom_pretrain.py. Editable range: 33–99, 117–129, 275–277.
[Current file: nanoGPT/custom_pretrain.py | editable: 33–99, 117–129, 275–277 | total: 467 lines]
33 # ── Self-Attention ─────────────────────────────────────────────────────────
34 class CausalSelfAttention(nn.Module):
35 def __init__(self, config):
... (61 more lines) ...
97 out = self.gate_out(blended)
98 out = self.resid_dropout(self.c_proj(out))
99 return out
...
117 # ── Transformer Block ─────────────────────────────────────────────────────
118 class Block(nn.Module):
119 def __init__(self, config):
... (7 more lines) ...
127 x = x + self.attn(self.ln_1(x))
128 x = x + self.mlp(self.ln_2(x))
129 return x
...
275 # CONFIG_OVERRIDES: override training hyperparameters for your method.
276 # Allowed keys: learning_rate, weight_deca
Tool Input
1{
2 "op": "replace",
3 "filename": "nanoGPT/custom_pretrain.py",
4 "start_line": 117,
5 "end_line": 129,
6 "content": "# ── Transformer Block ─────────────────────────────────────────────────────\nclass Block(nn.Module):\n def __init__(self, config):\n super().__init__()\n self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)\n self.attn = CausalSelfAttention(config)\n self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)\n self.mlp = MLP(config)\n # Optional scaling factor for parallel branch\n self.parallel_scale = nn.Parameter(torch.ones(1))\n\n def forward(self, x):\n # Parallel execution: compute attention and MLP simultaneously\n # Normalize input\n x_norm = self.ln_1(x)\n # Compute attention and MLP in parallel (they depend on same normalized input)\n attn_out = self.attn(x_norm)\n mlp_out = self.mlp(self.ln_2(x_norm))\n # Combine with residual connection\n x = x + attn_out + self.parallel_scale * mlp_out\n return x"
7}OK: Replaced lines 117..129 with 21 line(s) in nanoGPT/custom_pretrain.py. Editable range: 33–99, 117–137, 283–285.
[Current file: nanoGPT/custom_pretrain.py | editable: 33–99, 117–137, 283–285 | total: 475 lines]
33 # ── Self-Attention ─────────────────────────────────────────────────────────
34 class CausalSelfAttention(nn.Module):
35 def __init__(self, config):
... (61 more lines) ...
97 out = self.gate_out(blended)
98 out = self.resid_dropout(self.c_proj(out))
99 return out
...
117 # ── Transformer Block ─────────────────────────────────────────────────────
118 class Block(nn.Module):
119 def __init__(self, config):
... (15 more lines) ...
135 # Combine with residual connection
136 x = x + attn_out + self.parallel_scale * mlp_out
137 return x
...
283 # CONFIG_OVERRIDES: override training hyperparameters for your method.
284 # Allowed keys: lea
[Test #1] (2 tests remaining; call submit(n=N) to choose which test result to submit as final)
### gpt-345m (scripts/gpt_345m.sh)
Training metrics (last steps):
TRAIN_METRICS: step=9000, train_loss=2.6842, val_loss=2.4153
TRAIN_METRICS: step=10000, train_loss=2.6690, val_loss=2.3899
TRAIN_METRICS: step=11000, train_loss=2.6595, val_loss=2.3636
TRAIN_METRICS: step=12000, train_loss=2.6374, val_loss=2.3603
TRAIN_METRICS: step=13000, train_loss=2.6222, val_loss=2.3602
Final metrics (gpt-345m): val_loss_gpt-345m=2.3563, wikitext2_ppl_gpt-345m=54.1100, lambada_ppl_gpt-345m=73.0500
### lm-eval-345m (scripts/lm_eval_345m.sh)
Final metrics (lm-eval-345m): arc_easy_lm-eval-345m=56.4400, hellaswag_lm-eval-345m=33.3700, piqa_lm-eval-345m=65.0200, winogrande_lm-eval-345m=51.5400
[Leaderboard] Results saved: {'val_loss_gpt-345m': 2.3563, 'wikitext2_ppl_gpt-345m': 54.11, 'lambada_ppl_gpt-345m': 73.05, 'elapsed_gpt-345m': 36594, 'arc_easy_lm-eval-345m': 56.44, 'hellaswag_lm-eval-345m': 33.37, 'piq
Tool Input
1{
2 "n": 1
3}Undo complete: Restored: nanoGPT/custom_pretrain.py
Tool Input
1{
2 "op": "replace",
3 "filename": "nanoGPT/custom_pretrain.py",
4 "start_line": 33,
5 "end_line": 99,
6 "content": "# ── Self-Attention ─────────────────────────────────────────────────────────\nclass CausalSelfAttention(nn.Module):\n def __init__(self, config):\n super().__init__()\n from fla.layers import DeltaNet, LinearAttention\n hidden_size = config.n_embd\n num_heads = config.n_head\n \n # DeltaNet branch (local + linear)\n self.delta = DeltaNet(\n hidden_size=hidden_size,\n num_heads=num_heads,\n use_beta=True,\n use_short_conv=True,\n conv_size=4,\n qk_activation='silu',\n qk_norm='l2',\n )\n # Global linear attention branch (no convolution)\n self.global_lin = LinearAttention(\n hidden_size=hidden_size,\n num_heads=num_heads,\n expand_k=1.0,\n expand_v=1.0,\n feature_map='silu',\n output_proj=False,\n )\n # Learnable per-head decay factors for each branch\n self.delta_decay = nn.Parameter(torch.zeros(num_heads))\n self.global_decay = nn.Parameter(torch.zeros(num_heads))\n \n # Gating network: computes mixing weight per token per head\n self.gate_proj = nn.Linear(hidden_size, num_heads)\n # Output projection (shared)\n self.c_proj = nn.Linear(hidden_size, hidden_size, bias=config.bias)\n self.resid_dropout = nn.Dropout(config.dropout)\n self.n_head = num_heads\n self.n_embd = hidden_size\n self.dropout = config.dropout\n # Disable learned absolute position embeddings (use relative)\n self.use_pos_emb = False\n\n @torch.compiler.disable\n def _delta_forward(self, x):\n o, _, _ = self.delta(x)\n return o\n\n @torch.compiler.disable\n def _global_forward(self, x):\n o, _, _ = self.global_lin(x)\n return o\n\n def forward(self, x):\n B, T, C = x.size()\n # Compute both branches in parallel\n delta_out = self._delta_forward(x)\n global_out = self._global_forward(x)\n \n # Compute gating weights per token per head\n gate_logits = self.gate_proj(x) # (B, T, num_heads)\n gate_weights = torch.softmax(gate_logits, dim=-1) # (B, T, num_heads)\n # Expand gate weights to match head dimension\n # Each head has dimension C // num_heads\n head_dim = C // self.n_head\n gate_weights = gate_weights.unsqueeze(-1) # (B, T, num_heads, 1)\n gate_weights = gate_weights.expand(-1, -1, -1, head_dim)\n gate_weights = gate_weights.reshape(B, T, C)\n \n # Combine branches using gating\n combined = gate_weights * delta_out + (1 - gate_weights) * global_out\n \n # Output projection\n out = self.resid_dropout(self.c_proj(combined))\n return out"
7}OK: Replaced lines 33..99 with 74 line(s) in nanoGPT/custom_pretrain.py. Editable range: 33–106, 124–136, 282–284.
[Current file: nanoGPT/custom_pretrain.py | editable: 33–106, 124–136, 282–284 | total: 474 lines]
33 # ── Self-Attention ─────────────────────────────────────────────────────────
34 class CausalSelfAttention(nn.Module):
35 def __init__(self, config):
... (68 more lines) ...
104 # Output projection
105 out = self.resid_dropout(self.c_proj(combined))
106 return out
...
124 # ── Transformer Block ─────────────────────────────────────────────────────
125 class Block(nn.Module):
126 def __init__(self, config):
... (7 more lines) ...
134 x = x + self.attn(self.ln_1(x))
135 x = x + self.mlp(self.ln_2(x))
136 return x
...
282 # CONFIG_OVERRIDES: override training hyperparameters for your method.
283 # Allowed keys: learning_rate, weight_decay,
[Test #2] (1 test remaining; call submit(n=N) to choose which test result to submit as final) ### gpt-345m (scripts/gpt_345m.sh) [COMMAND FAILED exit=1] W0414 20:44:28.153000 438558 site-packages/torch/distributed/run.py:792] W0414 20:44:28.153000 438558 site-packages/torch/distributed/run.py:792] ***************************************** W0414 20:44:28.153000 438558 site-packages/torch/distributed/run.py:792] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. W0414 20:44:28.153000 438558 site-packages/torch/distributed/run.py:792] ***************************************** [W414 20:44:28.129934952 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3 [W414 20:44:28.130249742 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3 [W414 20:44:28.170243115 socket.cpp:204
Tool Input
1{
2 "op": "replace",
3 "filename": "nanoGPT/custom_pretrain.py",
4 "start_line": 33,
5 "end_line": 106,
6 "content": "# ── Self-Attention ─────────────────────────────────────────────────────────\nclass CausalSelfAttention(nn.Module):\n def __init__(self, config):\n super().__init__()\n from fla.layers import DeltaNet\n hidden_size = config.n_embd\n num_heads = config.n_head\n \n # DeltaNet branch (local + linear)\n self.delta = DeltaNet(\n hidden_size=hidden_size,\n num_heads=num_heads,\n use_beta=True,\n use_short_conv=True,\n conv_size=4,\n qk_activation='silu',\n qk_norm='l2',\n )\n # Learnable per-head decay factor (initialized as zero)\n self.decay = nn.Parameter(torch.zeros(num_heads))\n # Gating network for residual blending\n self.gate_proj = nn.Linear(hidden_size, hidden_size)\n # Output projection (standard)\n self.c_proj = nn.Linear(hidden_size, hidden_size, bias=config.bias)\n self.resid_dropout = nn.Dropout(config.dropout)\n self.n_head = num_heads\n self.n_embd = hidden_size\n self.dropout = config.dropout\n # Disable learned absolute position embeddings (use relative)\n self.use_pos_emb = False\n\n @torch.compiler.disable\n def _delta_forward(self, x):\n o, _, _ = self.delta(x)\n return o\n\n def forward(self, x):\n B, T, C = x.size()\n # DeltaNet forward\n delta_out = self._delta_forward(x)\n \n # Apply per-head decay (scale each head's output)\n # Reshape delta_out to (B, T, num_heads, head_dim)\n head_dim = C // self.n_head\n delta_out = delta_out.view(B, T, self.n_head, head_dim)\n decay = torch.sigmoid(self.decay) # (num_heads)\n delta_out = delta_out * decay.view(1, 1, self.n_head, 1)\n delta_out = delta_out.view(B, T, C)\n \n # Compute gating weights per token per channel\n gate = torch.sigmoid(self.gate_proj(x)) # (B, T, C)\n # Blend delta output with skip connection (input)\n blended = gate * delta_out + (1 - gate) * x\n \n # Output projection\n out = self.resid_dropout(self.c_proj(blended))\n return out"
7}OK: Replaced lines 33..106 with 57 line(s) in nanoGPT/custom_pretrain.py. Editable range: 33–89.
[Current file: nanoGPT/custom_pretrain.py | editable: 33–89 | total: 457 lines]
33 # ── Self-Attention ─────────────────────────────────────────────────────────
34 class CausalSelfAttention(nn.Module):
35 def __init__(self, config):
... (51 more lines) ...
87 # Output projection
88 out = self.resid_dropout(self.c_proj(blended))
89 return out