Agent Conversation: deepseek-reasoner
Task: cv-sample-weighting
CV Sample Reweighting Strategy Design
Research Question
Design a novel sample reweighting strategy for class-imbalanced image classification that improves test accuracy on long-tail distributed datasets across different architectures and imbalance ratios.
Background
Real-world datasets often exhibit long-tail class distributions where a few head classes dominate while many tail classes have very few samples. Standard training with uniform loss weighting biases the model toward frequent classes, degrading performance on rare ones. Sample reweighting assigns per-class weights to the cross-entropy loss to counteract this imbalance. Classic approaches include:
- Inverse frequency: weight[c] = total / (C * count[c]) — directly compensates for imbalance
- Effective number (Cui et al., CVPR 2019): models data overlap using E_n = (1 - beta^n) / (1 - beta)
- Square-root inverse: weight[c] = 1/sqrt(count[c]) — a gentler smoothed variant
These methods define different mappings from class frequency to loss weight and may behave differently across datasets and imbalance regimes.
What You Can Modify
The compute_class_weights(class_counts, num_classes, config) function (lines 164-195) in custom_weighting.py. This function receives per-class sample counts and must return a weight tensor for nn.CrossEntropyLoss(weight=...).
You can modify:
- The functional form mapping class counts to weights (inverse, power-law, logarithmic, piecewise, etc.)
- Use of the
configdict:imbalance_ratio,dataset,arch,total_samples - Normalization strategy (sum to C, sum to 1, unnormalized, etc.)
- Any pure-computation logic (no access to training data or model parameters)
Evaluation
- Metric: Best test accuracy (%, higher is better) on the balanced test set
- Benchmarks (all long-tail imbalanced):
- ResNet-32 on CIFAR-10-LT (imbalance ratio = 100, 10 classes)
- ResNet-32 on CIFAR-100-LT (imbalance ratio = 100, 100 classes)
- VGG-16-BN on CIFAR-100-LT (imbalance ratio = 50, 100 classes) — hidden, evaluated on final submission only
- Training: SGD (lr=0.1, momentum=0.9, wd=5e-4), cosine annealing, 200 epochs
- Data augmentation: RandomCrop(32, pad=4) + RandomHorizontalFlip
pytorch-vision/custom_weighting.py [EDITABLE — lines 164–195 only]
1: """CV Sample Reweighting Benchmark.
2:
3: Train vision models (ResNet-32, VGG-16-BN) on long-tail imbalanced CIFAR
4: to evaluate sample reweighting strategies for class-imbalanced classification.
5:
6: FIXED: Model architectures, imbalanced dataset creation, data pipeline, training loop.
7: EDITABLE: compute_class_weights() function.
8:
9: Usage:
10: python custom_weighting.py --arch resnet32 --dataset cifar10 --imbalance-ratio 100 --seed 42
11: """
12:
13: import argparse
14: import math
15: import os
16: import time
17:
18: import numpy aTool Input
1{
2 "op": "replace",
3 "filename": "pytorch-vision/custom_weighting.py",
4 "start_line": 164,
5 "end_line": 195,
6 "content": "def compute_class_weights(class_counts, num_classes, config):\n \"\"\"Adaptive Power Weighting (APW) for long-tail classification.\n\n Weight for class c: w_c = (total / (C * count_c)) ^ p\n\n The exponent p is adaptively set based on imbalance ratio and dataset:\n p = base_exponent / log10(imbalance_ratio) * dataset_factor\n\n Where:\n - base_exponent = 0.5 (moderate damping)\n - dataset_factor = 1.0 for CIFAR-10, 0.8 for CIFAR-100 (more damping for more classes)\n\n This ensures:\n - Higher imbalance ratio → smaller p → gentler weighting to avoid overfitting to tail.\n - More classes (CIFAR-100) → smaller p → further damping for stability.\n - Weights are normalized to sum to C for consistent loss scale.\n\n The square-root weighting (p=0.5) is a special case; APW adjusts p per scenario.\n \"\"\"\n total = class_counts.sum().float()\n imbalance_ratio = config['imbalance_ratio']\n dataset = config['dataset']\n arch = config['arch']\n\n # Base exponent (square root is 0.5)\n base_exp = 0.5\n\n # Dataset factor: more damping for more classes\n if dataset == 'cifar100':\n dataset_factor = 0.8\n else: # cifar10\n dataset_factor = 1.0\n\n # Architecture factor: VGG may need more damping due to no skip connections\n if arch == 'vgg16bn':\n arch_factor = 0.9\n else:\n arch_factor = 1.0\n\n # Compute adaptive exponent\n # log10(imbalance_ratio) ≥ 1 (since ratio ≥ 10)\n p = base_exp * dataset_factor * arch_factor / math.log10(imbalance_ratio)\n\n # Clamp p to reasonable range [0.2, 1.0]\n p = max(0.2, min(p, 1.0))\n\n # Compute raw inverse frequency weights\n inv_freq = total / (num_classes * class_counts.float())\n \n # Apply power weighting\n weights = torch.pow(inv_freq, p)\n\n # Normalize so weights sum to num_classes (preserves average weight = 1)\n weights = weights / weights.sum() * num_classes\n\n return weights"
7}OK: Replaced lines 164..195 with 56 line(s) in pytorch-vision/custom_weighting.py. Editable range: 164–219. [Current file: pytorch-vision/custom_weighting.py | editable: 164–219 | total: 430 lines] 161| # EDITABLE 162| # ============================================================================ 163| # -- EDITABLE REGION START (lines 164-195) ------------------------------------ 164 def compute_class_weights(class_counts, num_classes, config): 165 """Adaptive Power Weighting (APW) for long-tail classification. 166 167 Weight for class c: w_c = (total / (C * count_c)) ^ p 168 169 The exponent p is adaptively set based on imbalance ratio and dataset: 170 p = base_exponent / log10(imbalance_ratio) * dataset_factor 171 172 Where: 173 - base_exponent = 0.5 (moderate damping) 174 - dataset_factor = 1.0 for CIFAR-10, 0.8 for CIFAR-100 (more damping for more classes) 175 176 This ensures:
[Test #1] (2 tests remaining; call submit(n=N) to choose which test result to submit as final) ### resnet32-cifar10lt (scripts/resnet32_cifar10lt.sh) Training progress (last epochs): TRAIN_METRICS: epoch=160 train_loss=0.0058 train_acc=99.90 test_loss=2.0276 test_acc=71.39 lr=0.009549 TRAIN_METRICS: epoch=170 train_loss=0.0039 train_acc=99.94 test_loss=2.3598 test_acc=68.75 lr=0.005450 TRAIN_METRICS: epoch=180 train_loss=0.0035 train_acc=99.97 test_loss=2.1570 test_acc=70.49 lr=0.002447 TRAIN_METRICS: epoch=190 train_loss=0.0025 train_acc=99.99 test_loss=2.2028 test_acc=70.15 lr=0.000616 TRAIN_METRICS: epoch=200 train_loss=0.0023 train_acc=99.98 test_loss=2.0985 test_acc=70.80 lr=0.000000 Final metrics (resnet32-cifar10lt): test_acc_resnet32-cifar10lt=72.61 ### resnet32-cifar100lt (scripts/resnet32_cifar100lt.sh) Training progress (last epochs): TRAIN_METRICS: epoch=160 train_loss=0.0230 train_acc=99.91 test_loss=4.0486 test_acc=38.87 lr=0.009549 TRAIN_METRICS: epoch=170 train_loss=0
Tool Input
1{
2 "op": "replace",
3 "filename": "pytorch-vision/custom_weighting.py",
4 "start_line": 164,
5 "end_line": 219,
6 "content": "def compute_class_weights(class_counts, num_classes, config):\n \"\"\"Adaptive Piecewise Power Weighting (APPW) for long-tail classification.\n\n Only tail classes (below median count) receive extra weight; head classes stay at 1.\n This prevents over-weighting already frequent classes while giving targeted boost to tail.\n\n For class c:\n if count_c >= median_count: weight = 1.0\n else: weight = (median_count / count_c) ^ p\n\n Exponent p adapts to imbalance ratio and dataset:\n p = base_exp * dataset_factor * arch_factor / log10(imbalance_ratio)\n with base_exp = 0.5, dataset_factor = 1.0 (cifar10) / 0.8 (cifar100),\n arch_factor = 1.0 (resnet32) / 0.9 (vgg16bn).\n\n Weights are normalized to sum to num_classes.\n \"\"\"\n total = class_counts.sum().float()\n imbalance_ratio = config['imbalance_ratio']\n dataset = config['dataset']\n arch = config['arch']\n\n # Base exponent\n base_exp = 0.5\n\n # Dataset factor\n if dataset == 'cifar100':\n dataset_factor = 0.8\n else:\n dataset_factor = 1.0\n\n # Architecture factor\n if arch == 'vgg16bn':\n arch_factor = 0.9\n else:\n arch_factor = 1.0\n\n # Adaptive exponent\n p = base_exp * dataset_factor * arch_factor / math.log10(imbalance_ratio)\n p = max(0.2, min(p, 1.0))\n\n # Median count\n median_count = torch.median(class_counts.float())\n\n # Compute weights\n weights = torch.ones(num_classes, dtype=torch.float32)\n for c in range(num_classes):\n cnt = class_counts[c].float()\n if cnt < median_count:\n # Boost tail classes\n ratio = median_count / cnt\n weights[c] = torch.pow(ratio, p)\n # else weight stays 1.0\n\n # Normalize so weights sum to num_classes\n weights = weights / weights.sum() * num_classes\n\n return weights"
7}OK: Replaced lines 164..219 with 58 line(s) in pytorch-vision/custom_weighting.py. Editable range: 164–221. [Current file: pytorch-vision/custom_weighting.py | editable: 164–221 | total: 432 lines] 161| # EDITABLE 162| # ============================================================================ 163| # -- EDITABLE REGION START (lines 164-195) ------------------------------------ 164 def compute_class_weights(class_counts, num_classes, config): 165 """Adaptive Piecewise Power Weighting (APPW) for long-tail classification. 166 167 Only tail classes (below median count) receive extra weight; head classes stay at 1. 168 This prevents over-weighting already frequent classes while giving targeted boost to tail. 169 170 For class c: 171 if count_c >= median_count: weight = 1.0 172 else: weight = (median_count / count_c) ^ p 173 174 Exponent p adapts to imbalance ratio and dataset: 175 p = base_
[Test #2] (1 test remaining; call submit(n=N) to choose which test result to submit as final) ### resnet32-cifar10lt (scripts/resnet32_cifar10lt.sh) Training progress (last epochs): TRAIN_METRICS: epoch=160 train_loss=0.0092 train_acc=99.81 test_loss=1.7541 test_acc=72.06 lr=0.009549 TRAIN_METRICS: epoch=170 train_loss=0.0042 train_acc=99.93 test_loss=1.9017 test_acc=70.75 lr=0.005450 TRAIN_METRICS: epoch=180 train_loss=0.0033 train_acc=99.94 test_loss=1.8469 test_acc=71.61 lr=0.002447 TRAIN_METRICS: epoch=190 train_loss=0.0022 train_acc=99.98 test_loss=1.8936 test_acc=71.25 lr=0.000616 TRAIN_METRICS: epoch=200 train_loss=0.0022 train_acc=99.98 test_loss=1.8365 test_acc=71.60 lr=0.000000 Final metrics (resnet32-cifar10lt): test_acc_resnet32-cifar10lt=72.56 ### resnet32-cifar100lt (scripts/resnet32_cifar100lt.sh) Training progress (last epochs): TRAIN_METRICS: epoch=160 train_loss=0.0223 train_acc=99.92 test_loss=3.6851 test_acc=40.31 lr=0.009549 TRAIN_METRICS: epoch=170 train_loss=0.
Tool Input
1{
2 "op": "replace",
3 "filename": "pytorch-vision/custom_weighting.py",
4 "start_line": 164,
5 "end_line": 221,
6 "content": "def compute_class_weights(class_counts, num_classes, config):\n \"\"\"Piecewise Effective Weighting (PEW) for long-tail classification.\n\n Combines effective number weighting (Cui et al.) with piecewise normalization:\n - Compute effective weight for each class: w_eff = (1 - beta) / (1 - beta^count)\n - Find median class (by count) and its effective weight w_med.\n - For classes with count >= median, set weight = 1 (no extra boost).\n - For classes with count < median, set weight = w_eff / w_med (relative boost).\n - Normalize all weights to sum to num_classes.\n\n This gives targeted boosting to tail classes based on effective number theory,\n while preventing over-weighting of head classes.\n\n Beta is fixed at 0.9999 as recommended for CIFAR-100 LT.\n \"\"\"\n imbalance_ratio = config['imbalance_ratio']\n dataset = config['dataset']\n arch = config['arch']\n\n # Effective number beta (Cui et al. CVPR 2019)\n beta = 0.9999\n\n # Compute effective weights\n counts_float = class_counts.float()\n effective = (1.0 - beta) / (1.0 - torch.pow(beta, counts_float))\n\n # Find median class by count (not by index)\n median_val = torch.median(counts_float)\n # Find class closest to median (could be multiple)\n median_idx = torch.argmin(torch.abs(counts_float - median_val))\n w_med = effective[median_idx]\n\n # Piecewise scaling\n weights = torch.ones(num_classes, dtype=torch.float32)\n for c in range(num_classes):\n if counts_float[c] < median_val:\n weights[c] = effective[c] / w_med\n # else weight stays 1.0\n\n # Normalize so weights sum to num_classes\n weights = weights / weights.sum() * num_classes\n\n return weights"
7}OK: Replaced lines 164..221 with 43 line(s) in pytorch-vision/custom_weighting.py. Editable range: 164–206. [Current file: pytorch-vision/custom_weighting.py | editable: 164–206 | total: 417 lines] 161| # EDITABLE 162| # ============================================================================ 163| # -- EDITABLE REGION START (lines 164-195) ------------------------------------ 164 def compute_class_weights(class_counts, num_classes, config): 165 """Piecewise Effective Weighting (PEW) for long-tail classification. 166 167 Combines effective number weighting (Cui et al.) with piecewise normalization: 168 - Compute effective weight for each class: w_eff = (1 - beta) / (1 - beta^count) 169 - Find median class (by count) and its effective weight w_med. 170 - For classes with count >= median, set weight = 1 (no extra boost). 171 - For classes with count < median, set weight = w_eff / w_med (relative boost). 172 - N
[Test #3] (0 tests remaining; call submit(n=N) to choose which test result to submit as final) [NOTE] This was your last test. You MUST now call submit(n=X) to choose which test result to submit as your final answer. ### resnet32-cifar10lt (scripts/resnet32_cifar10lt.sh) Training progress (last epochs): TRAIN_METRICS: epoch=160 train_loss=0.0096 train_acc=99.78 test_loss=2.3195 test_acc=71.56 lr=0.009549 TRAIN_METRICS: epoch=170 train_loss=0.0044 train_acc=99.95 test_loss=2.6320 test_acc=70.00 lr=0.005450 TRAIN_METRICS: epoch=180 train_loss=0.0038 train_acc=99.94 test_loss=2.5724 test_acc=70.36 lr=0.002447 TRAIN_METRICS: epoch=190 train_loss=0.0027 train_acc=99.98 test_loss=2.5496 test_acc=70.37 lr=0.000616 TRAIN_METRICS: epoch=200 train_loss=0.0029 train_acc=99.98 test_loss=2.4270 test_acc=71.29 lr=0.000000 Final metrics (resnet32-cifar10lt): test_acc_resnet32-cifar10lt=72.50 ### resnet32-cifar100lt (scripts/resnet32_cifar100lt.sh) Training progress (last epochs): TRAIN_METRICS: ep
Tool Input
1{
2 "n": 2
3}[submit] Submitting result from test #2 as final. ### resnet32-cifar10lt (scripts/resnet32_cifar10lt.sh) Training progress (last epochs): TRAIN_METRICS: epoch=160 train_loss=0.0092 train_acc=99.81 test_loss=1.7541 test_acc=72.06 lr=0.009549 TRAIN_METRICS: epoch=170 train_loss=0.0042 train_acc=99.93 test_loss=1.9017 test_acc=70.75 lr=0.005450 TRAIN_METRICS: epoch=180 train_loss=0.0033 train_acc=99.94 test_loss=1.8469 test_acc=71.61 lr=0.002447 TRAIN_METRICS: epoch=190 train_loss=0.0022 train_acc=99.98 test_loss=1.8936 test_acc=71.25 lr=0.000616 TRAIN_METRICS: epoch=200 train_loss=0.0022 train_acc=99.98 test_loss=1.8365 test_acc=71.60 lr=0.000000 Final metrics (resnet32-cifar10lt): test_acc_resnet32-cifar10lt=72.56 ### resnet32-cifar100lt (scripts/resnet32_cifar100lt.sh) Training progress (last epochs): TRAIN_METRICS: epoch=160 train_loss=0.0223 train_acc=99.92 test_loss=3.6851 test_acc=40.31 lr=0.009549 TRAIN_METRICS: epoch=170 train_loss=0.0190 train_acc=99.94 test_loss=3.7192 test_a