Agent Conversation: openai/gpt-5.4-pro
Task: security-membership-inference-defense
Membership Inference Defense via Training Regularization
Research Question
How can we design a stronger training-time regularizer that reduces membership inference leakage while preserving standard predictive accuracy?
Background
Membership inference attacks exploit the gap between train and non-train examples, often through confidence or loss statistics. Many defenses regularize predictions to reduce overconfidence and shrink the train-test generalization gap, but stronger privacy often hurts utility.
Task
Implement a better privacy-preserving training objective in bench/membership/custom_membership_defense.py. The fixed harness will train a model on a fixed dataset split, then run a confidence-based membership inference attack on train versus held-out examples.
Your method should improve the privacy-utility tradeoff: lower membership attack AUC while retaining high test accuracy.
Editable Interface
You must implement:
class MembershipDefense:
def compute_loss(self, logits, labels, model, epoch):
...
logits: model outputs for the current minibatchlabels: ground-truth labelsmodel: current model instance, exposed for optional parameter regularizationepoch: current training epoch- Return value: scalar loss tensor used by the fixed training loop
The optimizer, architecture, batch schedule, and attack implementation are fixed.
Evaluation
Benchmarks:
mnistfmnistcifar10
Reported metrics:
test_accmia_aucprivacy_gapprivacy_score
Primary metric: privacy_score (higher is better).
Baselines
erm: standard cross-entropy traininglabel_smoothing: smoothed targets to reduce overconfidenceconfidence_penalty: cross-entropy plus predictive entropy penaltyrelaxloss: margin-aware loss relaxation as a stronger privacy baseline
opacus/bench/membership/custom_membership_defense.py [EDITABLE — entire file only]
1: """Editable membership-inference defense for MLS-Bench."""
2:
3: import torch
4: import torch.nn.functional as F
5:
6: # ============================================================
7: # EDITABLE
8: # ============================================================
9: class MembershipDefense:
10: """Training-time regularizer for privacy-utility tradeoffs."""
11:
12: def __init__(self):
13: self.label_smoothing = 0.0
14: self.entropy_weight = 0.0
15:
16: def compute_loss(self, logits, labels, model, epoch):
17: ce = F.cross_entropy(logits, labels, label_smoothing=self.label_smoothing)
18: probs = torch.softmax(logits, dim=1)
19: entropy = -(probs * torch.log(probs.clamp_min(1e-8))).sum(dim=1).mean()
20: return ce - self.entropy_weight * entropy
21: # ============================================================
22Tool Input
1{
2 "op": "replace",
3 "filename": "pytorch-vision/custom_membership_defense.py",
4 "content": "\"\"\"Editable membership-inference defense for MLS-Bench.\"\"\"\n\nimport torch\nimport torch.nn.functional as F\n\n# ============================================================\n# EDITABLE\n# ============================================================\nclass MembershipDefense:\n \"\"\"Confidence-capped adaptive smoothing for membership privacy.\n\n Idea:\n - keep standard ERM behavior on hard / misclassified examples;\n - only regularize examples that are already correct and becoming\n overconfident, since those dominate confidence-based MIA signals;\n - use an EMA of correct-class confidence to set an adaptive target,\n making the defense robust across datasets with very different\n difficulty and class count.\n \"\"\"\n\n def __init__(self):\n # Delay privacy regularization until the model has learned a useful\n # classifier, then smoothly ramp the defense on.\n self.warmup_epochs = 8\n self.ramp_tau = 12.0\n\n # EMA target for correctly-classified confidence.\n self.ema_momentum = 0.90\n self.correct_conf_ema = None\n\n # Confidence target schedule. During warmup the target is near 1.0,\n # then it anneals toward an EMA-based cap.\n self.init_target = 0.98\n self.min_target = 0.66\n self.max_target = 0.82\n self.conf_gap = 0.08\n\n # Regularization strengths.\n self.max_smoothing = 0.16\n self.excess_weight = 0.70\n self.mean_weight = 0.20\n self.entropy_weight = 0.01\n\n def _update_conf_ema(self, value):\n value = value.detach()\n if self.correct_conf_ema is None:\n self.correct_conf_ema = value\n else:\n self.correct_conf_ema = (\n self.ema_momentum * self.correct_conf_ema.to(value.device)\n + (1.0 - self.ema_momentum) * value\n )\n\n def compute_loss(self, logits, labels, epoch):\n log_probs = F.log_softmax(logits, dim=1)\n probs = log_probs.exp()\n\n true_log_probs = log_probs.gather(1, labels.unsqueeze(1)).squeeze(1)\n true_probs = true_log_probs.exp()\n nll = -true_log_probs\n uniform_ce = -log_probs.mean(dim=1)\n\n preds = probs.argmax(dim=1)\n correct_mask = preds.eq(labels)\n correct_conf = true_probs[correct_mask]\n\n with torch.no_grad():\n ema_source = correct_conf if correct_conf.numel() > 0 else true_probs\n self._update_conf_ema(ema_source.mean())\n\n if epoch < self.warmup_epochs:\n ramp = logits.new_tensor(0.0)\n target_conf = logits.new_tensor(self.init_target)\n else:\n steps = float(epoch - self.warmup_epochs + 1)\n decay = torch.exp(logits.new_tensor(-steps / self.ramp_tau))\n ramp = 1.0 - decay\n\n ema_target = torch.clamp(\n self.correct_conf_ema.to(logits.device) - self.conf_gap,\n min=self.min_target,\n max=self.max_target,\n )\n target_conf = ema_target + (self.init_target - ema_target) * decay\n\n # Confidence-adaptive label smoothing: only examples whose true-class\n # confidence exceeds the target receive extra smoothing.\n excess_ratio = (\n (true_probs.detach() - target_conf) / (1.0 - target_conf + 1e-6)\n ).clamp(min=0.0, max=1.0)\n adaptive_eps = self.max_smoothing * ramp * excess_ratio\n adaptive_ce = ((1.0 - adaptive_eps) * nll + adaptive_eps * uniform_ce).mean()\n\n # Additional cap on overconfident correct examples to directly shrink\n # the train-confidence tail exploited by MIA.\n over_mask = correct_mask & (true_probs.detach() > target_conf)\n over_conf = true_probs[over_mask]\n if over_conf.numel() > 0:\n excess_pen = (over_conf - target_conf).pow(2).mean()\n mean_pen = (over_conf.mean() - target_conf).clamp_min(0.0).pow(2)\n over_entropy = -(probs[over_mask] * log_probs[over_mask]).sum(dim=1).mean()\n else:\n zero = logits.new_tensor(0.0)\n excess_pen = zero\n mean_pen = zero\n over_entropy = zero\n\n return adaptive_ce + ramp * (\n self.excess_weight * excess_pen\n + self.mean_weight * mean_pen\n - self.entropy_weight * over_entropy\n )\n# ============================================================\n# END EDITABLE\n# ============================================================\n",
5 "after_line": 0,
6 "start_line": 1,
7 "end_line": 32
8}OK: Replaced lines 1..32 with 115 line(s) in pytorch-vision/custom_membership_defense.py. Editable range: entire file.
[Current file: pytorch-vision/custom_membership_defense.py | editable: 1–115 | total: 115 lines]
1 """Editable membership-inference defense for MLS-Bench."""
2
3 import torch
4 import torch.nn.functional as F
5
6 # ============================================================
7 # EDITABLE
8 # ============================================================
9 class MembershipDefense:
10 """Confidence-capped adaptive smoothing for membership privacy.
11
12 Idea:
13 - keep standard ERM behavior on hard / misclassified examples;
14 - only regularize examples that are already correct and becoming
15 overconfident, since those dominate confidence-based MIA signals;
16 - use an EMA of correct-class confidence to set an adaptive target,
17 making the def
Tool Input
1{
2 "op": "replace",
3 "filename": "pytorch-vision/custom_membership_defense.py",
4 "content": "\"\"\"Editable membership-inference defense for MLS-Bench.\"\"\"\n\nimport math\n\nimport torch\nimport torch.nn.functional as F\n\n# ============================================================\n# EDITABLE\n# ============================================================\nclass MembershipDefense:\n \"\"\"Adaptive confidence-capped self-distillation.\n\n Core idea:\n 1) keep standard supervised learning on hard examples;\n 2) on already-easy / high-confidence examples, replace one-hot targets\n with a tempered self-target so the model preserves class ranking while\n reducing overconfident train-only predictions;\n 3) explicitly cap the upper tail of true-class confidence and shrink its\n spread, which directly targets confidence-based membership inference.\n \"\"\"\n\n def __init__(self):\n # Mild global smoothing keeps optimization stable.\n self.base_smoothing = 0.02\n\n # Adaptive self-distillation parameters.\n self.alpha_max = 0.35\n self.temperature = 2.2\n\n # Confidence gate is applied on normalized true-class confidence so the\n # same thresholds transfer across datasets with different class counts.\n self.gate_center = 0.55\n self.gate_width = 0.08\n\n # Upper-tail confidence control.\n self.cap_weight = 0.8\n self.var_weight = 0.04\n\n def compute_loss(self, logits, labels, epoch):\n num_classes = logits.size(1)\n dtype = logits.dtype\n\n log_probs = F.log_softmax(logits, dim=1)\n probs = log_probs.exp()\n\n true_probs = probs.gather(1, labels.unsqueeze(1)).squeeze(1)\n uniform_prob = 1.0 / float(num_classes)\n norm_true = ((true_probs - uniform_prob) / max(1.0 - uniform_prob, 1e-6)).clamp(0.0, 1.0)\n\n # Start close to ERM, then progressively regularize confident examples.\n ramp = 1.0 - math.exp(-float(epoch + 1) / 20.0)\n ramp_t = logits.new_tensor(ramp)\n\n gate_center = self.gate_center + 0.10 * (1.0 - ramp)\n gate = torch.sigmoid((norm_true.detach() - gate_center) / self.gate_width)\n\n with torch.no_grad():\n tempered = F.softmax(logits / self.temperature, dim=1)\n\n one_hot = F.one_hot(labels, num_classes=num_classes).to(dtype=dtype)\n uniform = torch.full_like(tempered, 1.0 / float(num_classes))\n alpha = (self.alpha_max * ramp_t * gate).unsqueeze(1)\n\n # Interpolate between one-hot supervision and a softer, model-aware\n # target that preserves class similarities better than uniform label\n # smoothing.\n targets = (1.0 - self.base_smoothing) * ((1.0 - alpha) * one_hot + alpha * tempered) + self.base_smoothing * uniform\n soft_ce = -(targets * log_probs).sum(dim=1)\n\n # Confidence-based MIA relies on the train-only upper tail. Penalize\n # only excessive confidence, not all confidence.\n target_norm = 0.74 + 0.08 * ramp\n cap_penalty = gate * F.relu(norm_true - target_norm).pow(2)\n\n # Compress the spread of easy-example confidence to further reduce the\n # separability exploited by confidence attacks.\n weight_sum = gate.sum().clamp_min(1e-6)\n mean_conf = (gate * norm_true.detach()).sum() / weight_sum\n var_penalty = (gate * (norm_true - mean_conf).pow(2)).sum() / weight_sum\n\n return (\n soft_ce.mean()\n + self.cap_weight * ramp_t * cap_penalty.mean()\n + self.var_weight * ramp_t * var_penalty\n )\n# ============================================================\n# END EDITABLE\n# ============================================================\n",
5 "after_line": 0,
6 "start_line": 1,
7 "end_line": 32
8}OK: Replaced lines 1..32 with 89 line(s) in pytorch-vision/custom_membership_defense.py. Editable range: entire file.
[Current file: pytorch-vision/custom_membership_defense.py | editable: 1–89 | total: 89 lines]
1 """Editable membership-inference defense for MLS-Bench."""
2
3 import math
4
5 import torch
6 import torch.nn.functional as F
7
8 # ============================================================
9 # EDITABLE
10 # ============================================================
11 class MembershipDefense:
12 """Adaptive confidence-capped self-distillation.
13
14 Core idea:
15 1) keep standard supervised learning on hard examples;
16 2) on already-easy / high-confidence examples, replace one-hot targets
17 with a tempered self-target so the model preserves class ranking while
18 reducing overconfident train-only predictions;
19 3) explicitly
[Test #1] (2 tests remaining; call submit(n=1) to submit this result) ### resnet20-cifar10 (scripts/resnet20_cifar10.sh) TRAIN_METRICS epoch=1 loss=1.8253 train_acc=33.71 lr=0.099975 TRAIN_METRICS epoch=10 loss=0.9225 train_acc=77.39 lr=0.097553 TRAIN_METRICS epoch=20 loss=0.9667 train_acc=81.65 lr=0.090451 TRAIN_METRICS epoch=30 loss=1.0114 train_acc=84.34 lr=0.079389 TRAIN_METRICS epoch=40 loss=1.0300 train_acc=86.94 lr=0.065451 TRAIN_METRICS epoch=50 loss=1.0440 train_acc=88.79 lr=0.050000 TRAIN_METRICS epoch=60 loss=1.0369 train_acc=91.37 lr=0.034549 TRAIN_METRICS epoch=70 loss=1.0285 train_acc=94.06 lr=0.020611 TRAIN_METRICS epoch=80 loss=1.0154 train_acc=97.15 lr=0.009549 TRAIN_METRICS epoch=90 loss=1.0131 train_acc=98.81 lr=0.002447 TRAIN_METRICS epoch=100 loss=1.0146 train_acc=99.35 lr=0.000000 resnet20-cifar10: test_acc=0.8937, mia_auc=0.5953, privacy_gap=0.0458, privacy_score=0.7984 ### vgg16bn-cifar100 (scripts/vgg16bn_cifar100.sh) TRAIN_METRICS epoch=1 loss=4.6580 train_a
Tool Input
1{
2 "op": "replace",
3 "filename": "pytorch-vision/custom_membership_defense.py",
4 "content": "\"\"\"Editable membership-inference defense for MLS-Bench.\"\"\"\n\nimport math\n\nimport torch\nimport torch.nn.functional as F\n\n# ============================================================\n# EDITABLE\n# ============================================================\nclass MembershipDefense:\n \"\"\"Adaptive tempered flooding for privacy-utility tradeoffs.\n\n The loss keeps standard supervised learning, but only once an example has\n become \"easy\" (low loss relative to the random-guess loss log(K)) it adds:\n 1) tempered self-distillation to soften train-only overconfidence while\n preserving class similarity structure; and\n 2) a smooth flooding barrier that prevents the loss from collapsing far\n below a class-count-aware target, directly shrinking the signal used by\n confidence/loss-based membership inference.\n \"\"\"\n\n def __init__(self):\n self.base_smoothing = 0.01\n self.temperature = 2.4\n self.distill_weight = 0.80\n self.floor_weight = 0.90\n self.spread_weight = 0.02\n\n # Gate over easy examples using loss normalized by log(num_classes).\n self.gate_center = 0.58\n self.gate_width = 0.10\n\n def compute_loss(self, logits, labels, epoch):\n num_classes = logits.size(1)\n log_k = math.log(float(num_classes))\n\n log_probs = F.log_softmax(logits, dim=1)\n hard_ce = F.cross_entropy(logits, labels, reduction=\"none\")\n base_ce = F.cross_entropy(\n logits,\n labels,\n reduction=\"none\",\n label_smoothing=self.base_smoothing,\n )\n\n # Start near ERM and gradually turn on privacy regularization.\n ramp = 1.0 - math.exp(-float(epoch + 1) / 15.0)\n ramp_t = logits.new_tensor(ramp)\n\n # Normalized skill = improvement over random-guess loss, in [0, 1].\n skill = (1.0 - hard_ce.detach() / log_k).clamp(0.0, 1.0)\n gate_center = self.gate_center + 0.08 * (1.0 - ramp)\n gate = torch.sigmoid((skill - gate_center) / self.gate_width)\n\n with torch.no_grad():\n tempered = F.softmax(logits / self.temperature, dim=1)\n\n distill = F.kl_div(log_probs, tempered, reduction=\"none\").sum(dim=1)\n\n # Smooth, class-count-aware flooding barrier: allow learning on hard\n # samples, but discourage memorized train samples from driving the loss\n # much below a moderate target.\n target_ce = log_k * (0.24 + 0.05 * ramp)\n flood = gate * F.relu(target_ce - hard_ce).pow(2)\n\n # Compress the easy-sample loss spread; confidence attacks exploit the\n # particularly low-loss tail of the train distribution.\n weight_sum = gate.sum().clamp_min(1e-6)\n mean_ce = (gate * hard_ce.detach()).sum() / weight_sum\n spread = (gate * (hard_ce - mean_ce).pow(2)).sum() / weight_sum\n\n return (\n base_ce.mean()\n + self.distill_weight * ramp_t * (gate * distill).mean()\n + self.floor_weight * ramp_t * flood.mean()\n + self.spread_weight * ramp_t * spread\n )\n# ============================================================\n# END EDITABLE\n# ============================================================\n",
5 "after_line": 0,
6 "start_line": 1,
7 "end_line": 89
8}OK: Replaced lines 1..89 with 81 line(s) in pytorch-vision/custom_membership_defense.py. Editable range: entire file.
[Current file: pytorch-vision/custom_membership_defense.py | editable: 1–81 | total: 81 lines]
1 """Editable membership-inference defense for MLS-Bench."""
2
3 import math
4
5 import torch
6 import torch.nn.functional as F
7
8 # ============================================================
9 # EDITABLE
10 # ============================================================
11 class MembershipDefense:
12 """Adaptive tempered flooding for privacy-utility tradeoffs.
13
14 The loss keeps standard supervised learning, but only once an example has
15 become "easy" (low loss relative to the random-guess loss log(K)) it adds:
16 1) tempered self-distillation to soften train-only overconfidence while
17 preserving class similarity structure; and
18 2)
[Test #2] (1 test remaining; call submit(n=2) to submit this result) ### resnet20-cifar10 (scripts/resnet20_cifar10.sh) TRAIN_METRICS epoch=1 loss=1.7912 train_acc=34.46 lr=0.099975 TRAIN_METRICS epoch=10 loss=0.9509 train_acc=77.37 lr=0.097553 TRAIN_METRICS epoch=20 loss=0.9497 train_acc=81.80 lr=0.090451 TRAIN_METRICS epoch=30 loss=0.9469 train_acc=84.34 lr=0.079389 TRAIN_METRICS epoch=40 loss=0.9243 train_acc=87.01 lr=0.065451 TRAIN_METRICS epoch=50 loss=0.9034 train_acc=88.63 lr=0.050000 TRAIN_METRICS epoch=60 loss=0.8689 train_acc=91.13 lr=0.034549 TRAIN_METRICS epoch=70 loss=0.8315 train_acc=93.57 lr=0.020611 TRAIN_METRICS epoch=80 loss=0.7866 train_acc=96.79 lr=0.009549 TRAIN_METRICS epoch=90 loss=0.7579 train_acc=98.82 lr=0.002447 TRAIN_METRICS epoch=100 loss=0.7494 train_acc=99.29 lr=0.000000 resnet20-cifar10: test_acc=0.8937, mia_auc=0.5853, privacy_gap=0.0312, privacy_score=0.8084 ### vgg16bn-cifar100 (scripts/vgg16bn_cifar100.sh) TRAIN_METRICS epoch=1 loss=4.6652 train_ac
Tool Input
1{
2 "n": 1
3}[submit] Submitting result from test #1 as final. ### resnet20-cifar10 (scripts/resnet20_cifar10.sh) TRAIN_METRICS epoch=1 loss=1.8253 train_acc=33.71 lr=0.099975 TRAIN_METRICS epoch=10 loss=0.9225 train_acc=77.39 lr=0.097553 TRAIN_METRICS epoch=20 loss=0.9667 train_acc=81.65 lr=0.090451 TRAIN_METRICS epoch=30 loss=1.0114 train_acc=84.34 lr=0.079389 TRAIN_METRICS epoch=40 loss=1.0300 train_acc=86.94 lr=0.065451 TRAIN_METRICS epoch=50 loss=1.0440 train_acc=88.79 lr=0.050000 TRAIN_METRICS epoch=60 loss=1.0369 train_acc=91.37 lr=0.034549 TRAIN_METRICS epoch=70 loss=1.0285 train_acc=94.06 lr=0.020611 TRAIN_METRICS epoch=80 loss=1.0154 train_acc=97.15 lr=0.009549 TRAIN_METRICS epoch=90 loss=1.0131 train_acc=98.81 lr=0.002447 TRAIN_METRICS epoch=100 loss=1.0146 train_acc=99.35 lr=0.000000 resnet20-cifar10: test_acc=0.8937, mia_auc=0.5953, privacy_gap=0.0458, privacy_score=0.7984 ### vgg16bn-cifar100 (scripts/vgg16bn_cifar100.sh) TRAIN_METRICS epoch=1 loss=4.6580 train_acc=1.14 lr=0.099975