Agent Conversation: gemini-3.1-pro-preview
Model: gemini-3.1-pro-preview
Total steps: 6
Messages: 14
UserStep 0
Task: llm-pretrain-kernel
LLM Pretraining: Custom GPU Kernel Optimization
Research Question
Write a custom GPU kernel (Triton or CUDA via PyTorch) to implement a fused MLP operation for GPT-2 pretraining. Your kernel should fuse multiple operations to reduce memory bandwidth and improve throughput while maintaining or improving model quality.
What You Can Modify
The fused_mlp_forward function (lines 34-48) in custom_pretrain.py:
- The MLP activation function (default: GELU via separate PyTorch ops)
- Kernel fusion strategy (fuse linear + activation, save intermediate values)
- Memory optimization (avoid materializing intermediate tensors)
- Custom autograd Functions for efficient backward pass
Note: The function signature fused_mlp_forward(x, w_fc, w_proj) must be preserved.
x: input tensor(B*T, n_embd)w_fc: first linear weight(4*n_embd, n_embd)w_proj: second linear weight(n_embd, 4*n_embd)- Returns: output tensor
(B*T, n_embd)
The MLP class calls this function and handles dropout separately.
Evaluation
- Metric: Validation loss (cross-entropy, lower is better) — kernel optimizations that also change the activation function may improve loss
- Model sizes: GPT-2 124M (12L/12H/768D) and GPT-2 1.5B (48L/25H/1600D, 4-GPU DDP)
- Dataset: FineWeb 10B (GPT-2 tokenizer)
- Training: 5000 iterations, batch_size=12, block_size=1024, grad_accum=5
- Hardware: H100 GPU with Triton support
nanoGPT/custom_pretrain.py [EDITABLE — lines 33–48, lines 257–259 only]
1: """Custom GPT-2 Pretraining Script
2: Based on Andrej Karpathy's nanoGPT, evaluated on FineWeb dataset.
3: """
4:
5: import math
6: import inspect
7: import os
8: import time
9: from contextlib import nullcontext
10: from dataclasses import dataclass
11:
12: import numpy as np
13: import torch
14: import torch.nn as nn
15: from torch.nn import functional as F
16:
17: # ============================================================================
18: # Model Components
19: # ============================================================================
20:
21: # ── Normalization ──────────────────────────────────────────────────────────
22: class LayerNorm(nn.Module):
23: """LayerNorm but with an optional bias."""
24: def __init__(self, ndim, bias):
25: super().__init__()
26: self.weight = nn.Parameter(torch.ones(ndim))
27: self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
28:
29: def forward(self, input):
30: return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
31:
32: # ── Custom Kernel / Fused Operation ───────────────────────────────────────
33: def fused_mlp_forward(x, w_fc, w_proj):
34: """MLP forward pass: linear -> activation -> linear.
35:
36: Default implementatioAssistantStep 1edit
Tool Input
1{
2 "content": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef _interleaved_sym_swiglu_fwd_kernel(\n h_ptr, act_ptr, n_pairs, BLOCK_SIZE: tl.constexpr\n):\n pid = tl.program_id(0)\n offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_pairs\n idx1 = 2 * offsets\n idx2 = 2 * offsets + 1\n h1_load = tl.load(h_ptr + idx1, mask=mask)\n h2_load = tl.load(h_ptr + idx2, mask=mask)\n h1 = h1_load.to(tl.float32)\n h2 = h2_load.to(tl.float32)\n sig_h1 = 1.0 / (1.0 + tl.exp(-h1))\n sig_h2 = 1.0 / (1.0 + tl.exp(-h2))\n silu_h1 = h1 * sig_h1\n silu_h2 = h2 * sig_h2\n act1 = h1 * silu_h2\n act2 = h2 * silu_h1\n tl.store(act_ptr + idx1, act1.to(h1_load.dtype), mask=mask)\n tl.store(act_ptr + idx2, act2.to(h2_load.dtype), mask=mask)\n\n@triton.jit\ndef _interleaved_sym_swiglu_bwd_kernel(\n dact_ptr, h_ptr, dh_ptr, n_pairs, BLOCK_SIZE: tl.constexpr\n):\n pid = tl.program_id(0)\n offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_pairs\n idx1 = 2 * offsets\n idx2 = 2 * offsets + 1\n dact1_load = tl.load(dact_ptr + idx1, mask=mask)\n dact2_load = tl.load(dact_ptr + idx2, mask=mask)\n h1_load = tl.load(h_ptr + idx1, mask=mask)\n h2_load = tl.load(h_ptr + idx2, mask=mask)\n dact1 = dact1_load.to(tl.float32)\n dact2 = dact2_load.to(tl.float32)\n h1 = h1_load.to(tl.float32)\n h2 = h2_load.to(tl.float32)\n sig_h1 = 1.0 / (1.0 + tl.exp(-h1))\n sig_h2 = 1.0 / (1.0 + tl.exp(-h2))\n silu_h1 = h1 * sig_h1\n silu_h2 = h2 * sig_h2\n silu_prime_h1 = sig_h1 * (1.0 + h1 * (1.0 - sig_h1))\n silu_prime_h2 = sig_h2 * (1.0 + h2 * (1.0 - sig_h2))\n dh1 = dact1 * silu_h2 + dact2 * (h2 * silu_prime_h1)\n dh2 = dact2 * silu_h1 + dact1 * (h1 * silu_prime_h2)\n tl.store(dh_ptr + idx1, dh1.to(h1_load.dtype), mask=mask)\n tl.store(dh_ptr + idx2, dh2.to(h2_load.dtype), mask=mask)\n\nclass SymSwiGLU_MLP(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x, w_fc, w_proj):\n h = (x @ w_fc.t()).contiguous()\n act = torch.empty_like(h)\n n_pairs = h.numel() // 2\n BLOCK = 1024\n grid = ((n_pairs + BLOCK - 1) // BLOCK,)\n _interleaved_sym_swiglu_fwd_kernel[grid](h, act, n_pairs, BLOCK_SIZE=BLOCK)\n out = act @ w_proj.t()\n ctx.save_for_backward(x, w_fc, w_proj, h, act)\n return out\n\n @staticmethod\n def backward(ctx, grad_output):\n x, w_fc, w_proj, h, act = ctx.saved_tensors\n dtype = grad_output.dtype\n d_act = (grad_output @ w_proj.to(dtype)).contiguous()\n dh = torch.empty_like(h)\n n_pairs = h.numel() // 2\n BLOCK = 1024\n grid = ((n_pairs + BLOCK - 1) // BLOCK,)\n _interleaved_sym_swiglu_bwd_kernel[grid](d_act, h, dh, n_pairs, BLOCK_SIZE=BLOCK)\n grad_w_proj = grad_output.reshape(-1, grad_output.shape[-1]).t() @ act.to(dtype).reshape(-1, act.shape[-1])\n grad_x = dh @ w_fc.to(dtype)\n grad_w_fc = dh.reshape(-1, dh.shape[-1]).t() @ x.to(dtype).reshape(-1, x.shape[-1])\n return grad_x, grad_w_fc, grad_w_proj\n\ndef fused_mlp_forward(x, w_fc, w_proj):\n return SymSwiGLU_MLP.apply(x, w_fc, w_proj)",
3 "start_line": 33,
4 "end_line": 48,
5 "op": "replace",
6 "filename": "nanoGPT/custom_pretrain.py"
7}Tool ResultStep 1
OK: Replaced lines 33..48 with 83 line(s) in nanoGPT/custom_pretrain.py. Editable range: 33–115, 324–326.
[Current file: nanoGPT/custom_pretrain.py | editable: 33–115 | total: 516 lines]
30| return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
31|
32| # ── Custom Kernel / Fused Operation ───────────────────────────────────────
33 import triton
34 import triton.language as tl
35
36 @triton.jit
37 def _interleaved_sym_swiglu_fwd_kernel(
38 h_ptr, act_ptr, n_pairs, BLOCK_SIZE: tl.constexpr
39 ):
40 pid = tl.program_id(0)
41 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
42 mask = offsets < n_pairs
43 idx1 = 2 * offsets
44 idx2 = 2 * offsets + 1
45 h1_load = tl.load(h_ptr + idx1, mask=mask)
46 h2_load = tl.load(h_ptr + idx2, mask=mask)
47 h1 = h1_load.to(tl.float32)
48 h2 = h2_load.to(tl.float32)
49 sig_h1 = 1.
AssistantStep 2test
Tool ResultStep 2
[Test #1] (2 tests remaining; call submit(n=N) to choose which test result to submit as final) ### gpt-345m (scripts/gpt_345m.sh) [COMMAND FAILED exit=-1] [output file not found: /scratch/gpfs/CHIJ/bohan/MLS-Bench/logs/llm-pretrain-kernel/gemini-3.1-pro-preview_20260412_102249/20260412_103127/group_1/gpt-345m_s42.out] ### lm-eval-345m (scripts/lm_eval_345m.sh) [COMMAND FAILED exit=1] ERROR: Checkpoint not found: /scratch/gpfs/CHIJ/bohan/MLS-Bench/.saves/llm-pretrain-kernel/gemini-3.1-pro-preview_20260412_102249/seed_42/ckpt_gpt-345m.pt
AssistantStep 3edit
Tool Input
1{
2 "start_line": 49,
3 "op": "replace",
4 "end_line": 50,
5 "filename": "nanoGPT/custom_pretrain.py",
6 "content": " sig_h1 = tl.sigmoid(h1)\n sig_h2 = tl.sigmoid(h2)"
7}Tool ResultStep 3
OK: Replaced lines 49..50 with 2 line(s) in nanoGPT/custom_pretrain.py. Editable range: 33–115, 324–326.
[Current file: nanoGPT/custom_pretrain.py | editable: 33–115 | total: 516 lines]
30| return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
31|
32| # ── Custom Kernel / Fused Operation ───────────────────────────────────────
33 import triton
34 import triton.language as tl
35
36 @triton.jit
37 def _interleaved_sym_swiglu_fwd_kernel(
38 h_ptr, act_ptr, n_pairs, BLOCK_SIZE: tl.constexpr
39 ):
40 pid = tl.program_id(0)
41 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
42 mask = offsets < n_pairs
43 idx1 = 2 * offsets
44 idx2 = 2 * offsets + 1
45 h1_load = tl.load(h_ptr + idx1, mask=mask)
46 h2_load = tl.load(h_ptr + idx2, mask=mask)
47 h1 = h1_load.to(tl.float32)
48 h2 = h2_load.to(tl.float32)
49 sig_h1 = tl.
AssistantStep 4edit
Tool Input
1{
2 "filename": "nanoGPT/custom_pretrain.py",
3 "op": "replace",
4 "end_line": 76,
5 "start_line": 75,
6 "content": " sig_h1 = tl.sigmoid(h1)\n sig_h2 = tl.sigmoid(h2)"
7}Tool ResultStep 4
OK: Replaced lines 75..76 with 2 line(s) in nanoGPT/custom_pretrain.py. Editable range: 33–115, 324–326.
[Current file: nanoGPT/custom_pretrain.py | editable: 33–115 | total: 516 lines]
30| return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
31|
32| # ── Custom Kernel / Fused Operation ───────────────────────────────────────
33 import triton
34 import triton.language as tl
35
36 @triton.jit
37 def _interleaved_sym_swiglu_fwd_kernel(
38 h_ptr, act_ptr, n_pairs, BLOCK_SIZE: tl.constexpr
39 ):
40 pid = tl.program_id(0)
41 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
42 mask = offsets < n_pairs
43 idx1 = 2 * offsets
44 idx2 = 2 * offsets + 1
45 h1_load = tl.load(h_ptr + idx1, mask=mask)
46 h2_load = tl.load(h_ptr + idx2, mask=mask)
47 h1 = h1_load.to(tl.float32)
48 h2 = h2_load.to(tl.float32)
49 sig_h1 = tl.
AssistantStep 5test
Tool ResultStep 5
[Test #2] (1 test remaining; call submit(n=N) to choose which test result to submit as final)
### gpt-345m (scripts/gpt_345m.sh)
Training metrics (last steps):
TRAIN_METRICS: step=9000, train_loss=2.6427, val_loss=2.3681
TRAIN_METRICS: step=10000, train_loss=2.6222, val_loss=2.3410
TRAIN_METRICS: step=11000, train_loss=2.6051, val_loss=2.3330
TRAIN_METRICS: step=12000, train_loss=2.5864, val_loss=2.3163
TRAIN_METRICS: step=13000, train_loss=2.5713, val_loss=2.3062
Final metrics (gpt-345m): val_loss_gpt-345m=2.3002, wikitext2_ppl_gpt-345m=44.8200, lambada_ppl_gpt-345m=68.4700
### lm-eval-345m (scripts/lm_eval_345m.sh)
Final metrics (lm-eval-345m): arc_easy_lm-eval-345m=54.2900, hellaswag_lm-eval-345m=32.8300, piqa_lm-eval-345m=64.1500, winogrande_lm-eval-345m=51.6200
[Leaderboard] Results saved: {'val_loss_gpt-345m': 2.3002, 'wikitext2_ppl_gpt-345m': 44.82, 'lambada_ppl_gpt-345m': 68.47, 'elapsed_gpt-345m': 19910, 'arc_easy_lm-eval-345m': 54.29, 'hellaswag_lm-eval-345m': 32.83, 'piqa
AssistantStep 6submit
Tool Input
1{
2 "n": 2
3}Tool ResultStep 6
[submit] Submitting result from test #2 as final.
### gpt-345m (scripts/gpt_345m.sh)
Training metrics (last steps):
TRAIN_METRICS: step=9000, train_loss=2.6427, val_loss=2.3681
TRAIN_METRICS: step=10000, train_loss=2.6222, val_loss=2.3410
TRAIN_METRICS: step=11000, train_loss=2.6051, val_loss=2.3330
TRAIN_METRICS: step=12000, train_loss=2.5864, val_loss=2.3163
TRAIN_METRICS: step=13000, train_loss=2.5713, val_loss=2.3062
Final metrics (gpt-345m): val_loss_gpt-345m=2.3002, wikitext2_ppl_gpt-345m=44.8200, lambada_ppl_gpt-345m=68.4700
### lm-eval-345m (scripts/lm_eval_345m.sh)
Final metrics (lm-eval-345m): arc_easy_lm-eval-345m=54.2900, hellaswag_lm-eval-345m=32.8300, piqa_lm-eval-345m=64.1500, winogrande_lm-eval-345m=51.6200
[submit] Finalized test #2 as final.
[Leaderboard] Results saved: {'val_loss_gpt-345m': 2.3002, 'wikitext2_ppl_gpt-345m': 44.82, 'lambada_ppl_gpt-345m': 68.47, 'elapsed_gpt-345m': 19910, 'arc_easy_lm-eval-345m': 54.29, 'hellaswag_lm-eval-345m': 32.83, 'piqa_lm-ev