Agent Conversation: google/gemini-3.1-pro-preview
Task: security-membership-inference-defense
Membership Inference Defense via Training Regularization
Research Question
How can we design a stronger training-time regularizer that reduces membership inference leakage while preserving standard predictive accuracy?
Background
Membership inference attacks exploit the gap between train and non-train examples, often through confidence or loss statistics. Many defenses regularize predictions to reduce overconfidence and shrink the train-test generalization gap, but stronger privacy often hurts utility.
Task
Implement a better privacy-preserving training objective in bench/membership/custom_membership_defense.py. The fixed harness will train a model on a fixed dataset split, then run a confidence-based membership inference attack on train versus held-out examples.
Your method should improve the privacy-utility tradeoff: lower membership attack AUC while retaining high test accuracy.
Editable Interface
You must implement:
class MembershipDefense:
def compute_loss(self, logits, labels, model, epoch):
...
logits: model outputs for the current minibatchlabels: ground-truth labelsmodel: current model instance, exposed for optional parameter regularizationepoch: current training epoch- Return value: scalar loss tensor used by the fixed training loop
The optimizer, architecture, batch schedule, and attack implementation are fixed.
Evaluation
Benchmarks:
mnistfmnistcifar10
Reported metrics:
test_accmia_aucprivacy_gapprivacy_score
Primary metric: privacy_score (higher is better).
Baselines
erm: standard cross-entropy traininglabel_smoothing: smoothed targets to reduce overconfidenceconfidence_penalty: cross-entropy plus predictive entropy penaltyrelaxloss: margin-aware loss relaxation as a stronger privacy baseline
opacus/bench/membership/custom_membership_defense.py [EDITABLE — entire file only]
1: """Editable membership-inference defense for MLS-Bench."""
2:
3: import torch
4: import torch.nn.functional as F
5:
6: # ============================================================
7: # EDITABLE
8: # ============================================================
9: class MembershipDefense:
10: """Training-time regularizer for privacy-utility tradeoffs."""
11:
12: def __init__(self):
13: self.label_smoothing = 0.0
14: self.entropy_weight = 0.0
15:
16: def compute_loss(self, logits, labels, model, epoch):
17: ce = F.cross_entropy(logits, labels, label_smoothing=self.label_smoothing)
18: probs = torch.softmax(logits, dim=1)
19: entropy = -(probs * torch.log(probs.clamp_min(1e-8))).sum(dim=1).mean()
20: return ce - self.entropy_weight * entropy
21: # ============================================================
22Tool Input
1{
2 "end_line": 20,
3 "start_line": 9,
4 "content": "class MembershipDefense:\n \"\"\"Training-time regularizer for privacy-utility tradeoffs.\"\"\"\n\n def __init__(self):\n self.mixup_alpha = 1.0\n self.logit_pen_weight = 0.01\n\n def compute_loss(self, logits, labels, model, epoch):\n batch_size = logits.size(0)\n device = logits.device\n num_classes = logits.size(1)\n\n # Logit Mixup\n lam = torch.distributions.Beta(self.mixup_alpha, self.mixup_alpha).sample((batch_size, 1)).to(device)\n lam = torch.max(lam, 1 - lam)\n indices = torch.randperm(batch_size, device=device)\n \n mixed_logits = lam * logits + (1 - lam) * logits[indices]\n \n labels_onehot = F.one_hot(labels, num_classes=num_classes).float()\n mixed_targets = lam * labels_onehot + (1 - lam) * labels_onehot[indices]\n \n mix_ce = F.cross_entropy(mixed_logits, mixed_targets)\n \n # Standard CE\n ce = F.cross_entropy(logits, labels, label_smoothing=0.1)\n \n # L2 Logit Penalty to restrict confidence\n logit_pen = torch.norm(logits, dim=1).mean()\n \n return 0.5 * ce + 0.5 * mix_ce + self.logit_pen_weight * logit_pen",
5 "filename": "opacus/bench/membership/custom_membership_defense.py",
6 "op": "replace"
7}OK: Replaced lines 9..20 with 31 line(s) in opacus/bench/membership/custom_membership_defense.py. Editable range: entire file.
[Current file: opacus/bench/membership/custom_membership_defense.py | editable: 1–42 | total: 42 lines]
1 """Editable membership-inference defense for MLS-Bench."""
2
3 import torch
4 import torch.nn.functional as F
5
6 # ============================================================
7 # EDITABLE
8 # ============================================================
9 class MembershipDefense:
10 """Training-time regularizer for privacy-utility tradeoffs."""
11
12 def __init__(self):
13 self.mixup_alpha = 1.0
14 self.logit_pen_weight = 0.01
15
16 def compute_loss(self, logits, labels, model, epoch):
17 batch_size = logits.size(0)
18 device = logits.device
19 num_classes = logits.size(1)
20
21
[Test #1] (2 tests remaining; call submit(n=1) to submit this result) ## Seed 42 ### mnist (scripts/mnist.sh) TRAIN_METRICS epoch=1 loss=2.3128 TRAIN_METRICS epoch=2 loss=2.2974 TRAIN_METRICS epoch=3 loss=2.1623 TRAIN_METRICS epoch=4 loss=1.6083 TRAIN_METRICS epoch=5 loss=1.1449 mnist: test_acc=0.8360, mia_auc=0.5530, privacy_gap=0.0365, privacy_score=0.7830 ### fmnist (scripts/fmnist.sh) TRAIN_METRICS epoch=1 loss=2.0748 TRAIN_METRICS epoch=2 loss=1.6575 TRAIN_METRICS epoch=3 loss=1.4879 TRAIN_METRICS epoch=4 loss=1.3472 TRAIN_METRICS epoch=5 loss=1.2687 fmnist: test_acc=0.7255, mia_auc=0.5225, privacy_gap=0.0161, privacy_score=0.7030 ### cifar10 (scripts/cifar10.sh) TRAIN_METRICS epoch=1 loss=2.3140 TRAIN_METRICS epoch=2 loss=2.2173 TRAIN_METRICS epoch=3 loss=2.1678 TRAIN_METRICS epoch=4 loss=2.0334 TRAIN_METRICS epoch=5 loss=1.9827 TRAIN_METRICS epoch=6 loss=1.9373 cifar10: test_acc=0.3320, mia_auc=0.5027, privacy_gap=0.0012, privacy_score=0.3293 ## Seed 123 ### mnist (scrip
Tool Input
1{
2 "start_line": 9,
3 "op": "replace",
4 "end_line": 41,
5 "content": "class MembershipDefense:\n \"\"\"Generalized Cross Entropy for Privacy.\"\"\"\n\n def __init__(self):\n self.q = 0.5\n\n def compute_loss(self, logits, labels, model, epoch):\n # We can add a bit of label smoothing to GCE\n # But let's first manually compute GCE\n probs = torch.softmax(logits, dim=1)\n p_true = probs[torch.arange(logits.size(0)), labels]\n \n # GCE Loss: (1 - p^q)/q\n gce_loss = (1.0 - torch.pow(p_true + 1e-6, self.q)).mean() / self.q\n \n # To ensure the model learns, we can mix a little bit of standard cross entropy\n ce_loss = F.cross_entropy(logits, labels, label_smoothing=0.1)\n \n return 0.8 * gce_loss + 0.2 * ce_loss",
6 "filename": "opacus/bench/membership/custom_membership_defense.py"
7}OK: Replaced lines 9..41 with 19 line(s) in opacus/bench/membership/custom_membership_defense.py. Editable range: entire file.
[Current file: opacus/bench/membership/custom_membership_defense.py | editable: 1–28 | total: 28 lines]
1 """Editable membership-inference defense for MLS-Bench."""
2
3 import torch
4 import torch.nn.functional as F
5
6 # ============================================================
7 # EDITABLE
8 # ============================================================
9 class MembershipDefense:
10 """Generalized Cross Entropy for Privacy."""
11
12 def __init__(self):
13 self.q = 0.5
14
15 def compute_loss(self, logits, labels, model, epoch):
16 # We can add a bit of label smoothing to GCE
17 # But let's first manually compute GCE
18 probs = torch.softmax(logits, dim=1)
19 p_true = probs[torch.arange(logits.size(0
[Test #2] (1 test remaining; call submit(n=2) to submit this result)
### mnist (scripts/mnist.sh)
TRAIN_METRICS epoch=1 loss=1.5506
TRAIN_METRICS epoch=2 loss=1.3839
TRAIN_METRICS epoch=3 loss=1.0841
TRAIN_METRICS epoch=4 loss=0.7743
TRAIN_METRICS epoch=5 loss=0.5742
mnist: test_acc=0.7810, mia_auc=0.5387, privacy_gap=0.0233, privacy_score=0.7423
### fmnist (scripts/fmnist.sh)
TRAIN_METRICS epoch=1 loss=1.3431
TRAIN_METRICS epoch=2 loss=0.9421
TRAIN_METRICS epoch=3 loss=0.7632
TRAIN_METRICS epoch=4 loss=0.6708
TRAIN_METRICS epoch=5 loss=0.6326
fmnist: test_acc=0.7255, mia_auc=0.5203, privacy_gap=0.0145, privacy_score=0.7052
### cifar10 (scripts/cifar10.sh)
TRAIN_METRICS epoch=1 loss=1.4887
TRAIN_METRICS epoch=2 loss=1.3841
TRAIN_METRICS epoch=3 loss=1.3084
TRAIN_METRICS epoch=4 loss=1.2701
TRAIN_METRICS epoch=5 loss=1.2257
TRAIN_METRICS epoch=6 loss=1.1636
cifar10: test_acc=0.4145, mia_auc=0.4927, privacy_gap=-0.0035, privacy_score=0.4145
[Leaderboard] Results saved: {'test_acc_mni
Tool Input
1{
2 "filename": "opacus/bench/membership/custom_membership_defense.py",
3 "end_line": 28,
4 "content": "class MembershipDefense:\n \"\"\"Variance Penalty on Cross Entropy.\"\"\"\n\n def __init__(self):\n self.var_weight = 1.0\n\n def compute_loss(self, logits, labels, model, epoch):\n # We use label smoothing to have a softer baseline\n ce_per = F.cross_entropy(logits, labels, label_smoothing=0.1, reduction='none')\n \n # Penalize the variance of the loss in the batch\n mean_ce = ce_per.mean()\n var_ce = ce_per.var()\n \n return mean_ce + self.var_weight * var_ce",
5 "op": "replace",
6 "start_line": 9
7}OK: Replaced lines 9..28 with 15 line(s) in opacus/bench/membership/custom_membership_defense.py. Editable range: entire file.
[Current file: opacus/bench/membership/custom_membership_defense.py | editable: 1–23 | total: 23 lines]
1 """Editable membership-inference defense for MLS-Bench."""
2
3 import torch
4 import torch.nn.functional as F
5
6 # ============================================================
7 # EDITABLE
8 # ============================================================
9 class MembershipDefense:
10 """Variance Penalty on Cross Entropy."""
11
12 def __init__(self):
13 self.var_weight = 1.0
14
15 def compute_loss(self, logits, labels, model, epoch):
16 # We use label smoothing to have a softer baseline
17 ce_per = F.cross_entropy(logits, labels, label_smoothing=0.1, reduction='none')
18
19 # Penalize the variance o
[Test #3] (0 tests remaining) ## Seed 42 ### mnist (scripts/mnist.sh) TRAIN_METRICS epoch=1 loss=2.2959 TRAIN_METRICS epoch=2 loss=2.0876 TRAIN_METRICS epoch=3 loss=1.6668 TRAIN_METRICS epoch=4 loss=1.2507 TRAIN_METRICS epoch=5 loss=1.0082 mnist: test_acc=0.9080, mia_auc=0.5472, privacy_gap=0.0307, privacy_score=0.8608 ### fmnist (scripts/fmnist.sh) TRAIN_METRICS epoch=1 loss=2.1936 TRAIN_METRICS epoch=2 loss=1.7476 TRAIN_METRICS epoch=3 loss=1.6183 TRAIN_METRICS epoch=4 loss=1.5404 TRAIN_METRICS epoch=5 loss=1.4904 fmnist: test_acc=0.6425, mia_auc=0.5206, privacy_gap=0.0098, privacy_score=0.6219 ### cifar10 (scripts/cifar10.sh) TRAIN_METRICS epoch=1 loss=2.2965 TRAIN_METRICS epoch=2 loss=2.2202 TRAIN_METRICS epoch=3 loss=2.1461 TRAIN_METRICS epoch=4 loss=2.1091 TRAIN_METRICS epoch=5 loss=2.0712 TRAIN_METRICS epoch=6 loss=2.0319 cifar10: test_acc=0.3650, mia_auc=0.4966, privacy_gap=-0.0004, privacy_score=0.3650 ## Seed 123 ### mnist (scripts/mnist.sh) TRAIN_METRICS epoch=1 loss
Tool Input
1{
2 "filename": "bench/membership/custom_membership_defense.py",
3 "op": "replace",
4 "start_line": 9,
5 "content": "class MembershipDefense:\n \"\"\"Median-Relaxed Cross Entropy.\n \n Dynamically reverses gradients for the easiest samples in the batch to \n compress the loss variance and prevent memorization of easy samples, \n making train and test loss distributions indistinguishable to MIA attackers.\n \"\"\"\n\n def __init__(self):\n self.alpha = 0.5 # Gradient reversal weight\n self.q = 0.5 # Quantile (0.5 = median)\n self.ls = 0.05 # Light label smoothing\n\n def compute_loss(self, logits, labels, epoch):\n ce_per = F.cross_entropy(logits, labels, label_smoothing=self.ls, reduction='none')\n \n # Determine dynamic target as the batch q-th quantile\n target_ce = torch.quantile(ce_per.float(), self.q).detach()\n \n relaxed = torch.where(\n ce_per > target_ce,\n ce_per,\n target_ce + self.alpha * (target_ce - ce_per)\n )\n return relaxed.mean()",
6 "end_line": 29
7}ERROR: Package 'bench' is not in allowed packages
Tool Input
1{
2 "op": "replace",
3 "start_line": 9,
4 "filename": "pytorch-vision/custom_membership_defense.py",
5 "content": "class MembershipDefense:\n \"\"\"Median-Relaxed Cross Entropy.\n \n Dynamically reverses gradients for the easiest samples in the batch to \n compress the loss variance and prevent memorization of easy samples, \n making train and test loss distributions indistinguishable to MIA attackers.\n \"\"\"\n\n def __init__(self):\n self.alpha = 0.5 # Gradient reversal weight\n self.q = 0.5 # Quantile (0.5 = median)\n self.ls = 0.05 # Light label smoothing\n\n def compute_loss(self, logits, labels, epoch):\n ce_per = F.cross_entropy(logits, labels, label_smoothing=self.ls, reduction='none')\n \n # Determine dynamic target as the batch q-th quantile\n target_ce = torch.quantile(ce_per.float(), self.q).detach()\n \n relaxed = torch.where(\n ce_per > target_ce,\n ce_per,\n target_ce + self.alpha * (target_ce - ce_per)\n )\n return relaxed.mean()",
6 "end_line": 29
7}OK: Replaced lines 9..29 with 25 line(s) in pytorch-vision/custom_membership_defense.py. Editable range: entire file.
[Current file: pytorch-vision/custom_membership_defense.py | editable: 1–36 | total: 36 lines]
1 """Editable membership-inference defense for MLS-Bench."""
2
3 import torch
4 import torch.nn.functional as F
5
6 # ============================================================
7 # EDITABLE
8 # ============================================================
9 class MembershipDefense:
10 """Median-Relaxed Cross Entropy.
11
12 Dynamically reverses gradients for the easiest samples in the batch to
13 compress the loss variance and prevent memorization of easy samples,
14 making train and test loss distributions indistinguishable to MIA attackers.
15 """
16
17 def __init__(self):
18 self.alpha = 0.5 # Gradient reversal weight
19
Tool Input
1{
2 "content": "# ============================================================\n# EDITABLE\n# ============================================================\nclass MembershipDefense:\n \"\"\"Logit Normalization + Label Smoothing.\n \"\"\"\n\n def __init__(self):\n self.tau = 0.04\n self.ls = 0.1\n\n def compute_loss(self, logits, labels, epoch):\n norms = torch.norm(logits, p=2, dim=1, keepdim=True)\n logits_norm = logits / (norms + 1e-7)\n scaled_logits = logits_norm / self.tau\n return F.cross_entropy(scaled_logits, labels, label_smoothing=self.ls)\n# ============================================================\n# END EDITABLE\n# ============================================================",
3 "start_line": 6,
4 "op": "replace",
5 "end_line": 32,
6 "filename": "pytorch-vision/custom_membership_defense.py"
7}OK: Replaced lines 6..32 with 19 line(s) in pytorch-vision/custom_membership_defense.py. Editable range: entire file.
[Current file: pytorch-vision/custom_membership_defense.py | editable: 1–24 | total: 24 lines]
1 """Editable membership-inference defense for MLS-Bench."""
2
3 import torch
4 import torch.nn.functional as F
5
6 # ============================================================
7 # EDITABLE
8 # ============================================================
9 class MembershipDefense:
10 """Logit Normalization + Label Smoothing.
11 """
12
13 def __init__(self):
14 self.tau = 0.04
15 self.ls = 0.1
16
17 def compute_loss(self, logits, labels, epoch):
18 norms = torch.norm(logits, p=2, dim=1, keepdim=True)
19 logits_norm = logits / (norms + 1e-7)
20 scaled_logits = logits_norm / self.tau
21 return F.cr