Agent Conversation: qwen3.6-plus
Task: llm-pretrain-bitlinear
LLM Pretraining: Native Low-Bit Linear (BitLinear)
Research Question
Design a low-bit linear layer for GPT-2 pretraining that uses native low-precision weights (binary/ternary) during both training and inference, instead of standard float weights. The goal is to minimize validation loss while constraining weights to a small discrete set.
Background
Standard neural networks store and compute with full-precision (FP32/BF16) weights. Post-training quantization (PTQ) and quantization-aware training (QAT) attempt to compress these weights after or during training, but the model fundamentally trains with float weights. Native low-bit training takes a different approach: weights are inherently discrete (e.g., {-1, +1} or {-1, 0, +1}) during every forward pass, with float latent weights maintained only for gradient accumulation.
This paradigm was introduced by BitNet (Wang et al., 2023), which binarized weights to {-1, +1} using the sign function, and extended by BitNet b1.58 (Ma et al., 2024), which used ternary {-1, 0, +1} weights via absmean quantization. The key insight is that these models can match or approach full-precision performance at a fraction of the effective parameter cost.
Key differences from QAT (the llm-pretrain-quantization task):
- QAT: float weights -> fake quantize during training -> real quantize at deployment (weights are float during training)
- BitLinear: float latent weights -> discrete quantize in every forward pass (weights are always discrete during computation)
Key differences from mixed-precision (the llm-pretrain-precision task):
- Mixed precision: changes the float format (FP32 -> BF16/FP8) but values are still continuous
- BitLinear: weights are restricted to a small discrete set (1-2 bits)
What You Can Modify
The BitLinear module (lines 38-115) in custom_pretrain.py:
weight_quant(weight)-- quantizes float latent weights to discrete values, returns (quantized_weight, scale)activation_quant(x)-- optional activation quantization, returns (quantized_x, scale)BitLinearclass -- linear layer that uses the above functions
Notes on the interface:
BitLinear.__init__(self, in_features, out_features, bias=True)must keepself.weightas a ParameterBitLinear.forward(self, x) -> outputwhere x has shape(..., in_features)and output has shape(..., out_features)- The quantization is applied in every forward pass (both training and eval) -- there is no separate train/eval path
weight_quantshould return(quantized_weight, scale)wherequantized_weight * scaleapproximates the original weightactivation_quantshould return(quantized_x, scale)similarly- All model linear layers (attention, MLP, lm_head) use
BitLinear - You may add helper classes (autograd Functions, learned parameters) alongside the existing functions
- Must be compatible with
torch.compile(no@torch.compiler.disable)
Evaluation
- **Primary me
Tool Input
1{
2 "op": "replace",
3 "filename": "nanoGPT/custom_pretrain.py",
4 "content": "# -- Native Low-Bit Linear (BitLinear) Module ---------------------------------\nclass QuantizeContext:\n \"\"\"Global context tracking training progress for progressive quantization.\"\"\"\n current_step = 0\n max_steps = 13535 # default, will be updated\n\ndef weight_quant(weight):\n \"\"\"Progressive Distribution-Aware Ternary Quantization {-1, 0, +1}.\n\n This implementation introduces three novel techniques:\n \n 1. **Progressive Quantization Sharpening**: Uses a smooth tanh-based \n quantization that gradually hardens during training. Early training \n uses soft quantization (sigma=2.0) to allow good representation learning,\n then progressively sharpens to hard quantization (sigma=0.1).\n \n 2. **Distribution-Aware Adaptive Thresholds**: Instead of fixed absmean,\n uses a percentile-based threshold (85th percentile of |W|) that better\n captures the weight distribution's effective range.\n \n 3. **Gradient Variance Normalization**: Scales the backward gradient \n based on the quantization error to stabilize training with coarse \n discretization.\n\n Args:\n weight: float latent weight tensor [out_features, in_features]\n Returns:\n (quantized_weight, scale): quantized weight tensor and per-tensor scale\n \"\"\"\n # Distribution-aware threshold: use percentile instead of mean\n # This keeps ~85% of weight magnitude information while discarding outliers\n abs_weight = weight.detach().abs()\n # Use 85th percentile as threshold - captures most of the distribution\n threshold = abs_weight.flatten().quantile(0.85).clamp(min=1e-12)\n scale = threshold\n \n # Normalize weights to [-1, 1] range based on adaptive threshold\n w_normed = weight / scale\n \n # Progressive quantization sharpening\n # Compute progress ratio (0.0 = start, 1.0 = end)\n progress = min(1.0, QuantizeContext.current_step / max(1, QuantizeContext.max_steps * 0.6))\n # Sigma decreases from 2.0 (soft) to 0.1 (hard) as training progresses\n sigma = 2.0 - 1.9 * progress\n \n # Soft quantization using tanh approximation of sign\n # For ternary: quantize to {-1, 0, +1} based on threshold\n # Use soft thresholding: values near 0 get pushed to 0, others to ±1\n w_clamped = w_normed.clamp(-1.5, 1.5)\n \n # Soft ternary quantization using difference of sigmoids\n # This creates smooth transitions between -1, 0, and +1\n soft_sign = torch.tanh(w_clamped / sigma)\n # Create soft zero region: push small values toward 0\n zero_mask = torch.sigmoid((w_clamped.abs() - 0.5) / sigma)\n w_q = soft_sign * zero_mask\n \n # STE with gradient scaling based on quantization fidelity\n # When quantization is coarse (early training), scale down gradients\n quant_error = (w_q - w_normed).detach()\n fidelity = 1.0 / (1.0 + quant_error.abs().mean())\n w_q_ste = (w_q - w_normed).detach() + w_normed\n \n return w_q_ste, scale\n\n\ndef activation_quant(x):\n \"\"\"Absmax 8-bit activation quantization with STE.\n \n Quantizes activations to 127 levels (int8 range) using per-tensor\n absmax scaling, following the BitNet b1.58 paper.\n \"\"\"\n Qb = 127 # int8 range\n scale = x.detach().abs().max().clamp(min=1e-12)\n x_normed = x / scale\n x_q = (x_normed * Qb).round().clamp(-Qb, Qb)\n # STE: forward uses quantized, backward passes through\n x_q = (x_q - x_normed * Qb).detach() + x_normed * Qb\n return x_q, scale / Qb\n\n\nclass BitLinear(nn.Module):\n \"\"\"BitLinear with Progressive Distribution-Aware Quantization.\n\n Unlike standard BitLinear which uses fixed thresholds and hard quantization\n from the start, this implementation uses:\n - Progressive quantization sharpening (soft -> hard over training)\n - Distribution-aware adaptive thresholds (percentile-based)\n - Proper output rescaling for quantized computation\n \"\"\"\n def __init__(self, in_features, out_features, bias=True):\n super().__init__()\n self.in_features = in_features\n self.out_features = out_features\n self.weight = nn.Parameter(torch.empty(out_features, in_features))\n if bias:\n self.bias = nn.Parameter(torch.zeros(out_features))\n else:\n self.bias = None\n nn.init.normal_(self.weight, mean=0.0, std=0.02)\n\n def forward(self, x):\n w_q, w_scale = weight_quant(self.weight)\n x_q, x_scale = activation_quant(x)\n # Perform matmul with quantized values, then rescale\n out = F.linear(x_q, w_q, None)\n # Rescale output: the true output ~ (x_scale * w_scale) * out_quantized\n out = out * (w_scale * x_scale)\n if self.bias is not None:\n out = out + self.bias\n return out",
5 "start_line": 38,
6 "end_line": 115
7}OK: Replaced lines 38..115 with 111 line(s) in nanoGPT/custom_pretrain.py. Editable range: 38–148, 361–361.
[Current file: nanoGPT/custom_pretrain.py | editable: 38–148, 361–361 | total: 551 lines]
38 # -- Native Low-Bit Linear (BitLinear) Module ---------------------------------
39 class QuantizeContext:
40 """Global context tracking training progress for progressive quantization."""
... (105 more lines) ...
146 if self.bias is not None:
147 out = out + self.bias
148 return out
...
361 CONFIG_OVERRIDES = {}Tool Input
1{
2 "content": "# -- Native Low-Bit Linear (BitLinear) Module ---------------------------------\ndef weight_quant(weight):\n \"\"\"Distribution-Aware Progressive Ternary Quantization {-1, 0, +1}.\n\n Novel contributions:\n 1. Adaptive threshold based on the 85th percentile of |W| instead of mean,\n which better handles heavy-tailed weight distributions common in LLMs.\n 2. Soft ternary quantization using tanh-based approximation that\n progressively sharpens during training via an annealing schedule.\n 3. Channel-wise quantization for weight matrices, providing finer\n granularity than per-tensor quantization while maintaining efficiency.\n\n Args:\n weight: float latent weight tensor [out_features, in_features]\n Returns:\n (quantized_weight, scale): quantized weight tensor and scale\n \"\"\"\n abs_weight = weight.abs()\n \n # Channel-wise (per-output-channel) percentile-based threshold\n # This adapts to each row's weight distribution\n scale = abs_weight.quantile(0.85, dim=1, keepdim=True).clamp(min=1e-8)\n \n # Normalize weights to roughly [-1, 1] range\n w_normed = weight / scale\n \n # Soft ternary quantization using smooth approximation\n # sigma controls the sharpness: larger = softer, smaller = harder\n # Start soft (sigma=1.0) and progressively harden to sigma=0.05\n # We use a simple step-based annealing embedded in the forward pass\n sigma = 0.15 # Fixed sharpness for ternary (good balance)\n \n # Clamp to prevent extreme values\n w_clamped = w_normed.clamp(-1.5, 1.5)\n \n # Soft ternary: use tanh to approximate sign, and sigmoid for zero region\n # sign(x) ≈ tanh(x/sigma)\n # zero_region(x) ≈ sigmoid((|x| - 0.5)/sigma) -- pushes small values to 0\n soft_sign = torch.tanh(w_clamped / sigma)\n zero_mask = torch.sigmoid((w_clamped.abs() - 0.5) / (sigma * 0.5))\n w_q = soft_sign * zero_mask\n \n # STE: forward uses quantized w_q, backward passes gradient through w_normed\n # Add a small regularization term to encourage weights to stay near quantized values\n w_q_ste = (w_q - w_normed).detach() + w_normed\n \n return w_q_ste, scale.squeeze(1) # squeeze for broadcasting in matmul\n\n\ndef activation_quant(x):\n \"\"\"Absmax 8-bit activation quantization with STE.\n \n Quantizes activations to 127 levels (int8 range) using per-tensor\n absmax scaling, following the BitNet b1.58 paper.\n \"\"\"\n Qb = 127 # int8 range\n scale = x.detach().abs().max().clamp(min=1e-12)\n x_normed = x / scale\n x_q = (x_normed * Qb).round().clamp(-Qb, Qb)\n # STE: forward uses quantized, backward passes through\n x_q = (x_q - x_normed * Qb).detach() + x_normed * Qb\n return x_q, scale / Qb\n\n\nclass BitLinear(nn.Module):\n \"\"\"BitLinear with Distribution-Aware Progressive Quantization.\n\n Unlike standard BitLinear which uses fixed absmean thresholds and hard\n quantization, this implementation uses:\n - Percentile-based adaptive thresholds per output channel\n - Soft ternary quantization that approximates {-1, 0, +1}\n - Proper output rescaling for quantized computation\n \"\"\"\n def __init__(self, in_features, out_features, bias=True):\n super().__init__()\n self.in_features = in_features\n self.out_features = out_features\n self.weight = nn.Parameter(torch.empty(out_features, in_features))\n if bias:\n self.bias = nn.Parameter(torch.zeros(out_features))\n else:\n self.bias = None\n nn.init.normal_(self.weight, mean=0.0, std=0.02)\n\n def forward(self, x):\n w_q, w_scale = weight_quant(self.weight)\n x_q, x_scale = activation_quant(x)\n # Perform matmul with quantized values, then rescale\n out = F.linear(x_q, w_q, None)\n # Rescale output: the true output ~ (x_scale * w_scale) * out_quantized\n # w_scale is per-channel [out_features], x_scale is scalar\n out = out * (w_scale * x_scale)\n if self.bias is not None:\n out = out + self.bias\n return out",
3 "filename": "nanoGPT/custom_pretrain.py",
4 "op": "replace",
5 "start_line": 38,
6 "end_line": 148
7}OK: Replaced lines 38..148 with 95 line(s) in nanoGPT/custom_pretrain.py. Editable range: 38–132, 345–345.
[Current file: nanoGPT/custom_pretrain.py | editable: 38–132, 345–345 | total: 535 lines]
38 # -- Native Low-Bit Linear (BitLinear) Module ---------------------------------
39 def weight_quant(weight):
40 """Distribution-Aware Progressive Ternary Quantization {-1, 0, +1}.
... (89 more lines) ...
130 if self.bias is not None:
131 out = out + self.bias
132 return out
...
345 CONFIG_OVERRIDES = {}Tool Input
1{
2 "content": "# -- Native Low-Bit Linear (BitLinear) Module ---------------------------------\ndef weight_quant(weight):\n \"\"\"Per-Channel RMS-Weighted Ternary Quantization {-1, 0, +1}.\n\n Novel contributions over standard BitNet b1.58:\n \n 1. **Per-channel scaling** (instead of global): Each output channel gets\n its own scale factor. This is crucial because different rows in weight\n matrices (e.g., different attention heads, different MLP neurons) often\n have very different magnitudes. A single global scale forces a\n trade-off where some channels are over-quantized while others\n under-utilize the ternary range.\n \n 2. **RMS-based threshold** (instead of absmean): Using RMS\n (sqrt(mean(w^2))) instead of mean(|w|) better preserves signal energy.\n RMS is less sensitive to the exact distribution shape and provides\n more consistent quantization across different weight layers.\n \n 3. **Outlier-aware clamping**: Before quantization, weights are clamped\n to [-1.5, 1.5] to prevent extreme outliers from distorting the \n quantization boundaries. This is particularly important for per-channel\n quantization where a single large weight in a channel can dominate\n the scale.\n\n The ternary quantization {-1, 0, +1} is applied with the Straight-Through\n Estimator (STE) for gradient flow.\n\n Args:\n weight: float latent weight tensor [out_features, in_features]\n Returns:\n (quantized_weight, scale): quantized weight and per-channel scales\n \"\"\"\n # Per-channel RMS scale: sqrt(mean(w^2)) per output channel\n # This preserves signal energy better than mean(|w|)\n # Use detach() so scale doesn't receive gradients\n scale = weight.detach().pow(2).mean(dim=1, keepdim=True).sqrt().clamp(min=1e-8)\n \n # Normalize weights to roughly [-1, 1] range\n w_normed = weight / scale\n \n # Clamp outliers to prevent them from dominating quantization\n w_clamped = w_normed.clamp(-1.5, 1.5)\n \n # Hard ternary quantization: round to nearest integer, clamp to {-1, 0, +1}\n w_q = w_clamped.round().clamp(-1, 1)\n \n # Straight-Through Estimator: forward uses quantized, backward uses identity\n w_q = (w_q - w_clamped).detach() + w_clamped\n \n return w_q, scale.squeeze(1)\n\n\ndef activation_quant(x):\n \"\"\"Absmax 8-bit activation quantization with STE.\n \n Quantizes activations to 127 levels (int8 range) using per-tensor\n absmax scaling, following the BitNet b1.58 paper.\n \"\"\"\n Qb = 127 # int8 range\n scale = x.detach().abs().max().clamp(min=1e-12)\n x_normed = x / scale\n x_q = (x_normed * Qb).round().clamp(-Qb, Qb)\n # STE: forward uses quantized, backward passes through\n x_q = (x_q - x_normed * Qb).detach() + x_normed * Qb\n return x_q, scale / Qb\n\n\nclass BitLinear(nn.Module):\n \"\"\"BitLinear with Per-Channel RMS-Weighted Ternary Quantization.\n\n Unlike standard BitLinear which uses a single global scale (absmean) for\n the entire weight matrix, this implementation uses per-channel RMS-based\n scaling. This provides finer-grained adaptation to the weight distribution\n in each output channel.\n \"\"\"\n def __init__(self, in_features, out_features, bias=True):\n super().__init__()\n self.in_features = in_features\n self.out_features = out_features\n self.weight = nn.Parameter(torch.empty(out_features, in_features))\n if bias:\n self.bias = nn.Parameter(torch.zeros(out_features))\n else:\n self.bias = None\n nn.init.normal_(self.weight, mean=0.0, std=0.02)\n\n def forward(self, x):\n w_q, w_scale = weight_quant(self.weight)\n x_q, x_scale = activation_quant(x)\n # Perform matmul with quantized values, then rescale\n out = F.linear(x_q, w_q, None)\n # Rescale output: w_scale is per-channel [out_features], x_scale is scalar\n out = out * (w_scale * x_scale)\n if self.bias is not None:\n out = out + self.bias\n return out",
3 "filename": "nanoGPT/custom_pretrain.py",
4 "op": "replace",
5 "start_line": 38,
6 "end_line": 132
7}OK: Replaced lines 38..132 with 96 line(s) in nanoGPT/custom_pretrain.py. Editable range: 38–133, 346–346.
[Current file: nanoGPT/custom_pretrain.py | editable: 38–133, 346–346 | total: 536 lines]
38 # -- Native Low-Bit Linear (BitLinear) Module ---------------------------------
39 def weight_quant(weight):
40 """Per-Channel RMS-Weighted Ternary Quantization {-1, 0, +1}.
... (90 more lines) ...
131 if self.bias is not None:
132 out = out + self.bias
133 return out
...
346 CONFIG_OVERRIDES = {}[Test #1] (2 tests remaining; call submit(n=N) to choose which test result to submit as final) ### gpt-345m (scripts/gpt_345m.sh) Training metrics (last steps): TRAIN_METRICS: step=0, train_loss=10.9434, val_loss=10.9461 ### lm-eval-345m (scripts/lm_eval_345m.sh) [COMMAND FAILED exit=-1] [output file not found: /scratch/gpfs/CHIJ/bohan/MLS-Bench/logs/llm-pretrain-bitlinear/qwen3.6-plus_20260413_120858/20260413_121124/group_2/lm-eval-345m_s42.out]
Tool Input
1{
2 "content": "# -- Native Low-Bit Linear (BitLinear) Module ---------------------------------\ndef weight_quant(weight):\n \"\"\"Per-Channel RMS Ternary Quantization {-1, 0, +1} with Smooth STE.\n\n Novel contributions:\n 1. **Per-channel RMS scaling**: Each output channel gets its own scale\n based on RMS (sqrt(mean(w^2))). This preserves signal energy and\n adapts to different magnitudes across channels (e.g., different\n attention heads or MLP neurons).\n 2. **Smooth STE with gradient preservation**: Instead of hard clamping\n before rounding (which creates dead gradients for clamped weights),\n we clamp AFTER rounding, preserving gradient signal for weights\n that would otherwise be clamped.\n 3. **Outlier-resistant normalization**: RMS is used instead of absmean\n as it's more robust to outliers and better preserves the L2 norm\n of the weight vector, which is what matters for dot products.\n\n Args:\n weight: float latent weight tensor [out_features, in_features]\n Returns:\n (quantized_weight, scale): quantized weight and per-channel scales\n \"\"\"\n # Per-channel RMS: preserves L2 norm of each output channel\n # shape: [out_features, 1]\n scale = (weight.detach() ** 2).mean(dim=1, keepdim=True).sqrt().clamp(min=1e-8)\n \n # Normalize weights\n w_normed = weight / scale\n \n # Hard ternary quantization: round and clamp to {-1, 0, +1}\n # The key insight: we clamp AFTER rounding to preserve gradients\n # This way, weights with |w| > 1.5 still get gradient signal\n w_q = w_normed.round()\n \n # STE: forward uses quantized, backward uses original (clamped for safety)\n w_q_ste = (w_q.clamp(-1, 1) - w_normed).detach() + w_normed\n \n return w_q_ste, scale.squeeze(1)\n\n\ndef activation_quant(x):\n \"\"\"Absmax 8-bit activation quantization with STE.\n \n Quantizes activations to 127 levels (int8 range) using per-tensor\n absmax scaling, following the BitNet b1.58 paper.\n \"\"\"\n Qb = 127 # int8 range\n scale = x.detach().abs().max().clamp(min=1e-12)\n x_normed = x / scale\n x_q = (x_normed * Qb).round().clamp(-Qb, Qb)\n # STE: forward uses quantized, backward passes through\n x_q = (x_q - x_normed * Qb).detach() + x_normed * Qb\n return x_q, scale / Qb\n\n\nclass BitLinear(nn.Module):\n \"\"\"BitLinear with Per-Channel RMS Ternary Quantization.\n\n Uses per-channel RMS scaling for weight quantization instead of global\n absmean scaling, providing finer adaptation to weight distributions.\n \"\"\"\n def __init__(self, in_features, out_features, bias=True):\n super().__init__()\n self.in_features = in_features\n self.out_features = out_features\n self.weight = nn.Parameter(torch.empty(out_features, in_features))\n if bias:\n self.bias = nn.Parameter(torch.zeros(out_features))\n else:\n self.bias = None\n nn.init.normal_(self.weight, mean=0.0, std=0.02)\n\n def forward(self, x):\n w_q, w_scale = weight_quant(self.weight)\n x_q, x_scale = activation_quant(x)\n # Perform matmul with quantized values, then rescale\n out = F.linear(x_q, w_q, None)\n # Rescale: w_scale is per-channel [out_features], x_scale is scalar\n out = out * (w_scale * x_scale)\n if self.bias is not None:\n out = out + self.bias\n return out",
3 "filename": "nanoGPT/custom_pretrain.py",
4 "op": "replace",
5 "start_line": 38,
6 "end_line": 133
7}OK: Replaced lines 38..133 with 82 line(s) in nanoGPT/custom_pretrain.py. Editable range: 38–119, 332–332.
[Current file: nanoGPT/custom_pretrain.py | editable: 38–119, 332–332 | total: 522 lines]
38 # -- Native Low-Bit Linear (BitLinear) Module ---------------------------------
39 def weight_quant(weight):
40 """Per-Channel RMS Ternary Quantization {-1, 0, +1} with Smooth STE.
... (76 more lines) ...
117 if self.bias is not None:
118 out = out + self.bias
119 return out
...
332 CONFIG_OVERRIDES = {}Tool Input
1{
2 "content": "# -- Native Low-Bit Linear (BitLinear) Module ---------------------------------\ndef weight_quant(weight):\n \"\"\"Per-Channel RMS Ternary Quantization {-1, 0, +1} with Smooth STE.\n\n Novel contributions:\n 1. **Per-channel RMS scaling**: Each output channel gets its own scale\n based on RMS (sqrt(mean(w^2))). This preserves signal energy and\n adapts to different magnitudes across channels (e.g., different\n attention heads or MLP neurons).\n 2. **Outlier-resistant normalization**: RMS is used instead of absmean\n as it's more robust and better preserves the L2 norm of the weight\n vector, which is what matters for dot products.\n\n Args:\n weight: float latent weight tensor [out_features, in_features]\n Returns:\n (quantized_weight, scale): quantized weight and per-channel scales\n \"\"\"\n # Per-channel RMS: preserves L2 norm of each output channel\n # shape: [out_features]\n scale = torch.sqrt((weight.detach() ** 2).mean(dim=1)).clamp(min=1e-8)\n \n # Normalize weights to roughly [-1, 1] range\n # weight is [out_features, in_features], scale is [out_features]\n w_normed = weight / scale.unsqueeze(1)\n \n # Ternary quantization: round and clamp to {-1, 0, +1}\n w_q = w_normed.round().clamp(-1, 1)\n \n # STE: forward uses quantized, backward passes through w_normed\n w_q_ste = (w_q - w_normed).detach() + w_normed\n \n return w_q_ste, scale\n\n\ndef activation_quant(x):\n \"\"\"Absmax 8-bit activation quantization with STE.\n \n Quantizes activations to 127 levels (int8 range) using per-tensor\n absmax scaling, following the BitNet b1.58 paper.\n \"\"\"\n Qb = 127 # int8 range\n scale = x.detach().abs().max().clamp(min=1e-12)\n x_normed = x / scale\n x_q = (x_normed * Qb).round().clamp(-Qb, Qb)\n # STE: forward uses quantized, backward passes through\n x_q = (x_q - x_normed * Qb).detach() + x_normed * Qb\n return x_q, scale / Qb\n\n\nclass BitLinear(nn.Module):\n \"\"\"BitLinear with Per-Channel RMS Ternary Quantization.\n\n Uses per-channel RMS scaling for weight quantization instead of global\n absmean scaling, providing finer adaptation to weight distributions.\n \"\"\"\n def __init__(self, in_features, out_features, bias=True):\n super().__init__()\n self.in_features = in_features\n self.out_features = out_features\n self.weight = nn.Parameter(torch.empty(out_features, in_features))\n if bias:\n self.bias = nn.Parameter(torch.zeros(out_features))\n else:\n self.bias = None\n nn.init.normal_(self.weight, mean=0.0, std=0.02)\n\n def forward(self, x):\n w_q, w_scale = weight_quant(self.weight)\n x_q, x_scale = activation_quant(x)\n # Perform matmul with quantized values, then rescale\n out = F.linear(x_q, w_q, None)\n # Rescale: w_scale is per-channel [out_features], x_scale is scalar\n out = out * (w_scale * x_scale)\n if self.bias is not None:\n out = out + self.bias\n return out",
3 "filename": "nanoGPT/custom_pretrain.py",
4 "op": "replace",
5 "start_line": 38,
6 "end_line": 119
7}OK: Replaced lines 38..119 with 77 line(s) in nanoGPT/custom_pretrain.py. Editable range: 38–114, 327–327.
[Current file: nanoGPT/custom_pretrain.py | editable: 38–114, 327–327 | total: 517 lines]
38 # -- Native Low-Bit Linear (BitLinear) Module ---------------------------------
39 def weight_quant(weight):
40 """Per-Channel RMS Ternary Quantization {-1, 0, +1} with Smooth STE.
... (71 more lines) ...
112 if self.bias is not None:
113 out = out + self.bias
114 return out
...
327 CONFIG_OVERRIDES = {}Tool Input
1{
2 "content": "# -- Native Low-Bit Linear (BitLinear) Module ---------------------------------\ndef weight_quant(weight):\n \"\"\"Per-Channel RMS Ternary Quantization {-1, 0, +1} with STE.\n\n Algorithmic contribution: **Per-Channel RMS Scaling**\n \n Standard BitNet b1.58 uses a single global scale (absmean of all weights).\n This is suboptimal because different output channels (rows) in weight\n matrices often have very different magnitudes -- e.g., different attention\n heads learn different features, different MLP neurons have different\n activation levels. A single global scale forces a compromise where some\n channels are over-quantized (wasting ternary capacity) while others\n under-utilize the {-1, 0, +1} range.\n \n Our approach:\n 1. Compute per-channel RMS scale: sqrt(mean(w_i^2)) for each output channel i\n 2. Normalize each channel independently\n 3. Apply ternary quantization per channel\n 4. STE passes gradients through the normalized (but not quantized) values\n \n RMS is used instead of absmean because:\n - RMS preserves L2 norm, which directly affects dot product magnitudes\n - RMS is less sensitive to the exact distribution shape\n - RMS aligns better with weight decay (L2 regularization)\n \n Args:\n weight: float latent weight tensor [out_features, in_features]\n Returns:\n (quantized_weight, scale): quantized weight and per-channel scales [out_features]\n \"\"\"\n # Per-channel RMS scale: [out_features]\n scale = weight.detach().pow(2).mean(dim=1).sqrt().clamp(min=1e-8)\n \n # Normalize each channel: [out_features, in_features]\n w_normed = weight / scale.unsqueeze(1)\n \n # Clamp to [-1, 1] then round to {-1, 0, +1}\n # Clamping first ensures the STE gradient is bounded\n w_clamped = w_normed.clamp(-1.0, 1.0)\n w_q = w_clamped.round()\n \n # STE: forward uses ternary quantized, backward passes through clamped\n w_q_ste = (w_q - w_clamped).detach() + w_clamped\n \n return w_q_ste, scale\n\n\ndef activation_quant(x):\n \"\"\"Absmax 8-bit activation quantization with STE.\n \n Quantizes activations to 127 levels (int8 range) using per-tensor\n absmax scaling, following the BitNet b1.58 paper.\n \"\"\"\n Qb = 127 # int8 range\n scale = x.detach().abs().max().clamp(min=1e-12)\n x_normed = x / scale\n x_q = (x_normed * Qb).round().clamp(-Qb, Qb)\n # STE: forward uses quantized, backward passes through\n x_q = (x_q - x_normed * Qb).detach() + x_normed * Qb\n return x_q, scale / Qb\n\n\nclass BitLinear(nn.Module):\n \"\"\"BitLinear with Per-Channel RMS Ternary Quantization.\n\n Key difference from standard BitLinear: uses per-channel RMS scaling\n instead of global absmean scaling, allowing each output channel to\n adapt its quantization scale independently.\n \"\"\"\n def __init__(self, in_features, out_features, bias=True):\n super().__init__()\n self.in_features = in_features\n self.out_features = out_features\n self.weight = nn.Parameter(torch.empty(out_features, in_features))\n if bias:\n self.bias = nn.Parameter(torch.zeros(out_features))\n else:\n self.bias = None\n nn.init.normal_(self.weight, mean=0.0, std=0.02)\n\n def forward(self, x):\n w_q, w_scale = weight_quant(self.weight)\n x_q, x_scale = activation_quant(x)\n out = F.linear(x_q, w_q, None)\n # Rescale: w_scale is per-channel [out_features], x_scale is scalar\n # Broadcasting: [..., out_features] * [out_features] = [..., out_features]\n out = out * (w_scale * x_scale)\n if self.bias is not None:\n out = out + self.bias\n return out",
3 "filename": "nanoGPT/custom_pretrain.py",
4 "op": "replace",
5 "start_line": 38,
6 "end_line": 114
7}OK: Replaced lines 38..114 with 90 line(s) in nanoGPT/custom_pretrain.py. Editable range: 38–127, 340–340.
[Current file: nanoGPT/custom_pretrain.py | editable: 38–127, 340–340 | total: 530 lines]
38 # -- Native Low-Bit Linear (BitLinear) Module ---------------------------------
39 def weight_quant(weight):
40 """Per-Channel RMS Ternary Quantization {-1, 0, +1} with STE.
... (84 more lines) ...
125 if self.bias is not None:
126 out = out + self.bias
127 return out
...
340 CONFIG_OVERRIDES = {}[Test #2] (1 test remaining; call submit(n=N) to choose which test result to submit as final) ### gpt-345m (scripts/gpt_345m.sh) [COMMAND FAILED exit=-1] [output file not found: /scratch/gpfs/CHIJ/bohan/MLS-Bench/logs/llm-pretrain-bitlinear/qwen3.6-plus_20260413_120858/20260413_134325/group_1/gpt-345m_s42.out] ### lm-eval-345m (scripts/lm_eval_345m.sh) [COMMAND FAILED exit=1] ERROR: Checkpoint not found: /scratch/gpfs/CHIJ/bohan/MLS-Bench/.saves/llm-pretrain-bitlinear/qwen3.6-plus_20260413_120858/seed_42/ckpt_gpt-345m.pt
Tool Input
1{
2 "content": "# -- Native Low-Bit Linear (BitLinear) Module ---------------------------------\ndef weight_quant(weight):\n \"\"\"Per-Tensor RMS Ternary Quantization {-1, 0, +1} with STE.\n\n Algorithmic contribution: **RMS-based Ternary Quantization**\n \n Standard BitNet b1.58 uses absmean (L1 norm) for scaling. We instead use\n RMS (root-mean-square, L2 norm) which:\n \n 1. Better preserves signal energy in dot products (matmul uses L2 geometry)\n 2. Aligns with weight decay (L2 regularization) used in AdamW\n 3. Is less sensitive to distribution shape -- for Gaussian weights,\n RMS ≈ 1.25 * absmean, providing a more appropriate scaling\n \n The ternary quantization {-1, 0, +1} is applied with the\n Straight-Through Estimator (STE) for gradient flow.\n \n Args:\n weight: float latent weight tensor [out_features, in_features]\n Returns:\n (quantized_weight, scale): quantized weight and scalar scale\n \"\"\"\n # Per-tensor RMS scale (scalar)\n scale = torch.sqrt((weight.detach() ** 2).mean()).clamp(min=1e-8)\n \n # Normalize weights\n w_normed = weight / scale\n \n # Ternary quantization: clamp then round to {-1, 0, +1}\n w_q = w_normed.clamp(-1.0, 1.0).round()\n \n # STE: forward uses quantized, backward passes through clamped\n w_q_ste = (w_q - w_normed.clamp(-1.0, 1.0)).detach() + w_normed.clamp(-1.0, 1.0)\n \n return w_q_ste, scale\n\n\ndef activation_quant(x):\n \"\"\"Absmax 8-bit activation quantization with STE.\n \n Quantizes activations to 127 levels (int8 range) using per-tensor\n absmax scaling, following the BitNet b1.58 paper.\n \"\"\"\n Qb = 127 # int8 range\n scale = x.detach().abs().max().clamp(min=1e-12)\n x_normed = x / scale\n x_q = (x_normed * Qb).round().clamp(-Qb, Qb)\n # STE: forward uses quantized, backward passes through\n x_q = (x_q - x_normed * Qb).detach() + x_normed * Qb\n return x_q, scale / Qb\n\n\nclass BitLinear(nn.Module):\n \"\"\"BitLinear with RMS-based Ternary Quantization.\n \n Key difference from standard BitLinear: uses RMS (L2 norm) instead of\n absmean (L1 norm) for weight scaling. This better preserves signal\n energy in the matmul operation.\n \"\"\"\n def __init__(self, in_features, out_features, bias=True):\n super().__init__()\n self.in_features = in_features\n self.out_features = out_features\n self.weight = nn.Parameter(torch.empty(out_features, in_features))\n if bias:\n self.bias = nn.Parameter(torch.zeros(out_features))\n else:\n self.bias = None\n nn.init.normal_(self.weight, mean=0.0, std=0.02)\n\n def forward(self, x):\n w_q, w_scale = weight_quant(self.weight)\n x_q, x_scale = activation_quant(x)\n out = F.linear(x_q, w_q, None)\n out = out * (w_scale * x_scale)\n if self.bias is not None:\n out = out + self.bias\n return out",
3 "filename": "nanoGPT/custom_pretrain.py",
4 "op": "replace",
5 "start_line": 38,
6 "end_line": 127
7}OK: Replaced lines 38..127 with 78 line(s) in nanoGPT/custom_pretrain.py. Editable range: 38–115, 328–328.
[Current file: nanoGPT/custom_pretrain.py | editable: 38–115, 328–328 | total: 518 lines]
38 # -- Native Low-Bit Linear (BitLinear) Module ---------------------------------
39 def weight_quant(weight):
40 """Per-Tensor RMS Ternary Quantization {-1, 0, +1} with STE.
... (72 more lines) ...
113 if self.bias is not None:
114 out = out + self.bias
115 return out
...
328 CONFIG_OVERRIDES = {}