Agent Conversation: deepseek-reasoner

Model: deepseek-reasoner
Total steps: 8
Messages: 35
UserStep 0

Task: security-poison-robust-learning

Poison-Robust Learning under Fixed Label Poisoning

Research Question

How can we design a stronger loss function or sample-weighting rule that improves robustness to poisoned training labels without changing the model, optimizer, or data pipeline?

Background

A small fraction of poisoned labels can disproportionately distort model decision boundaries. Robust learning methods typically modify the objective to downweight suspicious samples or reduce memorization of corrupted targets.

Task

Implement a better poison-robust objective in bench/poison/custom_robust_loss.py. The fixed harness injects label corruption into the training set, trains with your loss, and evaluates on a clean test set.

Your method should improve clean test accuracy under poisoning while keeping the training procedure modular and transferable across datasets.

Editable Interface

You must implement:

class RobustLoss:
    def compute_loss(self, logits, labels, epoch):
        ...
  • logits: current minibatch model outputs
  • labels: possibly poisoned labels
  • epoch: current training epoch
  • Return value: scalar loss tensor

The corruption process, model, optimizer, and training schedule are fixed.

Evaluation

Benchmarks:

  • fmnist-labelflip
  • cifar10-labelflip
  • cifar100-labelflip

Reported metrics:

  • test_acc
  • poison_fit
  • robust_score

Primary metric: robust_score (higher is better).

Baselines

  • cross_entropy: standard ERM on poisoned labels
  • generalized_ce: generalized cross-entropy for noisy labels
  • symmetric_ce: CE plus reverse-CE penalty
  • bootstrap: target interpolation with model predictions

pytorch-vision/bench/poison/custom_robust_loss.py [EDITABLE — entire file only]

     1: """Editable poison-robust loss for MLS-Bench."""
     2: 
     3: import torch
     4: import torch.nn.functional as F
     5: 
     6: # ============================================================
     7: # EDITABLE
     8: # ============================================================
     9: class RobustLoss:
    10:     """Default cross-entropy objective."""
    11: 
    12:     def __init__(self):
    13:         self.q = 0.7
    14: 
    15:     def compute_loss(self, logits, labels, epoch):
    16:         return F.cross_entropy(logits, labels)
    17: # ============================================================
    18: # END EDITABLE
    19: # ============================================================

cross_entropy baseline — editable region [READ-ONLY — reference implementation]

     1: """Editable poison-robust loss for MLS-Bench."""
     2: 
     3: import torch
     4: import torch.nn.functional as F
     5: 
     6: # ============================================================
     7: # EDITABLE
     8: class RobustLoss:
     9:     """Standard cross-entropy on poisoned labels."""
    10: 
    11:     def __init__(self):
    12:       
AssistantStep 1edit
Tool Input
1{
2  "op": "replace",
3  "filename": "pytorch-vision/bench/poison/custom_robust_loss.py",
4  "start_line": 1,
5  "end_line": 19,
6  "content": "\"\"\"Editable poison-robust loss for MLS-Bench.\"\"\"\n\nimport torch\nimport torch.nn.functional as F\n\n# ============================================================\n# EDITABLE\n# ============================================================\nclass RobustLoss:\n    \"\"\"Adaptive label smoothing with intra-class consistency regularization.\"\"\"\n\n    def __init__(self):\n        # Hyperparameters\n        self.ema_decay = 0.9  # decay for per-class prototype moving average\n        self.lambda_consistency = 0.5  # weight for consistency loss\n        self.smooth_eps = 0.1  # base label smoothing epsilon\n        self.adaptive_factor = 0.5  # scaling factor for adaptive smoothing\n        \n        # State variables\n        self.num_classes = None\n        self.class_prototypes = None  # shape (num_classes, num_classes)\n        self.initialized = False\n        \n    def _initialize(self, num_classes):\n        self.num_classes = num_classes\n        # Initialize prototypes as uniform distributions\n        self.class_prototypes = torch.ones(num_classes, num_classes) / num_classes\n        self.initialized = True\n\n    def compute_loss(self, logits, labels, epoch):\n        batch_size, num_classes = logits.shape\n        if not self.initialized:\n            self._initialize(num_classes)\n        \n        device = logits.device\n        self.class_prototypes = self.class_prototypes.to(device)\n        \n        # Convert labels to one-hot\n        one_hot = F.one_hot(labels, num_classes=num_classes).float()\n        \n        # Compute softmax probabilities\n        probs = torch.softmax(logits, dim=1)\n        \n        # Update per-class prototypes using exponential moving average\n        with torch.no_grad():\n            for c in range(num_classes):\n                mask = (labels == c)\n                if mask.any():\n                    class_probs = probs[mask]\n                    mean_prob = class_probs.mean(dim=0)\n                    # Update EMA\n                    self.class_prototypes[c] = (\n                        self.ema_decay * self.class_prototypes[c] +\n                        (1 - self.ema_decay) * mean_prob\n                    )\n        \n        # Intra-class consistency loss: KL divergence between sample probs and its class prototype\n        # Gather prototype for each sample\n        prototypes = self.class_prototypes[labels]  # shape (batch_size, num_classes)\n        # Add small epsilon to avoid log(0)\n        prototypes = prototypes.clamp_min(1e-8)\n        probs_clamped = probs.clamp_min(1e-8)\n        consistency_loss = (probs_clamped * (torch.log(probs_clamped) - torch.log(prototypes))).sum(dim=1).mean()\n        \n        # Adaptive label smoothing\n        # Confidence: probability assigned to the given label\n        confidence = probs.gather(1, labels.unsqueeze(1)).squeeze(1)  # shape (batch_size,)\n        # Smoothing factor: lower confidence -> more smoothing\n        alpha = (1 - confidence) * self.adaptive_factor\n        alpha = alpha.clamp(max=self.smooth_eps)  # cap at base smoothing epsilon\n        \n        # Smoothed target: mix of one-hot and uniform distribution\n        uniform = torch.ones_like(one_hot) / num_classes\n        # Expand alpha to shape (batch_size, 1) for broadcasting\n        alpha_expanded = alpha.unsqueeze(1)\n        target = (1 - alpha_expanded) * one_hot + alpha_expanded * uniform\n        \n        # Cross-entropy with smoothed targets\n        log_probs = F.log_softmax(logits, dim=1)\n        ce_loss = -(target * log_probs).sum(dim=1).mean()\n        \n        # Total loss\n        total_loss = ce_loss + self.lambda_consistency * consistency_loss\n        \n        return total_loss\n# ============================================================\n# END EDITABLE\n# ============================================================"
7}
Tool ResultStep 1
OK: Replaced lines 1..19 with 88 line(s) in pytorch-vision/bench/poison/custom_robust_loss.py. Editable range: entire file.

[Current file: pytorch-vision/bench/poison/custom_robust_loss.py | editable: 1–88 | total: 88 lines]
     1  """Editable poison-robust loss for MLS-Bench."""
     2  
     3  import torch
     4  import torch.nn.functional as F
     5  
     6  # ============================================================
     7  # EDITABLE
     8  # ============================================================
     9  class RobustLoss:
    10      """Adaptive label smoothing with intra-class consistency regularization."""
    11  
    12      def __init__(self):
    13          # Hyperparameters
    14          self.ema_decay = 0.9  # decay for per-class prototype moving average
    15          self.lambda_consistency = 0.5  # weight for consistency loss
    16          self.smooth_eps = 0.1  # base label smoothing epsilon
    17          self.adaptive_factor = 0.5  # scaling fa
AssistantStep 2test
Tool ResultStep 2
[Test #1] (2 tests remaining; call submit(n=1) to submit this result)


## Seed 42

### cifar10-labelflip (scripts/cifar10_labelflip.sh)
TRAIN_METRICS epoch=1 loss=2.2868
TRAIN_METRICS epoch=2 loss=2.2081
TRAIN_METRICS epoch=3 loss=2.1729
TRAIN_METRICS epoch=4 loss=2.1447
TRAIN_METRICS epoch=5 loss=2.1094
cifar10-labelflip: test_acc=0.2847, poison_fit=0.1108, robust_score=0.5869

### fmnist-labelflip (scripts/fmnist_labelflip.sh)
TRAIN_METRICS epoch=1 loss=2.2900
TRAIN_METRICS epoch=2 loss=1.9927
TRAIN_METRICS epoch=3 loss=1.8229
TRAIN_METRICS epoch=4 loss=1.7015
TRAIN_METRICS epoch=5 loss=1.6452
fmnist-labelflip: test_acc=0.7193, poison_fit=0.0210, robust_score=0.8492

### cifar100-labelflip (scripts/cifar100_labelflip.sh)
TRAIN_METRICS epoch=1 loss=4.6024
TRAIN_METRICS epoch=2 loss=4.5498
TRAIN_METRICS epoch=3 loss=4.5289
TRAIN_METRICS epoch=4 loss=4.5069
TRAIN_METRICS epoch=5 loss=4.4489
cifar100-labelflip: test_acc=0.0387, poison_fit=0.0276, robust_score=0.5055


## Seed 123

### c
AssistantStep 3edit
Tool Input
1{
2  "op": "replace",
3  "filename": "pytorch-vision/bench/poison/custom_robust_loss.py",
4  "start_line": 1,
5  "end_line": 88,
6  "content": "\"\"\"Editable poison-robust loss for MLS-Bench.\"\"\"\n\nimport torch\nimport torch.nn.functional as F\n\n# ============================================================\n# EDITABLE\n# ============================================================\nclass RobustLoss:\n    \"\"\"Adaptive Bootstrap with Confidence Weighting and Entropy Regularization.\"\"\"\n\n    def __init__(self):\n        # Hyperparameters\n        self.entropy_weight = 0.1  # weight for entropy regularization\n        self.confidence_threshold = 0.3  # threshold for high-confidence samples\n        self.temperature = 1.0  # temperature for soft targets\n        \n    def compute_loss(self, logits, labels, epoch):\n        batch_size, num_classes = logits.shape\n        \n        # Compute softmax probabilities with temperature\n        probs = torch.softmax(logits / self.temperature, dim=1)\n        # Compute confidence (probability of given label)\n        confidence = probs.gather(1, labels.unsqueeze(1)).squeeze(1)  # shape (batch_size,)\n        \n        # Adaptive bootstrap weight: beta = confidence, clipped to [0,1]\n        beta = confidence.detach()  # shape (batch_size,)\n        # Ensure beta is at least a small value to avoid zero weight on hard label\n        beta = beta.clamp(min=0.1, max=0.9)\n        \n        # Hard targets (one-hot)\n        hard = F.one_hot(labels, num_classes=num_classes).float()\n        # Soft targets (model predictions detached)\n        soft = probs.detach()\n        \n        # Interpolated targets per sample\n        beta_expanded = beta.unsqueeze(1)  # shape (batch_size, 1)\n        target = beta_expanded * hard + (1 - beta_expanded) * soft\n        \n        # Cross-entropy loss with interpolated targets\n        log_probs = F.log_softmax(logits, dim=1)\n        ce_loss = -(target * log_probs).sum(dim=1).mean()\n        \n        # Entropy regularization: encourage high confidence for samples with high given-label confidence\n        # Only apply to samples where confidence > threshold\n        mask = (confidence > self.confidence_threshold).float()\n        entropy = -(probs * torch.log(probs.clamp_min(1e-8))).sum(dim=1)\n        entropy_loss = (mask * entropy).mean()\n        \n        # Total loss\n        total_loss = ce_loss + self.entropy_weight * entropy_loss\n        \n        return total_loss\n# ============================================================\n# END EDITABLE\n# ============================================================"
7}
Tool ResultStep 3
OK: Replaced lines 1..88 with 56 line(s) in pytorch-vision/bench/poison/custom_robust_loss.py. Editable range: entire file.

[Current file: pytorch-vision/bench/poison/custom_robust_loss.py | editable: 1–56 | total: 56 lines]
     1  """Editable poison-robust loss for MLS-Bench."""
     2  
     3  import torch
     4  import torch.nn.functional as F
     5  
     6  # ============================================================
     7  # EDITABLE
     8  # ============================================================
     9  class RobustLoss:
    10      """Adaptive Bootstrap with Confidence Weighting and Entropy Regularization."""
    11  
    12      def __init__(self):
    13          # Hyperparameters
    14          self.entropy_weight = 0.1  # weight for entropy regularization
    15          self.confidence_threshold = 0.3  # threshold for high-confidence samples
    16          self.temperature = 1.0  # temperature for soft targets
    17          
    18      def compute_loss(
AssistantStep 4test
Tool ResultStep 4
[Test #2] (1 test remaining; call submit(n=2) to submit this result)

### cifar10-labelflip (scripts/cifar10_labelflip.sh)
TRAIN_METRICS epoch=1 loss=2.1535
TRAIN_METRICS epoch=2 loss=1.8810
TRAIN_METRICS epoch=3 loss=1.6851
TRAIN_METRICS epoch=4 loss=1.6955
TRAIN_METRICS epoch=5 loss=1.6880
cifar10-labelflip: test_acc=0.2100, poison_fit=0.1200, robust_score=0.5450

### fmnist-labelflip (scripts/fmnist_labelflip.sh)
TRAIN_METRICS epoch=1 loss=2.1337
TRAIN_METRICS epoch=2 loss=1.2999
TRAIN_METRICS epoch=3 loss=0.9917
TRAIN_METRICS epoch=4 loss=0.8698
TRAIN_METRICS epoch=5 loss=0.7145
fmnist-labelflip: test_acc=0.7020, poison_fit=0.0140, robust_score=0.8440

### cifar100-labelflip (scripts/cifar100_labelflip.sh)
TRAIN_METRICS epoch=1 loss=4.5935
TRAIN_METRICS epoch=2 loss=4.5320
TRAIN_METRICS epoch=3 loss=4.4990
TRAIN_METRICS epoch=4 loss=4.4487
TRAIN_METRICS epoch=5 loss=4.4011
cifar100-labelflip: test_acc=0.0233, poison_fit=0.0219, robust_score=0.5007

[Leaderboard] Results saved: {'te
AssistantStep 5edit
Tool Input
1{
2  "op": "replace",
3  "filename": "pytorch-vision/bench/poison/custom_robust_loss.py",
4  "start_line": 1,
5  "end_line": 56,
6  "content": "\"\"\"Editable poison-robust loss for MLS-Bench.\"\"\"\n\nimport torch\nimport torch.nn.functional as F\n\n# ============================================================\n# EDITABLE\n# ============================================================\nclass RobustLoss:\n    \"\"\"Dynamic Label Smoothing with Class Prototype Targets.\"\"\"\n\n    def __init__(self):\n        # Hyperparameters\n        self.ema_decay = 0.9  # for prototype updates\n        self.initial_alpha = 0.8  # initial weight on prototype\n        self.final_alpha = 0.0   # final weight on prototype\n        self.total_epochs = 5    # assumed total epochs (training runs 5 epochs)\n        self.entropy_weight = 0.05  # entropy regularization weight\n        \n        # State\n        self.num_classes = None\n        self.class_prototypes = None  # shape (num_classes, num_classes)\n        self.initialized = False\n        \n    def _initialize(self, num_classes):\n        self.num_classes = num_classes\n        # Initialize prototypes as uniform distributions\n        self.class_prototypes = torch.ones(num_classes, num_classes) / num_classes\n        self.initialized = True\n\n    def compute_loss(self, logits, labels, epoch):\n        batch_size, num_classes = logits.shape\n        if not self.initialized:\n            self._initialize(num_classes)\n        \n        device = logits.device\n        self.class_prototypes = self.class_prototypes.to(device)\n        \n        # Compute softmax probabilities\n        probs = torch.softmax(logits, dim=1)\n        \n        # Update per-class prototypes using exponential moving average\n        with torch.no_grad():\n            for c in range(num_classes):\n                mask = (labels == c)\n                if mask.any():\n                    class_probs = probs[mask]\n                    mean_prob = class_probs.mean(dim=0)\n                    # Update EMA\n                    self.class_prototypes[c] = (\n                        self.ema_decay * self.class_prototypes[c] +\n                        (1 - self.ema_decay) * mean_prob\n                    )\n        \n        # Schedule alpha: linear decay from initial_alpha to final_alpha over total_epochs\n        alpha = self.initial_alpha - (self.initial_alpha - self.final_alpha) * (epoch / (self.total_epochs - 1))\n        alpha = max(alpha, 0.0)\n        \n        # Prototype targets for each sample\n        prototypes = self.class_prototypes[labels]  # shape (batch_size, num_classes)\n        # Hard targets\n        hard = F.one_hot(labels, num_classes=num_classes).float()\n        \n        # Combined target\n        target = (1 - alpha) * hard + alpha * prototypes\n        \n        # Cross-entropy loss\n        log_probs = F.log_softmax(logits, dim=1)\n        ce_loss = -(target * log_probs).sum(dim=1).mean()\n        \n        # Entropy regularization: encourage confident predictions\n        entropy = -(probs * torch.log(probs.clamp_min(1e-8))).sum(dim=1).mean()\n        \n        # Total loss\n        total_loss = ce_loss + self.entropy_weight * entropy\n        \n        return total_loss\n# ============================================================\n# END EDITABLE\n# ============================================================"
7}
Tool ResultStep 5
OK: Replaced lines 1..56 with 80 line(s) in pytorch-vision/bench/poison/custom_robust_loss.py. Editable range: entire file.

[Current file: pytorch-vision/bench/poison/custom_robust_loss.py | editable: 1–80 | total: 80 lines]
     1  """Editable poison-robust loss for MLS-Bench."""
     2  
     3  import torch
     4  import torch.nn.functional as F
     5  
     6  # ============================================================
     7  # EDITABLE
     8  # ============================================================
     9  class RobustLoss:
    10      """Dynamic Label Smoothing with Class Prototype Targets."""
    11  
    12      def __init__(self):
    13          # Hyperparameters
    14          self.ema_decay = 0.9  # for prototype updates
    15          self.initial_alpha = 0.8  # initial weight on prototype
    16          self.final_alpha = 0.0   # final weight on prototype
    17          self.total_epochs = 5    # assumed total epochs (training runs 5 epochs)
    18      
AssistantStep 6test
Tool ResultStep 6
[Test #3] (0 tests remaining)


## Seed 42

### cifar10-labelflip (scripts/cifar10_labelflip.sh)
TRAIN_METRICS epoch=1 loss=2.4175
TRAIN_METRICS epoch=2 loss=2.3788
TRAIN_METRICS epoch=3 loss=2.3013
TRAIN_METRICS epoch=4 loss=2.1861
TRAIN_METRICS epoch=5 loss=2.0512
cifar10-labelflip: test_acc=0.3087, poison_fit=0.0967, robust_score=0.6060

### fmnist-labelflip (scripts/fmnist_labelflip.sh)
TRAIN_METRICS epoch=1 loss=2.4238
TRAIN_METRICS epoch=2 loss=2.3223
TRAIN_METRICS epoch=3 loss=2.0651
TRAIN_METRICS epoch=4 loss=1.7671
TRAIN_METRICS epoch=5 loss=1.4306
fmnist-labelflip: test_acc=0.6593, poison_fit=0.0270, robust_score=0.8162

### cifar100-labelflip (scripts/cifar100_labelflip.sh)
TRAIN_METRICS epoch=1 loss=4.8362
TRAIN_METRICS epoch=2 loss=4.8159
TRAIN_METRICS epoch=3 loss=4.7598
TRAIN_METRICS epoch=4 loss=4.6704
TRAIN_METRICS epoch=5 loss=4.5525
cifar100-labelflip: test_acc=0.0293, poison_fit=0.0219, robust_score=0.5037


## Seed 123

### cifar10-labelflip (scripts/cifar10_labelf
AssistantStep 1edit
Tool Input
1{
2  "op": "replace",
3  "filename": "pytorch-vision/bench/poison/custom_robust_loss.py",
4  "start_line": 1,
5  "end_line": 19,
6  "content": "\"\"\"Editable poison-robust loss for MLS-Bench.\"\"\"\n\nimport torch\nimport torch.nn.functional as F\n\n# ============================================================\n# EDITABLE\n# ============================================================\nclass RobustLoss:\n    \"\"\"Adaptive Robust Loss (ARL): softmin-weighted cross-entropy.\n    \n    Weights samples by exp(-loss_i / temperature) where loss_i is per-sample CE.\n    Temperature decays with epoch to gradually focus on low-loss (clean) samples.\n    \"\"\"\n\n    def __init__(self):\n        # initial temperature (high = uniform weights)\n        self.t0 = 2.0\n        # final temperature after 100 epochs\n        self.t_min = 0.1\n        # decay factor per epoch\n        self.decay = (self.t0 / self.t_min) ** (1.0 / 100)\n\n    def compute_loss(self, logits, labels, epoch):\n        # per-sample cross-entropy\n        loss_per_sample = F.cross_entropy(logits, labels, reduction='none')\n        # temperature for current epoch\n        temperature = self.t0 * (self.decay ** (-epoch))\n        temperature = max(temperature, self.t_min)\n        # compute weights: softmin over losses\n        weights = F.softmax(-loss_per_sample.detach() / temperature, dim=0)\n        # weighted loss\n        loss = (weights * loss_per_sample).sum()\n        return loss\n# ============================================================\n# END EDITABLE\n# ============================================================"
7}
Tool ResultStep 1
OK: Replaced lines 1..19 with 37 line(s) in pytorch-vision/bench/poison/custom_robust_loss.py. Editable range: entire file.

[Current file: pytorch-vision/bench/poison/custom_robust_loss.py | editable: 1–37 | total: 37 lines]
     1  """Editable poison-robust loss for MLS-Bench."""
     2  
     3  import torch
     4  import torch.nn.functional as F
     5  
     6  # ============================================================
     7  # EDITABLE
     8  # ============================================================
     9  class RobustLoss:
    10      """Adaptive Robust Loss (ARL): softmin-weighted cross-entropy.
    11      
    12      Weights samples by exp(-loss_i / temperature) where loss_i is per-sample CE.
    13      Temperature decays with epoch to gradually focus on low-loss (clean) samples.
    14      """
    15  
    16      def __init__(self):
    17          # initial temperature (high = uniform weights)
    18          self.t0 = 2.0
    19          # final temperatur
AssistantStep 2test
AssistantStep 1edit
Tool Input
1{
2  "op": "replace",
3  "filename": "pytorch-vision/bench/poison/custom_robust_loss.py",
4  "start_line": 1,
5  "end_line": 19,
6  "content": "\"\"\"Editable poison-robust loss for MLS-Bench.\"\"\"\n\nimport torch\nimport torch.nn.functional as F\n\n# ============================================================\n# EDITABLE\n# ============================================================\nclass RobustLoss:\n    \"\"\"Adaptive Confidence-based Label Smoothing (ACLS).\"\"\"\n\n    def __init__(self):\n        self.num_classes = None\n        self.eps = 1e-8\n        # smoothing schedule: start with low smoothing, increase later\n        self.smoothing_start = 0.1\n        self.smoothing_end = 0.5\n        # confidence threshold for weighting\n        self.conf_threshold = 0.3\n\n    def compute_loss(self, logits, labels, epoch):\n        if self.num_classes is None:\n            self.num_classes = logits.shape[1]\n        \n        # compute confidence\n        probs = torch.softmax(logits, dim=1)\n        confidence = probs.gather(1, labels.unsqueeze(1)).squeeze(1)  # shape (batch,)\n        \n        # adaptive smoothing factor per sample based on confidence\n        # low confidence -> higher smoothing\n        smoothing_factor = (1.0 - confidence).clamp(min=self.smoothing_start, max=self.smoothing_end)\n        \n        # schedule: increase smoothing factor as training progresses\n        epoch_factor = min(1.0, epoch / 100.0)  # assuming 100 epochs total\n        smoothing_factor = smoothing_factor * epoch_factor\n        \n        # create smoothed labels\n        one_hot = F.one_hot(labels, num_classes=self.num_classes).float()\n        smoothing_matrix = smoothing_factor.unsqueeze(1) / self.num_classes\n        smoothed_labels = one_hot * (1.0 - smoothing_factor.unsqueeze(1)) + smoothing_matrix\n        \n        # compute cross-entropy with smoothed labels\n        log_probs = F.log_softmax(logits, dim=1)\n        loss = -(smoothed_labels * log_probs).sum(dim=1).mean()\n        return loss\n# ============================================================\n# END EDITABLE\n# ============================================================"
7}
Tool ResultStep 1
OK: Replaced lines 1..19 with 48 line(s) in pytorch-vision/bench/poison/custom_robust_loss.py. Editable range: entire file.

[Current file: pytorch-vision/bench/poison/custom_robust_loss.py | editable: 1–48 | total: 48 lines]
     1  """Editable poison-robust loss for MLS-Bench."""
     2  
     3  import torch
     4  import torch.nn.functional as F
     5  
     6  # ============================================================
     7  # EDITABLE
     8  # ============================================================
     9  class RobustLoss:
    10      """Adaptive Confidence-based Label Smoothing (ACLS)."""
    11  
    12      def __init__(self):
    13          self.num_classes = None
    14          self.eps = 1e-8
    15          # smoothing schedule: start with low smoothing, increase later
    16          self.smoothing_start = 0.1
    17          self.smoothing_end = 0.5
    18          # confidence threshold for weighting
    19          self.conf_threshold = 0.3
    20  
   
AssistantStep 2test
AssistantStep 1edit
Tool Input
1{
2  "op": "replace",
3  "filename": "pytorch-vision/bench/poison/custom_robust_loss.py",
4  "start_line": 1,
5  "end_line": 19,
6  "content": "\"\"\"Editable poison-robust loss for MLS-Bench.\"\"\"\n\nimport torch\nimport torch.nn.functional as F\n\n# ============================================================\n# EDITABLE\n# ============================================================\nclass RobustLoss:\n    \"\"\"Adaptive Bootstrap with Confidence Thresholding (ABCT).\"\"\"\n\n    def __init__(self):\n        # Base mixing parameter for bootstrap\n        self.beta_base = 0.8\n        # Confidence threshold starts low, increases over epochs\n        self.conf_thresh_start = 0.7\n        self.conf_thresh_end = 0.95\n        # Number of total epochs (assuming 100 as in description)\n        self.total_epochs = 100\n\n    def compute_loss(self, logits, labels, epoch):\n        # Current confidence threshold linear schedule\n        conf_thresh = self.conf_thresh_start + (self.conf_thresh_end - self.conf_thresh_start) * (epoch / self.total_epochs)\n        \n        probs = torch.softmax(logits, dim=1)\n        conf, preds = torch.max(probs, dim=1)\n        \n        # Determine per-sample beta: low if high confidence and prediction != label (suspicious)\n        beta = torch.ones_like(conf) * self.beta_base\n        suspicious = (conf > conf_thresh) & (preds != labels)\n        beta[suspicious] = 0.2  # trust predictions more for suspicious samples\n        \n        # Bootstrap target mixing\n        hard = F.one_hot(labels, num_classes=logits.shape[1]).float()\n        soft = torch.softmax(logits.detach(), dim=1)\n        # Expand beta to match dimensions\n        beta_expanded = beta.view(-1, 1)\n        target = beta_expanded * hard + (1.0 - beta_expanded) * soft\n        \n        log_probs = F.log_softmax(logits, dim=1)\n        return -(target * log_probs).sum(dim=1).mean()\n# ============================================================\n# END EDITABLE\n# ============================================================"
7}