Agent Conversation: anthropic/claude-opus-4.6
Task: security-membership-inference-defense
Membership Inference Defense via Training Regularization
Research Question
How can we design a stronger training-time regularizer that reduces membership inference leakage while preserving standard predictive accuracy?
Background
Membership inference attacks exploit the gap between train and non-train examples, often through confidence or loss statistics. Many defenses regularize predictions to reduce overconfidence and shrink the train-test generalization gap, but stronger privacy often hurts utility.
Task
Implement a better privacy-preserving training objective in bench/membership/custom_membership_defense.py. The fixed harness will train a model on a fixed dataset split, then run a confidence-based membership inference attack on train versus held-out examples.
Your method should improve the privacy-utility tradeoff: lower membership attack AUC while retaining high test accuracy.
Editable Interface
You must implement:
class MembershipDefense:
def compute_loss(self, logits, labels, model, epoch):
...
logits: model outputs for the current minibatchlabels: ground-truth labelsmodel: current model instance, exposed for optional parameter regularizationepoch: current training epoch- Return value: scalar loss tensor used by the fixed training loop
The optimizer, architecture, batch schedule, and attack implementation are fixed.
Evaluation
Benchmarks:
mnistfmnistcifar10
Reported metrics:
test_accmia_aucprivacy_gapprivacy_score
Primary metric: privacy_score (higher is better).
Baselines
erm: standard cross-entropy traininglabel_smoothing: smoothed targets to reduce overconfidenceconfidence_penalty: cross-entropy plus predictive entropy penaltyrelaxloss: margin-aware loss relaxation as a stronger privacy baseline
opacus/bench/membership/custom_membership_defense.py [EDITABLE — entire file only]
1: """Editable membership-inference defense for MLS-Bench."""
2:
3: import torch
4: import torch.nn.functional as F
5:
6: # ============================================================
7: # EDITABLE
8: # ============================================================
9: class MembershipDefense:
10: """Training-time regularizer for privacy-utility tradeoffs."""
11:
12: def __init__(self):
13: self.label_smoothing = 0.0
14: self.entropy_weight = 0.0
15:
16: def compute_loss(self, logits, labels, model, epoch):
17: ce = F.cross_entropy(logits, labels, label_smoothing=self.label_smoothing)
18: probs = torch.softmax(logits, dim=1)
19: entropy = -(probs * torch.log(probs.clamp_min(1e-8))).sum(dim=1).mean()
20: return ce - self.entropy_weight * entropy
21: # ============================================================
22Tool Input
1{
2 "op": "replace",
3 "filename": "opacus/bench/membership/custom_membership_defense.py",
4 "start_line": 1,
5 "end_line": 23,
6 "content": "\"\"\"Editable membership-inference defense for MLS-Bench.\"\"\"\n\nimport torch\nimport torch.nn.functional as F\n\n# ============================================================\n# EDITABLE\n# ============================================================\nclass MembershipDefense:\n \"\"\"Confidence-Bounded Entropy-Regularized Training (CBERT).\n \n Novel multi-component defense against membership inference attacks:\n \n 1. Soft Logit Bounding: Uses scaled tanh to create a differentiable \n upper bound on logit magnitudes, directly limiting prediction \n confidence without hard truncation.\n \n 2. Confidence-Adaptive Smoothing: Per-sample label smoothing that \n increases with model confidence, creating a negative feedback loop \n that prevents memorization of individual samples.\n \n 3. Entropy Maximization: Encourages spreading probability mass across \n classes while maintaining correct classification.\n \n 4. Logit Norm Penalty: L2 regularization on raw logits to ensure \n bounded representations complement the soft clipping.\n \"\"\"\n\n def __init__(self):\n # Soft logit bounding scale\n self.logit_bound = 3.0\n # Confidence-adaptive smoothing parameters\n self.base_smoothing = 0.15\n self.max_smoothing = 0.45\n self.confidence_threshold = 0.5\n self.smoothing_sharpness = 8.0\n # Entropy bonus weight\n self.entropy_weight = 0.15\n # L2 penalty on raw logits\n self.logit_l2 = 0.008\n\n def compute_loss(self, logits, labels, model, epoch):\n B, C = logits.shape\n \n # 1. Soft logit bounding via scaled tanh\n # Creates a differentiable ceiling on logit magnitudes,\n # directly limiting the maximum achievable confidence\n bounded = self.logit_bound * torch.tanh(logits / self.logit_bound)\n \n # 2. Confidence-adaptive per-sample label smoothing\n # More confident predictions receive stronger smoothing,\n # creating a self-regulating negative feedback loop\n with torch.no_grad():\n probs_det = torch.softmax(bounded, dim=1)\n confidence = probs_det.max(dim=1).values\n # Sigmoid gating: smooth transition from base to max smoothing\n alpha = self.base_smoothing + (self.max_smoothing - self.base_smoothing) * \\\n torch.sigmoid(self.smoothing_sharpness * (confidence - self.confidence_threshold))\n \n # Create per-sample soft targets: blend one-hot with uniform\n one_hot = F.one_hot(labels, C).float()\n uniform = torch.ones_like(one_hot) / C\n soft_targets = (1 - alpha.unsqueeze(1)) * one_hot + alpha.unsqueeze(1) * uniform\n \n # Cross-entropy with per-sample soft targets\n log_probs = F.log_softmax(bounded, dim=1)\n ce = -(soft_targets * log_probs).sum(dim=1).mean()\n \n # 3. Entropy maximization on bounded logits\n probs = torch.softmax(bounded, dim=1)\n entropy = -(probs * torch.log(probs.clamp_min(1e-8))).sum(dim=1).mean()\n \n # 4. L2 penalty on raw logits to keep representations bounded\n logit_reg = self.logit_l2 * (logits ** 2).mean()\n \n return ce - self.entropy_weight * entropy + logit_reg\n# ============================================================\n# END EDITABLE\n# ============================================================\n"
7}OK: Replaced lines 1..23 with 79 line(s) in opacus/bench/membership/custom_membership_defense.py. Editable range: entire file.
[Current file: opacus/bench/membership/custom_membership_defense.py | editable: 1–79 | total: 79 lines]
1 """Editable membership-inference defense for MLS-Bench."""
2
3 import torch
4 import torch.nn.functional as F
5
6 # ============================================================
7 # EDITABLE
8 # ============================================================
9 class MembershipDefense:
10 """Confidence-Bounded Entropy-Regularized Training (CBERT).
11
12 Novel multi-component defense against membership inference attacks:
13
14 1. Soft Logit Bounding: Uses scaled tanh to create a differentiable
15 upper bound on logit magnitudes, directly limiting prediction
16 confidence without hard truncation.
17
18 2. Confidence-A
[Test #1] (2 tests remaining; call submit(n=1) to submit this result) ## Seed 42 ### mnist (scripts/mnist.sh) TRAIN_METRICS epoch=1 loss=1.9436 TRAIN_METRICS epoch=2 loss=1.6534 TRAIN_METRICS epoch=3 loss=1.3666 TRAIN_METRICS epoch=4 loss=1.2659 TRAIN_METRICS epoch=5 loss=1.2308 mnist: test_acc=0.9215, mia_auc=0.5464, privacy_gap=0.0261, privacy_score=0.8751 ### fmnist (scripts/fmnist.sh) TRAIN_METRICS epoch=1 loss=1.8589 TRAIN_METRICS epoch=2 loss=1.4880 TRAIN_METRICS epoch=3 loss=1.3914 TRAIN_METRICS epoch=4 loss=1.3376 TRAIN_METRICS epoch=5 loss=1.3119 fmnist: test_acc=0.7285, mia_auc=0.5211, privacy_gap=0.0139, privacy_score=0.7074 ### cifar10 (scripts/cifar10.sh) TRAIN_METRICS epoch=1 loss=1.9064 TRAIN_METRICS epoch=2 loss=1.7949 TRAIN_METRICS epoch=3 loss=1.7212 TRAIN_METRICS epoch=4 loss=1.6688 TRAIN_METRICS epoch=5 loss=1.6298 TRAIN_METRICS epoch=6 loss=1.6010 cifar10: test_acc=0.4405, mia_auc=0.4950, privacy_gap=-0.0025, privacy_score=0.4405 ## Seed 123 ### mnist (scri
Tool Input
1{
2 "op": "replace",
3 "filename": "pytorch-vision/custom_membership_defense.py",
4 "start_line": 1,
5 "end_line": 31,
6 "content": "\"\"\"Editable membership-inference defense for MLS-Bench.\"\"\"\n\nimport torch\nimport torch.nn.functional as F\n\n# ============================================================\n# EDITABLE\n# ============================================================\nclass MembershipDefense:\n \"\"\"Variance-Constrained Relaxed Entropy Loss (VCRE).\n\n Novel contribution: Combines relaxed loss training with per-sample\n loss variance regularization, entropy shaping, and logit magnitude\n control to minimize membership inference leakage.\n\n Key insight: MIA attacks exploit the distributional gap between\n per-sample losses of members vs non-members. By (1) relaxing\n gradients for overly well-fit samples, (2) penalizing the variance\n of per-sample losses to encourage uniform treatment, (3) maximizing\n prediction entropy, and (4) penalizing logit magnitudes, we\n systematically close this gap without sacrificing generalization.\n \"\"\"\n\n def __init__(self):\n # Relaxation parameters (proven in RelaxLoss)\n self.target_ce = 1.0\n self.alpha = 0.5\n # Label smoothing strength\n self.label_smoothing = 0.15\n # Entropy bonus weight\n self.entropy_weight = 0.15\n # Novel: loss variance penalty weight\n self.var_weight = 0.3\n # Logit magnitude penalty weight\n self.logit_reg_weight = 0.01\n\n def compute_loss(self, logits, labels, epoch):\n # Per-sample cross-entropy with label smoothing\n ce_per = F.cross_entropy(logits, labels, reduction='none',\n label_smoothing=self.label_smoothing)\n\n # Relaxed loss: for well-fit samples (low CE), reverse the gradient\n # direction to prevent memorization; for under-fit samples, train normally\n relaxed = torch.where(\n ce_per > self.target_ce,\n ce_per,\n self.alpha * (self.target_ce - ce_per),\n )\n main_loss = relaxed.mean()\n\n # Novel: Per-sample loss variance regularization\n # By penalizing the spread of per-sample losses, we encourage the model\n # to treat all training samples uniformly. This directly reduces the\n # statistical signal that MIA attacks use (the gap between the loss\n # distributions of members vs non-members).\n loss_var = ce_per.var()\n\n # Maximum entropy regularization: push output distributions toward\n # higher entropy, reducing overconfidence on training samples\n probs = F.softmax(logits, dim=1)\n entropy = -(probs * torch.log(probs.clamp_min(1e-8))).sum(dim=1).mean()\n\n # Logit magnitude penalty: prevents the model from producing extreme\n # logit values that lead to high-confidence predictions exploitable by MIA\n logit_reg = (logits ** 2).mean()\n\n return (main_loss\n + self.var_weight * loss_var\n - self.entropy_weight * entropy\n + self.logit_reg_weight * logit_reg)\n# ============================================================\n# END EDITABLE\n# ============================================================\n"
7}OK: Replaced lines 1..31 with 73 line(s) in pytorch-vision/custom_membership_defense.py. Editable range: entire file.
[Current file: pytorch-vision/custom_membership_defense.py | editable: 1–74 | total: 74 lines]
1 """Editable membership-inference defense for MLS-Bench."""
2
3 import torch
4 import torch.nn.functional as F
5
6 # ============================================================
7 # EDITABLE
8 # ============================================================
9 class MembershipDefense:
10 """Variance-Constrained Relaxed Entropy Loss (VCRE).
11
12 Novel contribution: Combines relaxed loss training with per-sample
13 loss variance regularization, entropy shaping, and logit magnitude
14 control to minimize membership inference leakage.
15
16 Key insight: MIA attacks exploit the distributional gap between
17 per-sample losses of members vs non-members. By (1) rel
Tool Input
1{
2 "op": "replace",
3 "filename": "pytorch-vision/custom_membership_defense.py",
4 "start_line": 1,
5 "end_line": 31,
6 "content": "\"\"\"Editable membership-inference defense for MLS-Bench.\"\"\"\n\nimport torch\nimport torch.nn.functional as F\n\n# ============================================================\n# EDITABLE\n# ============================================================\nclass MembershipDefense:\n \"\"\"Variance-Regularized Confidence-Bounded Loss (VRCBL).\n\n Novel contribution: Per-sample loss variance regularization combined with\n adaptive flooding, label smoothing, and entropy maximization. The variance\n penalty directly targets the MIA attack signal by encouraging uniform\n loss across all training samples, closing the gap between memorized and\n non-memorized examples that attackers exploit.\n \"\"\"\n\n def __init__(self):\n self.label_smoothing = 0.15\n self.flood_ratio = 0.15 # fraction of max entropy as flood level\n self.variance_weight = 1.0 # weight for per-sample loss variance penalty\n self.entropy_weight = 0.15 # weight for entropy bonus\n self.logit_reg_weight = 0.005 # weight for logit L2 penalty\n\n def compute_loss(self, logits, labels, epoch):\n num_classes = logits.size(1)\n\n # --- Per-sample cross-entropy with label smoothing ---\n ce_per = F.cross_entropy(\n logits, labels,\n label_smoothing=self.label_smoothing,\n reduction='none'\n )\n ce_mean = ce_per.mean()\n\n # --- Adaptive flooding: floor on mean loss to prevent memorization ---\n # Flood level scales with log(num_classes) so it adapts to dataset\n flood_level = self.flood_ratio * torch.log(\n torch.tensor(float(num_classes), device=logits.device)\n )\n ce_flooded = (ce_mean - flood_level).abs() + flood_level\n\n # --- Per-sample loss variance penalty (NOVEL) ---\n # Encourages uniform loss distribution across samples, directly\n # reducing the signal that MIA attacks exploit (loss gap between\n # memorized train samples and unseen test samples)\n loss_var = ce_per.var()\n\n # --- Entropy bonus: encourages less peaked predictions ---\n probs = torch.softmax(logits, dim=1)\n entropy = -(probs * torch.log(probs.clamp_min(1e-8))).sum(dim=1).mean()\n\n # --- Logit L2 regularization: bounds output confidence ---\n logit_reg = (logits ** 2).mean()\n\n # --- Epoch-adaptive regularization (stronger in later epochs) ---\n # Later epochs are more prone to memorization\n adaptive = min(1.0, 0.5 + 0.5 * epoch / 50.0)\n\n total = (\n ce_flooded\n + self.variance_weight * adaptive * loss_var\n - self.entropy_weight * adaptive * entropy\n + self.logit_reg_weight * logit_reg\n )\n return total\n# ============================================================\n# END EDITABLE\n# ============================================================\n"
7}OK: Replaced lines 1..31 with 70 line(s) in pytorch-vision/custom_membership_defense.py. Editable range: entire file.
[Current file: pytorch-vision/custom_membership_defense.py | editable: 1–71 | total: 71 lines]
1 """Editable membership-inference defense for MLS-Bench."""
2
3 import torch
4 import torch.nn.functional as F
5
6 # ============================================================
7 # EDITABLE
8 # ============================================================
9 class MembershipDefense:
10 """Variance-Regularized Confidence-Bounded Loss (VRCBL).
11
12 Novel contribution: Per-sample loss variance regularization combined with
13 adaptive flooding, label smoothing, and entropy maximization. The variance
14 penalty directly targets the MIA attack signal by encouraging uniform
15 loss across all training samples, closing the gap between memorized and
16 non-memorized exa
Tool Input
1{
2 "op": "replace",
3 "filename": "pytorch-vision/custom_membership_defense.py",
4 "start_line": 68,
5 "end_line": 71,
6 "content": "# ============================================================\n# END EDITABLE\n# ============================================================\n"
7}OK: Replaced lines 68..71 with 3 line(s) in pytorch-vision/custom_membership_defense.py. Editable range: entire file.
[Current file: pytorch-vision/custom_membership_defense.py | editable: 1–70 | total: 70 lines]
1 """Editable membership-inference defense for MLS-Bench."""
2
3 import torch
4 import torch.nn.functional as F
5
6 # ============================================================
7 # EDITABLE
8 # ============================================================
9 class MembershipDefense:
10 """Variance-Regularized Confidence-Bounded Loss (VRCBL).
11
12 Novel contribution: Per-sample loss variance regularization combined with
13 adaptive flooding, label smoothing, and entropy maximization. The variance
14 penalty directly targets the MIA attack signal by encouraging uniform
15 loss across all training samples, closing the gap between memorized and
16 non-memorized exa
Tool Input
1{
2 "op": "replace",
3 "filename": "pytorch-vision/custom_membership_defense.py",
4 "start_line": 1,
5 "end_line": 31,
6 "content": "\"\"\"Editable membership-inference defense for MLS-Bench.\"\"\"\n\nimport math\nimport torch\nimport torch.nn.functional as F\n\n# ============================================================\n# EDITABLE\n# ============================================================\nclass MembershipDefense:\n \"\"\"Smooth Confidence Flooding with Multi-Regularization (SCF-MR).\n \n Novel defense combining:\n 1. Softplus-based smooth flooding: Unlike RelaxLoss's hard piecewise function,\n uses softplus(threshold - per_sample_ce) to create a differentiable penalty\n that smoothly increases as samples become too well-learned. This avoids \n gradient discontinuities while effectively preventing memorization.\n 2. Label smoothing for baseline confidence calibration.\n 3. Entropy maximization bonus to spread probability mass.\n 4. Logit norm regularization to directly limit confidence magnitude.\n \n The flood level automatically adapts to the number of classes.\n \"\"\"\n\n def __init__(self):\n self.label_smoothing = 0.1\n self.flood_weight = 0.5\n self.flood_beta = 3.0 # sharpness of flooding transition\n self.entropy_weight = 0.1\n self.logit_norm_weight = 0.005\n\n def compute_loss(self, logits, labels, epoch):\n num_classes = logits.size(1)\n \n # Class-adaptive flood level: scales with log(num_classes)\n # For 10 classes: ~0.69, for 100 classes: ~1.38\n flood_level = 0.3 * math.log(num_classes)\n \n # Per-sample cross-entropy with label smoothing\n ce_per = F.cross_entropy(logits, labels, reduction='none',\n label_smoothing=self.label_smoothing)\n ce = ce_per.mean()\n \n # Smooth flooding: softplus creates a differentiable penalty for \n # samples with CE below the flood level, preventing memorization\n flood_penalty = F.softplus(flood_level - ce_per, beta=self.flood_beta).mean()\n \n # Predictive entropy bonus: encourages less peaked distributions\n probs = torch.softmax(logits, dim=1)\n entropy = -(probs * torch.log(probs.clamp_min(1e-8))).sum(dim=1).mean()\n \n # Logit norm regularization: directly limits confidence magnitude\n logit_norm = logits.pow(2).mean()\n \n loss = (ce \n + self.flood_weight * flood_penalty \n - self.entropy_weight * entropy \n + self.logit_norm_weight * logit_norm)\n \n return loss\n# ============================================================\n# END EDITABLE\n# ============================================================\n"
7}OK: Replaced lines 1..31 with 63 line(s) in pytorch-vision/custom_membership_defense.py. Editable range: entire file.
[Current file: pytorch-vision/custom_membership_defense.py | editable: 1–64 | total: 64 lines]
1 """Editable membership-inference defense for MLS-Bench."""
2
3 import math
4 import torch
5 import torch.nn.functional as F
6
7 # ============================================================
8 # EDITABLE
9 # ============================================================
10 class MembershipDefense:
11 """Smooth Confidence Flooding with Multi-Regularization (SCF-MR).
12
13 Novel defense combining:
14 1. Softplus-based smooth flooding: Unlike RelaxLoss's hard piecewise function,
15 uses softplus(threshold - per_sample_ce) to create a differentiable penalty
16 that smoothly increases as samples become too well-learned. This avoids
17 gradient disc
[Test #1] (2 tests remaining; call submit(n=1) to submit this result) ### resnet20-cifar10 (scripts/resnet20_cifar10.sh) TRAIN_METRICS epoch=1 loss=1.7523 train_acc=33.01 lr=0.099975 TRAIN_METRICS epoch=10 loss=1.0178 train_acc=77.25 lr=0.097553 TRAIN_METRICS epoch=20 loss=0.9272 train_acc=82.55 lr=0.090451 TRAIN_METRICS epoch=30 loss=0.8844 train_acc=84.78 lr=0.079389 TRAIN_METRICS epoch=40 loss=0.8464 train_acc=87.22 lr=0.065451 TRAIN_METRICS epoch=50 loss=0.8113 train_acc=88.98 lr=0.050000 TRAIN_METRICS epoch=60 loss=0.7702 train_acc=91.43 lr=0.034549 TRAIN_METRICS epoch=70 loss=0.7274 train_acc=94.08 lr=0.020611 TRAIN_METRICS epoch=80 loss=0.6738 train_acc=97.34 lr=0.009549 TRAIN_METRICS epoch=90 loss=0.6432 train_acc=99.20 lr=0.002447 TRAIN_METRICS epoch=100 loss=0.6360 train_acc=99.52 lr=0.000000 resnet20-cifar10: test_acc=0.8945, mia_auc=0.6045, privacy_gap=0.0595, privacy_score=0.7900 ### vgg16bn-cifar100 (scripts/vgg16bn_cifar100.sh) TRAIN_METRICS epoch=1 loss=4.1895 train_a
Tool Input
1{
2 "op": "replace",
3 "filename": "pytorch-vision/custom_membership_defense.py",
4 "start_line": 1,
5 "end_line": 64,
6 "content": "\"\"\"Editable membership-inference defense for MLS-Bench.\"\"\"\n\nimport math\nimport torch\nimport torch.nn.functional as F\n\n# ============================================================\n# EDITABLE\n# ============================================================\nclass MembershipDefense:\n \"\"\"Smooth Relaxation with Confidence Calibration (SRC).\n\n A novel training-time membership inference defense with three mechanisms:\n\n 1. Sigmoid-Smoothed Loss Relaxation: Unlike RelaxLoss's hard piecewise\n switching between normal and inverted gradients, we use a sigmoid to\n smoothly interpolate. Well-learned samples (low per-sample CE) get\n partially inverted gradients pushing them toward lower confidence,\n while poorly-learned samples receive standard CE gradients.\n This eliminates the gradient discontinuity at the threshold, leading\n to smoother optimization. The blending is self-curricularizing:\n early in training all samples have high CE so blend~0 (pure CE),\n and regularization automatically activates as samples become memorized.\n\n 2. Label Smoothing: Standard soft targets prevent overconfident predictions.\n\n 3. Direct Max-Confidence Penalty: Soft upper bound on maximum prediction\n probability via softplus, directly reducing the signal available to\n confidence-based MIA attacks.\n\n All thresholds adapt to the number of classes via log-scaling.\n \"\"\"\n\n def __init__(self):\n self.label_smoothing = 0.1\n self.alpha = 0.5 # inversion strength for well-learned samples\n self.sharpness = 5.0 # sigmoid transition sharpness\n self.conf_weight = 0.5 # weight for max-confidence penalty\n self.conf_cap = 0.85 # soft cap on max prediction probability\n\n def compute_loss(self, logits, labels, epoch):\n num_classes = logits.size(1)\n\n # Class-adaptive target CE: scales with log(num_classes)\n # 10 classes: ~1.0, 100 classes: ~1.92\n target_ce = 0.4 * math.log(num_classes) + 0.08\n\n # Per-sample cross-entropy with label smoothing\n ce_per = F.cross_entropy(logits, labels, reduction='none',\n label_smoothing=self.label_smoothing)\n\n # Smooth blending factor (detached for stable gradients):\n # ≈1 for well-learned samples (low CE), ≈0 for poorly-learned\n with torch.no_grad():\n blend = torch.sigmoid(self.sharpness * (target_ce - ce_per))\n\n # Smoothly relaxed loss:\n # For high-loss samples (blend≈0): standard CE gradient for learning\n # For low-loss samples (blend≈1): inverted gradient pushes away from\n # memorization, preventing the train-test gap MIA exploits\n relaxed = (1.0 - blend) * ce_per + blend * self.alpha * (target_ce - ce_per)\n loss = relaxed.mean()\n\n # Soft max-confidence penalty: directly limits prediction confidence\n # which is the primary signal used by confidence-based MIA attacks\n probs = torch.softmax(logits, dim=1)\n max_prob = probs.max(dim=1)[0]\n conf_penalty = F.softplus(max_prob - self.conf_cap, beta=10.0).mean()\n\n return loss + self.conf_weight * conf_penalty\n# ============================================================\n# END EDITABLE\n# ============================================================\n"
7}OK: Replaced lines 1..64 with 73 line(s) in pytorch-vision/custom_membership_defense.py. Editable range: entire file.
[Current file: pytorch-vision/custom_membership_defense.py | editable: 1–73 | total: 73 lines]
1 """Editable membership-inference defense for MLS-Bench."""
2
3 import math
4 import torch
5 import torch.nn.functional as F
6
7 # ============================================================
8 # EDITABLE
9 # ============================================================
10 class MembershipDefense:
11 """Smooth Relaxation with Confidence Calibration (SRC).
12
13 A novel training-time membership inference defense with three mechanisms:
14
15 1. Sigmoid-Smoothed Loss Relaxation: Unlike RelaxLoss's hard piecewise
16 switching between normal and inverted gradients, we use a sigmoid to
17 smoothly interpolate. Well-learned samples (low per-sample CE) get
1