Agent Conversation: deepseek-reasoner
Task: cv-diffusion-conditioning
Class-Conditional Diffusion: Conditioning Injection Methods
Background
Class-conditional diffusion models generate images conditioned on a class label. The key design choice is how to inject the class information into the UNet:
- Cross-Attention: Class embedding serves as key/value in a cross-attention layer after each ResBlock. Used in Stable Diffusion for text conditioning.
- Adaptive Normalization (AdaLN-Zero): Class embedding modulates LayerNorm with learned scale, shift, and gating parameters. Used in DiT.
- FiLM Conditioning: Class embedding is added to the timestep embedding and injected via adaptive GroupNorm (scale/shift) in ResBlocks.
Research Question
Which conditioning injection method achieves the best class-conditional FID on CIFAR-10?
Task
You are given custom_train.py, a self-contained class-conditional DDPM
training script with a small UNet on CIFAR-10 (32x32, 10 classes).
The editable region contains:
-
prepare_conditioning(time_emb, class_emb)— controls how class embedding is combined with the timestep embedding before entering ResBlocks. -
ClassConditioner(nn.Module)— an additional conditioning module applied after each ResBlock, enabling methods like cross-attention or adaptive norm.
Your goal is to design a conditioning injection method that achieves lower FID than the baselines.
Evaluation
- Dataset: CIFAR-10 (32x32, 10 classes)
- Model: UNet2DModel (diffusers backbone) at three scales:
- Small: block_out_channels=(64,128,128,128), ~9M params, batch 128
- Medium: block_out_channels=(128,256,256,256), ~36M params, batch 128
- Large: block_out_channels=(256,512,512,512), ~140M params, batch 64
- Training: 35000 steps per scale, AdamW lr=2e-4, EMA rate 0.9995
- Metric: FID (lower is better), computed with clean-fid against CIFAR-10 train set (50k samples)
- Inference: 50-step DDIM sampling (class-conditional)
Baselines
- concat-film: Class embedding added to timestep embedding, injected via FiLM (adaptive GroupNorm) in ResBlocks. Simplest method.
- cross-attn: Class embedding used as key/value in cross-attention layers after ResBlocks. Most expressive method.
- adanorm: Class embedding generates scale/shift/gate parameters for adaptive LayerNorm after ResBlocks. DiT-style method.
diffusers-main/custom_train.py [EDITABLE — lines 195–227 only]
1: """Class-Conditional DDPM Training on CIFAR-10.
2:
3: Uses diffusers UNet2DModel backbone (same architecture as google/ddpm-cifar10-32)
4: with configurable class-conditioning injection. Only the conditioning method
5: (prepare_conditioning + ClassConditioner) is editable.
6: """
7:
8: import copy
9: import math
10: import os
11: import sys
12: import time
13:
14: import numpy as np
15: import torch
16: import torch.distributed as dist
17: import torch.nn Tool Input
1{
2 "op": "replace",
3 "filename": "diffusers-main/custom_train.py",
4 "start_line": 195,
5 "end_line": 227,
6 "content": "def prepare_conditioning(time_emb, class_emb):\n \"\"\"Prepare the combined embedding used in ResBlocks.\n\n TODO: Implement your conditioning preparation here.\n\n Args:\n time_emb: [B, time_embed_dim] timestep embedding\n class_emb: [B, time_embed_dim] class embedding\n\n Returns: [B, time_embed_dim] embedding used in ResBlocks\n \"\"\"\n # Combined additive and multiplicative conditioning\n # Simple additive conditioning like concat-film\n return time_emb + class_emb\n\n\nclass ClassConditioner(nn.Module):\n \"\"\"Conditioning module applied after each UNet block.\n\n TODO: Implement your conditioning method here.\n\n Args (forward):\n h: [B, C, H, W] feature map\n class_emb: [B, time_embed_dim] class embedding\n\n Available utilities:\n CrossAttentionLayer(channels, context_dim, num_heads)\n AdaLNBlock(channels, cond_dim)\n \"\"\"\n def __init__(self, channels, cond_dim):\n super().__init__()\n # Cross-attention path\n self.cross_attn = CrossAttentionLayer(channels, cond_dim, num_heads=4)\n # Adaptive normalization path \n self.adaln = AdaLNBlock(channels, cond_dim)\n # Gating network: produces per-channel weight\n self.gate_proj = nn.Sequential(\n nn.SiLU(),\n nn.Linear(cond_dim, channels)\n )\n # Initialize gate output to 0.0 (sigmoid(0.0)=0.5 for equal mixing)\n nn.init.constant_(self.gate_proj[-1].weight, 0.0)\n nn.init.constant_(self.gate_proj[-1].bias, 0.0)\n \n def forward(self, h, class_emb):\n attn_out = self.cross_attn(h, class_emb) # = h + attn_res\n norm_out = self.adaln(h, class_emb) # = h + norm_res\n # Compute gating weights per channel\n gate = torch.sigmoid(self.gate_proj(class_emb)) # [B, C]\n gate = gate.view(gate.shape[0], gate.shape[1], 1, 1) # [B, C, 1, 1]\n # Combine residuals\n combined_res = gate * (attn_out - h) + (1 - gate) * (norm_out - h)\n return h + combined_res"
7}OK: Replaced lines 195..227 with 53 line(s) in diffusers-main/custom_train.py. Editable range: 195–247. [Current file: diffusers-main/custom_train.py | editable: 195–247 | total: 630 lines] 192| # Conditioning injection (EDITABLE REGION) 193| # ============================================================================ 194| 195 def prepare_conditioning(time_emb, class_emb): 196 """Prepare the combined embedding used in ResBlocks. 197 198 TODO: Implement your conditioning preparation here. 199 200 Args: 201 time_emb: [B, time_embed_dim] timestep embedding 202 class_emb: [B, time_embed_dim] class embedding 203 204 Returns: [B, time_embed_dim] embedding used in ResBlocks 205 """ 206 # Combined additive and multiplicative conditioning 207 # Simple additive conditioning like concat-film 208 return time_emb + class_emb 209 210 211 class ClassConditioner(nn.Mod
[Test #1] (2 tests remaining; call submit(n=1) to submit this result) ### train_small (scripts/train_small.sh) Training results: Model parameters: 10.8M | GPUs: 8 step 200/35000 | loss 0.0349 | 25.7s step 400/35000 | loss 0.0321 | 24.4s step 600/35000 | loss 0.0382 | 24.8s step 800/35000 | loss 0.0270 | 24.2s step 1000/35000 | loss 0.0401 | 24.7s step 1200/35000 | loss 0.0334 | 24.5s step 1400/35000 | loss 0.0317 | 24.9s step 1600/35000 | loss 0.0228 | 24.7s step 1800/35000 | loss 0.0323 | 24.5s step 2000/35000 | loss 0.0434 | 25.4s step 2200/35000 | loss 0.0429 | 24.5s step 2400/35000 | loss 0.0275 | 24.7s step 2600/35000 | loss 0.0466 | 24.7s step 2800/35000 | loss 0.0285 | 24.7s step 3000/35000 | loss 0.0348 | 24.7s step 3200/35000 | loss 0.0347 | 25.2s step 3400/35000 | loss 0.0325 | 24.3s step 3600/35000 | loss 0.0300 | 24.3s step 3800/35000 | loss 0.0354 | 25.3s step 4000/35000 | loss 0.0228 | 24.8s step 4200/35000 | loss 0.0346 | 24.8s step 4400/35000 | loss 0.0205 | 25.3s step
Tool Input
1{
2 "n": 1
3}Undo complete: Restored: diffusers-main/custom_train.py
Tool Input
1{
2 "op": "replace",
3 "filename": "diffusers-main/custom_train.py",
4 "start_line": 195,
5 "end_line": 227,
6 "content": "def prepare_conditioning(time_emb, class_emb):\n \"\"\"Prepare the combined embedding used in ResBlocks.\n\n TODO: Implement your conditioning preparation here.\n\n Args:\n time_emb: [B, time_embed_dim] timestep embedding\n class_emb: [B, time_embed_dim] class embedding\n\n Returns: [B, time_embed_dim] embedding used in ResBlocks\n \"\"\"\n # Simple additive conditioning like concat-film\n return time_emb + class_emb\n\n\nclass ClassConditioner(nn.Module):\n \"\"\"Conditioning module applied after each UNet block.\n\n TODO: Implement your conditioning method here.\n\n Args (forward):\n h: [B, C, H, W] feature map\n class_emb: [B, time_embed_dim] class embedding\n\n Available utilities:\n CrossAttentionLayer(channels, context_dim, num_heads)\n AdaLNBlock(channels, cond_dim)\n \"\"\"\n def __init__(self, channels, cond_dim):\n super().__init__()\n # Cross-attention layer\n self.cross_attn = CrossAttentionLayer(channels, cond_dim, num_heads=4)\n # Learn per-channel gating weight from class embedding\n self.gate_proj = nn.Linear(cond_dim, channels)\n # Initialize zero gate (start with no conditioning, let training learn)\n nn.init.zeros_(self.gate_proj.weight)\n nn.init.zeros_(self.gate_proj.bias)\n \n def forward(self, h, class_emb):\n attn_out = self.cross_attn(h, class_emb) # = h + attn_res\n # Compute per-channel gate\n gate = torch.sigmoid(self.gate_proj(class_emb)) # [B, C]\n gate = gate.view(gate.shape[0], gate.shape[1], 1, 1)\n # Blend between original h and attn_out\n # h + gate * (attn_out - h) = (1-gate)*h + gate*attn_out\n return h + gate * (attn_out - h)"
7}OK: Replaced lines 195..227 with 46 line(s) in diffusers-main/custom_train.py. Editable range: 195–240. [Current file: diffusers-main/custom_train.py | editable: 195–240 | total: 623 lines] 192| # Conditioning injection (EDITABLE REGION) 193| # ============================================================================ 194| 195 def prepare_conditioning(time_emb, class_emb): 196 """Prepare the combined embedding used in ResBlocks. 197 198 TODO: Implement your conditioning preparation here. 199 200 Args: 201 time_emb: [B, time_embed_dim] timestep embedding 202 class_emb: [B, time_embed_dim] class embedding 203 204 Returns: [B, time_embed_dim] embedding used in ResBlocks 205 """ 206 # Simple additive conditioning like concat-film 207 return time_emb + class_emb 208 209 210 class ClassConditioner(nn.Module): 211 """Conditioning module applied after each UNet
[Test #2] (1 test remaining; call submit(n=2) to submit this result) ### train_small (scripts/train_small.sh) Training results: Model parameters: 10.0M | GPUs: 8 step 200/35000 | loss 0.0350 | 23.2s step 400/35000 | loss 0.0322 | 22.4s step 600/35000 | loss 0.0386 | 21.9s step 800/35000 | loss 0.0272 | 22.7s step 1000/35000 | loss 0.0406 | 22.1s step 1200/35000 | loss 0.0334 | 22.3s step 1400/35000 | loss 0.0320 | 22.5s step 1600/35000 | loss 0.0230 | 22.0s step 1800/35000 | loss 0.0321 | 22.0s step 2000/35000 | loss 0.0430 | 22.2s step 2200/35000 | loss 0.0429 | 21.9s step 2400/35000 | loss 0.0276 | 22.1s step 2600/35000 | loss 0.0471 | 22.5s step 2800/35000 | loss 0.0284 | 22.0s step 3000/35000 | loss 0.0346 | 22.1s step 3200/35000 | loss 0.0347 | 22.3s step 3400/35000 | loss 0.0324 | 21.9s step 3600/35000 | loss 0.0304 | 22.0s step 3800/35000 | loss 0.0355 | 22.5s step 4000/35000 | loss 0.0226 | 21.9s step 4200/35000 | loss 0.0346 | 21.8s step 4400/35000 | loss 0.0207 | 22.2s step
Tool Input
1{
2 "n": 1
3}Undo complete: Restored: diffusers-main/custom_train.py
Tool Input
1{
2 "op": "replace",
3 "filename": "diffusers-main/custom_train.py",
4 "start_line": 195,
5 "end_line": 227,
6 "content": "def prepare_conditioning(time_emb, class_emb):\n \"\"\"Prepare the combined embedding used in ResBlocks.\n\n TODO: Implement your conditioning preparation here.\n\n Args:\n time_emb: [B, time_embed_dim] timestep embedding\n class_emb: [B, time_embed_dim] class embedding\n\n Returns: [B, time_embed_dim] embedding used in ResBlocks\n \"\"\"\n # Simple additive conditioning like concat-film\n return time_emb + class_emb\n\n\nclass ClassConditioner(nn.Module):\n \"\"\"Conditioning module applied after each UNet block.\n\n TODO: Implement your conditioning method here.\n\n Args (forward):\n h: [B, C, H, W] feature map\n class_emb: [B, time_embed_dim] class embedding\n\n Available utilities:\n CrossAttentionLayer(channels, context_dim, num_heads)\n AdaLNBlock(channels, cond_dim)\n \"\"\"\n def __init__(self, channels, cond_dim):\n super().__init__()\n # Adaptive LayerNorm-Zero conditioning\n self.adaln = AdaLNBlock(channels, cond_dim)\n # Self-gating: compute per-channel gate from feature map\n self.self_gate = nn.Sequential(\n nn.AdaptiveAvgPool2d(1),\n nn.Conv2d(channels, channels, kernel_size=1),\n nn.Sigmoid()\n )\n # Initialize conv weights zero so initial gate ~0.5\n nn.init.zeros_(self.self_gate[1].weight)\n nn.init.zeros_(self.self_gate[1].bias)\n \n def forward(self, h, class_emb):\n # Adaptive normalization output\n adaln_out = self.adaln(h, class_emb) # = h + gate * (norm - h)\n # Compute self-gate from original features\n gate = self.self_gate(h) # [B, C, 1, 1]\n # Blend between original h and conditioned output\n return h + gate * (adaln_out - h)"
7}OK: Replaced lines 195..227 with 49 line(s) in diffusers-main/custom_train.py. Editable range: 195–243. [Current file: diffusers-main/custom_train.py | editable: 195–243 | total: 626 lines] 192| # Conditioning injection (EDITABLE REGION) 193| # ============================================================================ 194| 195 def prepare_conditioning(time_emb, class_emb): 196 """Prepare the combined embedding used in ResBlocks. 197 198 TODO: Implement your conditioning preparation here. 199 200 Args: 201 time_emb: [B, time_embed_dim] timestep embedding 202 class_emb: [B, time_embed_dim] class embedding 203 204 Returns: [B, time_embed_dim] embedding used in ResBlocks 205 """ 206 # Simple additive conditioning like concat-film 207 return time_emb + class_emb 208 209 210 class ClassConditioner(nn.Module): 211 """Conditioning module applied after each UNet
[Test #3] (0 tests remaining) ### train_small (scripts/train_small.sh) Training results: Model parameters: 9.9M | GPUs: 8 step 200/35000 | loss 0.0348 | 20.6s step 400/35000 | loss 0.0320 | 19.8s step 600/35000 | loss 0.0382 | 20.0s step 800/35000 | loss 0.0268 | 19.8s step 1000/35000 | loss 0.0405 | 20.1s step 1200/35000 | loss 0.0332 | 19.7s step 1400/35000 | loss 0.0317 | 20.0s step 1600/35000 | loss 0.0228 | 19.3s step 1800/35000 | loss 0.0321 | 19.7s step 2000/35000 | loss 0.0435 | 19.3s step 2200/35000 | loss 0.0426 | 19.4s step 2400/35000 | loss 0.0275 | 19.4s step 2600/35000 | loss 0.0472 | 20.0s step 2800/35000 | loss 0.0283 | 20.1s step 3000/35000 | loss 0.0348 | 19.5s step 3200/35000 | loss 0.0348 | 20.5s step 3400/35000 | loss 0.0327 | 19.6s step 3600/35000 | loss 0.0297 | 19.5s step 3800/35000 | loss 0.0354 | 19.5s step 4000/35000 | loss 0.0228 | 19.7s step 4200/35000 | loss 0.0348 | 19.8s step 4400/35000 | loss 0.0207 | 21.2s step 4600/35000 | loss 0.0351 | 19.7s step 48
Tool Input
1{
2 "op": "replace",
3 "filename": "diffusers-main/custom_train.py",
4 "start_line": 195,
5 "end_line": 227,
6 "content": "def prepare_conditioning(time_emb, class_emb):\n \"\"\"Gated cross-attention: additive conditioning in time embedding.\"\"\"\n return time_emb + class_emb\n\n\nclass ClassConditioner(nn.Module):\n \"\"\"Gated cross-attention with channel-wise gating.\"\"\"\n def __init__(self, channels, cond_dim):\n super().__init__()\n self.cross_attn = CrossAttentionLayer(channels, cond_dim, num_heads=4)\n # Gate network: produce per-channel gating scalar\n self.gate_net = nn.Sequential(\n nn.SiLU(),\n nn.Linear(cond_dim, channels)\n )\n # Initialize gate output to ~1.0 (bias) so initially cross-attn contributes fully\n nn.init.zeros_(self.gate_net[-1].weight)\n nn.init.ones_(self.gate_net[-1].bias)\n\n def forward(self, h, class_emb):\n cross_out = self.cross_attn(h, class_emb)\n # Compute gate per channel, shape [B, C]\n gate = torch.sigmoid(self.gate_net(class_emb))\n # Reshape to [B, C, 1, 1] for broadcasting\n gate = gate.view(gate.shape[0], gate.shape[1], 1, 1)\n # Blend: h + gate * cross_out\n return h + gate * cross_out"
7}OK: Replaced lines 195..227 with 27 line(s) in diffusers-main/custom_train.py. Editable range: 195–221. [Current file: diffusers-main/custom_train.py | editable: 195–221 | total: 604 lines] 192| # Conditioning injection (EDITABLE REGION) 193| # ============================================================================ 194| 195 def prepare_conditioning(time_emb, class_emb): 196 """Gated cross-attention: additive conditioning in time embedding.""" 197 return time_emb + class_emb 198 199 200 class ClassConditioner(nn.Module): 201 """Gated cross-attention with channel-wise gating.""" 202 def __init__(self, channels, cond_dim): 203 super().__init__() 204 self.cross_attn = CrossAttentionLayer(channels, cond_dim, num_heads=4) 205 # Gate network: produce per-channel gating scalar 206 self.gate_net = nn.Sequential( 207 nn.SiLU(), 208 nn.Linear(cond_dim,