Agent Conversation: google/gemini-3.1-pro-preview
Model: google/gemini-3.1-pro-preview
Total steps: 7
Messages: 15
UserStep 0
Task: security-adversarial-attack-black-box-score
Score-Based Query Black-Box Attack under Linf Constraint
Research Question
Can you design a stronger score-based query black-box attack that improves attack success rate (ASR) under a fixed query budget and L_inf perturbation constraint?
Objective
Implement a better black-box attack in bench/custom_attack.py:
- Threat model: query black-box (no gradient access).
- Constraint:
||x_adv - x||_inf <= eps. - Budget:
n_queriesis a per-sample query budget. - Primary metric: maximize
ASRunder fixed budget. - Tie-break: for similar ASR, lower
avg_queriesis better.
Editable Interface
You must implement:
run_attack(model, images, labels, eps, n_queries, device, n_classes) -> adv_images
Inputs:
model: black-box wrapper that returns logits only.images: tensor of shape(N, C, H, W), in[0, 1].labels: tensor of shape(N,).n_classes: 10 for CIFAR-10, 100 for CIFAR-100.
Output:
adv_images: tensor with same shape asimages, values in[0, 1].
Trusted Evaluation Logic
The evaluation logic in bench/run_eval.py is trusted and not editable.
- It tracks all model queries through a wrapper.
- If a batch exceeds query budget (
batch_size * n_queries), the entire batch is marked as attack failure. L_infand[0, 1]validity are checked per sample; only invalid samples are marked as attack failure.
Do not attempt to bypass wrapper behavior or evaluation logic. Only improve the attack algorithm in custom_attack.py.
Query Semantics
- One call to
model(x)consumesx.shape[0]queries. - Repeated calls on the same sample still consume additional queries.
- Different batch partitioning should be treated as equivalent total budget usage.
Performance Hint
The benchmark is query-heavy. Prefer batch-parallel score updates instead of pure per-sample loops when possible.
Evaluation Scenarios (6)
- ResNet20 on CIFAR-10
- VGG11-BN on CIFAR-10
- MobileNetV2 on CIFAR-10
- ResNet20 on CIFAR-100
- VGG11-BN on CIFAR-100
- MobileNetV2 on CIFAR-100
Reported metrics line format:
ATTACK_METRICS asr=... clean_acc=... robust_acc=... avg_queries=...
torchattacks/bench/custom_attack.py [EDITABLE — lines 7–56 only]
1: import torch
2: import torch.nn as nn
3:
4: # =====================================================================
5: # EDITABLE: implement run_attack below
6: # =====================================================================
7: def run_attack(
8: model: nn.Module,
9: images: torch.Tensor,
10: labels: torch.Tensor,
11: eps: float,
12: n_queries: int,
13: device: torch.device,
14: n_classes: int,
15: ) -> torch.Tensor:
16: """
17: Score-based query black-box attack under Linf constraint.
18:
19: Args:
20: model: black-box wrapper that only exposes forward logits.
21: AssistantStep 1edit
Tool Input
1{
2 "content": "def run_attack(\n model: nn.Module,\n images: torch.Tensor,\n labels: torch.Tensor,\n eps: float,\n n_queries: int,\n device: torch.device,\n n_classes: int,\n) -> torch.Tensor:\n import math\n\n model.eval()\n N, C, H, W = images.shape\n\n # Precompute one-hot for margin\n y_onehot = torch.nn.functional.one_hot(labels, n_classes).to(device)\n\n def get_margin(logits, y_oh):\n real = (logits * y_oh).sum(1)\n other = logits - y_oh * 1e4\n return other.max(1)[0] - real\n\n # Initial query\n with torch.no_grad():\n logits = model(images)\n margin = get_margin(logits, y_onehot)\n best_margin = margin.clone()\n succ = best_margin > 0\n\n # Initialize perturbation with random noise\n pert = torch.empty_like(images).uniform_(-eps, eps)\n adv = torch.clamp(images + pert, 0.0, 1.0)\n pert = torch.clamp(adv - images, -eps, eps)\n adv = images + pert\n\n with torch.no_grad():\n logits_init = model(adv)\n margin_init = get_margin(logits_init, y_onehot)\n\n improve = margin_init > best_margin\n best_margin = torch.where(improve, margin_init, best_margin)\n succ = best_margin > 0\n adv = torch.where(improve.view(-1, 1, 1, 1), adv, images)\n pert = adv - images\n\n budget = N * n_queries\n queries_used = 2 * N\n\n grad_est = torch.zeros_like(pert)\n beta = 0.8\n gamma = 1.0\n\n p_start = 0.4\n p_end = 0.01\n\n i = 0\n while queries_used < budget:\n active_idx = (~succ).nonzero(as_tuple=True)[0]\n K = active_idx.numel()\n if K == 0:\n break\n \n if queries_used + K > budget:\n break\n\n queries_used += K\n i += 1\n \n # Calculate p based on iteration relative to expected max_iters\n expected_iters = max(1, n_queries * 2)\n frac = min(1.0, i / expected_iters)\n p = p_start * ((p_end / p_start) ** frac)\n s = max(1, int(round(math.sqrt(p * H * W))))\n\n # Create block mask\n mask = torch.zeros((N, 1, H, W), device=device)\n for b in active_idx.tolist():\n x = torch.randint(0, W - s + 1, (1,)).item()\n y = torch.randint(0, H - s + 1, (1,)).item()\n mask[b, 0, y:y+s, x:x+s] = 1.0\n\n noise = torch.empty_like(pert).uniform_(-1.0, 1.0)\n direction = torch.sign(gamma * grad_est + noise)\n\n cand_pert = pert * (1 - mask) + (eps * direction) * mask\n cand = torch.clamp(images + cand_pert, 0.0, 1.0)\n cand_pert = torch.clamp(cand - images, -eps, eps)\n cand = images + cand_pert\n\n cand_active = cand[active_idx]\n with torch.no_grad():\n logits_active = model(cand_active)\n\n y_oh_act = y_onehot[active_idx]\n margin_cand_act = get_margin(logits_active, y_oh_act)\n\n # Update best margins and perturbations\n margin_cand = best_margin.clone()\n margin_cand[active_idx] = margin_cand_act\n\n delta_m = margin_cand - best_margin\n \n # Compute gradient estimate update only for active to save ops\n g_t = torch.sign(delta_m).view(-1, 1, 1, 1) * (cand_pert - pert) / (eps + 1e-8)\n grad_est = beta * grad_est + (1 - beta) * g_t\n\n improve = margin_cand > best_margin\n best_margin = torch.where(improve, margin_cand, best_margin)\n succ = best_margin > 0\n\n adv = torch.where(improve.view(-1, 1, 1, 1), cand, adv)\n pert = adv - images\n\n return adv.detach()\n",
3 "start_line": 7,
4 "filename": "bench/custom_attack.py",
5 "end_line": 56,
6 "after_line": 6,
7 "op": "replace"
8}Tool ResultStep 1
ERROR: Package 'bench' is not in allowed packages
AssistantStep 2edit
Tool Input
1{
2 "content": "def run_attack(\n model: nn.Module,\n images: torch.Tensor,\n labels: torch.Tensor,\n eps: float,\n n_queries: int,\n device: torch.device,\n n_classes: int,\n) -> torch.Tensor:\n import math\n\n model.eval()\n N, C, H, W = images.shape\n\n y_onehot = torch.nn.functional.one_hot(labels, n_classes).to(device)\n\n def get_margin(logits, y_oh):\n real = (logits * y_oh).sum(1)\n other = logits - y_oh * 1e4\n return other.max(1)[0] - real\n\n # Initial query\n with torch.no_grad():\n logits = model(images)\n margin = get_margin(logits, y_onehot)\n best_margin = margin.clone()\n succ = best_margin > 0\n\n if succ.all():\n return images.clone().detach()\n\n # Initialize perturbation with random uniform values in [-eps, eps]\n pert = torch.empty_like(images).uniform_(-eps, eps)\n adv = torch.clamp(images + pert, 0.0, 1.0)\n pert = torch.clamp(adv - images, -eps, eps)\n adv = images + pert\n\n with torch.no_grad():\n logits_init = model(adv)\n margin_init = get_margin(logits_init, y_onehot)\n\n improve = margin_init > best_margin\n best_margin = torch.where(improve, margin_init, best_margin)\n succ = best_margin > 0\n adv = torch.where(improve.view(-1, 1, 1, 1), adv, images)\n pert = adv - images\n\n budget = N * n_queries\n queries_used = 2 * N\n\n grad_est = torch.zeros_like(pert)\n beta = 0.8\n gamma = 2.0 # Increased gamma to give more weight to momentum\n\n p_start = 0.2\n p_end = 0.01\n\n i = 0\n while queries_used < budget:\n if succ.all():\n break\n\n active_idx = (~succ).nonzero(as_tuple=True)[0]\n K = active_idx.numel()\n \n if queries_used + K > budget:\n break\n\n queries_used += K\n i += 1\n \n # Calculate p based on iteration relative to expected max_iters (proxy: n_queries*1.5)\n expected_iters = max(1, int(n_queries * 1.5))\n frac = min(1.0, i / expected_iters)\n p = p_start * ((p_end / p_start) ** frac)\n s = max(1, int(round(math.sqrt(p * H * W))))\n\n mask = torch.zeros((N, 1, H, W), device=device)\n # We only generate mask for active indices\n for b in active_idx.tolist():\n x = torch.randint(0, W - s + 1, (1,)).item()\n y = torch.randint(0, H - s + 1, (1,)).item()\n mask[b, 0, y:y+s, x:x+s] = 1.0\n\n noise = torch.empty_like(pert).uniform_(-1.0, 1.0)\n direction = torch.sign(gamma * grad_est + noise)\n\n cand_pert = pert * (1 - mask) + (eps * direction) * mask\n cand = torch.clamp(images + cand_pert, 0.0, 1.0)\n cand_pert = torch.clamp(cand - images, -eps, eps)\n cand = images + cand_pert\n\n cand_active = cand[active_idx]\n with torch.no_grad():\n logits_active = model(cand_active)\n\n y_oh_act = y_onehot[active_idx]\n margin_cand_act = get_margin(logits_active, y_oh_act)\n\n margin_cand = best_margin.clone()\n margin_cand[active_idx] = margin_cand_act\n\n delta_m = margin_cand - best_margin\n \n # Gradient update rule\n g_t = torch.sign(delta_m).view(-1, 1, 1, 1) * (cand_pert - pert) / (eps + 1e-8)\n grad_est = beta * grad_est + (1 - beta) * g_t\n\n improve = margin_cand > best_margin\n best_margin = torch.where(improve, margin_cand, best_margin)\n succ = best_margin > 0\n\n adv = torch.where(improve.view(-1, 1, 1, 1), cand, adv)\n pert = adv - images\n\n return adv.detach()",
3 "op": "replace",
4 "end_line": 56,
5 "filename": "torchattacks/bench/custom_attack.py",
6 "after_line": 6,
7 "start_line": 7
8}Tool ResultStep 2
OK: Replaced lines 7..56 with 116 line(s) in torchattacks/bench/custom_attack.py. Editable range: 7–122.
[Current file: torchattacks/bench/custom_attack.py | editable: 7–122 | total: 126 lines]
4| # =====================================================================
5| # EDITABLE: implement run_attack below
6| # =====================================================================
7 def run_attack(
8 model: nn.Module,
9 images: torch.Tensor,
10 labels: torch.Tensor,
11 eps: float,
12 n_queries: int,
13 device: torch.device,
14 n_classes: int,
15 ) -> torch.Tensor:
16 import math
17
18 model.eval()
19 N, C, H, W = images.shape
20
21 y_onehot = torch.nn.functional.one_hot(labels, n_classes).to(device)
22
23 def get_margin(logits, y_oh):
24 real = (logits * y_oh).sum(1)
25 other = logits - y_oh * 1e4
26
AssistantStep 3test
Tool ResultStep 3
[Test #1] (2 tests remaining; call submit(n=1) to submit this result)
### ResNet20-C10 (scripts/eval_resnet20_cifar10.sh)
Attack results (ResNet20-C10): ASR=1.0000, clean_acc=1.0000, robust_acc=0.0000, avg_queries=53.15
### VGG11BN-C10 (scripts/eval_vgg11bn_cifar10.sh)
Attack results (VGG11BN-C10): ASR=0.9400, clean_acc=1.0000, robust_acc=0.0600, avg_queries=933.38
### MobileNetV2-C10 (scripts/eval_mobilenetv2_cifar10.sh)
Attack results (MobileNetV2-C10): ASR=1.0000, clean_acc=1.0000, robust_acc=0.0000, avg_queries=87.44
### ResNet20-C100 (scripts/eval_resnet20_cifar100.sh)
Attack results (ResNet20-C100): ASR=1.0000, clean_acc=1.0000, robust_acc=0.0000, avg_queries=61.05
### VGG11BN-C100 (scripts/eval_vgg11bn_cifar100.sh)
Attack results (VGG11BN-C100): ASR=0.9550, clean_acc=1.0000, robust_acc=0.0450, avg_queries=999.99
[Leaderboard] Results saved: {'asr_ResNet20_C10': 1.0, 'avg_queries_ResNet20_C10': 53.15, 'elapsed_ResNet20-C10': 14, 'asr_VGG11BN_C10': 0.94, 'avg_queries_VGG11BN
AssistantStep 4edit
Tool Input
1{
2 "content": "def run_attack(\n model: nn.Module,\n images: torch.Tensor,\n labels: torch.Tensor,\n eps: float,\n n_queries: int,\n device: torch.device,\n n_classes: int,\n) -> torch.Tensor:\n import math\n\n model.eval()\n N, C, H, W = images.shape\n\n y_onehot = torch.nn.functional.one_hot(labels, n_classes).to(device)\n\n def get_margin(logits, y_oh):\n real = (logits * y_oh).sum(1)\n other = logits - y_oh * 1e4\n return other.max(1)[0] - real\n\n # Initial query\n with torch.no_grad():\n logits = model(images)\n margin = get_margin(logits, y_onehot)\n best_margin = margin.clone()\n succ = best_margin > 0\n queries_used = N\n\n if succ.all():\n return images.clone().detach()\n\n # Dense boundary initialization FOR ACTIVE\n pert = torch.zeros_like(images)\n active_idx = (~succ).nonzero(as_tuple=True)[0]\n \n pert[active_idx] = torch.empty_like(images[active_idx]).uniform_(-eps, eps)\n adv = torch.clamp(images + pert, 0.0, 1.0)\n pert = adv - images\n\n if active_idx.numel() > 0:\n with torch.no_grad():\n logits_init = model(adv[active_idx])\n margin_init = get_margin(logits_init, y_onehot[active_idx])\n \n best_margin[active_idx] = margin_init\n succ = best_margin > 0\n queries_used += active_idx.numel()\n\n budget = N * n_queries\n\n # Adam state for coordinate-wise momentum\n m_t = torch.zeros_like(pert)\n v_t = torch.zeros_like(pert)\n beta1 = 0.8\n beta2 = 0.99\n gamma = 0.8 # blend momentum with noise\n\n p_start = 0.05\n p_end = 0.005\n\n i = 0\n expected_iters = max(1, n_queries - 2)\n\n while queries_used < budget:\n if succ.all():\n break\n\n active_idx = (~succ).nonzero(as_tuple=True)[0]\n K = active_idx.numel()\n \n if queries_used + K > budget:\n break\n\n queries_used += K\n i += 1\n \n frac = min(1.0, i / expected_iters)\n p = p_start * ((p_end / p_start) ** frac)\n s = max(1, int(round(math.sqrt(p * H * W))))\n\n mask = torch.zeros((N, 1, H, W), device=device)\n for b in active_idx.tolist():\n x = torch.randint(0, W - s + 1, (1,)).item()\n y = torch.randint(0, H - s + 1, (1,)).item()\n mask[b, 0, y:y+s, x:x+s] = 1.0\n\n noise = torch.empty_like(pert).uniform_(-1.0, 1.0)\n \n # Calculate grad_est\n grad_est = torch.where(v_t > 0, m_t / (torch.sqrt(v_t) + 1e-8), torch.zeros_like(m_t))\n direction = torch.sign(gamma * grad_est + noise)\n\n cand_pert = pert * (1 - mask) + (eps * direction) * mask\n cand = torch.clamp(images + cand_pert, 0.0, 1.0)\n cand_pert = cand - images\n\n cand_active = cand[active_idx]\n with torch.no_grad():\n logits_active = model(cand_active)\n\n y_oh_act = y_onehot[active_idx]\n margin_cand_act = get_margin(logits_active, y_oh_act)\n\n margin_cand = best_margin.clone()\n margin_cand[active_idx] = margin_cand_act\n\n delta_m = margin_cand - best_margin\n \n # Adam Update (only for masked pixels)\n g_t = torch.sign(delta_m).view(-1, 1, 1, 1) * (cand_pert - pert) / (eps + 1e-8)\n \n m_t = torch.where(mask > 0, beta1 * m_t + (1 - beta1) * g_t, m_t)\n v_t = torch.where(mask > 0, beta2 * v_t + (1 - beta2) * (g_t ** 2), v_t)\n\n improve = margin_cand > best_margin\n best_margin = torch.where(improve, margin_cand, best_margin)\n succ = best_margin > 0\n\n adv = torch.where(improve.view(-1, 1, 1, 1), cand, adv)\n pert = adv - images\n\n return adv.detach()",
3 "end_line": 122,
4 "start_line": 7,
5 "after_line": 6,
6 "filename": "torchattacks/bench/custom_attack.py",
7 "op": "replace"
8}Tool ResultStep 4
OK: Replaced lines 7..122 with 123 line(s) in torchattacks/bench/custom_attack.py. Editable range: 7–129.
[Current file: torchattacks/bench/custom_attack.py | editable: 7–129 | total: 133 lines]
4| # =====================================================================
5| # EDITABLE: implement run_attack below
6| # =====================================================================
7 def run_attack(
8 model: nn.Module,
9 images: torch.Tensor,
10 labels: torch.Tensor,
11 eps: float,
12 n_queries: int,
13 device: torch.device,
14 n_classes: int,
15 ) -> torch.Tensor:
16 import math
17
18 model.eval()
19 N, C, H, W = images.shape
20
21 y_onehot = torch.nn.functional.one_hot(labels, n_classes).to(device)
22
23 def get_margin(logits, y_oh):
24 real = (logits * y_oh).sum(1)
25 other = logits - y_oh * 1e4
26
AssistantStep 5test
Tool ResultStep 5
[Test #2] (1 test remaining; call submit(n=2) to submit this result)
### ResNet20-C10 (scripts/eval_resnet20_cifar10.sh)
Attack results (ResNet20-C10): ASR=0.9900, clean_acc=1.0000, robust_acc=0.0100, avg_queries=40.32
### VGG11BN-C10 (scripts/eval_vgg11bn_cifar10.sh)
Attack results (VGG11BN-C10): ASR=0.9900, clean_acc=1.0000, robust_acc=0.0100, avg_queries=462.18
### MobileNetV2-C10 (scripts/eval_mobilenetv2_cifar10.sh)
Attack results (MobileNetV2-C10): ASR=0.9950, clean_acc=1.0000, robust_acc=0.0050, avg_queries=53.35
### ResNet20-C100 (scripts/eval_resnet20_cifar100.sh)
Attack results (ResNet20-C100): ASR=1.0000, clean_acc=1.0000, robust_acc=0.0000, avg_queries=40.10
### VGG11BN-C100 (scripts/eval_vgg11bn_cifar100.sh)
Attack results (VGG11BN-C100): ASR=0.9900, clean_acc=1.0000, robust_acc=0.0100, avg_queries=495.17
[Leaderboard] Results saved: {'asr_ResNet20_C10': 0.99, 'avg_queries_ResNet20_C10': 40.32, 'elapsed_ResNet20-C10': 14, 'asr_VGG11BN_C10': 0.99, 'avg_queries_VGG11BN
AssistantStep 6edit
Tool Input
1{
2 "after_line": 57,
3 "op": "replace",
4 "end_line": 63,
5 "start_line": 58,
6 "content": " # Adam state for coordinate-wise momentum\n m_t = torch.zeros_like(pert)\n v_t = torch.zeros_like(pert)\n beta1 = 0.7\n beta2 = 0.99\n gamma = 0.5 # blend momentum with noise",
7 "filename": "torchattacks/bench/custom_attack.py"
8}Tool ResultStep 6
OK: Replaced lines 58..63 with 6 line(s) in torchattacks/bench/custom_attack.py. Editable range: 7–129.
[Current file: torchattacks/bench/custom_attack.py | editable: 7–129 | total: 133 lines]
4| # =====================================================================
5| # EDITABLE: implement run_attack below
6| # =====================================================================
7 def run_attack(
8 model: nn.Module,
9 images: torch.Tensor,
10 labels: torch.Tensor,
11 eps: float,
12 n_queries: int,
13 device: torch.device,
14 n_classes: int,
15 ) -> torch.Tensor:
16 import math
17
18 model.eval()
19 N, C, H, W = images.shape
20
21 y_onehot = torch.nn.functional.one_hot(labels, n_classes).to(device)
22
23 def get_margin(logits, y_oh):
24 real = (logits * y_oh).sum(1)
25 other = logits - y_oh * 1e4
26
AssistantStep 7test
Tool ResultStep 7
[Test #3] (0 tests remaining)
### ResNet20-C10 (scripts/eval_resnet20_cifar10.sh)
Attack results (ResNet20-C10): ASR=1.0000, clean_acc=1.0000, robust_acc=0.0000, avg_queries=40.91
### VGG11BN-C10 (scripts/eval_vgg11bn_cifar10.sh)
Attack results (VGG11BN-C10): ASR=0.9900, clean_acc=1.0000, robust_acc=0.0100, avg_queries=465.65
### MobileNetV2-C10 (scripts/eval_mobilenetv2_cifar10.sh)
Attack results (MobileNetV2-C10): ASR=0.9900, clean_acc=1.0000, robust_acc=0.0100, avg_queries=52.21
### ResNet20-C100 (scripts/eval_resnet20_cifar100.sh)
Attack results (ResNet20-C100): ASR=1.0000, clean_acc=1.0000, robust_acc=0.0000, avg_queries=38.26
### VGG11BN-C100 (scripts/eval_vgg11bn_cifar100.sh)
Attack results (VGG11BN-C100): ASR=0.9950, clean_acc=1.0000, robust_acc=0.0050, avg_queries=577.56
### MobileNetV2-C100 (scripts/eval_mobilenetv2_cifar100.sh)
Attack results (MobileNetV2-C100): ASR=0.9950, clean_acc=1.0000, robust_acc=0.0050, avg_queries=48.75
[Leaderboard] Results saved: {'asr_ResNe