Agent Conversation: claude-opus-4.6
Model: claude-opus-4.6
Total steps: 10
Messages: 23
UserStep 0
Task: llm-pretrain-optimizer
LLM Pretraining: Optimizer & Learning Rate Schedule Optimization
Research Question
Design an improved optimizer and/or learning rate schedule for GPT-2 language model pretraining. Your modifications should reduce validation loss compared to the standard AdamW with cosine annealing schedule.
What You Can Modify
Two regions in custom_pretrain.py:
- configure_optimizers method (lines 172-189): Optimizer creation and parameter grouping
- get_lr function (lines 192-201): Learning rate schedule
You can modify:
- The optimization algorithm (default: AdamW with fused implementation)
- Parameter grouping strategy (default: weight decay for 2D params, no decay for 1D)
- Learning rate schedule shape (default: cosine with linear warmup)
- Any optimizer hyperparameters
Note: The training loop calls get_lr(it, warmup_iters, lr_decay_iters, learning_rate, min_lr) — keep this signature compatible. The optimizer returned by configure_optimizers must support .zero_grad(), .step(), and .param_groups.
Evaluation
- Metric: Validation loss (cross-entropy, lower is better), plus perplexity (WikiText-2, LAMBADA) and downstream accuracy (ARC-Easy, HellaSwag, PIQA, WinoGrande)
- Model: GPT-2 Medium (24L/16H/1024D, ~355M params)
- Dataset: FineWeb 10B (GPT-2 tokenizer), ~7.1B tokens (D=20N Chinchilla-optimal)
- Training: 12030 iterations, BSZ=96, GA=6, 2-GPU DDP
- Hardware: H200 GPU
nanoGPT/custom_pretrain.py [EDITABLE — lines 171–189, lines 245–247 only]
1: """Custom GPT-2 Pretraining Script
2: Based on Andrej Karpathy's nanoGPT, evaluated on FineWeb dataset.
3: """
4:
5: import math
6: import inspect
7: import os
8: import time
9: from contextlib import nullcontext
10: from dataclasses import dataclass
11:
12: import numpy as np
13: import torch
14: import torch.nn as nn
15: from torch.nn import functional as F
16:
17: # ============================================================================
18: # Model Components
19: # ============================================================================
20:
21: # ── Normalization ──────────────────────────────────────────────────────────
22: class LayerNorm(nn.Module):
23: """LayerNorm but with an optional bias."""
24: def __init__(self, ndim, bias):
25: super().__init__()
26: self.weight = nn.Parameter(torch.ones(ndim))
27: self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
28:
29: def forward(self, input):
30: return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
31:
32: # ── Self-Attention ─────────────────────────────────────────────────────────
33: class CausalSelfAttention(nn.Module):
34: def __init__(self, config):
35: super().__init__()
36: assert config.nAssistantStep 1edit
Tool Input
1{
2 "op": "replace",
3 "filename": "nanoGPT/custom_pretrain.py",
4 "start_line": 171,
5 "end_line": 189,
6 "content": " def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):\n param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}\n decay_params = [(n, p) for n, p in param_dict.items() if p.dim() >= 2]\n nodecay_params = [(n, p) for n, p in param_dict.items() if p.dim() < 2]\n num_decay_params = sum(p.numel() for _, p in decay_params)\n num_nodecay_params = sum(p.numel() for _, p in nodecay_params)\n print(f\"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters\")\n print(f\"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters\")\n\n # Separate 2D projection weights (for MuonGC) from rest (for AdamW)\n muon_params = [p for n, p in decay_params\n if 'wte' not in n and 'wpe' not in n and 'lm_head' not in n]\n adam_decay_params = [p for n, p in decay_params\n if 'wte' in n or 'wpe' in n or 'lm_head' in n]\n adam_nodecay_params = [p for _, p in nodecay_params]\n\n class MuonGC(torch.optim.Optimizer):\n \"\"\"Muon with Gradient Centralization — MomentUm Orthogonalized by\n Newton-Schulz with row-wise gradient centralization.\n\n Novel contribution: Before Newton-Schulz orthogonalization, we subtract\n the row-wise mean from the Nesterov momentum estimate. This projects\n the update direction onto a zero-mean subspace, acting as an implicit\n regularizer complementary to NS spectral normalization. GC removes\n directional bias (mean component) while NS equalizes singular values,\n producing more balanced, better-conditioned weight updates that improve\n generalization in language model pretraining.\n \"\"\"\n def __init__(self, params, lr=0.02, momentum=0.95, ns_steps=5, weight_decay=0.0):\n defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, weight_decay=weight_decay)\n super().__init__(params, defaults)\n\n @staticmethod\n def zeroth_power_via_newtonschulz5(G, steps=5):\n \"\"\"Approximate G @ (G^T G)^{-1/2} via Newton-Schulz iterations.\"\"\"\n assert G.ndim == 2\n a, b, c = (3.4445, -4.7750, 2.0315)\n X = G.bfloat16()\n X = X / (X.norm() + 1e-7)\n if G.size(0) > G.size(1):\n X = X.T\n for _ in range(steps):\n A = X @ X.T\n X = a * X + b * (A @ X) + c * (A @ (A @ X))\n if G.size(0) > G.size(1):\n X = X.T\n return X\n\n @torch.no_grad()\n def step(self):\n for group in self.param_groups:\n lr = group['lr']\n momentum = group['momentum']\n wd = group.get('weight_decay', 0.0)\n for p in group['params']:\n if p.grad is None:\n continue\n # Decoupled weight decay (before update)\n if wd > 0:\n p.mul_(1 - lr * wd)\n g = p.grad\n state = self.state[p]\n if len(state) == 0:\n state['momentum_buffer'] = torch.zeros_like(g)\n buf = state['momentum_buffer']\n # EMA momentum\n buf.lerp_(g, 1.0 - momentum)\n # Nesterov momentum estimate\n nesterov_g = g.lerp(buf, momentum)\n if nesterov_g.dim() == 2:\n # Gradient Centralization: subtract row-wise mean\n # This removes the mean component, regularizing updates\n nesterov_g = nesterov_g - nesterov_g.mean(dim=1, keepdim=True)\n orig_shape = nesterov_g.shape\n # Split fused QKV (c_attn: 3*n_embd x n_embd) into 3 parts\n if orig_shape[0] == 3 * orig_shape[1]:\n parts = nesterov_g.split(orig_shape[1])\n update = torch.cat([\n self.zeroth_power_via_newtonschulz5(part, steps=group['ns_steps'])\n for part in parts\n ])\n scale = max(1, orig_shape[0] // orig_shape[1]) ** 0.5\n else:\n update = self.zeroth_power_via_newtonschulz5(nesterov_g, steps=group['ns_steps'])\n scale = max(1, orig_shape[0] / orig_shape[1]) ** 0.5\n p.data.add_(update.to(p.dtype), alpha=-lr * scale)\n else:\n # Fallback: plain SGD with momentum for non-2D\n p.add_(buf, alpha=-lr)\n\n fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters\n use_fused = fused_available and device_type == 'cuda'\n extra_args = dict(fused=True) if use_fused else dict()\n\n muon_base_lr = 0.02\n muon_lr_scale = muon_base_lr / learning_rate\n muon_opt = MuonGC([{'params': muon_params, 'lr_scale': muon_lr_scale}],\n lr=muon_base_lr, momentum=0.95, weight_decay=0.1)\n adam_groups = [\n {'params': adam_decay_params, 'weight_decay': weight_decay},\n {'params': adam_nodecay_params, 'weight_decay': 0.0},\n ]\n adam_opt = torch.optim.AdamW(adam_groups, lr=learning_rate, betas=betas, **extra_args)\n\n class CombinedOptimizer:\n \"\"\"Combines MuonGC (for projections) with AdamW (for embeddings/norms).\"\"\"\n def __init__(self, optimizers):\n self.optimizers = optimizers\n self.param_groups = []\n for opt in optimizers:\n self.param_groups.extend(opt.param_groups)\n def zero_grad(self, set_to_none=True):\n for opt in self.optimizers:\n opt.zero_grad(set_to_none=set_to_none)\n def step(self):\n for opt in self.optimizers:\n opt.step()\n def state_dict(self):\n return [opt.state_dict() for opt in self.optimizers]\n\n print(f\"using MuonGC (lr={muon_base_lr}, scale={muon_lr_scale:.1f}) + AdamW combined optimizer\")\n return CombinedOptimizer([muon_opt, adam_opt])\n"
7}Tool ResultStep 1
OK: Replaced lines 171..189 with 122 line(s) in nanoGPT/custom_pretrain.py. Editable range: 171–292, 348–350.
[Current file: nanoGPT/custom_pretrain.py | editable: 171–292, 348–350 | total: 540 lines]
171 def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
172 param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
173 decay_params = [(n, p) for n, p in param_dict.items() if p.dim() >= 2]
... (116 more lines) ...
290
291 print(f"using MuonGC (lr={muon_base_lr}, scale={muon_lr_scale:.1f}) + AdamW combined optimizer")
292 return CombinedOptimizer([muon_opt, adam_opt])
...
348 # CONFIG_OVERRIDES: override training hyperparameters for your method.
349 # Allowed keys: learning_rate, weight_decay, warmup_iters, min_lr, grad_clip.
350 CONFIG_OVERRIDES = {}AssistantStep 2edit
Tool Input
1{
2 "content": " # CONFIG_OVERRIDES: override training hyperparameters for your method.\n # Allowed keys: learning_rate, weight_decay, warmup_iters, min_lr, grad_clip.\n CONFIG_OVERRIDES = {'learning_rate': 3e-4}",
3 "end_line": 350,
4 "filename": "nanoGPT/custom_pretrain.py",
5 "op": "replace",
6 "start_line": 348
7}Tool ResultStep 2
OK: Replaced lines 348..350 with 3 line(s) in nanoGPT/custom_pretrain.py. Editable range: 171–292, 348–350.
[Current file: nanoGPT/custom_pretrain.py | editable: 171–292, 348–350 | total: 540 lines]
171 def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
172 param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
173 decay_params = [(n, p) for n, p in param_dict.items() if p.dim() >= 2]
... (116 more lines) ...
290
291 print(f"using MuonGC (lr={muon_base_lr}, scale={muon_lr_scale:.1f}) + AdamW combined optimizer")
292 return CombinedOptimizer([muon_opt, adam_opt])
...
348 # CONFIG_OVERRIDES: override training hyperparameters for your method.
349 # Allowed keys: learning_rate, weight_decay, warmup_iters, min_lr, grad_clip.
350 CONFIG_OVERRIDES = {'learning_rate': 3e-4}AssistantStep 3test
Tool ResultStep 3
[Test #1] (2 tests remaining; call submit(n=N) to choose which test result to submit as final) ### gpt-345m (scripts/gpt_345m.sh) Training metrics (last steps): TRAIN_METRICS: step=0, train_loss=11.0307, val_loss=11.0296 TRAIN_METRICS: step=1000, train_loss=3.2877, val_loss=3.0116 TRAIN_METRICS: step=2000, train_loss=3.0832, val_loss=2.7805 TRAIN_METRICS: step=3000, train_loss=2.9872, val_loss=2.7007 ### lm-eval-345m (scripts/lm_eval_345m.sh) [COMMAND FAILED exit=1] ERROR: Checkpoint not found: /scratch/gpfs/CHIJ/bohan/MLS-Bench/.saves/llm-pretrain-optimizer/claude-opus-4.6_20260413_120858/seed_42/ckpt_gpt-345m.pt
AssistantStep 4edit
Tool Input
1{
2 "content": " # CONFIG_OVERRIDES: override training hyperparameters for your method.\n # Allowed keys: learning_rate, weight_decay, warmup_iters, min_lr, grad_clip.\n CONFIG_OVERRIDES = {}",
3 "end_line": 350,
4 "filename": "nanoGPT/custom_pretrain.py",
5 "op": "replace",
6 "start_line": 348
7}Tool ResultStep 4
OK: Replaced lines 348..350 with 3 line(s) in nanoGPT/custom_pretrain.py. Editable range: 171–292, 348–350.
[Current file: nanoGPT/custom_pretrain.py | editable: 171–292, 348–350 | total: 540 lines]
171 def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
172 param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
173 decay_params = [(n, p) for n, p in param_dict.items() if p.dim() >= 2]
... (116 more lines) ...
290
291 print(f"using MuonGC (lr={muon_base_lr}, scale={muon_lr_scale:.1f}) + AdamW combined optimizer")
292 return CombinedOptimizer([muon_opt, adam_opt])
...
348 # CONFIG_OVERRIDES: override training hyperparameters for your method.
349 # Allowed keys: learning_rate, weight_decay, warmup_iters, min_lr, grad_clip.
350 CONFIG_OVERRIDES = {}AssistantStep 5edit
Tool Input
1{
2 "content": " # CONFIG_OVERRIDES: override training hyperparameters for your method.\n # Allowed keys: learning_rate, weight_decay, warmup_iters, min_lr, grad_clip.\n CONFIG_OVERRIDES = {'learning_rate': 3e-4}",
3 "end_line": 350,
4 "filename": "nanoGPT/custom_pretrain.py",
5 "op": "replace",
6 "start_line": 348
7}Tool ResultStep 5
OK: Replaced lines 348..350 with 3 line(s) in nanoGPT/custom_pretrain.py. Editable range: 171–292, 348–350.
[Current file: nanoGPT/custom_pretrain.py | editable: 171–292, 348–350 | total: 540 lines]
171 def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
172 param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
173 decay_params = [(n, p) for n, p in param_dict.items() if p.dim() >= 2]
... (116 more lines) ...
290
291 print(f"using MuonGC (lr={muon_base_lr}, scale={muon_lr_scale:.1f}) + AdamW combined optimizer")
292 return CombinedOptimizer([muon_opt, adam_opt])
...
348 # CONFIG_OVERRIDES: override training hyperparameters for your method.
349 # Allowed keys: learning_rate, weight_decay, warmup_iters, min_lr, grad_clip.
350 CONFIG_OVERRIDES = {'learning_rate': 3e-4}AssistantStep 6edit
Tool Input
1{
2 "content": " def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):\n param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}\n decay_params = [(n, p) for n, p in param_dict.items() if p.dim() >= 2]\n nodecay_params = [(n, p) for n, p in param_dict.items() if p.dim() < 2]\n num_decay_params = sum(p.numel() for _, p in decay_params)\n num_nodecay_params = sum(p.numel() for _, p in nodecay_params)\n print(f\"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters\")\n print(f\"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters\")\n\n # Separate 2D projection weights (for MuonGC) from rest (for AdamW)\n muon_params = [p for n, p in decay_params\n if 'wte' not in n and 'wpe' not in n and 'lm_head' not in n]\n adam_decay_params = [p for n, p in decay_params\n if 'wte' in n or 'wpe' in n or 'lm_head' in n]\n adam_nodecay_params = [p for _, p in nodecay_params]\n\n class MuonGC(torch.optim.Optimizer):\n \"\"\"Muon with Gradient Centralization — MomentUm Orthogonalized by\n Newton-Schulz with row-wise gradient centralization.\n\n Novel contribution: Before Newton-Schulz orthogonalization, we subtract\n the row-wise mean from the Nesterov momentum estimate. This projects\n the update direction onto a zero-mean subspace, acting as an implicit\n regularizer complementary to NS spectral normalization. GC removes\n directional bias (mean component) while NS equalizes singular values,\n producing more balanced, better-conditioned weight updates that improve\n generalization in language model pretraining.\n \"\"\"\n def __init__(self, params, lr=0.02, momentum=0.95, ns_steps=5, weight_decay=0.0):\n defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, weight_decay=weight_decay)\n super().__init__(params, defaults)\n\n @staticmethod\n def zeroth_power_via_newtonschulz5(G, steps=5):\n \"\"\"Approximate G @ (G^T G)^{-1/2} via Newton-Schulz iterations.\"\"\"\n assert G.ndim == 2\n a, b, c = (3.4445, -4.7750, 2.0315)\n X = G.bfloat16()\n X = X / (X.norm() + 1e-7)\n if G.size(0) > G.size(1):\n X = X.T\n for _ in range(steps):\n A = X @ X.T\n X = a * X + b * (A @ X) + c * (A @ (A @ X))\n if G.size(0) > G.size(1):\n X = X.T\n return X\n\n @torch.no_grad()\n def step(self):\n for group in self.param_groups:\n lr = group['lr']\n momentum = group['momentum']\n wd = group.get('weight_decay', 0.0)\n for p in group['params']:\n if p.grad is None:\n continue\n # Decoupled weight decay (before update)\n if wd > 0:\n p.mul_(1 - lr * wd)\n g = p.grad\n state = self.state[p]\n if len(state) == 0:\n state['momentum_buffer'] = torch.zeros_like(g)\n buf = state['momentum_buffer']\n # EMA momentum\n buf.lerp_(g, 1.0 - momentum)\n # Nesterov momentum estimate\n nesterov_g = g.lerp(buf, momentum)\n if nesterov_g.dim() == 2:\n # Gradient Centralization: subtract row-wise mean\n # This removes the mean component, regularizing updates\n nesterov_g = nesterov_g - nesterov_g.mean(dim=1, keepdim=True)\n orig_shape = nesterov_g.shape\n # Split fused QKV (c_attn: 3*n_embd x n_embd) into 3 parts\n if orig_shape[0] == 3 * orig_shape[1]:\n parts = nesterov_g.split(orig_shape[1])\n update = torch.cat([\n self.zeroth_power_via_newtonschulz5(part, steps=group['ns_steps'])\n for part in parts\n ])\n scale = max(1, orig_shape[0] // orig_shape[1]) ** 0.5\n else:\n update = self.zeroth_power_via_newtonschulz5(nesterov_g, steps=group['ns_steps'])\n scale = max(1, orig_shape[0] / orig_shape[1]) ** 0.5\n p.data.add_(update.to(p.dtype), alpha=-lr * scale)\n else:\n # Fallback: plain SGD with momentum for non-2D\n p.add_(buf, alpha=-lr)\n\n fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters\n use_fused = fused_available and device_type == 'cuda'\n extra_args = dict(fused=True) if use_fused else dict()\n\n muon_base_lr = 0.02\n muon_lr_scale = muon_base_lr / learning_rate\n # Use ns_steps=3 for efficiency at 345M scale while retaining convergence quality\n muon_opt = MuonGC([{'params': muon_params, 'lr_scale': muon_lr_scale}],\n lr=muon_base_lr, momentum=0.95, ns_steps=3, weight_decay=0.1)\n adam_groups = [\n {'params': adam_decay_params, 'weight_decay': weight_decay},\n {'params': adam_nodecay_params, 'weight_decay': 0.0},\n ]\n adam_opt = torch.optim.AdamW(adam_groups, lr=learning_rate, betas=betas, **extra_args)\n\n class CombinedOptimizer:\n \"\"\"Combines MuonGC (for projections) with AdamW (for embeddings/norms).\"\"\"\n def __init__(self, optimizers):\n self.optimizers = optimizers\n self.param_groups = []\n for opt in optimizers:\n self.param_groups.extend(opt.param_groups)\n def zero_grad(self, set_to_none=True):\n for opt in self.optimizers:\n opt.zero_grad(set_to_none=set_to_none)\n def step(self):\n for opt in self.optimizers:\n opt.step()\n def state_dict(self):\n return [opt.state_dict() for opt in self.optimizers]\n\n print(f\"using MuonGC (lr={muon_base_lr}, scale={muon_lr_scale:.1f}, ns=3) + AdamW combined optimizer\")\n return CombinedOptimizer([muon_opt, adam_opt])",
3 "end_line": 292,
4 "filename": "nanoGPT/custom_pretrain.py",
5 "op": "replace",
6 "start_line": 171
7}Tool ResultStep 6
OK: Replaced lines 171..292 with 123 line(s) in nanoGPT/custom_pretrain.py. Editable range: 171–293, 349–351.
[Current file: nanoGPT/custom_pretrain.py | editable: 171–293, 349–351 | total: 541 lines]
171 def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
172 param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
173 decay_params = [(n, p) for n, p in param_dict.items() if p.dim() >= 2]
... (117 more lines) ...
291
292 print(f"using MuonGC (lr={muon_base_lr}, scale={muon_lr_scale:.1f}, ns=3) + AdamW combined optimizer")
293 return CombinedOptimizer([muon_opt, adam_opt])
...
349 # CONFIG_OVERRIDES: override training hyperparameters for your method.
350 # Allowed keys: learning_rate, weight_decay, warmup_iters, min_lr, grad_clip.
351 CONFIG_OVERRIDES = {'learning_rate': 3e-4}AssistantStep 7test
AssistantStep 7test
Tool ResultStep 7
[Test #2] (1 test remaining; call submit(n=N) to choose which test result to submit as final)
### gpt-345m (scripts/gpt_345m.sh)
Training metrics (last steps):
TRAIN_METRICS: step=8000, train_loss=2.7003, val_loss=2.4058
TRAIN_METRICS: step=9000, train_loss=2.6347, val_loss=2.3439
TRAIN_METRICS: step=10000, train_loss=2.5918, val_loss=2.3018
TRAIN_METRICS: step=11000, train_loss=2.5416, val_loss=2.2488
TRAIN_METRICS: step=12000, train_loss=2.5244, val_loss=2.2344
Final metrics (gpt-345m): val_loss_gpt-345m=2.2272, wikitext2_ppl_gpt-345m=40.1800, lambada_ppl_gpt-345m=61.8800
### lm-eval-345m (scripts/lm_eval_345m.sh)
[COMMAND FAILED exit=1]
ERROR: Checkpoint not found: /scratch/gpfs/CHIJ/bohan/MLS-Bench/.saves/llm-pretrain-optimizer/claude-opus-4.6_20260413_120858/seed_42/ckpt_gpt-345m.pt
[Leaderboard] Results saved: {'val_loss_gpt-345m': 2.2272, 'wikitext2_ppl_gpt-345m': 40.18, 'lambada_ppl_gpt-345m': 61.88, 'elapsed_gpt-345m': 20772, 'elapsed_lm-eval-345m': 1}AssistantStep 8edit
Tool Input
1{
2 "content": " def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):\n param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}\n decay_params = [(n, p) for n, p in param_dict.items() if p.dim() >= 2]\n nodecay_params = [(n, p) for n, p in param_dict.items() if p.dim() < 2]\n num_decay_params = sum(p.numel() for _, p in decay_params)\n num_nodecay_params = sum(p.numel() for _, p in nodecay_params)\n print(f\"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters\")\n print(f\"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters\")\n\n # Separate 2D projection weights (for MuonGC) from rest (for AdamW)\n muon_params = [p for n, p in decay_params\n if 'wte' not in n and 'wpe' not in n and 'lm_head' not in n]\n adam_decay_params = [p for n, p in decay_params\n if 'wte' in n or 'wpe' in n or 'lm_head' in n]\n adam_nodecay_params = [p for _, p in nodecay_params]\n\n class MuonGC(torch.optim.Optimizer):\n \"\"\"Muon with Gradient Centralization — MomentUm Orthogonalized by\n Newton-Schulz with row-wise gradient centralization.\n\n Novel contribution: Before Newton-Schulz orthogonalization, we apply\n row-wise gradient centralization to the Nesterov momentum estimate.\n This projects updates onto a zero-mean subspace, acting as an implicit\n regularizer complementary to NS spectral normalization. GC removes\n directional bias (mean component) while NS equalizes singular values,\n producing more balanced, better-conditioned weight updates.\n \"\"\"\n def __init__(self, params, lr=0.02, momentum=0.95, ns_steps=5, weight_decay=0.0):\n defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, weight_decay=weight_decay)\n super().__init__(params, defaults)\n\n @staticmethod\n def zeroth_power_via_newtonschulz5(G, steps=5):\n \"\"\"Approximate G @ (G^T G)^{-1/2} via Newton-Schulz iterations.\"\"\"\n assert G.ndim == 2\n a, b, c = (3.4445, -4.7750, 2.0315)\n X = G.bfloat16()\n X = X / (X.norm() + 1e-7)\n if G.size(0) > G.size(1):\n X = X.T\n for _ in range(steps):\n A = X @ X.T\n X = a * X + b * (A @ X) + c * (A @ (A @ X))\n if G.size(0) > G.size(1):\n X = X.T\n return X\n\n @torch.no_grad()\n def step(self):\n for group in self.param_groups:\n lr = group['lr']\n momentum = group['momentum']\n wd = group.get('weight_decay', 0.0)\n for p in group['params']:\n if p.grad is None:\n continue\n # Decoupled weight decay (before update)\n if wd > 0:\n p.mul_(1 - lr * wd)\n g = p.grad\n state = self.state[p]\n if len(state) == 0:\n state['momentum_buffer'] = torch.zeros_like(g)\n buf = state['momentum_buffer']\n # EMA momentum\n buf.lerp_(g, 1.0 - momentum)\n # Nesterov momentum estimate\n nesterov_g = g.lerp(buf, momentum)\n if nesterov_g.dim() == 2:\n # Gradient Centralization: subtract row-wise mean\n nesterov_g = nesterov_g - nesterov_g.mean(dim=1, keepdim=True)\n orig_shape = nesterov_g.shape\n if orig_shape[0] == 3 * orig_shape[1]:\n parts = nesterov_g.split(orig_shape[1])\n update = torch.cat([\n self.zeroth_power_via_newtonschulz5(part, steps=group['ns_steps'])\n for part in parts\n ])\n scale = max(1, orig_shape[0] // orig_shape[1]) ** 0.5\n else:\n update = self.zeroth_power_via_newtonschulz5(nesterov_g, steps=group['ns_steps'])\n scale = max(1, orig_shape[0] / orig_shape[1]) ** 0.5\n p.data.add_(update.to(p.dtype), alpha=-lr * scale)\n else:\n p.add_(buf, alpha=-lr)\n\n fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters\n use_fused = fused_available and device_type == 'cuda'\n extra_args = dict(fused=True) if use_fused else dict()\n\n muon_base_lr = 0.02\n muon_lr_scale = muon_base_lr / learning_rate\n muon_opt = MuonGC([{'params': muon_params, 'lr_scale': muon_lr_scale}],\n lr=muon_base_lr, momentum=0.95, ns_steps=5, weight_decay=0.1)\n adam_groups = [\n {'params': adam_decay_params, 'weight_decay': weight_decay},\n {'params': adam_nodecay_params, 'weight_decay': 0.0},\n ]\n adam_opt = torch.optim.AdamW(adam_groups, lr=learning_rate, betas=betas, **extra_args)\n\n class CombinedOptimizer:\n \"\"\"Combines MuonGC (for projections) with AdamW (for embeddings/norms).\"\"\"\n def __init__(self, optimizers):\n self.optimizers = optimizers\n self.param_groups = []\n for opt in optimizers:\n self.param_groups.extend(opt.param_groups)\n def zero_grad(self, set_to_none=True):\n for opt in self.optimizers:\n opt.zero_grad(set_to_none=set_to_none)\n def step(self):\n for opt in self.optimizers:\n opt.step()\n def state_dict(self):\n return [opt.state_dict() for opt in self.optimizers]\n\n print(f\"using MuonGC (lr={muon_base_lr}, scale={muon_lr_scale:.1f}, ns=5) + AdamW combined optimizer\")\n return CombinedOptimizer([muon_opt, adam_opt])",
3 "end_line": 293,
4 "filename": "nanoGPT/custom_pretrain.py",
5 "op": "replace",
6 "start_line": 171
7}Tool ResultStep 8
OK: Replaced lines 171..293 with 118 line(s) in nanoGPT/custom_pretrain.py. Editable range: 171–288.
[Current file: nanoGPT/custom_pretrain.py | editable: 171–288 | total: 536 lines]
171 def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
172 param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
173 decay_params = [(n, p) for n, p in param_dict.items() if p.dim() >= 2]
... (112 more lines) ...
286
287 print(f"using MuonGC (lr={muon_base_lr}, scale={muon_lr_scale:.1f}, ns=5) + AdamW combined optimizer")
288 return CombinedOptimizer([muon_opt, adam_opt])AssistantStep 9test