optimization-diagonal-net
Description
Optimizer Design for Diagonal-Net Sparse Recovery
Research Question
Can you design an optimizer that recovers a sparse linear predictor from fewer training samples when the model uses a diagonal-net parameterization with noisy labels?
Background
The diagonal-net reparameterizes a linear model as w = u^2 - v^2 (element-wise), where u, v ∈ R^d are the trainable parameters. Despite the model being equivalent to a linear predictor, the squared parameterization creates a non-convex loss landscape whose geometry interacts with the optimizer's implicit bias. Classical results show that gradient-based methods on this parameterization can achieve implicit sparse regularization — the optimizer's dynamics naturally favour sparse solutions without explicit L1 penalties.
The benchmark uses PyTorch with autograd for gradient computation. Each training step adds fresh Rademacher noise ζ_t ∈ {-delta, +delta} to the labels before computing the loss, simulating stochastic perturbations. Test evaluation always uses clean (noise-free) labels.
The critical quantity is the sample complexity of recovery: how many training examples n does the optimizer need to reliably recover a k-sparse ground truth in R^d? Different optimizers induce different implicit biases, leading to dramatically different sample requirements.
Task
Modify the three functions in RAIN/opt_diagonal_net/custom_optimizer.py (lines 23–90) to implement a novel or improved optimizer:
get_hyperparameters(dim, sparsity, noise_scale, delta)— return optimizer configurationinit_state(u, v, hyperparameters)— initialise optimizer statestep(u, v, grad_u, grad_v, state, hyperparameters)— perform one update step
The default template implements vanilla gradient descent. Your goal is to achieve successful recovery (test MSE < 1.0) with fewer training samples across all test settings.
Interface
u,v: parameter vectors of shape(d,)astorch.Tensor(float64), initialised asalpha/sqrt(2d) * ones(d)withalpha = 1e-3grad_u,grad_v: full-batch MSE gradients w.r.t.uandv(computed by PyTorch autograd)state: mutable dict for optimizer internal state (momentum buffers, accumulators, etc.)hyperparameters: dict returned byget_hyperparametersstep()must return(u_new, v_new, state_new)as a tuple of torch.Tensor and dict- All operations should use
torch(not numpy); the benchmark provides gradients via autograd - The
deltaparameter controls the magnitude of Rademacher noise added to training labels each step
Training loop (executed by the benchmark)
model.zero_grad()
noise = delta * (2 * torch.randint(0, 2, y_train.shape) - 1).float()
y_noisy = y_train + noise
loss = 0.5 * torch.mean((model(X_train) - y_noisy) ** 2)
loss.backward()
with torch.no_grad():
u_new, v_new, state = step(u, v, grad_u, grad_v, state, hparams)
model.u.data.copy_(u_new)
model.v.data.copy_(v_new)
Evaluation
Three problem settings are evaluated:
- d200_k5_s01: d=200, k=5, sigma=0.1, delta=0.5
- d500_k10_s01: d=500, k=10, sigma=0.1, delta=0.5
- d500_k10_s02: d=500, k=10, sigma=0.2, delta=0.5
For each setting, the benchmark performs a coarse-to-fine search over training-set sizes n ∈ {50, 75, ..., 1600} to find the smallest n* where recovery succeeds on at least 4 of 5 seeds. Recovery means test MSE < 1.0 at the time training stops.
Metric: score = -log2(n*) per setting (higher is better — fewer samples needed).
Training uses full-batch gradients (with noisy labels) and a shared stopping rule: training halts when both train and test MSE have plateaued (two-window comparison over 20,000 steps), or after 1,000,000 steps.
Baselines (16 configurations)
SGD (4 configs): lr ∈ {0.005, 0.01, 0.05, 0.1} AdaGrad (4 configs): lr ∈ {0.005, 0.01, 0.05, 0.1}, eps=1e-6 Adam without bias correction (8 configs): lr ∈ {0.005, 0.01, 0.05, 0.1} × beta2 ∈ {0.95, 0.999}, beta1=0.9, eps=1e-6
Hints
- The diagonal-net parameterization
w = u^2 - v^2naturally biases gradient descent toward sparse solutions when initialised near zero - Adaptive methods (Adam, AdaGrad) change the effective geometry of this bias — this can help or hurt
- The initialisation
alpha/sqrt(2d) * ones(d)withalpha = 1e-3means u=v at init, so w_hat=0 initially - The Rademacher noise (delta parameter) adds stochasticity to training — your optimizer should be robust to this
- Consider how your optimizer interacts with the non-convex structure: coordinate-wise adaptivity, momentum, and learning rate scheduling all affect the sparsity bias
- All 16 baselines use eps=1e-6 and Adam uses NO bias correction
Code
1"""Editable optimizer scaffold for the opt-diagonal-net MLS-Bench task.23Implement a custom optimizer for training a diagonal-net model to recover4a sparse linear predictor. You may edit the three functions below5(get_hyperparameters, init_state, step) while the benchmark harness,6data generation, model, stopping rule, and search protocol are fixed.7"""89from __future__ import annotations1011from typing import Any1213import torch1415from fixed_benchmark import run_cli
1"""Fixed benchmark harness for diagonal-net sparse recovery.23Evaluates custom optimizers on the problem of recovering a sparse linear4predictor through a diagonal-net parameterization (w_hat = u^L - v^L, L=2).5The benchmark measures the minimum training-set size n* required for6reliable recovery (test MSE < 1.0) using a coarse-to-fine search protocol.78Metric: -log2(n*) (higher is better — fewer training samples needed).910Uses PyTorch with autograd for gradient computation; the optimizer interface11receives torch.Tensor gradients directly.12"""1314from __future__ import annotations15
Results
| Model | Type | n star d200 k5 s01 ↓ | score d200 k5 s01 ↑ | n star d500 k10 s01 ↓ | score d500 k10 s01 ↑ | n star d500 k10 s02 ↓ | score d500 k10 s02 ↑ | n star d10000 k50 ↓ | score d10000 k50 ↑ |
|---|---|---|---|---|---|---|---|---|---|
| adagrad | baseline | 175.000 | -7.451 | 487.000 | -8.928 | 487.000 | -8.928 | 2000.000 | -10.966 |
| adam | baseline | 50.000 | -5.644 | 56.000 | -5.807 | 56.000 | -5.807 | 350.000 | -8.451 |
| adam2 | baseline | 50.000 | -5.644 | 53.000 | -5.728 | 53.000 | -5.728 | 350.000 | -8.451 |
| sgd | baseline | 50.000 | -5.644 | 62.000 | -5.954 | 62.000 | -5.954 | 487.000 | -8.928 |