Agent Conversation: openai/gpt-5.4-pro

Model: openai/gpt-5.4-pro
Total steps: 5
Messages: 14
UserStep 0

Task: security-membership-inference-defense

Membership Inference Defense via Training Regularization

Research Question

How can we design a stronger training-time regularizer that reduces membership inference leakage while preserving standard predictive accuracy?

Background

Membership inference attacks exploit the gap between train and non-train examples, often through confidence or loss statistics. Many defenses regularize predictions to reduce overconfidence and shrink the train-test generalization gap, but stronger privacy often hurts utility.

Task

Implement a better privacy-preserving training objective in bench/membership/custom_membership_defense.py. The fixed harness will train a model on a fixed dataset split, then run a confidence-based membership inference attack on train versus held-out examples.

Your method should improve the privacy-utility tradeoff: lower membership attack AUC while retaining high test accuracy.

Editable Interface

You must implement:

class MembershipDefense:
    def compute_loss(self, logits, labels, model, epoch):
        ...
  • logits: model outputs for the current minibatch
  • labels: ground-truth labels
  • model: current model instance, exposed for optional parameter regularization
  • epoch: current training epoch
  • Return value: scalar loss tensor used by the fixed training loop

The optimizer, architecture, batch schedule, and attack implementation are fixed.

Evaluation

Benchmarks:

  • mnist
  • fmnist
  • cifar10

Reported metrics:

  • test_acc
  • mia_auc
  • privacy_gap
  • privacy_score

Primary metric: privacy_score (higher is better).

Baselines

  • erm: standard cross-entropy training
  • label_smoothing: smoothed targets to reduce overconfidence
  • confidence_penalty: cross-entropy plus predictive entropy penalty
  • relaxloss: margin-aware loss relaxation as a stronger privacy baseline

opacus/bench/membership/custom_membership_defense.py [EDITABLE — entire file only]

     1: """Editable membership-inference defense for MLS-Bench."""
     2: 
     3: import torch
     4: import torch.nn.functional as F
     5: 
     6: # ============================================================
     7: # EDITABLE
     8: # ============================================================
     9: class MembershipDefense:
    10:     """Training-time regularizer for privacy-utility tradeoffs."""
    11: 
    12:     def __init__(self):
    13:         self.label_smoothing = 0.0
    14:         self.entropy_weight = 0.0
    15: 
    16:     def compute_loss(self, logits, labels, model, epoch):
    17:         ce = F.cross_entropy(logits, labels, label_smoothing=self.label_smoothing)
    18:         probs = torch.softmax(logits, dim=1)
    19:         entropy = -(probs * torch.log(probs.clamp_min(1e-8))).sum(dim=1).mean()
    20:         return ce - self.entropy_weight * entropy
    21: # ============================================================
    22
AssistantStep 1edit
Tool Input
1{
2  "op": "replace",
3  "filename": "pytorch-vision/custom_membership_defense.py",
4  "content": "\"\"\"Editable membership-inference defense for MLS-Bench.\"\"\"\n\nimport torch\nimport torch.nn.functional as F\n\n# ============================================================\n# EDITABLE\n# ============================================================\nclass MembershipDefense:\n    \"\"\"Confidence-capped adaptive smoothing for membership privacy.\n\n    Idea:\n    - keep standard ERM behavior on hard / misclassified examples;\n    - only regularize examples that are already correct and becoming\n      overconfident, since those dominate confidence-based MIA signals;\n    - use an EMA of correct-class confidence to set an adaptive target,\n      making the defense robust across datasets with very different\n      difficulty and class count.\n    \"\"\"\n\n    def __init__(self):\n        # Delay privacy regularization until the model has learned a useful\n        # classifier, then smoothly ramp the defense on.\n        self.warmup_epochs = 8\n        self.ramp_tau = 12.0\n\n        # EMA target for correctly-classified confidence.\n        self.ema_momentum = 0.90\n        self.correct_conf_ema = None\n\n        # Confidence target schedule. During warmup the target is near 1.0,\n        # then it anneals toward an EMA-based cap.\n        self.init_target = 0.98\n        self.min_target = 0.66\n        self.max_target = 0.82\n        self.conf_gap = 0.08\n\n        # Regularization strengths.\n        self.max_smoothing = 0.16\n        self.excess_weight = 0.70\n        self.mean_weight = 0.20\n        self.entropy_weight = 0.01\n\n    def _update_conf_ema(self, value):\n        value = value.detach()\n        if self.correct_conf_ema is None:\n            self.correct_conf_ema = value\n        else:\n            self.correct_conf_ema = (\n                self.ema_momentum * self.correct_conf_ema.to(value.device)\n                + (1.0 - self.ema_momentum) * value\n            )\n\n    def compute_loss(self, logits, labels, epoch):\n        log_probs = F.log_softmax(logits, dim=1)\n        probs = log_probs.exp()\n\n        true_log_probs = log_probs.gather(1, labels.unsqueeze(1)).squeeze(1)\n        true_probs = true_log_probs.exp()\n        nll = -true_log_probs\n        uniform_ce = -log_probs.mean(dim=1)\n\n        preds = probs.argmax(dim=1)\n        correct_mask = preds.eq(labels)\n        correct_conf = true_probs[correct_mask]\n\n        with torch.no_grad():\n            ema_source = correct_conf if correct_conf.numel() > 0 else true_probs\n            self._update_conf_ema(ema_source.mean())\n\n        if epoch < self.warmup_epochs:\n            ramp = logits.new_tensor(0.0)\n            target_conf = logits.new_tensor(self.init_target)\n        else:\n            steps = float(epoch - self.warmup_epochs + 1)\n            decay = torch.exp(logits.new_tensor(-steps / self.ramp_tau))\n            ramp = 1.0 - decay\n\n            ema_target = torch.clamp(\n                self.correct_conf_ema.to(logits.device) - self.conf_gap,\n                min=self.min_target,\n                max=self.max_target,\n            )\n            target_conf = ema_target + (self.init_target - ema_target) * decay\n\n        # Confidence-adaptive label smoothing: only examples whose true-class\n        # confidence exceeds the target receive extra smoothing.\n        excess_ratio = (\n            (true_probs.detach() - target_conf) / (1.0 - target_conf + 1e-6)\n        ).clamp(min=0.0, max=1.0)\n        adaptive_eps = self.max_smoothing * ramp * excess_ratio\n        adaptive_ce = ((1.0 - adaptive_eps) * nll + adaptive_eps * uniform_ce).mean()\n\n        # Additional cap on overconfident correct examples to directly shrink\n        # the train-confidence tail exploited by MIA.\n        over_mask = correct_mask & (true_probs.detach() > target_conf)\n        over_conf = true_probs[over_mask]\n        if over_conf.numel() > 0:\n            excess_pen = (over_conf - target_conf).pow(2).mean()\n            mean_pen = (over_conf.mean() - target_conf).clamp_min(0.0).pow(2)\n            over_entropy = -(probs[over_mask] * log_probs[over_mask]).sum(dim=1).mean()\n        else:\n            zero = logits.new_tensor(0.0)\n            excess_pen = zero\n            mean_pen = zero\n            over_entropy = zero\n\n        return adaptive_ce + ramp * (\n            self.excess_weight * excess_pen\n            + self.mean_weight * mean_pen\n            - self.entropy_weight * over_entropy\n        )\n# ============================================================\n# END EDITABLE\n# ============================================================\n",
5  "after_line": 0,
6  "start_line": 1,
7  "end_line": 32
8}
Tool ResultStep 1
OK: Replaced lines 1..32 with 115 line(s) in pytorch-vision/custom_membership_defense.py. Editable range: entire file.

[Current file: pytorch-vision/custom_membership_defense.py | editable: 1–115 | total: 115 lines]
     1  """Editable membership-inference defense for MLS-Bench."""
     2  
     3  import torch
     4  import torch.nn.functional as F
     5  
     6  # ============================================================
     7  # EDITABLE
     8  # ============================================================
     9  class MembershipDefense:
    10      """Confidence-capped adaptive smoothing for membership privacy.
    11  
    12      Idea:
    13      - keep standard ERM behavior on hard / misclassified examples;
    14      - only regularize examples that are already correct and becoming
    15        overconfident, since those dominate confidence-based MIA signals;
    16      - use an EMA of correct-class confidence to set an adaptive target,
    17        making the def
AssistantStep 2test
AssistantStep 1edit
Tool Input
1{
2  "op": "replace",
3  "filename": "pytorch-vision/custom_membership_defense.py",
4  "content": "\"\"\"Editable membership-inference defense for MLS-Bench.\"\"\"\n\nimport math\n\nimport torch\nimport torch.nn.functional as F\n\n# ============================================================\n# EDITABLE\n# ============================================================\nclass MembershipDefense:\n    \"\"\"Adaptive confidence-capped self-distillation.\n\n    Core idea:\n    1) keep standard supervised learning on hard examples;\n    2) on already-easy / high-confidence examples, replace one-hot targets\n       with a tempered self-target so the model preserves class ranking while\n       reducing overconfident train-only predictions;\n    3) explicitly cap the upper tail of true-class confidence and shrink its\n       spread, which directly targets confidence-based membership inference.\n    \"\"\"\n\n    def __init__(self):\n        # Mild global smoothing keeps optimization stable.\n        self.base_smoothing = 0.02\n\n        # Adaptive self-distillation parameters.\n        self.alpha_max = 0.35\n        self.temperature = 2.2\n\n        # Confidence gate is applied on normalized true-class confidence so the\n        # same thresholds transfer across datasets with different class counts.\n        self.gate_center = 0.55\n        self.gate_width = 0.08\n\n        # Upper-tail confidence control.\n        self.cap_weight = 0.8\n        self.var_weight = 0.04\n\n    def compute_loss(self, logits, labels, epoch):\n        num_classes = logits.size(1)\n        dtype = logits.dtype\n\n        log_probs = F.log_softmax(logits, dim=1)\n        probs = log_probs.exp()\n\n        true_probs = probs.gather(1, labels.unsqueeze(1)).squeeze(1)\n        uniform_prob = 1.0 / float(num_classes)\n        norm_true = ((true_probs - uniform_prob) / max(1.0 - uniform_prob, 1e-6)).clamp(0.0, 1.0)\n\n        # Start close to ERM, then progressively regularize confident examples.\n        ramp = 1.0 - math.exp(-float(epoch + 1) / 20.0)\n        ramp_t = logits.new_tensor(ramp)\n\n        gate_center = self.gate_center + 0.10 * (1.0 - ramp)\n        gate = torch.sigmoid((norm_true.detach() - gate_center) / self.gate_width)\n\n        with torch.no_grad():\n            tempered = F.softmax(logits / self.temperature, dim=1)\n\n        one_hot = F.one_hot(labels, num_classes=num_classes).to(dtype=dtype)\n        uniform = torch.full_like(tempered, 1.0 / float(num_classes))\n        alpha = (self.alpha_max * ramp_t * gate).unsqueeze(1)\n\n        # Interpolate between one-hot supervision and a softer, model-aware\n        # target that preserves class similarities better than uniform label\n        # smoothing.\n        targets = (1.0 - self.base_smoothing) * ((1.0 - alpha) * one_hot + alpha * tempered) + self.base_smoothing * uniform\n        soft_ce = -(targets * log_probs).sum(dim=1)\n\n        # Confidence-based MIA relies on the train-only upper tail. Penalize\n        # only excessive confidence, not all confidence.\n        target_norm = 0.74 + 0.08 * ramp\n        cap_penalty = gate * F.relu(norm_true - target_norm).pow(2)\n\n        # Compress the spread of easy-example confidence to further reduce the\n        # separability exploited by confidence attacks.\n        weight_sum = gate.sum().clamp_min(1e-6)\n        mean_conf = (gate * norm_true.detach()).sum() / weight_sum\n        var_penalty = (gate * (norm_true - mean_conf).pow(2)).sum() / weight_sum\n\n        return (\n            soft_ce.mean()\n            + self.cap_weight * ramp_t * cap_penalty.mean()\n            + self.var_weight * ramp_t * var_penalty\n        )\n# ============================================================\n# END EDITABLE\n# ============================================================\n",
5  "after_line": 0,
6  "start_line": 1,
7  "end_line": 32
8}
Tool ResultStep 1
OK: Replaced lines 1..32 with 89 line(s) in pytorch-vision/custom_membership_defense.py. Editable range: entire file.

[Current file: pytorch-vision/custom_membership_defense.py | editable: 1–89 | total: 89 lines]
     1  """Editable membership-inference defense for MLS-Bench."""
     2  
     3  import math
     4  
     5  import torch
     6  import torch.nn.functional as F
     7  
     8  # ============================================================
     9  # EDITABLE
    10  # ============================================================
    11  class MembershipDefense:
    12      """Adaptive confidence-capped self-distillation.
    13  
    14      Core idea:
    15      1) keep standard supervised learning on hard examples;
    16      2) on already-easy / high-confidence examples, replace one-hot targets
    17         with a tempered self-target so the model preserves class ranking while
    18         reducing overconfident train-only predictions;
    19      3) explicitly 
AssistantStep 2test
Tool ResultStep 2
[Test #1] (2 tests remaining; call submit(n=1) to submit this result)

### resnet20-cifar10 (scripts/resnet20_cifar10.sh)
TRAIN_METRICS epoch=1 loss=1.8253 train_acc=33.71 lr=0.099975
TRAIN_METRICS epoch=10 loss=0.9225 train_acc=77.39 lr=0.097553
TRAIN_METRICS epoch=20 loss=0.9667 train_acc=81.65 lr=0.090451
TRAIN_METRICS epoch=30 loss=1.0114 train_acc=84.34 lr=0.079389
TRAIN_METRICS epoch=40 loss=1.0300 train_acc=86.94 lr=0.065451
TRAIN_METRICS epoch=50 loss=1.0440 train_acc=88.79 lr=0.050000
TRAIN_METRICS epoch=60 loss=1.0369 train_acc=91.37 lr=0.034549
TRAIN_METRICS epoch=70 loss=1.0285 train_acc=94.06 lr=0.020611
TRAIN_METRICS epoch=80 loss=1.0154 train_acc=97.15 lr=0.009549
TRAIN_METRICS epoch=90 loss=1.0131 train_acc=98.81 lr=0.002447
TRAIN_METRICS epoch=100 loss=1.0146 train_acc=99.35 lr=0.000000
resnet20-cifar10: test_acc=0.8937, mia_auc=0.5953, privacy_gap=0.0458, privacy_score=0.7984

### vgg16bn-cifar100 (scripts/vgg16bn_cifar100.sh)
TRAIN_METRICS epoch=1 loss=4.6580 train_a
AssistantStep 3edit
Tool Input
1{
2  "op": "replace",
3  "filename": "pytorch-vision/custom_membership_defense.py",
4  "content": "\"\"\"Editable membership-inference defense for MLS-Bench.\"\"\"\n\nimport math\n\nimport torch\nimport torch.nn.functional as F\n\n# ============================================================\n# EDITABLE\n# ============================================================\nclass MembershipDefense:\n    \"\"\"Adaptive tempered flooding for privacy-utility tradeoffs.\n\n    The loss keeps standard supervised learning, but only once an example has\n    become \"easy\" (low loss relative to the random-guess loss log(K)) it adds:\n      1) tempered self-distillation to soften train-only overconfidence while\n         preserving class similarity structure; and\n      2) a smooth flooding barrier that prevents the loss from collapsing far\n         below a class-count-aware target, directly shrinking the signal used by\n         confidence/loss-based membership inference.\n    \"\"\"\n\n    def __init__(self):\n        self.base_smoothing = 0.01\n        self.temperature = 2.4\n        self.distill_weight = 0.80\n        self.floor_weight = 0.90\n        self.spread_weight = 0.02\n\n        # Gate over easy examples using loss normalized by log(num_classes).\n        self.gate_center = 0.58\n        self.gate_width = 0.10\n\n    def compute_loss(self, logits, labels, epoch):\n        num_classes = logits.size(1)\n        log_k = math.log(float(num_classes))\n\n        log_probs = F.log_softmax(logits, dim=1)\n        hard_ce = F.cross_entropy(logits, labels, reduction=\"none\")\n        base_ce = F.cross_entropy(\n            logits,\n            labels,\n            reduction=\"none\",\n            label_smoothing=self.base_smoothing,\n        )\n\n        # Start near ERM and gradually turn on privacy regularization.\n        ramp = 1.0 - math.exp(-float(epoch + 1) / 15.0)\n        ramp_t = logits.new_tensor(ramp)\n\n        # Normalized skill = improvement over random-guess loss, in [0, 1].\n        skill = (1.0 - hard_ce.detach() / log_k).clamp(0.0, 1.0)\n        gate_center = self.gate_center + 0.08 * (1.0 - ramp)\n        gate = torch.sigmoid((skill - gate_center) / self.gate_width)\n\n        with torch.no_grad():\n            tempered = F.softmax(logits / self.temperature, dim=1)\n\n        distill = F.kl_div(log_probs, tempered, reduction=\"none\").sum(dim=1)\n\n        # Smooth, class-count-aware flooding barrier: allow learning on hard\n        # samples, but discourage memorized train samples from driving the loss\n        # much below a moderate target.\n        target_ce = log_k * (0.24 + 0.05 * ramp)\n        flood = gate * F.relu(target_ce - hard_ce).pow(2)\n\n        # Compress the easy-sample loss spread; confidence attacks exploit the\n        # particularly low-loss tail of the train distribution.\n        weight_sum = gate.sum().clamp_min(1e-6)\n        mean_ce = (gate * hard_ce.detach()).sum() / weight_sum\n        spread = (gate * (hard_ce - mean_ce).pow(2)).sum() / weight_sum\n\n        return (\n            base_ce.mean()\n            + self.distill_weight * ramp_t * (gate * distill).mean()\n            + self.floor_weight * ramp_t * flood.mean()\n            + self.spread_weight * ramp_t * spread\n        )\n# ============================================================\n# END EDITABLE\n# ============================================================\n",
5  "after_line": 0,
6  "start_line": 1,
7  "end_line": 89
8}
Tool ResultStep 3
OK: Replaced lines 1..89 with 81 line(s) in pytorch-vision/custom_membership_defense.py. Editable range: entire file.

[Current file: pytorch-vision/custom_membership_defense.py | editable: 1–81 | total: 81 lines]
     1  """Editable membership-inference defense for MLS-Bench."""
     2  
     3  import math
     4  
     5  import torch
     6  import torch.nn.functional as F
     7  
     8  # ============================================================
     9  # EDITABLE
    10  # ============================================================
    11  class MembershipDefense:
    12      """Adaptive tempered flooding for privacy-utility tradeoffs.
    13  
    14      The loss keeps standard supervised learning, but only once an example has
    15      become "easy" (low loss relative to the random-guess loss log(K)) it adds:
    16        1) tempered self-distillation to soften train-only overconfidence while
    17           preserving class similarity structure; and
    18        2)
AssistantStep 4test
Tool ResultStep 4
[Test #2] (1 test remaining; call submit(n=2) to submit this result)

### resnet20-cifar10 (scripts/resnet20_cifar10.sh)
TRAIN_METRICS epoch=1 loss=1.7912 train_acc=34.46 lr=0.099975
TRAIN_METRICS epoch=10 loss=0.9509 train_acc=77.37 lr=0.097553
TRAIN_METRICS epoch=20 loss=0.9497 train_acc=81.80 lr=0.090451
TRAIN_METRICS epoch=30 loss=0.9469 train_acc=84.34 lr=0.079389
TRAIN_METRICS epoch=40 loss=0.9243 train_acc=87.01 lr=0.065451
TRAIN_METRICS epoch=50 loss=0.9034 train_acc=88.63 lr=0.050000
TRAIN_METRICS epoch=60 loss=0.8689 train_acc=91.13 lr=0.034549
TRAIN_METRICS epoch=70 loss=0.8315 train_acc=93.57 lr=0.020611
TRAIN_METRICS epoch=80 loss=0.7866 train_acc=96.79 lr=0.009549
TRAIN_METRICS epoch=90 loss=0.7579 train_acc=98.82 lr=0.002447
TRAIN_METRICS epoch=100 loss=0.7494 train_acc=99.29 lr=0.000000
resnet20-cifar10: test_acc=0.8937, mia_auc=0.5853, privacy_gap=0.0312, privacy_score=0.8084

### vgg16bn-cifar100 (scripts/vgg16bn_cifar100.sh)
TRAIN_METRICS epoch=1 loss=4.6652 train_ac
AssistantStep 5submit
Tool Input
1{
2  "n": 1
3}
Tool ResultStep 5
[submit] Submitting result from test #1 as final.

### resnet20-cifar10 (scripts/resnet20_cifar10.sh)
TRAIN_METRICS epoch=1 loss=1.8253 train_acc=33.71 lr=0.099975
TRAIN_METRICS epoch=10 loss=0.9225 train_acc=77.39 lr=0.097553
TRAIN_METRICS epoch=20 loss=0.9667 train_acc=81.65 lr=0.090451
TRAIN_METRICS epoch=30 loss=1.0114 train_acc=84.34 lr=0.079389
TRAIN_METRICS epoch=40 loss=1.0300 train_acc=86.94 lr=0.065451
TRAIN_METRICS epoch=50 loss=1.0440 train_acc=88.79 lr=0.050000
TRAIN_METRICS epoch=60 loss=1.0369 train_acc=91.37 lr=0.034549
TRAIN_METRICS epoch=70 loss=1.0285 train_acc=94.06 lr=0.020611
TRAIN_METRICS epoch=80 loss=1.0154 train_acc=97.15 lr=0.009549
TRAIN_METRICS epoch=90 loss=1.0131 train_acc=98.81 lr=0.002447
TRAIN_METRICS epoch=100 loss=1.0146 train_acc=99.35 lr=0.000000
resnet20-cifar10: test_acc=0.8937, mia_auc=0.5953, privacy_gap=0.0458, privacy_score=0.7984

### vgg16bn-cifar100 (scripts/vgg16bn_cifar100.sh)
TRAIN_METRICS epoch=1 loss=4.6580 train_acc=1.14 lr=0.099975