Agent Conversation: deepseek-reasoner
Task: cv-pooling-aggregation
CV Global Pooling / Feature Aggregation Design
Research Question
Design a novel global pooling or feature aggregation strategy for image classification that improves test accuracy across different CNN architectures and datasets.
Background
Global pooling is the final spatial aggregation step in modern CNNs, reducing feature maps from [B, C, H, W] to [B, C] before the classifier head. The standard approach is Global Average Pooling (GAP), which computes the spatial mean per channel. While simple and effective, GAP discards spatial structure and treats all positions equally. Alternative strategies include:
- Global Max Pooling (GMP): Selects the strongest activation per channel, capturing the most salient features but ignoring distribution information.
- Generalized Mean (GeM) Pooling (Radenovic et al., 2018): Learnable power-mean that interpolates between average and max pooling.
- Average + Max: Element-wise combination of GAP and GMP, capturing both mean-field and peak statistics.
There is room to design pooling strategies that better capture the spatial statistics of feature maps, adapt to different architectures, or learn task-specific aggregation patterns.
What You Can Modify
The CustomPool class (lines 31-48) in custom_pool.py. This class receives a 4D tensor [B, C, H, W] and must return a 2D tensor [B, C].
You can modify:
- The aggregation function (mean, max, learned weights, attention, higher-order statistics)
- Whether to use learnable parameters
- How spatial information is summarized (single-point, multi-scale, distribution-based)
- Channel-wise or spatial-wise weighting mechanisms
- Any combination of the above
Constraints:
- Input: [B, C, H, W] tensor (C varies by architecture: 64 for ResNet-56, 512 for VGG-16-BN, 1280 for MobileNetV2)
- Output: [B, C] tensor (must match input channel dimension exactly)
- Must work with variable spatial sizes (8×8 for ResNet on CIFAR, 1×1 for VGG after max-pools, 1×1 for MobileNetV2)
- No access to training data or labels within the pooling layer
Evaluation
- Metric: Best test accuracy (%, higher is better)
- Architectures & datasets:
- ResNet-56 on CIFAR-100 (deep residual, 100 classes; final feature map 8×8, C=64)
- VGG-16-BN on CIFAR-100 (deep non-residual with BatchNorm, 100 classes; final feature map 1×1 after max-pools, C=512)
- MobileNetV2 on FashionMNIST (lightweight inverted-residual, 10 classes) — hidden, evaluated on final submission only
- Training: SGD (lr=0.1, momentum=0.9, wd=5e-4), cosine annealing, 200 epochs
- Data augmentation: RandomCrop(32, pad=4) + RandomHorizontalFlip
pytorch-vision/custom_pool.py [EDITABLE — lines 31–48 only]
1: """CV Pooling / Feature Aggregation Benchmark.
2:
3: Train vision models (ResNet, VGG, MobileNetV2) on CIFAR-10/100/FashionMNIST to evaluate
4: global pooling and feature aggregation strategies.
5:
6: FIXED: ModTool Input
1{
2 "op": "replace",
3 "filename": "pytorch-vision/custom_pool.py",
4 "start_line": 31,
5 "end_line": 48,
6 "content": "class CustomPool(nn.Module):\n \"\"\"Spatial-Channel Attention Pooling (SCAP).\n\n Learns adaptive spatial attention weights per channel group,\n combining attentive averaging with max pooling via learnable gating.\n \"\"\"\n\n def __init__(self):\n super().__init__()\n # Will be initialized in forward when input shape is known\n self.spatial_att = None\n self.gate_weight = nn.Parameter(torch.tensor(0.5))\n self.gate_bias = nn.Parameter(torch.tensor(0.0))\n\n def _create_spatial_attention(self, x):\n B, C, H, W = x.shape\n # Use a lightweight network to compute spatial attention\n # First reduce channels via group convolution, then produce HxW attention map\n # We'll create a small network: Conv2d with kernel 1 to reduce to C//8, then another to 1 channel\n # Use adaptive architecture based on input size\n mid_channels = max(C // 8, 8)\n self.spatial_att = nn.Sequential(\n nn.Conv2d(C, mid_channels, kernel_size=1, bias=False),\n nn.BatchNorm2d(mid_channels),\n nn.ReLU(inplace=True),\n nn.Conv2d(mid_channels, 1, kernel_size=1, bias=False),\n nn.Sigmoid()\n ).to(x.device)\n\n def forward(self, x):\n B, C, H, W = x.shape\n if self.spatial_att is None:\n self._create_spatial_attention(x)\n \n # Compute spatial attention weights\n att_weights = self.spatial_att(x) # [B, 1, H, W]\n \n # Weighted average pooling using attention\n weighted_sum = (x * att_weights).sum(dim=(2, 3)) # [B, C]\n weights_sum = att_weights.sum(dim=(2, 3)) # [B, 1]\n attended_avg = weighted_sum / (weights_sum + 1e-6) # [B, C]\n \n # Global max pooling\n global_max = F.adaptive_max_pool2d(x, 1).view(B, C)\n \n # Learnable gating between attended average and max\n gate = torch.sigmoid(self.gate_weight * 2.0 + self.gate_bias)\n out = gate * attended_avg + (1 - gate) * global_max\n return out"
7}OK: Replaced lines 31..48 with 49 line(s) in pytorch-vision/custom_pool.py. Editable range: 31–79.
[Current file: pytorch-vision/custom_pool.py | editable: 31–79 | total: 457 lines]
28| # ============================================================================
29|
30| # -- EDITABLE REGION START (lines 31-48) ------------------------------------
31 class CustomPool(nn.Module):
32 """Spatial-Channel Attention Pooling (SCAP).
33
34 Learns adaptive spatial attention weights per channel group,
35 combining attentive averaging with max pooling via learnable gating.
36 """
37
38 def __init__(self):
39 super().__init__()
40 # Will be initialized in forward when input shape is known
41 self.spatial_att = None
42 self.gate_weight = nn.Parameter(torch.tensor(0.5))
43 self.gate_bias = nn.Parameter(torch.tensor(0.0))
44
45 def _create_spatia
Tool Input
1{
2 "n": 1
3}Undo complete: Restored: pytorch-vision/custom_pool.py
Tool Input
1{
2 "op": "replace",
3 "filename": "pytorch-vision/custom_pool.py",
4 "start_line": 31,
5 "end_line": 48,
6 "content": "class CustomPool(nn.Module):\n \"\"\"Channel-wise Adaptive Pooling (CAP).\n\n Learns per-channel gating between average and max pooling,\n allowing each channel to adapt its aggregation strategy.\n Also incorporates channel statistics (mean, max, std) to compute gates.\n \"\"\"\n\n def __init__(self):\n super().__init__()\n # Statistics network: from 3 stats per channel to gate weight\n # Input: 3 statistics (mean, max, std) per channel\n # Output: gate weight per channel\n self.stat_net = nn.Sequential(\n nn.Linear(3, 16),\n nn.ReLU(inplace=True),\n nn.Linear(16, 1, bias=False),\n )\n # Learnable bias per channel\n self.gate_bias = nn.Parameter(torch.zeros(1))\n # Learnable per-channel scaling factor\n self.channel_scale = nn.Parameter(torch.ones(1))\n\n def forward(self, x):\n B, C, H, W = x.shape\n # Compute basic statistics per channel\n x_flat = x.view(B, C, -1) # [B, C, H*W]\n mean = x_flat.mean(dim=2) # [B, C]\n maximum = x_flat.max(dim=2).values # [B, C]\n std = x_flat.std(dim=2) # [B, C]\n \n # Concatenate statistics along new dimension for stat_net\n stats = torch.stack([mean, maximum, std], dim=2) # [B, C, 3]\n # Process each channel independently\n # Reshape to [B*C, 3], apply stat_net, reshape back\n stats_flat = stats.view(-1, 3)\n gate_raw = self.stat_net(stats_flat).view(B, C) # [B, C]\n gate = torch.sigmoid(gate_raw * self.channel_scale + self.gate_bias)\n \n # Compute average and max pooling\n avg_pool = F.adaptive_avg_pool2d(x, 1).view(B, C)\n max_pool = F.adaptive_max_pool2d(x, 1).view(B, C)\n \n # Channel-wise gated combination\n out = gate * avg_pool + (1 - gate) * max_pool\n return out"
7}OK: Replaced lines 31..48 with 46 line(s) in pytorch-vision/custom_pool.py. Editable range: 31–76.
[Current file: pytorch-vision/custom_pool.py | editable: 31–76 | total: 454 lines]
28| # ============================================================================
29|
30| # -- EDITABLE REGION START (lines 31-48) ------------------------------------
31 class CustomPool(nn.Module):
32 """Channel-wise Adaptive Pooling (CAP).
33
34 Learns per-channel gating between average and max pooling,
35 allowing each channel to adapt its aggregation strategy.
36 Also incorporates channel statistics (mean, max, std) to compute gates.
37 """
38
39 def __init__(self):
40 super().__init__()
41 # Statistics network: from 3 stats per channel to gate weight
42 # Input: 3 statistics (mean, max, std) per channel
43 # Output: gate weight per channel
44 self.stat_
[Test #1] (1 test remaining; call submit(n=N) to choose which test result to submit as final) ### resnet56-cifar100 (scripts/resnet56_cifar100.sh) Training progress (last epochs): TRAIN_METRICS: epoch=160 train_loss=0.3743 train_acc=88.48 test_loss=1.2288 test_acc=69.24 lr=0.009549 TRAIN_METRICS: epoch=170 train_loss=0.2204 train_acc=93.49 test_loss=1.2561 test_acc=69.81 lr=0.005450 TRAIN_METRICS: epoch=180 train_loss=0.1097 train_acc=97.37 test_loss=1.2271 test_acc=71.53 lr=0.002447 TRAIN_METRICS: epoch=190 train_loss=0.0658 train_acc=98.82 test_loss=1.2222 test_acc=72.39 lr=0.000616 TRAIN_METRICS: epoch=200 train_loss=0.0566 train_acc=99.08 test_loss=1.2181 test_acc=72.39 lr=0.000000 Final metrics (resnet56-cifar100): test_acc_resnet56-cifar100=72.66 ### vgg16bn-cifar100 (scripts/vgg16bn_cifar100.sh) Training progress (last epochs): TRAIN_METRICS: epoch=160 train_loss=nan train_acc=1.00 test_loss=nan test_acc=1.00 lr=0.009549 TRAIN_METRICS: epoch=170 train_loss=nan train_acc=1.00 t
Tool Input
1{
2 "n": 1
3}ERROR: Nothing to undo
Tool Input
1{
2 "op": "replace",
3 "filename": "pytorch-vision/custom_pool.py",
4 "start_line": 31,
5 "end_line": 76,
6 "content": "class CustomPool(nn.Module):\n \"\"\"Learnable Channel-wise Gated Pooling (LCGP).\n\n Learns per-channel gating weight between average and max pooling.\n Each channel can adapt its own aggregation strategy.\n Simple, stable, and effective across architectures.\n \"\"\"\n\n def __init__(self):\n super().__init__()\n # Per-channel gate weight (unscaled). Shape will be set in forward.\n self.gate_weight = None\n # Temperature parameter (learnable)\n self.temperature = nn.Parameter(torch.tensor(1.0))\n # Small epsilon for numerical stability\n self.eps = 1e-6\n\n def _initialize_gate_weight(self, C):\n # Initialize gate weight for C channels\n # Initialize to zero so sigmoid(gate_weight) = 0.5 (equal blend)\n self.gate_weight = nn.Parameter(torch.zeros(C))\n\n def forward(self, x):\n B, C, H, W = x.shape\n if self.gate_weight is None or len(self.gate_weight) != C:\n self._initialize_gate_weight(C)\n \n # Compute average and max pooling\n avg_pool = F.adaptive_avg_pool2d(x, 1).view(B, C)\n max_pool = F.adaptive_max_pool2d(x, 1).view(B, C)\n \n # Per-channel gate\n gate = torch.sigmoid(self.gate_weight * self.temperature)\n # Ensure gate is broadcastable across batch dimension\n gate = gate.view(1, C)\n \n # Gated combination\n out = gate * avg_pool + (1 - gate) * max_pool\n return out"
7}OK: Replaced lines 31..76 with 39 line(s) in pytorch-vision/custom_pool.py. Editable range: 31–69.
[Current file: pytorch-vision/custom_pool.py | editable: 31–69 | total: 447 lines]
31 class CustomPool(nn.Module):
32 """Learnable Channel-wise Gated Pooling (LCGP).
33
... (33 more lines) ...
67 # Gated combination
68 out = gate * avg_pool + (1 - gate) * max_pool
69 return outTool Input
1{
2 "op": "replace",
3 "filename": "pytorch-vision/custom_pool.py",
4 "start_line": 31,
5 "end_line": 69,
6 "content": "class CustomPool(nn.Module):\n \"\"\"Learnable Channel-wise Gated Pooling (LCGP).\n\n Learns per-channel gating weight between average and max pooling.\n Each channel can adapt its own aggregation strategy.\n Simple, stable, and effective across architectures.\n \"\"\"\n\n def __init__(self):\n super().__init__()\n # Maximum number of channels expected across architectures\n # MobileNetV2 has 1280 channels, we add margin\n self.max_channels = 2048\n # Per-channel gate weight (unscaled). Initialize to zero -> sigmoid(0)=0.5\n self.gate_weight = nn.Parameter(torch.zeros(self.max_channels))\n # Temperature parameter (learnable)\n self.temperature = nn.Parameter(torch.tensor(1.0))\n # Small epsilon for numerical stability\n self.eps = 1e-6\n\n def forward(self, x):\n B, C, H, W = x.shape\n if C > self.max_channels:\n # Expand gate_weight if needed (unlikely)\n new_weight = torch.zeros(C, device=x.device)\n new_weight[:self.max_channels] = self.gate_weight\n self.gate_weight = nn.Parameter(new_weight)\n self.max_channels = C\n \n # Slice gate_weight to current channel count\n gate_w = self.gate_weight[:C]\n \n # Compute average and max pooling\n avg_pool = F.adaptive_avg_pool2d(x, 1).view(B, C)\n max_pool = F.adaptive_max_pool2d(x, 1).view(B, C)\n \n # Per-channel gate\n gate = torch.sigmoid(gate_w * self.temperature)\n # Ensure gate is broadcastable across batch dimension\n gate = gate.view(1, C)\n \n # Gated combination\n out = gate * avg_pool + (1 - gate) * max_pool\n return out"
7}OK: Replaced lines 31..69 with 44 line(s) in pytorch-vision/custom_pool.py. Editable range: 31–74.
[Current file: pytorch-vision/custom_pool.py | editable: 31–74 | total: 452 lines]
31 class CustomPool(nn.Module):
32 """Learnable Channel-wise Gated Pooling (LCGP).
33
... (38 more lines) ...
72 # Gated combination
73 out = gate * avg_pool + (1 - gate) * max_pool
74 return out