Agent Conversation: qwen/qwen3.6-plus
Task: cv-diffusion-conditioning
Class-Conditional Diffusion: Conditioning Injection Methods
Background
Class-conditional diffusion models generate images conditioned on a class label. The key design choice is how to inject the class information into the UNet:
- Cross-Attention: Class embedding serves as key/value in a cross-attention layer after each ResBlock. Used in Stable Diffusion for text conditioning.
- Adaptive Normalization (AdaLN-Zero): Class embedding modulates LayerNorm with learned scale, shift, and gating parameters. Used in DiT.
- FiLM Conditioning: Class embedding is added to the timestep embedding and injected via adaptive GroupNorm (scale/shift) in ResBlocks.
Research Question
Which conditioning injection method achieves the best class-conditional FID on CIFAR-10?
Task
You are given custom_train.py, a self-contained class-conditional DDPM
training script with a small UNet on CIFAR-10 (32x32, 10 classes).
The editable region contains:
-
prepare_conditioning(time_emb, class_emb)— controls how class embedding is combined with the timestep embedding before entering ResBlocks. -
ClassConditioner(nn.Module)— an additional conditioning module applied after each ResBlock, enabling methods like cross-attention or adaptive norm.
Your goal is to design a conditioning injection method that achieves lower FID than the baselines.
Evaluation
- Dataset: CIFAR-10 (32x32, 10 classes)
- Model: UNet2DModel (diffusers backbone) at three scales:
- Small: block_out_channels=(64,128,128,128), ~9M params, batch 128
- Medium: block_out_channels=(128,256,256,256), ~36M params, batch 128
- Large: block_out_channels=(256,512,512,512), ~140M params, batch 64
- Training: 35000 steps per scale, AdamW lr=2e-4, EMA rate 0.9995
- Metric: FID (lower is better), computed with clean-fid against CIFAR-10 train set (50k samples)
- Inference: 50-step DDIM sampling (class-conditional)
Baselines
Your method is evaluated on train_small and train_medium (train_large is skipped during agent runs). FID (lower is better) for the 3 baselines:
| Name | Strategy | fid (small) | fid (medium) |
|---|---|---|---|
concat-film | Class embedding added to timestep embedding, injected via FiLM (adaptive GroupNorm) in ResBlocks. Simplest method. | 19.39 | 11.50 |
cross-attn | Class embedding used as key/value in cross-attention layers after ResBlocks. Most expressive method. | 19.17 | 11.21 |
adanorm | Class embedding generates scale/shift/gate parameters for adaptive LayerNorm after ResBlocks. DiT-style. | 20.01 | 12.31 |
The strongest baseline is cross-attn at fid 19.17 (small) / 11.21 (medium). Your method must achieve lower fid than cross-attn on both train_small and train_medium to count as a successful improvement. A method that wins on one scale but regresses on the other is a partial fail — the task is about robust conditioning across model sizes.
Ac
Tool Input
1{
2 "op": "replace",
3 "filename": "diffusers-main/custom_train.py",
4 "start_line": 195,
5 "end_line": 227,
6 "content": "def prepare_conditioning(time_emb, class_emb):\n \"\"\"Gated fusion of time and class embeddings.\n \n The class embedding modulates the time embedding via a sigmoid gate,\n creating a synergistic combined embedding that carries both temporal\n (noise level) and semantic (class) information to ResBlocks.\n \"\"\"\n gate = torch.sigmoid(class_emb)\n return time_emb * (1 - gate * 0.5) + class_emb * gate * 0.5\n\n\nclass ClassConditioner(nn.Module):\n \"\"\"Cross-attention with complementary FiLM modulation.\n \n Combines spatial cross-attention (which learns WHERE to attend to class info)\n with FiLM-style channel modulation (which learns WHICH channels to emphasize).\n The FiLM path is zero-initialized and provides a complementary signal to\n the cross-attention, improving conditioning expressiveness within budget.\n \"\"\"\n def __init__(self, channels, cond_dim):\n super().__init__()\n self.cross_attn = CrossAttentionLayer(channels, cond_dim, num_heads=4)\n # Lightweight FiLM scale modulation (zero-initialized)\n self.film_proj = nn.Linear(cond_dim, channels)\n nn.init.zeros_(self.film_proj.weight)\n nn.init.zeros_(self.film_proj.bias)\n \n def forward(self, h, class_emb):\n # Cross-attention path: spatial features attend to class embedding\n attn_out = self.cross_attn(h, class_emb) # = h + residual_attn\n # FiLM path: channel-wise scale modulation\n scale = self.film_proj(class_emb).unsqueeze(-1).unsqueeze(-1)\n # Combine: cross-attn residual + FiLM-modulated skip connection\n # = residual_attn + h * (1 + scale)\n return attn_out - h + h * (1 + scale)"
7}OK: Replaced lines 195..227 with 35 line(s) in diffusers-main/custom_train.py. Editable range: 195–229. [Current file: diffusers-main/custom_train.py | editable: 195–229 | total: 612 lines] 192| # Conditioning injection (EDITABLE REGION) 193| # ============================================================================ 194| 195 def prepare_conditioning(time_emb, class_emb): 196 """Gated fusion of time and class embeddings. 197 198 The class embedding modulates the time embedding via a sigmoid gate, 199 creating a synergistic combined embedding that carries both temporal 200 (noise level) and semantic (class) information to ResBlocks. 201 """ 202 gate = torch.sigmoid(class_emb) 203 return time_emb * (1 - gate * 0.5) + class_emb * gate * 0.5 204 205 206 class ClassConditioner(nn.Module): 207 """Cross-attention with complementary FiLM modulation. 208 209 Combines spatial
[Test #1] (2 tests remaining; call submit(n=1) to submit this result) ### train_small (scripts/train_small.sh) Training results: Model parameters: 10.0M | GPUs: 8 step 200/35000 | loss 0.0346 | 22.9s step 400/35000 | loss 0.0321 | 22.0s step 600/35000 | loss 0.0385 | 22.3s step 800/35000 | loss 0.0271 | 21.9s step 1000/35000 | loss 0.0406 | 21.9s step 1200/35000 | loss 0.0334 | 21.8s step 1400/35000 | loss 0.0318 | 22.4s step 1600/35000 | loss 0.0227 | 22.1s step 1800/35000 | loss 0.0319 | 22.6s step 2000/35000 | loss 0.0432 | 22.8s step 2200/35000 | loss 0.0427 | 22.4s step 2400/35000 | loss 0.0275 | 22.5s step 2600/35000 | loss 0.0472 | 22.3s step 2800/35000 | loss 0.0282 | 22.0s step 3000/35000 | loss 0.0347 | 21.9s step 3200/35000 | loss 0.0347 | 22.4s step 3400/35000 | loss 0.0324 | 22.3s step 3600/35000 | loss 0.0303 | 22.2s step 3800/35000 | loss 0.0355 | 22.6s step 4000/35000 | loss 0.0226 | 22.3s step 4200/35000 | loss 0.0345 | 22.0s step 4400/35000 | loss 0.0205 | 22.2s step
Tool Input
1{
2 "op": "replace",
3 "filename": "diffusers-main/custom_train.py",
4 "start_line": 195,
5 "end_line": 229,
6 "content": "def prepare_conditioning(time_emb, class_emb):\n \"\"\"Dual-pathway: inject class info into time embedding.\n \n The class embedding is added to the time embedding, creating a combined\n embedding that carries class information into every ResBlock's FiLM layers.\n This works synergistically with the spatial cross-attention in ClassConditioner.\n \"\"\"\n return time_emb + class_emb\n\n\nclass ClassConditioner(nn.Module):\n \"\"\"Cross-attention with class-dependent residual gating.\n \n Combines spatial cross-attention (WHERE to attend) with a learned\n per-channel gating mechanism that controls how much cross-attention\n signal passes through each feature channel. The gate is predicted from\n the class embedding, enabling class-specific conditioning strength.\n \"\"\"\n def __init__(self, channels, cond_dim):\n super().__init__()\n self.cross_attn = CrossAttentionLayer(channels, cond_dim, num_heads=4)\n # Class-dependent residual gate (zero-initialized for stability)\n self.gate_proj = nn.Sequential(nn.SiLU(), nn.Linear(cond_dim, channels))\n nn.init.zeros_(self.gate_proj[-1].weight)\n nn.init.zeros_(self.gate_proj[-1].bias)\n \n def forward(self, h, class_emb):\n # Spatial cross-attention: features attend to class embedding\n attn_out = self.cross_attn(h, class_emb) # = h + residual_attn\n # Extract the attention residual\n residual = attn_out - h\n # Class-dependent gate: sigmoid ensures 0..1 range\n gate = torch.sigmoid(self.gate_proj(class_emb)).unsqueeze(-1).unsqueeze(-1)\n # Gated residual: class controls how much attention signal flows\n return h + residual * gate"
7}OK: Replaced lines 195..229 with 35 line(s) in diffusers-main/custom_train.py. Editable range: 195–229. [Current file: diffusers-main/custom_train.py | editable: 195–229 | total: 612 lines] 192| # Conditioning injection (EDITABLE REGION) 193| # ============================================================================ 194| 195 def prepare_conditioning(time_emb, class_emb): 196 """Dual-pathway: inject class info into time embedding. 197 198 The class embedding is added to the time embedding, creating a combined 199 embedding that carries class information into every ResBlock's FiLM layers. 200 This works synergistically with the spatial cross-attention in ClassConditioner. 201 """ 202 return time_emb + class_emb 203 204 205 class ClassConditioner(nn.Module): 206 """Cross-attention with class-dependent residual gating. 207 208 Combines spatial cross-attention (WHERE to attend)
Tool Input
1{
2 "content": "def prepare_conditioning(time_emb, class_emb):\n \"\"\"Keep time embedding unchanged; class conditioning via ClassConditioner only.\"\"\"\n return time_emb\n\n\nclass ClassConditioner(nn.Module):\n \"\"\"Cross-attention + FiLM: complementary additive and multiplicative conditioning.\n \n Cross-attention provides a class-dependent additive signal to spatial features.\n FiLM provides a class-dependent multiplicative signal to channels.\n Together they give richer class conditioning than either alone.\n FiLM is zero-initialized (scale=1, shift=0) for stable training start.\n \"\"\"\n def __init__(self, channels, cond_dim):\n super().__init__()\n # Spatial conditioning via cross-attention\n self.cross_attn = CrossAttentionLayer(channels, cond_dim, num_heads=4)\n # Channel conditioning via FiLM (scale + shift from class embedding)\n self.film_proj = nn.Linear(cond_dim, 2 * channels)\n # Initialize to identity: scale=1, shift=0\n nn.init.zeros_(self.film_proj.weight)\n nn.init.ones_(self.film_proj.bias[:channels])\n nn.init.zeros_(self.film_proj.bias[channels:])\n \n def forward(self, h, class_emb):\n # Cross-attention: class-dependent additive signal\n attn_out = self.cross_attn(h, class_emb)\n # FiLM: class-dependent multiplicative signal\n scale, shift = self.film_proj(class_emb).chunk(2, dim=-1)\n scale = scale.unsqueeze(-1).unsqueeze(-1)\n shift = shift.unsqueeze(-1).unsqueeze(-1)\n return attn_out * scale + shift",
3 "end_line": 229,
4 "filename": "diffusers-main/custom_train.py",
5 "op": "replace",
6 "start_line": 195
7}OK: Replaced lines 195..229 with 32 line(s) in diffusers-main/custom_train.py. Editable range: 195–226. [Current file: diffusers-main/custom_train.py | editable: 195–226 | total: 609 lines] 192| # Conditioning injection (EDITABLE REGION) 193| # ============================================================================ 194| 195 def prepare_conditioning(time_emb, class_emb): 196 """Keep time embedding unchanged; class conditioning via ClassConditioner only.""" 197 return time_emb 198 199 200 class ClassConditioner(nn.Module): 201 """Cross-attention + FiLM: complementary additive and multiplicative conditioning. 202 203 Cross-attention provides a class-dependent additive signal to spatial features. 204 FiLM provides a class-dependent multiplicative signal to channels. 205 Together they give richer class conditioning than either alone. 206 FiLM is zero-initialized (scale=1, shift=0) for stable t
Tool Input
1{
2 "content": "def prepare_conditioning(time_emb, class_emb):\n \"\"\"Dual-pathway: class embedding is added to time embedding for FiLM ResBlocks.\"\"\"\n return time_emb + class_emb\n\n\nclass ClassConditioner(nn.Module):\n \"\"\"Cross-attention + FiLM: complementary additive and multiplicative conditioning.\n \n Cross-attention provides a class-dependent additive signal to spatial features.\n FiLM provides a class-dependent multiplicative signal to channels.\n Together they give richer class conditioning than either alone.\n FiLM is zero-initialized (scale=1, shift=0) for stable training start.\n \"\"\"\n def __init__(self, channels, cond_dim):\n super().__init__()\n # Spatial conditioning via cross-attention\n self.cross_attn = CrossAttentionLayer(channels, cond_dim, num_heads=4)\n # Channel conditioning via FiLM (scale + shift from class embedding)\n self.film_proj = nn.Linear(cond_dim, 2 * channels)\n # Initialize to identity: scale=1, shift=0\n nn.init.zeros_(self.film_proj.weight)\n nn.init.ones_(self.film_proj.bias[:channels])\n nn.init.zeros_(self.film_proj.bias[channels:])\n \n def forward(self, h, class_emb):\n # Cross-attention: class-dependent additive signal\n attn_out = self.cross_attn(h, class_emb)\n # FiLM: class-dependent multiplicative signal\n scale, shift = self.film_proj(class_emb).chunk(2, dim=-1)\n scale = scale.unsqueeze(-1).unsqueeze(-1)\n shift = shift.unsqueeze(-1).unsqueeze(-1)\n return attn_out * scale + shift",
3 "end_line": 229,
4 "filename": "diffusers-main/custom_train.py",
5 "op": "replace",
6 "start_line": 195
7}ERROR: Lines 195..229 exceed the editable range. You may only edit lines 195–226. [Current file: diffusers-main/custom_train.py | editable: 195–226 | total: 609 lines] 192| # Conditioning injection (EDITABLE REGION) 193| # ============================================================================ 194| 195 def prepare_conditioning(time_emb, class_emb): 196 """Keep time embedding unchanged; class conditioning via ClassConditioner only.""" 197 return time_emb 198 199 200 class ClassConditioner(nn.Module): 201 """Cross-attention + FiLM: complementary additive and multiplicative conditioning. 202 203 Cross-attention provides a class-dependent additive signal to spatial features. 204 FiLM provides a class-dependent multiplicative signal to channels. 205 Together they give richer class conditioning than either alone. 206 FiLM is zero-initialized (scale=1, shift=0) for stable training start. 207
Tool Input
1{
2 "end_line": 226,
3 "filename": "diffusers-main/custom_train.py",
4 "op": "replace",
5 "start_line": 195,
6 "content": "def prepare_conditioning(time_emb, class_emb):\n \"\"\"Dual-pathway: inject class embedding into time embedding for ResBlock FiLM layers.\"\"\"\n return time_emb + class_emb\n\n\nclass ClassConditioner(nn.Module):\n \"\"\"Cross-attention for spatial class conditioning.\n \n Class embedding serves as key/value in cross-attention, allowing each\n spatial position to attend to class-specific information. This provides\n a complementary conditioning pathway to the FiLM injection via time_emb.\n \"\"\"\n def __init__(self, channels, cond_dim):\n super().__init__()\n self.cross_attn = CrossAttentionLayer(channels, cond_dim, num_heads=4)\n \n def forward(self, h, class_emb):\n return self.cross_attn(h, class_emb)"
7}OK: Replaced lines 195..226 with 18 line(s) in diffusers-main/custom_train.py. Editable range: 195–212. [Current file: diffusers-main/custom_train.py | editable: 195–212 | total: 595 lines] 192| # Conditioning injection (EDITABLE REGION) 193| # ============================================================================ 194| 195 def prepare_conditioning(time_emb, class_emb): 196 """Dual-pathway: inject class embedding into time embedding for ResBlock FiLM layers.""" 197 return time_emb + class_emb 198 199 200 class ClassConditioner(nn.Module): 201 """Cross-attention for spatial class conditioning. 202 203 Class embedding serves as key/value in cross-attention, allowing each 204 spatial position to attend to class-specific information. This provides 205 a complementary conditioning pathway to the FiLM injection via time_emb. 206 """ 207 def __init__(self, channels, cond_dim): 208
[Test #2] (1 test remaining; call submit(n=2) to submit this result) ### train_small (scripts/train_small.sh) Training results: Model parameters: 9.7M | GPUs: 8 step 200/35000 | loss 0.0349 | 22.5s step 400/35000 | loss 0.0324 | 21.4s step 600/35000 | loss 0.0382 | 21.2s step 800/35000 | loss 0.0269 | 21.2s step 1000/35000 | loss 0.0406 | 21.8s step 1200/35000 | loss 0.0332 | 20.9s step 1400/35000 | loss 0.0318 | 21.5s step 1600/35000 | loss 0.0227 | 21.1s step 1800/35000 | loss 0.0321 | 21.3s step 2000/35000 | loss 0.0439 | 21.3s step 2200/35000 | loss 0.0427 | 21.2s step 2400/35000 | loss 0.0276 | 21.2s step 2600/35000 | loss 0.0466 | 21.6s step 2800/35000 | loss 0.0281 | 20.9s step 3000/35000 | loss 0.0344 | 20.7s step 3200/35000 | loss 0.0347 | 21.0s step 3400/35000 | loss 0.0324 | 21.1s step 3600/35000 | loss 0.0299 | 20.8s step 3800/35000 | loss 0.0355 | 21.3s step 4000/35000 | loss 0.0227 | 21.0s step 4200/35000 | loss 0.0347 | 21.1s step 4400/35000 | loss 0.0205 | 21.2s step 4
Tool Input
1{
2 "content": "def prepare_conditioning(time_emb, class_emb):\n \"\"\"Keep time embedding unchanged; class conditioning entirely via ClassConditioner.\"\"\"\n return time_emb\n\n\nclass ClassConditioner(nn.Module):\n \"\"\"Class-projected cross-attention: project class to channel space before attention.\n \n Instead of using the raw 512-dim class embedding as cross-attention context,\n we first project it to match the feature channel dimension. This creates a\n class-specific representation in the same space as spatial features, making\n the K/V projections more effective and natural. The cross-attention then\n operates on features and class projections in the same channel space.\n \"\"\"\n def __init__(self, channels, cond_dim):\n super().__init__()\n # Project class embedding to feature channel space\n self.class_proj = nn.Linear(cond_dim, channels)\n # Cross-attention with matched dimensions\n self.cross_attn = CrossAttentionLayer(channels, channels, num_heads=4)\n \n def forward(self, h, class_emb):\n context = self.class_proj(class_emb)\n return self.cross_attn(h, context)",
3 "end_line": 212,
4 "filename": "diffusers-main/custom_train.py",
5 "op": "replace",
6 "start_line": 195
7}OK: Replaced lines 195..212 with 24 line(s) in diffusers-main/custom_train.py. Editable range: 195–218. [Current file: diffusers-main/custom_train.py | editable: 195–218 | total: 601 lines] 192| # Conditioning injection (EDITABLE REGION) 193| # ============================================================================ 194| 195 def prepare_conditioning(time_emb, class_emb): 196 """Keep time embedding unchanged; class conditioning entirely via ClassConditioner.""" 197 return time_emb 198 199 200 class ClassConditioner(nn.Module): 201 """Class-projected cross-attention: project class to channel space before attention. 202 203 Instead of using the raw 512-dim class embedding as cross-attention context, 204 we first project it to match the feature channel dimension. This creates a 205 class-specific representation in the same space as spatial features, making 206 the K/V projections more effectiv
[Test #3] (0 tests remaining) ### train_small (scripts/train_small.sh) Training results: Model parameters: 9.7M | GPUs: 8 step 200/35000 | loss 0.0347 | 22.5s step 400/35000 | loss 0.0319 | 22.0s step 600/35000 | loss 0.0383 | 21.6s step 800/35000 | loss 0.0268 | 21.6s step 1000/35000 | loss 0.0407 | 21.6s step 1200/35000 | loss 0.0332 | 21.9s step 1400/35000 | loss 0.0318 | 22.3s step 1600/35000 | loss 0.0225 | 21.8s step 1800/35000 | loss 0.0319 | 21.8s step 2000/35000 | loss 0.0431 | 21.9s step 2200/35000 | loss 0.0424 | 22.2s step 2400/35000 | loss 0.0274 | 21.7s step 2600/35000 | loss 0.0473 | 21.8s step 2800/35000 | loss 0.0280 | 21.7s step 3000/35000 | loss 0.0348 | 21.9s step 3200/35000 | loss 0.0348 | 21.6s step 3400/35000 | loss 0.0323 | 21.7s step 3600/35000 | loss 0.0299 | 21.5s step 3800/35000 | loss 0.0357 | 21.9s step 4000/35000 | loss 0.0228 | 21.5s step 4200/35000 | loss 0.0347 | 21.7s step 4400/35000 | loss 0.0205 | 21.5s step 4600/35000 | loss 0.0353 | 21.2s step 48