Agent Conversation: gpt-5.4-pro
Model: gpt-5.4-pro
Total steps: 5
Messages: 11
UserStep 0
Task: security-adversarial-attack-black-box-score
Score-Based Query Black-Box Attack under Linf Constraint
Research Question
Can you design a stronger score-based query black-box attack that improves attack success rate (ASR) under a fixed query budget and L_inf perturbation constraint?
Objective
Implement a better black-box attack in bench/custom_attack.py:
- Threat model: query black-box (no gradient access).
- Constraint:
||x_adv - x||_inf <= eps. - Budget:
n_queriesis a per-sample query budget. - Primary metric: maximize
ASRunder fixed budget. - Tie-break: for similar ASR, lower
avg_queriesis better.
Editable Interface
You must implement:
run_attack(model, images, labels, eps, n_queries, device, n_classes) -> adv_images
Inputs:
model: black-box wrapper that returns logits only.images: tensor of shape(N, C, H, W), in[0, 1].labels: tensor of shape(N,).n_classes: 10 for CIFAR-10, 100 for CIFAR-100.
Output:
adv_images: tensor with same shape asimages, values in[0, 1].
Trusted Evaluation Logic
The evaluation logic in bench/run_eval.py is trusted and not editable.
- It tracks all model queries through a wrapper.
- If a batch exceeds query budget (
batch_size * n_queries), the entire batch is marked as attack failure. L_infand[0, 1]validity are checked per sample; only invalid samples are marked as attack failure.
Do not attempt to bypass wrapper behavior or evaluation logic. Only improve the attack algorithm in custom_attack.py.
Query Semantics
- One call to
model(x)consumesx.shape[0]queries. - Repeated calls on the same sample still consume additional queries.
- Different batch partitioning should be treated as equivalent total budget usage.
Performance Hint
The benchmark is query-heavy. Prefer batch-parallel score updates instead of pure per-sample loops when possible.
Evaluation Scenarios (6)
- ResNet20 on CIFAR-10
- VGG11-BN on CIFAR-10
- MobileNetV2 on CIFAR-10
- ResNet20 on CIFAR-100
- VGG11-BN on CIFAR-100
- MobileNetV2 on CIFAR-100
Reported metrics line format:
ATTACK_METRICS asr=... clean_acc=... robust_acc=... avg_queries=...
torchattacks/bench/custom_attack.py [EDITABLE — lines 7–56 only]
1: import torch
2: import torch.nn as nn
3:
4: # =====================================================================
5: # EDITABLE: implement run_attack below
6: # =====================================================================
7: def run_attack(
8: model: nn.Module,
9: images: torch.Tensor,
10: labels: torch.Tensor,
11: eps: float,
12: n_queries: int,
13: device: torch.device,
14: n_classes: int,
15: ) -> torch.Tensor:
16: """
17: Score-based query black-box attack under Linf constraint.
18:
19: Args:
20: model: black-box wrapper that only exposes forward logits.
21: AssistantStep 1edit
Tool Input
1{
2 "op": "replace",
3 "filename": "torchattacks/bench/custom_attack.py",
4 "content": "def run_attack(\n model: nn.Module,\n images: torch.Tensor,\n labels: torch.Tensor,\n eps: float,\n n_queries: int,\n device: torch.device,\n n_classes: int,\n) -> torch.Tensor:\n import os, sys\n sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), \"..\")))\n import torchattacks\n\n model.eval()\n q = max(1, int(n_queries))\n if q < 8:\n atk = torchattacks.Square(model, norm=\"Linf\", eps=eps, n_queries=q, n_restarts=1,\n p_init=0.8, seed=0, verbose=False, loss=\"margin\", resc_schedule=True)\n adv = atk(images, labels)\n delta = torch.clamp(adv - images, -eps, eps)\n return torch.clamp(images + delta, 0.0, 1.0).detach()\n\n n, c, h, w = images.shape\n budget = n * q\n q1 = max(1, min(q - 2, int(0.6 * q)))\n topk = min(4, max(1, n_classes - 1))\n\n def score(z, y):\n a = torch.arange(z.size(0), device=z.device)\n zz = z.clone()\n zz[a, y] = -1e9\n return z[a, y] - torch.logsumexp(zz.topk(topk, 1).values, 1)\n\n atk = torchattacks.Square(model, norm=\"Linf\", eps=eps, n_queries=q1, n_restarts=1,\n p_init=0.8, seed=0, verbose=False, loss=\"ce\", resc_schedule=True)\n with torch.no_grad():\n adv = atk(images, labels)\n used, delta = n * q1, adv - images\n logits = model(adv)\n used += n\n scr, pred = score(logits, labels), logits.argmax(1)\n\n while used < budget and pred.eq(labels).any():\n idx = pred.eq(labels).nonzero(as_tuple=False).squeeze(1)\n rem = budget - used\n if rem <= 0:\n break\n two = rem >= 2 * idx.numel()\n cap = max(1, rem // (2 if two else 1))\n if idx.numel() > cap:\n idx = idx[scr[idx].argsort()[:cap]]\n m = idx.numel()\n p = max(0.02, 0.8 * (1.0 - used / budget) ** 2)\n s = max(1, min(h, w, int(round((p * h * w) ** 0.5))))\n d1 = delta[idx].clone()\n d2 = delta[idx].clone() if two else None\n sg = eps * (torch.randint(0, 2, (m, c, 1, 1), device=device).to(images.dtype) * 2 - 1)\n xs = torch.randint(0, h - s + 1, (m,), device=device)\n ys = torch.randint(0, w - s + 1, (m,), device=device)\n for i in range(m):\n x0, y0 = xs[i].item(), ys[i].item()\n d1[i, :, x0:x0 + s, y0:y0 + s] = sg[i]\n if two:\n d2[i, :, x0:x0 + s, y0:y0 + s] = -sg[i]\n c1 = torch.clamp(images[idx] + d1, 0.0, 1.0)\n if two:\n c2 = torch.clamp(images[idx] + d2, 0.0, 1.0)\n z = model(torch.cat([c1, c2], 0)); used += 2 * m\n z1, z2 = z[:m], z[m:]\n s1, s2 = score(z1, labels[idx]), score(z2, labels[idx])\n p1, p2 = z1.argmax(1), z2.argmax(1)\n pick2 = ((p2 != labels[idx]) & (p1 == labels[idx])) | (((p2 != labels[idx]) == (p1 != labels[idx])) & (s2 < s1))\n cand = torch.where(pick2[:, None, None, None], c2, c1)\n ns, np = torch.where(pick2, s2, s1), torch.where(pick2, p2, p1)\n else:\n z1 = model(c1); used += m\n cand, ns, np = c1, score(z1, labels[idx]), z1.argmax(1)\n good = (np != labels[idx]) | (ns < scr[idx])\n if good.any():\n j = idx[good]\n adv[j], delta[j], scr[j], pred[j] = cand[good], cand[good] - images[j], ns[good], np[good]\n\n delta = torch.clamp(adv - images, -eps, eps)\n return torch.clamp(images + delta, 0.0, 1.0).detach()",
5 "after_line": 0,
6 "start_line": 7,
7 "end_line": 56
8}Tool ResultStep 1
OK: Replaced lines 7..56 with 84 line(s) in torchattacks/bench/custom_attack.py. Editable range: 7–90.
[Current file: torchattacks/bench/custom_attack.py | editable: 7–90 | total: 94 lines]
4| # =====================================================================
5| # EDITABLE: implement run_attack below
6| # =====================================================================
7 def run_attack(
8 model: nn.Module,
9 images: torch.Tensor,
10 labels: torch.Tensor,
11 eps: float,
12 n_queries: int,
13 device: torch.device,
14 n_classes: int,
15 ) -> torch.Tensor:
16 import os, sys
17 sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
18 import torchattacks
19
20 model.eval()
21 q = max(1, int(n_queries))
22 if q < 8:
23 atk = torchattacks.Square(model, norm="Linf", eps=eps, n_queries=q, n_rest
AssistantStep 2test
Tool ResultStep 2
[Test #1] (2 tests remaining; call submit(n=1) to submit this result)
### ResNet20-C10 (scripts/eval_resnet20_cifar10.sh)
Attack results (ResNet20-C10): ASR=0.9950, clean_acc=1.0000, robust_acc=0.0050, avg_queries=279.32
### VGG11BN-C10 (scripts/eval_vgg11bn_cifar10.sh)
Attack results (VGG11BN-C10): ASR=0.9100, clean_acc=1.0000, robust_acc=0.0900, avg_queries=560.62
### MobileNetV2-C10 (scripts/eval_mobilenetv2_cifar10.sh)
Attack results (MobileNetV2-C10): ASR=0.9900, clean_acc=1.0000, robust_acc=0.0100, avg_queries=251.03
### ResNet20-C100 (scripts/eval_resnet20_cifar100.sh)
Attack results (ResNet20-C100): ASR=0.9900, clean_acc=1.0000, robust_acc=0.0100, avg_queries=309.09
### VGG11BN-C100 (scripts/eval_vgg11bn_cifar100.sh)
Attack results (VGG11BN-C100): ASR=0.8600, clean_acc=1.0000, robust_acc=0.1400, avg_queries=589.25
[Leaderboard] Results saved: {'asr_ResNet20_C10': 0.995, 'avg_queries_ResNet20_C10': 279.32, 'elapsed_ResNet20-C10': 167, 'asr_VGG11BN_C10': 0.91, 'avg_queries_
AssistantStep 3edit
Tool Input
1{
2 "op": "replace",
3 "filename": "torchattacks/bench/custom_attack.py",
4 "content": "def run_attack(\n model: nn.Module,\n images: torch.Tensor,\n labels: torch.Tensor,\n eps: float,\n n_queries: int,\n device: torch.device,\n n_classes: int,\n) -> torch.Tensor:\n import os, sys\n sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), \"..\")))\n import torchattacks\n\n model.eval()\n q = max(1, int(n_queries))\n if q < 8:\n atk = torchattacks.Square(model, norm=\"Linf\", eps=eps, n_queries=q, n_restarts=1,\n p_init=0.8, seed=0, verbose=False, loss=\"margin\", resc_schedule=True)\n adv = atk(images, labels)\n delta = torch.clamp(adv - images, -eps, eps)\n return torch.clamp(images + delta, 0.0, 1.0).detach()\n\n n, c, h, w = images.shape\n budget = n * q\n topk = min(2 if n_classes <= 10 else 4, max(1, n_classes - 1))\n\n def qcount():\n names = (\"query_count\", \"queries\", \"num_queries\", \"n_queries\", \"total_queries\", \"count\")\n for obj in (model, getattr(model, \"model\", None)):\n if obj is None:\n continue\n for name in names:\n v = getattr(obj, name, None)\n if isinstance(v, (int, float)):\n return int(v)\n if torch.is_tensor(v) and v.numel() == 1:\n return int(v.item())\n return None\n\n def proj(x, base):\n return torch.clamp(base + torch.clamp(x - base, -eps, eps), 0.0, 1.0)\n\n def score(z, y):\n a = torch.arange(z.size(0), device=z.device)\n zz = z.clone()\n zz[a, y] = -1e9\n return z[a, y] - torch.logsumexp(zz.topk(topk, 1).values, 1)\n\n def eval_logits(x):\n q0 = qcount()\n z = model(x)\n q1 = qcount()\n spent = (q1 - q0) if (q0 is not None and q1 is not None and q1 >= q0) else x.size(0)\n return z, spent\n\n def run_square(x, y, qx, loss, p_init, seed):\n q0 = qcount()\n atk = torchattacks.Square(model=model, norm=\"Linf\", eps=eps, n_queries=int(qx), n_restarts=1,\n p_init=p_init, seed=seed, verbose=False, loss=loss, resc_schedule=True)\n out = proj(atk(x, y), x)\n q1 = qcount()\n spent = (q1 - q0) if (q0 is not None and q1 is not None and q1 >= q0) else x.size(0) * int(qx)\n return out, spent\n\n with torch.no_grad():\n used = 0\n q1 = max(1, min(q - 5, int(0.36 * q)))\n adv, spent = run_square(images, labels, q1, \"margin\", 0.8, 0)\n used += spent\n logits, spent = eval_logits(adv)\n used += spent\n scr, pred = score(logits, labels), logits.argmax(1)\n\n for share, loss, p_init, seed in ((0.28, \"ce\", 0.25, 1), (0.18, \"margin\", 0.08, 2)):\n idx = pred.eq(labels).nonzero(as_tuple=False).squeeze(1)\n if idx.numel() == 0 or used >= budget:\n break\n left = budget - used\n m = idx.numel()\n qx = min(int(share * q), max(0, left // m - 2))\n if qx < 1:\n continue\n cand, spent = run_square(images[idx], labels[idx], qx, loss, p_init, seed)\n used += spent\n logits2, spent = eval_logits(cand)\n used += spent\n s2, p2 = score(logits2, labels[idx]), logits2.argmax(1)\n take = ((p2 != labels[idx]) & (pred[idx] == labels[idx])) | (((p2 != labels[idx]) == (pred[idx] != labels[idx])) & (s2 < scr[idx]))\n if take.any():\n j = idx[take]\n adv[j], scr[j], pred[j] = cand[take], s2[take], p2[take]\n\n delta = adv - images\n while used < budget and pred.eq(labels).any():\n idx = pred.eq(labels).nonzero(as_tuple=False).squeeze(1)\n left = budget - used\n if left <= 0:\n break\n two = left >= 2 * idx.numel()\n need = 2 if two else 1\n cap = max(1, left // need)\n if idx.numel() > cap:\n idx = idx[scr[idx].argsort()[:cap]]\n m = idx.numel()\n if m == 0:\n break\n t = max(0.0, 1.0 - used / max(1, budget))\n p = 0.04 + 0.55 * t * t\n s = max(1, min(h, w, int(round((p * h * w) ** 0.5))))\n d1 = delta[idx].clone()\n d2 = delta[idx].clone() if two else None\n textured = s <= 6\n xs = torch.randint(0, h - s + 1, (m,), device=device)\n ys = torch.randint(0, w - s + 1, (m,), device=device)\n for i in range(m):\n if textured:\n patch = eps * (torch.randint(0, 2, (c, s, s), device=device).to(images.dtype) * 2 - 1)\n else:\n patch = eps * (torch.randint(0, 2, (c, 1, 1), device=device).to(images.dtype) * 2 - 1)\n x0, y0 = xs[i].item(), ys[i].item()\n d1[i, :, x0:x0 + s, y0:y0 + s] = patch\n if two:\n d2[i, :, x0:x0 + s, y0:y0 + s] = -patch\n c1 = torch.clamp(images[idx] + torch.clamp(d1, -eps, eps), 0.0, 1.0)\n if two:\n c2 = torch.clamp(images[idx] + torch.clamp(d2, -eps, eps), 0.0, 1.0)\n z, spent = eval_logits(torch.cat([c1, c2], 0))\n used += spent\n z1, z2 = z[:m], z[m:]\n s1, s2 = score(z1, labels[idx]), score(z2, labels[idx])\n p1, p2 = z1.argmax(1), z2.argmax(1)\n pick2 = ((p2 != labels[idx]) & (p1 == labels[idx])) | (((p2 != labels[idx]) == (p1 != labels[idx])) & (s2 < s1))\n cand = torch.where(pick2[:, None, None, None], c2, c1)\n ns, np = torch.where(pick2, s2, s1), torch.where(pick2, p2, p1)\n else:\n z1, spent = eval_logits(c1)\n used += spent\n cand, ns, np = c1, score(z1, labels[idx]), z1.argmax(1)\n good = (np != labels[idx]) | (ns < scr[idx])\n if good.any():\n j = idx[good]\n adv[j], delta[j], scr[j], pred[j] = cand[good], cand[good] - images[j], ns[good], np[good]\n\n delta = torch.clamp(adv - images, -eps, eps)\n return torch.clamp(images + delta, 0.0, 1.0).detach()",
5 "after_line": 0,
6 "start_line": 7,
7 "end_line": 90
8}Tool ResultStep 3
OK: Replaced lines 7..90 with 145 line(s) in torchattacks/bench/custom_attack.py. Editable range: 7–151.
[Current file: torchattacks/bench/custom_attack.py | editable: 7–151 | total: 155 lines]
4| # =====================================================================
5| # EDITABLE: implement run_attack below
6| # =====================================================================
7 def run_attack(
8 model: nn.Module,
9 images: torch.Tensor,
10 labels: torch.Tensor,
11 eps: float,
12 n_queries: int,
13 device: torch.device,
14 n_classes: int,
15 ) -> torch.Tensor:
16 import os, sys
17 sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
18 import torchattacks
19
20 model.eval()
21 q = max(1, int(n_queries))
22 if q < 8:
23 atk = torchattacks.Square(model, norm="Linf", eps=eps, n_queries=q, n_
AssistantStep 4test
Tool ResultStep 4
[Test #2] (1 test remaining; call submit(n=2) to submit this result)
### ResNet20-C10 (scripts/eval_resnet20_cifar10.sh)
Attack results (ResNet20-C10): ASR=0.9950, clean_acc=1.0000, robust_acc=0.0050, avg_queries=418.08
### VGG11BN-C10 (scripts/eval_vgg11bn_cifar10.sh)
Attack results (VGG11BN-C10): ASR=0.8600, clean_acc=1.0000, robust_acc=0.1400, avg_queries=766.00
### MobileNetV2-C10 (scripts/eval_mobilenetv2_cifar10.sh)
Attack results (MobileNetV2-C10): ASR=0.9900, clean_acc=1.0000, robust_acc=0.0100, avg_queries=628.02
### ResNet20-C100 (scripts/eval_resnet20_cifar100.sh)
Attack results (ResNet20-C100): ASR=0.9950, clean_acc=1.0000, robust_acc=0.0050, avg_queries=582.66
### VGG11BN-C100 (scripts/eval_vgg11bn_cifar100.sh)
Attack results (VGG11BN-C100): ASR=0.8150, clean_acc=1.0000, robust_acc=0.1850, avg_queries=783.15
[Leaderboard] Results saved: {'asr_ResNet20_C10': 0.995, 'avg_queries_ResNet20_C10': 418.08, 'elapsed_ResNet20-C10': 76, 'asr_VGG11BN_C10': 0.86, 'avg_queries_VG
AssistantStep 5submit
Tool Input
1{
2 "n": 1
3}Tool ResultStep 5
[submit] Submitting result from test #1 as final. ### ResNet20-C10 (scripts/eval_resnet20_cifar10.sh) Attack results (ResNet20-C10): ASR=0.9950, clean_acc=1.0000, robust_acc=0.0050, avg_queries=279.32 ### VGG11BN-C10 (scripts/eval_vgg11bn_cifar10.sh) Attack results (VGG11BN-C10): ASR=0.9100, clean_acc=1.0000, robust_acc=0.0900, avg_queries=560.62 ### MobileNetV2-C10 (scripts/eval_mobilenetv2_cifar10.sh) Attack results (MobileNetV2-C10): ASR=0.9900, clean_acc=1.0000, robust_acc=0.0100, avg_queries=251.03 ### ResNet20-C100 (scripts/eval_resnet20_cifar100.sh) Attack results (ResNet20-C100): ASR=0.9900, clean_acc=1.0000, robust_acc=0.0100, avg_queries=309.09 ### VGG11BN-C100 (scripts/eval_vgg11bn_cifar100.sh) Attack results (VGG11BN-C100): ASR=0.8600, clean_acc=1.0000, robust_acc=0.1400, avg_queries=589.25 ### MobileNetV2-C100 (scripts/eval_mobilenetv2_cifar100.sh) Attack results (MobileNetV2-C100): ASR=1.0000, clean_acc=1.0000, robust_acc=0.0000, avg_queries=315.85 [Leaderboard] Re