Agent Conversation: gpt-5.4-pro

Model: gpt-5.4-pro
Total steps: 5
Messages: 11
UserStep 0

Task: security-adversarial-attack-black-box-score

Score-Based Query Black-Box Attack under Linf Constraint

Research Question

Can you design a stronger score-based query black-box attack that improves attack success rate (ASR) under a fixed query budget and L_inf perturbation constraint?

Objective

Implement a better black-box attack in bench/custom_attack.py:

  • Threat model: query black-box (no gradient access).
  • Constraint: ||x_adv - x||_inf <= eps.
  • Budget: n_queries is a per-sample query budget.
  • Primary metric: maximize ASR under fixed budget.
  • Tie-break: for similar ASR, lower avg_queries is better.

Editable Interface

You must implement:

run_attack(model, images, labels, eps, n_queries, device, n_classes) -> adv_images

Inputs:

  • model: black-box wrapper that returns logits only.
  • images: tensor of shape (N, C, H, W), in [0, 1].
  • labels: tensor of shape (N,).
  • n_classes: 10 for CIFAR-10, 100 for CIFAR-100.

Output:

  • adv_images: tensor with same shape as images, values in [0, 1].

Trusted Evaluation Logic

The evaluation logic in bench/run_eval.py is trusted and not editable.

  • It tracks all model queries through a wrapper.
  • If a batch exceeds query budget (batch_size * n_queries), the entire batch is marked as attack failure.
  • L_inf and [0, 1] validity are checked per sample; only invalid samples are marked as attack failure.

Do not attempt to bypass wrapper behavior or evaluation logic. Only improve the attack algorithm in custom_attack.py.

Query Semantics

  • One call to model(x) consumes x.shape[0] queries.
  • Repeated calls on the same sample still consume additional queries.
  • Different batch partitioning should be treated as equivalent total budget usage.

Performance Hint

The benchmark is query-heavy. Prefer batch-parallel score updates instead of pure per-sample loops when possible.

Evaluation Scenarios (6)

  • ResNet20 on CIFAR-10
  • VGG11-BN on CIFAR-10
  • MobileNetV2 on CIFAR-10
  • ResNet20 on CIFAR-100
  • VGG11-BN on CIFAR-100
  • MobileNetV2 on CIFAR-100

Reported metrics line format:

ATTACK_METRICS asr=... clean_acc=... robust_acc=... avg_queries=...

torchattacks/bench/custom_attack.py [EDITABLE — lines 7–56 only]

     1: import torch
     2: import torch.nn as nn
     3: 
     4: # =====================================================================
     5: # EDITABLE: implement run_attack below
     6: # =====================================================================
     7: def run_attack(
     8:     model: nn.Module,
     9:     images: torch.Tensor,
    10:     labels: torch.Tensor,
    11:     eps: float,
    12:     n_queries: int,
    13:     device: torch.device,
    14:     n_classes: int,
    15: ) -> torch.Tensor:
    16:     """
    17:     Score-based query black-box attack under Linf constraint.
    18: 
    19:     Args:
    20:         model: black-box wrapper that only exposes forward logits.
    21:     
AssistantStep 1edit
Tool Input
1{
2  "op": "replace",
3  "filename": "torchattacks/bench/custom_attack.py",
4  "content": "def run_attack(\n    model: nn.Module,\n    images: torch.Tensor,\n    labels: torch.Tensor,\n    eps: float,\n    n_queries: int,\n    device: torch.device,\n    n_classes: int,\n) -> torch.Tensor:\n    import os, sys\n    sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), \"..\")))\n    import torchattacks\n\n    model.eval()\n    q = max(1, int(n_queries))\n    if q < 8:\n        atk = torchattacks.Square(model, norm=\"Linf\", eps=eps, n_queries=q, n_restarts=1,\n                                  p_init=0.8, seed=0, verbose=False, loss=\"margin\", resc_schedule=True)\n        adv = atk(images, labels)\n        delta = torch.clamp(adv - images, -eps, eps)\n        return torch.clamp(images + delta, 0.0, 1.0).detach()\n\n    n, c, h, w = images.shape\n    budget = n * q\n    q1 = max(1, min(q - 2, int(0.6 * q)))\n    topk = min(4, max(1, n_classes - 1))\n\n    def score(z, y):\n        a = torch.arange(z.size(0), device=z.device)\n        zz = z.clone()\n        zz[a, y] = -1e9\n        return z[a, y] - torch.logsumexp(zz.topk(topk, 1).values, 1)\n\n    atk = torchattacks.Square(model, norm=\"Linf\", eps=eps, n_queries=q1, n_restarts=1,\n                              p_init=0.8, seed=0, verbose=False, loss=\"ce\", resc_schedule=True)\n    with torch.no_grad():\n        adv = atk(images, labels)\n        used, delta = n * q1, adv - images\n        logits = model(adv)\n        used += n\n        scr, pred = score(logits, labels), logits.argmax(1)\n\n        while used < budget and pred.eq(labels).any():\n            idx = pred.eq(labels).nonzero(as_tuple=False).squeeze(1)\n            rem = budget - used\n            if rem <= 0:\n                break\n            two = rem >= 2 * idx.numel()\n            cap = max(1, rem // (2 if two else 1))\n            if idx.numel() > cap:\n                idx = idx[scr[idx].argsort()[:cap]]\n            m = idx.numel()\n            p = max(0.02, 0.8 * (1.0 - used / budget) ** 2)\n            s = max(1, min(h, w, int(round((p * h * w) ** 0.5))))\n            d1 = delta[idx].clone()\n            d2 = delta[idx].clone() if two else None\n            sg = eps * (torch.randint(0, 2, (m, c, 1, 1), device=device).to(images.dtype) * 2 - 1)\n            xs = torch.randint(0, h - s + 1, (m,), device=device)\n            ys = torch.randint(0, w - s + 1, (m,), device=device)\n            for i in range(m):\n                x0, y0 = xs[i].item(), ys[i].item()\n                d1[i, :, x0:x0 + s, y0:y0 + s] = sg[i]\n                if two:\n                    d2[i, :, x0:x0 + s, y0:y0 + s] = -sg[i]\n            c1 = torch.clamp(images[idx] + d1, 0.0, 1.0)\n            if two:\n                c2 = torch.clamp(images[idx] + d2, 0.0, 1.0)\n                z = model(torch.cat([c1, c2], 0)); used += 2 * m\n                z1, z2 = z[:m], z[m:]\n                s1, s2 = score(z1, labels[idx]), score(z2, labels[idx])\n                p1, p2 = z1.argmax(1), z2.argmax(1)\n                pick2 = ((p2 != labels[idx]) & (p1 == labels[idx])) | (((p2 != labels[idx]) == (p1 != labels[idx])) & (s2 < s1))\n                cand = torch.where(pick2[:, None, None, None], c2, c1)\n                ns, np = torch.where(pick2, s2, s1), torch.where(pick2, p2, p1)\n            else:\n                z1 = model(c1); used += m\n                cand, ns, np = c1, score(z1, labels[idx]), z1.argmax(1)\n            good = (np != labels[idx]) | (ns < scr[idx])\n            if good.any():\n                j = idx[good]\n                adv[j], delta[j], scr[j], pred[j] = cand[good], cand[good] - images[j], ns[good], np[good]\n\n        delta = torch.clamp(adv - images, -eps, eps)\n        return torch.clamp(images + delta, 0.0, 1.0).detach()",
5  "after_line": 0,
6  "start_line": 7,
7  "end_line": 56
8}
Tool ResultStep 1
OK: Replaced lines 7..56 with 84 line(s) in torchattacks/bench/custom_attack.py. Editable range: 7–90.

[Current file: torchattacks/bench/custom_attack.py | editable: 7–90 | total: 94 lines]
     4| # =====================================================================
     5| # EDITABLE: implement run_attack below
     6| # =====================================================================
     7  def run_attack(
     8      model: nn.Module,
     9      images: torch.Tensor,
    10      labels: torch.Tensor,
    11      eps: float,
    12      n_queries: int,
    13      device: torch.device,
    14      n_classes: int,
    15  ) -> torch.Tensor:
    16      import os, sys
    17      sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
    18      import torchattacks
    19  
    20      model.eval()
    21      q = max(1, int(n_queries))
    22      if q < 8:
    23          atk = torchattacks.Square(model, norm="Linf", eps=eps, n_queries=q, n_rest
AssistantStep 2test
Tool ResultStep 2
[Test #1] (2 tests remaining; call submit(n=1) to submit this result)

### ResNet20-C10 (scripts/eval_resnet20_cifar10.sh)
Attack results (ResNet20-C10): ASR=0.9950, clean_acc=1.0000, robust_acc=0.0050, avg_queries=279.32

### VGG11BN-C10 (scripts/eval_vgg11bn_cifar10.sh)
Attack results (VGG11BN-C10): ASR=0.9100, clean_acc=1.0000, robust_acc=0.0900, avg_queries=560.62

### MobileNetV2-C10 (scripts/eval_mobilenetv2_cifar10.sh)
Attack results (MobileNetV2-C10): ASR=0.9900, clean_acc=1.0000, robust_acc=0.0100, avg_queries=251.03

### ResNet20-C100 (scripts/eval_resnet20_cifar100.sh)
Attack results (ResNet20-C100): ASR=0.9900, clean_acc=1.0000, robust_acc=0.0100, avg_queries=309.09

### VGG11BN-C100 (scripts/eval_vgg11bn_cifar100.sh)
Attack results (VGG11BN-C100): ASR=0.8600, clean_acc=1.0000, robust_acc=0.1400, avg_queries=589.25

[Leaderboard] Results saved: {'asr_ResNet20_C10': 0.995, 'avg_queries_ResNet20_C10': 279.32, 'elapsed_ResNet20-C10': 167, 'asr_VGG11BN_C10': 0.91, 'avg_queries_
AssistantStep 3edit
Tool Input
1{
2  "op": "replace",
3  "filename": "torchattacks/bench/custom_attack.py",
4  "content": "def run_attack(\n    model: nn.Module,\n    images: torch.Tensor,\n    labels: torch.Tensor,\n    eps: float,\n    n_queries: int,\n    device: torch.device,\n    n_classes: int,\n) -> torch.Tensor:\n    import os, sys\n    sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), \"..\")))\n    import torchattacks\n\n    model.eval()\n    q = max(1, int(n_queries))\n    if q < 8:\n        atk = torchattacks.Square(model, norm=\"Linf\", eps=eps, n_queries=q, n_restarts=1,\n                                  p_init=0.8, seed=0, verbose=False, loss=\"margin\", resc_schedule=True)\n        adv = atk(images, labels)\n        delta = torch.clamp(adv - images, -eps, eps)\n        return torch.clamp(images + delta, 0.0, 1.0).detach()\n\n    n, c, h, w = images.shape\n    budget = n * q\n    topk = min(2 if n_classes <= 10 else 4, max(1, n_classes - 1))\n\n    def qcount():\n        names = (\"query_count\", \"queries\", \"num_queries\", \"n_queries\", \"total_queries\", \"count\")\n        for obj in (model, getattr(model, \"model\", None)):\n            if obj is None:\n                continue\n            for name in names:\n                v = getattr(obj, name, None)\n                if isinstance(v, (int, float)):\n                    return int(v)\n                if torch.is_tensor(v) and v.numel() == 1:\n                    return int(v.item())\n        return None\n\n    def proj(x, base):\n        return torch.clamp(base + torch.clamp(x - base, -eps, eps), 0.0, 1.0)\n\n    def score(z, y):\n        a = torch.arange(z.size(0), device=z.device)\n        zz = z.clone()\n        zz[a, y] = -1e9\n        return z[a, y] - torch.logsumexp(zz.topk(topk, 1).values, 1)\n\n    def eval_logits(x):\n        q0 = qcount()\n        z = model(x)\n        q1 = qcount()\n        spent = (q1 - q0) if (q0 is not None and q1 is not None and q1 >= q0) else x.size(0)\n        return z, spent\n\n    def run_square(x, y, qx, loss, p_init, seed):\n        q0 = qcount()\n        atk = torchattacks.Square(model=model, norm=\"Linf\", eps=eps, n_queries=int(qx), n_restarts=1,\n                                  p_init=p_init, seed=seed, verbose=False, loss=loss, resc_schedule=True)\n        out = proj(atk(x, y), x)\n        q1 = qcount()\n        spent = (q1 - q0) if (q0 is not None and q1 is not None and q1 >= q0) else x.size(0) * int(qx)\n        return out, spent\n\n    with torch.no_grad():\n        used = 0\n        q1 = max(1, min(q - 5, int(0.36 * q)))\n        adv, spent = run_square(images, labels, q1, \"margin\", 0.8, 0)\n        used += spent\n        logits, spent = eval_logits(adv)\n        used += spent\n        scr, pred = score(logits, labels), logits.argmax(1)\n\n        for share, loss, p_init, seed in ((0.28, \"ce\", 0.25, 1), (0.18, \"margin\", 0.08, 2)):\n            idx = pred.eq(labels).nonzero(as_tuple=False).squeeze(1)\n            if idx.numel() == 0 or used >= budget:\n                break\n            left = budget - used\n            m = idx.numel()\n            qx = min(int(share * q), max(0, left // m - 2))\n            if qx < 1:\n                continue\n            cand, spent = run_square(images[idx], labels[idx], qx, loss, p_init, seed)\n            used += spent\n            logits2, spent = eval_logits(cand)\n            used += spent\n            s2, p2 = score(logits2, labels[idx]), logits2.argmax(1)\n            take = ((p2 != labels[idx]) & (pred[idx] == labels[idx])) | (((p2 != labels[idx]) == (pred[idx] != labels[idx])) & (s2 < scr[idx]))\n            if take.any():\n                j = idx[take]\n                adv[j], scr[j], pred[j] = cand[take], s2[take], p2[take]\n\n        delta = adv - images\n        while used < budget and pred.eq(labels).any():\n            idx = pred.eq(labels).nonzero(as_tuple=False).squeeze(1)\n            left = budget - used\n            if left <= 0:\n                break\n            two = left >= 2 * idx.numel()\n            need = 2 if two else 1\n            cap = max(1, left // need)\n            if idx.numel() > cap:\n                idx = idx[scr[idx].argsort()[:cap]]\n            m = idx.numel()\n            if m == 0:\n                break\n            t = max(0.0, 1.0 - used / max(1, budget))\n            p = 0.04 + 0.55 * t * t\n            s = max(1, min(h, w, int(round((p * h * w) ** 0.5))))\n            d1 = delta[idx].clone()\n            d2 = delta[idx].clone() if two else None\n            textured = s <= 6\n            xs = torch.randint(0, h - s + 1, (m,), device=device)\n            ys = torch.randint(0, w - s + 1, (m,), device=device)\n            for i in range(m):\n                if textured:\n                    patch = eps * (torch.randint(0, 2, (c, s, s), device=device).to(images.dtype) * 2 - 1)\n                else:\n                    patch = eps * (torch.randint(0, 2, (c, 1, 1), device=device).to(images.dtype) * 2 - 1)\n                x0, y0 = xs[i].item(), ys[i].item()\n                d1[i, :, x0:x0 + s, y0:y0 + s] = patch\n                if two:\n                    d2[i, :, x0:x0 + s, y0:y0 + s] = -patch\n            c1 = torch.clamp(images[idx] + torch.clamp(d1, -eps, eps), 0.0, 1.0)\n            if two:\n                c2 = torch.clamp(images[idx] + torch.clamp(d2, -eps, eps), 0.0, 1.0)\n                z, spent = eval_logits(torch.cat([c1, c2], 0))\n                used += spent\n                z1, z2 = z[:m], z[m:]\n                s1, s2 = score(z1, labels[idx]), score(z2, labels[idx])\n                p1, p2 = z1.argmax(1), z2.argmax(1)\n                pick2 = ((p2 != labels[idx]) & (p1 == labels[idx])) | (((p2 != labels[idx]) == (p1 != labels[idx])) & (s2 < s1))\n                cand = torch.where(pick2[:, None, None, None], c2, c1)\n                ns, np = torch.where(pick2, s2, s1), torch.where(pick2, p2, p1)\n            else:\n                z1, spent = eval_logits(c1)\n                used += spent\n                cand, ns, np = c1, score(z1, labels[idx]), z1.argmax(1)\n            good = (np != labels[idx]) | (ns < scr[idx])\n            if good.any():\n                j = idx[good]\n                adv[j], delta[j], scr[j], pred[j] = cand[good], cand[good] - images[j], ns[good], np[good]\n\n        delta = torch.clamp(adv - images, -eps, eps)\n        return torch.clamp(images + delta, 0.0, 1.0).detach()",
5  "after_line": 0,
6  "start_line": 7,
7  "end_line": 90
8}
Tool ResultStep 3
OK: Replaced lines 7..90 with 145 line(s) in torchattacks/bench/custom_attack.py. Editable range: 7–151.

[Current file: torchattacks/bench/custom_attack.py | editable: 7–151 | total: 155 lines]
     4| # =====================================================================
     5| # EDITABLE: implement run_attack below
     6| # =====================================================================
     7  def run_attack(
     8      model: nn.Module,
     9      images: torch.Tensor,
    10      labels: torch.Tensor,
    11      eps: float,
    12      n_queries: int,
    13      device: torch.device,
    14      n_classes: int,
    15  ) -> torch.Tensor:
    16      import os, sys
    17      sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
    18      import torchattacks
    19  
    20      model.eval()
    21      q = max(1, int(n_queries))
    22      if q < 8:
    23          atk = torchattacks.Square(model, norm="Linf", eps=eps, n_queries=q, n_
AssistantStep 4test
Tool ResultStep 4
[Test #2] (1 test remaining; call submit(n=2) to submit this result)

### ResNet20-C10 (scripts/eval_resnet20_cifar10.sh)
Attack results (ResNet20-C10): ASR=0.9950, clean_acc=1.0000, robust_acc=0.0050, avg_queries=418.08

### VGG11BN-C10 (scripts/eval_vgg11bn_cifar10.sh)
Attack results (VGG11BN-C10): ASR=0.8600, clean_acc=1.0000, robust_acc=0.1400, avg_queries=766.00

### MobileNetV2-C10 (scripts/eval_mobilenetv2_cifar10.sh)
Attack results (MobileNetV2-C10): ASR=0.9900, clean_acc=1.0000, robust_acc=0.0100, avg_queries=628.02

### ResNet20-C100 (scripts/eval_resnet20_cifar100.sh)
Attack results (ResNet20-C100): ASR=0.9950, clean_acc=1.0000, robust_acc=0.0050, avg_queries=582.66

### VGG11BN-C100 (scripts/eval_vgg11bn_cifar100.sh)
Attack results (VGG11BN-C100): ASR=0.8150, clean_acc=1.0000, robust_acc=0.1850, avg_queries=783.15

[Leaderboard] Results saved: {'asr_ResNet20_C10': 0.995, 'avg_queries_ResNet20_C10': 418.08, 'elapsed_ResNet20-C10': 76, 'asr_VGG11BN_C10': 0.86, 'avg_queries_VG
AssistantStep 5submit
Tool Input
1{
2  "n": 1
3}
Tool ResultStep 5
[submit] Submitting result from test #1 as final.

### ResNet20-C10 (scripts/eval_resnet20_cifar10.sh)
Attack results (ResNet20-C10): ASR=0.9950, clean_acc=1.0000, robust_acc=0.0050, avg_queries=279.32

### VGG11BN-C10 (scripts/eval_vgg11bn_cifar10.sh)
Attack results (VGG11BN-C10): ASR=0.9100, clean_acc=1.0000, robust_acc=0.0900, avg_queries=560.62

### MobileNetV2-C10 (scripts/eval_mobilenetv2_cifar10.sh)
Attack results (MobileNetV2-C10): ASR=0.9900, clean_acc=1.0000, robust_acc=0.0100, avg_queries=251.03

### ResNet20-C100 (scripts/eval_resnet20_cifar100.sh)
Attack results (ResNet20-C100): ASR=0.9900, clean_acc=1.0000, robust_acc=0.0100, avg_queries=309.09

### VGG11BN-C100 (scripts/eval_vgg11bn_cifar100.sh)
Attack results (VGG11BN-C100): ASR=0.8600, clean_acc=1.0000, robust_acc=0.1400, avg_queries=589.25

### MobileNetV2-C100 (scripts/eval_mobilenetv2_cifar100.sh)
Attack results (MobileNetV2-C100): ASR=1.0000, clean_acc=1.0000, robust_acc=0.0000, avg_queries=315.85

[Leaderboard] Re