Agent Conversation: deepseek-reasoner
Task: dl-weight-initialization
DL Weight Initialization Strategy Design
Research Question
Design a novel weight initialization strategy for deep convolutional neural networks that improves convergence speed and final test accuracy across different architectures and datasets.
Background
Weight initialization is fundamental to training deep neural networks. Poor initialization leads to vanishing/exploding gradients, slow convergence, or suboptimal generalization. Classic methods include:
- Kaiming/He (2015): Accounts for ReLU nonlinearity, N(0, sqrt(2/fan_out))
- Orthogonal (2014): Preserves gradient norms via orthogonal matrices
- Fixup (2019): Scales the last conv in each residual block by L^(-0.5) where L is the number of blocks, controlling variance accumulation across depth; zero-initializes the last BN per block so residual branches start near identity
However, these methods each address only one aspect of initialization. There is room to design strategies that jointly account for residual connections, batch normalization's re-scaling effect, depth-dependent scaling, and the interaction between different layer types (convolution vs classifier).
What You Can Modify
The initialize_weights(model, config) function (lines 147-180) in custom_init.py. This function receives the fully constructed model and a config dict, and must initialize all parameters.
You can modify:
- How
nn.Conv2dweights are initialized (distribution, fan-in/fan-out, gain) - How
nn.BatchNorm2dparameters (weight/bias) are initialized - How
nn.Linearweights and biases are initialized - Per-layer or depth-dependent scaling strategies
- Special handling for residual shortcut projections vs main-path convolutions
- Any data-independent initialization logic (no training data access)
The config dict provides: arch (str), num_classes (int), depth (int = number of Conv2d + Linear layers). You can also iterate over model.named_modules() or model.named_parameters().
Evaluation
- Metric: Best test accuracy (%, higher is better)
- Architectures & datasets:
- ResNet-56 on CIFAR-100 (deep residual, 100 classes)
- VGG-16-BN on CIFAR-100 (deep non-residual with BatchNorm, 100 classes)
- MobileNetV2 on FashionMNIST (lightweight inverted-residual, 10 classes) — hidden, evaluated on final submission only
- Training: SGD (lr=0.1, momentum=0.9, wd=5e-4), cosine annealing, 200 epochs
- Data augmentation: RandomCrop(32, pad=4) + RandomHorizontalFlip
pytorch-vision/custom_init.py [EDITABLE — lines 228–261 only]
1: """CV Weight Initialization Benchmark.
2:
3: Train vision models (ResNet, VGG, MobileNetV2) on CIFAR-10/100/FashionMNIST to evaluate
4: weight initialization strategies.
5:
6: FIXED: Model architectures, data pipeline, training loop.
7: EDITABLE: initialize_weights() function.
8:
9: Usage:
10: python custom_init.py --arch resnet20 --daTool Input
1{
2 "op": "replace",
3 "filename": "pytorch-vision/custom_init.py",
4 "start_line": 228,
5 "end_line": 261,
6 "content": "def initialize_weights(model, config):\n \"\"\"Residual-Aware Variance Scaling (RAVS) initialization.\n \n Combines depth-aware scaling for residual networks with adaptive\n BatchNorm initialization to improve convergence.\n \n Key ideas:\n 1. For Conv2d layers: scale variance by sqrt(2/(fan_out * depth_factor))\n where depth_factor = 1 + depth_index/total_depth for residual layers\n to gradually increase variance with depth.\n 2. For BatchNorm2d: weight=1, bias=small positive (0.01) to avoid dead ReLUs.\n 3. For Linear: Kaiming normal with fan_in.\n 4. Special handling for residual shortcut projections: use fan_in scaling.\n \"\"\"\n arch = config['arch']\n total_depth = config['depth']\n \n # First pass: assign depth indices to conv and linear layers\n layer_indices = {}\n idx = 0\n for name, module in model.named_modules():\n if isinstance(module, (nn.Conv2d, nn.Linear)):\n layer_indices[name] = idx\n idx += 1\n \n # Second pass: initialize\n for name, module in model.named_modules():\n if isinstance(module, nn.Conv2d):\n # Determine if this is a shortcut projection (1x1 conv)\n is_shortcut = (module.kernel_size == (1, 1)) and (module.stride == (1, 1) or module.stride == (2, 2))\n \n # Compute depth factor\n depth_idx = layer_indices.get(name, 0)\n depth_factor = 1.0 + (depth_idx / total_depth) # ranges from 1 to ~2\n \n # Use fan_out for regular convs, fan_in for shortcuts\n mode = 'fan_in' if is_shortcut else 'fan_out'\n gain = math.sqrt(2.0) # for ReLU\n \n # Adjust gain by depth factor\n gain /= math.sqrt(depth_factor)\n \n # Kaiming normal with adjusted gain\n nn.init.kaiming_normal_(module.weight, mode=mode, nonlinearity='relu')\n # Scale weights by gain / std(initialized weights)\n with torch.no_grad():\n std = module.weight.std().item()\n if std > 0:\n module.weight.mul_(gain / std)\n \n elif isinstance(module, nn.BatchNorm2d):\n nn.init.constant_(module.weight, 1.0)\n nn.init.constant_(module.bias, 0.01) # small positive to avoid dead ReLUs\n \n elif isinstance(module, nn.Linear):\n # Kaiming normal with fan_in\n nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')\n if module.bias is not None:\n nn.init.constant_(module.bias, 0.0)"
7}OK: Replaced lines 228..261 with 59 line(s) in pytorch-vision/custom_init.py. Editable range: 228–286. [Current file: pytorch-vision/custom_init.py | editable: 228–286 | total: 451 lines] 225| # ============================================================================ 226| 227| # -- EDITABLE REGION START (lines 228-261) ------------------------------------ 228 def initialize_weights(model, config): 229 """Residual-Aware Variance Scaling (RAVS) initialization. 230 231 Combines depth-aware scaling for residual networks with adaptive 232 BatchNorm initialization to improve convergence. 233 234 Key ideas: 235 1. For Conv2d layers: scale variance by sqrt(2/(fan_out * depth_factor)) 236 where depth_factor = 1 + depth_index/total_depth for residual layers 237 to gradually increase variance with depth. 238 2. For BatchNorm2d: weight=1, bias=small positive (0.01) to avoid dead ReLUs. 239
Tool Input
1{
2 "n": 1
3}Undo complete: Restored: pytorch-vision/custom_init.py
Tool Input
1{
2 "op": "replace",
3 "filename": "pytorch-vision/custom_init.py",
4 "start_line": 228,
5 "end_line": 261,
6 "content": "def initialize_weights(model, config):\n \"\"\"Depth-Adaptive Residual Scaling (DARS) initialization.\n \n For residual networks: scale conv weights by (1 + layer_depth/total_depth)^(-0.5)\n to gradually increase variance with depth while controlling accumulation.\n BatchNorm: weight=1, bias=small positive (0.01).\n Linear: Kaiming normal fan_in.\n \"\"\"\n arch = config['arch']\n total_depth = config['depth']\n \n # First pass: collect conv layers and assign depth indices\n conv_layers = []\n for name, module in model.named_modules():\n if isinstance(module, nn.Conv2d):\n conv_layers.append((name, module))\n \n # Initialize all convs with Kaiming normal first\n for name, module in model.named_modules():\n if isinstance(module, nn.Conv2d):\n nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')\n elif isinstance(module, nn.BatchNorm2d):\n nn.init.constant_(module.weight, 1.0)\n nn.init.constant_(module.bias, 0.01)\n elif isinstance(module, nn.Linear):\n nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')\n if module.bias is not None:\n nn.init.constant_(module.bias, 0.0)\n \n # Apply depth scaling for residual architectures\n if arch.startswith('resnet') or arch == 'mobilenetv2':\n for idx, (name, module) in enumerate(conv_layers):\n depth_factor = (1.0 + idx / total_depth) ** 0.5 # sqrt of depth factor\n module.weight.data.div_(depth_factor)"
7}OK: Replaced lines 228..261 with 34 line(s) in pytorch-vision/custom_init.py. Editable range: 228–261. [Current file: pytorch-vision/custom_init.py | editable: 228–261 | total: 426 lines] 225| # ============================================================================ 226| 227| # -- EDITABLE REGION START (lines 228-261) ------------------------------------ 228 def initialize_weights(model, config): 229 """Depth-Adaptive Residual Scaling (DARS) initialization. 230 231 For residual networks: scale conv weights by (1 + layer_depth/total_depth)^(-0.5) 232 to gradually increase variance with depth while controlling accumulation. 233 BatchNorm: weight=1, bias=small positive (0.01). 234 Linear: Kaiming normal fan_in. 235 """ 236 arch = config['arch'] 237 total_depth = config['depth'] 238 239 # First pass: collect conv layers and assign depth indices 240 conv_layers = [] 241
[Test #1] (2 tests remaining; call submit(n=N) to choose which test result to submit as final) ### resnet56-cifar100 (scripts/resnet56_cifar100.sh) Training progress (last epochs): TRAIN_METRICS: epoch=160 train_loss=0.3597 train_acc=88.70 test_loss=1.2788 test_acc=67.66 lr=0.009549 TRAIN_METRICS: epoch=170 train_loss=0.2088 train_acc=93.87 test_loss=1.2492 test_acc=70.95 lr=0.005450 TRAIN_METRICS: epoch=180 train_loss=0.1027 train_acc=97.53 test_loss=1.2272 test_acc=72.03 lr=0.002447 TRAIN_METRICS: epoch=190 train_loss=0.0625 train_acc=98.96 test_loss=1.2280 test_acc=72.46 lr=0.000616 TRAIN_METRICS: epoch=200 train_loss=0.0533 train_acc=99.24 test_loss=1.2187 test_acc=72.35 lr=0.000000 Final metrics (resnet56-cifar100): test_acc_resnet56-cifar100=72.61 ### vgg16bn-cifar100 (scripts/vgg16bn_cifar100.sh) Training progress (last epochs): TRAIN_METRICS: epoch=160 train_loss=0.2293 train_acc=93.34 test_loss=1.4907 test_acc=69.51 lr=0.009549 TRAIN_METRICS: epoch=170 train_loss=0.1004 trai
Tool Input
1{
2 "op": "replace",
3 "filename": "pytorch-vision/custom_init.py",
4 "start_line": 228,
5 "end_line": 261,
6 "content": "def initialize_weights(model, config):\n \"\"\"Residual-Aware Fixup++ initialization.\n \n For residual networks: apply Fixup scaling to conv layers inside residual blocks.\n Additionally, scale shortcut projections by sqrt(0.5).\n BatchNorm: weight=1, bias=0 (except last BN in block zero-init).\n Linear: Kaiming normal fan_in.\n \"\"\"\n arch = config['arch']\n is_residual = arch.startswith('resnet') or arch == 'mobilenetv2'\n \n # First pass: count residual blocks\n block_count = 0\n for m in model.modules():\n if isinstance(m, BasicBlock) or isinstance(m, InvertedResidual):\n block_count += 1\n \n # Second pass: initialize\n for m in model.modules():\n if isinstance(m, nn.Conv2d):\n nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n # Scale shortcuts\n if is_residual and m.kernel_size == (1, 1):\n m.weight.data.mul_(0.5**0.5) # sqrt(0.5)\n elif isinstance(m, nn.BatchNorm2d):\n nn.init.constant_(m.weight, 1.0)\n nn.init.constant_(m.bias, 0.0)\n elif isinstance(m, nn.Linear):\n nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')\n if m.bias is not None:\n nn.init.constant_(m.bias, 0.0)\n \n # Fixup scaling for residual blocks\n if is_residual and block_count > 0:\n fixup_scale = block_count ** (-0.5)\n for m in model.modules():\n if isinstance(m, BasicBlock):\n m.conv2.weight.data.mul_(fixup_scale)\n nn.init.constant_(m.bn2.weight, 0.0)\n elif isinstance(m, InvertedResidual):\n # Scale last conv in the block (the 1x1 pointwise conv)\n last_conv = None\n for sub in m.conv.modules():\n if isinstance(sub, nn.Conv2d):\n last_conv = sub\n if last_conv:\n last_conv.weight.data.mul_(fixup_scale)\n # Zero-init last BN in block\n for sub in reversed(list(m.conv.modules())):\n if isinstance(sub, nn.BatchNorm2d):\n nn.init.constant_(sub.weight, 0.0)\n break"
7}OK: Replaced lines 228..261 with 52 line(s) in pytorch-vision/custom_init.py. Editable range: 228–279.
[Current file: pytorch-vision/custom_init.py | editable: 228–279 | total: 444 lines]
225| # ============================================================================
226|
227| # -- EDITABLE REGION START (lines 228-261) ------------------------------------
228 def initialize_weights(model, config):
229 """Residual-Aware Fixup++ initialization.
230
231 For residual networks: apply Fixup scaling to conv layers inside residual blocks.
232 Additionally, scale shortcut projections by sqrt(0.5).
233 BatchNorm: weight=1, bias=0 (except last BN in block zero-init).
234 Linear: Kaiming normal fan_in.
235 """
236 arch = config['arch']
237 is_residual = arch.startswith('resnet') or arch == 'mobilenetv2'
238
239 # First pass: count residual blocks
240 block_count = 0
241 for
Tool Input
1{
2 "n": 1
3}Undo complete: Restored: pytorch-vision/custom_init.py
Tool Input
1{
2 "op": "replace",
3 "filename": "pytorch-vision/custom_init.py",
4 "start_line": 228,
5 "end_line": 261,
6 "content": "def initialize_weights(model, config):\n \"\"\"Depth-Adaptive Variance Scaling (DAVS) initialization.\n \n For all conv layers: scale variance by (1 + layer_depth/total_depth)^(-alpha)\n where alpha=0.25 provides gentle depth adjustment.\n BatchNorm: weight=1, bias=0.\n Linear: Kaiming normal fan_in.\n \"\"\"\n arch = config['arch']\n total_depth = config['depth']\n alpha = 0.25\n \n # Collect conv layers\n conv_layers = []\n for name, module in model.named_modules():\n if isinstance(module, nn.Conv2d):\n conv_layers.append(module)\n \n # Initialize all layers\n for module in model.modules():\n if isinstance(module, nn.Conv2d):\n nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')\n elif isinstance(module, nn.BatchNorm2d):\n nn.init.constant_(module.weight, 1.0)\n nn.init.constant_(module.bias, 0.0)\n elif isinstance(module, nn.Linear):\n nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')\n if module.bias is not None:\n nn.init.constant_(module.bias, 0.0)\n \n # Apply depth scaling\n for idx, conv in enumerate(conv_layers):\n depth_factor = (1.0 + idx / total_depth) ** alpha\n conv.weight.data.div_(depth_factor)"
7}OK: Replaced lines 228..261 with 34 line(s) in pytorch-vision/custom_init.py. Editable range: 228–261. [Current file: pytorch-vision/custom_init.py | editable: 228–261 | total: 426 lines] 225| # ============================================================================ 226| 227| # -- EDITABLE REGION START (lines 228-261) ------------------------------------ 228 def initialize_weights(model, config): 229 """Depth-Adaptive Variance Scaling (DAVS) initialization. 230 231 For all conv layers: scale variance by (1 + layer_depth/total_depth)^(-alpha) 232 where alpha=0.25 provides gentle depth adjustment. 233 BatchNorm: weight=1, bias=0. 234 Linear: Kaiming normal fan_in. 235 """ 236 arch = config['arch'] 237 total_depth = config['depth'] 238 alpha = 0.25 239 240 # Collect conv layers 241 conv_layers = [] 242 for name, module in model.named_modules(): 243
[Test #2] (1 test remaining; call submit(n=N) to choose which test result to submit as final) ### resnet56-cifar100 (scripts/resnet56_cifar100.sh) Training progress (last epochs): TRAIN_METRICS: epoch=160 train_loss=0.3792 train_acc=88.24 test_loss=1.2972 test_acc=67.69 lr=0.009549 TRAIN_METRICS: epoch=170 train_loss=0.2270 train_acc=93.35 test_loss=1.2566 test_acc=69.70 lr=0.005450 TRAIN_METRICS: epoch=180 train_loss=0.1200 train_acc=97.08 test_loss=1.2543 test_acc=71.28 lr=0.002447 TRAIN_METRICS: epoch=190 train_loss=0.0733 train_acc=98.68 test_loss=1.2497 test_acc=71.82 lr=0.000616 TRAIN_METRICS: epoch=200 train_loss=0.0653 train_acc=98.93 test_loss=1.2533 test_acc=72.22 lr=0.000000 Final metrics (resnet56-cifar100): test_acc_resnet56-cifar100=72.22 ### vgg16bn-cifar100 (scripts/vgg16bn_cifar100.sh) Training progress (last epochs): TRAIN_METRICS: epoch=160 train_loss=0.2519 train_acc=92.65 test_loss=1.4667 test_acc=69.03 lr=0.009549 TRAIN_METRICS: epoch=170 train_loss=0.1003 train
Tool Input
1{
2 "op": "replace",
3 "filename": "pytorch-vision/custom_init.py",
4 "start_line": 228,
5 "end_line": 261,
6 "content": "def initialize_weights(model, config):\n \"\"\"Residual Branch Balancing (RBB) initialization.\n \n For residual networks: scale last conv in each residual block by L^(-0.5)\n where L is number of blocks (Fixup). Additionally scale shortcut convs by L^(-0.5).\n Zero-init last BN in each block.\n For VGG: standard Kaiming.\n \"\"\"\n arch = config['arch']\n is_res = arch.startswith('resnet') or arch == 'mobilenetv2'\n \n # Count residual blocks\n L = 0\n for m in model.modules():\n if isinstance(m, BasicBlock) or isinstance(m, InvertedResidual):\n L += 1\n \n # Kaiming init for all\n for m in model.modules():\n if isinstance(m, nn.Conv2d):\n nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n elif isinstance(m, nn.BatchNorm2d):\n nn.init.constant_(m.weight, 1.0)\n nn.init.constant_(m.bias, 0.0)\n elif isinstance(m, nn.Linear):\n nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')\n if m.bias is not None:\n nn.init.constant_(m.bias, 0.0)\n \n if is_res and L > 0:\n scale = L ** (-0.5)\n # Scale residual block convs and shortcuts\n for m in model.modules():\n if isinstance(m, BasicBlock):\n m.conv2.weight.data.mul_(scale)\n nn.init.constant_(m.bn2.weight, 0.0)\n # shortcut conv if present\n if len(m.shortcut) > 0:\n for subm in m.shortcut:\n if isinstance(subm, nn.Conv2d):\n subm.weight.data.mul_(scale)\n elif isinstance(m, InvertedResidual):\n # Find last conv in block\n last_conv = None\n for subm in m.conv.modules():\n if isinstance(subm, nn.Conv2d):\n last_conv = subm\n if last_conv:\n last_conv.weight.data.mul_(scale)\n # Zero-init last BN in block\n for subm in reversed(list(m.conv.modules())):\n if isinstance(subm, nn.BatchNorm2d):\n nn.init.constant_(subm.weight, 0.0)\n break"
7}OK: Replaced lines 228..261 with 54 line(s) in pytorch-vision/custom_init.py. Editable range: 228–281.
[Current file: pytorch-vision/custom_init.py | editable: 228–281 | total: 446 lines]
225| # ============================================================================
226|
227| # -- EDITABLE REGION START (lines 228-261) ------------------------------------
228 def initialize_weights(model, config):
229 """Residual Branch Balancing (RBB) initialization.
230
231 For residual networks: scale last conv in each residual block by L^(-0.5)
232 where L is number of blocks (Fixup). Additionally scale shortcut convs by L^(-0.5).
233 Zero-init last BN in each block.
234 For VGG: standard Kaiming.
235 """
236 arch = config['arch']
237 is_res = arch.startswith('resnet') or arch == 'mobilenetv2'
238
239 # Count residual blocks
240 L = 0
241 for m in model.modules():
242