optimization-pac-bayes-bound

OptimizationPBBrigorous codebase

Description

Task: PAC-Bayes Generalization Bound Optimization

Research Question

Design a tighter PAC-Bayes generalization bound by optimizing the bound formulation, prior/posterior parameterization, and KL divergence estimation for stochastic neural networks.

Background

PAC-Bayes theory provides non-vacuous generalization bounds for stochastic classifiers. Given a prior distribution P over hypotheses (chosen before seeing data) and a posterior Q (learned from data), PAC-Bayes bounds certify that with high probability (1-delta), the true risk of a stochastic classifier sampled from Q is bounded.

The key components of a PAC-Bayes bound are:

  • Empirical risk: estimated loss of the stochastic predictor on training data
  • KL divergence: KL(Q||P) measuring complexity of the posterior relative to the prior
  • Bound formula: how these terms combine to yield the final certificate

Standard bounds include:

  • McAllester/Maurer: risk + sqrt(KL_term / (2n)) -- simple but loose
  • Catoni/Lambda: risk/(1-lam/2) + KL_term/(nlam(1-lam/2)) -- tighter with tuned lambda
  • Quadratic: (sqrt(risk + KL_term) + sqrt(KL_term))^2 -- better at low risk

The bound can be further tightened through:

  • Optimizing the bound functional form (beyond classical inequalities)
  • Better training objectives that minimize the bound directly
  • Improved risk certificate evaluation (e.g., PAC-Bayes-kl inversion)
  • Data-dependent prior construction
  • Tighter KL estimation or alternative divergence measures

What to Implement

Implement the BoundOptimizer class in custom_pac_bayes.py. You must implement:

  1. compute_bound(empirical_risk, kl, n, delta): The PAC-Bayes bound formula
  2. train_step(model, data, target, device, n_bound, delta): Training objective
  3. compute_risk_certificate(model, bound_loader, device, delta, mc_samples): Final certificate evaluation

Interface

  • model(x, sample=True/False): stochastic forward pass (sample=True) or posterior mean (sample=False)
  • get_total_kl(model): sum of KL divergence across all probabilistic layers
  • inv_kl(q, c): binary KL inversion -- find p such that KL(Ber(q)||Ber(p)) = c
  • compute_01_risk(model, loader, device, mc_samples): MC estimate of 0-1 risk
  • Available losses: F.nll_loss, F.cross_entropy on log_softmax outputs

Evaluation

The bound optimizer is tested on three settings:

  1. MNIST-FCN: 4-layer fully connected network (784-600-600-600-10) on MNIST
  2. MNIST-CNN: 4-layer CNN (2 conv + 2 fc) on MNIST
  3. FashionMNIST-CNN: Same CNN architecture on FashionMNIST

Primary metric: risk_certificate (0-1 loss PAC-Bayes bound) -- lower is better (tighter bound).

Training uses data-dependent priors: 50% of training data trains a deterministic prior, 50% evaluates the bound.

Code

custom_pac_bayes.py
EditableRead-only
1"""PAC-Bayes Bound Optimization — custom template.
2
3This script trains a stochastic neural network by minimizing a PAC-Bayes
4bound and then evaluates the tightness of the resulting risk certificate.
5
6The agent edits the EDITABLE section (BoundOptimizer class) which controls:
7 1. How the PAC-Bayes bound is computed from empirical risk + KL divergence
8 2. How the posterior distribution is optimized (training objective)
9 3. How the final risk certificate is evaluated
10
11Fixed sections handle data loading, model architecture, stochastic layers,
12and the outer training loop.
13"""
14
15import argparse

Additional context files (read-only):

  • PBB/pbb/models.py
  • PBB/pbb/utils.py

Results

ModelTyperisk certificate mnist-fcn test error mnist-fcn kl divergence mnist-fcn ce bound mnist-fcn empirical 01 risk mnist-fcn risk certificate mnist-cnn test error mnist-cnn kl divergence mnist-cnn ce bound mnist-cnn empirical 01 risk mnist-cnn risk certificate fmnist-cnn test error fmnist-cnn kl divergence fmnist-cnn ce bound fmnist-cnn empirical 01 risk fmnist-cnn
catonibaseline0.0640.021731.7670.2320.0170.0280.009261.0400.1030.0090.1340.088445.0630.5240.078
mcallesterbaseline0.0370.02371.9000.1660.0240.0160.00910.0000.0750.0120.1020.08938.3200.3050.085
quadraticbaseline0.0410.021151.8570.1600.0220.0180.01046.6570.0710.0110.0990.09029.7800.3020.084
anthropic/claude-opus-4.6vanilla0.0390.022113.5700.1610.0230.0170.01038.2000.0680.0110.1000.08825.4100.3050.086
deepseek-reasonervanilla---------------
google/gemini-3.1-pro-previewvanilla0.0330.0259.2300.1710.0270.0160.0106.0100.0740.0120.1040.09122.7700.3360.090
openai/gpt-5.4-provanilla0.0320.0255.3400.1320.0270.0150.0095.3400.0650.0110.0990.0912.5500.2820.091
qwen3.6-plus:freevanilla0.0580.021610.6100.2250.0170.0260.009235.0500.0990.0090.1330.084449.8600.5290.078
anthropic/claude-opus-4.6agent0.0360.02276.2300.1590.0230.0170.01028.7700.0660.0110.0990.08923.3200.3120.085
deepseek-reasoneragent0.0370.02371.7800.1660.0240.0160.01010.0800.0750.0120.1030.08944.1700.3050.086
google/gemini-3.1-pro-previewagent0.0330.0259.2300.1710.0270.0160.0106.0100.0740.0120.1040.09122.7700.3360.090
openai/gpt-5.4-proagent0.0320.0255.3400.1320.0270.0150.0095.3400.0650.0110.0990.0912.5500.2820.091
qwen3.6-plus:freeagent0.0650.021745.5000.1900.0180.0280.009265.3200.0770.0090.1410.086545.6600.3650.079
qwen3.6-plus:freeagent0.0580.021610.6100.2250.0170.0260.009235.0500.0990.0090.1330.084449.8600.5290.078

Agent Conversations