ml-subgroup-calibration-shift

Classical MLscikit-learnrigorous codebase

Description

Subgroup Calibration Under Distribution Shift

Research Question

Design a post-hoc calibration method that remains reliable when subgroup composition shifts between calibration and test time.

Background

Many calibration methods look good on average but fail on protected or operational subgroups once the test distribution shifts. This task isolates that failure mode. The fixed pipeline trains a tabular classifier, then applies a user-defined calibration mapping on held-out calibration data before evaluation on shifted test data.

Classical baselines include:

  • Temperature scaling: one global temperature for all samples
  • Isotonic regression: non-parametric monotone calibration
  • Beta calibration: a richer parametric mapping on probabilities
  • Group-wise temperature scaling: separate temperatures per subgroup

Task

Modify the CalibrationMethod class in custom_subgroup_calibration.py. The fixed code loads data, creates a shifted split, trains the base classifier, and computes metrics. Your method only controls the post-hoc calibration mapping.

class CalibrationMethod:
    def fit(self, probs, labels, groups=None):
        ...

    def predict_proba(self, probs, groups=None):
        ...

Inputs are positive-class probabilities from the base classifier. groups contains subgroup IDs when available and may be ignored by group-agnostic methods.

Evaluation

This benchmark uses three lightweight tabular proxies that are already available in the current scikit-learn package setup. We would normally prefer Adult, ACSIncome, COMPAS, and Law School Admissions, but those require package-level data changes that are outside this task directory. To keep the benchmark runnable offline, we use cached scikit-learn datasets with similar calibration and subgroup-shift behavior:

  • breast_cancer: binary classification on the scikit-learn breast cancer dataset
  • california_housing: binary high-value/low-value decision built from California housing
  • diabetes: binary high-risk/low-risk decision built from the diabetes target

For each dataset, the split is intentionally shifted:

  • a domain score determines the held-out test tail
  • subgroup labels are quartiles of a separate proxy feature
  • calibration is fit on the source region and evaluated on the shifted region

Metrics

Lower is better for:

  • worst_group_ece
  • brier
  • max_subgroup_gap

Higher is better for:

  • subgroup_auroc

Notes

  • The task is deliberately low compute and should run with a small tabular classifier.
  • If you need the exact Adult/ACSIncome/COMPAS/Law School datasets, they should be added through a package-level data change, not inside this task directory.

Code

custom_subgroup_calibration.py
EditableRead-only
1"""Subgroup calibration under distribution shift.
2
3The benchmark is intentionally offline and low compute. It uses cached
4scikit-learn tabular proxies instead of downloading Adult/ACSIncome/COMPAS/
5Law School because this task directory cannot change package-level data setup.
6
7Fixed:
8- dataset loading
9- shifted train/calibration/test split
10- base classifier training
11- metric computation
12
13Editable:
14- CalibrationMethod
15"""

Results

ModelTypeworst group ece breast cancer brier breast cancer subgroup auroc breast cancer max subgroup gap breast cancer worst group ece diabetes brier diabetes subgroup auroc diabetes max subgroup gap diabetes worst group ece california housing brier california housing subgroup auroc california housing max subgroup gap california housing
beta_calibrationbaseline0.1850.1120.9850.1660.1450.1600.7650.0720.3790.3230.9910.124
group_temperature_scalingbaseline0.3380.1790.9600.3300.1690.1710.7650.0620.3770.3110.9910.103
isotonic_regressionbaseline0.2330.1290.9750.2170.1640.1620.7700.0730.3800.3300.9000.093
temperature_scalingbaseline0.3490.1810.9410.3410.1310.1630.7650.0410.3710.3100.9910.107
anthropic/claude-opus-4.6vanilla0.3600.1800.9560.3520.1540.1620.7650.0620.3740.3090.9910.133
deepseek-reasonervanilla0.3490.1810.9410.3410.1940.1690.7650.0870.3770.3110.9910.104
google/gemini-3.1-pro-previewvanilla0.3200.1550.9890.3140.1130.1600.7650.0440.3750.3150.9910.108
qwen/qwen3.6-plusvanilla0.0970.0490.9850.0850.1540.1680.7650.0480.4010.2840.9910.286
anthropic/claude-opus-4.6agent0.1800.1230.9850.1240.1430.1620.7650.0640.3730.3160.9910.117
deepseek-reasoneragent0.2300.1310.9850.2070.1780.1600.7650.0820.3760.3120.9910.129
google/gemini-3.1-pro-previewagent0.3200.1550.9890.3140.1130.1600.7650.0440.3750.3150.9910.108
openai/gpt-5.4agent------------
qwen/qwen3.6-plusagent0.0970.0490.9850.0850.1540.1680.7650.0480.4010.2840.9910.286

Agent Conversations