Agent Conversation: anthropic/claude-opus-4.6

Model: anthropic/claude-opus-4.6
Total steps: 6
Messages: 25
UserStep 0

Task: security-membership-inference-defense

Membership Inference Defense via Training Regularization

Research Question

How can we design a stronger training-time regularizer that reduces membership inference leakage while preserving standard predictive accuracy?

Background

Membership inference attacks exploit the gap between train and non-train examples, often through confidence or loss statistics. Many defenses regularize predictions to reduce overconfidence and shrink the train-test generalization gap, but stronger privacy often hurts utility.

Task

Implement a better privacy-preserving training objective in bench/membership/custom_membership_defense.py. The fixed harness will train a model on a fixed dataset split, then run a confidence-based membership inference attack on train versus held-out examples.

Your method should improve the privacy-utility tradeoff: lower membership attack AUC while retaining high test accuracy.

Editable Interface

You must implement:

class MembershipDefense:
    def compute_loss(self, logits, labels, model, epoch):
        ...
  • logits: model outputs for the current minibatch
  • labels: ground-truth labels
  • model: current model instance, exposed for optional parameter regularization
  • epoch: current training epoch
  • Return value: scalar loss tensor used by the fixed training loop

The optimizer, architecture, batch schedule, and attack implementation are fixed.

Evaluation

Benchmarks:

  • mnist
  • fmnist
  • cifar10

Reported metrics:

  • test_acc
  • mia_auc
  • privacy_gap
  • privacy_score

Primary metric: privacy_score (higher is better).

Baselines

  • erm: standard cross-entropy training
  • label_smoothing: smoothed targets to reduce overconfidence
  • confidence_penalty: cross-entropy plus predictive entropy penalty
  • relaxloss: margin-aware loss relaxation as a stronger privacy baseline

opacus/bench/membership/custom_membership_defense.py [EDITABLE — entire file only]

     1: """Editable membership-inference defense for MLS-Bench."""
     2: 
     3: import torch
     4: import torch.nn.functional as F
     5: 
     6: # ============================================================
     7: # EDITABLE
     8: # ============================================================
     9: class MembershipDefense:
    10:     """Training-time regularizer for privacy-utility tradeoffs."""
    11: 
    12:     def __init__(self):
    13:         self.label_smoothing = 0.0
    14:         self.entropy_weight = 0.0
    15: 
    16:     def compute_loss(self, logits, labels, model, epoch):
    17:         ce = F.cross_entropy(logits, labels, label_smoothing=self.label_smoothing)
    18:         probs = torch.softmax(logits, dim=1)
    19:         entropy = -(probs * torch.log(probs.clamp_min(1e-8))).sum(dim=1).mean()
    20:         return ce - self.entropy_weight * entropy
    21: # ============================================================
    22
AssistantStep 1edit
Tool Input
1{
2  "op": "replace",
3  "filename": "opacus/bench/membership/custom_membership_defense.py",
4  "start_line": 1,
5  "end_line": 23,
6  "content": "\"\"\"Editable membership-inference defense for MLS-Bench.\"\"\"\n\nimport torch\nimport torch.nn.functional as F\n\n# ============================================================\n# EDITABLE\n# ============================================================\nclass MembershipDefense:\n    \"\"\"Confidence-Bounded Entropy-Regularized Training (CBERT).\n    \n    Novel multi-component defense against membership inference attacks:\n    \n    1. Soft Logit Bounding: Uses scaled tanh to create a differentiable \n       upper bound on logit magnitudes, directly limiting prediction \n       confidence without hard truncation.\n    \n    2. Confidence-Adaptive Smoothing: Per-sample label smoothing that \n       increases with model confidence, creating a negative feedback loop \n       that prevents memorization of individual samples.\n    \n    3. Entropy Maximization: Encourages spreading probability mass across \n       classes while maintaining correct classification.\n    \n    4. Logit Norm Penalty: L2 regularization on raw logits to ensure \n       bounded representations complement the soft clipping.\n    \"\"\"\n\n    def __init__(self):\n        # Soft logit bounding scale\n        self.logit_bound = 3.0\n        # Confidence-adaptive smoothing parameters\n        self.base_smoothing = 0.15\n        self.max_smoothing = 0.45\n        self.confidence_threshold = 0.5\n        self.smoothing_sharpness = 8.0\n        # Entropy bonus weight\n        self.entropy_weight = 0.15\n        # L2 penalty on raw logits\n        self.logit_l2 = 0.008\n\n    def compute_loss(self, logits, labels, model, epoch):\n        B, C = logits.shape\n        \n        # 1. Soft logit bounding via scaled tanh\n        # Creates a differentiable ceiling on logit magnitudes,\n        # directly limiting the maximum achievable confidence\n        bounded = self.logit_bound * torch.tanh(logits / self.logit_bound)\n        \n        # 2. Confidence-adaptive per-sample label smoothing\n        # More confident predictions receive stronger smoothing,\n        # creating a self-regulating negative feedback loop\n        with torch.no_grad():\n            probs_det = torch.softmax(bounded, dim=1)\n            confidence = probs_det.max(dim=1).values\n            # Sigmoid gating: smooth transition from base to max smoothing\n            alpha = self.base_smoothing + (self.max_smoothing - self.base_smoothing) * \\\n                    torch.sigmoid(self.smoothing_sharpness * (confidence - self.confidence_threshold))\n        \n        # Create per-sample soft targets: blend one-hot with uniform\n        one_hot = F.one_hot(labels, C).float()\n        uniform = torch.ones_like(one_hot) / C\n        soft_targets = (1 - alpha.unsqueeze(1)) * one_hot + alpha.unsqueeze(1) * uniform\n        \n        # Cross-entropy with per-sample soft targets\n        log_probs = F.log_softmax(bounded, dim=1)\n        ce = -(soft_targets * log_probs).sum(dim=1).mean()\n        \n        # 3. Entropy maximization on bounded logits\n        probs = torch.softmax(bounded, dim=1)\n        entropy = -(probs * torch.log(probs.clamp_min(1e-8))).sum(dim=1).mean()\n        \n        # 4. L2 penalty on raw logits to keep representations bounded\n        logit_reg = self.logit_l2 * (logits ** 2).mean()\n        \n        return ce - self.entropy_weight * entropy + logit_reg\n# ============================================================\n# END EDITABLE\n# ============================================================\n"
7}
Tool ResultStep 1
OK: Replaced lines 1..23 with 79 line(s) in opacus/bench/membership/custom_membership_defense.py. Editable range: entire file.

[Current file: opacus/bench/membership/custom_membership_defense.py | editable: 1–79 | total: 79 lines]
     1  """Editable membership-inference defense for MLS-Bench."""
     2  
     3  import torch
     4  import torch.nn.functional as F
     5  
     6  # ============================================================
     7  # EDITABLE
     8  # ============================================================
     9  class MembershipDefense:
    10      """Confidence-Bounded Entropy-Regularized Training (CBERT).
    11      
    12      Novel multi-component defense against membership inference attacks:
    13      
    14      1. Soft Logit Bounding: Uses scaled tanh to create a differentiable 
    15         upper bound on logit magnitudes, directly limiting prediction 
    16         confidence without hard truncation.
    17      
    18      2. Confidence-A
AssistantStep 2test
Tool ResultStep 2
[Test #1] (2 tests remaining; call submit(n=1) to submit this result)


## Seed 42

### mnist (scripts/mnist.sh)
TRAIN_METRICS epoch=1 loss=1.9436
TRAIN_METRICS epoch=2 loss=1.6534
TRAIN_METRICS epoch=3 loss=1.3666
TRAIN_METRICS epoch=4 loss=1.2659
TRAIN_METRICS epoch=5 loss=1.2308
mnist: test_acc=0.9215, mia_auc=0.5464, privacy_gap=0.0261, privacy_score=0.8751

### fmnist (scripts/fmnist.sh)
TRAIN_METRICS epoch=1 loss=1.8589
TRAIN_METRICS epoch=2 loss=1.4880
TRAIN_METRICS epoch=3 loss=1.3914
TRAIN_METRICS epoch=4 loss=1.3376
TRAIN_METRICS epoch=5 loss=1.3119
fmnist: test_acc=0.7285, mia_auc=0.5211, privacy_gap=0.0139, privacy_score=0.7074

### cifar10 (scripts/cifar10.sh)
TRAIN_METRICS epoch=1 loss=1.9064
TRAIN_METRICS epoch=2 loss=1.7949
TRAIN_METRICS epoch=3 loss=1.7212
TRAIN_METRICS epoch=4 loss=1.6688
TRAIN_METRICS epoch=5 loss=1.6298
TRAIN_METRICS epoch=6 loss=1.6010
cifar10: test_acc=0.4405, mia_auc=0.4950, privacy_gap=-0.0025, privacy_score=0.4405


## Seed 123

### mnist (scri
AssistantStep 1edit
Tool Input
1{
2  "op": "replace",
3  "filename": "pytorch-vision/custom_membership_defense.py",
4  "start_line": 1,
5  "end_line": 31,
6  "content": "\"\"\"Editable membership-inference defense for MLS-Bench.\"\"\"\n\nimport torch\nimport torch.nn.functional as F\n\n# ============================================================\n# EDITABLE\n# ============================================================\nclass MembershipDefense:\n    \"\"\"Variance-Constrained Relaxed Entropy Loss (VCRE).\n\n    Novel contribution: Combines relaxed loss training with per-sample\n    loss variance regularization, entropy shaping, and logit magnitude\n    control to minimize membership inference leakage.\n\n    Key insight: MIA attacks exploit the distributional gap between\n    per-sample losses of members vs non-members. By (1) relaxing\n    gradients for overly well-fit samples, (2) penalizing the variance\n    of per-sample losses to encourage uniform treatment, (3) maximizing\n    prediction entropy, and (4) penalizing logit magnitudes, we\n    systematically close this gap without sacrificing generalization.\n    \"\"\"\n\n    def __init__(self):\n        # Relaxation parameters (proven in RelaxLoss)\n        self.target_ce = 1.0\n        self.alpha = 0.5\n        # Label smoothing strength\n        self.label_smoothing = 0.15\n        # Entropy bonus weight\n        self.entropy_weight = 0.15\n        # Novel: loss variance penalty weight\n        self.var_weight = 0.3\n        # Logit magnitude penalty weight\n        self.logit_reg_weight = 0.01\n\n    def compute_loss(self, logits, labels, epoch):\n        # Per-sample cross-entropy with label smoothing\n        ce_per = F.cross_entropy(logits, labels, reduction='none',\n                                  label_smoothing=self.label_smoothing)\n\n        # Relaxed loss: for well-fit samples (low CE), reverse the gradient\n        # direction to prevent memorization; for under-fit samples, train normally\n        relaxed = torch.where(\n            ce_per > self.target_ce,\n            ce_per,\n            self.alpha * (self.target_ce - ce_per),\n        )\n        main_loss = relaxed.mean()\n\n        # Novel: Per-sample loss variance regularization\n        # By penalizing the spread of per-sample losses, we encourage the model\n        # to treat all training samples uniformly. This directly reduces the\n        # statistical signal that MIA attacks use (the gap between the loss\n        # distributions of members vs non-members).\n        loss_var = ce_per.var()\n\n        # Maximum entropy regularization: push output distributions toward\n        # higher entropy, reducing overconfidence on training samples\n        probs = F.softmax(logits, dim=1)\n        entropy = -(probs * torch.log(probs.clamp_min(1e-8))).sum(dim=1).mean()\n\n        # Logit magnitude penalty: prevents the model from producing extreme\n        # logit values that lead to high-confidence predictions exploitable by MIA\n        logit_reg = (logits ** 2).mean()\n\n        return (main_loss\n                + self.var_weight * loss_var\n                - self.entropy_weight * entropy\n                + self.logit_reg_weight * logit_reg)\n# ============================================================\n# END EDITABLE\n# ============================================================\n"
7}
Tool ResultStep 1
OK: Replaced lines 1..31 with 73 line(s) in pytorch-vision/custom_membership_defense.py. Editable range: entire file.

[Current file: pytorch-vision/custom_membership_defense.py | editable: 1–74 | total: 74 lines]
     1  """Editable membership-inference defense for MLS-Bench."""
     2  
     3  import torch
     4  import torch.nn.functional as F
     5  
     6  # ============================================================
     7  # EDITABLE
     8  # ============================================================
     9  class MembershipDefense:
    10      """Variance-Constrained Relaxed Entropy Loss (VCRE).
    11  
    12      Novel contribution: Combines relaxed loss training with per-sample
    13      loss variance regularization, entropy shaping, and logit magnitude
    14      control to minimize membership inference leakage.
    15  
    16      Key insight: MIA attacks exploit the distributional gap between
    17      per-sample losses of members vs non-members. By (1) rel
AssistantStep 2test
AssistantStep 1edit
Tool Input
1{
2  "op": "replace",
3  "filename": "pytorch-vision/custom_membership_defense.py",
4  "start_line": 1,
5  "end_line": 31,
6  "content": "\"\"\"Editable membership-inference defense for MLS-Bench.\"\"\"\n\nimport torch\nimport torch.nn.functional as F\n\n# ============================================================\n# EDITABLE\n# ============================================================\nclass MembershipDefense:\n    \"\"\"Variance-Regularized Confidence-Bounded Loss (VRCBL).\n\n    Novel contribution: Per-sample loss variance regularization combined with\n    adaptive flooding, label smoothing, and entropy maximization. The variance\n    penalty directly targets the MIA attack signal by encouraging uniform\n    loss across all training samples, closing the gap between memorized and\n    non-memorized examples that attackers exploit.\n    \"\"\"\n\n    def __init__(self):\n        self.label_smoothing = 0.15\n        self.flood_ratio = 0.15        # fraction of max entropy as flood level\n        self.variance_weight = 1.0      # weight for per-sample loss variance penalty\n        self.entropy_weight = 0.15      # weight for entropy bonus\n        self.logit_reg_weight = 0.005   # weight for logit L2 penalty\n\n    def compute_loss(self, logits, labels, epoch):\n        num_classes = logits.size(1)\n\n        # --- Per-sample cross-entropy with label smoothing ---\n        ce_per = F.cross_entropy(\n            logits, labels,\n            label_smoothing=self.label_smoothing,\n            reduction='none'\n        )\n        ce_mean = ce_per.mean()\n\n        # --- Adaptive flooding: floor on mean loss to prevent memorization ---\n        # Flood level scales with log(num_classes) so it adapts to dataset\n        flood_level = self.flood_ratio * torch.log(\n            torch.tensor(float(num_classes), device=logits.device)\n        )\n        ce_flooded = (ce_mean - flood_level).abs() + flood_level\n\n        # --- Per-sample loss variance penalty (NOVEL) ---\n        # Encourages uniform loss distribution across samples, directly\n        # reducing the signal that MIA attacks exploit (loss gap between\n        # memorized train samples and unseen test samples)\n        loss_var = ce_per.var()\n\n        # --- Entropy bonus: encourages less peaked predictions ---\n        probs = torch.softmax(logits, dim=1)\n        entropy = -(probs * torch.log(probs.clamp_min(1e-8))).sum(dim=1).mean()\n\n        # --- Logit L2 regularization: bounds output confidence ---\n        logit_reg = (logits ** 2).mean()\n\n        # --- Epoch-adaptive regularization (stronger in later epochs) ---\n        # Later epochs are more prone to memorization\n        adaptive = min(1.0, 0.5 + 0.5 * epoch / 50.0)\n\n        total = (\n            ce_flooded\n            + self.variance_weight * adaptive * loss_var\n            - self.entropy_weight * adaptive * entropy\n            + self.logit_reg_weight * logit_reg\n        )\n        return total\n# ============================================================\n# END EDITABLE\n# ============================================================\n"
7}
Tool ResultStep 1
OK: Replaced lines 1..31 with 70 line(s) in pytorch-vision/custom_membership_defense.py. Editable range: entire file.

[Current file: pytorch-vision/custom_membership_defense.py | editable: 1–71 | total: 71 lines]
     1  """Editable membership-inference defense for MLS-Bench."""
     2  
     3  import torch
     4  import torch.nn.functional as F
     5  
     6  # ============================================================
     7  # EDITABLE
     8  # ============================================================
     9  class MembershipDefense:
    10      """Variance-Regularized Confidence-Bounded Loss (VRCBL).
    11  
    12      Novel contribution: Per-sample loss variance regularization combined with
    13      adaptive flooding, label smoothing, and entropy maximization. The variance
    14      penalty directly targets the MIA attack signal by encouraging uniform
    15      loss across all training samples, closing the gap between memorized and
    16      non-memorized exa
AssistantStep 2edit
Tool Input
1{
2  "op": "replace",
3  "filename": "pytorch-vision/custom_membership_defense.py",
4  "start_line": 68,
5  "end_line": 71,
6  "content": "# ============================================================\n# END EDITABLE\n# ============================================================\n"
7}
Tool ResultStep 2
OK: Replaced lines 68..71 with 3 line(s) in pytorch-vision/custom_membership_defense.py. Editable range: entire file.

[Current file: pytorch-vision/custom_membership_defense.py | editable: 1–70 | total: 70 lines]
     1  """Editable membership-inference defense for MLS-Bench."""
     2  
     3  import torch
     4  import torch.nn.functional as F
     5  
     6  # ============================================================
     7  # EDITABLE
     8  # ============================================================
     9  class MembershipDefense:
    10      """Variance-Regularized Confidence-Bounded Loss (VRCBL).
    11  
    12      Novel contribution: Per-sample loss variance regularization combined with
    13      adaptive flooding, label smoothing, and entropy maximization. The variance
    14      penalty directly targets the MIA attack signal by encouraging uniform
    15      loss across all training samples, closing the gap between memorized and
    16      non-memorized exa
AssistantStep 3test
AssistantStep 1edit
Tool Input
1{
2  "op": "replace",
3  "filename": "pytorch-vision/custom_membership_defense.py",
4  "start_line": 1,
5  "end_line": 31,
6  "content": "\"\"\"Editable membership-inference defense for MLS-Bench.\"\"\"\n\nimport math\nimport torch\nimport torch.nn.functional as F\n\n# ============================================================\n# EDITABLE\n# ============================================================\nclass MembershipDefense:\n    \"\"\"Smooth Confidence Flooding with Multi-Regularization (SCF-MR).\n    \n    Novel defense combining:\n    1. Softplus-based smooth flooding: Unlike RelaxLoss's hard piecewise function,\n       uses softplus(threshold - per_sample_ce) to create a differentiable penalty\n       that smoothly increases as samples become too well-learned. This avoids \n       gradient discontinuities while effectively preventing memorization.\n    2. Label smoothing for baseline confidence calibration.\n    3. Entropy maximization bonus to spread probability mass.\n    4. Logit norm regularization to directly limit confidence magnitude.\n    \n    The flood level automatically adapts to the number of classes.\n    \"\"\"\n\n    def __init__(self):\n        self.label_smoothing = 0.1\n        self.flood_weight = 0.5\n        self.flood_beta = 3.0       # sharpness of flooding transition\n        self.entropy_weight = 0.1\n        self.logit_norm_weight = 0.005\n\n    def compute_loss(self, logits, labels, epoch):\n        num_classes = logits.size(1)\n        \n        # Class-adaptive flood level: scales with log(num_classes)\n        # For 10 classes: ~0.69, for 100 classes: ~1.38\n        flood_level = 0.3 * math.log(num_classes)\n        \n        # Per-sample cross-entropy with label smoothing\n        ce_per = F.cross_entropy(logits, labels, reduction='none',\n                                  label_smoothing=self.label_smoothing)\n        ce = ce_per.mean()\n        \n        # Smooth flooding: softplus creates a differentiable penalty for \n        # samples with CE below the flood level, preventing memorization\n        flood_penalty = F.softplus(flood_level - ce_per, beta=self.flood_beta).mean()\n        \n        # Predictive entropy bonus: encourages less peaked distributions\n        probs = torch.softmax(logits, dim=1)\n        entropy = -(probs * torch.log(probs.clamp_min(1e-8))).sum(dim=1).mean()\n        \n        # Logit norm regularization: directly limits confidence magnitude\n        logit_norm = logits.pow(2).mean()\n        \n        loss = (ce \n                + self.flood_weight * flood_penalty \n                - self.entropy_weight * entropy \n                + self.logit_norm_weight * logit_norm)\n        \n        return loss\n# ============================================================\n# END EDITABLE\n# ============================================================\n"
7}
Tool ResultStep 1
OK: Replaced lines 1..31 with 63 line(s) in pytorch-vision/custom_membership_defense.py. Editable range: entire file.

[Current file: pytorch-vision/custom_membership_defense.py | editable: 1–64 | total: 64 lines]
     1  """Editable membership-inference defense for MLS-Bench."""
     2  
     3  import math
     4  import torch
     5  import torch.nn.functional as F
     6  
     7  # ============================================================
     8  # EDITABLE
     9  # ============================================================
    10  class MembershipDefense:
    11      """Smooth Confidence Flooding with Multi-Regularization (SCF-MR).
    12      
    13      Novel defense combining:
    14      1. Softplus-based smooth flooding: Unlike RelaxLoss's hard piecewise function,
    15         uses softplus(threshold - per_sample_ce) to create a differentiable penalty
    16         that smoothly increases as samples become too well-learned. This avoids 
    17         gradient disc
AssistantStep 2test
Tool ResultStep 2
[Test #1] (2 tests remaining; call submit(n=1) to submit this result)

### resnet20-cifar10 (scripts/resnet20_cifar10.sh)
TRAIN_METRICS epoch=1 loss=1.7523 train_acc=33.01 lr=0.099975
TRAIN_METRICS epoch=10 loss=1.0178 train_acc=77.25 lr=0.097553
TRAIN_METRICS epoch=20 loss=0.9272 train_acc=82.55 lr=0.090451
TRAIN_METRICS epoch=30 loss=0.8844 train_acc=84.78 lr=0.079389
TRAIN_METRICS epoch=40 loss=0.8464 train_acc=87.22 lr=0.065451
TRAIN_METRICS epoch=50 loss=0.8113 train_acc=88.98 lr=0.050000
TRAIN_METRICS epoch=60 loss=0.7702 train_acc=91.43 lr=0.034549
TRAIN_METRICS epoch=70 loss=0.7274 train_acc=94.08 lr=0.020611
TRAIN_METRICS epoch=80 loss=0.6738 train_acc=97.34 lr=0.009549
TRAIN_METRICS epoch=90 loss=0.6432 train_acc=99.20 lr=0.002447
TRAIN_METRICS epoch=100 loss=0.6360 train_acc=99.52 lr=0.000000
resnet20-cifar10: test_acc=0.8945, mia_auc=0.6045, privacy_gap=0.0595, privacy_score=0.7900

### vgg16bn-cifar100 (scripts/vgg16bn_cifar100.sh)
TRAIN_METRICS epoch=1 loss=4.1895 train_a
AssistantStep 3edit
Tool Input
1{
2  "op": "replace",
3  "filename": "pytorch-vision/custom_membership_defense.py",
4  "start_line": 1,
5  "end_line": 64,
6  "content": "\"\"\"Editable membership-inference defense for MLS-Bench.\"\"\"\n\nimport math\nimport torch\nimport torch.nn.functional as F\n\n# ============================================================\n# EDITABLE\n# ============================================================\nclass MembershipDefense:\n    \"\"\"Smooth Relaxation with Confidence Calibration (SRC).\n\n    A novel training-time membership inference defense with three mechanisms:\n\n    1. Sigmoid-Smoothed Loss Relaxation: Unlike RelaxLoss's hard piecewise\n       switching between normal and inverted gradients, we use a sigmoid to\n       smoothly interpolate. Well-learned samples (low per-sample CE) get\n       partially inverted gradients pushing them toward lower confidence,\n       while poorly-learned samples receive standard CE gradients.\n       This eliminates the gradient discontinuity at the threshold, leading\n       to smoother optimization. The blending is self-curricularizing:\n       early in training all samples have high CE so blend~0 (pure CE),\n       and regularization automatically activates as samples become memorized.\n\n    2. Label Smoothing: Standard soft targets prevent overconfident predictions.\n\n    3. Direct Max-Confidence Penalty: Soft upper bound on maximum prediction\n       probability via softplus, directly reducing the signal available to\n       confidence-based MIA attacks.\n\n    All thresholds adapt to the number of classes via log-scaling.\n    \"\"\"\n\n    def __init__(self):\n        self.label_smoothing = 0.1\n        self.alpha = 0.5       # inversion strength for well-learned samples\n        self.sharpness = 5.0   # sigmoid transition sharpness\n        self.conf_weight = 0.5 # weight for max-confidence penalty\n        self.conf_cap = 0.85   # soft cap on max prediction probability\n\n    def compute_loss(self, logits, labels, epoch):\n        num_classes = logits.size(1)\n\n        # Class-adaptive target CE: scales with log(num_classes)\n        # 10 classes: ~1.0, 100 classes: ~1.92\n        target_ce = 0.4 * math.log(num_classes) + 0.08\n\n        # Per-sample cross-entropy with label smoothing\n        ce_per = F.cross_entropy(logits, labels, reduction='none',\n                                  label_smoothing=self.label_smoothing)\n\n        # Smooth blending factor (detached for stable gradients):\n        # ≈1 for well-learned samples (low CE), ≈0 for poorly-learned\n        with torch.no_grad():\n            blend = torch.sigmoid(self.sharpness * (target_ce - ce_per))\n\n        # Smoothly relaxed loss:\n        # For high-loss samples (blend≈0): standard CE gradient for learning\n        # For low-loss samples (blend≈1): inverted gradient pushes away from\n        #   memorization, preventing the train-test gap MIA exploits\n        relaxed = (1.0 - blend) * ce_per + blend * self.alpha * (target_ce - ce_per)\n        loss = relaxed.mean()\n\n        # Soft max-confidence penalty: directly limits prediction confidence\n        # which is the primary signal used by confidence-based MIA attacks\n        probs = torch.softmax(logits, dim=1)\n        max_prob = probs.max(dim=1)[0]\n        conf_penalty = F.softplus(max_prob - self.conf_cap, beta=10.0).mean()\n\n        return loss + self.conf_weight * conf_penalty\n# ============================================================\n# END EDITABLE\n# ============================================================\n"
7}
Tool ResultStep 3
OK: Replaced lines 1..64 with 73 line(s) in pytorch-vision/custom_membership_defense.py. Editable range: entire file.

[Current file: pytorch-vision/custom_membership_defense.py | editable: 1–73 | total: 73 lines]
     1  """Editable membership-inference defense for MLS-Bench."""
     2  
     3  import math
     4  import torch
     5  import torch.nn.functional as F
     6  
     7  # ============================================================
     8  # EDITABLE
     9  # ============================================================
    10  class MembershipDefense:
    11      """Smooth Relaxation with Confidence Calibration (SRC).
    12  
    13      A novel training-time membership inference defense with three mechanisms:
    14  
    15      1. Sigmoid-Smoothed Loss Relaxation: Unlike RelaxLoss's hard piecewise
    16         switching between normal and inverted gradients, we use a sigmoid to
    17         smoothly interpolate. Well-learned samples (low per-sample CE) get
    1
AssistantStep 4test