Agent Conversation: google/gemini-3.1-pro-preview

Model: google/gemini-3.1-pro-preview
Total steps: 8
Messages: 17
UserStep 0

Task: security-adversarial-training

Adversarial Training for Model Robustness

Research Question

How to design better adversarial training methods to enhance model robustness against L_inf adversarial attacks?

Background

Adversarial training is the most effective approach for improving neural network robustness against adversarial examples. The standard method (Madry et al., 2018) trains on PGD-generated adversarial examples using cross-entropy loss, but suffers from a trade-off between clean accuracy and robust accuracy. Advanced methods like TRADES and MART address this through different loss formulations that decouple the robustness objective from clean classification.

Task

Implement a novel adversarial training method in bench/custom_adv_train.py by modifying the AdversarialTrainer class. Your method should improve robust accuracy against white-box L_inf attacks while maintaining reasonable clean accuracy.

Interface

You must implement the AdversarialTrainer class with two methods:

  • __init__(self, model, eps, alpha, attack_steps, num_classes, **kwargs): Initialize your trainer.

    • model: The neural network to train (nn.Module).
    • eps: L_inf perturbation budget (0.3 for MNIST, 8/255 for CIFAR).
    • alpha: Step size for inner PGD attack.
    • attack_steps: Number of PGD steps for adversarial example generation.
    • num_classes: Number of output classes (10 or 100).
  • train_step(self, images, labels, optimizer) -> dict: Perform one training step.

    • images: Clean images, shape (N, C, H, W), values in [0, 1].
    • labels: Ground truth labels, shape (N,).
    • optimizer: SGD optimizer (lr, momentum, weight_decay already configured).
    • Returns: dict with at least 'loss' key (float).

The training loop, learning rate schedule (cosine annealing), model architecture, and data loading are handled externally. You only control the adversarial training procedure within each step.

Evaluation

After training, models are evaluated on:

  • Clean accuracy: Accuracy on unperturbed test images.
  • Robust accuracy (FGSM): Accuracy under 1-step FGSM attack.
  • Robust accuracy (PGD-50): Accuracy under 50-step PGD attack (primary metric).

Four scenarios (model + dataset):

  • SmallCNN on MNIST (eps = 0.3)
  • PreActResNet-18 on CIFAR-10 (eps = 8/255)
  • VGG-11-BN on CIFAR-10 (eps = 8/255)
  • PreActResNet-18 on CIFAR-100 (eps = 8/255)

Higher robust accuracy (PGD-50) across all scenarios is better.

Baselines

  • standard: Vanilla training (no adversarial examples). High clean accuracy, ~0% robust accuracy.
  • pgdat: PGD Adversarial Training (Madry et al., 2018). Trains on PGD adversarial examples with CE loss.
  • trades: TRADES (Zhang et al., 2019). Balances clean and robust accuracy via KL divergence regularization.
  • mart: MART (Wang et al., 2020). Misclassification-aware regularization that focuses on hard examples.
  • awp: AWP + TRADES (Wu et al., 2020). Adversarial weight perturbation
AssistantStep 1edit
Tool Input
1{
2  "start_line": 10,
3  "content": "class AdversarialTrainer:\n    \"\"\"\n    Adaptive Target TRADES with Sharpness-Aware Minimization (AT-SAM).\n    Combines a confident-calibrated distillation target with an efficient AWP/SAM\n    perturbation wrapper for state-of-the-art robust learning.\n    \"\"\"\n\n    def __init__(self, model, eps, alpha, attack_steps, num_classes, **kwargs):\n        self.model = model\n        self.eps = eps\n        self.alpha = alpha\n        self.attack_steps = attack_steps\n        self.num_classes = num_classes\n        self.beta = 6.0\n        self.awp_gamma = 0.01  # Perturbation size for SAM\n        self.ls_eps = 0.1      # Label smoothing factor\n\n    def compute_loss(self, images, adv_images, labels):\n        logits_clean = self.model(images)\n        logits_adv = self.model(adv_images)\n        \n        loss_clean = F.cross_entropy(logits_clean, labels)\n        \n        # Calculate adaptive target based on clean prediction confidence\n        probs_clean = F.softmax(logits_clean, dim=1)\n        conf_y = probs_clean.gather(1, labels.unsqueeze(1)).detach()\n        \n        # Smooth one-hot labels for the true class\n        smooth_y = torch.full_like(probs_clean, self.ls_eps / (self.num_classes - 1))\n        smooth_y.scatter_(1, labels.unsqueeze(1), 1.0 - self.ls_eps)\n        \n        # Interpolate TRADES target (clean probs) and PGD target (true label)\n        # Confident/correct examples match clean distribution; unconfident match true label.\n        alpha_mix = torch.clamp(1.0 - conf_y, min=0.0, max=1.0)\n        target_probs = (1.0 - alpha_mix) * probs_clean.detach() + alpha_mix * smooth_y\n        \n        loss_kl = F.kl_div(\n            F.log_softmax(logits_adv, dim=1), \n            target_probs, \n            reduction='batchmean'\n        )\n        \n        return loss_clean + self.beta * loss_kl\n\n    def train_step(self, images, labels, optimizer):\n        self.model.eval()\n        \n        # 1. Generate adversarial examples (TRADES-style attack)\n        adv_images = images.clone().detach()\n        adv_images = adv_images + torch.empty_like(adv_images).uniform_(-self.eps, self.eps)\n        adv_images = torch.clamp(adv_images, 0.0, 1.0)\n\n        with torch.no_grad():\n            logits_clean_ref = self.model(images).detach()\n\n        for _ in range(self.attack_steps):\n            adv_images.requires_grad_(True)\n            with torch.enable_grad():\n                logits_adv = self.model(adv_images)\n                loss_attack = F.kl_div(\n                    F.log_softmax(logits_adv, dim=1),\n                    F.softmax(logits_clean_ref, dim=1),\n                    reduction='batchmean'\n                )\n            grad = torch.autograd.grad(loss_attack, adv_images)[0]\n            adv_images = adv_images.detach() + self.alpha * grad.sign()\n            delta = torch.clamp(adv_images - images, min=-self.eps, max=self.eps)\n            adv_images = torch.clamp(images + delta, 0.0, 1.0).detach()\n\n        self.model.train()\n        \n        # 2. SAM/AWP Forward 1 (on clean weights)\n        loss = self.compute_loss(images, adv_images, labels)\n        optimizer.zero_grad()\n        loss.backward()\n\n        # 3. Apply standard SAM/AWP weight perturbation\n        e_w = {}\n        for name, param in self.model.named_parameters():\n            if param.requires_grad and param.grad is not None and len(param.shape) > 1:\n                g_norm = torch.norm(param.grad)\n                p_norm = torch.norm(param)\n                if g_norm > 1e-10:\n                    e = self.awp_gamma * p_norm / g_norm * param.grad\n                    param.data.add_(e)\n                    e_w[name] = e\n\n        # 4. SAM/AWP Forward 2 (on perturbed weights)\n        optimizer.zero_grad()\n        loss_robust = self.compute_loss(images, adv_images, labels)\n        loss_robust.backward()\n\n        # 5. Restore clean weights\n        for name, param in self.model.named_parameters():\n            if name in e_w:\n                param.data.sub_(e_w[name])\n\n        # 6. Apply gradients\n        optimizer.step()\n\n        return {'loss': loss.item(), 'loss_robust': loss_robust.item()}\n",
4  "op": "replace",
5  "filename": "bench/custom_adv_train.py",
6  "end_line": 54
7}
Tool ResultStep 1
ERROR: Package 'bench' is not in allowed packages
AssistantStep 2edit
Tool Input
1{
2  "start_line": 10,
3  "content": "class AdversarialTrainer:\n    \"\"\"\n    Adaptive Target TRADES with Sharpness-Aware Minimization (AT-SAM).\n    Combines a confident-calibrated distillation target with an efficient AWP/SAM\n    perturbation wrapper for state-of-the-art robust learning.\n    \"\"\"\n\n    def __init__(self, model, eps, alpha, attack_steps, num_classes, **kwargs):\n        self.model = model\n        self.eps = eps\n        self.alpha = alpha\n        self.attack_steps = attack_steps\n        self.num_classes = num_classes\n        self.beta = 6.0\n        self.awp_gamma = 0.01  # Perturbation size for SAM\n        self.ls_eps = 0.1      # Label smoothing factor\n\n    def compute_loss(self, images, adv_images, labels):\n        logits_clean = self.model(images)\n        logits_adv = self.model(adv_images)\n        \n        loss_clean = F.cross_entropy(logits_clean, labels)\n        \n        # Calculate adaptive target based on clean prediction confidence\n        probs_clean = F.softmax(logits_clean, dim=1)\n        conf_y = probs_clean.gather(1, labels.unsqueeze(1)).detach()\n        \n        # Smooth one-hot labels for the true class\n        smooth_y = torch.full_like(probs_clean, self.ls_eps / (self.num_classes - 1))\n        smooth_y.scatter_(1, labels.unsqueeze(1), 1.0 - self.ls_eps)\n        \n        # Interpolate TRADES target (clean probs) and PGD target (true label)\n        # Confident/correct examples match clean distribution; unconfident match true label.\n        alpha_mix = torch.clamp(1.0 - conf_y, min=0.0, max=1.0)\n        target_probs = (1.0 - alpha_mix) * probs_clean.detach() + alpha_mix * smooth_y\n        \n        loss_kl = F.kl_div(\n            F.log_softmax(logits_adv, dim=1), \n            target_probs, \n            reduction='batchmean'\n        )\n        \n        return loss_clean + self.beta * loss_kl\n\n    def train_step(self, images, labels, optimizer):\n        self.model.eval()\n        \n        # 1. Generate adversarial examples (TRADES-style attack)\n        adv_images = images.clone().detach()\n        adv_images = adv_images + torch.empty_like(adv_images).uniform_(-self.eps, self.eps)\n        adv_images = torch.clamp(adv_images, 0.0, 1.0)\n\n        with torch.no_grad():\n            logits_clean_ref = self.model(images).detach()\n\n        for _ in range(self.attack_steps):\n            adv_images.requires_grad_(True)\n            with torch.enable_grad():\n                logits_adv = self.model(adv_images)\n                loss_attack = F.kl_div(\n                    F.log_softmax(logits_adv, dim=1),\n                    F.softmax(logits_clean_ref, dim=1),\n                    reduction='batchmean'\n                )\n            grad = torch.autograd.grad(loss_attack, adv_images)[0]\n            adv_images = adv_images.detach() + self.alpha * grad.sign()\n            delta = torch.clamp(adv_images - images, min=-self.eps, max=self.eps)\n            adv_images = torch.clamp(images + delta, 0.0, 1.0).detach()\n\n        self.model.train()\n        \n        # 2. SAM/AWP Forward 1 (on clean weights)\n        loss = self.compute_loss(images, adv_images, labels)\n        optimizer.zero_grad()\n        loss.backward()\n\n        # 3. Apply standard SAM/AWP weight perturbation\n        e_w = {}\n        for name, param in self.model.named_parameters():\n            if param.requires_grad and param.grad is not None and len(param.shape) > 1:\n                g_norm = torch.norm(param.grad)\n                p_norm = torch.norm(param)\n                if g_norm > 1e-10:\n                    e = self.awp_gamma * p_norm / g_norm * param.grad\n                    param.data.add_(e)\n                    e_w[name] = e\n\n        # 4. SAM/AWP Forward 2 (on perturbed weights)\n        optimizer.zero_grad()\n        loss_robust = self.compute_loss(images, adv_images, labels)\n        loss_robust.backward()\n\n        # 5. Restore clean weights\n        for name, param in self.model.named_parameters():\n            if name in e_w:\n                param.data.sub_(e_w[name])\n\n        # 6. Apply gradients\n        optimizer.step()\n\n        return {'loss': loss.item(), 'loss_robust': loss_robust.item()}\n",
4  "op": "replace",
5  "filename": "torchattacks/bench/custom_adv_train.py",
6  "end_line": 54
7}
Tool ResultStep 2
OK: Replaced lines 10..54 with 101 line(s) in torchattacks/bench/custom_adv_train.py. Editable range: 10–110.

[Current file: torchattacks/bench/custom_adv_train.py | editable: 10–110 | total: 113 lines]
     7| # ═══════════════════════════════════════════════════════════════════
     8| # EDITABLE — implement AdversarialTrainer below
     9| # ═══════════════════════════════════════════════════════════════════
    10  class AdversarialTrainer:
    11      """
    12      Adaptive Target TRADES with Sharpness-Aware Minimization (AT-SAM).
    13      Combines a confident-calibrated distillation target with an efficient AWP/SAM
    14      perturbation wrapper for state-of-the-art robust learning.
    15      """
    16  
    17      def __init__(self, model, eps, alpha, attack_steps, num_classes, **kwargs):
    18          self.model = model
    19          self.eps = eps
    20          self.alpha = alpha
    21          self.attack_steps = attack_steps
    22          self.num_classe
AssistantStep 3test
Tool ResultStep 3
[Test #1] (2 tests remaining; call submit(n=1) to submit this result)


## Seed 42

### SmallCNN-MNIST (scripts/SmallCNN_MNIST.sh)
Training progress (SmallCNN-MNIST):
TRAIN_METRICS epoch=1 loss=10.9156 lr=0.009961
TRAIN_METRICS epoch=10 loss=10.9120 lr=0.006545
TRAIN_METRICS epoch=20 loss=10.9077 lr=0.000955
Results (SmallCNN-MNIST): clean_acc=0.1135, robust_acc_fgsm=0.1135, robust_acc_pgd=0.1135

### PreActResNet18-C10 (scripts/PreActResNet18_C10.sh)
Training progress (PreActResNet18-C10):
TRAIN_METRICS epoch=40 loss=10.9543 lr=0.050000
TRAIN_METRICS epoch=50 loss=10.9375 lr=0.030866
TRAIN_METRICS epoch=60 loss=10.9277 lr=0.014645
TRAIN_METRICS epoch=70 loss=10.9201 lr=0.003806
TRAIN_METRICS epoch=80 loss=10.9168 lr=0.000000
Results (PreActResNet18-C10): clean_acc=0.1000, robust_acc_fgsm=0.1000, robust_acc_pgd=0.1000


## Seed 123

### SmallCNN-MNIST (scripts/SmallCNN_MNIST.sh)
Training progress (SmallCNN-MNIST):
TRAIN_METRICS epoch=1 loss=10.9146 lr=0.009961
TRAIN_METRICS epoch=10 lo
AssistantStep 4edit
Tool Input
1{
2  "end_line": 111,
3  "op": "replace",
4  "start_line": 10,
5  "filename": "torchattacks/bench/custom_adv_train.py",
6  "content": "class AdversarialTrainer:\n    \"\"\"\n    Adaptive Target TRADES with Sharpness-Aware Minimization (AT-SAM).\n    Combines an interpolated, confidence-calibrated distillation target with an \n    efficient global SAM perturbation step to achieve robust optimization.\n    \"\"\"\n\n    def __init__(self, model, eps, alpha, attack_steps, num_classes, **kwargs):\n        self.model = model\n        self.eps = eps\n        self.alpha = alpha\n        self.attack_steps = attack_steps\n        self.num_classes = num_classes\n        self.beta = 6.0\n        self.rho = 0.05        # Perturbation radius for SAM\n        self.ls_eps = 0.1      # Label smoothing factor\n\n    def compute_loss(self, images, adv_images, labels):\n        logits_clean = self.model(images)\n        logits_adv = self.model(adv_images)\n        \n        loss_clean = F.cross_entropy(logits_clean, labels)\n        \n        # Calculate adaptive target based on clean prediction confidence\n        probs_clean = F.softmax(logits_clean, dim=1)\n        conf_y = probs_clean.gather(1, labels.unsqueeze(1)).detach()\n        \n        # Smooth one-hot labels for the true class\n        smooth_y = torch.full_like(probs_clean, self.ls_eps / max(1, self.num_classes - 1))\n        smooth_y.scatter_(1, labels.unsqueeze(1), 1.0 - self.ls_eps)\n        \n        # Interpolate TRADES target (clean probs) and PGD target (true label)\n        # Confident examples match clean distribution; unconfident match smoothed true label.\n        alpha_mix = torch.clamp(1.0 - conf_y, min=0.0, max=1.0)\n        target_probs = (1.0 - alpha_mix) * probs_clean.detach() + alpha_mix * smooth_y\n        \n        loss_kl = F.kl_div(\n            F.log_softmax(logits_adv, dim=1), \n            target_probs, \n            reduction='batchmean'\n        )\n        \n        return loss_clean + self.beta * loss_kl\n\n    def train_step(self, images, labels, optimizer):\n        self.model.eval()\n        \n        # 1. Generate adversarial examples (TRADES-style attack)\n        adv_images = images.clone().detach()\n        adv_images = adv_images + torch.empty_like(adv_images).uniform_(-self.eps, self.eps)\n        adv_images = torch.clamp(adv_images, 0.0, 1.0)\n\n        with torch.no_grad():\n            logits_clean_ref = self.model(images).detach()\n\n        for _ in range(self.attack_steps):\n            adv_images.requires_grad_(True)\n            with torch.enable_grad():\n                logits_adv = self.model(adv_images)\n                loss_attack = F.kl_div(\n                    F.log_softmax(logits_adv, dim=1),\n                    F.softmax(logits_clean_ref, dim=1),\n                    reduction='batchmean'\n                )\n            grad = torch.autograd.grad(loss_attack, adv_images)[0]\n            adv_images = adv_images.detach() + self.alpha * grad.sign()\n            delta = torch.clamp(adv_images - images, min=-self.eps, max=self.eps)\n            adv_images = torch.clamp(images + delta, 0.0, 1.0).detach()\n\n        self.model.train()\n        \n        # 2. SAM Forward 1 (on clean weights)\n        loss = self.compute_loss(images, adv_images, labels)\n        optimizer.zero_grad()\n        loss.backward()\n\n        # 3. Apply standard global SAM weight perturbation\n        e_w = {}\n        grads = []\n        for name, param in self.model.named_parameters():\n            if param.requires_grad and param.grad is not None and len(param.shape) > 1:\n                grads.append(param.grad.detach().view(-1))\n        \n        if grads:\n            grad_norm = torch.norm(torch.cat(grads))\n            if grad_norm > 1e-12:\n                scale = self.rho / grad_norm\n                for name, param in self.model.named_parameters():\n                    if param.requires_grad and param.grad is not None and len(param.shape) > 1:\n                        e = param.grad.detach() * scale\n                        param.data.add_(e)\n                        e_w[name] = e\n\n        # 4. SAM Forward 2 (on perturbed weights)\n        optimizer.zero_grad()\n        loss_robust = self.compute_loss(images, adv_images, labels)\n        loss_robust.backward()\n\n        # 5. Restore clean weights\n        for name, param in self.model.named_parameters():\n            if name in e_w:\n                param.data.sub_(e_w[name])\n\n        # 6. Apply gradients step\n        optimizer.step()\n\n        return {'loss': loss.item(), 'loss_robust': loss_robust.item()}\n"
7}
Tool ResultStep 4
ERROR: Lines 10..111 exceed the editable range. You may only edit lines 10–110.

[Current file: torchattacks/bench/custom_adv_train.py | editable: 10–110 | total: 113 lines]
     7| # ═══════════════════════════════════════════════════════════════════
     8| # EDITABLE — implement AdversarialTrainer below
     9| # ═══════════════════════════════════════════════════════════════════
    10  class AdversarialTrainer:
    11      """
    12      Adaptive Target TRADES with Sharpness-Aware Minimization (AT-SAM).
    13      Combines a confident-calibrated distillation target with an efficient AWP/SAM
    14      perturbation wrapper for state-of-the-art robust learning.
    15      """
    16  
    17      def __init__(self, model, eps, alpha, attack_steps, num_classes, **kwargs):
    18          self.model = model
    19          self.eps = eps
    20          self.alpha = alpha
    21          self.attack_steps = attack_steps
    22          self.num_classes = num_classes
    23        
AssistantStep 5edit
Tool Input
1{
2  "start_line": 10,
3  "op": "replace",
4  "end_line": 110,
5  "filename": "torchattacks/bench/custom_adv_train.py",
6  "content": "class AdversarialTrainer:\n    \"\"\"\n    Adaptive Target TRADES with Sharpness-Aware Minimization (AT-SAM).\n    Combines an interpolated, confidence-calibrated distillation target with an \n    efficient global SAM perturbation step to achieve robust optimization.\n    \"\"\"\n\n    def __init__(self, model, eps, alpha, attack_steps, num_classes, **kwargs):\n        self.model = model\n        self.eps = eps\n        self.alpha = alpha\n        self.attack_steps = attack_steps\n        self.num_classes = num_classes\n        self.beta = 6.0\n        self.rho = 0.05        # Perturbation radius for SAM\n        self.ls_eps = 0.1      # Label smoothing factor\n\n    def compute_loss(self, images, adv_images, labels):\n        logits_clean = self.model(images)\n        logits_adv = self.model(adv_images)\n        \n        loss_clean = F.cross_entropy(logits_clean, labels)\n        \n        # Calculate adaptive target based on clean prediction confidence\n        probs_clean = F.softmax(logits_clean, dim=1)\n        conf_y = probs_clean.gather(1, labels.unsqueeze(1)).detach()\n        \n        # Smooth one-hot labels for the true class\n        smooth_y = torch.full_like(probs_clean, self.ls_eps / max(1, self.num_classes - 1))\n        smooth_y.scatter_(1, labels.unsqueeze(1), 1.0 - self.ls_eps)\n        \n        # Interpolate TRADES target (clean probs) and PGD target (true label)\n        # Confident examples match clean distribution; unconfident match smoothed true label.\n        alpha_mix = torch.clamp(1.0 - conf_y, min=0.0, max=1.0)\n        target_probs = (1.0 - alpha_mix) * probs_clean.detach() + alpha_mix * smooth_y\n        \n        loss_kl = F.kl_div(\n            F.log_softmax(logits_adv, dim=1), \n            target_probs, \n            reduction='batchmean'\n        )\n        \n        return loss_clean + self.beta * loss_kl\n\n    def train_step(self, images, labels, optimizer):\n        self.model.eval()\n        \n        # 1. Generate adversarial examples (TRADES-style attack)\n        adv_images = images.clone().detach()\n        adv_images = adv_images + torch.empty_like(adv_images).uniform_(-self.eps, self.eps)\n        adv_images = torch.clamp(adv_images, 0.0, 1.0)\n\n        with torch.no_grad():\n            logits_clean_ref = self.model(images).detach()\n\n        for _ in range(self.attack_steps):\n            adv_images.requires_grad_(True)\n            with torch.enable_grad():\n                logits_adv = self.model(adv_images)\n                loss_attack = F.kl_div(\n                    F.log_softmax(logits_adv, dim=1),\n                    F.softmax(logits_clean_ref, dim=1),\n                    reduction='batchmean'\n                )\n            grad = torch.autograd.grad(loss_attack, adv_images)[0]\n            adv_images = adv_images.detach() + self.alpha * grad.sign()\n            delta = torch.clamp(adv_images - images, min=-self.eps, max=self.eps)\n            adv_images = torch.clamp(images + delta, 0.0, 1.0).detach()\n\n        self.model.train()\n        \n        # 2. SAM Forward 1 (on clean weights)\n        loss = self.compute_loss(images, adv_images, labels)\n        optimizer.zero_grad()\n        loss.backward()\n\n        # 3. Apply standard global SAM weight perturbation\n        e_w = {}\n        grads = []\n        for name, param in self.model.named_parameters():\n            if param.requires_grad and param.grad is not None and len(param.shape) > 1:\n                grads.append(param.grad.detach().view(-1))\n        \n        if grads:\n            grad_norm = torch.norm(torch.cat(grads))\n            if grad_norm > 1e-12:\n                scale = self.rho / grad_norm\n                for name, param in self.model.named_parameters():\n                    if param.requires_grad and param.grad is not None and len(param.shape) > 1:\n                        e = param.grad.detach() * scale\n                        param.data.add_(e)\n                        e_w[name] = e\n\n        # 4. SAM Forward 2 (on perturbed weights)\n        optimizer.zero_grad()\n        loss_robust = self.compute_loss(images, adv_images, labels)\n        loss_robust.backward()\n\n        # 5. Restore clean weights\n        for name, param in self.model.named_parameters():\n            if name in e_w:\n                param.data.sub_(e_w[name])\n\n        # 6. Apply gradients step\n        optimizer.step()\n\n        return {'loss': loss.item(), 'loss_robust': loss_robust.item()}"
7}
Tool ResultStep 5
OK: Replaced lines 10..110 with 107 line(s) in torchattacks/bench/custom_adv_train.py. Editable range: 10–116.

[Current file: torchattacks/bench/custom_adv_train.py | editable: 10–116 | total: 119 lines]
     7| # ═══════════════════════════════════════════════════════════════════
     8| # EDITABLE — implement AdversarialTrainer below
     9| # ═══════════════════════════════════════════════════════════════════
    10  class AdversarialTrainer:
    11      """
    12      Adaptive Target TRADES with Sharpness-Aware Minimization (AT-SAM).
    13      Combines an interpolated, confidence-calibrated distillation target with an 
    14      efficient global SAM perturbation step to achieve robust optimization.
    15      """
    16  
    17      def __init__(self, model, eps, alpha, attack_steps, num_classes, **kwargs):
    18          self.model = model
    19          self.eps = eps
    20          self.alpha = alpha
    21          self.attack_steps = attack_steps
    22          sel
AssistantStep 6test
Tool ResultStep 6
[Test #2] (1 test remaining; call submit(n=2) to submit this result)

### SmallCNN-MNIST (scripts/SmallCNN_MNIST.sh)
Training progress (SmallCNN-MNIST):
TRAIN_METRICS epoch=1 loss=3.2477 lr=0.009961
TRAIN_METRICS epoch=10 loss=0.4031 lr=0.006545
TRAIN_METRICS epoch=20 loss=0.2719 lr=0.000955
Results (SmallCNN-MNIST): clean_acc=0.9893, robust_acc_fgsm=0.9691, robust_acc_pgd=0.9423

### PreActResNet18-C10 (scripts/PreActResNet18_C10.sh)
Training progress (PreActResNet18-C10):
TRAIN_METRICS epoch=40 loss=3.3656 lr=0.050000
TRAIN_METRICS epoch=50 loss=3.0526 lr=0.030866
TRAIN_METRICS epoch=60 loss=2.7006 lr=0.014645
TRAIN_METRICS epoch=70 loss=2.3003 lr=0.003806
TRAIN_METRICS epoch=80 loss=2.0765 lr=0.000000
Results (PreActResNet18-C10): clean_acc=0.8619, robust_acc_fgsm=0.5299, robust_acc_pgd=0.4469

[Leaderboard] Results saved: {'clean_acc_SmallCNN_MNIST': 0.9893, 'robust_acc_fgsm_SmallCNN_MNIST': 0.9691, 'robust_acc_pgd_SmallCNN_MNIST': 0.9423, 'elapsed_SmallCNN-MNIST': 591, 'clean_acc_
AssistantStep 7edit
Tool Input
1{
2  "end_line": 116,
3  "op": "replace",
4  "filename": "torchattacks/bench/custom_adv_train.py",
5  "start_line": 10,
6  "content": "class AdversarialTrainer:\n    \"\"\"\n    Confidence-Adaptive TRADES with Layer-wise AWP (CA-AWP).\n    Introduces an instance-adaptive true-label objective that shifts between PGD-AT \n    for unconfident/hard examples and TRADES consistency for confident examples.\n    Coupled with efficient in-place Adversarial Weight Perturbation (AWP).\n    \"\"\"\n\n    def __init__(self, model, eps, alpha, attack_steps, num_classes, **kwargs):\n        self.model = model\n        self.eps = eps\n        self.alpha = alpha\n        self.attack_steps = attack_steps\n        self.num_classes = num_classes\n        self.beta = 6.0\n        self.awp_gamma = 0.01\n\n    def train_step(self, images, labels, optimizer):\n        self.model.eval()\n        \n        # 1. Generate adversarial examples (TRADES-style attack: maximize KL)\n        adv_images = images.clone().detach()\n        adv_images = adv_images + torch.empty_like(adv_images).uniform_(-self.eps, self.eps)\n        adv_images = torch.clamp(adv_images, 0.0, 1.0)\n\n        with torch.no_grad():\n            logits_clean_ref = self.model(images).detach()\n\n        for _ in range(self.attack_steps):\n            adv_images.requires_grad_(True)\n            with torch.enable_grad():\n                logits_adv = self.model(adv_images)\n                loss_attack = F.kl_div(\n                    F.log_softmax(logits_adv, dim=1),\n                    F.softmax(logits_clean_ref, dim=1),\n                    reduction='batchmean'\n                )\n            grad = torch.autograd.grad(loss_attack, adv_images)[0]\n            adv_images = adv_images.detach() + self.alpha * grad.sign()\n            delta = torch.clamp(adv_images - images, min=-self.eps, max=self.eps)\n            adv_images = torch.clamp(images + delta, 0.0, 1.0).detach()\n\n        self.model.train()\n        \n        # 2. AWP gradient computation (Maximize KL divergence only)\n        logits_clean = self.model(images)\n        logits_adv = self.model(adv_images)\n        \n        loss_kl_awp = F.kl_div(\n            F.log_softmax(logits_adv, dim=1),\n            F.softmax(logits_clean, dim=1).detach(),\n            reduction='batchmean'\n        )\n        \n        optimizer.zero_grad()\n        loss_kl_awp.backward()\n        \n        # 3. Layer-wise AWP weight perturbation\n        e_w = {}\n        for name, param in self.model.named_parameters():\n            if param.requires_grad and param.grad is not None and len(param.shape) > 1:\n                g_norm = torch.norm(param.grad.detach())\n                p_norm = torch.norm(param.detach())\n                if g_norm > 1e-10:\n                    e = self.awp_gamma * p_norm / g_norm * param.grad.detach()\n                    param.data.add_(e)\n                    e_w[name] = e\n\n        # 4. Outer optimization on perturbed weights\n        optimizer.zero_grad()\n        logits_clean_rob = self.model(images)\n        logits_adv_rob = self.model(adv_images)\n        \n        loss_clean = F.cross_entropy(logits_clean_rob, labels)\n        \n        # TRADES KL Consistency\n        loss_kl = F.kl_div(\n            F.log_softmax(logits_adv_rob, dim=1),\n            F.softmax(logits_clean_rob, dim=1).detach(),\n            reduction='batchmean'\n        )\n        \n        # Novel Contribution: Confidence-Adaptive Penalty\n        # Hard examples lack correct guidance from pure KL consistency.\n        # We supplement with instance-weighted adversarial cross-entropy.\n        probs_clean_rob = F.softmax(logits_clean_rob, dim=1).detach()\n        conf = probs_clean_rob.gather(1, labels.unsqueeze(1)).squeeze(1)\n        \n        adv_ce_weights = 1.0 - conf\n        loss_adv_inst = F.cross_entropy(logits_adv_rob, labels, reduction='none')\n        loss_adv_adaptive = (adv_ce_weights * loss_adv_inst).mean()\n        \n        loss_robust = loss_clean + self.beta * loss_kl + loss_adv_adaptive\n        \n        loss_robust.backward()\n\n        # 5. Restore clean weights\n        for name, param in self.model.named_parameters():\n            if name in e_w:\n                param.data.sub_(e_w[name])\n\n        # 6. Apply gradients step\n        optimizer.step()\n\n        return {'loss': loss_robust.item()}\n"
7}
Tool ResultStep 7
OK: Replaced lines 10..116 with 105 line(s) in torchattacks/bench/custom_adv_train.py. Editable range: 10–114.

[Current file: torchattacks/bench/custom_adv_train.py | editable: 10–114 | total: 117 lines]
     7| # ═══════════════════════════════════════════════════════════════════
     8| # EDITABLE — implement AdversarialTrainer below
     9| # ═══════════════════════════════════════════════════════════════════
    10  class AdversarialTrainer:
    11      """
    12      Confidence-Adaptive TRADES with Layer-wise AWP (CA-AWP).
    13      Introduces an instance-adaptive true-label objective that shifts between PGD-AT 
    14      for unconfident/hard examples and TRADES consistency for confident examples.
    15      Coupled with efficient in-place Adversarial Weight Perturbation (AWP).
    16      """
    17  
    18      def __init__(self, model, eps, alpha, attack_steps, num_classes, **kwargs):
    19          self.model = model
    20          self.eps = eps
    21          self
AssistantStep 8test
Tool ResultStep 8
[Test #3] (0 tests remaining)


## Seed 42

### SmallCNN-MNIST (scripts/SmallCNN_MNIST.sh)
Training progress (SmallCNN-MNIST):
TRAIN_METRICS epoch=1 loss=2.8734 lr=0.009961
TRAIN_METRICS epoch=10 loss=0.5954 lr=0.006545
TRAIN_METRICS epoch=20 loss=0.4833 lr=0.000955
Results (SmallCNN-MNIST): clean_acc=0.9860, robust_acc_fgsm=0.9589, robust_acc_pgd=0.9333

### PreActResNet18-C10 (scripts/PreActResNet18_C10.sh)
Training progress (PreActResNet18-C10):
TRAIN_METRICS epoch=40 loss=2.3210 lr=0.050000
TRAIN_METRICS epoch=50 loss=2.2589 lr=0.030866
TRAIN_METRICS epoch=60 loss=2.1925 lr=0.014645
TRAIN_METRICS epoch=70 loss=2.1275 lr=0.003806
TRAIN_METRICS epoch=80 loss=2.1305 lr=0.000000
Results (PreActResNet18-C10): clean_acc=0.8161, robust_acc_fgsm=0.5169, robust_acc_pgd=0.4687

### PreActResNet18-C100 (scripts/PreActResNet18_C100.sh)
Training progress (PreActResNet18-C100):
TRAIN_METRICS epoch=40 loss=4.7123 lr=0.050000
TRAIN_METRICS epoch=50 loss=4.3058 lr=0.030866
TRAIN_METRICS epoch=60 lo