security-poison-robust-learning

Adversarial MLpytorch-visionrigorous codebase

Description

Poison-Robust Learning under Label-Flip Poisoning

Research Question

How can we design a stronger loss function or sample-weighting rule that improves robustness to poisoned training labels without changing the model, optimizer, or data pipeline?

Background

A fraction of poisoned (label-flipped) training labels can disproportionately distort model decision boundaries. Robust learning methods typically modify the objective to downweight suspicious samples or reduce memorization of corrupted targets. This task uses research-scale models (ResNet-20, VGG-16-BN, MobileNetV2) trained on full datasets with standard SGD + CosineAnnealing for 100 epochs.

Task

Implement a better poison-robust objective in bench/poison/custom_robust_loss.py. The fixed harness injects random label-flip corruption into the training set, trains with your loss, and evaluates on a clean test set.

Your method should improve clean test accuracy under poisoning while reducing how much the model memorizes poisoned labels. The approach must be modular and transferable across architectures and datasets.

Editable Interface

You must implement:

class RobustLoss:
    def compute_loss(self, logits, labels, epoch):
        ...
  • logits: current minibatch model outputs
  • labels: possibly poisoned labels (label-flip: (original + 1) % num_classes)
  • epoch: current training epoch (0-indexed)
  • Return value: scalar loss tensor

The corruption process, model architectures, optimizer, and training schedule are fixed.

Evaluation

Benchmarks:

  • resnet20-cifar10-labelflip: ResNet-20 on CIFAR-10, 10% label-flip poison
  • vgg16bn-cifar100-labelflip: VGG-16-BN on CIFAR-100, 10% label-flip poison
  • mobilenetv2-fmnist-labelflip: MobileNetV2 on FashionMNIST, 15% label-flip poison

Reported metrics:

  • test_acc: accuracy on clean test set
  • poison_fit: fraction of poisoned samples where model predicts the poisoned (wrong) label
  • robust_score = (test_acc + (1 - poison_fit)) / 2

Primary metric: robust_score (higher is better).

Baselines

  • cross_entropy: standard ERM on poisoned labels
  • generalized_ce: generalized cross-entropy for noisy labels
  • symmetric_ce: CE plus reverse-CE penalty
  • bootstrap: target interpolation with model predictions

Code

custom_robust_loss.py
EditableRead-only
1"""Editable poison-robust loss for MLS-Bench."""
2
3import torch
4import torch.nn.functional as F
5
6# ============================================================
7# EDITABLE
8# ============================================================
9class RobustLoss:
10 """Default cross-entropy objective."""
11
12 def __init__(self):
13 pass
14
15 def compute_loss(self, logits, labels, epoch):
run_poison_robust.py
EditableRead-only
1"""Research-scale evaluation harness for poison-robust learning.
2
3Train standard vision models (ResNet-20, VGG-16-BN, MobileNetV2) on
4CIFAR-10/100/FashionMNIST with label-flip poisoning. The agent's custom
5RobustLoss replaces nn.CrossEntropyLoss in the training loop.
6
7FIXED: Model architectures, data pipeline, training schedule, poison injection.
8EDITABLE: RobustLoss class in custom_robust_loss.py.
9"""
10
11import argparse
12import os
13import random
14
15import numpy as np

Results

ModelTypetest acc vgg16bn cifar100 labelflip poison fit vgg16bn cifar100 labelflip robust score vgg16bn cifar100 labelflip test acc resnet20 cifar10 labelflip poison fit resnet20 cifar10 labelflip robust score resnet20 cifar10 labelflip test acc mobilenetv2 fmnist labelflip poison fit mobilenetv2 fmnist labelflip robust score mobilenetv2 fmnist labelflip
bootstrapbaseline0.6810.5640.5590.9020.1140.8940.9480.0080.970
cross_entropybaseline0.6910.4530.6190.9040.1190.8920.9470.0100.968
generalized_cebaseline0.6670.0260.8210.9050.0400.9320.9330.0050.964
symmetric_cebaseline0.7030.1810.7610.9080.0800.9140.9430.0070.968
anthropic/claude-opus-4.6vanilla0.5860.0250.7800.8740.0230.9250.7660.0030.882
deepseek-reasonervanilla0.6920.4570.6180.8990.1180.8900.9470.0090.969
google/gemini-3.1-pro-previewvanilla0.4750.0110.7320.9010.0320.9350.9220.0050.959
openai/gpt-5.4-provanilla0.7250.0450.8400.9190.0310.9440.9460.0040.971
qwen3.6-plus:freevanilla0.0100.0080.5010.1000.0990.5000.1000.0980.501
anthropic/claude-opus-4.6agent0.6920.0360.8280.8960.0290.9340.9280.0050.961
deepseek-reasoneragent0.6850.5220.5820.9060.1160.8950.9470.0080.969
google/gemini-3.1-pro-previewagent0.5440.0300.7570.8990.0350.9320.9240.0050.960
openai/gpt-5.4-proagent0.7260.0460.8400.9160.0280.9440.9460.0040.971
qwen3.6-plus:freeagent0.6890.3800.6540.9090.0890.9100.9470.0070.970

Agent Conversations