Agent Conversation: google/gemini-3.1-pro-preview
Task: cv-diffusion-conditioning
Class-Conditional Diffusion: Conditioning Injection Methods
Background
Class-conditional diffusion models generate images conditioned on a class label. The key design choice is how to inject the class information into the UNet:
- Cross-Attention: Class embedding serves as key/value in a cross-attention layer after each ResBlock. Used in Stable Diffusion for text conditioning.
- Adaptive Normalization (AdaLN-Zero): Class embedding modulates LayerNorm with learned scale, shift, and gating parameters. Used in DiT.
- FiLM Conditioning: Class embedding is added to the timestep embedding and injected via adaptive GroupNorm (scale/shift) in ResBlocks.
Research Question
Which conditioning injection method achieves the best class-conditional FID on CIFAR-10?
Task
You are given custom_train.py, a self-contained class-conditional DDPM
training script with a small UNet on CIFAR-10 (32x32, 10 classes).
The editable region contains:
-
prepare_conditioning(time_emb, class_emb)— controls how class embedding is combined with the timestep embedding before entering ResBlocks. -
ClassConditioner(nn.Module)— an additional conditioning module applied after each ResBlock, enabling methods like cross-attention or adaptive norm.
Your goal is to design a conditioning injection method that achieves lower FID than the baselines.
Evaluation
- Dataset: CIFAR-10 (32x32, 10 classes)
- Model: UNet2DModel (diffusers backbone) at three scales:
- Small: block_out_channels=(64,128,128,128), ~9M params, batch 128
- Medium: block_out_channels=(128,256,256,256), ~36M params, batch 128
- Large: block_out_channels=(256,512,512,512), ~140M params, batch 64
- Training: 35000 steps per scale, AdamW lr=2e-4, EMA rate 0.9995
- Metric: FID (lower is better), computed with clean-fid against CIFAR-10 train set (50k samples)
- Inference: 50-step DDIM sampling (class-conditional)
Baselines
Your method is evaluated on train_small and train_medium (train_large is skipped during agent runs). FID (lower is better) for the 3 baselines:
| Name | Strategy | fid (small) | fid (medium) |
|---|---|---|---|
concat-film | Class embedding added to timestep embedding, injected via FiLM (adaptive GroupNorm) in ResBlocks. Simplest method. | 19.39 | 11.50 |
cross-attn | Class embedding used as key/value in cross-attention layers after ResBlocks. Most expressive method. | 19.17 | 11.21 |
adanorm | Class embedding generates scale/shift/gate parameters for adaptive LayerNorm after ResBlocks. DiT-style. | 20.01 | 12.31 |
The strongest baseline is cross-attn at fid 19.17 (small) / 11.21 (medium). Your method must achieve lower fid than cross-attn on both train_small and train_medium to count as a successful improvement. A method that wins on one scale but regresses on the other is a partial fail — the task is about robust conditioning across model sizes.
Ac
Tool Input
1{
2 "start_line": 195,
3 "op": "replace",
4 "end_line": 227,
5 "filename": "custom_train.py",
6 "content": "def prepare_conditioning(time_emb, class_emb):\n return time_emb + class_emb\n\n\nclass ClassConditioner(nn.Module):\n def __init__(self, channels, cond_dim, num_heads=4, num_tokens=4):\n super().__init__()\n self.num_heads = num_heads\n self.head_dim = channels // num_heads\n self.num_tokens = num_tokens\n self.norm = nn.GroupNorm(32, channels)\n \n token_dim = cond_dim // num_tokens\n self.q_proj = nn.Linear(channels, channels)\n self.k_proj = nn.Linear(token_dim, channels)\n self.v_proj = nn.Linear(token_dim, channels)\n self.out_proj = zero_module(nn.Linear(channels, channels))\n\n def forward(self, h, class_emb):\n B, C, H, W = h.shape\n hx = self.norm(h).view(B, C, -1).transpose(1, 2)\n \n ctx = class_emb.view(B, self.num_tokens, -1)\n \n q = self.q_proj(hx).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)\n k = self.k_proj(ctx).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)\n v = self.v_proj(ctx).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)\n \n attn = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_dim)\n attn = F.softmax(attn, dim=-1)\n out = torch.matmul(attn, v)\n out = out.transpose(1, 2).reshape(B, H * W, C)\n out = self.out_proj(out)\n return h + out.transpose(1, 2).view(B, C, H, W)"
7}ERROR: Package 'custom_train.py' is not in allowed packages
Tool Input
1{
2 "filename": "diffusers-main/custom_train.py",
3 "op": "replace",
4 "start_line": 195,
5 "end_line": 228,
6 "content": "def prepare_conditioning(time_emb, class_emb):\n return time_emb + class_emb\n\n\nclass ClassConditioner(nn.Module):\n def __init__(self, channels, cond_dim, num_heads=4, num_tokens=4):\n super().__init__()\n self.num_heads = num_heads\n self.head_dim = channels // num_heads\n self.num_tokens = num_tokens\n self.norm = nn.GroupNorm(32, channels)\n \n token_dim = cond_dim // num_tokens\n self.q_proj = nn.Linear(channels, channels)\n self.k_proj = nn.Linear(token_dim, channels)\n self.v_proj = nn.Linear(token_dim, channels)\n self.out_proj = zero_module(nn.Linear(channels, channels))\n\n def forward(self, h, class_emb):\n B, C, H, W = h.shape\n hx = self.norm(h).view(B, C, -1).transpose(1, 2)\n \n ctx = class_emb.view(B, self.num_tokens, -1)\n \n q = self.q_proj(hx).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)\n k = self.k_proj(ctx).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)\n v = self.v_proj(ctx).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)\n \n attn = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_dim)\n attn = F.softmax(attn, dim=-1)\n out = torch.matmul(attn, v)\n out = out.transpose(1, 2).reshape(B, H * W, C)\n out = self.out_proj(out)\n return h + out.transpose(1, 2).view(B, C, H, W)"
7}ERROR: Lines 195..228 exceed the editable range. You may only edit lines 195–227.
[Current file: diffusers-main/custom_train.py | editable: 195–227 | total: 610 lines]
192| # Conditioning injection (EDITABLE REGION)
193| # ============================================================================
194|
195 def prepare_conditioning(time_emb, class_emb):
196 """Prepare the combined embedding used in ResBlocks.
197
198 TODO: Implement your conditioning preparation here.
199
200 Args:
201 time_emb: [B, time_embed_dim] timestep embedding
202 class_emb: [B, time_embed_dim] class embedding
203
204 Returns: [B, time_embed_dim] embedding used in ResBlocks
205 """
206 raise NotImplementedError("Implement prepare_conditioning")
207
208
209 class ClassConditioner(nn.Module):
210 """Conditioning module applied after each UNet block.
211
212 TODO: Implement your
Tool Input
1{
2 "content": "def prepare_conditioning(time_emb, class_emb):\n return time_emb + class_emb\n\n\nclass ClassConditioner(nn.Module):\n def __init__(self, channels, cond_dim, num_heads=4, num_tokens=4):\n super().__init__()\n self.num_heads = num_heads\n self.head_dim = channels // num_heads\n self.num_tokens = num_tokens\n self.norm = nn.GroupNorm(32, channels)\n \n token_dim = cond_dim // num_tokens\n self.q_proj = nn.Linear(channels, channels)\n self.k_proj = nn.Linear(token_dim, channels)\n self.v_proj = nn.Linear(token_dim, channels)\n self.out_proj = zero_module(nn.Linear(channels, channels))\n\n def forward(self, h, class_emb):\n B, C, H, W = h.shape\n hx = self.norm(h).view(B, C, -1).transpose(1, 2)\n \n ctx = class_emb.view(B, self.num_tokens, -1)\n \n q = self.q_proj(hx).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)\n k = self.k_proj(ctx).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)\n v = self.v_proj(ctx).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)\n \n attn = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_dim)\n attn = F.softmax(attn, dim=-1)\n out = torch.matmul(attn, v)\n out = out.transpose(1, 2).reshape(B, H * W, C)\n out = self.out_proj(out)\n return h + out.transpose(1, 2).view(B, C, H, W)",
3 "op": "replace",
4 "filename": "diffusers-main/custom_train.py",
5 "start_line": 195,
6 "end_line": 227
7}OK: Replaced lines 195..227 with 34 line(s) in diffusers-main/custom_train.py. Editable range: 195–228. [Current file: diffusers-main/custom_train.py | editable: 195–228 | total: 611 lines] 192| # Conditioning injection (EDITABLE REGION) 193| # ============================================================================ 194| 195 def prepare_conditioning(time_emb, class_emb): 196 return time_emb + class_emb 197 198 199 class ClassConditioner(nn.Module): 200 def __init__(self, channels, cond_dim, num_heads=4, num_tokens=4): 201 super().__init__() 202 self.num_heads = num_heads 203 self.head_dim = channels // num_heads 204 self.num_tokens = num_tokens 205 self.norm = nn.GroupNorm(32, channels) 206 207 token_dim = cond_dim // num_tokens 208 self.q_proj = nn.Linear(channels, channels) 209 self.k_proj = nn.Linear(token_dim, channels) 210
[Test #1] (2 tests remaining; call submit(n=1) to submit this result) ### train_small (scripts/train_small.sh) Training results: Model parameters: 9.3M | GPUs: 8 step 200/35000 | loss 0.0351 | 21.9s step 400/35000 | loss 0.0322 | 21.1s step 600/35000 | loss 0.0383 | 21.4s step 800/35000 | loss 0.0271 | 21.3s step 1000/35000 | loss 0.0408 | 21.4s step 1200/35000 | loss 0.0332 | 21.0s step 1400/35000 | loss 0.0317 | 21.3s step 1600/35000 | loss 0.0229 | 21.5s step 1800/35000 | loss 0.0323 | 21.3s step 2000/35000 | loss 0.0438 | 21.0s step 2200/35000 | loss 0.0429 | 21.0s step 2400/35000 | loss 0.0277 | 20.8s step 2600/35000 | loss 0.0471 | 21.1s step 2800/35000 | loss 0.0284 | 21.3s step 3000/35000 | loss 0.0347 | 21.1s step 3200/35000 | loss 0.0349 | 21.6s step 3400/35000 | loss 0.0325 | 21.0s step 3600/35000 | loss 0.0300 | 21.1s step 3800/35000 | loss 0.0355 | 21.5s step 4000/35000 | loss 0.0228 | 21.0s step 4200/35000 | loss 0.0346 | 21.1s step 4400/35000 | loss 0.0206 | 21.2s step
Tool Input
1{
2 "filename": "diffusers-main/custom_train.py",
3 "end_line": 228,
4 "content": "def prepare_conditioning(time_emb, class_emb):\n # Keep time embedding pure to prevent interference\n return time_emb\n\n\nclass ClassConditioner(nn.Module):\n \"\"\"Feature-Adaptive FiLM (FA-FiLM).\n \n Dynamically predicts scale and shift for the features based on both\n the global class embedding and the current spatial feature state.\n \"\"\"\n def __init__(self, channels, cond_dim):\n super().__init__()\n self.pool = nn.AdaptiveAvgPool2d(1)\n self.mlp = nn.Sequential(\n nn.Linear(channels + cond_dim, channels),\n nn.LayerNorm(channels),\n nn.SiLU(),\n nn.Linear(channels, 2 * channels)\n )\n # Zero-initialize for identity mapping at start\n nn.init.zeros_(self.mlp[-1].weight)\n nn.init.zeros_(self.mlp[-1].bias)\n\n def forward(self, h, class_emb):\n B, C, _, _ = h.shape\n # Extract current feature state\n pooled = self.pool(h).view(B, C)\n # Combine state with global class condition\n ctx = torch.cat([pooled, class_emb], dim=1)\n \n # Predict scale and shift\n params = self.mlp(ctx).view(B, 2, C, 1, 1)\n scale, shift = params.chunk(2, dim=1)\n scale = scale.squeeze(1)\n shift = shift.squeeze(1)\n \n return h * (1 + scale) + shift",
5 "start_line": 195,
6 "op": "replace"
7}OK: Replaced lines 195..228 with 38 line(s) in diffusers-main/custom_train.py. Editable range: 195–232. [Current file: diffusers-main/custom_train.py | editable: 195–232 | total: 615 lines] 192| # Conditioning injection (EDITABLE REGION) 193| # ============================================================================ 194| 195 def prepare_conditioning(time_emb, class_emb): 196 # Keep time embedding pure to prevent interference 197 return time_emb 198 199 200 class ClassConditioner(nn.Module): 201 """Feature-Adaptive FiLM (FA-FiLM). 202 203 Dynamically predicts scale and shift for the features based on both 204 the global class embedding and the current spatial feature state. 205 """ 206 def __init__(self, channels, cond_dim): 207 super().__init__() 208 self.pool = nn.AdaptiveAvgPool2d(1) 209 self.mlp = nn.Sequential( 210 nn.Linear(channels
[Test #2] (1 test remaining; call submit(n=2) to submit this result) ### train_small (scripts/train_small.sh) Training results: Model parameters: 9.6M | GPUs: 8 step 200/35000 | loss 0.0345 | 19.9s step 400/35000 | loss 0.0318 | 19.6s step 600/35000 | loss 0.0382 | 19.4s step 800/35000 | loss 0.0267 | 18.9s step 1000/35000 | loss 0.0402 | 18.6s step 1200/35000 | loss 0.0331 | 18.5s step 1400/35000 | loss 0.0318 | 19.2s step 1600/35000 | loss 0.0227 | 19.1s step 1800/35000 | loss 0.0319 | 19.1s step 2000/35000 | loss 0.0435 | 19.1s step 2200/35000 | loss 0.0425 | 19.3s step 2400/35000 | loss 0.0274 | 18.8s step 2600/35000 | loss 0.0472 | 19.3s step 2800/35000 | loss 0.0283 | 19.3s step 3000/35000 | loss 0.0346 | 19.1s step 3200/35000 | loss 0.0348 | 19.3s step 3400/35000 | loss 0.0322 | 19.1s step 3600/35000 | loss 0.0299 | 19.2s step 3800/35000 | loss 0.0352 | 19.6s step 4000/35000 | loss 0.0230 | 19.1s step 4200/35000 | loss 0.0349 | 18.8s step 4400/35000 | loss 0.0205 | 19.4s step 4
Tool Input
1{
2 "start_line": 195,
3 "end_line": 232,
4 "filename": "diffusers-main/custom_train.py",
5 "content": "def prepare_conditioning(time_emb, class_emb):\n # Keep time_emb pure to avoid timestep interference\n return time_emb\n\n\nclass ClassConditioner(nn.Module):\n \"\"\"Dual-Injection Conditioning: Global FiLM + Multi-Token Spatial Attention.\"\"\"\n def __init__(self, channels, cond_dim, num_heads=4, num_tokens=4):\n super().__init__()\n self.num_heads = num_heads\n self.head_dim = channels // num_heads\n self.num_tokens = num_tokens\n \n # 1. Global FiLM\n self.film = nn.Sequential(\n nn.SiLU(),\n nn.Linear(cond_dim, 2 * channels)\n )\n nn.init.zeros_(self.film[-1].weight)\n nn.init.zeros_(self.film[-1].bias)\n \n # 2. Multi-Token Cross Attention\n self.norm = nn.GroupNorm(32, channels)\n token_dim = cond_dim // num_tokens\n self.q_proj = nn.Linear(channels, channels)\n self.k_proj = nn.Linear(token_dim, channels)\n self.v_proj = nn.Linear(token_dim, channels)\n self.out_proj = zero_module(nn.Linear(channels, channels))\n\n def forward(self, h, class_emb):\n B, C, H, W = h.shape\n \n # Global modulation\n scale, shift = self.film(class_emb).unsqueeze(-1).unsqueeze(-1).chunk(2, dim=1)\n h_film = h * (1 + scale) + shift\n \n # Spatial Multi-Token routing\n hx = self.norm(h_film).view(B, C, -1).transpose(1, 2)\n ctx = class_emb.view(B, self.num_tokens, -1)\n \n q = self.q_proj(hx).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)\n k = self.k_proj(ctx).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)\n v = self.v_proj(ctx).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)\n \n attn = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_dim)\n attn = F.softmax(attn, dim=-1)\n out = torch.matmul(attn, v)\n out = out.transpose(1, 2).reshape(B, H * W, C)\n out = self.out_proj(out)\n \n return h_film + out.transpose(1, 2).view(B, C, H, W)",
6 "op": "replace"
7}OK: Replaced lines 195..232 with 51 line(s) in diffusers-main/custom_train.py. Editable range: 195–245. [Current file: diffusers-main/custom_train.py | editable: 195–245 | total: 628 lines] 192| # Conditioning injection (EDITABLE REGION) 193| # ============================================================================ 194| 195 def prepare_conditioning(time_emb, class_emb): 196 # Keep time_emb pure to avoid timestep interference 197 return time_emb 198 199 200 class ClassConditioner(nn.Module): 201 """Dual-Injection Conditioning: Global FiLM + Multi-Token Spatial Attention.""" 202 def __init__(self, channels, cond_dim, num_heads=4, num_tokens=4): 203 super().__init__() 204 self.num_heads = num_heads 205 self.head_dim = channels // num_heads 206 self.num_tokens = num_tokens 207 208 # 1. Global FiLM 209 self.film = nn.Sequential( 210
[Test #3] (0 tests remaining) ### train_small (scripts/train_small.sh) Training results: Model parameters: 9.9M | GPUs: 8 step 200/35000 | loss 0.0344 | 22.5s step 400/35000 | loss 0.0318 | 22.3s step 600/35000 | loss 0.0381 | 22.1s step 800/35000 | loss 0.0271 | 22.0s step 1000/35000 | loss 0.0404 | 23.0s step 1200/35000 | loss 0.0330 | 22.7s step 1400/35000 | loss 0.0316 | 22.9s step 1600/35000 | loss 0.0226 | 22.4s step 1800/35000 | loss 0.0319 | 22.8s step 2000/35000 | loss 0.0431 | 22.6s step 2200/35000 | loss 0.0427 | 22.1s step 2400/35000 | loss 0.0275 | 21.9s step 2600/35000 | loss 0.0472 | 22.3s step 2800/35000 | loss 0.0282 | 22.0s step 3000/35000 | loss 0.0346 | 22.2s step 3200/35000 | loss 0.0349 | 22.7s step 3400/35000 | loss 0.0326 | 22.4s step 3600/35000 | loss 0.0300 | 21.9s step 3800/35000 | loss 0.0355 | 22.1s step 4000/35000 | loss 0.0227 | 22.3s step 4200/35000 | loss 0.0348 | 22.3s step 4400/35000 | loss 0.0204 | 22.3s step 4600/35000 | loss 0.0352 | 22.2s step 48