Agent Conversation: gemini-3.1-pro-preview
Task: llm-pretrain-linear-attention
LLM Pretraining: Linear/Subquadratic Attention Mechanism
Research Question
Design a novel linear or subquadratic attention mechanism for GPT-2 language model pretraining that achieves competitive validation loss while replacing standard softmax attention. The mechanism should scale better than O(n^2) in sequence length.
What You Can Modify
Two editable regions in custom_pretrain.py:
-
CausalSelfAttentionclass (lines 33-70): The attention mechanism itself, including:- The attention computation (replace softmax attention with linear/subquadratic alternatives)
- Feature maps, gating mechanisms, decay factors
- Query/Key/Value projections and transformations
- Internal state management (recurrent states, convolutions, etc.)
-
Blockclass (lines 88-100): The transformer block structure, including:- How attention and MLP sublayers are composed
- Normalization placement (pre-norm, post-norm)
- Residual connection patterns
Note: The flash-linear-attention (FLA) library is pre-installed and provides 27+ optimized linear attention implementations with Triton kernels. You can import from fla.layers (e.g., GatedLinearAttention, DeltaNet, MultiScaleRetention, LinearAttention, HGRN2, Mamba2, etc.) or implement your own mechanism from scratch.
Note: If your attention mechanism does not use learned absolute position embeddings, set self.use_pos_emb = False in __init__ — the model will then skip adding position embeddings in the forward pass.
Note: torch.compile is disabled for this task since FLA's Triton kernels are not compatible with it.
Evaluation
- Metric: Validation loss (cross-entropy, lower is better), plus perplexity (WikiText-2, LAMBADA) and downstream accuracy (ARC-Easy, HellaSwag, PIQA, WinoGrande)
- Model: GPT-2 Medium (24L/16H/1024D, ~355M params)
- Dataset: FineWeb 10B (GPT-2 tokenizer), ~7.1B tokens (D=20N Chinchilla-optimal)
- Training: 13535 iterations, BSZ=32, GA=16, 2-GPU DDP
- Hardware: H200 GPU
nanoGPT/custom_pretrain.py [EDITABLE — lines 33–70, lines 88–100, lines 246–248 only]
1: """Custom GPT-2 Pretraining Script
2: Based on Andrej Karpathy's nanoGPT, evaluated on FineWeb dataset.
3: # flash-linear-attention is available: from fla.layers import GatedLinearAttention, DeltaNet, MultiScaleRetention, etc.
4: """
5:
6: import math
7: import inspect
8: import os
9: import time
10: from contextlib import nullcontext
11: from dataclasses import dataclass
12:
13: import numpy as np
14: import torch
15: import torch.nn as nn
16: from torch.nn import functional as F
17:
18: # ============================================================================
19: # Model Components
20: # ============================================================================
21:
22: # ── Normalization ─Tool Input
1{
2 "content": "class CausalSelfAttention(nn.Module):\n def __init__(self, config):\n super().__init__()\n self.n_embd = config.n_embd\n self.d1 = config.n_embd // 2\n self.d2 = config.n_embd - self.d1\n self.h1 = config.n_head // 2\n self.h2 = config.n_head - self.h1\n \n from fla.layers import GatedLinearAttention, DeltaNet\n self.attn1 = GatedLinearAttention(\n mode='chunk',\n hidden_size=self.d1,\n num_heads=self.h1,\n expand_k=0.5,\n expand_v=1.0,\n use_output_gate=True,\n gate_fn='swish',\n )\n self.attn2 = DeltaNet(\n hidden_size=self.d2,\n num_heads=self.h2,\n use_beta=True,\n use_short_conv=True,\n conv_size=4,\n qk_activation='silu',\n qk_norm='l2',\n )\n \n self.gate_conv = nn.Conv1d(config.n_embd, config.n_embd, kernel_size=3, padding=2, groups=config.n_embd)\n self.gate = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)\n self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)\n self.resid_dropout = nn.Dropout(config.dropout)\n self.use_pos_emb = False\n\n def forward(self, x):\n x1, x2 = torch.split(x, [self.d1, self.d2], dim=-1)\n o1, _, _ = self.attn1(x1)\n o2, _, _ = self.attn2(x2)\n o = torch.cat([o1, o2], dim=-1)\n \n gx = x.transpose(1, 2)\n gx = self.gate_conv(gx)[..., :-2].transpose(1, 2)\n gate_val = F.silu(self.gate(gx))\n \n o = o * gate_val\n return self.resid_dropout(self.c_proj(o))",
3 "start_line": 33,
4 "op": "replace",
5 "filename": "nanoGPT/custom_pretrain.py",
6 "end_line": 70
7}OK: Replaced lines 33..70 with 47 line(s) in nanoGPT/custom_pretrain.py. Editable range: 33–79, 97–109, 255–257.
[Current file: nanoGPT/custom_pretrain.py | editable: 33–79, 97–109, 255–257 | total: 447 lines]
33 class CausalSelfAttention(nn.Module):
34 def __init__(self, config):
35 super().__init__()
... (41 more lines) ...
77
78 o = o * gate_val
79 return self.resid_dropout(self.c_proj(o))
...
97 # ── Transformer Block ─────────────────────────────────────────────────────
98 class Block(nn.Module):
99 def __init__(self, config):
... (7 more lines) ...
107 x = x + self.attn(self.ln_1(x))
108 x = x + self.mlp(self.ln_2(x))
109 return x
...
255 # CONFIG_OVERRIDES: override training hyperparameters for your method.
256 # Allowed keys: learning_rate, weight_decay, warmup_iters, min_lr, grad_clip.
257 CONFIG_OVERRIDES = {}Tool Input
1{
2 "start_line": 255,
3 "content": " # CONFIG_OVERRIDES: override training hyperparameters for your method.\n # Allowed keys: learning_rate, weight_decay, warmup_iters, min_lr, grad_clip.\n CONFIG_OVERRIDES = {'learning_rate': 1e-3, 'weight_decay': 0.05, 'warmup_iters': 500}",
4 "filename": "nanoGPT/custom_pretrain.py",
5 "end_line": 257,
6 "op": "replace"
7}OK: Replaced lines 255..257 with 3 line(s) in nanoGPT/custom_pretrain.py. Editable range: 33–79, 97–109, 255–257.
[Current file: nanoGPT/custom_pretrain.py | editable: 33–79, 97–109, 255–257 | total: 447 lines]
33 class CausalSelfAttention(nn.Module):
34 def __init__(self, config):
35 super().__init__()
... (41 more lines) ...
77
78 o = o * gate_val
79 return self.resid_dropout(self.c_proj(o))
...
97 # ── Transformer Block ─────────────────────────────────────────────────────
98 class Block(nn.Module):
99 def __init__(self, config):
... (7 more lines) ...
107 x = x + self.attn(self.ln_1(x))
108 x = x + self.mlp(self.ln_2(x))
109 return x
...
255 # CONFIG_OVERRIDES: override training hyperparameters for your method.
256 # Allowed keys: learning_rate, weight_decay, warmup_iters, min_lr, grad_clip.
257 CONFIG_OVERRIDES = {'learnin
[Test #1] (2 tests remaining; call submit(n=N) to choose which test result to submit as final) ### gpt-345m (scripts/gpt_345m.sh) Training metrics (last steps): TRAIN_METRICS: step=9000, train_loss=2.5288, val_loss=2.2711 TRAIN_METRICS: step=10000, train_loss=2.5171, val_loss=2.2586 TRAIN_METRICS: step=11000, train_loss=2.4971, val_loss=2.2489 TRAIN_METRICS: step=12000, train_loss=2.4697, val_loss=2.2067 TRAIN_METRICS: step=13000, train_loss=2.4695, val_loss=2.2218 Final metrics (gpt-345m): val_loss_gpt-345m=2.2138, wikitext2_ppl_gpt-345m=40.0700, lambada_ppl_gpt-345m=60.2700 ### lm-eval-345m (scripts/lm_eval_345m.sh) [COMMAND FAILED exit=1] Evaluating checkpoint: /scratch/gpfs/CHIJ/bohan/MLS-Bench/.saves/llm-pretrain-linear-attention/gemini-3.1-pro-preview_20260413_120858/seed_42/ckpt_gpt-345m.pt Model source: /scratch/gpfs/CHIJ/bohan/MLS-Bench/.saves/llm-pretrain-linear-attention/gemini-3.1-pro-preview_20260413_120858/seed_42/model_source_gpt-345m.py Using the latest cached version
Tool Input
1{
2 "filename": "nanoGPT/custom_pretrain.py",
3 "end_line": 79,
4 "op": "replace",
5 "content": "class CausalSelfAttention(nn.Module):\n def __init__(self, config):\n super().__init__()\n import os\n if \"TRITON_CACHE_DIR\" not in os.environ:\n os.environ[\"TRITON_CACHE_DIR\"] = os.path.join(os.environ.get('OUTPUT_DIR', '.'), \"triton_cache\")\n\n self.n_embd = config.n_embd\n self.d1 = config.n_embd // 2\n self.d2 = config.n_embd - self.d1\n self.h1 = config.n_head // 2\n self.h2 = config.n_head - self.h1\n \n self.in_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)\n \n from fla.layers import GatedLinearAttention, DeltaNet\n self.attn1 = GatedLinearAttention(\n mode='chunk',\n hidden_size=self.d1,\n num_heads=self.h1,\n expand_k=0.5,\n expand_v=1.0,\n use_output_gate=True,\n gate_fn='swish',\n )\n self.attn2 = DeltaNet(\n hidden_size=self.d2,\n num_heads=self.h2,\n use_beta=True,\n use_short_conv=True,\n conv_size=4,\n qk_activation='silu',\n qk_norm='l2',\n )\n \n self.gate_conv = nn.Conv1d(config.n_embd, config.n_embd, kernel_size=3, padding=2, groups=config.n_embd)\n self.gate = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)\n self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)\n self.resid_dropout = nn.Dropout(config.dropout)\n self.use_pos_emb = False\n\n def forward(self, x):\n hx = self.in_proj(x)\n x1, x2 = torch.split(hx, [self.d1, self.d2], dim=-1)\n \n o1, _, _ = self.attn1(x1)\n o2, _, _ = self.attn2(x2)\n o = torch.cat([o1, o2], dim=-1)\n \n gx = x.transpose(1, 2)\n gx = self.gate_conv(gx)[..., :-2].transpose(1, 2)\n gate_val = F.silu(self.gate(gx))\n \n o = o * gate_val\n return self.resid_dropout(self.c_proj(o))",
6 "start_line": 33
7}OK: Replaced lines 33..79 with 55 line(s) in nanoGPT/custom_pretrain.py. Editable range: 33–87, 105–117, 263–265.
[Current file: nanoGPT/custom_pretrain.py | editable: 33–87, 105–117, 263–265 | total: 455 lines]
33 class CausalSelfAttention(nn.Module):
34 def __init__(self, config):
35 super().__init__()
... (49 more lines) ...
85
86 o = o * gate_val
87 return self.resid_dropout(self.c_proj(o))
...
105 # ── Transformer Block ─────────────────────────────────────────────────────
106 class Block(nn.Module):
107 def __init__(self, config):
... (7 more lines) ...
115 x = x + self.attn(self.ln_1(x))
116 x = x + self.mlp(self.ln_2(x))
117 return x
...
263 # CONFIG_OVERRIDES: override training hyperparameters for your method.
264 # Allowed keys: learning_rate, weight_decay, warmup_iters, min_lr, grad_clip.
265 CONFIG_OVERRIDES = {'learni
[Test #2] (1 test remaining; call submit(n=N) to choose which test result to submit as final) ### gpt-345m (scripts/gpt_345m.sh) Training metrics (last steps): TRAIN_METRICS: step=9000, train_loss=2.5395, val_loss=2.2684 TRAIN_METRICS: step=10000, train_loss=2.5184, val_loss=2.2544 TRAIN_METRICS: step=11000, train_loss=2.4845, val_loss=2.2538 TRAIN_METRICS: step=12000, train_loss=2.4780, val_loss=2.2089 TRAIN_METRICS: step=13000, train_loss=2.4657, val_loss=2.2149 Final metrics (gpt-345m): val_loss_gpt-345m=2.2014, wikitext2_ppl_gpt-345m=40.4400, lambada_ppl_gpt-345m=60.0900 ### lm-eval-345m (scripts/lm_eval_345m.sh) [COMMAND FAILED exit=1] Evaluating checkpoint: /scratch/gpfs/CHIJ/bohan/MLS-Bench/.saves/llm-pretrain-linear-attention/gemini-3.1-pro-preview_20260413_120858/seed_42/ckpt_gpt-345m.pt Model source: /scratch/gpfs/CHIJ/bohan/MLS-Bench/.saves/llm-pretrain-linear-attention/gemini-3.1-pro-preview_20260413_120858/seed_42/model_source_gpt-345m.py Using the latest cached version
Tool Input
1{
2 "filename": "nanoGPT/custom_pretrain.py",
3 "start_line": 33,
4 "end_line": 87,
5 "op": "replace",
6 "content": "class CausalSelfAttention(nn.Module):\n def __init__(self, config):\n super().__init__()\n import os\n import tempfile\n triton_cache = os.path.join(tempfile.gettempdir(), f\"triton_cache_{os.getpid()}\")\n os.makedirs(triton_cache, exist_ok=True)\n os.environ[\"TRITON_CACHE_DIR\"] = triton_cache\n\n self.n_embd = config.n_embd\n self.d1 = config.n_embd // 2\n self.d2 = config.n_embd - self.d1\n self.h1 = config.n_head // 2\n self.h2 = config.n_head - self.h1\n \n self.in_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)\n \n from fla.layers import GatedLinearAttention, DeltaNet\n self.attn1 = GatedLinearAttention(\n mode='chunk',\n hidden_size=self.d1,\n num_heads=self.h1,\n expand_k=0.5,\n expand_v=1.0,\n use_output_gate=True,\n gate_fn='swish',\n )\n self.attn2 = DeltaNet(\n hidden_size=self.d2,\n num_heads=self.h2,\n use_beta=True,\n use_short_conv=True,\n conv_size=4,\n qk_activation='silu',\n qk_norm='l2',\n )\n \n self.gate_conv = nn.Conv1d(config.n_embd, config.n_embd, kernel_size=3, padding=2, groups=config.n_embd)\n self.gate = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)\n self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)\n self.resid_dropout = nn.Dropout(config.dropout)\n self.use_pos_emb = False\n\n def forward(self, x):\n hx = self.in_proj(x)\n x1, x2 = torch.split(hx, [self.d1, self.d2], dim=-1)\n \n o1, _, _ = self.attn1(x1)\n o2, _, _ = self.attn2(x2)\n o = torch.cat([o1, o2], dim=-1)\n \n gx = x.transpose(1, 2)\n gx = self.gate_conv(gx)[..., :-2].transpose(1, 2)\n gate_val = F.silu(self.gate(gx))\n \n o = o * gate_val\n return self.resid_dropout(self.c_proj(o))"
7}OK: Replaced lines 33..87 with 57 line(s) in nanoGPT/custom_pretrain.py. Editable range: 33–89.
[Current file: nanoGPT/custom_pretrain.py | editable: 33–89 | total: 457 lines]
33 class CausalSelfAttention(nn.Module):
34 def __init__(self, config):
35 super().__init__()
... (51 more lines) ...
87
88 o = o * gate_val
89 return self.resid_dropout(self.c_proj(o))[Test #3] (0 tests remaining; call submit(n=N) to choose which test result to submit as final)
[NOTE] This was your last test. You MUST now call submit(n=X) to choose which test result to submit as your final answer.
### gpt-345m (scripts/gpt_345m.sh)
Training metrics (last steps):
TRAIN_METRICS: step=9000, train_loss=2.5395, val_loss=2.2684
TRAIN_METRICS: step=10000, train_loss=2.5184, val_loss=2.2544
TRAIN_METRICS: step=11000, train_loss=2.4845, val_loss=2.2538
TRAIN_METRICS: step=12000, train_loss=2.4780, val_loss=2.2089
TRAIN_METRICS: step=13000, train_loss=2.4657, val_loss=2.2149
Final metrics (gpt-345m): val_loss_gpt-345m=2.2014, wikitext2_ppl_gpt-345m=40.4400, lambada_ppl_gpt-345m=60.1000
### lm-eval-345m (scripts/lm_eval_345m.sh)
Final metrics (lm-eval-345m): arc_easy_lm-eval-345m=59.6000, hellaswag_lm-eval-345m=35.9500, piqa_lm-eval-345m=65.7200, winogrande_lm-eval-345m=50.9100
[Leaderboard] Results saved: {'val_loss_gpt-345m': 2.2014, 'wikitext2_ppl_gpt-345m': 40.44, 'lam
Tool Input
1{
2 "n": 3
3}[submit] Submitting result from test #3 as final.
### gpt-345m (scripts/gpt_345m.sh)
Training metrics (last steps):
TRAIN_METRICS: step=9000, train_loss=2.5395, val_loss=2.2684
TRAIN_METRICS: step=10000, train_loss=2.5184, val_loss=2.2544
TRAIN_METRICS: step=11000, train_loss=2.4845, val_loss=2.2538
TRAIN_METRICS: step=12000, train_loss=2.4780, val_loss=2.2089
TRAIN_METRICS: step=13000, train_loss=2.4657, val_loss=2.2149
Final metrics (gpt-345m): val_loss_gpt-345m=2.2014, wikitext2_ppl_gpt-345m=40.4400, lambada_ppl_gpt-345m=60.1000
### lm-eval-345m (scripts/lm_eval_345m.sh)
Final metrics (lm-eval-345m): arc_easy_lm-eval-345m=59.6000, hellaswag_lm-eval-345m=35.9500, piqa_lm-eval-345m=65.7200, winogrande_lm-eval-345m=50.9100
[submit] Finalized test #3 as final.
[Leaderboard] Results saved: {'val_loss_gpt-345m': 2.2014, 'wikitext2_ppl_gpt-345m': 40.44, 'lambada_ppl_gpt-345m': 60.1, 'elapsed_gpt-345m': 35051, 'arc_easy_lm-eval-345m': 59.6, 'hellaswag_lm-eval-345m': 35.95, 'piqa_lm-eval