ml-selective-deferral
Description
Selective Deferral Under Subgroup Shift
Research Question
Design a practical selective prediction and deferral policy for high-stakes tabular decisions.
The task isolates one modular question: given a fixed base classifier, what acceptance / deferral rule best trades off selective risk, subgroup fairness, and overall discrimination?
Background
Selective prediction systems should be able to say "I do not know" when the classifier is uncertain. In high-stakes settings, that deferral can be handed to a human reviewer or a slower backup process. The benchmark studies whether a policy can:
- keep selective risk low at a fixed target coverage,
- avoid concentrating deferrals on one subgroup,
- preserve AUROC as a confidence ranking signal, and
- remain simple enough to run offline on modest compute.
Task
Modify the SelectivePolicy class in custom_selective.py. The rest of the pipeline is fixed: dataset loading, train / calibration / test splitting, base model training, and metric computation.
The policy receives calibration-time base-model probabilities and subgroup labels, then decides whether each test example should be accepted or deferred. You may implement a single global threshold, a learned deferral score, subgroup-specific thresholds, or any other compact policy that fits the interface.
Evaluation
The benchmark runs on four offline tabular proxies from scikit-learn:
breast_cancerdiabetes(binarized around the training-set median)california(binarized around the training-set median)madelon
Each dataset is split into train / calibration / test partitions. Subgroups are formed from a stable feature threshold so that worst-group behavior can be measured.
Metrics:
selective_risk_at80: classification error on accepted examples at 80% target coverageworst_group_selective_risk: worst subgroup error on accepted examplesdeferral_rate_gap: max subgroup deferral rate minus min subgroup deferral rateauroc: AUROC of the acceptance score for predicting correctness
Baselines
confidence_thresholding: tune one confidence threshold to hit the target coverageconformal_abstention: split-conformal abstention with a coverage targetlearned_deferral: train a compact meta-model that predicts whether the base model will be correctgroupwise_thresholding: subgroup-specific thresholds as a stronger reference baseline
Practical Notes
This task intentionally uses datasets that are already available offline in the repository's scikit-learn package setup, so no shared package edit is required.
If you want the exact adult / ACSIncome / COMPAS / Law School datasets from the prompt, that would require a shared package-data follow-up, which I have not implemented here.
Code
1"""Selective prediction / deferral benchmark.23Fixed:4- offline dataset loading5- train / calibration / test splits6- base classifier training7- metric computation89Editable:10- SelectivePolicy, which decides whether to accept or defer predictions11based on calibration outputs.12"""1314from __future__ import annotations15
Results
| Model | Type | selective risk at80 breast cancer ↓ | coverage at80 breast cancer ↑ | worst group selective risk breast cancer ↓ | deferral rate gap breast cancer ↓ | auroc breast cancer ↑ | selective risk at80 diabetes ↓ | coverage at80 diabetes ↑ | worst group selective risk diabetes ↓ | deferral rate gap diabetes ↓ | auroc diabetes ↑ | selective risk at80 california ↓ | coverage at80 california ↑ | worst group selective risk california ↓ | deferral rate gap california ↓ | auroc california ↑ | selective risk at80 madelon ↓ | coverage at80 madelon ↑ | worst group selective risk madelon ↓ | deferral rate gap madelon ↓ | auroc madelon ↑ |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| confidence_thresholding | baseline | 0.000 | 0.711 | 0.000 | 0.063 | 0.955 | 0.192 | 0.876 | 0.211 | 0.093 | 0.757 | 0.121 | 0.797 | 0.132 | 0.098 | 0.789 | 0.465 | 0.823 | 0.487 | 0.042 | 0.495 |
| conformal_abstention | baseline | 0.000 | 0.711 | 0.000 | 0.063 | 0.955 | 0.192 | 0.876 | 0.211 | 0.093 | 0.757 | 0.121 | 0.797 | 0.132 | 0.097 | 0.789 | 0.464 | 0.825 | 0.487 | 0.046 | 0.495 |
| groupwise_thresholding | baseline | 0.000 | 0.719 | 0.000 | 0.273 | 0.955 | 0.203 | 0.888 | 0.231 | 0.118 | 0.757 | 0.125 | 0.804 | 0.151 | 0.014 | 0.789 | 0.467 | 0.815 | 0.493 | 0.064 | 0.495 |
| learned_deferral | baseline | 0.000 | 0.728 | 0.000 | 0.137 | 0.964 | 0.164 | 0.820 | 0.176 | 0.017 | 0.807 | 0.122 | 0.797 | 0.139 | 0.050 | 0.788 | 0.458 | 0.781 | 0.472 | 0.487 | 0.493 |
| anthropic/claude-opus-4.6 | vanilla | 0.012 | 0.737 | 0.019 | 0.233 | 0.977 | 0.155 | 0.798 | 0.222 | 0.149 | 0.792 | 0.124 | 0.806 | 0.150 | 0.028 | 0.799 | 0.460 | 0.787 | 0.474 | 0.039 | 0.501 |
| deepseek-reasoner | vanilla | 0.010 | 0.921 | 0.019 | 0.143 | 0.852 | 0.209 | 0.966 | 0.211 | 0.073 | 0.702 | 0.124 | 0.812 | 0.130 | 0.140 | 0.787 | 0.459 | 0.956 | 0.472 | 0.098 | 0.501 |
| google/gemini-3.1-pro-preview | vanilla | 0.000 | 0.719 | 0.000 | 0.118 | 0.948 | 0.210 | 0.910 | 0.231 | 0.076 | 0.759 | 0.118 | 0.803 | 0.134 | 0.045 | 0.816 | 0.455 | 0.971 | 0.466 | 0.006 | 0.571 |
| openai/gpt-5.4 | vanilla | 0.000 | 0.754 | 0.000 | 0.054 | 0.959 | 0.195 | 0.865 | 0.250 | 0.205 | 0.804 | 0.125 | 0.804 | 0.147 | 0.012 | 0.786 | 0.459 | 0.817 | 0.463 | 0.181 | 0.491 |
| qwen/qwen3.6-plus | vanilla | 0.000 | 0.596 | 0.000 | 0.512 | 0.959 | 0.209 | 0.966 | 0.250 | 0.017 | 0.798 | 0.171 | 0.969 | 0.203 | 0.064 | 0.789 | 0.458 | 1.000 | 0.472 | 0.000 | 0.488 |
| anthropic/claude-opus-4.6 | agent | 0.012 | 0.737 | 0.019 | 0.233 | 0.977 | 0.155 | 0.798 | 0.222 | 0.149 | 0.792 | 0.124 | 0.806 | 0.150 | 0.028 | 0.799 | 0.460 | 0.787 | 0.474 | 0.039 | 0.501 |
| deepseek-reasoner | agent | 0.010 | 0.921 | 0.019 | 0.143 | 0.852 | 0.209 | 0.966 | 0.211 | 0.073 | 0.702 | 0.124 | 0.812 | 0.130 | 0.140 | 0.787 | 0.459 | 0.956 | 0.472 | 0.098 | 0.501 |
| google/gemini-3.1-pro-preview | agent | 0.000 | 0.737 | 0.000 | 0.127 | 0.968 | 0.160 | 0.843 | 0.222 | 0.066 | 0.821 | 0.123 | 0.808 | 0.150 | 0.004 | 0.799 | 0.436 | 0.790 | 0.452 | 0.039 | 0.543 |
| openai/gpt-5.4 | agent | 0.000 | 0.754 | 0.000 | 0.054 | 0.959 | 0.195 | 0.865 | 0.250 | 0.205 | 0.804 | 0.125 | 0.804 | 0.147 | 0.012 | 0.786 | 0.459 | 0.817 | 0.463 | 0.181 | 0.491 |
| qwen/qwen3.6-plus | agent | 0.000 | 0.719 | 0.000 | 0.237 | 0.968 | 0.179 | 0.876 | 0.222 | 0.003 | 0.802 | 0.126 | 0.801 | 0.151 | 0.009 | 0.789 | 0.460 | 0.823 | 0.475 | 0.036 | 0.494 |