Agent Conversation: deepseek-reasoner
Task: security-poison-robust-learning
Poison-Robust Learning under Fixed Label Poisoning
Research Question
How can we design a stronger loss function or sample-weighting rule that improves robustness to poisoned training labels without changing the model, optimizer, or data pipeline?
Background
A small fraction of poisoned labels can disproportionately distort model decision boundaries. Robust learning methods typically modify the objective to downweight suspicious samples or reduce memorization of corrupted targets.
Task
Implement a better poison-robust objective in bench/poison/custom_robust_loss.py. The fixed harness injects label corruption into the training set, trains with your loss, and evaluates on a clean test set.
Your method should improve clean test accuracy under poisoning while keeping the training procedure modular and transferable across datasets.
Editable Interface
You must implement:
class RobustLoss:
def compute_loss(self, logits, labels, epoch):
...
logits: current minibatch model outputslabels: possibly poisoned labelsepoch: current training epoch- Return value: scalar loss tensor
The corruption process, model, optimizer, and training schedule are fixed.
Evaluation
Benchmarks:
fmnist-labelflipcifar10-labelflipcifar100-labelflip
Reported metrics:
test_accpoison_fitrobust_score
Primary metric: robust_score (higher is better).
Baselines
cross_entropy: standard ERM on poisoned labelsgeneralized_ce: generalized cross-entropy for noisy labelssymmetric_ce: CE plus reverse-CE penaltybootstrap: target interpolation with model predictions
pytorch-vision/bench/poison/custom_robust_loss.py [EDITABLE — entire file only]
1: """Editable poison-robust loss for MLS-Bench."""
2:
3: import torch
4: import torch.nn.functional as F
5:
6: # ============================================================
7: # EDITABLE
8: # ============================================================
9: class RobustLoss:
10: """Default cross-entropy objective."""
11:
12: def __init__(self):
13: self.q = 0.7
14:
15: def compute_loss(self, logits, labels, epoch):
16: return F.cross_entropy(logits, labels)
17: # ============================================================
18: # END EDITABLE
19: # ============================================================
cross_entropy baseline — editable region [READ-ONLY — reference implementation]
1: """Editable poison-robust loss for MLS-Bench."""
2:
3: import torch
4: import torch.nn.functional as F
5:
6: # ============================================================
7: # EDITABLE
8: class RobustLoss:
9: """Standard cross-entropy on poisoned labels."""
10:
11: def __init__(self):
12: Tool Input
1{
2 "op": "replace",
3 "filename": "pytorch-vision/bench/poison/custom_robust_loss.py",
4 "start_line": 1,
5 "end_line": 19,
6 "content": "\"\"\"Editable poison-robust loss for MLS-Bench.\"\"\"\n\nimport torch\nimport torch.nn.functional as F\n\n# ============================================================\n# EDITABLE\n# ============================================================\nclass RobustLoss:\n \"\"\"Adaptive label smoothing with intra-class consistency regularization.\"\"\"\n\n def __init__(self):\n # Hyperparameters\n self.ema_decay = 0.9 # decay for per-class prototype moving average\n self.lambda_consistency = 0.5 # weight for consistency loss\n self.smooth_eps = 0.1 # base label smoothing epsilon\n self.adaptive_factor = 0.5 # scaling factor for adaptive smoothing\n \n # State variables\n self.num_classes = None\n self.class_prototypes = None # shape (num_classes, num_classes)\n self.initialized = False\n \n def _initialize(self, num_classes):\n self.num_classes = num_classes\n # Initialize prototypes as uniform distributions\n self.class_prototypes = torch.ones(num_classes, num_classes) / num_classes\n self.initialized = True\n\n def compute_loss(self, logits, labels, epoch):\n batch_size, num_classes = logits.shape\n if not self.initialized:\n self._initialize(num_classes)\n \n device = logits.device\n self.class_prototypes = self.class_prototypes.to(device)\n \n # Convert labels to one-hot\n one_hot = F.one_hot(labels, num_classes=num_classes).float()\n \n # Compute softmax probabilities\n probs = torch.softmax(logits, dim=1)\n \n # Update per-class prototypes using exponential moving average\n with torch.no_grad():\n for c in range(num_classes):\n mask = (labels == c)\n if mask.any():\n class_probs = probs[mask]\n mean_prob = class_probs.mean(dim=0)\n # Update EMA\n self.class_prototypes[c] = (\n self.ema_decay * self.class_prototypes[c] +\n (1 - self.ema_decay) * mean_prob\n )\n \n # Intra-class consistency loss: KL divergence between sample probs and its class prototype\n # Gather prototype for each sample\n prototypes = self.class_prototypes[labels] # shape (batch_size, num_classes)\n # Add small epsilon to avoid log(0)\n prototypes = prototypes.clamp_min(1e-8)\n probs_clamped = probs.clamp_min(1e-8)\n consistency_loss = (probs_clamped * (torch.log(probs_clamped) - torch.log(prototypes))).sum(dim=1).mean()\n \n # Adaptive label smoothing\n # Confidence: probability assigned to the given label\n confidence = probs.gather(1, labels.unsqueeze(1)).squeeze(1) # shape (batch_size,)\n # Smoothing factor: lower confidence -> more smoothing\n alpha = (1 - confidence) * self.adaptive_factor\n alpha = alpha.clamp(max=self.smooth_eps) # cap at base smoothing epsilon\n \n # Smoothed target: mix of one-hot and uniform distribution\n uniform = torch.ones_like(one_hot) / num_classes\n # Expand alpha to shape (batch_size, 1) for broadcasting\n alpha_expanded = alpha.unsqueeze(1)\n target = (1 - alpha_expanded) * one_hot + alpha_expanded * uniform\n \n # Cross-entropy with smoothed targets\n log_probs = F.log_softmax(logits, dim=1)\n ce_loss = -(target * log_probs).sum(dim=1).mean()\n \n # Total loss\n total_loss = ce_loss + self.lambda_consistency * consistency_loss\n \n return total_loss\n# ============================================================\n# END EDITABLE\n# ============================================================"
7}OK: Replaced lines 1..19 with 88 line(s) in pytorch-vision/bench/poison/custom_robust_loss.py. Editable range: entire file.
[Current file: pytorch-vision/bench/poison/custom_robust_loss.py | editable: 1–88 | total: 88 lines]
1 """Editable poison-robust loss for MLS-Bench."""
2
3 import torch
4 import torch.nn.functional as F
5
6 # ============================================================
7 # EDITABLE
8 # ============================================================
9 class RobustLoss:
10 """Adaptive label smoothing with intra-class consistency regularization."""
11
12 def __init__(self):
13 # Hyperparameters
14 self.ema_decay = 0.9 # decay for per-class prototype moving average
15 self.lambda_consistency = 0.5 # weight for consistency loss
16 self.smooth_eps = 0.1 # base label smoothing epsilon
17 self.adaptive_factor = 0.5 # scaling fa
[Test #1] (2 tests remaining; call submit(n=1) to submit this result) ## Seed 42 ### cifar10-labelflip (scripts/cifar10_labelflip.sh) TRAIN_METRICS epoch=1 loss=2.2868 TRAIN_METRICS epoch=2 loss=2.2081 TRAIN_METRICS epoch=3 loss=2.1729 TRAIN_METRICS epoch=4 loss=2.1447 TRAIN_METRICS epoch=5 loss=2.1094 cifar10-labelflip: test_acc=0.2847, poison_fit=0.1108, robust_score=0.5869 ### fmnist-labelflip (scripts/fmnist_labelflip.sh) TRAIN_METRICS epoch=1 loss=2.2900 TRAIN_METRICS epoch=2 loss=1.9927 TRAIN_METRICS epoch=3 loss=1.8229 TRAIN_METRICS epoch=4 loss=1.7015 TRAIN_METRICS epoch=5 loss=1.6452 fmnist-labelflip: test_acc=0.7193, poison_fit=0.0210, robust_score=0.8492 ### cifar100-labelflip (scripts/cifar100_labelflip.sh) TRAIN_METRICS epoch=1 loss=4.6024 TRAIN_METRICS epoch=2 loss=4.5498 TRAIN_METRICS epoch=3 loss=4.5289 TRAIN_METRICS epoch=4 loss=4.5069 TRAIN_METRICS epoch=5 loss=4.4489 cifar100-labelflip: test_acc=0.0387, poison_fit=0.0276, robust_score=0.5055 ## Seed 123 ### c
Tool Input
1{
2 "op": "replace",
3 "filename": "pytorch-vision/bench/poison/custom_robust_loss.py",
4 "start_line": 1,
5 "end_line": 88,
6 "content": "\"\"\"Editable poison-robust loss for MLS-Bench.\"\"\"\n\nimport torch\nimport torch.nn.functional as F\n\n# ============================================================\n# EDITABLE\n# ============================================================\nclass RobustLoss:\n \"\"\"Adaptive Bootstrap with Confidence Weighting and Entropy Regularization.\"\"\"\n\n def __init__(self):\n # Hyperparameters\n self.entropy_weight = 0.1 # weight for entropy regularization\n self.confidence_threshold = 0.3 # threshold for high-confidence samples\n self.temperature = 1.0 # temperature for soft targets\n \n def compute_loss(self, logits, labels, epoch):\n batch_size, num_classes = logits.shape\n \n # Compute softmax probabilities with temperature\n probs = torch.softmax(logits / self.temperature, dim=1)\n # Compute confidence (probability of given label)\n confidence = probs.gather(1, labels.unsqueeze(1)).squeeze(1) # shape (batch_size,)\n \n # Adaptive bootstrap weight: beta = confidence, clipped to [0,1]\n beta = confidence.detach() # shape (batch_size,)\n # Ensure beta is at least a small value to avoid zero weight on hard label\n beta = beta.clamp(min=0.1, max=0.9)\n \n # Hard targets (one-hot)\n hard = F.one_hot(labels, num_classes=num_classes).float()\n # Soft targets (model predictions detached)\n soft = probs.detach()\n \n # Interpolated targets per sample\n beta_expanded = beta.unsqueeze(1) # shape (batch_size, 1)\n target = beta_expanded * hard + (1 - beta_expanded) * soft\n \n # Cross-entropy loss with interpolated targets\n log_probs = F.log_softmax(logits, dim=1)\n ce_loss = -(target * log_probs).sum(dim=1).mean()\n \n # Entropy regularization: encourage high confidence for samples with high given-label confidence\n # Only apply to samples where confidence > threshold\n mask = (confidence > self.confidence_threshold).float()\n entropy = -(probs * torch.log(probs.clamp_min(1e-8))).sum(dim=1)\n entropy_loss = (mask * entropy).mean()\n \n # Total loss\n total_loss = ce_loss + self.entropy_weight * entropy_loss\n \n return total_loss\n# ============================================================\n# END EDITABLE\n# ============================================================"
7}OK: Replaced lines 1..88 with 56 line(s) in pytorch-vision/bench/poison/custom_robust_loss.py. Editable range: entire file.
[Current file: pytorch-vision/bench/poison/custom_robust_loss.py | editable: 1–56 | total: 56 lines]
1 """Editable poison-robust loss for MLS-Bench."""
2
3 import torch
4 import torch.nn.functional as F
5
6 # ============================================================
7 # EDITABLE
8 # ============================================================
9 class RobustLoss:
10 """Adaptive Bootstrap with Confidence Weighting and Entropy Regularization."""
11
12 def __init__(self):
13 # Hyperparameters
14 self.entropy_weight = 0.1 # weight for entropy regularization
15 self.confidence_threshold = 0.3 # threshold for high-confidence samples
16 self.temperature = 1.0 # temperature for soft targets
17
18 def compute_loss(
[Test #2] (1 test remaining; call submit(n=2) to submit this result)
### cifar10-labelflip (scripts/cifar10_labelflip.sh)
TRAIN_METRICS epoch=1 loss=2.1535
TRAIN_METRICS epoch=2 loss=1.8810
TRAIN_METRICS epoch=3 loss=1.6851
TRAIN_METRICS epoch=4 loss=1.6955
TRAIN_METRICS epoch=5 loss=1.6880
cifar10-labelflip: test_acc=0.2100, poison_fit=0.1200, robust_score=0.5450
### fmnist-labelflip (scripts/fmnist_labelflip.sh)
TRAIN_METRICS epoch=1 loss=2.1337
TRAIN_METRICS epoch=2 loss=1.2999
TRAIN_METRICS epoch=3 loss=0.9917
TRAIN_METRICS epoch=4 loss=0.8698
TRAIN_METRICS epoch=5 loss=0.7145
fmnist-labelflip: test_acc=0.7020, poison_fit=0.0140, robust_score=0.8440
### cifar100-labelflip (scripts/cifar100_labelflip.sh)
TRAIN_METRICS epoch=1 loss=4.5935
TRAIN_METRICS epoch=2 loss=4.5320
TRAIN_METRICS epoch=3 loss=4.4990
TRAIN_METRICS epoch=4 loss=4.4487
TRAIN_METRICS epoch=5 loss=4.4011
cifar100-labelflip: test_acc=0.0233, poison_fit=0.0219, robust_score=0.5007
[Leaderboard] Results saved: {'te
Tool Input
1{
2 "op": "replace",
3 "filename": "pytorch-vision/bench/poison/custom_robust_loss.py",
4 "start_line": 1,
5 "end_line": 56,
6 "content": "\"\"\"Editable poison-robust loss for MLS-Bench.\"\"\"\n\nimport torch\nimport torch.nn.functional as F\n\n# ============================================================\n# EDITABLE\n# ============================================================\nclass RobustLoss:\n \"\"\"Dynamic Label Smoothing with Class Prototype Targets.\"\"\"\n\n def __init__(self):\n # Hyperparameters\n self.ema_decay = 0.9 # for prototype updates\n self.initial_alpha = 0.8 # initial weight on prototype\n self.final_alpha = 0.0 # final weight on prototype\n self.total_epochs = 5 # assumed total epochs (training runs 5 epochs)\n self.entropy_weight = 0.05 # entropy regularization weight\n \n # State\n self.num_classes = None\n self.class_prototypes = None # shape (num_classes, num_classes)\n self.initialized = False\n \n def _initialize(self, num_classes):\n self.num_classes = num_classes\n # Initialize prototypes as uniform distributions\n self.class_prototypes = torch.ones(num_classes, num_classes) / num_classes\n self.initialized = True\n\n def compute_loss(self, logits, labels, epoch):\n batch_size, num_classes = logits.shape\n if not self.initialized:\n self._initialize(num_classes)\n \n device = logits.device\n self.class_prototypes = self.class_prototypes.to(device)\n \n # Compute softmax probabilities\n probs = torch.softmax(logits, dim=1)\n \n # Update per-class prototypes using exponential moving average\n with torch.no_grad():\n for c in range(num_classes):\n mask = (labels == c)\n if mask.any():\n class_probs = probs[mask]\n mean_prob = class_probs.mean(dim=0)\n # Update EMA\n self.class_prototypes[c] = (\n self.ema_decay * self.class_prototypes[c] +\n (1 - self.ema_decay) * mean_prob\n )\n \n # Schedule alpha: linear decay from initial_alpha to final_alpha over total_epochs\n alpha = self.initial_alpha - (self.initial_alpha - self.final_alpha) * (epoch / (self.total_epochs - 1))\n alpha = max(alpha, 0.0)\n \n # Prototype targets for each sample\n prototypes = self.class_prototypes[labels] # shape (batch_size, num_classes)\n # Hard targets\n hard = F.one_hot(labels, num_classes=num_classes).float()\n \n # Combined target\n target = (1 - alpha) * hard + alpha * prototypes\n \n # Cross-entropy loss\n log_probs = F.log_softmax(logits, dim=1)\n ce_loss = -(target * log_probs).sum(dim=1).mean()\n \n # Entropy regularization: encourage confident predictions\n entropy = -(probs * torch.log(probs.clamp_min(1e-8))).sum(dim=1).mean()\n \n # Total loss\n total_loss = ce_loss + self.entropy_weight * entropy\n \n return total_loss\n# ============================================================\n# END EDITABLE\n# ============================================================"
7}OK: Replaced lines 1..56 with 80 line(s) in pytorch-vision/bench/poison/custom_robust_loss.py. Editable range: entire file.
[Current file: pytorch-vision/bench/poison/custom_robust_loss.py | editable: 1–80 | total: 80 lines]
1 """Editable poison-robust loss for MLS-Bench."""
2
3 import torch
4 import torch.nn.functional as F
5
6 # ============================================================
7 # EDITABLE
8 # ============================================================
9 class RobustLoss:
10 """Dynamic Label Smoothing with Class Prototype Targets."""
11
12 def __init__(self):
13 # Hyperparameters
14 self.ema_decay = 0.9 # for prototype updates
15 self.initial_alpha = 0.8 # initial weight on prototype
16 self.final_alpha = 0.0 # final weight on prototype
17 self.total_epochs = 5 # assumed total epochs (training runs 5 epochs)
18
[Test #3] (0 tests remaining) ## Seed 42 ### cifar10-labelflip (scripts/cifar10_labelflip.sh) TRAIN_METRICS epoch=1 loss=2.4175 TRAIN_METRICS epoch=2 loss=2.3788 TRAIN_METRICS epoch=3 loss=2.3013 TRAIN_METRICS epoch=4 loss=2.1861 TRAIN_METRICS epoch=5 loss=2.0512 cifar10-labelflip: test_acc=0.3087, poison_fit=0.0967, robust_score=0.6060 ### fmnist-labelflip (scripts/fmnist_labelflip.sh) TRAIN_METRICS epoch=1 loss=2.4238 TRAIN_METRICS epoch=2 loss=2.3223 TRAIN_METRICS epoch=3 loss=2.0651 TRAIN_METRICS epoch=4 loss=1.7671 TRAIN_METRICS epoch=5 loss=1.4306 fmnist-labelflip: test_acc=0.6593, poison_fit=0.0270, robust_score=0.8162 ### cifar100-labelflip (scripts/cifar100_labelflip.sh) TRAIN_METRICS epoch=1 loss=4.8362 TRAIN_METRICS epoch=2 loss=4.8159 TRAIN_METRICS epoch=3 loss=4.7598 TRAIN_METRICS epoch=4 loss=4.6704 TRAIN_METRICS epoch=5 loss=4.5525 cifar100-labelflip: test_acc=0.0293, poison_fit=0.0219, robust_score=0.5037 ## Seed 123 ### cifar10-labelflip (scripts/cifar10_labelf
Tool Input
1{
2 "op": "replace",
3 "filename": "pytorch-vision/bench/poison/custom_robust_loss.py",
4 "start_line": 1,
5 "end_line": 19,
6 "content": "\"\"\"Editable poison-robust loss for MLS-Bench.\"\"\"\n\nimport torch\nimport torch.nn.functional as F\n\n# ============================================================\n# EDITABLE\n# ============================================================\nclass RobustLoss:\n \"\"\"Adaptive Robust Loss (ARL): softmin-weighted cross-entropy.\n \n Weights samples by exp(-loss_i / temperature) where loss_i is per-sample CE.\n Temperature decays with epoch to gradually focus on low-loss (clean) samples.\n \"\"\"\n\n def __init__(self):\n # initial temperature (high = uniform weights)\n self.t0 = 2.0\n # final temperature after 100 epochs\n self.t_min = 0.1\n # decay factor per epoch\n self.decay = (self.t0 / self.t_min) ** (1.0 / 100)\n\n def compute_loss(self, logits, labels, epoch):\n # per-sample cross-entropy\n loss_per_sample = F.cross_entropy(logits, labels, reduction='none')\n # temperature for current epoch\n temperature = self.t0 * (self.decay ** (-epoch))\n temperature = max(temperature, self.t_min)\n # compute weights: softmin over losses\n weights = F.softmax(-loss_per_sample.detach() / temperature, dim=0)\n # weighted loss\n loss = (weights * loss_per_sample).sum()\n return loss\n# ============================================================\n# END EDITABLE\n# ============================================================"
7}OK: Replaced lines 1..19 with 37 line(s) in pytorch-vision/bench/poison/custom_robust_loss.py. Editable range: entire file.
[Current file: pytorch-vision/bench/poison/custom_robust_loss.py | editable: 1–37 | total: 37 lines]
1 """Editable poison-robust loss for MLS-Bench."""
2
3 import torch
4 import torch.nn.functional as F
5
6 # ============================================================
7 # EDITABLE
8 # ============================================================
9 class RobustLoss:
10 """Adaptive Robust Loss (ARL): softmin-weighted cross-entropy.
11
12 Weights samples by exp(-loss_i / temperature) where loss_i is per-sample CE.
13 Temperature decays with epoch to gradually focus on low-loss (clean) samples.
14 """
15
16 def __init__(self):
17 # initial temperature (high = uniform weights)
18 self.t0 = 2.0
19 # final temperatur
Tool Input
1{
2 "op": "replace",
3 "filename": "pytorch-vision/bench/poison/custom_robust_loss.py",
4 "start_line": 1,
5 "end_line": 19,
6 "content": "\"\"\"Editable poison-robust loss for MLS-Bench.\"\"\"\n\nimport torch\nimport torch.nn.functional as F\n\n# ============================================================\n# EDITABLE\n# ============================================================\nclass RobustLoss:\n \"\"\"Adaptive Confidence-based Label Smoothing (ACLS).\"\"\"\n\n def __init__(self):\n self.num_classes = None\n self.eps = 1e-8\n # smoothing schedule: start with low smoothing, increase later\n self.smoothing_start = 0.1\n self.smoothing_end = 0.5\n # confidence threshold for weighting\n self.conf_threshold = 0.3\n\n def compute_loss(self, logits, labels, epoch):\n if self.num_classes is None:\n self.num_classes = logits.shape[1]\n \n # compute confidence\n probs = torch.softmax(logits, dim=1)\n confidence = probs.gather(1, labels.unsqueeze(1)).squeeze(1) # shape (batch,)\n \n # adaptive smoothing factor per sample based on confidence\n # low confidence -> higher smoothing\n smoothing_factor = (1.0 - confidence).clamp(min=self.smoothing_start, max=self.smoothing_end)\n \n # schedule: increase smoothing factor as training progresses\n epoch_factor = min(1.0, epoch / 100.0) # assuming 100 epochs total\n smoothing_factor = smoothing_factor * epoch_factor\n \n # create smoothed labels\n one_hot = F.one_hot(labels, num_classes=self.num_classes).float()\n smoothing_matrix = smoothing_factor.unsqueeze(1) / self.num_classes\n smoothed_labels = one_hot * (1.0 - smoothing_factor.unsqueeze(1)) + smoothing_matrix\n \n # compute cross-entropy with smoothed labels\n log_probs = F.log_softmax(logits, dim=1)\n loss = -(smoothed_labels * log_probs).sum(dim=1).mean()\n return loss\n# ============================================================\n# END EDITABLE\n# ============================================================"
7}OK: Replaced lines 1..19 with 48 line(s) in pytorch-vision/bench/poison/custom_robust_loss.py. Editable range: entire file.
[Current file: pytorch-vision/bench/poison/custom_robust_loss.py | editable: 1–48 | total: 48 lines]
1 """Editable poison-robust loss for MLS-Bench."""
2
3 import torch
4 import torch.nn.functional as F
5
6 # ============================================================
7 # EDITABLE
8 # ============================================================
9 class RobustLoss:
10 """Adaptive Confidence-based Label Smoothing (ACLS)."""
11
12 def __init__(self):
13 self.num_classes = None
14 self.eps = 1e-8
15 # smoothing schedule: start with low smoothing, increase later
16 self.smoothing_start = 0.1
17 self.smoothing_end = 0.5
18 # confidence threshold for weighting
19 self.conf_threshold = 0.3
20
Tool Input
1{
2 "op": "replace",
3 "filename": "pytorch-vision/bench/poison/custom_robust_loss.py",
4 "start_line": 1,
5 "end_line": 19,
6 "content": "\"\"\"Editable poison-robust loss for MLS-Bench.\"\"\"\n\nimport torch\nimport torch.nn.functional as F\n\n# ============================================================\n# EDITABLE\n# ============================================================\nclass RobustLoss:\n \"\"\"Adaptive Bootstrap with Confidence Thresholding (ABCT).\"\"\"\n\n def __init__(self):\n # Base mixing parameter for bootstrap\n self.beta_base = 0.8\n # Confidence threshold starts low, increases over epochs\n self.conf_thresh_start = 0.7\n self.conf_thresh_end = 0.95\n # Number of total epochs (assuming 100 as in description)\n self.total_epochs = 100\n\n def compute_loss(self, logits, labels, epoch):\n # Current confidence threshold linear schedule\n conf_thresh = self.conf_thresh_start + (self.conf_thresh_end - self.conf_thresh_start) * (epoch / self.total_epochs)\n \n probs = torch.softmax(logits, dim=1)\n conf, preds = torch.max(probs, dim=1)\n \n # Determine per-sample beta: low if high confidence and prediction != label (suspicious)\n beta = torch.ones_like(conf) * self.beta_base\n suspicious = (conf > conf_thresh) & (preds != labels)\n beta[suspicious] = 0.2 # trust predictions more for suspicious samples\n \n # Bootstrap target mixing\n hard = F.one_hot(labels, num_classes=logits.shape[1]).float()\n soft = torch.softmax(logits.detach(), dim=1)\n # Expand beta to match dimensions\n beta_expanded = beta.view(-1, 1)\n target = beta_expanded * hard + (1.0 - beta_expanded) * soft\n \n log_probs = F.log_softmax(logits, dim=1)\n return -(target * log_probs).sum(dim=1).mean()\n# ============================================================\n# END EDITABLE\n# ============================================================"
7}