Agent Conversation: gemini-3.1-pro-preview

Model: gemini-3.1-pro-preview
Total steps: 6
Messages: 14
UserStep 0

Task: llm-pretrain-kernel

LLM Pretraining: Custom GPU Kernel Optimization

Research Question

Write a custom GPU kernel (Triton or CUDA via PyTorch) to implement a fused MLP operation for GPT-2 pretraining. Your kernel should fuse multiple operations to reduce memory bandwidth and improve throughput while maintaining or improving model quality.

What You Can Modify

The fused_mlp_forward function (lines 34-48) in custom_pretrain.py:

  • The MLP activation function (default: GELU via separate PyTorch ops)
  • Kernel fusion strategy (fuse linear + activation, save intermediate values)
  • Memory optimization (avoid materializing intermediate tensors)
  • Custom autograd Functions for efficient backward pass

Note: The function signature fused_mlp_forward(x, w_fc, w_proj) must be preserved.

  • x: input tensor (B*T, n_embd)
  • w_fc: first linear weight (4*n_embd, n_embd)
  • w_proj: second linear weight (n_embd, 4*n_embd)
  • Returns: output tensor (B*T, n_embd)

The MLP class calls this function and handles dropout separately.

Evaluation

  • Metric: Validation loss (cross-entropy, lower is better) — kernel optimizations that also change the activation function may improve loss
  • Model sizes: GPT-2 124M (12L/12H/768D) and GPT-2 1.5B (48L/25H/1600D, 4-GPU DDP)
  • Dataset: FineWeb 10B (GPT-2 tokenizer)
  • Training: 5000 iterations, batch_size=12, block_size=1024, grad_accum=5
  • Hardware: H100 GPU with Triton support

nanoGPT/custom_pretrain.py [EDITABLE — lines 33–48, lines 257–259 only]

     1: """Custom GPT-2 Pretraining Script
     2: Based on Andrej Karpathy's nanoGPT, evaluated on FineWeb dataset.
     3: """
     4: 
     5: import math
     6: import inspect
     7: import os
     8: import time
     9: from contextlib import nullcontext
    10: from dataclasses import dataclass
    11: 
    12: import numpy as np
    13: import torch
    14: import torch.nn as nn
    15: from torch.nn import functional as F
    16: 
    17: # ============================================================================
    18: # Model Components
    19: # ============================================================================
    20: 
    21: # ── Normalization ──────────────────────────────────────────────────────────
    22: class LayerNorm(nn.Module):
    23:     """LayerNorm but with an optional bias."""
    24:     def __init__(self, ndim, bias):
    25:         super().__init__()
    26:         self.weight = nn.Parameter(torch.ones(ndim))
    27:         self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
    28: 
    29:     def forward(self, input):
    30:         return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
    31: 
    32: # ── Custom Kernel / Fused Operation ───────────────────────────────────────
    33: def fused_mlp_forward(x, w_fc, w_proj):
    34:     """MLP forward pass: linear -> activation -> linear.
    35: 
    36:     Default implementatio
AssistantStep 1edit
Tool Input
1{
2  "content": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef _interleaved_sym_swiglu_fwd_kernel(\n    h_ptr, act_ptr, n_pairs, BLOCK_SIZE: tl.constexpr\n):\n    pid = tl.program_id(0)\n    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    mask = offsets < n_pairs\n    idx1 = 2 * offsets\n    idx2 = 2 * offsets + 1\n    h1_load = tl.load(h_ptr + idx1, mask=mask)\n    h2_load = tl.load(h_ptr + idx2, mask=mask)\n    h1 = h1_load.to(tl.float32)\n    h2 = h2_load.to(tl.float32)\n    sig_h1 = 1.0 / (1.0 + tl.exp(-h1))\n    sig_h2 = 1.0 / (1.0 + tl.exp(-h2))\n    silu_h1 = h1 * sig_h1\n    silu_h2 = h2 * sig_h2\n    act1 = h1 * silu_h2\n    act2 = h2 * silu_h1\n    tl.store(act_ptr + idx1, act1.to(h1_load.dtype), mask=mask)\n    tl.store(act_ptr + idx2, act2.to(h2_load.dtype), mask=mask)\n\n@triton.jit\ndef _interleaved_sym_swiglu_bwd_kernel(\n    dact_ptr, h_ptr, dh_ptr, n_pairs, BLOCK_SIZE: tl.constexpr\n):\n    pid = tl.program_id(0)\n    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    mask = offsets < n_pairs\n    idx1 = 2 * offsets\n    idx2 = 2 * offsets + 1\n    dact1_load = tl.load(dact_ptr + idx1, mask=mask)\n    dact2_load = tl.load(dact_ptr + idx2, mask=mask)\n    h1_load = tl.load(h_ptr + idx1, mask=mask)\n    h2_load = tl.load(h_ptr + idx2, mask=mask)\n    dact1 = dact1_load.to(tl.float32)\n    dact2 = dact2_load.to(tl.float32)\n    h1 = h1_load.to(tl.float32)\n    h2 = h2_load.to(tl.float32)\n    sig_h1 = 1.0 / (1.0 + tl.exp(-h1))\n    sig_h2 = 1.0 / (1.0 + tl.exp(-h2))\n    silu_h1 = h1 * sig_h1\n    silu_h2 = h2 * sig_h2\n    silu_prime_h1 = sig_h1 * (1.0 + h1 * (1.0 - sig_h1))\n    silu_prime_h2 = sig_h2 * (1.0 + h2 * (1.0 - sig_h2))\n    dh1 = dact1 * silu_h2 + dact2 * (h2 * silu_prime_h1)\n    dh2 = dact2 * silu_h1 + dact1 * (h1 * silu_prime_h2)\n    tl.store(dh_ptr + idx1, dh1.to(h1_load.dtype), mask=mask)\n    tl.store(dh_ptr + idx2, dh2.to(h2_load.dtype), mask=mask)\n\nclass SymSwiGLU_MLP(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, x, w_fc, w_proj):\n        h = (x @ w_fc.t()).contiguous()\n        act = torch.empty_like(h)\n        n_pairs = h.numel() // 2\n        BLOCK = 1024\n        grid = ((n_pairs + BLOCK - 1) // BLOCK,)\n        _interleaved_sym_swiglu_fwd_kernel[grid](h, act, n_pairs, BLOCK_SIZE=BLOCK)\n        out = act @ w_proj.t()\n        ctx.save_for_backward(x, w_fc, w_proj, h, act)\n        return out\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        x, w_fc, w_proj, h, act = ctx.saved_tensors\n        dtype = grad_output.dtype\n        d_act = (grad_output @ w_proj.to(dtype)).contiguous()\n        dh = torch.empty_like(h)\n        n_pairs = h.numel() // 2\n        BLOCK = 1024\n        grid = ((n_pairs + BLOCK - 1) // BLOCK,)\n        _interleaved_sym_swiglu_bwd_kernel[grid](d_act, h, dh, n_pairs, BLOCK_SIZE=BLOCK)\n        grad_w_proj = grad_output.reshape(-1, grad_output.shape[-1]).t() @ act.to(dtype).reshape(-1, act.shape[-1])\n        grad_x = dh @ w_fc.to(dtype)\n        grad_w_fc = dh.reshape(-1, dh.shape[-1]).t() @ x.to(dtype).reshape(-1, x.shape[-1])\n        return grad_x, grad_w_fc, grad_w_proj\n\ndef fused_mlp_forward(x, w_fc, w_proj):\n    return SymSwiGLU_MLP.apply(x, w_fc, w_proj)",
3  "start_line": 33,
4  "end_line": 48,
5  "op": "replace",
6  "filename": "nanoGPT/custom_pretrain.py"
7}
Tool ResultStep 1
OK: Replaced lines 33..48 with 83 line(s) in nanoGPT/custom_pretrain.py. Editable range: 33–115, 324–326.

[Current file: nanoGPT/custom_pretrain.py | editable: 33–115 | total: 516 lines]
    30|         return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
    31| 
    32| # ── Custom Kernel / Fused Operation ───────────────────────────────────────
    33  import triton
    34  import triton.language as tl
    35  
    36  @triton.jit
    37  def _interleaved_sym_swiglu_fwd_kernel(
    38      h_ptr, act_ptr, n_pairs, BLOCK_SIZE: tl.constexpr
    39  ):
    40      pid = tl.program_id(0)
    41      offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    42      mask = offsets < n_pairs
    43      idx1 = 2 * offsets
    44      idx2 = 2 * offsets + 1
    45      h1_load = tl.load(h_ptr + idx1, mask=mask)
    46      h2_load = tl.load(h_ptr + idx2, mask=mask)
    47      h1 = h1_load.to(tl.float32)
    48      h2 = h2_load.to(tl.float32)
    49      sig_h1 = 1.
AssistantStep 2test
Tool ResultStep 2
[Test #1] (2 tests remaining; call submit(n=N) to choose which test result to submit as final)

### gpt-345m (scripts/gpt_345m.sh)
[COMMAND FAILED exit=-1]
[output file not found: /scratch/gpfs/CHIJ/bohan/MLS-Bench/logs/llm-pretrain-kernel/gemini-3.1-pro-preview_20260412_102249/20260412_103127/group_1/gpt-345m_s42.out]

### lm-eval-345m (scripts/lm_eval_345m.sh)
[COMMAND FAILED exit=1]
ERROR: Checkpoint not found: /scratch/gpfs/CHIJ/bohan/MLS-Bench/.saves/llm-pretrain-kernel/gemini-3.1-pro-preview_20260412_102249/seed_42/ckpt_gpt-345m.pt
AssistantStep 3edit
Tool Input
1{
2  "start_line": 49,
3  "op": "replace",
4  "end_line": 50,
5  "filename": "nanoGPT/custom_pretrain.py",
6  "content": "    sig_h1 = tl.sigmoid(h1)\n    sig_h2 = tl.sigmoid(h2)"
7}
Tool ResultStep 3
OK: Replaced lines 49..50 with 2 line(s) in nanoGPT/custom_pretrain.py. Editable range: 33–115, 324–326.

[Current file: nanoGPT/custom_pretrain.py | editable: 33–115 | total: 516 lines]
    30|         return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
    31| 
    32| # ── Custom Kernel / Fused Operation ───────────────────────────────────────
    33  import triton
    34  import triton.language as tl
    35  
    36  @triton.jit
    37  def _interleaved_sym_swiglu_fwd_kernel(
    38      h_ptr, act_ptr, n_pairs, BLOCK_SIZE: tl.constexpr
    39  ):
    40      pid = tl.program_id(0)
    41      offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    42      mask = offsets < n_pairs
    43      idx1 = 2 * offsets
    44      idx2 = 2 * offsets + 1
    45      h1_load = tl.load(h_ptr + idx1, mask=mask)
    46      h2_load = tl.load(h_ptr + idx2, mask=mask)
    47      h1 = h1_load.to(tl.float32)
    48      h2 = h2_load.to(tl.float32)
    49      sig_h1 = tl.
AssistantStep 4edit
Tool Input
1{
2  "filename": "nanoGPT/custom_pretrain.py",
3  "op": "replace",
4  "end_line": 76,
5  "start_line": 75,
6  "content": "    sig_h1 = tl.sigmoid(h1)\n    sig_h2 = tl.sigmoid(h2)"
7}
Tool ResultStep 4
OK: Replaced lines 75..76 with 2 line(s) in nanoGPT/custom_pretrain.py. Editable range: 33–115, 324–326.

[Current file: nanoGPT/custom_pretrain.py | editable: 33–115 | total: 516 lines]
    30|         return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
    31| 
    32| # ── Custom Kernel / Fused Operation ───────────────────────────────────────
    33  import triton
    34  import triton.language as tl
    35  
    36  @triton.jit
    37  def _interleaved_sym_swiglu_fwd_kernel(
    38      h_ptr, act_ptr, n_pairs, BLOCK_SIZE: tl.constexpr
    39  ):
    40      pid = tl.program_id(0)
    41      offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    42      mask = offsets < n_pairs
    43      idx1 = 2 * offsets
    44      idx2 = 2 * offsets + 1
    45      h1_load = tl.load(h_ptr + idx1, mask=mask)
    46      h2_load = tl.load(h_ptr + idx2, mask=mask)
    47      h1 = h1_load.to(tl.float32)
    48      h2 = h2_load.to(tl.float32)
    49      sig_h1 = tl.
AssistantStep 5test
Tool ResultStep 5
[Test #2] (1 test remaining; call submit(n=N) to choose which test result to submit as final)

### gpt-345m (scripts/gpt_345m.sh)
Training metrics (last steps):
TRAIN_METRICS: step=9000, train_loss=2.6427, val_loss=2.3681
TRAIN_METRICS: step=10000, train_loss=2.6222, val_loss=2.3410
TRAIN_METRICS: step=11000, train_loss=2.6051, val_loss=2.3330
TRAIN_METRICS: step=12000, train_loss=2.5864, val_loss=2.3163
TRAIN_METRICS: step=13000, train_loss=2.5713, val_loss=2.3062
Final metrics (gpt-345m): val_loss_gpt-345m=2.3002, wikitext2_ppl_gpt-345m=44.8200, lambada_ppl_gpt-345m=68.4700

### lm-eval-345m (scripts/lm_eval_345m.sh)
Final metrics (lm-eval-345m): arc_easy_lm-eval-345m=54.2900, hellaswag_lm-eval-345m=32.8300, piqa_lm-eval-345m=64.1500, winogrande_lm-eval-345m=51.6200

[Leaderboard] Results saved: {'val_loss_gpt-345m': 2.3002, 'wikitext2_ppl_gpt-345m': 44.82, 'lambada_ppl_gpt-345m': 68.47, 'elapsed_gpt-345m': 19910, 'arc_easy_lm-eval-345m': 54.29, 'hellaswag_lm-eval-345m': 32.83, 'piqa
AssistantStep 6submit
Tool Input
1{
2  "n": 2
3}
Tool ResultStep 6
[submit] Submitting result from test #2 as final.

### gpt-345m (scripts/gpt_345m.sh)
Training metrics (last steps):
TRAIN_METRICS: step=9000, train_loss=2.6427, val_loss=2.3681
TRAIN_METRICS: step=10000, train_loss=2.6222, val_loss=2.3410
TRAIN_METRICS: step=11000, train_loss=2.6051, val_loss=2.3330
TRAIN_METRICS: step=12000, train_loss=2.5864, val_loss=2.3163
TRAIN_METRICS: step=13000, train_loss=2.5713, val_loss=2.3062
Final metrics (gpt-345m): val_loss_gpt-345m=2.3002, wikitext2_ppl_gpt-345m=44.8200, lambada_ppl_gpt-345m=68.4700

### lm-eval-345m (scripts/lm_eval_345m.sh)
Final metrics (lm-eval-345m): arc_easy_lm-eval-345m=54.2900, hellaswag_lm-eval-345m=32.8300, piqa_lm-eval-345m=64.1500, winogrande_lm-eval-345m=51.6200

[submit] Finalized test #2 as final.

[Leaderboard] Results saved: {'val_loss_gpt-345m': 2.3002, 'wikitext2_ppl_gpt-345m': 44.82, 'lambada_ppl_gpt-345m': 68.47, 'elapsed_gpt-345m': 19910, 'arc_easy_lm-eval-345m': 54.29, 'hellaswag_lm-eval-345m': 32.83, 'piqa_lm-ev