Agent Conversation: qwen3.6-plus
Model: qwen3.6-plus
Total steps: 10
Messages: 22
UserStep 0
Task: llm-pretrain-kernel
LLM Pretraining: Custom GPU Kernel Optimization
Research Question
Write a custom GPU kernel (Triton or CUDA via PyTorch) to implement a fused MLP operation for GPT-2 pretraining. Your kernel should fuse multiple operations to reduce memory bandwidth and improve throughput while maintaining or improving model quality.
What You Can Modify
The fused_mlp_forward function (lines 34-48) in custom_pretrain.py:
- The MLP activation function (default: GELU via separate PyTorch ops)
- Kernel fusion strategy (fuse linear + activation, save intermediate values)
- Memory optimization (avoid materializing intermediate tensors)
- Custom autograd Functions for efficient backward pass
Note: The function signature fused_mlp_forward(x, w_fc, w_proj) must be preserved.
x: input tensor(B*T, n_embd)w_fc: first linear weight(4*n_embd, n_embd)w_proj: second linear weight(n_embd, 4*n_embd)- Returns: output tensor
(B*T, n_embd)
The MLP class calls this function and handles dropout separately.
Evaluation
- Metric: Validation loss (cross-entropy, lower is better) — kernel optimizations that also change the activation function may improve loss
- Model sizes: GPT-2 124M (12L/12H/768D) and GPT-2 1.5B (48L/25H/1600D, 4-GPU DDP)
- Dataset: FineWeb 10B (GPT-2 tokenizer)
- Training: 5000 iterations, batch_size=12, block_size=1024, grad_accum=5
- Hardware: H100 GPU with Triton support
nanoGPT/custom_pretrain.py [EDITABLE — lines 33–48, lines 257–259 only]
1: """Custom GPT-2 Pretraining Script
2: Based on Andrej Karpathy's nanoGPT, evaluated on FineWeb dataset.
3: """
4:
5: import math
6: import inspect
7: import os
8: import time
9: from contextlib import nullcontext
10: from dataclasses import dataclass
11:
12: import numpy as np
13: import torch
14: import torch.nn as nn
15: from torch.nn import functional as F
16:
17: # ============================================================================
18: # Model Components
19: # ============================================================================
20:
21: # ── Normalization ──────────────────────────────────────────────────────────
22: class LayerNorm(nn.Module):
23: """LayerNorm but with an optional bias."""
24: def __init__(self, ndim, bias):
25: super().__init__()
26: self.weight = nn.Parameter(torch.ones(ndim))
27: self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
28:
29: def forward(self, input):
30: return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
31:
32: # ── Custom Kernel / Fused Operation ───────────────────────────────────────
33: def fused_mlp_forward(x, w_fc, w_proj):
34: """MLP forward pass: linear -> activation -> linear.
35:
36: Default implementatioAssistantStep 1edit
Tool Input
1{
2 "op": "replace",
3 "filename": "nanoGPT/custom_pretrain.py",
4 "start_line": 33,
5 "end_line": 48,
6 "content": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef _swiglu_kernel(\n gate_ptr, value_ptr, out_ptr,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(0)\n offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n\n g = tl.load(gate_ptr + offsets, mask=mask)\n v = tl.load(value_ptr + offsets, mask=mask)\n\n # Swish(x) = x * sigmoid(x), computed in fp32 for stability\n g_f = g.to(tl.float32)\n sig = tl.sigmoid(g_f)\n swish = g_f * sig\n\n # SwiGLU output = Swish(gate) * value\n out = (swish * v.to(tl.float32)).to(g.dtype)\n tl.store(out_ptr + offsets, out, mask=mask)\n\n\nclass _SwiGLUMLP(torch.autograd.Function):\n \"\"\"SwiGLU activation with Triton kernel for the element-wise operation.\n\n Architecture:\n - Split w_fc into gate (top 2*n_embd) and value (bottom 2*n_embd)\n - gate = x @ w_gate.T, value = x @ w_val.T\n - output = Swish(gate) * value (dimension: 2*n_embd)\n - final = output @ w_proj_half.T (dimension: n_embd)\n\n SwiGLU provides superior expressivity through multiplicative gating,\n as demonstrated in LLaMA, PaLM, and other modern LLMs.\n \"\"\"\n @staticmethod\n def forward(ctx, x, w_fc, w_proj):\n D = w_fc.shape[1] # n_embd\n\n # Split weights into gate and value\n w_fc_gate = w_fc[:2 * D, :]\n w_fc_val = w_fc[2 * D:, :]\n\n # First matmul (PyTorch optimized)\n gate = x @ w_fc_gate.t() # (M, 2D)\n value = x @ w_fc_val.t() # (M, 2D)\n\n # Apply SwiGLU via Triton kernel\n n = gate.numel()\n BLOCK = min(triton.next_power_of_2(n), 65536)\n BLOCK = max(BLOCK, 256)\n out = torch.empty_like(gate)\n _swiglu_kernel[(n + BLOCK - 1) // BLOCK,](\n gate, value, out, n, BLOCK_SIZE=BLOCK\n )\n\n # Use only the half of w_proj that matches SwiGLU's output dimension\n w_proj_half = w_proj[:, :2 * D] # (D, 2D)\n final = out @ w_proj_half.t() # (M, D)\n\n ctx.save_for_backward(x, w_fc_gate, w_fc_val, w_proj_half, gate, value, out)\n return final\n\n @staticmethod\n def backward(ctx, grad_output):\n x, w_fc_gate, w_fc_val, w_proj_half, gate, value, out = ctx.saved_tensors\n dtype = grad_output.dtype\n\n # Gradient through second linear: d_act = grad_output @ w_proj_half\n d_act = grad_output @ w_proj_half # (M, 2D)\n\n # Split gradient for gate and value branches\n d_gate = d_act # flows through Swish\n d_val = d_act # flows directly\n\n # SwiGLU backward: out = Swish(gate) * value\n # d(Swish)/dx = sigmoid(x) * (1 + x * (1 - sigmoid(x)))\n gate_f = gate.float()\n sig = torch.sigmoid(gate_f)\n swish_grad = sig * (1.0 + gate_f * (1.0 - sig))\n swish_val = (gate_f * sig).to(dtype)\n\n d_gate_pre = (d_gate.float() * swish_grad).to(dtype)\n d_val_pre = d_val * swish_val\n\n # Gradient through first linear\n grad_x = d_gate_pre @ w_fc_gate + d_val_pre @ w_fc_val\n\n # Weight gradients\n grad_w_fc_gate = d_gate_pre.t() @ x\n grad_w_fc_val = d_val_pre.t() @ x\n grad_w_fc = torch.cat([grad_w_fc_gate, grad_w_fc_val], dim=0)\n\n grad_w_proj = grad_output.t() @ out\n\n return grad_x, grad_w_fc, grad_w_proj\n\n\ndef fused_mlp_forward(x, w_fc, w_proj):\n \"\"\"MLP forward pass with SwiGLU activation and Triton kernel.\n\n SwiGLU (Swish-Gated Linear Unit) splits the MLP into gated and value paths,\n applying Swish(x) = x * sigmoid(x) to the gate before element-wise multiplication.\n This provides superior representational capacity compared to GELU/ReLU².\n\n Args:\n x: input tensor (B*T, n_embd)\n w_fc: first linear weight (4*n_embd, n_embd) - split into gate+value\n w_proj: second linear weight (n_embd, 4*n_embd) - half used for projection\n Returns:\n output tensor (B*T, n_embd)\n \"\"\"\n return _SwiGLUMLP.apply(x, w_fc, w_proj)"
7}Tool ResultStep 1
OK: Replaced lines 33..48 with 116 line(s) in nanoGPT/custom_pretrain.py. Editable range: 33–148, 357–359.
[Current file: nanoGPT/custom_pretrain.py | editable: 33–148 | total: 549 lines]
30| return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
31|
32| # ── Custom Kernel / Fused Operation ───────────────────────────────────────
33 import triton
34 import triton.language as tl
35
36 @triton.jit
37 def _swiglu_kernel(
38 gate_ptr, value_ptr, out_ptr,
39 n_elements,
40 BLOCK_SIZE: tl.constexpr,
41 ):
42 pid = tl.program_id(0)
43 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
44 mask = offsets < n_elements
45
46 g = tl.load(gate_ptr + offsets, mask=mask)
47 v = tl.load(value_ptr + offsets, mask=mask)
48
49 # Swish(x) = x * sigmoid(x), computed in fp32 for stability
50 g_f = g.to(tl.float32)
51 sig =
AssistantStep 2edit
Tool Input
1{
2 "filename": "nanoGPT/custom_pretrain.py",
3 "op": "replace",
4 "start_line": 99,
5 "end_line": 131,
6 "content": " @staticmethod\n def backward(ctx, grad_output):\n x, w_fc_gate, w_fc_val, w_proj_half, gate, value, out = ctx.saved_tensors\n dtype = grad_output.dtype\n\n # Gradient through second linear: d_out = grad_output @ w_proj_half (M, 2D)\n d_out = grad_output @ w_proj_half\n\n # SwiGLU backward: out = Swish(gate) * value\n # d(L)/d(gate) = d_out * value * d(Swish)/d(gate)\n # d(L)/d(value) = d_out * Swish(gate)\n # d(Swish)/d(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x)))\n\n # Compute in float32 for numerical stability\n gate_f = gate.float()\n value_f = value.float()\n d_out_f = d_out.float()\n\n sig = torch.sigmoid(gate_f)\n swish_grad = sig * (1.0 + gate_f * (1.0 - sig))\n swish_val = gate_f * sig\n\n d_gate_pre = (d_out_f * value_f * swish_grad).to(dtype)\n d_val_pre = (d_out_f * swish_val).to(dtype)\n\n # Gradient through first linear\n grad_x = d_gate_pre @ w_fc_gate + d_val_pre @ w_fc_val\n\n # Weight gradients\n grad_w_fc_gate = d_gate_pre.t() @ x\n grad_w_fc_val = d_val_pre.t() @ x\n grad_w_fc = torch.cat([grad_w_fc_gate, grad_w_fc_val], dim=0)\n\n grad_w_proj = grad_output.t() @ out\n\n return grad_x, grad_w_fc, grad_w_proj"
7}Tool ResultStep 2
OK: Replaced lines 99..131 with 36 line(s) in nanoGPT/custom_pretrain.py. Editable range: 33–151, 360–362.
[Current file: nanoGPT/custom_pretrain.py | editable: 33–151 | total: 552 lines]
30| return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
31|
32| # ── Custom Kernel / Fused Operation ───────────────────────────────────────
33 import triton
34 import triton.language as tl
35
36 @triton.jit
37 def _swiglu_kernel(
38 gate_ptr, value_ptr, out_ptr,
39 n_elements,
40 BLOCK_SIZE: tl.constexpr,
41 ):
42 pid = tl.program_id(0)
43 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
44 mask = offsets < n_elements
45
46 g = tl.load(gate_ptr + offsets, mask=mask)
47 v = tl.load(value_ptr + offsets, mask=mask)
48
49 # Swish(x) = x * sigmoid(x), computed in fp32 for stability
50 g_f = g.to(tl.float32)
51 sig =
AssistantStep 3edit
Tool Input
1{
2 "filename": "nanoGPT/custom_pretrain.py",
3 "op": "replace",
4 "start_line": 99,
5 "end_line": 134,
6 "content": " @staticmethod\n def backward(ctx, grad_output):\n x, w_fc_gate, w_fc_val, w_proj_half, gate, value, out = ctx.saved_tensors\n dtype = grad_output.dtype\n D = w_fc_gate.shape[1]\n\n # Gradient through second linear: d_out = grad_output @ w_proj_half (M, 2D)\n d_out = grad_output @ w_proj_half\n\n # SwiGLU backward: out = Swish(gate) * value\n # d(L)/d(gate) = d_out * value * d(Swish)/d(gate)\n # d(L)/d(value) = d_out * Swish(gate)\n # d(Swish)/d(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x)))\n\n # Compute in float32 for numerical stability\n gate_f = gate.float()\n value_f = value.float()\n d_out_f = d_out.float()\n\n sig = torch.sigmoid(gate_f)\n swish_grad = sig * (1.0 + gate_f * (1.0 - sig))\n swish_val = gate_f * sig\n\n d_gate_pre = (d_out_f * value_f * swish_grad).to(dtype)\n d_val_pre = (d_out_f * swish_val).to(dtype)\n\n # Gradient through first linear\n grad_x = d_gate_pre @ w_fc_gate + d_val_pre @ w_fc_val\n\n # Weight gradients for fc (concatenate gate and value portions)\n grad_w_fc_gate = d_gate_pre.t() @ x\n grad_w_fc_val = d_val_pre.t() @ x\n grad_w_fc = torch.cat([grad_w_fc_gate, grad_w_fc_val], dim=0)\n\n # Weight gradient for proj - must match full w_proj shape (D, 4D)\n # Only first 2D columns receive gradients; rest are zero\n grad_w_proj = torch.zeros(w_fc_gate.shape[0], 4 * D, device=grad_output.device, dtype=dtype)\n grad_w_proj[:, :2 * D] = grad_output.t() @ out\n\n return grad_x, grad_w_fc, grad_w_proj"
7}Tool ResultStep 3
OK: Replaced lines 99..134 with 40 line(s) in nanoGPT/custom_pretrain.py. Editable range: 33–155, 364–366.
[Current file: nanoGPT/custom_pretrain.py | editable: 33–155 | total: 556 lines]
30| return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
31|
32| # ── Custom Kernel / Fused Operation ───────────────────────────────────────
33 import triton
34 import triton.language as tl
35
36 @triton.jit
37 def _swiglu_kernel(
38 gate_ptr, value_ptr, out_ptr,
39 n_elements,
40 BLOCK_SIZE: tl.constexpr,
41 ):
42 pid = tl.program_id(0)
43 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
44 mask = offsets < n_elements
45
46 g = tl.load(gate_ptr + offsets, mask=mask)
47 v = tl.load(value_ptr + offsets, mask=mask)
48
49 # Swish(x) = x * sigmoid(x), computed in fp32 for stability
50 g_f = g.to(tl.float32)
51 sig =
AssistantStep 4edit
Tool Input
1{
2 "filename": "nanoGPT/custom_pretrain.py",
3 "op": "replace",
4 "start_line": 59,
5 "end_line": 155,
6 "content": "class _SwiGLUMLP(torch.autograd.Function):\n \"\"\"SwiGLU activation with Triton kernel for the element-wise operation.\n\n Architecture:\n - h = x @ w_fc.t() (M, 4D) - single efficient matmul\n - gate, value = h.chunk(2) (M, 2D) each\n - output = Swish(gate) * value (M, 2D) via Triton kernel\n - final = output @ w_proj_half.t() (M, D)\n\n SwiGLU provides superior expressivity through multiplicative gating,\n as demonstrated in LLaMA, PaLM, and other modern LLMs.\n \"\"\"\n @staticmethod\n def forward(ctx, x, w_fc, w_proj):\n D = x.shape[1] # n_embd\n\n # Single matmul: h = x @ w_fc.t() (M, 4D)\n h = x @ w_fc.t()\n\n # Split into gate and value halves\n gate = h[:, :2 * D] # (M, 2D)\n value = h[:, 2 * D:] # (M, 2D)\n\n # Apply SwiGLU via Triton kernel\n n = gate.numel()\n BLOCK = min(triton.next_power_of_2(n), 65536)\n BLOCK = max(BLOCK, 256)\n out = torch.empty_like(gate)\n _swiglu_kernel[(n + BLOCK - 1) // BLOCK,](\n gate, value, out, n, BLOCK_SIZE=BLOCK\n )\n\n # Project back: only first 2D columns of w_proj used\n w_proj_half = w_proj[:, :2 * D] # (D, 2D)\n final = out @ w_proj_half.t() # (M, D)\n\n ctx.save_for_backward(x, w_fc, w_proj_half, gate, value, out)\n return final\n\n @staticmethod\n def backward(ctx, grad_output):\n x, w_fc, w_proj_half, gate, value, out = ctx.saved_tensors\n D = x.shape[1]\n\n dtype = grad_output.dtype\n\n # Gradient through second linear: d_out = grad_output @ w_proj_half (M, 2D)\n d_out = grad_output @ w_proj_half\n\n # SwiGLU backward: out = Swish(gate) * value\n # d(L)/d(gate) = d_out * value * d(Swish)/d(gate)\n # d(L)/d(value) = d_out * Swish(gate)\n # d(Swish)/d(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x)))\n\n gate_f = gate.float()\n value_f = value.float()\n d_out_f = d_out.float()\n\n sig = torch.sigmoid(gate_f)\n swish_grad = sig * (1.0 + gate_f * (1.0 - sig))\n swish_val = gate_f * sig\n\n d_gate = (d_out_f * value_f * swish_grad).to(dtype)\n d_val = (d_out_f * swish_val).to(dtype)\n\n # Concatenate gradients for the full h tensor\n d_h = torch.cat([d_gate, d_val], dim=-1) # (M, 4D)\n\n # Gradient through first linear\n grad_x = d_h @ w_fc\n grad_w_fc = d_h.t() @ x\n\n # Weight gradient for proj - match full w_proj shape (D, 4D)\n grad_w_proj = torch.zeros(w_proj_half.shape[0], 4 * D,\n device=grad_output.device, dtype=dtype)\n grad_w_proj[:, :2 * D] = grad_output.t() @ out\n\n return grad_x, grad_w_fc, grad_w_proj"
7}Tool ResultStep 4
OK: Replaced lines 59..155 with 78 line(s) in nanoGPT/custom_pretrain.py. Editable range: 33–136, 345–347.
[Current file: nanoGPT/custom_pretrain.py | editable: 33–136 | total: 537 lines]
30| return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
31|
32| # ── Custom Kernel / Fused Operation ───────────────────────────────────────
33 import triton
34 import triton.language as tl
35
36 @triton.jit
37 def _swiglu_kernel(
38 gate_ptr, value_ptr, out_ptr,
39 n_elements,
40 BLOCK_SIZE: tl.constexpr,
41 ):
42 pid = tl.program_id(0)
43 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
44 mask = offsets < n_elements
45
46 g = tl.load(gate_ptr + offsets, mask=mask)
47 v = tl.load(value_ptr + offsets, mask=mask)
48
49 # Swish(x) = x * sigmoid(x), computed in fp32 for stability
50 g_f = g.to(tl.float32)
51 sig =
AssistantStep 5test
Tool ResultStep 5
[Test #1] (2 tests remaining; call submit(n=N) to choose which test result to submit as final) ### gpt-345m (scripts/gpt_345m.sh) [COMMAND FAILED exit=-1] [output file not found: /scratch/gpfs/CHIJ/bohan/MLS-Bench/logs/llm-pretrain-kernel/qwen3.6-plus_20260412_102247/20260412_103220/group_1/gpt-345m_s42.out] ### lm-eval-345m (scripts/lm_eval_345m.sh) [COMMAND FAILED exit=1] ERROR: Checkpoint not found: /scratch/gpfs/CHIJ/bohan/MLS-Bench/.saves/llm-pretrain-kernel/qwen3.6-plus_20260412_102247/seed_42/ckpt_gpt-345m.pt
AssistantStep 6edit
Tool Input
1{
2 "filename": "nanoGPT/custom_pretrain.py",
3 "op": "replace",
4 "start_line": 33,
5 "end_line": 136,
6 "content": "class _SwiGLUMLP(torch.autograd.Function):\n \"\"\"SwiGLU MLP: h=x@w_fc.t(), gate,value=h.split(2), out=Swish(gate)*value, final=out@w_proj_half.t()\n SwiGLU (Swish-Gated Linear Unit) provides superior expressivity via multiplicative gating.\"\"\"\n\n @staticmethod\n def forward(ctx, x, w_fc, w_proj):\n D = x.shape[1]\n h = torch.matmul(x, w_fc.t())\n gate = h[:, :2 * D]\n value = h[:, 2 * D:]\n\n # Swish(gate) = gate * sigmoid(gate)\n sig = torch.sigmoid(gate)\n swish = gate * sig\n out = swish * value\n\n w_proj_half = w_proj[:, :2 * D]\n final = torch.matmul(out, w_proj_half.t())\n\n ctx.save_for_backward(x, w_fc, w_proj_half, gate, value, out)\n return final\n\n @staticmethod\n def backward(ctx, grad_output):\n x, w_fc, w_proj_half, gate, value, out = ctx.saved_tensors\n D = x.shape[1]\n\n d_out = torch.matmul(grad_output, w_proj_half)\n\n gate_f = gate.float()\n value_f = value.float()\n d_out_f = d_out.float()\n\n sig = torch.sigmoid(gate_f)\n swish_grad = sig * (1.0 + gate_f * (1.0 - sig))\n swish_val = gate_f * sig\n\n d_gate = (d_out_f * value_f * swish_grad).to(d_out.dtype)\n d_val = (d_out_f * swish_val).to(d_out.dtype)\n\n d_h = torch.cat([d_gate, d_val], dim=-1)\n\n grad_x = torch.matmul(d_h, w_fc)\n grad_w_fc = torch.matmul(d_h.t(), x)\n\n grad_w_proj = torch.empty(w_proj_half.shape[0], 4 * D,\n device=grad_output.device, dtype=grad_output.dtype)\n grad_w_proj.zero_()\n grad_w_proj[:, :2 * D] = torch.matmul(grad_output.t(), out)\n\n return grad_x, grad_w_fc, grad_w_proj\n\n\ndef fused_mlp_forward(x, w_fc, w_proj):\n \"\"\"MLP forward pass with SwiGLU activation.\n\n SwiGLU (Swish-Gated Linear Unit) splits the MLP into gated and value paths,\n applying Swish(x) = x * sigmoid(x) to the gate before element-wise multiplication.\n This provides superior representational capacity compared to GELU/ReLU^2.\n\n Args:\n x: input tensor (B*T, n_embd)\n w_fc: first linear weight (4*n_embd, n_embd) - split into gate+value\n w_proj: second linear weight (n_embd, 4*n_embd) - first 2*n_embd cols used\n Returns:\n output tensor (B*T, n_embd)\n \"\"\"\n return _SwiGLUMLP.apply(x, w_fc, w_proj)"
7}Tool ResultStep 6
OK: Replaced lines 33..136 with 68 line(s) in nanoGPT/custom_pretrain.py. Editable range: 33–100, 309–311.
[Current file: nanoGPT/custom_pretrain.py | editable: 33–100 | total: 501 lines]
30| return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
31|
32| # ── Custom Kernel / Fused Operation ───────────────────────────────────────
33 class _SwiGLUMLP(torch.autograd.Function):
34 """SwiGLU MLP: h=x@w_fc.t(), gate,value=h.split(2), out=Swish(gate)*value, final=out@w_proj_half.t()
35 SwiGLU (Swish-Gated Linear Unit) provides superior expressivity via multiplicative gating."""
36
37 @staticmethod
38 def forward(ctx, x, w_fc, w_proj):
39 D = x.shape[1]
40 h = torch.matmul(x, w_fc.t())
41 gate = h[:, :2 * D]
42 value = h[:, 2 * D:]
43
44 # Swish(gate) = gate * sigmoid(gate)
45 sig = torch.sigmoid(gate)
46 swi
AssistantStep 7edit
Tool Input
1{
2 "content": "class _SwiGLUMLP(torch.autograd.Function):\n \"\"\"SwiGLU MLP with custom autograd for fused activation.\n h = x @ w_fc.t(), gate,value = split(h), out = Swish(gate) * value,\n final = out @ w_proj_half.t()\n \"\"\"\n\n @staticmethod\n def forward(ctx, x, w_fc, w_proj):\n D = x.shape[1]\n # h: (M, 4D)\n h = x @ w_fc.t()\n gate, value = h.split([2 * D, 2 * D], dim=-1)\n # Swish(x) = x * sigmoid(x), computed in fp32 for stability\n gate_f = gate.float()\n swish = gate_f * torch.sigmoid(gate_f)\n out = (swish * value.float()).to(h.dtype)\n # Use first 2D cols of w_proj\n w_proj_half = w_proj[:, :2 * D]\n final = out @ w_proj_half.t()\n ctx.save_for_backward(x, w_fc, w_proj_half, gate, value, out)\n return final\n\n @staticmethod\n def backward(ctx, grad_output):\n x, w_fc, w_proj_half, gate, value, out = ctx.saved_tensors\n D = x.shape[1]\n d_out = grad_output @ w_proj_half # (M, 2D)\n # d(Swish)/dx = sigmoid(x)*(1 + x*(1-sigmoid(x)))\n gate_f, value_f, d_out_f = gate.float(), value.float(), d_out.float()\n sig = torch.sigmoid(gate_f)\n swish_grad = sig * (1.0 + gate_f * (1.0 - sig))\n swish_val = gate_f * sig\n d_gate = (d_out_f * value_f * swish_grad).to(d_out.dtype)\n d_val = (d_out_f * swish_val).to(d_out.dtype)\n d_h = torch.cat([d_gate, d_val], dim=-1)\n grad_x = d_h @ w_fc\n grad_w_fc = d_h.t() @ x\n grad_w_proj = torch.zeros(w_proj_half.shape[0], 4 * D,\n device=grad_output.device, dtype=grad_output.dtype)\n grad_w_proj[:, :2 * D] = grad_output.t() @ out\n return grad_x, grad_w_fc, grad_w_proj\n\n\ndef fused_mlp_forward(x, w_fc, w_proj):\n \"\"\"MLP forward with SwiGLU activation.\n Swish-gated linear unit splits 4*n_embd into gate+value (2*n_embd each),\n applies Swish(x)=x*sigmoid(x) to gate then element-wise multiplies with value.\n Proven superior to GELU/ReLU^2 in LLaMA/PaLM/Modern LLMs.\n \"\"\"\n return _SwiGLUMLP.apply(x, w_fc, w_proj)",
3 "end_line": 100,
4 "filename": "nanoGPT/custom_pretrain.py",
5 "start_line": 33,
6 "op": "replace"
7}Tool ResultStep 7
OK: Replaced lines 33..100 with 50 line(s) in nanoGPT/custom_pretrain.py. Editable range: 33–82, 291–293.
[Current file: nanoGPT/custom_pretrain.py | editable: 33–82 | total: 483 lines]
30| return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
31|
32| # ── Custom Kernel / Fused Operation ───────────────────────────────────────
33 class _SwiGLUMLP(torch.autograd.Function):
34 """SwiGLU MLP with custom autograd for fused activation.
35 h = x @ w_fc.t(), gate,value = split(h), out = Swish(gate) * value,
36 final = out @ w_proj_half.t()
37 """
38
39 @staticmethod
40 def forward(ctx, x, w_fc, w_proj):
41 D = x.shape[1]
42 # h: (M, 4D)
43 h = x @ w_fc.t()
44 gate, value = h.split([2 * D, 2 * D], dim=-1)
45 # Swish(x) = x * sigmoid(x), computed in fp32 for stability
46 gate_f = gate.float()
47 sw
AssistantStep 8test
Tool ResultStep 8
[Test #2] (1 test remaining; call submit(n=N) to choose which test result to submit as final)
### gpt-345m (scripts/gpt_345m.sh)
Training metrics (last steps):
TRAIN_METRICS: step=9000, train_loss=2.6691, val_loss=2.3907
TRAIN_METRICS: step=10000, train_loss=2.6482, val_loss=2.3635
TRAIN_METRICS: step=11000, train_loss=2.6314, val_loss=2.3553
TRAIN_METRICS: step=12000, train_loss=2.6133, val_loss=2.3395
TRAIN_METRICS: step=13000, train_loss=2.5983, val_loss=2.3286
Final metrics (gpt-345m): val_loss_gpt-345m=2.3225, wikitext2_ppl_gpt-345m=45.5400, lambada_ppl_gpt-345m=70.1700
### lm-eval-345m (scripts/lm_eval_345m.sh)
Final metrics (lm-eval-345m): arc_easy_lm-eval-345m=54.3400, hellaswag_lm-eval-345m=32.6400, piqa_lm-eval-345m=64.0400, winogrande_lm-eval-345m=50.4300
[Leaderboard] Results saved: {'val_loss_gpt-345m': 2.3225, 'wikitext2_ppl_gpt-345m': 45.54, 'lambada_ppl_gpt-345m': 70.17, 'elapsed_gpt-345m': 17943, 'arc_easy_lm-eval-345m': 54.34, 'hellaswag_lm-eval-345m': 32.64, 'piqa
AssistantStep 9edit
Tool Input
1{
2 "content": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef _silu_activation_kernel(\n x_ptr, out_ptr, n_elements,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(0)\n offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n xf = tl.load(x_ptr + offsets, mask=mask).to(tl.float32)\n sig = tl.sigmoid(xf)\n out = (xf * sig).to(tl.load(x_ptr + offsets, mask=mask).dtype)\n tl.store(out_ptr + offsets, out, mask=mask)\n\n\n@triton.jit\ndef _silu_backward_kernel(\n grad_ptr, x_ptr, out_ptr,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n):\n \"\"\"d(L)/dx = d(L)/d(silu) * sigmoid(x) * (1 + x*(1-sigmoid(x)))\"\"\"\n pid = tl.program_id(0)\n offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n xf = tl.load(x_ptr + offsets, mask=mask).to(tl.float32)\n grad_f = tl.load(grad_ptr + offsets, mask=mask).to(tl.float32)\n sig = tl.sigmoid(xf)\n silu_grad = sig * (1.0 + xf * (1.0 - sig))\n dx = (grad_f * silu_grad).to(tl.load(x_ptr + offsets, mask=mask).dtype)\n tl.store(out_ptr + offsets, dx, mask=mask)\n\n\nclass _SiLUMMLP(torch.autograd.Function):\n \"\"\"SiLU (Swish) activation with Triton kernel for element-wise operation.\n SiLU(x) = x * sigmoid(x) - smooth non-linearity superior to GELU in many cases.\n Uses all parameters (4*n_embd -> 4*n_embd -> n_embd), unlike SwiGLU variants.\n \"\"\"\n\n @staticmethod\n def forward(ctx, x, w_fc, w_proj):\n h = x @ w_fc.t() # (M, 4D)\n act = torch.empty_like(h)\n n = h.numel()\n BLOCK = min(triton.next_power_of_2(n), 65536)\n BLOCK = max(BLOCK, 256)\n _silu_activation_kernel[(n + BLOCK - 1) // BLOCK,](\n h, act, n, BLOCK_SIZE=BLOCK\n )\n final = act @ w_proj.t() # (M, D)\n ctx.save_for_backward(x, w_fc, w_proj, h, act)\n return final\n\n @staticmethod\n def backward(ctx, grad_output):\n x, w_fc, w_proj, h, act = ctx.saved_tensors\n d_act = grad_output @ w_proj # (M, 4D)\n dtype = grad_output.dtype\n\n # SiLU gradient via Triton kernel\n d_h = torch.empty_like(d_act)\n n = d_act.numel()\n BLOCK = min(triton.next_power_of_2(n), 65536)\n BLOCK = max(BLOCK, 256)\n _silu_backward_kernel[(n + BLOCK - 1) // BLOCK,](\n d_act, h, d_h, n, BLOCK_SIZE=BLOCK\n )\n\n grad_x = d_h @ w_fc\n grad_w_fc = d_h.t() @ x\n grad_w_proj = grad_output.t() @ act\n return grad_x, grad_w_fc, grad_w_proj\n\n\ndef fused_mlp_forward(x, w_fc, w_proj):\n \"\"\"MLP forward with SiLU (Swish) activation via custom Triton kernel.\n SiLU(x) = x * sigmoid(x) provides smooth gradients and has been shown\n to outperform GELU in modern LLM architectures (e.g., LLaMA).\n Full parameter utilization: (M,4D) -> SiLU -> (M,4D) -> linear -> (M,D).\n \"\"\"\n return _SiLUMMLP.apply(x, w_fc, w_proj)",
3 "end_line": 82,
4 "filename": "nanoGPT/custom_pretrain.py",
5 "start_line": 33,
6 "op": "replace"
7}Tool ResultStep 9
OK: Replaced lines 33..82 with 83 line(s) in nanoGPT/custom_pretrain.py. Editable range: 33–115, 324–326.
[Current file: nanoGPT/custom_pretrain.py | editable: 33–115 | total: 516 lines]
30| return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
31|
32| # ── Custom Kernel / Fused Operation ───────────────────────────────────────
33 import triton
34 import triton.language as tl
35
36 @triton.jit
37 def _silu_activation_kernel(
38 x_ptr, out_ptr, n_elements,
39 BLOCK_SIZE: tl.constexpr,
40 ):
41 pid = tl.program_id(0)
42 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
43 mask = offsets < n_elements
44 xf = tl.load(x_ptr + offsets, mask=mask).to(tl.float32)
45 sig = tl.sigmoid(xf)
46 out = (xf * sig).to(tl.load(x_ptr + offsets, mask=mask).dtype)
47 tl.store(out_ptr + offsets, out, mask=mask)
48
49
50 @triton.jit