Agent Conversation: deepseek-reasoner
Task: ml-subgroup-calibration-shift
Subgroup Calibration Under Distribution Shift
Research Question
Design a post-hoc calibration method that remains reliable when subgroup composition shifts between calibration and test time.
Background
Many calibration methods look good on average but fail on protected or operational subgroups once the test distribution shifts. This task isolates that failure mode. The fixed pipeline trains a tabular classifier, then applies a user-defined calibration mapping on held-out calibration data before evaluation on shifted test data.
Classical baselines include:
- Temperature scaling: one global temperature for all samples
- Isotonic regression: non-parametric monotone calibration
- Beta calibration: a richer parametric mapping on probabilities
- Group-wise temperature scaling: separate temperatures per subgroup
Task
Modify the CalibrationMethod class in custom_subgroup_calibration.py. The fixed code loads data, creates a shifted split, trains the base classifier, and computes metrics. Your method only controls the post-hoc calibration mapping.
class CalibrationMethod:
def fit(self, probs, labels, groups=None):
...
def predict_proba(self, probs, groups=None):
...
Inputs are positive-class probabilities from the base classifier. groups contains subgroup IDs when available and may be ignored by group-agnostic methods.
Evaluation
This benchmark uses three lightweight tabular proxies that are already available in the current scikit-learn package setup. We would normally prefer Adult, ACSIncome, COMPAS, and Law School Admissions, but those require package-level data changes that are outside this task directory. To keep the benchmark runnable offline, we use cached scikit-learn datasets with similar calibration and subgroup-shift behavior:
- breast_cancer: binary classification on the scikit-learn breast cancer dataset
- california_housing: binary high-value/low-value decision built from California housing
- diabetes: binary high-risk/low-risk decision built from the diabetes target
For each dataset, the split is intentionally shifted:
- a domain score determines the held-out test tail
- subgroup labels are quartiles of a separate proxy feature
- calibration is fit on the source region and evaluated on the shifted region
Metrics
Lower is better for:
- worst_group_ece
- brier
- max_subgroup_gap
Higher is better for:
- subgroup_auroc
Notes
- The task is deliberately low compute and should run with a small tabular classifier.
- If you need the exact Adult/ACSIncome/COMPAS/Law School datasets, they should be added through a package-level data change, not inside this task directory.
scikit-learn/custom_subgroup_calibration.py [EDITABLE — lines 72–91 only]
1: """Subgroup calibration under distribution shift.
2:
3: The benchmark is intentionally offline and low compute. It usesTool Input
1{
2 "op": "replace",
3 "filename": "scikit-learn/custom_subgroup_calibration.py",
4 "start_line": 72,
5 "end_line": 91,
6 "content": "class CalibrationMethod:\n \"\"\"Regularized Group Temperature Scaling (RGTS).\n \n Fits a global temperature, then regularized group-specific temperatures\n with shrinkage towards the global prior. The regularization strength\n adapts to group size to prevent overfitting.\n \"\"\"\n\n def __init__(self):\n self.eps = 1e-6\n self.global_temp_ = 1.0\n self.group_temps_ = {}\n # regularization hyperparameter base\n self.reg_strength = 0.1 # fixed base strength\n self.min_group_size = 8\n\n def _fit_temperature(self, probs, labels, prior_temp=1.0, reg_weight=0.0):\n probs = np.asarray(probs).reshape(-1)\n labels = np.asarray(labels).reshape(-1).astype(int)\n logits = special.logit(np.clip(probs, self.eps, 1.0 - self.eps))\n\n def objective(log_t):\n t = float(np.exp(log_t))\n cal = special.expit(logits / t)\n p = np.clip(cal, self.eps, 1.0 - self.eps)\n nll = -np.mean(labels * np.log(p) + (1 - labels) * np.log(1 - p))\n # regularization penalty: squared log deviation from prior\n penalty = reg_weight * (log_t - np.log(prior_temp))**2\n return float(nll + penalty)\n\n result = optimize.minimize_scalar(objective, bounds=(-3.0, 3.0), method=\"bounded\")\n return float(np.exp(result.x)) if result.success else prior_temp\n\n def fit(self, probs, labels, groups=None):\n probs = np.asarray(probs).reshape(-1)\n labels = np.asarray(labels).reshape(-1).astype(int)\n # global temperature without regularization\n self.global_temp_ = self._fit_temperature(probs, labels, prior_temp=1.0, reg_weight=0.0)\n \n self.group_temps_ = {}\n if groups is None:\n return self\n \n groups = np.asarray(groups).reshape(-1)\n unique_groups = np.unique(groups)\n for g in unique_groups:\n mask = groups == g\n g_size = mask.sum()\n if g_size < self.min_group_size or np.unique(labels[mask]).size < 2:\n # insufficient data, fallback to global\n self.group_temps_[int(g)] = self.global_temp_\n else:\n # regularization weight: base strength divided by sqrt(group size)\n # larger groups get less regularization\n reg_weight = self.reg_strength / np.sqrt(g_size)\n temp = self._fit_temperature(\n probs[mask], labels[mask],\n prior_temp=self.global_temp_,\n reg_weight=reg_weight\n )\n self.group_temps_[int(g)] = temp\n return self\n\n def predict_proba(self, probs, groups=None):\n probs = np.asarray(probs).reshape(-1)\n logits = special.logit(np.clip(probs, self.eps, 1.0 - self.eps))\n if groups is None:\n temp = self.global_temp_\n return np.clip(special.expit(logits / temp), self.eps, 1.0 - self.eps)\n \n groups = np.asarray(groups).reshape(-1)\n out = np.empty_like(probs)\n for g in np.unique(groups):\n mask = groups == g\n temp = self.group_temps_.get(int(g), self.global_temp_)\n out[mask] = special.expit(logits[mask] / temp)\n return np.clip(out, self.eps, 1.0 - self.eps)"
7}OK: Replaced lines 72..91 with 77 line(s) in scikit-learn/custom_subgroup_calibration.py. Editable range: 72–148.
[Current file: scikit-learn/custom_subgroup_calibration.py | editable: 72–148 | total: 289 lines]
72 class CalibrationMethod:
73 """Regularized Group Temperature Scaling (RGTS).
74
... (71 more lines) ...
146 temp = self.group_temps_.get(int(g), self.global_temp_)
147 out[mask] = special.expit(logits[mask] / temp)
148 return np.clip(out, self.eps, 1.0 - self.eps)[Test #1] (2 tests remaining; call submit(n=N) to choose which test result to submit as final)
### breast_cancer (scripts/breast_cancer.sh)
Training progress:
TRAIN_METRICS: dataset=breast_cancer cal_ece_before=0.014909 cal_ece_after=0.000001 cal_brier_before=0.001191 cal_brier_after=0.000000
Final metrics (breast_cancer): worst_group_ece_breast_cancer=0.348658, brier_breast_cancer=0.181076, subgroup_auroc_breast_cancer=0.940760, max_subgroup_gap_breast_cancer=0.341088
### diabetes (scripts/diabetes.sh)
Training progress:
TRAIN_METRICS: dataset=diabetes cal_ece_before=0.141277 cal_ece_after=0.051129 cal_brier_before=0.154074 cal_brier_after=0.149920
Final metrics (diabetes): worst_group_ece_diabetes=0.194176, brier_diabetes=0.168511, subgroup_auroc_diabetes=0.765085, max_subgroup_gap_diabetes=0.086521
[Leaderboard] Results saved: {'worst_group_ece_breast_cancer': 0.348658, 'brier_breast_cancer': 0.181076, 'subgroup_auroc_breast_cancer': 0.94076, 'max_subgroup_gap_breast_cancer': 0.3
Tool Input
1{
2 "op": "replace",
3 "filename": "scikit-learn/custom_subgroup_calibration.py",
4 "start_line": 72,
5 "end_line": 148,
6 "content": "class CalibrationMethod:\n \"\"\"Group-Adaptive Platt Scaling with Hierarchical Shrinkage (GAPS).\n \n Fits global Platt scaling (slope + intercept) via logistic regression on logits.\n For each group, fits group-specific Platt parameters regularized towards\n the global parameters with strength inversely proportional to group size.\n \"\"\"\n\n def __init__(self):\n self.eps = 1e-6\n self.global_w_ = 1.0 # slope\n self.global_b_ = 0.0 # intercept\n self.group_params_ = {} # dict g -> (w, b)\n self.reg_lambda = 1.0 # base regularization strength\n self.min_group_size = 8\n\n def _platt_transform(self, probs, w, b):\n \"\"\"Apply Platt scaling: sigmoid(w * logit(p) + b).\"\"\"\n p = np.clip(probs, self.eps, 1.0 - self.eps)\n logits = special.logit(p)\n return special.expit(w * logits + b)\n\n def _fit_platt(self, probs, labels, prior_w=1.0, prior_b=0.0, reg_weight=0.0):\n \"\"\"MAP estimation of Platt parameters with L2 regularization towards prior.\"\"\"\n probs = np.asarray(probs).reshape(-1)\n labels = np.asarray(labels).reshape(-1).astype(int)\n p = np.clip(probs, self.eps, 1.0 - self.eps)\n logits = special.logit(p)\n # feature matrix: logits (slope) and intercept\n X = np.column_stack([logits, np.ones_like(logits)])\n # negative log-likelihood with L2 penalty\n def objective(params):\n w, b = params[0], params[1]\n z = w * logits + b\n # stable sigmoid\n p_cal = special.expit(z)\n p_cal = np.clip(p_cal, self.eps, 1.0 - self.eps)\n nll = -np.mean(labels * np.log(p_cal) + (1 - labels) * np.log(1 - p_cal))\n # L2 penalty on deviation from prior\n penalty = reg_weight * ((w - prior_w)**2 + (b - prior_b)**2)\n return nll + penalty\n # initial guess from prior\n init = np.array([prior_w, prior_b])\n bounds = [(-5.0, 5.0), (-5.0, 5.0)] # reasonable bounds\n result = optimize.minimize(objective, init, method='L-BFGS-B', bounds=bounds)\n if result.success:\n w, b = result.x\n return float(w), float(b)\n else:\n return prior_w, prior_b\n\n def fit(self, probs, labels, groups=None):\n probs = np.asarray(probs).reshape(-1)\n labels = np.asarray(labels).reshape(-1).astype(int)\n # global Platt scaling (no regularization)\n self.global_w_, self.global_b_ = self._fit_platt(\n probs, labels, prior_w=1.0, prior_b=0.0, reg_weight=0.0\n )\n \n self.group_params_ = {}\n if groups is None:\n return self\n \n groups = np.asarray(groups).reshape(-1)\n unique_groups = np.unique(groups)\n for g in unique_groups:\n mask = groups == g\n g_size = mask.sum()\n if g_size < self.min_group_size or np.unique(labels[mask]).size < 2:\n self.group_params_[int(g)] = (self.global_w_, self.global_b_)\n else:\n # regularization weight: base lambda divided by group size\n reg_weight = self.reg_lambda / g_size\n w, b = self._fit_platt(\n probs[mask], labels[mask],\n prior_w=self.global_w_,\n prior_b=self.global_b_,\n reg_weight=reg_weight\n )\n self.group_params_[int(g)] = (w, b)\n return self\n\n def predict_proba(self, probs, groups=None):\n probs = np.asarray(probs).reshape(-1)\n if groups is None:\n w, b = self.global_w_, self.global_b_\n return np.clip(self._platt_transform(probs, w, b), self.eps, 1.0 - self.eps)\n \n groups = np.asarray(groups).reshape(-1)\n out = np.empty_like(probs)\n for g in np.unique(groups):\n mask = groups == g\n w, b = self.group_params_.get(int(g), (self.global_w_, self.global_b_))\n out[mask] = self._platt_transform(probs[mask], w, b)\n return np.clip(out, self.eps, 1.0 - self.eps)"
7}OK: Replaced lines 72..148 with 95 line(s) in scikit-learn/custom_subgroup_calibration.py. Editable range: 72–166.
[Current file: scikit-learn/custom_subgroup_calibration.py | editable: 72–166 | total: 307 lines]
72 class CalibrationMethod:
73 """Group-Adaptive Platt Scaling with Hierarchical Shrinkage (GAPS).
74
... (89 more lines) ...
164 w, b = self.group_params_.get(int(g), (self.global_w_, self.global_b_))
165 out[mask] = self._platt_transform(probs[mask], w, b)
166 return np.clip(out, self.eps, 1.0 - self.eps)[Test #2] (1 test remaining; call submit(n=N) to choose which test result to submit as final)
### breast_cancer (scripts/breast_cancer.sh)
Training progress:
TRAIN_METRICS: dataset=breast_cancer cal_ece_before=0.014909 cal_ece_after=0.000007 cal_brier_before=0.001191 cal_brier_after=0.000000
Final metrics (breast_cancer): worst_group_ece_breast_cancer=0.320289, brier_breast_cancer=0.155423, subgroup_auroc_breast_cancer=0.988741, max_subgroup_gap_breast_cancer=0.313956
### diabetes (scripts/diabetes.sh)
Training progress:
TRAIN_METRICS: dataset=diabetes cal_ece_before=0.141277 cal_ece_after=0.122354 cal_brier_before=0.154074 cal_brier_after=0.145969
Final metrics (diabetes): worst_group_ece_diabetes=0.169147, brier_diabetes=0.160340, subgroup_auroc_diabetes=0.765085, max_subgroup_gap_diabetes=0.080683
[Leaderboard] Results saved: {'worst_group_ece_breast_cancer': 0.320289, 'brier_breast_cancer': 0.155423, 'subgroup_auroc_breast_cancer': 0.988741, 'max_subgroup_gap_breast_cancer': 0.3
Tool Input
1{
2 "op": "replace",
3 "filename": "scikit-learn/custom_subgroup_calibration.py",
4 "start_line": 72,
5 "end_line": 166,
6 "content": "class CalibrationMethod:\n \"\"\"Robust Group Platt Scaling with Strong Regularization (RGPS).\n \n Platt scaling with slope and intercept per group, regularized towards\n identity mapping (slope=1, intercept=0). Regularization strength inversely\n proportional to group size. Parameters bounded to avoid extreme scaling.\n \"\"\"\n\n def __init__(self):\n self.eps = 1e-6\n self.global_w_ = 1.0\n self.global_b_ = 0.0\n self.group_params_ = {}\n self.reg_lambda = 2.0 # base regularization strength\n self.min_group_size = 8\n self.bound = 3.0 # max absolute deviation for slope and intercept\n\n def _platt_transform(self, probs, w, b):\n p = np.clip(probs, self.eps, 1.0 - self.eps)\n logits = special.logit(p)\n return special.expit(w * logits + b)\n\n def _fit_platt(self, probs, labels, prior_w=1.0, prior_b=0.0, reg_weight=0.0):\n probs = np.asarray(probs).reshape(-1)\n labels = np.asarray(labels).reshape(-1).astype(int)\n p = np.clip(probs, self.eps, 1.0 - self.eps)\n logits = special.logit(p)\n # negative log-likelihood with L2 regularization\n def objective(params):\n w, b = params[0], params[1]\n # enforce bounds via penalty (soft constraint)\n bound_penalty = 0.0\n if abs(w - prior_w) > self.bound:\n bound_penalty += 10.0 * (abs(w - prior_w) - self.bound)**2\n if abs(b - prior_b) > self.bound:\n bound_penalty += 10.0 * (abs(b - prior_b) - self.bound)**2\n z = w * logits + b\n p_cal = special.expit(z)\n p_cal = np.clip(p_cal, self.eps, 1.0 - self.eps)\n nll = -np.mean(labels * np.log(p_cal) + (1 - labels) * np.log(1 - p_cal))\n reg = reg_weight * ((w - prior_w)**2 + (b - prior_b)**2)\n return nll + reg + bound_penalty\n init = np.array([prior_w, prior_b])\n # hard bounds\n w_lower = prior_w - self.bound\n w_upper = prior_w + self.bound\n b_lower = prior_b - self.bound\n b_upper = prior_b + self.bound\n bounds = [(w_lower, w_upper), (b_lower, b_upper)]\n result = optimize.minimize(objective, init, method='L-BFGS-B', bounds=bounds)\n if result.success:\n w, b = result.x\n return float(w), float(b)\n else:\n return prior_w, prior_b\n\n def fit(self, probs, labels, groups=None):\n probs = np.asarray(probs).reshape(-1)\n labels = np.asarray(labels).reshape(-1).astype(int)\n # global Platt scaling with mild regularization\n self.global_w_, self.global_b_ = self._fit_platt(\n probs, labels, prior_w=1.0, prior_b=0.0, reg_weight=0.1\n )\n \n self.group_params_ = {}\n if groups is None:\n return self\n \n groups = np.asarray(groups).reshape(-1)\n unique_groups = np.unique(groups)\n for g in unique_groups:\n mask = groups == g\n g_size = mask.sum()\n if g_size < self.min_group_size or np.unique(labels[mask]).size < 2:\n self.group_params_[int(g)] = (self.global_w_, self.global_b_)\n else:\n # regularization weight: base lambda / group_size\n reg_weight = self.reg_lambda / g_size\n w, b = self._fit_platt(\n probs[mask], labels[mask],\n prior_w=self.global_w_,\n prior_b=self.global_b_,\n reg_weight=reg_weight\n )\n self.group_params_[int(g)] = (w, b)\n return self\n\n def predict_proba(self, probs, groups=None):\n probs = np.asarray(probs).reshape(-1)\n if groups is None:\n w, b = self.global_w_, self.global_b_\n return np.clip(self._platt_transform(probs, w, b), self.eps, 1.0 - self.eps)\n \n groups = np.asarray(groups).reshape(-1)\n out = np.empty_like(probs)\n for g in np.unique(groups):\n mask = groups == g\n w, b = self.group_params_.get(int(g), (self.global_w_, self.global_b_))\n out[mask] = self._platt_transform(probs[mask], w, b)\n return np.clip(out, self.eps, 1.0 - self.eps)"
7}OK: Replaced lines 72..166 with 100 line(s) in scikit-learn/custom_subgroup_calibration.py. Editable range: 72–171.
[Current file: scikit-learn/custom_subgroup_calibration.py | editable: 72–171 | total: 312 lines]
72 class CalibrationMethod:
73 """Robust Group Platt Scaling with Strong Regularization (RGPS).
74
... (94 more lines) ...
169 w, b = self.group_params_.get(int(g), (self.global_w_, self.global_b_))
170 out[mask] = self._platt_transform(probs[mask], w, b)
171 return np.clip(out, self.eps, 1.0 - self.eps)[Test #3] (0 tests remaining; call submit(n=N) to choose which test result to submit as final) [NOTE] This was your last test. You MUST now call submit(n=X) to choose which test result to submit as your final answer. ### breast_cancer (scripts/breast_cancer.sh) Training progress: TRAIN_METRICS: dataset=breast_cancer cal_ece_before=0.014909 cal_ece_after=0.007607 cal_brier_before=0.001191 cal_brier_after=0.000419 Final metrics (breast_cancer): worst_group_ece_breast_cancer=0.230231, brier_breast_cancer=0.131412, subgroup_auroc_breast_cancer=0.984643, max_subgroup_gap_breast_cancer=0.207263 ### diabetes (scripts/diabetes.sh) Training progress: TRAIN_METRICS: dataset=diabetes cal_ece_before=0.141277 cal_ece_after=0.099501 cal_brier_before=0.154074 cal_brier_after=0.147987 Final metrics (diabetes): worst_group_ece_diabetes=0.178284, brier_diabetes=0.160154, subgroup_auroc_diabetes=0.765085, max_subgroup_gap_diabetes=0.082443 ### california_housing (scripts/california_housing.sh) Traini
Tool Input
1{
2 "n": 3
3}[submit] Submitting result from test #3 as final. ### breast_cancer (scripts/breast_cancer.sh) Training progress: TRAIN_METRICS: dataset=breast_cancer cal_ece_before=0.014909 cal_ece_after=0.007607 cal_brier_before=0.001191 cal_brier_after=0.000419 Final metrics (breast_cancer): worst_group_ece_breast_cancer=0.230231, brier_breast_cancer=0.131412, subgroup_auroc_breast_cancer=0.984643, max_subgroup_gap_breast_cancer=0.207263 ### diabetes (scripts/diabetes.sh) Training progress: TRAIN_METRICS: dataset=diabetes cal_ece_before=0.141277 cal_ece_after=0.099501 cal_brier_before=0.154074 cal_brier_after=0.147987 Final metrics (diabetes): worst_group_ece_diabetes=0.178284, brier_diabetes=0.160154, subgroup_auroc_diabetes=0.765085, max_subgroup_gap_diabetes=0.082443 ### california_housing (scripts/california_housing.sh) Training progress: TRAIN_METRICS: dataset=california_housing cal_ece_before=0.006742 cal_ece_after=0.007216 cal_brier_before=0.098403 cal_brier_after=0.097736 Final metrics (