optimization-diagonal-net

OptimizationRAINrigorous codebase

Description

Optimizer Design for Diagonal-Net Sparse Recovery

Research Question

Can you design an optimizer that recovers a sparse linear predictor from fewer training samples when the model uses a diagonal-net parameterization with noisy labels?

Background

The diagonal-net reparameterizes a linear model as w = u^2 - v^2 (element-wise), where u, v ∈ R^d are the trainable parameters. Despite the model being equivalent to a linear predictor, the squared parameterization creates a non-convex loss landscape whose geometry interacts with the optimizer's implicit bias. Classical results show that gradient-based methods on this parameterization can achieve implicit sparse regularization — the optimizer's dynamics naturally favour sparse solutions without explicit L1 penalties.

The benchmark uses PyTorch with autograd for gradient computation. Each training step adds fresh Rademacher noise ζ_t ∈ {-delta, +delta} to the labels before computing the loss, simulating stochastic perturbations. Test evaluation always uses clean (noise-free) labels.

The critical quantity is the sample complexity of recovery: how many training examples n does the optimizer need to reliably recover a k-sparse ground truth in R^d? Different optimizers induce different implicit biases, leading to dramatically different sample requirements.

Task

Modify the three functions in RAIN/opt_diagonal_net/custom_optimizer.py (lines 23–90) to implement a novel or improved optimizer:

  1. get_hyperparameters(dim, sparsity, noise_scale, delta) — return optimizer configuration
  2. init_state(u, v, hyperparameters) — initialise optimizer state
  3. step(u, v, grad_u, grad_v, state, hyperparameters) — perform one update step

The default template implements vanilla gradient descent. Your goal is to achieve successful recovery (test MSE < 1.0) with fewer training samples across all test settings.

Interface

  • u, v: parameter vectors of shape (d,) as torch.Tensor (float64), initialised as alpha/sqrt(2d) * ones(d) with alpha = 1e-3
  • grad_u, grad_v: full-batch MSE gradients w.r.t. u and v (computed by PyTorch autograd)
  • state: mutable dict for optimizer internal state (momentum buffers, accumulators, etc.)
  • hyperparameters: dict returned by get_hyperparameters
  • step() must return (u_new, v_new, state_new) as a tuple of torch.Tensor and dict
  • All operations should use torch (not numpy); the benchmark provides gradients via autograd
  • The delta parameter controls the magnitude of Rademacher noise added to training labels each step

Training loop (executed by the benchmark)

model.zero_grad()
noise = delta * (2 * torch.randint(0, 2, y_train.shape) - 1).float()
y_noisy = y_train + noise
loss = 0.5 * torch.mean((model(X_train) - y_noisy) ** 2)
loss.backward()
with torch.no_grad():
    u_new, v_new, state = step(u, v, grad_u, grad_v, state, hparams)
    model.u.data.copy_(u_new)
    model.v.data.copy_(v_new)

Evaluation

Three problem settings are evaluated:

  • d200_k5_s01: d=200, k=5, sigma=0.1, delta=0.5
  • d500_k10_s01: d=500, k=10, sigma=0.1, delta=0.5
  • d500_k10_s02: d=500, k=10, sigma=0.2, delta=0.5

For each setting, the benchmark performs a coarse-to-fine search over training-set sizes n ∈ {50, 75, ..., 1600} to find the smallest n* where recovery succeeds on at least 4 of 5 seeds. Recovery means test MSE < 1.0 at the time training stops.

Metric: score = -log2(n*) per setting (higher is better — fewer samples needed).

Training uses full-batch gradients (with noisy labels) and a shared stopping rule: training halts when both train and test MSE have plateaued (two-window comparison over 20,000 steps), or after 1,000,000 steps.

Baselines (16 configurations)

SGD (4 configs): lr ∈ {0.005, 0.01, 0.05, 0.1} AdaGrad (4 configs): lr ∈ {0.005, 0.01, 0.05, 0.1}, eps=1e-6 Adam without bias correction (8 configs): lr ∈ {0.005, 0.01, 0.05, 0.1} × beta2 ∈ {0.95, 0.999}, beta1=0.9, eps=1e-6

Hints

  • The diagonal-net parameterization w = u^2 - v^2 naturally biases gradient descent toward sparse solutions when initialised near zero
  • Adaptive methods (Adam, AdaGrad) change the effective geometry of this bias — this can help or hurt
  • The initialisation alpha/sqrt(2d) * ones(d) with alpha = 1e-3 means u=v at init, so w_hat=0 initially
  • The Rademacher noise (delta parameter) adds stochasticity to training — your optimizer should be robust to this
  • Consider how your optimizer interacts with the non-convex structure: coordinate-wise adaptivity, momentum, and learning rate scheduling all affect the sparsity bias
  • All 16 baselines use eps=1e-6 and Adam uses NO bias correction

Code

custom_optimizer.py
EditableRead-only
1"""Editable optimizer scaffold for the opt-diagonal-net MLS-Bench task.
2
3Implement a custom optimizer for training a diagonal-net model to recover
4a sparse linear predictor. You may edit the three functions below
5(get_hyperparameters, init_state, step) while the benchmark harness,
6data generation, model, stopping rule, and search protocol are fixed.
7"""
8
9from __future__ import annotations
10
11from typing import Any
12
13import torch
14
15from fixed_benchmark import run_cli
fixed_benchmark.py
EditableRead-only
1"""Fixed benchmark harness for diagonal-net sparse recovery.
2
3Evaluates custom optimizers on the problem of recovering a sparse linear
4predictor through a diagonal-net parameterization (w_hat = u^L - v^L, L=2).
5The benchmark measures the minimum training-set size n* required for
6reliable recovery (test MSE < 1.0) using a coarse-to-fine search protocol.
7
8Metric: -log2(n*) (higher is better — fewer training samples needed).
9
10Uses PyTorch with autograd for gradient computation; the optimizer interface
11receives torch.Tensor gradients directly.
12"""
13
14from __future__ import annotations
15

Results

ModelTypen star d200 k5 s01 score d200 k5 s01 n star d500 k10 s01 score d500 k10 s01 n star d500 k10 s02 score d500 k10 s02 n star d10000 k50 score d10000 k50
adagradbaseline175.000-7.451487.000-8.928487.000-8.9282000.000-10.966
adambaseline50.000-5.64456.000-5.80756.000-5.807350.000-8.451
adam2baseline50.000-5.64453.000-5.72853.000-5.728350.000-8.451
sgdbaseline50.000-5.64462.000-5.95462.000-5.954487.000-8.928