Agent Conversation: anthropic/claude-opus-4.6
Task: cv-diffusion-conditioning
Class-Conditional Diffusion: Conditioning Injection Methods
Background
Class-conditional diffusion models generate images conditioned on a class label. The key design choice is how to inject the class information into the UNet:
- Cross-Attention: Class embedding serves as key/value in a cross-attention layer after each ResBlock. Used in Stable Diffusion for text conditioning.
- Adaptive Normalization (AdaLN-Zero): Class embedding modulates LayerNorm with learned scale, shift, and gating parameters. Used in DiT.
- FiLM Conditioning: Class embedding is added to the timestep embedding and injected via adaptive GroupNorm (scale/shift) in ResBlocks.
Research Question
Which conditioning injection method achieves the best class-conditional FID on CIFAR-10?
Task
You are given custom_train.py, a self-contained class-conditional DDPM
training script with a small UNet on CIFAR-10 (32x32, 10 classes).
The editable region contains:
-
prepare_conditioning(time_emb, class_emb)— controls how class embedding is combined with the timestep embedding before entering ResBlocks. -
ClassConditioner(nn.Module)— an additional conditioning module applied after each ResBlock, enabling methods like cross-attention or adaptive norm.
Your goal is to design a conditioning injection method that achieves lower FID than the baselines.
Evaluation
- Dataset: CIFAR-10 (32x32, 10 classes)
- Model: UNet2DModel (diffusers backbone) at three scales:
- Small: block_out_channels=(64,128,128,128), ~9M params, batch 128
- Medium: block_out_channels=(128,256,256,256), ~36M params, batch 128
- Large: block_out_channels=(256,512,512,512), ~140M params, batch 64
- Training: 35000 steps per scale, AdamW lr=2e-4, EMA rate 0.9995
- Metric: FID (lower is better), computed with clean-fid against CIFAR-10 train set (50k samples)
- Inference: 50-step DDIM sampling (class-conditional)
Baselines
- concat-film: Class embedding added to timestep embedding, injected via FiLM (adaptive GroupNorm) in ResBlocks. Simplest method.
- cross-attn: Class embedding used as key/value in cross-attention layers after ResBlocks. Most expressive method.
- adanorm: Class embedding generates scale/shift/gate parameters for adaptive LayerNorm after ResBlocks. DiT-style method.
Action Budget (HARD LIMIT)
You have 20 total actions (every edit/test/submit/undo counts as 1) and 3 total test() calls. After 20 actions the run is force-stopped immediately. You MUST call submit(n=...) before the budget runs out — if no submit happens, your run records ZERO and all your work is lost. Plan accordingly: a typical good run uses ~15 edits, 3 tests, and 1 submit. Save submit budget for the end.
diffusers-main/custom_train.py [EDITABLE — lines 195–227 only]
1: """Class-Conditional DDPM Training on CIFAR-10.
2:
3: Uses diffusers UNet2DModel backbone (sameTool Input
1{
2 "op": "replace",
3 "filename": "diffusers-main/custom_train.py",
4 "start_line": 195,
5 "end_line": 227,
6 "content": "def prepare_conditioning(time_emb, class_emb):\n \"\"\"Dual-path gated fusion: combines additive and multiplicative interactions.\n \n Instead of simple addition (FiLM) or no fusion (cross-attn/adanorm),\n we use a content-dependent gating that allows the class embedding to\n both additively shift AND multiplicatively modulate the time embedding.\n This provides richer time-class interactions: at different timesteps,\n different class dimensions become more/less relevant.\n \n Args:\n time_emb: [B, time_embed_dim] timestep embedding\n class_emb: [B, time_embed_dim] class embedding\n\n Returns: [B, time_embed_dim] embedding used in ResBlocks\n \"\"\"\n # Multiplicative gating: class determines which time dimensions are active\n gate = torch.sigmoid(class_emb)\n # Combined: additive class info + multiplicative modulation of time\n return time_emb * (1.0 + 0.1 * gate) + class_emb\n\n\nclass ClassConditioner(nn.Module):\n \"\"\"Adaptive LayerNorm-Zero conditioning applied after each block.\n \n Combined with the FiLM-style injection in prepare_conditioning,\n this provides dual-path conditioning: class info enters both through\n the time embedding (into every ResBlock) AND through per-block AdaLN-Zero\n modulation. The AdaLN-Zero provides fine-grained, block-specific control\n while FiLM provides deep integration throughout the ResBlock internals.\n \"\"\"\n def __init__(self, channels, cond_dim):\n super().__init__()\n self.adaln = AdaLNBlock(channels, cond_dim)\n\n def forward(self, h, class_emb):\n return self.adaln(h, class_emb)\n"
7}OK: Replaced lines 195..227 with 36 line(s) in diffusers-main/custom_train.py. Editable range: 195–230. [Current file: diffusers-main/custom_train.py | editable: 195–230 | total: 613 lines] 192| # Conditioning injection (EDITABLE REGION) 193| # ============================================================================ 194| 195 def prepare_conditioning(time_emb, class_emb): 196 """Dual-path gated fusion: combines additive and multiplicative interactions. 197 198 Instead of simple addition (FiLM) or no fusion (cross-attn/adanorm), 199 we use a content-dependent gating that allows the class embedding to 200 both additively shift AND multiplicatively modulate the time embedding. 201 This provides richer time-class interactions: at different timesteps, 202 different class dimensions become more/less relevant. 203 204 Args: 205 time_emb: [B, time_embed_dim] timestep embedding 206
Tool Input
1{
2 "op": "replace",
3 "filename": "diffusers-main/custom_train.py",
4 "start_line": 195,
5 "end_line": 227,
6 "content": "def prepare_conditioning(time_emb, class_emb):\n \"\"\"Dual-pathway conditioning: inject class info through BOTH the FiLM path\n (adding to time_emb, which modulates ResBlock GroupNorm scale/shift) AND\n through the cross-attention ClassConditioner after each block.\n\n Motivation: FiLM provides global, channel-wise class conditioning through\n the ResBlock's built-in adaptive GroupNorm. Cross-attention provides\n spatially-varying, per-pixel class conditioning. These are complementary\n signals — combining them gives the model two orthogonal ways to use class\n information without any additional parameters.\n\n Args:\n time_emb: [B, time_embed_dim] timestep embedding\n class_emb: [B, time_embed_dim] class embedding\n\n Returns: [B, time_embed_dim] embedding used in ResBlocks\n \"\"\"\n return time_emb + class_emb\n\n\nclass ClassConditioner(nn.Module):\n \"\"\"Cross-attention pathway for spatial class conditioning.\n\n Combined with the FiLM pathway in prepare_conditioning, this creates\n a dual-pathway scheme: global (FiLM) + spatial (cross-attention).\n \"\"\"\n def __init__(self, channels, cond_dim):\n super().__init__()\n self.cross_attn = CrossAttentionLayer(channels, cond_dim, num_heads=4)\n\n def forward(self, h, class_emb):\n return self.cross_attn(h, class_emb)"
7}OK: Replaced lines 195..227 with 32 line(s) in diffusers-main/custom_train.py. Editable range: 195–226. [Current file: diffusers-main/custom_train.py | editable: 195–226 | total: 609 lines] 192| # Conditioning injection (EDITABLE REGION) 193| # ============================================================================ 194| 195 def prepare_conditioning(time_emb, class_emb): 196 """Dual-pathway conditioning: inject class info through BOTH the FiLM path 197 (adding to time_emb, which modulates ResBlock GroupNorm scale/shift) AND 198 through the cross-attention ClassConditioner after each block. 199 200 Motivation: FiLM provides global, channel-wise class conditioning through 201 the ResBlock's built-in adaptive GroupNorm. Cross-attention provides 202 spatially-varying, per-pixel class conditioning. These are complementary 203 signals — combining them gives the model two orthogonal ways to use class 204 i
Tool Input
1{
2 "op": "replace",
3 "filename": "diffusers-main/custom_train.py",
4 "start_line": 195,
5 "end_line": 227,
6 "content": "def prepare_conditioning(time_emb, class_emb):\n \"\"\"Dual-path conditioning: FiLM (additive to time_emb) + cross-attention.\n\n Adding class_emb to time_emb injects class info through the ResBlock\n adaptive GroupNorm (scale/shift), providing multiplicative channel\n modulation inside each ResBlock. This complements the cross-attention\n ClassConditioner which adds a residual bias after each block.\n \"\"\"\n return time_emb + class_emb\n\n\nclass ClassConditioner(nn.Module):\n \"\"\"Multi-token cross-attention with class-conditioned token expansion.\n\n Standard cross-attention with 1 KV token degenerates (softmax=1 always),\n making Q irrelevant. We expand the class embedding into K learnable\n context tokens, making the attention non-degenerate: different spatial\n positions can attend to different class aspects.\n\n Combined with FiLM (via prepare_conditioning), this gives dual-path\n class conditioning: FiLM for global channel modulation inside ResBlocks,\n multi-token cross-attention for spatially-adaptive conditioning after.\n \"\"\"\n def __init__(self, channels, cond_dim, num_tokens=4):\n super().__init__()\n self.num_tokens = num_tokens\n # Project class embedding to multiple context tokens\n self.token_proj = nn.Linear(cond_dim, num_tokens * cond_dim)\n # Cross-attention with multi-token context\n self.cross_attn = CrossAttentionLayer(channels, cond_dim, num_heads=4)\n\n def forward(self, h, class_emb):\n # Expand class_emb into multiple tokens: [B, cond_dim] -> [B, num_tokens, cond_dim]\n B = class_emb.shape[0]\n multi_ctx = self.token_proj(class_emb).view(B, self.num_tokens, -1)\n # Use multi-token cross attention\n return self._multi_token_cross_attn(h, multi_ctx)\n\n def _multi_token_cross_attn(self, x, context):\n \"\"\"Cross-attention with multiple KV tokens.\"\"\"\n B, C, H, W = x.shape\n ca = self.cross_attn\n h = ca.norm(x).view(B, C, -1).transpose(1, 2) # [B, HW, C]\n # context: [B, num_tokens, cond_dim]\n q = ca.q_proj(h).view(B, -1, ca.num_heads, ca.head_dim).transpose(1, 2)\n k = ca.k_proj(context).view(B, -1, ca.num_heads, ca.head_dim).transpose(1, 2)\n v = ca.v_proj(context).view(B, -1, ca.num_heads, ca.head_dim).transpose(1, 2)\n attn = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(ca.head_dim)\n attn = F.softmax(attn, dim=-1)\n out = torch.matmul(attn, v)\n out = out.transpose(1, 2).reshape(B, H * W, C)\n out = ca.out_proj(out)\n return x + out.transpose(1, 2).view(B, C, H, W)\n"
7}OK: Replaced lines 195..227 with 53 line(s) in diffusers-main/custom_train.py. Editable range: 195–247. [Current file: diffusers-main/custom_train.py | editable: 195–247 | total: 630 lines] 192| # Conditioning injection (EDITABLE REGION) 193| # ============================================================================ 194| 195 def prepare_conditioning(time_emb, class_emb): 196 """Dual-path conditioning: FiLM (additive to time_emb) + cross-attention. 197 198 Adding class_emb to time_emb injects class info through the ResBlock 199 adaptive GroupNorm (scale/shift), providing multiplicative channel 200 modulation inside each ResBlock. This complements the cross-attention 201 ClassConditioner which adds a residual bias after each block. 202 """ 203 return time_emb + class_emb 204 205 206 class ClassConditioner(nn.Module): 207 """Multi-token cross-attention with class-conditioned token expansi
Tool Input
1{
2 "content": "def prepare_conditioning(time_emb, class_emb):\n \"\"\"Dual-path conditioning: FiLM (additive to time_emb) + cross-attention.\n\n Adding class_emb to time_emb injects class info through the ResBlock\n adaptive GroupNorm (scale/shift), providing multiplicative channel\n modulation inside each ResBlock. This complements the cross-attention\n ClassConditioner which adds a residual bias after each block.\n \"\"\"\n return time_emb + class_emb\n\n\nclass ClassConditioner(nn.Module):\n \"\"\"Multi-token cross-attention with class-conditioned token expansion.\n\n Standard cross-attention with 1 KV token degenerates (softmax=1 always),\n making Q irrelevant. We expand the class embedding into K learnable\n context tokens via a lightweight bottleneck projection, making the\n attention non-degenerate: different spatial positions can attend to\n different class aspects.\n\n Combined with FiLM (via prepare_conditioning), this gives dual-path\n class conditioning: FiLM for global channel modulation inside ResBlocks,\n multi-token cross-attention for spatially-adaptive conditioning after.\n \"\"\"\n def __init__(self, channels, cond_dim, num_tokens=4):\n super().__init__()\n self.num_tokens = num_tokens\n self.cond_dim = cond_dim\n # Lightweight bottleneck to generate multiple tokens\n # bottleneck_dim keeps param count low\n bottleneck_dim = cond_dim // 4\n self.token_proj = nn.Sequential(\n nn.SiLU(),\n nn.Linear(cond_dim, bottleneck_dim),\n nn.SiLU(),\n nn.Linear(bottleneck_dim, num_tokens * cond_dim),\n )\n # Cross-attention with multi-token context\n self.cross_attn = CrossAttentionLayer(channels, cond_dim, num_heads=4)\n\n def forward(self, h, class_emb):\n # Expand class_emb into multiple tokens: [B, cond_dim] -> [B, num_tokens, cond_dim]\n B = class_emb.shape[0]\n multi_ctx = self.token_proj(class_emb).view(B, self.num_tokens, self.cond_dim)\n # Use multi-token cross attention\n return self._multi_token_cross_attn(h, multi_ctx)\n\n def _multi_token_cross_attn(self, x, context):\n \"\"\"Cross-attention with multiple KV tokens.\"\"\"\n B, C, H, W = x.shape\n ca = self.cross_attn\n h = ca.norm(x).view(B, C, -1).transpose(1, 2) # [B, HW, C]\n # context: [B, num_tokens, cond_dim]\n q = ca.q_proj(h).view(B, -1, ca.num_heads, ca.head_dim).transpose(1, 2)\n k = ca.k_proj(context).view(B, -1, ca.num_heads, ca.head_dim).transpose(1, 2)\n v = ca.v_proj(context).view(B, -1, ca.num_heads, ca.head_dim).transpose(1, 2)\n attn = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(ca.head_dim)\n attn = F.softmax(attn, dim=-1)\n out = torch.matmul(attn, v)\n out = out.transpose(1, 2).reshape(B, H * W, C)\n out = ca.out_proj(out)\n return x + out.transpose(1, 2).view(B, C, H, W)\n",
3 "end_line": 247,
4 "filename": "diffusers-main/custom_train.py",
5 "op": "replace",
6 "start_line": 195
7}OK: Replaced lines 195..247 with 62 line(s) in diffusers-main/custom_train.py. Editable range: 195–256. [Current file: diffusers-main/custom_train.py | editable: 195–256 | total: 639 lines] 192| # Conditioning injection (EDITABLE REGION) 193| # ============================================================================ 194| 195 def prepare_conditioning(time_emb, class_emb): 196 """Dual-path conditioning: FiLM (additive to time_emb) + cross-attention. 197 198 Adding class_emb to time_emb injects class info through the ResBlock 199 adaptive GroupNorm (scale/shift), providing multiplicative channel 200 modulation inside each ResBlock. This complements the cross-attention 201 ClassConditioner which adds a residual bias after each block. 202 """ 203 return time_emb + class_emb 204 205 206 class ClassConditioner(nn.Module): 207 """Multi-token cross-attention with class-conditioned token expansi
[Test #1] (2 tests remaining; call submit(n=1) to submit this result) ### train_small (scripts/train_small.sh) Training results: Model parameters: 10.5M | GPUs: 8 step 200/35000 | loss 0.0351 | 23.6s step 400/35000 | loss 0.0321 | 22.6s step 600/35000 | loss 0.0380 | 22.6s step 800/35000 | loss 0.0270 | 22.2s step 1000/35000 | loss 0.0405 | 22.6s step 1200/35000 | loss 0.0330 | 22.2s step 1400/35000 | loss 0.0318 | 22.6s step 1600/35000 | loss 0.0227 | 22.6s step 1800/35000 | loss 0.0319 | 22.7s step 2000/35000 | loss 0.0434 | 22.7s step 2200/35000 | loss 0.0428 | 22.3s step 2400/35000 | loss 0.0276 | 22.3s step 2600/35000 | loss 0.0464 | 22.6s step 2800/35000 | loss 0.0285 | 22.5s step 3000/35000 | loss 0.0346 | 22.3s step 3200/35000 | loss 0.0350 | 23.0s step 3400/35000 | loss 0.0325 | 22.3s step 3600/35000 | loss 0.0302 | 22.3s step 3800/35000 | loss 0.0352 | 22.5s step 4000/35000 | loss 0.0228 | 22.4s step 4200/35000 | loss 0.0346 | 22.4s step 4400/35000 | loss 0.0206 | 22.7s step
Tool Input
1{
2 "content": "def prepare_conditioning(time_emb, class_emb):\n \"\"\"Dual-path conditioning: FiLM (additive to time_emb) + cross-attention.\n\n Adding class_emb to time_emb injects class info through the ResBlock\n adaptive GroupNorm (scale/shift), providing multiplicative channel\n modulation inside each ResBlock. This complements the cross-attention\n ClassConditioner which adds a learned residual bias after each block.\n\n Motivation: FiLM and cross-attention are complementary conditioning paths.\n FiLM modulates feature statistics globally via GroupNorm affine params,\n while the post-block conditioner adds spatially-uniform but channel-wise\n adaptive residuals. Having both paths provides redundant but\n complementary class signal, improving conditioning robustness.\n \"\"\"\n return time_emb + class_emb\n\n\nclass ClassConditioner(nn.Module):\n \"\"\"Gated channel modulation: projects class embedding to scale, shift,\n and gate parameters applied after GroupNorm, with sigmoid gating for\n stable training.\n\n Unlike AdaLN-Zero which uses unbounded linear gating, we use sigmoid\n gating which is bounded in [0,1], providing better gradient flow and\n more stable interpolation between the identity (gate=0.5 at init with\n appropriate bias) and the modulated features.\n\n Combined with FiLM (via prepare_conditioning), this provides dual-path\n class conditioning: FiLM for in-ResBlock modulation, gated channel\n modulation for post-block refinement.\n \"\"\"\n def __init__(self, channels, cond_dim):\n super().__init__()\n self.norm = nn.GroupNorm(32, channels)\n # Project class embedding to scale, shift, gate\n self.proj = nn.Sequential(\n nn.SiLU(),\n nn.Linear(cond_dim, 3 * channels),\n )\n # Zero-init for stable start (identity mapping)\n nn.init.zeros_(self.proj[-1].weight)\n nn.init.zeros_(self.proj[-1].bias)\n # Set gate bias so sigmoid(0) = 0.5 initially -> start near identity\n # Actually with zero init, gate=sigmoid(0)=0.5, scale=0, shift=0\n # Result: x + 0.5 * (norm(x) * 1 + 0 - x) = 0.5*x + 0.5*norm(x)\n # Better: init gate bias to -3 so sigmoid(-3)≈0.05, nearly identity\n with torch.no_grad():\n self.proj[-1].bias.data[2*channels:] = -3.0\n\n def forward(self, h, class_emb):\n params = self.proj(class_emb)\n scale, shift, gate_logit = params.unsqueeze(-1).unsqueeze(-1).chunk(3, dim=1)\n gate = torch.sigmoid(gate_logit)\n normed = self.norm(h.float()).to(h.dtype)\n modulated = normed * (1 + scale) + shift\n return h + gate * (modulated - h)\n",
3 "end_line": 256,
4 "filename": "diffusers-main/custom_train.py",
5 "op": "replace",
6 "start_line": 195
7}OK: Replaced lines 195..256 with 56 line(s) in diffusers-main/custom_train.py. Editable range: 195–250. [Current file: diffusers-main/custom_train.py | editable: 195–250 | total: 633 lines] 192| # Conditioning injection (EDITABLE REGION) 193| # ============================================================================ 194| 195 def prepare_conditioning(time_emb, class_emb): 196 """Dual-path conditioning: FiLM (additive to time_emb) + cross-attention. 197 198 Adding class_emb to time_emb injects class info through the ResBlock 199 adaptive GroupNorm (scale/shift), providing multiplicative channel 200 modulation inside each ResBlock. This complements the cross-attention 201 ClassConditioner which adds a learned residual bias after each block. 202 203 Motivation: FiLM and cross-attention are complementary conditioning paths. 204 FiLM modulates feature statistics globally via GroupNorm affine params, 205
[Test #2] (1 test remaining; call submit(n=2) to submit this result) ### train_small (scripts/train_small.sh) Training results: Model parameters: 9.7M | GPUs: 8 step 200/35000 | loss 0.0349 | 20.3s step 400/35000 | loss 0.0321 | 18.8s step 600/35000 | loss 0.0383 | 18.7s step 800/35000 | loss 0.0270 | 18.3s step 1000/35000 | loss 0.0404 | 18.7s step 1200/35000 | loss 0.0333 | 18.5s step 1400/35000 | loss 0.0317 | 18.9s step 1600/35000 | loss 0.0226 | 18.7s step 1800/35000 | loss 0.0319 | 18.9s step 2000/35000 | loss 0.0434 | 18.7s step 2200/35000 | loss 0.0424 | 18.5s step 2400/35000 | loss 0.0274 | 18.3s step 2600/35000 | loss 0.0472 | 19.1s step 2800/35000 | loss 0.0281 | 18.7s step 3000/35000 | loss 0.0348 | 18.8s step 3200/35000 | loss 0.0350 | 19.1s step 3400/35000 | loss 0.0323 | 18.9s step 3600/35000 | loss 0.0301 | 18.6s step 3800/35000 | loss 0.0354 | 19.6s step 4000/35000 | loss 0.0227 | 18.9s step 4200/35000 | loss 0.0349 | 18.5s step 4400/35000 | loss 0.0204 | 19.1s step 4
Tool Input
1{
2 "content": "def prepare_conditioning(time_emb, class_emb):\n \"\"\"Dual-path conditioning: FiLM (additive to time_emb) + cross-attention.\n\n Adding class_emb to time_emb injects class info through the ResBlock\n adaptive GroupNorm (scale/shift), providing multiplicative channel\n modulation inside each ResBlock. This complements the cross-attention\n ClassConditioner which adds a residual bias after each block.\n\n Motivation: FiLM and cross-attention are complementary conditioning paths.\n FiLM modulates feature statistics globally via GroupNorm affine params,\n while cross-attention adds spatially-uniform but channel-wise adaptive\n residuals. Having both paths provides redundant but complementary class\n signal, improving conditioning robustness.\n \"\"\"\n return time_emb + class_emb\n\n\nclass ClassConditioner(nn.Module):\n \"\"\"Cross-attention with FiLM dual-path conditioning.\n\n Uses standard cross-attention (class embedding as key/value) combined\n with FiLM conditioning through the time embedding path. The dual-path\n approach provides complementary conditioning: FiLM for in-ResBlock\n channel modulation, cross-attention for post-block refinement.\n \"\"\"\n def __init__(self, channels, cond_dim):\n super().__init__()\n self.cross_attn = CrossAttentionLayer(channels, cond_dim, num_heads=4)\n\n def forward(self, h, class_emb):\n return self.cross_attn(h, class_emb)\n",
3 "end_line": 250,
4 "filename": "diffusers-main/custom_train.py",
5 "op": "replace",
6 "start_line": 195
7}OK: Replaced lines 195..250 with 31 line(s) in diffusers-main/custom_train.py. Editable range: 195–225. [Current file: diffusers-main/custom_train.py | editable: 195–225 | total: 608 lines] 192| # Conditioning injection (EDITABLE REGION) 193| # ============================================================================ 194| 195 def prepare_conditioning(time_emb, class_emb): 196 """Dual-path conditioning: FiLM (additive to time_emb) + cross-attention. 197 198 Adding class_emb to time_emb injects class info through the ResBlock 199 adaptive GroupNorm (scale/shift), providing multiplicative channel 200 modulation inside each ResBlock. This complements the cross-attention 201 ClassConditioner which adds a residual bias after each block. 202 203 Motivation: FiLM and cross-attention are complementary conditioning paths. 204 FiLM modulates feature statistics globally via GroupNorm affine params, 205 whil