security-poison-robust-learning
Description
Poison-Robust Learning under Label-Flip Poisoning
Research Question
How can we design a stronger loss function or sample-weighting rule that improves robustness to poisoned training labels without changing the model, optimizer, or data pipeline?
Background
A fraction of poisoned (label-flipped) training labels can disproportionately distort model decision boundaries. Robust learning methods typically modify the objective to downweight suspicious samples or reduce memorization of corrupted targets. This task uses research-scale models (ResNet-20, VGG-16-BN, MobileNetV2) trained on full datasets with standard SGD + CosineAnnealing for 100 epochs.
Task
Implement a better poison-robust objective in bench/poison/custom_robust_loss.py. The fixed harness injects random label-flip corruption into the training set, trains with your loss, and evaluates on a clean test set.
Your method should improve clean test accuracy under poisoning while reducing how much the model memorizes poisoned labels. The approach must be modular and transferable across architectures and datasets.
Editable Interface
You must implement:
class RobustLoss:
def compute_loss(self, logits, labels, epoch):
...
logits: current minibatch model outputslabels: possibly poisoned labels (label-flip:(original + 1) % num_classes)epoch: current training epoch (0-indexed)- Return value: scalar loss tensor
The corruption process, model architectures, optimizer, and training schedule are fixed.
Evaluation
Benchmarks:
resnet20-cifar10-labelflip: ResNet-20 on CIFAR-10, 10% label-flip poisonvgg16bn-cifar100-labelflip: VGG-16-BN on CIFAR-100, 10% label-flip poisonmobilenetv2-fmnist-labelflip: MobileNetV2 on FashionMNIST, 15% label-flip poison
Reported metrics:
test_acc: accuracy on clean test setpoison_fit: fraction of poisoned samples where model predicts the poisoned (wrong) labelrobust_score = (test_acc + (1 - poison_fit)) / 2
Primary metric: robust_score (higher is better).
Baselines
cross_entropy: standard ERM on poisoned labelsgeneralized_ce: generalized cross-entropy for noisy labelssymmetric_ce: CE plus reverse-CE penaltybootstrap: target interpolation with model predictions
Code
1"""Editable poison-robust loss for MLS-Bench."""23import torch4import torch.nn.functional as F56# ============================================================7# EDITABLE8# ============================================================9class RobustLoss:10"""Default cross-entropy objective."""1112def __init__(self):13pass1415def compute_loss(self, logits, labels, epoch):
1"""Research-scale evaluation harness for poison-robust learning.23Train standard vision models (ResNet-20, VGG-16-BN, MobileNetV2) on4CIFAR-10/100/FashionMNIST with label-flip poisoning. The agent's custom5RobustLoss replaces nn.CrossEntropyLoss in the training loop.67FIXED: Model architectures, data pipeline, training schedule, poison injection.8EDITABLE: RobustLoss class in custom_robust_loss.py.9"""1011import argparse12import os13import random1415import numpy as np
Results
| Model | Type | test acc vgg16bn cifar100 labelflip ↑ | poison fit vgg16bn cifar100 labelflip ↓ | robust score vgg16bn cifar100 labelflip ↑ | test acc resnet20 cifar10 labelflip ↑ | poison fit resnet20 cifar10 labelflip ↓ | robust score resnet20 cifar10 labelflip ↑ | test acc mobilenetv2 fmnist labelflip ↑ | poison fit mobilenetv2 fmnist labelflip ↓ | robust score mobilenetv2 fmnist labelflip ↑ |
|---|---|---|---|---|---|---|---|---|---|---|
| bootstrap | baseline | 0.681 | 0.564 | 0.559 | 0.902 | 0.114 | 0.894 | 0.948 | 0.008 | 0.970 |
| cross_entropy | baseline | 0.691 | 0.453 | 0.619 | 0.904 | 0.119 | 0.892 | 0.947 | 0.010 | 0.968 |
| generalized_ce | baseline | 0.667 | 0.026 | 0.821 | 0.905 | 0.040 | 0.932 | 0.933 | 0.005 | 0.964 |
| symmetric_ce | baseline | 0.703 | 0.181 | 0.761 | 0.908 | 0.080 | 0.914 | 0.943 | 0.007 | 0.968 |
| anthropic/claude-opus-4.6 | vanilla | 0.586 | 0.025 | 0.780 | 0.874 | 0.023 | 0.925 | 0.766 | 0.003 | 0.882 |
| deepseek-reasoner | vanilla | 0.692 | 0.457 | 0.618 | 0.899 | 0.118 | 0.890 | 0.947 | 0.009 | 0.969 |
| google/gemini-3.1-pro-preview | vanilla | 0.475 | 0.011 | 0.732 | 0.901 | 0.032 | 0.935 | 0.922 | 0.005 | 0.959 |
| openai/gpt-5.4-pro | vanilla | 0.725 | 0.045 | 0.840 | 0.919 | 0.031 | 0.944 | 0.946 | 0.004 | 0.971 |
| qwen3.6-plus:free | vanilla | 0.010 | 0.008 | 0.501 | 0.100 | 0.099 | 0.500 | 0.100 | 0.098 | 0.501 |
| anthropic/claude-opus-4.6 | agent | 0.692 | 0.036 | 0.828 | 0.896 | 0.029 | 0.934 | 0.928 | 0.005 | 0.961 |
| deepseek-reasoner | agent | 0.685 | 0.522 | 0.582 | 0.906 | 0.116 | 0.895 | 0.947 | 0.008 | 0.969 |
| google/gemini-3.1-pro-preview | agent | 0.544 | 0.030 | 0.757 | 0.899 | 0.035 | 0.932 | 0.924 | 0.005 | 0.960 |
| openai/gpt-5.4-pro | agent | 0.726 | 0.046 | 0.840 | 0.916 | 0.028 | 0.944 | 0.946 | 0.004 | 0.971 |
| qwen3.6-plus:free | agent | 0.689 | 0.380 | 0.654 | 0.909 | 0.089 | 0.910 | 0.947 | 0.007 | 0.970 |