ml-selective-deferral

Classical MLscikit-learnrigorous codebase

Description

Selective Deferral Under Subgroup Shift

Research Question

Design a practical selective prediction and deferral policy for high-stakes tabular decisions.

The task isolates one modular question: given a fixed base classifier, what acceptance / deferral rule best trades off selective risk, subgroup fairness, and overall discrimination?

Background

Selective prediction systems should be able to say "I do not know" when the classifier is uncertain. In high-stakes settings, that deferral can be handed to a human reviewer or a slower backup process. The benchmark studies whether a policy can:

  • keep selective risk low at a fixed target coverage,
  • avoid concentrating deferrals on one subgroup,
  • preserve AUROC as a confidence ranking signal, and
  • remain simple enough to run offline on modest compute.

Task

Modify the SelectivePolicy class in custom_selective.py. The rest of the pipeline is fixed: dataset loading, train / calibration / test splitting, base model training, and metric computation.

The policy receives calibration-time base-model probabilities and subgroup labels, then decides whether each test example should be accepted or deferred. You may implement a single global threshold, a learned deferral score, subgroup-specific thresholds, or any other compact policy that fits the interface.

Evaluation

The benchmark runs on four offline tabular proxies from scikit-learn:

  • breast_cancer
  • diabetes (binarized around the training-set median)
  • california (binarized around the training-set median)
  • madelon

Each dataset is split into train / calibration / test partitions. Subgroups are formed from a stable feature threshold so that worst-group behavior can be measured.

Metrics:

  • selective_risk_at80: classification error on accepted examples at 80% target coverage
  • worst_group_selective_risk: worst subgroup error on accepted examples
  • deferral_rate_gap: max subgroup deferral rate minus min subgroup deferral rate
  • auroc: AUROC of the acceptance score for predicting correctness

Baselines

  • confidence_thresholding: tune one confidence threshold to hit the target coverage
  • conformal_abstention: split-conformal abstention with a coverage target
  • learned_deferral: train a compact meta-model that predicts whether the base model will be correct
  • groupwise_thresholding: subgroup-specific thresholds as a stronger reference baseline

Practical Notes

This task intentionally uses datasets that are already available offline in the repository's scikit-learn package setup, so no shared package edit is required.

If you want the exact adult / ACSIncome / COMPAS / Law School datasets from the prompt, that would require a shared package-data follow-up, which I have not implemented here.

Code

custom_selective.py
EditableRead-only
1"""Selective prediction / deferral benchmark.
2
3Fixed:
4- offline dataset loading
5- train / calibration / test splits
6- base classifier training
7- metric computation
8
9Editable:
10- SelectivePolicy, which decides whether to accept or defer predictions
11 based on calibration outputs.
12"""
13
14from __future__ import annotations
15

Results

ModelTypeselective risk at80 breast cancer coverage at80 breast cancer worst group selective risk breast cancer deferral rate gap breast cancer auroc breast cancer selective risk at80 diabetes coverage at80 diabetes worst group selective risk diabetes deferral rate gap diabetes auroc diabetes selective risk at80 california coverage at80 california worst group selective risk california deferral rate gap california auroc california selective risk at80 madelon coverage at80 madelon worst group selective risk madelon deferral rate gap madelon auroc madelon
confidence_thresholdingbaseline0.0000.7110.0000.0630.9550.1920.8760.2110.0930.7570.1210.7970.1320.0980.7890.4650.8230.4870.0420.495
conformal_abstentionbaseline0.0000.7110.0000.0630.9550.1920.8760.2110.0930.7570.1210.7970.1320.0970.7890.4640.8250.4870.0460.495
groupwise_thresholdingbaseline0.0000.7190.0000.2730.9550.2030.8880.2310.1180.7570.1250.8040.1510.0140.7890.4670.8150.4930.0640.495
learned_deferralbaseline0.0000.7280.0000.1370.9640.1640.8200.1760.0170.8070.1220.7970.1390.0500.7880.4580.7810.4720.4870.493
anthropic/claude-opus-4.6vanilla0.0120.7370.0190.2330.9770.1550.7980.2220.1490.7920.1240.8060.1500.0280.7990.4600.7870.4740.0390.501
deepseek-reasonervanilla0.0100.9210.0190.1430.8520.2090.9660.2110.0730.7020.1240.8120.1300.1400.7870.4590.9560.4720.0980.501
google/gemini-3.1-pro-previewvanilla0.0000.7190.0000.1180.9480.2100.9100.2310.0760.7590.1180.8030.1340.0450.8160.4550.9710.4660.0060.571
openai/gpt-5.4vanilla0.0000.7540.0000.0540.9590.1950.8650.2500.2050.8040.1250.8040.1470.0120.7860.4590.8170.4630.1810.491
qwen/qwen3.6-plusvanilla0.0000.5960.0000.5120.9590.2090.9660.2500.0170.7980.1710.9690.2030.0640.7890.4581.0000.4720.0000.488
anthropic/claude-opus-4.6agent0.0120.7370.0190.2330.9770.1550.7980.2220.1490.7920.1240.8060.1500.0280.7990.4600.7870.4740.0390.501
deepseek-reasoneragent0.0100.9210.0190.1430.8520.2090.9660.2110.0730.7020.1240.8120.1300.1400.7870.4590.9560.4720.0980.501
google/gemini-3.1-pro-previewagent0.0000.7370.0000.1270.9680.1600.8430.2220.0660.8210.1230.8080.1500.0040.7990.4360.7900.4520.0390.543
openai/gpt-5.4agent0.0000.7540.0000.0540.9590.1950.8650.2500.2050.8040.1250.8040.1470.0120.7860.4590.8170.4630.1810.491
qwen/qwen3.6-plusagent0.0000.7190.0000.2370.9680.1790.8760.2220.0030.8020.1260.8010.1510.0090.7890.4600.8230.4750.0360.494

Agent Conversations