optimization-pac-bayes-bound
Description
Task: PAC-Bayes Generalization Bound Optimization
Research Question
Design a tighter PAC-Bayes generalization bound by optimizing the bound formulation, prior/posterior parameterization, and KL divergence estimation for stochastic neural networks.
Background
PAC-Bayes theory provides non-vacuous generalization bounds for stochastic classifiers. Given a prior distribution P over hypotheses (chosen before seeing data) and a posterior Q (learned from data), PAC-Bayes bounds certify that with high probability (1-delta), the true risk of a stochastic classifier sampled from Q is bounded.
The key components of a PAC-Bayes bound are:
- Empirical risk: estimated loss of the stochastic predictor on training data
- KL divergence: KL(Q||P) measuring complexity of the posterior relative to the prior
- Bound formula: how these terms combine to yield the final certificate
Standard bounds include:
- McAllester/Maurer: risk + sqrt(KL_term / (2n)) -- simple but loose
- Catoni/Lambda: risk/(1-lam/2) + KL_term/(nlam(1-lam/2)) -- tighter with tuned lambda
- Quadratic: (sqrt(risk + KL_term) + sqrt(KL_term))^2 -- better at low risk
The bound can be further tightened through:
- Optimizing the bound functional form (beyond classical inequalities)
- Better training objectives that minimize the bound directly
- Improved risk certificate evaluation (e.g., PAC-Bayes-kl inversion)
- Data-dependent prior construction
- Tighter KL estimation or alternative divergence measures
What to Implement
Implement the BoundOptimizer class in custom_pac_bayes.py. You must implement:
compute_bound(empirical_risk, kl, n, delta): The PAC-Bayes bound formulatrain_step(model, data, target, device, n_bound, delta): Training objectivecompute_risk_certificate(model, bound_loader, device, delta, mc_samples): Final certificate evaluation
Interface
model(x, sample=True/False): stochastic forward pass (sample=True) or posterior mean (sample=False)get_total_kl(model): sum of KL divergence across all probabilistic layersinv_kl(q, c): binary KL inversion -- find p such that KL(Ber(q)||Ber(p)) = ccompute_01_risk(model, loader, device, mc_samples): MC estimate of 0-1 risk- Available losses: F.nll_loss, F.cross_entropy on log_softmax outputs
Evaluation
The bound optimizer is tested on three settings:
- MNIST-FCN: 4-layer fully connected network (784-600-600-600-10) on MNIST
- MNIST-CNN: 4-layer CNN (2 conv + 2 fc) on MNIST
- FashionMNIST-CNN: Same CNN architecture on FashionMNIST
Primary metric: risk_certificate (0-1 loss PAC-Bayes bound) -- lower is better (tighter bound).
Training uses data-dependent priors: 50% of training data trains a deterministic prior, 50% evaluates the bound.
Code
1"""PAC-Bayes Bound Optimization — custom template.23This script trains a stochastic neural network by minimizing a PAC-Bayes4bound and then evaluates the tightness of the resulting risk certificate.56The agent edits the EDITABLE section (BoundOptimizer class) which controls:71. How the PAC-Bayes bound is computed from empirical risk + KL divergence82. How the posterior distribution is optimized (training objective)93. How the final risk certificate is evaluated1011Fixed sections handle data loading, model architecture, stochastic layers,12and the outer training loop.13"""1415import argparse
Additional context files (read-only):
PBB/pbb/models.pyPBB/pbb/utils.py
Results
| Model | Type | risk certificate mnist-fcn ↓ | test error mnist-fcn ↓ | kl divergence mnist-fcn ↓ | ce bound mnist-fcn ↓ | empirical 01 risk mnist-fcn ↓ | risk certificate mnist-cnn ↓ | test error mnist-cnn ↓ | kl divergence mnist-cnn ↓ | ce bound mnist-cnn ↓ | empirical 01 risk mnist-cnn ↓ | risk certificate fmnist-cnn ↓ | test error fmnist-cnn ↓ | kl divergence fmnist-cnn ↓ | ce bound fmnist-cnn ↓ | empirical 01 risk fmnist-cnn ↓ |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| catoni | baseline | 0.064 | 0.021 | 731.767 | 0.232 | 0.017 | 0.028 | 0.009 | 261.040 | 0.103 | 0.009 | 0.134 | 0.088 | 445.063 | 0.524 | 0.078 |
| mcallester | baseline | 0.037 | 0.023 | 71.900 | 0.166 | 0.024 | 0.016 | 0.009 | 10.000 | 0.075 | 0.012 | 0.102 | 0.089 | 38.320 | 0.305 | 0.085 |
| quadratic | baseline | 0.041 | 0.021 | 151.857 | 0.160 | 0.022 | 0.018 | 0.010 | 46.657 | 0.071 | 0.011 | 0.099 | 0.090 | 29.780 | 0.302 | 0.084 |
| anthropic/claude-opus-4.6 | vanilla | 0.039 | 0.022 | 113.570 | 0.161 | 0.023 | 0.017 | 0.010 | 38.200 | 0.068 | 0.011 | 0.100 | 0.088 | 25.410 | 0.305 | 0.086 |
| deepseek-reasoner | vanilla | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - |
| google/gemini-3.1-pro-preview | vanilla | 0.033 | 0.025 | 9.230 | 0.171 | 0.027 | 0.016 | 0.010 | 6.010 | 0.074 | 0.012 | 0.104 | 0.091 | 22.770 | 0.336 | 0.090 |
| openai/gpt-5.4-pro | vanilla | 0.032 | 0.025 | 5.340 | 0.132 | 0.027 | 0.015 | 0.009 | 5.340 | 0.065 | 0.011 | 0.099 | 0.091 | 2.550 | 0.282 | 0.091 |
| qwen3.6-plus:free | vanilla | 0.058 | 0.021 | 610.610 | 0.225 | 0.017 | 0.026 | 0.009 | 235.050 | 0.099 | 0.009 | 0.133 | 0.084 | 449.860 | 0.529 | 0.078 |
| anthropic/claude-opus-4.6 | agent | 0.036 | 0.022 | 76.230 | 0.159 | 0.023 | 0.017 | 0.010 | 28.770 | 0.066 | 0.011 | 0.099 | 0.089 | 23.320 | 0.312 | 0.085 |
| deepseek-reasoner | agent | 0.037 | 0.023 | 71.780 | 0.166 | 0.024 | 0.016 | 0.010 | 10.080 | 0.075 | 0.012 | 0.103 | 0.089 | 44.170 | 0.305 | 0.086 |
| google/gemini-3.1-pro-preview | agent | 0.033 | 0.025 | 9.230 | 0.171 | 0.027 | 0.016 | 0.010 | 6.010 | 0.074 | 0.012 | 0.104 | 0.091 | 22.770 | 0.336 | 0.090 |
| openai/gpt-5.4-pro | agent | 0.032 | 0.025 | 5.340 | 0.132 | 0.027 | 0.015 | 0.009 | 5.340 | 0.065 | 0.011 | 0.099 | 0.091 | 2.550 | 0.282 | 0.091 |
| qwen3.6-plus:free | agent | 0.065 | 0.021 | 745.500 | 0.190 | 0.018 | 0.028 | 0.009 | 265.320 | 0.077 | 0.009 | 0.141 | 0.086 | 545.660 | 0.365 | 0.079 |
| qwen3.6-plus:free | agent | 0.058 | 0.021 | 610.610 | 0.225 | 0.017 | 0.026 | 0.009 | 235.050 | 0.099 | 0.009 | 0.133 | 0.084 | 449.860 | 0.529 | 0.078 |