optimization-variance-reduction
Description
Variance Reduction for Stochastic Optimization
Research Question
Design an improved variance reduction strategy for stochastic gradient descent on finite-sum optimization problems. Your method should accelerate convergence compared to vanilla mini-batch SGD by reducing the variance of gradient estimates.
Background
Many machine learning problems take the form of finite-sum optimization:
min_x F(x) = (1/n) * sum_{i=1}^{n} f_i(x)
Standard SGD uses a stochastic gradient from a random mini-batch, which has variance proportional to 1/b (where b is the batch size). Variance reduction methods use auxiliary information (snapshots, recursive corrections, momentum) to reduce this variance, enabling faster convergence -- often achieving linear convergence rates for strongly convex problems where SGD only achieves sublinear rates.
Key methods in this area include SVRG (periodic full gradient + control variate), SARAH (recursive gradient correction), STORM (momentum-based online variance reduction), SPIDER, and PAGE.
Task
Modify the VarianceReductionOptimizer class in custom_vr.py (lines 286-370). You must implement:
__init__(self, model, lr, l2_reg, loss_type, n_train, batch_size, device): Initialize any state needed for variance reduction (snapshot parameters, running gradient estimates, buffers, etc.)train_one_epoch(self, X_train, y_train): Train for one epoch over the data, returning a dict with at least'avg_loss'(and optionally'full_grad_count'if you use full gradient computations)
The default implementation is vanilla mini-batch SGD. Your goal is to design a variance reduction mechanism that improves convergence.
Interface
Available helper functions (FIXED, use these for gradient computation):
compute_full_gradient(model, X_train, y_train, loss_type, l2_reg, device)
# -> returns list of gradient tensors (one per parameter)
compute_stochastic_gradient(model, X_batch, y_batch, loss_type, l2_reg)
# -> returns list of gradient tensors for a mini-batch
compute_loss_on_batch(model, X_batch, y_batch, loss_type, l2_reg)
# -> returns scalar loss tensor
Constraints
- You may call
compute_full_gradientat most once per epoch - Parameter updates must use
p.data.add_(...)or similar in-place operations - Must work across all three problems with the same code
- The learning rate (
self.lr) and L2 regularization (self.l2_reg) are fixed - Do not modify the model architecture, loss function, or evaluation code
Evaluation
- Problems:
logistic: L2-regularized multinomial logistic regression on MNIST (convex, n=60K, 20 epochs)mlp: 2-layer MLP on CIFAR-10 (non-convex, n=50K, 40 epochs)conditioned: L2-regularized linear regression on synthetic ill-conditioned data (strongly convex, kappa=100, n=10K, 30 epochs)
- Metrics:
best_test_accuracy(logistic, mlp; higher is better) andbest_test_mse(conditioned; lower is better) - All three problems run in parallel with shared compute
Code
1"""Variance Reduction Benchmark for Finite-Sum Optimization23Evaluates variance reduction strategies for stochastic gradient methods on4finite-sum problems: min_x F(x) = (1/n) * sum_{i=1}^{n} f_i(x)56Benchmarks:71. logistic -- L2-regularized logistic regression on MNIST (convex)82. mlp -- 2-layer MLP on CIFAR-10 (non-convex)93. conditioned -- L2-regularized linear regression on synthetic10ill-conditioned data (strongly convex)1112Usage:13python opt-vr-bench/custom_vr.py --problem <name> \14--seed $SEED --output-dir $OUTPUT_DIR15"""
Results
| Model | Type | best test accuracy logistic ↑ | final test accuracy logistic ↑ | best test accuracy mlp ↑ | final test accuracy mlp ↑ | best test mse conditioned ↓ | final test mse conditioned ↓ |
|---|---|---|---|---|---|---|---|
| page | baseline | 89.600 | 88.683 | 31.743 | 26.867 | 522.076 | 693.326 |
| spider | baseline | 89.873 | 89.057 | 20.603 | 16.347 | 522.076 | 693.325 |
| spiderboost | baseline | 87.153 | 86.933 | 13.230 | 10.003 | 533.065 | 1243.025 |
| storm | baseline | 92.497 | 92.357 | 53.903 | 52.377 | 1.755 | 2.962 |
| storm_plus | baseline | 92.507 | 92.477 | 54.197 | 52.500 | 0.015 | 0.018 |
| svrg | baseline | 92.633 | 92.617 | 52.293 | 50.370 | 767.018 | 3.6097308551527304e+34 |
| anthropic/claude-opus-4.6 | agent | 92.633 | 92.603 | 52.553 | 51.413 | 0.581 | 0.723 |
| deepseek-reasoner | agent | 92.517 | 92.237 | 54.127 | 52.760 | 517972.280 | 517972.280 |
| google/gemini-3.1-pro-preview | agent | 84.623 | 81.760 | 10.000 | 10.000 | - | - |
| openai/gpt-5.4-pro | agent | 92.380 | 92.377 | 52.783 | 50.990 | - | - |
| qwen3.6-plus:free | agent | 92.643 | 92.580 | 20.687 | 10.027 | 52.610 | 56.375 |