Agent Conversation: google/gemini-3.1-pro-preview

Model: google/gemini-3.1-pro-preview
Total steps: 7
Messages: 15
UserStep 0

Task: security-adversarial-attack-black-box-score

Score-Based Query Black-Box Attack under Linf Constraint

Research Question

Can you design a stronger score-based query black-box attack that improves attack success rate (ASR) under a fixed query budget and L_inf perturbation constraint?

Objective

Implement a better black-box attack in bench/custom_attack.py:

  • Threat model: query black-box (no gradient access).
  • Constraint: ||x_adv - x||_inf <= eps.
  • Budget: n_queries is a per-sample query budget.
  • Primary metric: maximize ASR under fixed budget.
  • Tie-break: for similar ASR, lower avg_queries is better.

Editable Interface

You must implement:

run_attack(model, images, labels, eps, n_queries, device, n_classes) -> adv_images

Inputs:

  • model: black-box wrapper that returns logits only.
  • images: tensor of shape (N, C, H, W), in [0, 1].
  • labels: tensor of shape (N,).
  • n_classes: 10 for CIFAR-10, 100 for CIFAR-100.

Output:

  • adv_images: tensor with same shape as images, values in [0, 1].

Trusted Evaluation Logic

The evaluation logic in bench/run_eval.py is trusted and not editable.

  • It tracks all model queries through a wrapper.
  • If a batch exceeds query budget (batch_size * n_queries), the entire batch is marked as attack failure.
  • L_inf and [0, 1] validity are checked per sample; only invalid samples are marked as attack failure.

Do not attempt to bypass wrapper behavior or evaluation logic. Only improve the attack algorithm in custom_attack.py.

Query Semantics

  • One call to model(x) consumes x.shape[0] queries.
  • Repeated calls on the same sample still consume additional queries.
  • Different batch partitioning should be treated as equivalent total budget usage.

Performance Hint

The benchmark is query-heavy. Prefer batch-parallel score updates instead of pure per-sample loops when possible.

Evaluation Scenarios (6)

  • ResNet20 on CIFAR-10
  • VGG11-BN on CIFAR-10
  • MobileNetV2 on CIFAR-10
  • ResNet20 on CIFAR-100
  • VGG11-BN on CIFAR-100
  • MobileNetV2 on CIFAR-100

Reported metrics line format:

ATTACK_METRICS asr=... clean_acc=... robust_acc=... avg_queries=...

torchattacks/bench/custom_attack.py [EDITABLE — lines 7–56 only]

     1: import torch
     2: import torch.nn as nn
     3: 
     4: # =====================================================================
     5: # EDITABLE: implement run_attack below
     6: # =====================================================================
     7: def run_attack(
     8:     model: nn.Module,
     9:     images: torch.Tensor,
    10:     labels: torch.Tensor,
    11:     eps: float,
    12:     n_queries: int,
    13:     device: torch.device,
    14:     n_classes: int,
    15: ) -> torch.Tensor:
    16:     """
    17:     Score-based query black-box attack under Linf constraint.
    18: 
    19:     Args:
    20:         model: black-box wrapper that only exposes forward logits.
    21:     
AssistantStep 1edit
Tool Input
1{
2  "content": "def run_attack(\n    model: nn.Module,\n    images: torch.Tensor,\n    labels: torch.Tensor,\n    eps: float,\n    n_queries: int,\n    device: torch.device,\n    n_classes: int,\n) -> torch.Tensor:\n    import math\n\n    model.eval()\n    N, C, H, W = images.shape\n\n    # Precompute one-hot for margin\n    y_onehot = torch.nn.functional.one_hot(labels, n_classes).to(device)\n\n    def get_margin(logits, y_oh):\n        real = (logits * y_oh).sum(1)\n        other = logits - y_oh * 1e4\n        return other.max(1)[0] - real\n\n    # Initial query\n    with torch.no_grad():\n        logits = model(images)\n    margin = get_margin(logits, y_onehot)\n    best_margin = margin.clone()\n    succ = best_margin > 0\n\n    # Initialize perturbation with random noise\n    pert = torch.empty_like(images).uniform_(-eps, eps)\n    adv = torch.clamp(images + pert, 0.0, 1.0)\n    pert = torch.clamp(adv - images, -eps, eps)\n    adv = images + pert\n\n    with torch.no_grad():\n        logits_init = model(adv)\n    margin_init = get_margin(logits_init, y_onehot)\n\n    improve = margin_init > best_margin\n    best_margin = torch.where(improve, margin_init, best_margin)\n    succ = best_margin > 0\n    adv = torch.where(improve.view(-1, 1, 1, 1), adv, images)\n    pert = adv - images\n\n    budget = N * n_queries\n    queries_used = 2 * N\n\n    grad_est = torch.zeros_like(pert)\n    beta = 0.8\n    gamma = 1.0\n\n    p_start = 0.4\n    p_end = 0.01\n\n    i = 0\n    while queries_used < budget:\n        active_idx = (~succ).nonzero(as_tuple=True)[0]\n        K = active_idx.numel()\n        if K == 0:\n            break\n        \n        if queries_used + K > budget:\n            break\n\n        queries_used += K\n        i += 1\n        \n        # Calculate p based on iteration relative to expected max_iters\n        expected_iters = max(1, n_queries * 2)\n        frac = min(1.0, i / expected_iters)\n        p = p_start * ((p_end / p_start) ** frac)\n        s = max(1, int(round(math.sqrt(p * H * W))))\n\n        # Create block mask\n        mask = torch.zeros((N, 1, H, W), device=device)\n        for b in active_idx.tolist():\n            x = torch.randint(0, W - s + 1, (1,)).item()\n            y = torch.randint(0, H - s + 1, (1,)).item()\n            mask[b, 0, y:y+s, x:x+s] = 1.0\n\n        noise = torch.empty_like(pert).uniform_(-1.0, 1.0)\n        direction = torch.sign(gamma * grad_est + noise)\n\n        cand_pert = pert * (1 - mask) + (eps * direction) * mask\n        cand = torch.clamp(images + cand_pert, 0.0, 1.0)\n        cand_pert = torch.clamp(cand - images, -eps, eps)\n        cand = images + cand_pert\n\n        cand_active = cand[active_idx]\n        with torch.no_grad():\n            logits_active = model(cand_active)\n\n        y_oh_act = y_onehot[active_idx]\n        margin_cand_act = get_margin(logits_active, y_oh_act)\n\n        # Update best margins and perturbations\n        margin_cand = best_margin.clone()\n        margin_cand[active_idx] = margin_cand_act\n\n        delta_m = margin_cand - best_margin\n        \n        # Compute gradient estimate update only for active to save ops\n        g_t = torch.sign(delta_m).view(-1, 1, 1, 1) * (cand_pert - pert) / (eps + 1e-8)\n        grad_est = beta * grad_est + (1 - beta) * g_t\n\n        improve = margin_cand > best_margin\n        best_margin = torch.where(improve, margin_cand, best_margin)\n        succ = best_margin > 0\n\n        adv = torch.where(improve.view(-1, 1, 1, 1), cand, adv)\n        pert = adv - images\n\n    return adv.detach()\n",
3  "start_line": 7,
4  "filename": "bench/custom_attack.py",
5  "end_line": 56,
6  "after_line": 6,
7  "op": "replace"
8}
Tool ResultStep 1
ERROR: Package 'bench' is not in allowed packages
AssistantStep 2edit
Tool Input
1{
2  "content": "def run_attack(\n    model: nn.Module,\n    images: torch.Tensor,\n    labels: torch.Tensor,\n    eps: float,\n    n_queries: int,\n    device: torch.device,\n    n_classes: int,\n) -> torch.Tensor:\n    import math\n\n    model.eval()\n    N, C, H, W = images.shape\n\n    y_onehot = torch.nn.functional.one_hot(labels, n_classes).to(device)\n\n    def get_margin(logits, y_oh):\n        real = (logits * y_oh).sum(1)\n        other = logits - y_oh * 1e4\n        return other.max(1)[0] - real\n\n    # Initial query\n    with torch.no_grad():\n        logits = model(images)\n    margin = get_margin(logits, y_onehot)\n    best_margin = margin.clone()\n    succ = best_margin > 0\n\n    if succ.all():\n        return images.clone().detach()\n\n    # Initialize perturbation with random uniform values in [-eps, eps]\n    pert = torch.empty_like(images).uniform_(-eps, eps)\n    adv = torch.clamp(images + pert, 0.0, 1.0)\n    pert = torch.clamp(adv - images, -eps, eps)\n    adv = images + pert\n\n    with torch.no_grad():\n        logits_init = model(adv)\n    margin_init = get_margin(logits_init, y_onehot)\n\n    improve = margin_init > best_margin\n    best_margin = torch.where(improve, margin_init, best_margin)\n    succ = best_margin > 0\n    adv = torch.where(improve.view(-1, 1, 1, 1), adv, images)\n    pert = adv - images\n\n    budget = N * n_queries\n    queries_used = 2 * N\n\n    grad_est = torch.zeros_like(pert)\n    beta = 0.8\n    gamma = 2.0  # Increased gamma to give more weight to momentum\n\n    p_start = 0.2\n    p_end = 0.01\n\n    i = 0\n    while queries_used < budget:\n        if succ.all():\n            break\n\n        active_idx = (~succ).nonzero(as_tuple=True)[0]\n        K = active_idx.numel()\n        \n        if queries_used + K > budget:\n            break\n\n        queries_used += K\n        i += 1\n        \n        # Calculate p based on iteration relative to expected max_iters (proxy: n_queries*1.5)\n        expected_iters = max(1, int(n_queries * 1.5))\n        frac = min(1.0, i / expected_iters)\n        p = p_start * ((p_end / p_start) ** frac)\n        s = max(1, int(round(math.sqrt(p * H * W))))\n\n        mask = torch.zeros((N, 1, H, W), device=device)\n        # We only generate mask for active indices\n        for b in active_idx.tolist():\n            x = torch.randint(0, W - s + 1, (1,)).item()\n            y = torch.randint(0, H - s + 1, (1,)).item()\n            mask[b, 0, y:y+s, x:x+s] = 1.0\n\n        noise = torch.empty_like(pert).uniform_(-1.0, 1.0)\n        direction = torch.sign(gamma * grad_est + noise)\n\n        cand_pert = pert * (1 - mask) + (eps * direction) * mask\n        cand = torch.clamp(images + cand_pert, 0.0, 1.0)\n        cand_pert = torch.clamp(cand - images, -eps, eps)\n        cand = images + cand_pert\n\n        cand_active = cand[active_idx]\n        with torch.no_grad():\n            logits_active = model(cand_active)\n\n        y_oh_act = y_onehot[active_idx]\n        margin_cand_act = get_margin(logits_active, y_oh_act)\n\n        margin_cand = best_margin.clone()\n        margin_cand[active_idx] = margin_cand_act\n\n        delta_m = margin_cand - best_margin\n        \n        # Gradient update rule\n        g_t = torch.sign(delta_m).view(-1, 1, 1, 1) * (cand_pert - pert) / (eps + 1e-8)\n        grad_est = beta * grad_est + (1 - beta) * g_t\n\n        improve = margin_cand > best_margin\n        best_margin = torch.where(improve, margin_cand, best_margin)\n        succ = best_margin > 0\n\n        adv = torch.where(improve.view(-1, 1, 1, 1), cand, adv)\n        pert = adv - images\n\n    return adv.detach()",
3  "op": "replace",
4  "end_line": 56,
5  "filename": "torchattacks/bench/custom_attack.py",
6  "after_line": 6,
7  "start_line": 7
8}
Tool ResultStep 2
OK: Replaced lines 7..56 with 116 line(s) in torchattacks/bench/custom_attack.py. Editable range: 7–122.

[Current file: torchattacks/bench/custom_attack.py | editable: 7–122 | total: 126 lines]
     4| # =====================================================================
     5| # EDITABLE: implement run_attack below
     6| # =====================================================================
     7  def run_attack(
     8      model: nn.Module,
     9      images: torch.Tensor,
    10      labels: torch.Tensor,
    11      eps: float,
    12      n_queries: int,
    13      device: torch.device,
    14      n_classes: int,
    15  ) -> torch.Tensor:
    16      import math
    17  
    18      model.eval()
    19      N, C, H, W = images.shape
    20  
    21      y_onehot = torch.nn.functional.one_hot(labels, n_classes).to(device)
    22  
    23      def get_margin(logits, y_oh):
    24          real = (logits * y_oh).sum(1)
    25          other = logits - y_oh * 1e4
    26  
AssistantStep 3test
Tool ResultStep 3
[Test #1] (2 tests remaining; call submit(n=1) to submit this result)

### ResNet20-C10 (scripts/eval_resnet20_cifar10.sh)
Attack results (ResNet20-C10): ASR=1.0000, clean_acc=1.0000, robust_acc=0.0000, avg_queries=53.15

### VGG11BN-C10 (scripts/eval_vgg11bn_cifar10.sh)
Attack results (VGG11BN-C10): ASR=0.9400, clean_acc=1.0000, robust_acc=0.0600, avg_queries=933.38

### MobileNetV2-C10 (scripts/eval_mobilenetv2_cifar10.sh)
Attack results (MobileNetV2-C10): ASR=1.0000, clean_acc=1.0000, robust_acc=0.0000, avg_queries=87.44

### ResNet20-C100 (scripts/eval_resnet20_cifar100.sh)
Attack results (ResNet20-C100): ASR=1.0000, clean_acc=1.0000, robust_acc=0.0000, avg_queries=61.05

### VGG11BN-C100 (scripts/eval_vgg11bn_cifar100.sh)
Attack results (VGG11BN-C100): ASR=0.9550, clean_acc=1.0000, robust_acc=0.0450, avg_queries=999.99

[Leaderboard] Results saved: {'asr_ResNet20_C10': 1.0, 'avg_queries_ResNet20_C10': 53.15, 'elapsed_ResNet20-C10': 14, 'asr_VGG11BN_C10': 0.94, 'avg_queries_VGG11BN
AssistantStep 4edit
Tool Input
1{
2  "content": "def run_attack(\n    model: nn.Module,\n    images: torch.Tensor,\n    labels: torch.Tensor,\n    eps: float,\n    n_queries: int,\n    device: torch.device,\n    n_classes: int,\n) -> torch.Tensor:\n    import math\n\n    model.eval()\n    N, C, H, W = images.shape\n\n    y_onehot = torch.nn.functional.one_hot(labels, n_classes).to(device)\n\n    def get_margin(logits, y_oh):\n        real = (logits * y_oh).sum(1)\n        other = logits - y_oh * 1e4\n        return other.max(1)[0] - real\n\n    # Initial query\n    with torch.no_grad():\n        logits = model(images)\n    margin = get_margin(logits, y_onehot)\n    best_margin = margin.clone()\n    succ = best_margin > 0\n    queries_used = N\n\n    if succ.all():\n        return images.clone().detach()\n\n    # Dense boundary initialization FOR ACTIVE\n    pert = torch.zeros_like(images)\n    active_idx = (~succ).nonzero(as_tuple=True)[0]\n    \n    pert[active_idx] = torch.empty_like(images[active_idx]).uniform_(-eps, eps)\n    adv = torch.clamp(images + pert, 0.0, 1.0)\n    pert = adv - images\n\n    if active_idx.numel() > 0:\n        with torch.no_grad():\n            logits_init = model(adv[active_idx])\n        margin_init = get_margin(logits_init, y_onehot[active_idx])\n        \n        best_margin[active_idx] = margin_init\n        succ = best_margin > 0\n        queries_used += active_idx.numel()\n\n    budget = N * n_queries\n\n    # Adam state for coordinate-wise momentum\n    m_t = torch.zeros_like(pert)\n    v_t = torch.zeros_like(pert)\n    beta1 = 0.8\n    beta2 = 0.99\n    gamma = 0.8  # blend momentum with noise\n\n    p_start = 0.05\n    p_end = 0.005\n\n    i = 0\n    expected_iters = max(1, n_queries - 2)\n\n    while queries_used < budget:\n        if succ.all():\n            break\n\n        active_idx = (~succ).nonzero(as_tuple=True)[0]\n        K = active_idx.numel()\n        \n        if queries_used + K > budget:\n            break\n\n        queries_used += K\n        i += 1\n        \n        frac = min(1.0, i / expected_iters)\n        p = p_start * ((p_end / p_start) ** frac)\n        s = max(1, int(round(math.sqrt(p * H * W))))\n\n        mask = torch.zeros((N, 1, H, W), device=device)\n        for b in active_idx.tolist():\n            x = torch.randint(0, W - s + 1, (1,)).item()\n            y = torch.randint(0, H - s + 1, (1,)).item()\n            mask[b, 0, y:y+s, x:x+s] = 1.0\n\n        noise = torch.empty_like(pert).uniform_(-1.0, 1.0)\n        \n        # Calculate grad_est\n        grad_est = torch.where(v_t > 0, m_t / (torch.sqrt(v_t) + 1e-8), torch.zeros_like(m_t))\n        direction = torch.sign(gamma * grad_est + noise)\n\n        cand_pert = pert * (1 - mask) + (eps * direction) * mask\n        cand = torch.clamp(images + cand_pert, 0.0, 1.0)\n        cand_pert = cand - images\n\n        cand_active = cand[active_idx]\n        with torch.no_grad():\n            logits_active = model(cand_active)\n\n        y_oh_act = y_onehot[active_idx]\n        margin_cand_act = get_margin(logits_active, y_oh_act)\n\n        margin_cand = best_margin.clone()\n        margin_cand[active_idx] = margin_cand_act\n\n        delta_m = margin_cand - best_margin\n        \n        # Adam Update (only for masked pixels)\n        g_t = torch.sign(delta_m).view(-1, 1, 1, 1) * (cand_pert - pert) / (eps + 1e-8)\n        \n        m_t = torch.where(mask > 0, beta1 * m_t + (1 - beta1) * g_t, m_t)\n        v_t = torch.where(mask > 0, beta2 * v_t + (1 - beta2) * (g_t ** 2), v_t)\n\n        improve = margin_cand > best_margin\n        best_margin = torch.where(improve, margin_cand, best_margin)\n        succ = best_margin > 0\n\n        adv = torch.where(improve.view(-1, 1, 1, 1), cand, adv)\n        pert = adv - images\n\n    return adv.detach()",
3  "end_line": 122,
4  "start_line": 7,
5  "after_line": 6,
6  "filename": "torchattacks/bench/custom_attack.py",
7  "op": "replace"
8}
Tool ResultStep 4
OK: Replaced lines 7..122 with 123 line(s) in torchattacks/bench/custom_attack.py. Editable range: 7–129.

[Current file: torchattacks/bench/custom_attack.py | editable: 7–129 | total: 133 lines]
     4| # =====================================================================
     5| # EDITABLE: implement run_attack below
     6| # =====================================================================
     7  def run_attack(
     8      model: nn.Module,
     9      images: torch.Tensor,
    10      labels: torch.Tensor,
    11      eps: float,
    12      n_queries: int,
    13      device: torch.device,
    14      n_classes: int,
    15  ) -> torch.Tensor:
    16      import math
    17  
    18      model.eval()
    19      N, C, H, W = images.shape
    20  
    21      y_onehot = torch.nn.functional.one_hot(labels, n_classes).to(device)
    22  
    23      def get_margin(logits, y_oh):
    24          real = (logits * y_oh).sum(1)
    25          other = logits - y_oh * 1e4
    26 
AssistantStep 5test
Tool ResultStep 5
[Test #2] (1 test remaining; call submit(n=2) to submit this result)

### ResNet20-C10 (scripts/eval_resnet20_cifar10.sh)
Attack results (ResNet20-C10): ASR=0.9900, clean_acc=1.0000, robust_acc=0.0100, avg_queries=40.32

### VGG11BN-C10 (scripts/eval_vgg11bn_cifar10.sh)
Attack results (VGG11BN-C10): ASR=0.9900, clean_acc=1.0000, robust_acc=0.0100, avg_queries=462.18

### MobileNetV2-C10 (scripts/eval_mobilenetv2_cifar10.sh)
Attack results (MobileNetV2-C10): ASR=0.9950, clean_acc=1.0000, robust_acc=0.0050, avg_queries=53.35

### ResNet20-C100 (scripts/eval_resnet20_cifar100.sh)
Attack results (ResNet20-C100): ASR=1.0000, clean_acc=1.0000, robust_acc=0.0000, avg_queries=40.10

### VGG11BN-C100 (scripts/eval_vgg11bn_cifar100.sh)
Attack results (VGG11BN-C100): ASR=0.9900, clean_acc=1.0000, robust_acc=0.0100, avg_queries=495.17

[Leaderboard] Results saved: {'asr_ResNet20_C10': 0.99, 'avg_queries_ResNet20_C10': 40.32, 'elapsed_ResNet20-C10': 14, 'asr_VGG11BN_C10': 0.99, 'avg_queries_VGG11BN
AssistantStep 6edit
Tool Input
1{
2  "after_line": 57,
3  "op": "replace",
4  "end_line": 63,
5  "start_line": 58,
6  "content": "    # Adam state for coordinate-wise momentum\n    m_t = torch.zeros_like(pert)\n    v_t = torch.zeros_like(pert)\n    beta1 = 0.7\n    beta2 = 0.99\n    gamma = 0.5  # blend momentum with noise",
7  "filename": "torchattacks/bench/custom_attack.py"
8}
Tool ResultStep 6
OK: Replaced lines 58..63 with 6 line(s) in torchattacks/bench/custom_attack.py. Editable range: 7–129.

[Current file: torchattacks/bench/custom_attack.py | editable: 7–129 | total: 133 lines]
     4| # =====================================================================
     5| # EDITABLE: implement run_attack below
     6| # =====================================================================
     7  def run_attack(
     8      model: nn.Module,
     9      images: torch.Tensor,
    10      labels: torch.Tensor,
    11      eps: float,
    12      n_queries: int,
    13      device: torch.device,
    14      n_classes: int,
    15  ) -> torch.Tensor:
    16      import math
    17  
    18      model.eval()
    19      N, C, H, W = images.shape
    20  
    21      y_onehot = torch.nn.functional.one_hot(labels, n_classes).to(device)
    22  
    23      def get_margin(logits, y_oh):
    24          real = (logits * y_oh).sum(1)
    25          other = logits - y_oh * 1e4
    26   
AssistantStep 7test
Tool ResultStep 7
[Test #3] (0 tests remaining)

### ResNet20-C10 (scripts/eval_resnet20_cifar10.sh)
Attack results (ResNet20-C10): ASR=1.0000, clean_acc=1.0000, robust_acc=0.0000, avg_queries=40.91

### VGG11BN-C10 (scripts/eval_vgg11bn_cifar10.sh)
Attack results (VGG11BN-C10): ASR=0.9900, clean_acc=1.0000, robust_acc=0.0100, avg_queries=465.65

### MobileNetV2-C10 (scripts/eval_mobilenetv2_cifar10.sh)
Attack results (MobileNetV2-C10): ASR=0.9900, clean_acc=1.0000, robust_acc=0.0100, avg_queries=52.21

### ResNet20-C100 (scripts/eval_resnet20_cifar100.sh)
Attack results (ResNet20-C100): ASR=1.0000, clean_acc=1.0000, robust_acc=0.0000, avg_queries=38.26

### VGG11BN-C100 (scripts/eval_vgg11bn_cifar100.sh)
Attack results (VGG11BN-C100): ASR=0.9950, clean_acc=1.0000, robust_acc=0.0050, avg_queries=577.56

### MobileNetV2-C100 (scripts/eval_mobilenetv2_cifar100.sh)
Attack results (MobileNetV2-C100): ASR=0.9950, clean_acc=1.0000, robust_acc=0.0050, avg_queries=48.75

[Leaderboard] Results saved: {'asr_ResNe