Agent Conversation: openai/gpt-5.4-pro

Model: openai/gpt-5.4-pro
Total steps: 6
Messages: 16
UserStep 0

Task: security-machine-unlearning

Machine Unlearning via Targeted Update Rules

Research Question

How can we design a stronger unlearning update rule that removes information about a forget set while retaining as much utility as possible on the retained data?

Background

Machine unlearning methods approximate the effect of retraining without the deleted data. The central tradeoff is clear: aggressive forgetting reduces utility, while conservative updates leave measurable traces of the forgotten examples.

Task

Implement a better unlearning rule in bench/unlearning/custom_unlearning.py. The fixed harness trains an initial model, defines a forget split, and then applies your update rule for a fixed number of unlearning steps using retain and forget minibatches.

Your method should lower forget-set memorization while preserving retained-task accuracy.

Editable Interface

You must implement:

class UnlearningMethod:
    def unlearn_step(self, model, retain_batch, forget_batch, optimizer, step, epoch):
        ...
  • retain_batch: minibatch sampled from retained data
  • forget_batch: minibatch sampled from the forget set
  • optimizer: fixed optimizer instance
  • Return value: dict with at least loss

The architecture, initial training, forget split, and evaluation probes are fixed.

Evaluation

Benchmarks:

  • mnist-class0
  • fmnist-class0
  • cifar10-class0

Reported metrics:

  • retain_acc
  • forget_acc
  • forget_mia_auc
  • unlearn_score

Primary metric: unlearn_score (higher is better).

Baselines

  • retain_finetune: continue training only on retained data
  • negative_gradient: ascend forget loss and descend retain loss
  • bad_teacher: distillation-style forgetting baseline
  • scrub: stronger representation-scrubbing baseline

opacus/bench/unlearning/custom_unlearning.py [EDITABLE — entire file only]

     1: """Editable unlearning method for MLS-Bench."""
     2: 
     3: import torch
     4: import torch.nn.functional as F
     5: 
     6: # ============================================================
     7: # EDITABLE
     8: # ============================================================
     9: class UnlearningMethod:
    10:     """Default retain-only finetuning update."""
    11: 
    12:     def __init__(self):
    13:         self.forget_weight = 0.0
    14: 
    15:     def unlearn_step(self, model, retain_batch, forget_batch, optimizer, step, epoch):
    16:         retain_x, retain_y = retain_batch
    17:         logits = model(retain_x)
    18:         loss = F.cross_entropy(logits, retain_y)
    19:         optimizer.zero_grad()
    20:         loss.backward()
    21:         optimizer.step()
    22:         return {"loss": loss.item()}
    23: # ============================================================
    24: # END EDITABLE
    25: # ============================================================

retain_finetune baseline — editable region [READ-ONLY — ref

AssistantStep 1edit
Tool Input
1{
2  "op": "replace",
3  "filename": "pytorch-vision/bench/unlearning/custom_unlearning.py",
4  "content": "\"\"\"Editable unlearning method for MLS-Bench.\"\"\"\n\nimport copy\n\nimport torch\nimport torch.nn.functional as F\n\n# ============================================================\n# EDITABLE\n# ============================================================\nclass UnlearningMethod:\n    \"\"\"Teacher-anchored counterfactual unlearning with gradient conflict control.\n\n    Core idea:\n    1) Snapshot the pretrained model on the first step and use it as a frozen\n       retain-teacher.\n    2) Preserve retained utility with supervised CE + teacher distillation.\n    3) Forget using a counterfactual target on forget examples: match the\n       teacher's non-true-class distribution mixed with a uniform non-true\n       distribution, while explicitly suppressing the true class and increasing\n       uncertainty.\n    4) Combine retain/forget gradients with conflict-aware projection and an\n       elementwise saliency gate so forgetting focuses on parameters that matter\n       more for forget than retain.\n    \"\"\"\n\n    def __init__(self):\n        self.teacher = None\n\n        # Retain-preservation hyperparameters.\n        self.distill_temp = 2.0\n        self.retain_kd_weight = 0.8\n\n        # Forget-objective hyperparameters.\n        self.forget_temp = 2.5\n        self.counterfactual_mix = 0.65\n        self.margin_weight = 0.75\n        self.confidence_weight = 0.5\n        self.entropy_weight = 0.05\n\n        # Gradient-combination hyperparameters.\n        self.base_forget_weight = 1.0\n        self.retain_protect = 1.0\n        self.total_unlearn_epochs = 20\n        self.eps = 1e-12\n\n    def _maybe_init_teacher(self, model):\n        if self.teacher is None:\n            self.teacher = copy.deepcopy(model)\n            self.teacher.eval()\n            for param in self.teacher.parameters():\n                param.requires_grad_(False)\n\n    def _counterfactual_target(self, teacher_logits, forget_y):\n        num_classes = teacher_logits.size(1)\n        teacher_probs = F.softmax(teacher_logits / self.forget_temp, dim=1)\n        true_mask = F.one_hot(forget_y, num_classes=num_classes).to(dtype=teacher_probs.dtype)\n\n        # Remove the forgotten class from the teacher distribution and renormalize.\n        teacher_nontrue = teacher_probs * (1.0 - true_mask)\n        teacher_nontrue = teacher_nontrue / teacher_nontrue.sum(dim=1, keepdim=True).clamp_min(1e-8)\n\n        # Mix with a uniform non-true target to encourage higher entropy and\n        # reduce membership traces.\n        uniform_nontrue = (1.0 - true_mask) / max(num_classes - 1, 1)\n        return self.counterfactual_mix * teacher_nontrue + (1.0 - self.counterfactual_mix) * uniform_nontrue\n\n    def unlearn_step(self, model, retain_batch, forget_batch, optimizer, step, epoch):\n        self._maybe_init_teacher(model)\n\n        retain_x, retain_y = retain_batch\n        forget_x, forget_y = forget_batch\n        params = [param for param in model.parameters() if param.requires_grad]\n\n        retain_logits = model(retain_x)\n        forget_logits = model(forget_x)\n\n        with torch.no_grad():\n            teacher_retain_logits = self.teacher(retain_x)\n            teacher_forget_logits = self.teacher(forget_x)\n\n        # -----------------------------\n        # Retain objective: fit labels + preserve original function on retained data.\n        # -----------------------------\n        retain_ce = F.cross_entropy(retain_logits, retain_y)\n        retain_kd = F.kl_div(\n            F.log_softmax(retain_logits / self.distill_temp, dim=1),\n            F.softmax(teacher_retain_logits / self.distill_temp, dim=1),\n            reduction=\"batchmean\",\n        ) * (self.distill_temp ** 2)\n        retain_loss = retain_ce + self.retain_kd_weight * retain_kd\n\n        # -----------------------------\n        # Forget objective: counterfactual relabeling + explicit true-class suppression.\n        # -----------------------------\n        cf_target = self._counterfactual_target(teacher_forget_logits, forget_y)\n        forget_log_probs = F.log_softmax(forget_logits / self.forget_temp, dim=1)\n        forget_soft_ce = -(cf_target * forget_log_probs).sum(dim=1).mean()\n\n        forget_probs = F.softmax(forget_logits, dim=1)\n        true_prob = forget_probs.gather(1, forget_y.unsqueeze(1)).squeeze(1)\n        entropy = -(forget_probs * torch.log(forget_probs.clamp_min(1e-8))).sum(dim=1).mean()\n\n        num_classes = forget_logits.size(1)\n        true_mask_bool = F.one_hot(forget_y, num_classes=num_classes).bool()\n        true_logit = forget_logits.gather(1, forget_y.unsqueeze(1)).squeeze(1)\n        max_other_logit = forget_logits.masked_fill(true_mask_bool, -1e9).max(dim=1).values\n        margin_loss = F.softplus(true_logit - max_other_logit).mean()\n\n        forget_loss = (\n            forget_soft_ce\n            + self.margin_weight * margin_loss\n            + self.confidence_weight * true_prob.mean()\n            - self.entropy_weight * entropy\n        )\n\n        # -----------------------------\n        # Conflict-aware gradient composition.\n        # -----------------------------\n        retain_grads = torch.autograd.grad(retain_loss, params, retain_graph=True, allow_unused=True)\n        forget_grads = torch.autograd.grad(forget_loss, params, allow_unused=True)\n\n        dot = 0.0\n        retain_norm_sq = 0.0\n        forget_norm_sq = 0.0\n        for grad_r, grad_f in zip(retain_grads, forget_grads):\n            if grad_r is not None:\n                retain_norm_sq = retain_norm_sq + (grad_r * grad_r).sum()\n            if grad_f is not None:\n                forget_norm_sq = forget_norm_sq + (grad_f * grad_f).sum()\n            if grad_r is not None and grad_f is not None:\n                dot = dot + (grad_r * grad_f).sum()\n\n        retain_norm_sq = torch.as_tensor(retain_norm_sq, device=retain_logits.device).clamp_min(self.eps)\n        forget_norm_sq = torch.as_tensor(forget_norm_sq, device=retain_logits.device).clamp_min(self.eps)\n        dot = torch.as_tensor(dot, device=retain_logits.device)\n\n        # If retain and forget gradients conflict, drop the forget component that\n        # directly opposes retain performance.\n        if dot.detach().item() < 0.0:\n            proj_coeff = (dot / retain_norm_sq).detach()\n            adjusted_forget_grads = []\n            for grad_r, grad_f in zip(retain_grads, forget_grads):\n                if grad_f is None:\n                    adjusted_forget_grads.append(None)\n                elif grad_r is None:\n                    adjusted_forget_grads.append(grad_f)\n                else:\n                    adjusted_forget_grads.append(grad_f - proj_coeff * grad_r)\n            forget_grads = adjusted_forget_grads\n\n        # Norm balancing plus a mild epoch anneal: forget strongly early, then\n        # consolidate retained utility later in unlearning.\n        progress = min(max((float(epoch) + 1.0) / float(self.total_unlearn_epochs), 0.0), 1.0)\n        epoch_scale = 1.0 - 0.6 * progress\n        trust_ratio = torch.sqrt(retain_norm_sq / forget_norm_sq).clamp(0.25, 1.5).detach()\n        confidence_scale = (0.5 + true_prob.mean().detach()).clamp(0.5, 1.5)\n        forget_scale = float(self.base_forget_weight * epoch_scale * trust_ratio * confidence_scale)\n\n        optimizer.zero_grad()\n        for param, grad_r, grad_f in zip(params, retain_grads, forget_grads):\n            total_grad = None\n            if grad_r is not None:\n                total_grad = grad_r\n\n            if grad_f is not None:\n                effective_forget_grad = grad_f\n                if grad_r is not None:\n                    saliency_gate = effective_forget_grad.abs() / (\n                        effective_forget_grad.abs() + self.retain_protect * grad_r.abs() + self.eps\n                    )\n                    effective_forget_grad = saliency_gate * effective_forget_grad\n                total_grad = (\n                    forget_scale * effective_forget_grad\n                    if total_grad is None\n                    else total_grad + forget_scale * effective_forget_grad\n                )\n\n            if total_grad is not None:\n                param.grad = total_grad.detach()\n\n        optimizer.step()\n\n        return {\n            \"loss\": (retain_loss.detach() + forget_scale * forget_loss.detach()).item(),\n            \"retain_loss\": retain_loss.detach().item(),\n            \"forget_loss\": forget_loss.detach().item(),\n            \"forget_scale\": forget_scale,\n            \"retain_ce\": retain_ce.detach().item(),\n            \"retain_kd\": retain_kd.detach().item(),\n            \"forget_true_prob\": true_prob.mean().detach().item(),\n            \"forget_entropy\": entropy.detach().item(),\n        }\n# ============================================================\n# END EDITABLE\n# ============================================================\n",
5  "after_line": 0,
6  "start_line": 1,
7  "end_line": 25
8}
Tool ResultStep 1
OK: Replaced lines 1..25 with 196 line(s) in pytorch-vision/bench/unlearning/custom_unlearning.py. Editable range: entire file.

[Current file: pytorch-vision/bench/unlearning/custom_unlearning.py | editable: 1–196 | total: 196 lines]
     1  """Editable unlearning method for MLS-Bench."""
     2  
     3  import copy
     4  
     5  import torch
     6  import torch.nn.functional as F
     7  
     8  # ============================================================
     9  # EDITABLE
    10  # ============================================================
    11  class UnlearningMethod:
    12      """Teacher-anchored counterfactual unlearning with gradient conflict control.
    13  
    14      Core idea:
    15      1) Snapshot the pretrained model on the first step and use it as a frozen
    16         retain-teacher.
    17      2) Preserve retained utility with supervised CE + teacher distillation.
    18      3) Forget using a counterfactual target on forget examples: match the
   
AssistantStep 2test
AssistantStep 1edit
Tool Input
1{
2  "op": "replace",
3  "filename": "pytorch-vision/bench/unlearning/custom_unlearning.py",
4  "content": "\"\"\"Editable unlearning method for MLS-Bench.\"\"\"\n\nimport copy\nimport math\n\nimport torch\nimport torch.nn.functional as F\n\n# ============================================================\n# EDITABLE\n# ============================================================\nclass UnlearningMethod:\n    \"\"\"Teacher-anchored counterfactual unlearning with projected gradients.\n\n    The update couples four ideas:\n      1) Freeze the pretrained model as a teacher on the first step.\n      2) Preserve utility with retain-set CE + distillation to the teacher.\n      3) Forget by matching a counterfactual teacher target on forget samples:\n         remove the true class from the teacher distribution, renormalize over\n         the remaining classes, and mix with a uniform non-true prior.\n      4) Estimate retain-importance online from retain gradients, damp forget\n         gradients on important parameters, then project retain gradients so\n         they do not cancel the forgetting direction.\n    \"\"\"\n\n    def __init__(self):\n        self.teacher = None\n        self.importance = None\n\n        self.temperature = 2.5\n        self.retain_kd_weight = 0.75\n        self.uniform_mix = 0.35\n        self.margin_weight = 0.80\n        self.entropy_weight = 0.03\n        self.importance_momentum = 0.90\n        self.importance_scale = 1.50\n        self.max_forget_scale = 1.40\n        self.min_forget_scale = 0.70\n        self.total_unlearn_epochs = 20.0\n        self.eps = 1e-12\n\n    def _maybe_init_teacher(self, model):\n        if self.teacher is not None:\n            return\n        device = next(model.parameters()).device\n        self.teacher = copy.deepcopy(model).to(device)\n        self.teacher.eval()\n        for param in self.teacher.parameters():\n            param.requires_grad_(False)\n\n    @staticmethod\n    def _clone_grads(model):\n        return [None if p.grad is None else p.grad.detach().clone() for p in model.parameters()]\n\n    def _update_importance(self, model, retain_grads):\n        if self.importance is None:\n            self.importance = [torch.zeros_like(p) for p in model.parameters()]\n        with torch.no_grad():\n            for imp, grad in zip(self.importance, retain_grads):\n                if grad is None:\n                    continue\n                imp.mul_(self.importance_momentum).addcmul_(grad, grad, value=1.0 - self.importance_momentum)\n\n    def _counterfactual_targets(self, teacher_logits, labels):\n        teacher_probs = F.softmax(teacher_logits / self.temperature, dim=1)\n        non_true = teacher_probs.clone()\n        non_true.scatter_(1, labels.view(-1, 1), 0.0)\n        non_true_mass = non_true.sum(dim=1, keepdim=True)\n\n        uniform = torch.ones_like(non_true)\n        uniform.scatter_(1, labels.view(-1, 1), 0.0)\n        uniform = uniform / uniform.sum(dim=1, keepdim=True).clamp_min(1.0)\n\n        counterfactual = non_true / non_true_mass.clamp_min(self.eps)\n        counterfactual = torch.where(non_true_mass > self.eps, counterfactual, uniform)\n        return (1.0 - self.uniform_mix) * counterfactual + self.uniform_mix * uniform\n\n    def _damp_forget_grads(self, forget_grads):\n        if self.importance is None:\n            return forget_grads\n\n        damped = []\n        with torch.no_grad():\n            for grad, imp in zip(forget_grads, self.importance):\n                if grad is None:\n                    damped.append(None)\n                    continue\n                mean_imp = imp.mean().clamp_min(self.eps)\n                rel_imp = imp / mean_imp\n                gate = 1.0 + self.importance_scale * rel_imp\n                damped.append(grad / gate)\n        return damped\n\n    def _assign_projected_grads(self, model, retain_grads, forget_grads, forget_scale):\n        device = next(model.parameters()).device\n        dot = torch.zeros((), device=device)\n        retain_norm_sq = torch.zeros((), device=device)\n        forget_norm_sq = torch.zeros((), device=device)\n\n        for r_grad, f_grad in zip(retain_grads, forget_grads):\n            if r_grad is not None:\n                retain_norm_sq = retain_norm_sq + (r_grad * r_grad).sum()\n            if f_grad is not None:\n                forget_norm_sq = forget_norm_sq + (f_grad * f_grad).sum()\n            if r_grad is not None and f_grad is not None:\n                dot = dot + (r_grad * f_grad).sum()\n\n        conflict_coeff = torch.zeros((), device=device)\n        if dot.item() < 0.0 and forget_norm_sq.item() > self.eps:\n            conflict_coeff = dot / forget_norm_sq.clamp_min(self.eps)\n\n        norm_balance = torch.sqrt(retain_norm_sq / forget_norm_sq.clamp_min(self.eps))\n        scaled_forget = forget_scale * norm_balance.clamp(0.60, 1.60)\n\n        for param, r_grad, f_grad in zip(model.parameters(), retain_grads, forget_grads):\n            if r_grad is None and f_grad is None:\n                param.grad = None\n                continue\n\n            if r_grad is None:\n                final_grad = torch.zeros_like(f_grad)\n            else:\n                final_grad = r_grad\n                if conflict_coeff.item() < 0.0 and f_grad is not None:\n                    final_grad = final_grad - conflict_coeff * f_grad\n\n            if f_grad is not None:\n                final_grad = final_grad + scaled_forget * f_grad\n\n            param.grad = final_grad\n\n        return {\n            \"grad_dot\": dot.item(),\n            \"retain_grad_norm\": retain_norm_sq.sqrt().item(),\n            \"forget_grad_norm\": forget_norm_sq.sqrt().item(),\n            \"forget_scale\": scaled_forget.item(),\n        }\n\n    def _retain_loss(self, model, retain_x, retain_y):\n        logits = model(retain_x)\n        retain_ce = F.cross_entropy(logits, retain_y)\n        with torch.no_grad():\n            teacher_logits = self.teacher(retain_x)\n        retain_kd = F.kl_div(\n            F.log_softmax(logits / self.temperature, dim=1),\n            F.softmax(teacher_logits / self.temperature, dim=1),\n            reduction=\"batchmean\",\n        ) * (self.temperature ** 2)\n        return retain_ce + self.retain_kd_weight * retain_kd, retain_ce, retain_kd\n\n    def _forget_loss(self, model, forget_x, forget_y):\n        logits = model(forget_x)\n        with torch.no_grad():\n            teacher_logits = self.teacher(forget_x)\n            targets = self._counterfactual_targets(teacher_logits, forget_y)\n\n        forget_cf = F.kl_div(\n            F.log_softmax(logits / self.temperature, dim=1),\n            targets,\n            reduction=\"batchmean\",\n        ) * (self.temperature ** 2)\n\n        true_logits = logits.gather(1, forget_y.view(-1, 1)).squeeze(1)\n        non_true_logits = logits.clone()\n        non_true_logits.scatter_(1, forget_y.view(-1, 1), float(\"-inf\"))\n        best_other = non_true_logits.max(dim=1).values\n        forget_margin = F.softplus(true_logits - best_other).mean()\n\n        probs = F.softmax(logits, dim=1)\n        true_prob = probs.gather(1, forget_y.view(-1, 1)).mean()\n        forget_entropy = -(probs * torch.log(probs.clamp_min(1e-8))).sum(dim=1).mean()\n\n        forget_loss = forget_cf + self.margin_weight * forget_margin - self.entropy_weight * forget_entropy\n        return forget_loss, forget_cf, forget_margin, forget_entropy, true_prob\n\n    def unlearn_step(self, model, retain_batch, forget_batch, optimizer, step, epoch):\n        self._maybe_init_teacher(model)\n\n        retain_x, retain_y = retain_batch\n        forget_x, forget_y = forget_batch\n\n        retain_loss, retain_ce, retain_kd = self._retain_loss(model, retain_x, retain_y)\n        optimizer.zero_grad()\n        retain_loss.backward()\n        retain_grads = self._clone_grads(model)\n        self._update_importance(model, retain_grads)\n\n        forget_loss, forget_cf, forget_margin, forget_entropy, true_prob = self._forget_loss(model, forget_x, forget_y)\n        optimizer.zero_grad()\n        forget_loss.backward()\n        forget_grads = self._clone_grads(model)\n        forget_grads = self._damp_forget_grads(forget_grads)\n\n        progress = min(max(float(epoch) / max(self.total_unlearn_epochs - 1.0, 1.0), 0.0), 1.0)\n        schedule = self.min_forget_scale + (self.max_forget_scale - self.min_forget_scale) * (0.5 * (1.0 + math.cos(math.pi * progress)))\n        confidence_boost = (0.75 + 1.50 * true_prob.detach()).clamp(0.75, 2.00).item()\n        grad_stats = self._assign_projected_grads(model, retain_grads, forget_grads, schedule * confidence_boost)\n\n        optimizer.step()\n\n        return {\n            \"loss\": (retain_loss.detach() + forget_loss.detach()).item(),\n            \"retain_loss\": retain_loss.item(),\n            \"retain_ce\": retain_ce.item(),\n            \"retain_kd\": retain_kd.item(),\n            \"forget_loss\": forget_loss.item(),\n            \"forget_cf\": forget_cf.item(),\n            \"forget_margin\": forget_margin.item(),\n            \"forget_entropy\": forget_entropy.item(),\n            \"forget_true_prob\": true_prob.item(),\n            **grad_stats,\n        }\n# ============================================================\n# END EDITABLE\n# ============================================================\n",
5  "after_line": 0,
6  "start_line": 1,
7  "end_line": 25
8}
Tool ResultStep 1
OK: Replaced lines 1..25 with 215 line(s) in pytorch-vision/bench/unlearning/custom_unlearning.py. Editable range: entire file.

[Current file: pytorch-vision/bench/unlearning/custom_unlearning.py | editable: 1–215 | total: 215 lines]
     1  """Editable unlearning method for MLS-Bench."""
     2  
     3  import copy
     4  import math
     5  
     6  import torch
     7  import torch.nn.functional as F
     8  
     9  # ============================================================
    10  # EDITABLE
    11  # ============================================================
    12  class UnlearningMethod:
    13      """Teacher-anchored counterfactual unlearning with projected gradients.
    14  
    15      The update couples four ideas:
    16        1) Freeze the pretrained model as a teacher on the first step.
    17        2) Preserve utility with retain-set CE + distillation to the teacher.
    18        3) Forget by matching a counterfactual teacher target on forget samples:
    
AssistantStep 2test
Tool ResultStep 2
[Test #1] (2 tests remaining; call submit(n=1) to submit this result)

### resnet20-cifar10-class0 (scripts/resnet20_cifar10_class0.sh)
TRAIN_METRICS phase=pretrain epoch=1 loss=1.5574 acc=42.08
TRAIN_METRICS phase=pretrain epoch=10 loss=0.5655 acc=80.50
TRAIN_METRICS phase=pretrain epoch=20 loss=0.4727 acc=83.88
TRAIN_METRICS phase=pretrain epoch=30 loss=0.4142 acc=85.77
TRAIN_METRICS phase=pretrain epoch=40 loss=0.3473 acc=88.18
TRAIN_METRICS phase=pretrain epoch=50 loss=0.2721 acc=90.71
TRAIN_METRICS phase=pretrain epoch=60 loss=0.1868 acc=93.53
TRAIN_METRICS phase=pretrain epoch=70 loss=0.0872 acc=97.12
TRAIN_METRICS phase=pretrain epoch=80 loss=0.0496 acc=98.61
TRAIN_METRICS phase=unlearn epoch=1 loss=1.7090
TRAIN_METRICS phase=unlearn epoch=5 loss=0.7527
TRAIN_METRICS phase=unlearn epoch=10 loss=0.7573
TRAIN_METRICS phase=unlearn epoch=15 loss=0.7582
TRAIN_METRICS phase=unlearn epoch=20 loss=0.6637
resnet20-cifar10-class0: retain_acc=0.9008, forget_acc=0.0000, forget_mia_auc=0.42
AssistantStep 3edit
Tool Input
1{
2  "op": "replace",
3  "filename": "pytorch-vision/bench/unlearning/custom_unlearning.py",
4  "content": "\"\"\"Editable unlearning method for MLS-Bench.\"\"\"\n\nimport copy\nimport math\n\nimport torch\nimport torch.nn.functional as F\n\n# ============================================================\n# EDITABLE\n# ============================================================\nclass UnlearningMethod:\n    \"\"\"Teacher-anchored counterfactual unlearning with phased gradient surgery.\n\n    Key components:\n      1) Freeze the pretrained model as a teacher on the first step.\n      2) Preserve utility with retain-set CE + retain distillation.\n      3) Forget via counterfactual teacher targets that remove the true class.\n      4) Track online retain-importance to damp forget updates on parameters that\n         matter for the retained task.\n      5) Use phase-scheduled gradient projection: early steps prioritize\n         forgetting, late steps prioritize retain recovery.\n    \"\"\"\n\n    def __init__(self):\n        self.teacher = None\n        self.importance = None\n\n        self.temperature = 2.5\n        self.retain_kd_weight = 1.00\n        self.uniform_mix = 0.35\n        self.margin_weight = 0.55\n        self.entropy_weight = 0.02\n        self.importance_momentum = 0.90\n        self.importance_scale = 2.50\n        self.max_forget_scale = 1.10\n        self.min_forget_scale = 0.25\n        self.repair_start = 0.35\n        self.total_unlearn_epochs = 20.0\n        self.eps = 1e-12\n\n    def _maybe_init_teacher(self, model):\n        if self.teacher is not None:\n            return\n        device = next(model.parameters()).device\n        self.teacher = copy.deepcopy(model).to(device)\n        self.teacher.eval()\n        for param in self.teacher.parameters():\n            param.requires_grad_(False)\n\n    @staticmethod\n    def _clone_grads(model):\n        return [None if p.grad is None else p.grad.detach().clone() for p in model.parameters()]\n\n    def _update_importance(self, model, retain_grads):\n        if self.importance is None:\n            self.importance = [torch.zeros_like(p) for p in model.parameters()]\n        with torch.no_grad():\n            for imp, grad in zip(self.importance, retain_grads):\n                if grad is None:\n                    continue\n                imp.mul_(self.importance_momentum).addcmul_(grad, grad, value=1.0 - self.importance_momentum)\n\n    def _counterfactual_targets(self, teacher_logits, labels):\n        teacher_probs = F.softmax(teacher_logits / self.temperature, dim=1)\n        non_true = teacher_probs.clone()\n        non_true.scatter_(1, labels.view(-1, 1), 0.0)\n        non_true_mass = non_true.sum(dim=1, keepdim=True)\n\n        uniform = torch.ones_like(non_true)\n        uniform.scatter_(1, labels.view(-1, 1), 0.0)\n        uniform = uniform / uniform.sum(dim=1, keepdim=True).clamp_min(1.0)\n\n        counterfactual = non_true / non_true_mass.clamp_min(self.eps)\n        counterfactual = torch.where(non_true_mass > self.eps, counterfactual, uniform)\n        return (1.0 - self.uniform_mix) * counterfactual + self.uniform_mix * uniform\n\n    def _damp_forget_grads(self, forget_grads):\n        if self.importance is None:\n            return forget_grads\n\n        damped = []\n        with torch.no_grad():\n            for grad, imp in zip(forget_grads, self.importance):\n                if grad is None:\n                    damped.append(None)\n                    continue\n                mean_imp = imp.mean().clamp_min(self.eps)\n                rel_imp = imp / mean_imp\n                gate = 1.0 + self.importance_scale * rel_imp\n                damped.append(grad / gate)\n        return damped\n\n    def _assign_projected_grads(self, model, retain_grads, forget_grads, forget_scale, progress):\n        device = next(model.parameters()).device\n        dot = torch.zeros((), device=device)\n        retain_norm_sq = torch.zeros((), device=device)\n        forget_norm_sq = torch.zeros((), device=device)\n\n        for r_grad, f_grad in zip(retain_grads, forget_grads):\n            if r_grad is not None:\n                retain_norm_sq = retain_norm_sq + (r_grad * r_grad).sum()\n            if f_grad is not None:\n                forget_norm_sq = forget_norm_sq + (f_grad * f_grad).sum()\n            if r_grad is not None and f_grad is not None:\n                dot = dot + (r_grad * f_grad).sum()\n\n        retain_coeff = torch.zeros((), device=device)\n        forget_coeff = torch.zeros((), device=device)\n        if dot.item() < 0.0:\n            if progress < self.repair_start and forget_norm_sq.item() > self.eps:\n                retain_coeff = dot / forget_norm_sq.clamp_min(self.eps)\n            elif retain_norm_sq.item() > self.eps:\n                forget_coeff = dot / retain_norm_sq.clamp_min(self.eps)\n\n        norm_balance = torch.sqrt(retain_norm_sq / forget_norm_sq.clamp_min(self.eps))\n        scaled_forget = forget_scale * norm_balance.clamp(0.50, 1.35)\n\n        for param, r_grad, f_grad in zip(model.parameters(), retain_grads, forget_grads):\n            if r_grad is None and f_grad is None:\n                param.grad = None\n                continue\n\n            if r_grad is None:\n                adj_r = None\n            else:\n                adj_r = r_grad\n                if retain_coeff.item() < 0.0 and f_grad is not None:\n                    adj_r = adj_r - retain_coeff * f_grad\n\n            if f_grad is None:\n                adj_f = None\n            else:\n                adj_f = f_grad\n                if forget_coeff.item() < 0.0 and r_grad is not None:\n                    adj_f = adj_f - forget_coeff * r_grad\n\n            if adj_r is None:\n                final_grad = torch.zeros_like(adj_f)\n            else:\n                final_grad = adj_r\n\n            if adj_f is not None:\n                final_grad = final_grad + scaled_forget * adj_f\n\n            param.grad = final_grad\n\n        return {\n            \"grad_dot\": dot.item(),\n            \"retain_grad_norm\": retain_norm_sq.sqrt().item(),\n            \"forget_grad_norm\": forget_norm_sq.sqrt().item(),\n            \"forget_scale\": scaled_forget.item(),\n        }\n\n    def _retain_loss(self, model, retain_x, retain_y):\n        logits = model(retain_x)\n        retain_ce = F.cross_entropy(logits, retain_y)\n        with torch.no_grad():\n            teacher_logits = self.teacher(retain_x)\n        retain_kd = F.kl_div(\n            F.log_softmax(logits / self.temperature, dim=1),\n            F.softmax(teacher_logits / self.temperature, dim=1),\n            reduction=\"batchmean\",\n        ) * (self.temperature ** 2)\n        return retain_ce + self.retain_kd_weight * retain_kd, retain_ce, retain_kd\n\n    def _forget_loss(self, model, forget_x, forget_y):\n        logits = model(forget_x)\n        with torch.no_grad():\n            teacher_logits = self.teacher(forget_x)\n            targets = self._counterfactual_targets(teacher_logits, forget_y)\n\n        forget_cf = F.kl_div(\n            F.log_softmax(logits / self.temperature, dim=1),\n            targets,\n            reduction=\"batchmean\",\n        ) * (self.temperature ** 2)\n\n        true_logits = logits.gather(1, forget_y.view(-1, 1)).squeeze(1)\n        non_true_logits = logits.clone()\n        non_true_logits.scatter_(1, forget_y.view(-1, 1), float(\"-inf\"))\n        best_other = non_true_logits.max(dim=1).values\n        forget_margin = F.softplus(true_logits - best_other).mean()\n\n        probs = F.softmax(logits, dim=1)\n        true_prob = probs.gather(1, forget_y.view(-1, 1)).mean()\n        forget_entropy = -(probs * torch.log(probs.clamp_min(1e-8))).sum(dim=1).mean()\n\n        forget_loss = forget_cf + self.margin_weight * forget_margin - self.entropy_weight * forget_entropy\n        return forget_loss, forget_cf, forget_margin, forget_entropy, true_prob\n\n    def unlearn_step(self, model, retain_batch, forget_batch, optimizer, step, epoch):\n        self._maybe_init_teacher(model)\n\n        retain_x, retain_y = retain_batch\n        forget_x, forget_y = forget_batch\n\n        retain_loss, retain_ce, retain_kd = self._retain_loss(model, retain_x, retain_y)\n        optimizer.zero_grad()\n        retain_loss.backward()\n        retain_grads = self._clone_grads(model)\n        self._update_importance(model, retain_grads)\n\n        forget_loss, forget_cf, forget_margin, forget_entropy, true_prob = self._forget_loss(model, forget_x, forget_y)\n        optimizer.zero_grad()\n        forget_loss.backward()\n        forget_grads = self._clone_grads(model)\n        forget_grads = self._damp_forget_grads(forget_grads)\n\n        progress = min(max(float(epoch) / max(self.total_unlearn_epochs - 1.0, 1.0), 0.0), 1.0)\n        schedule = self.min_forget_scale + (self.max_forget_scale - self.min_forget_scale) * (0.5 * (1.0 + math.cos(math.pi * progress)))\n        confidence_boost = (0.65 + 1.10 * true_prob.detach()).clamp(0.60, 1.75).item()\n        grad_stats = self._assign_projected_grads(model, retain_grads, forget_grads, schedule * confidence_boost, progress)\n\n        optimizer.step()\n\n        return {\n            \"loss\": (retain_loss.detach() + forget_loss.detach()).item(),\n            \"retain_loss\": retain_loss.item(),\n            \"retain_ce\": retain_ce.item(),\n            \"retain_kd\": retain_kd.item(),\n            \"forget_loss\": forget_loss.item(),\n            \"forget_cf\": forget_cf.item(),\n            \"forget_margin\": forget_margin.item(),\n            \"forget_entropy\": forget_entropy.item(),\n            \"forget_true_prob\": true_prob.item(),\n            **grad_stats,\n        }\n# ============================================================\n# END EDITABLE\n# ============================================================\n",
5  "after_line": 0,
6  "start_line": 1,
7  "end_line": 215
8}
Tool ResultStep 3
OK: Replaced lines 1..215 with 231 line(s) in pytorch-vision/bench/unlearning/custom_unlearning.py. Editable range: entire file.

[Current file: pytorch-vision/bench/unlearning/custom_unlearning.py | editable: 1–231 | total: 231 lines]
     1  """Editable unlearning method for MLS-Bench."""
     2  
     3  import copy
     4  import math
     5  
     6  import torch
     7  import torch.nn.functional as F
     8  
     9  # ============================================================
    10  # EDITABLE
    11  # ============================================================
    12  class UnlearningMethod:
    13      """Teacher-anchored counterfactual unlearning with phased gradient surgery.
    14  
    15      Key components:
    16        1) Freeze the pretrained model as a teacher on the first step.
    17        2) Preserve utility with retain-set CE + retain distillation.
    18        3) Forget via counterfactual teacher targets that remove the true class.
    19        4) Track
AssistantStep 4test
Tool ResultStep 4
[Test #2] (1 test remaining; call submit(n=2) to submit this result)

### resnet20-cifar10-class0 (scripts/resnet20_cifar10_class0.sh)
TRAIN_METRICS phase=pretrain epoch=1 loss=1.5574 acc=42.08
TRAIN_METRICS phase=pretrain epoch=10 loss=0.5655 acc=80.50
TRAIN_METRICS phase=pretrain epoch=20 loss=0.4727 acc=83.88
TRAIN_METRICS phase=pretrain epoch=30 loss=0.4142 acc=85.77
TRAIN_METRICS phase=pretrain epoch=40 loss=0.3473 acc=88.18
TRAIN_METRICS phase=pretrain epoch=50 loss=0.2721 acc=90.71
TRAIN_METRICS phase=pretrain epoch=60 loss=0.1868 acc=93.53
TRAIN_METRICS phase=pretrain epoch=70 loss=0.0872 acc=97.12
TRAIN_METRICS phase=pretrain epoch=80 loss=0.0496 acc=98.61
TRAIN_METRICS phase=unlearn epoch=1 loss=2.0309
TRAIN_METRICS phase=unlearn epoch=5 loss=1.0178
TRAIN_METRICS phase=unlearn epoch=10 loss=0.9762
TRAIN_METRICS phase=unlearn epoch=15 loss=0.9883
TRAIN_METRICS phase=unlearn epoch=20 loss=0.9328
resnet20-cifar10-class0: retain_acc=0.8823, forget_acc=0.0010, forget_mia_auc=0.424
AssistantStep 5undo
Tool Input
1{
2  "n": 1
3}
Tool ResultStep 5
Undo complete:
Restored: pytorch-vision/bench/unlearning/custom_unlearning.py
AssistantStep 6submit
Tool Input
1{
2  "n": 1
3}
Tool ResultStep 6
[submit] Submitting result from test #1 as final.

### resnet20-cifar10-class0 (scripts/resnet20_cifar10_class0.sh)
TRAIN_METRICS phase=pretrain epoch=1 loss=1.5574 acc=42.08
TRAIN_METRICS phase=pretrain epoch=10 loss=0.5655 acc=80.50
TRAIN_METRICS phase=pretrain epoch=20 loss=0.4727 acc=83.88
TRAIN_METRICS phase=pretrain epoch=30 loss=0.4142 acc=85.77
TRAIN_METRICS phase=pretrain epoch=40 loss=0.3473 acc=88.18
TRAIN_METRICS phase=pretrain epoch=50 loss=0.2721 acc=90.71
TRAIN_METRICS phase=pretrain epoch=60 loss=0.1868 acc=93.53
TRAIN_METRICS phase=pretrain epoch=70 loss=0.0872 acc=97.12
TRAIN_METRICS phase=pretrain epoch=80 loss=0.0496 acc=98.61
TRAIN_METRICS phase=unlearn epoch=1 loss=1.7090
TRAIN_METRICS phase=unlearn epoch=5 loss=0.7527
TRAIN_METRICS phase=unlearn epoch=10 loss=0.7573
TRAIN_METRICS phase=unlearn epoch=15 loss=0.7582
TRAIN_METRICS phase=unlearn epoch=20 loss=0.6637
resnet20-cifar10-class0: retain_acc=0.9008, forget_acc=0.0000, forget_mia_auc=0.4289, unlearn_score=0.