security-machine-unlearning

Adversarial MLpytorch-visionrigorous codebase

Description

Machine Unlearning via Targeted Update Rules

Research Question

How can we design a stronger unlearning update rule that removes information about a forget set while retaining as much utility as possible on the retained data?

Background

Machine unlearning methods approximate the effect of retraining without the deleted data. The central tradeoff is clear: aggressive forgetting reduces utility, while conservative updates leave measurable traces of the forgotten examples.

The harness pretrains a standard vision model (ResNet-20, VGG-16-BN, or MobileNetV2) on the full training set for 80 epochs using SGD with cosine annealing. After pretraining, a single class is designated as the forget set. Your unlearning method then runs for 20 epochs, receiving both retain-set and forget-set minibatches each step, with an Adam optimizer (lr=0.001).

Task

Implement a better unlearning rule in bench/unlearning/custom_unlearning.py. The fixed harness trains an initial model, defines a forget split, and then applies your update rule for a fixed number of unlearning steps using retain and forget minibatches.

Your method should lower forget-set memorization while preserving retained-task accuracy.

Editable Interface

You must implement:

class UnlearningMethod:
    def unlearn_step(self, model, retain_batch, forget_batch, optimizer, step, epoch):
        ...
  • retain_batch: (images, labels) tuple from retained data (already on device)
  • forget_batch: (images, labels) tuple from the forget set (already on device)
  • optimizer: fixed Adam optimizer instance (lr=0.001)
  • Return value: dict with at least loss

The architecture, initial training, forget split, and evaluation probes are fixed.

Evaluation

Benchmarks:

  • resnet20-cifar10-class0: ResNet-20 on CIFAR-10, forgetting class 0
  • vgg16bn-cifar100-class0: VGG-16-BN on CIFAR-100, forgetting class 0
  • mobilenetv2-fmnist-class0: MobileNetV2 on FashionMNIST, forgetting class 0

Reported metrics:

  • retain_acc: accuracy on non-forget test data
  • forget_acc: accuracy on forget-class test data (lower is better)
  • forget_mia_auc: membership inference attack AUC on forget set (lower is better)
  • unlearn_score: (retain_acc + (1 - forget_acc) + (1 - forget_mia_auc)) / 3

Primary metric: unlearn_score (higher is better).

Baselines

  • retain_finetune: continue training only on retained data
  • negative_gradient: ascend forget loss and descend retain loss
  • bad_teacher: distillation-style forgetting baseline
  • scrub: stronger representation-scrubbing baseline

Code

custom_unlearning.py
EditableRead-only
1"""Editable unlearning method for MLS-Bench."""
2
3import torch
4import torch.nn.functional as F
5
6# ============================================================
7# EDITABLE
8# ============================================================
9class UnlearningMethod:
10 """Default retain-only finetuning update."""
11
12 def __init__(self):
13 self.forget_weight = 0.0
14
15 def unlearn_step(self, model, retain_batch, forget_batch, optimizer, step, epoch):
run_unlearning.py
EditableRead-only
1"""Fixed evaluation harness for security-machine-unlearning.
2
3Pipeline:
4 1. Load full dataset with standard augmentation
5 2. Split into retain set (all classes except forget_class) and forget set
6 3. Pretrain model on FULL training set for --pretrain-epochs (SGD + CosineAnnealing)
7 4. Run unlearning: agent method processes retain/forget batches for --unlearn-epochs
8 5. Evaluate: retain_acc, forget_acc, forget_mia_auc
9 6. Compute unlearn_score = (retain_acc + (1-forget_acc) + (1-forget_mia_auc)) / 3
10"""
11
12import argparse
13import math
14import os
15import random

Results

ModelTyperetain acc vgg16bn cifar100 class0 forget acc vgg16bn cifar100 class0 forget mia auc vgg16bn cifar100 class0 unlearn score vgg16bn cifar100 class0 retain acc resnet20 cifar10 class0 forget acc resnet20 cifar10 class0 forget mia auc resnet20 cifar10 class0 unlearn score resnet20 cifar10 class0 retain acc mobilenetv2 fmnist class0 forget acc mobilenetv2 fmnist class0 forget mia auc mobilenetv2 fmnist class0 unlearn score mobilenetv2 fmnist class0
bad_teacherbaseline0.4630.0000.4200.6810.8440.0010.4140.8100.9290.0000.4940.812
negative_gradientbaseline0.0100.0000.3630.5490.1730.0000.1260.6820.1110.0000.0380.691
retain_finetunebaseline0.5340.0000.4760.6860.8760.0000.4510.8080.9370.0000.4820.819
scrubbaseline0.4510.0000.4400.6700.8310.0000.3970.8110.9240.0000.5210.801
anthropic/claude-opus-4.6vanilla0.3920.0000.4120.6600.1990.0000.4600.5800.8580.0030.4670.796
deepseek-reasonervanilla0.0100.0000.3630.5490.1230.0000.1460.6590.1120.0000.0460.689
google/gemini-3.1-pro-previewvanilla0.5140.0000.5180.6650.9090.0000.4390.8240.9480.0000.5210.809
openai/gpt-5.4-provanilla0.5220.0000.5080.6710.9010.0000.4290.8240.9430.0000.5180.808
qwen3.6-plus:freevanilla0.4870.0000.4140.6910.8540.0000.4180.8120.9340.0000.5030.810
anthropic/claude-opus-4.6agent0.0380.0000.4530.5280.8690.0330.4550.7940.8840.0000.4950.796
deepseek-reasoneragent0.0890.0000.5690.5070.1570.0000.2640.6310.1110.0000.0480.688
google/gemini-3.1-pro-previewagent0.5490.0000.4910.6860.9090.0030.4200.829----
openai/gpt-5.4-proagent0.5220.0000.5080.6710.9010.0000.4290.8240.9430.0000.5180.808
qwen3.6-plus:freeagent0.4640.0000.4090.6850.8590.0010.3910.8220.9330.0000.4820.817

Agent Conversations