Agent Conversation: google/gemini-3.1-pro-preview

Model: google/gemini-3.1-pro-preview
Total steps: 8
Messages: 17
UserStep 0

Task: cv-diffusion-conditioning

Class-Conditional Diffusion: Conditioning Injection Methods

Background

Class-conditional diffusion models generate images conditioned on a class label. The key design choice is how to inject the class information into the UNet:

  • Cross-Attention: Class embedding serves as key/value in a cross-attention layer after each ResBlock. Used in Stable Diffusion for text conditioning.
  • Adaptive Normalization (AdaLN-Zero): Class embedding modulates LayerNorm with learned scale, shift, and gating parameters. Used in DiT.
  • FiLM Conditioning: Class embedding is added to the timestep embedding and injected via adaptive GroupNorm (scale/shift) in ResBlocks.

Research Question

Which conditioning injection method achieves the best class-conditional FID on CIFAR-10?

Task

You are given custom_train.py, a self-contained class-conditional DDPM training script with a small UNet on CIFAR-10 (32x32, 10 classes).

The editable region contains:

  1. prepare_conditioning(time_emb, class_emb) — controls how class embedding is combined with the timestep embedding before entering ResBlocks.

  2. ClassConditioner(nn.Module) — an additional conditioning module applied after each ResBlock, enabling methods like cross-attention or adaptive norm.

Your goal is to design a conditioning injection method that achieves lower FID than the baselines.

Evaluation

  • Dataset: CIFAR-10 (32x32, 10 classes)
  • Model: UNet2DModel (diffusers backbone) at three scales:
    • Small: block_out_channels=(64,128,128,128), ~9M params, batch 128
    • Medium: block_out_channels=(128,256,256,256), ~36M params, batch 128
    • Large: block_out_channels=(256,512,512,512), ~140M params, batch 64
  • Training: 35000 steps per scale, AdamW lr=2e-4, EMA rate 0.9995
  • Metric: FID (lower is better), computed with clean-fid against CIFAR-10 train set (50k samples)
  • Inference: 50-step DDIM sampling (class-conditional)

Baselines

Your method is evaluated on train_small and train_medium (train_large is skipped during agent runs). FID (lower is better) for the 3 baselines:

NameStrategyfid (small)fid (medium)
concat-filmClass embedding added to timestep embedding, injected via FiLM (adaptive GroupNorm) in ResBlocks. Simplest method.19.3911.50
cross-attnClass embedding used as key/value in cross-attention layers after ResBlocks. Most expressive method.19.1711.21
adanormClass embedding generates scale/shift/gate parameters for adaptive LayerNorm after ResBlocks. DiT-style.20.0112.31

The strongest baseline is cross-attn at fid 19.17 (small) / 11.21 (medium). Your method must achieve lower fid than cross-attn on both train_small and train_medium to count as a successful improvement. A method that wins on one scale but regresses on the other is a partial fail — the task is about robust conditioning across model sizes.

Ac

AssistantStep 1edit
Tool Input
1{
2  "start_line": 195,
3  "op": "replace",
4  "end_line": 227,
5  "filename": "custom_train.py",
6  "content": "def prepare_conditioning(time_emb, class_emb):\n    return time_emb + class_emb\n\n\nclass ClassConditioner(nn.Module):\n    def __init__(self, channels, cond_dim, num_heads=4, num_tokens=4):\n        super().__init__()\n        self.num_heads = num_heads\n        self.head_dim = channels // num_heads\n        self.num_tokens = num_tokens\n        self.norm = nn.GroupNorm(32, channels)\n        \n        token_dim = cond_dim // num_tokens\n        self.q_proj = nn.Linear(channels, channels)\n        self.k_proj = nn.Linear(token_dim, channels)\n        self.v_proj = nn.Linear(token_dim, channels)\n        self.out_proj = zero_module(nn.Linear(channels, channels))\n\n    def forward(self, h, class_emb):\n        B, C, H, W = h.shape\n        hx = self.norm(h).view(B, C, -1).transpose(1, 2)\n        \n        ctx = class_emb.view(B, self.num_tokens, -1)\n        \n        q = self.q_proj(hx).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)\n        k = self.k_proj(ctx).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)\n        v = self.v_proj(ctx).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)\n        \n        attn = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_dim)\n        attn = F.softmax(attn, dim=-1)\n        out = torch.matmul(attn, v)\n        out = out.transpose(1, 2).reshape(B, H * W, C)\n        out = self.out_proj(out)\n        return h + out.transpose(1, 2).view(B, C, H, W)"
7}
Tool ResultStep 1
ERROR: Package 'custom_train.py' is not in allowed packages
AssistantStep 2edit
Tool Input
1{
2  "filename": "diffusers-main/custom_train.py",
3  "op": "replace",
4  "start_line": 195,
5  "end_line": 228,
6  "content": "def prepare_conditioning(time_emb, class_emb):\n    return time_emb + class_emb\n\n\nclass ClassConditioner(nn.Module):\n    def __init__(self, channels, cond_dim, num_heads=4, num_tokens=4):\n        super().__init__()\n        self.num_heads = num_heads\n        self.head_dim = channels // num_heads\n        self.num_tokens = num_tokens\n        self.norm = nn.GroupNorm(32, channels)\n        \n        token_dim = cond_dim // num_tokens\n        self.q_proj = nn.Linear(channels, channels)\n        self.k_proj = nn.Linear(token_dim, channels)\n        self.v_proj = nn.Linear(token_dim, channels)\n        self.out_proj = zero_module(nn.Linear(channels, channels))\n\n    def forward(self, h, class_emb):\n        B, C, H, W = h.shape\n        hx = self.norm(h).view(B, C, -1).transpose(1, 2)\n        \n        ctx = class_emb.view(B, self.num_tokens, -1)\n        \n        q = self.q_proj(hx).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)\n        k = self.k_proj(ctx).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)\n        v = self.v_proj(ctx).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)\n        \n        attn = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_dim)\n        attn = F.softmax(attn, dim=-1)\n        out = torch.matmul(attn, v)\n        out = out.transpose(1, 2).reshape(B, H * W, C)\n        out = self.out_proj(out)\n        return h + out.transpose(1, 2).view(B, C, H, W)"
7}
Tool ResultStep 2
ERROR: Lines 195..228 exceed the editable range. You may only edit lines 195–227.

[Current file: diffusers-main/custom_train.py | editable: 195–227 | total: 610 lines]
   192| # Conditioning injection (EDITABLE REGION)
   193| # ============================================================================
   194| 
   195  def prepare_conditioning(time_emb, class_emb):
   196      """Prepare the combined embedding used in ResBlocks.
   197  
   198      TODO: Implement your conditioning preparation here.
   199  
   200      Args:
   201          time_emb:  [B, time_embed_dim] timestep embedding
   202          class_emb: [B, time_embed_dim] class embedding
   203  
   204      Returns: [B, time_embed_dim] embedding used in ResBlocks
   205      """
   206      raise NotImplementedError("Implement prepare_conditioning")
   207  
   208  
   209  class ClassConditioner(nn.Module):
   210      """Conditioning module applied after each UNet block.
   211  
   212      TODO: Implement your 
AssistantStep 3edit
Tool Input
1{
2  "content": "def prepare_conditioning(time_emb, class_emb):\n    return time_emb + class_emb\n\n\nclass ClassConditioner(nn.Module):\n    def __init__(self, channels, cond_dim, num_heads=4, num_tokens=4):\n        super().__init__()\n        self.num_heads = num_heads\n        self.head_dim = channels // num_heads\n        self.num_tokens = num_tokens\n        self.norm = nn.GroupNorm(32, channels)\n        \n        token_dim = cond_dim // num_tokens\n        self.q_proj = nn.Linear(channels, channels)\n        self.k_proj = nn.Linear(token_dim, channels)\n        self.v_proj = nn.Linear(token_dim, channels)\n        self.out_proj = zero_module(nn.Linear(channels, channels))\n\n    def forward(self, h, class_emb):\n        B, C, H, W = h.shape\n        hx = self.norm(h).view(B, C, -1).transpose(1, 2)\n        \n        ctx = class_emb.view(B, self.num_tokens, -1)\n        \n        q = self.q_proj(hx).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)\n        k = self.k_proj(ctx).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)\n        v = self.v_proj(ctx).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)\n        \n        attn = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_dim)\n        attn = F.softmax(attn, dim=-1)\n        out = torch.matmul(attn, v)\n        out = out.transpose(1, 2).reshape(B, H * W, C)\n        out = self.out_proj(out)\n        return h + out.transpose(1, 2).view(B, C, H, W)",
3  "op": "replace",
4  "filename": "diffusers-main/custom_train.py",
5  "start_line": 195,
6  "end_line": 227
7}
Tool ResultStep 3
OK: Replaced lines 195..227 with 34 line(s) in diffusers-main/custom_train.py. Editable range: 195–228.

[Current file: diffusers-main/custom_train.py | editable: 195–228 | total: 611 lines]
   192| # Conditioning injection (EDITABLE REGION)
   193| # ============================================================================
   194| 
   195  def prepare_conditioning(time_emb, class_emb):
   196      return time_emb + class_emb
   197  
   198  
   199  class ClassConditioner(nn.Module):
   200      def __init__(self, channels, cond_dim, num_heads=4, num_tokens=4):
   201          super().__init__()
   202          self.num_heads = num_heads
   203          self.head_dim = channels // num_heads
   204          self.num_tokens = num_tokens
   205          self.norm = nn.GroupNorm(32, channels)
   206          
   207          token_dim = cond_dim // num_tokens
   208          self.q_proj = nn.Linear(channels, channels)
   209          self.k_proj = nn.Linear(token_dim, channels)
   210
AssistantStep 4test
Tool ResultStep 4
[Test #1] (2 tests remaining; call submit(n=1) to submit this result)

### train_small (scripts/train_small.sh)
Training results:
Model parameters: 9.3M | GPUs: 8
step 200/35000 | loss 0.0351 | 21.9s
step 400/35000 | loss 0.0322 | 21.1s
step 600/35000 | loss 0.0383 | 21.4s
step 800/35000 | loss 0.0271 | 21.3s
step 1000/35000 | loss 0.0408 | 21.4s
step 1200/35000 | loss 0.0332 | 21.0s
step 1400/35000 | loss 0.0317 | 21.3s
step 1600/35000 | loss 0.0229 | 21.5s
step 1800/35000 | loss 0.0323 | 21.3s
step 2000/35000 | loss 0.0438 | 21.0s
step 2200/35000 | loss 0.0429 | 21.0s
step 2400/35000 | loss 0.0277 | 20.8s
step 2600/35000 | loss 0.0471 | 21.1s
step 2800/35000 | loss 0.0284 | 21.3s
step 3000/35000 | loss 0.0347 | 21.1s
step 3200/35000 | loss 0.0349 | 21.6s
step 3400/35000 | loss 0.0325 | 21.0s
step 3600/35000 | loss 0.0300 | 21.1s
step 3800/35000 | loss 0.0355 | 21.5s
step 4000/35000 | loss 0.0228 | 21.0s
step 4200/35000 | loss 0.0346 | 21.1s
step 4400/35000 | loss 0.0206 | 21.2s
step 
AssistantStep 5edit
Tool Input
1{
2  "filename": "diffusers-main/custom_train.py",
3  "end_line": 228,
4  "content": "def prepare_conditioning(time_emb, class_emb):\n    # Keep time embedding pure to prevent interference\n    return time_emb\n\n\nclass ClassConditioner(nn.Module):\n    \"\"\"Feature-Adaptive FiLM (FA-FiLM).\n    \n    Dynamically predicts scale and shift for the features based on both\n    the global class embedding and the current spatial feature state.\n    \"\"\"\n    def __init__(self, channels, cond_dim):\n        super().__init__()\n        self.pool = nn.AdaptiveAvgPool2d(1)\n        self.mlp = nn.Sequential(\n            nn.Linear(channels + cond_dim, channels),\n            nn.LayerNorm(channels),\n            nn.SiLU(),\n            nn.Linear(channels, 2 * channels)\n        )\n        # Zero-initialize for identity mapping at start\n        nn.init.zeros_(self.mlp[-1].weight)\n        nn.init.zeros_(self.mlp[-1].bias)\n\n    def forward(self, h, class_emb):\n        B, C, _, _ = h.shape\n        # Extract current feature state\n        pooled = self.pool(h).view(B, C)\n        # Combine state with global class condition\n        ctx = torch.cat([pooled, class_emb], dim=1)\n        \n        # Predict scale and shift\n        params = self.mlp(ctx).view(B, 2, C, 1, 1)\n        scale, shift = params.chunk(2, dim=1)\n        scale = scale.squeeze(1)\n        shift = shift.squeeze(1)\n        \n        return h * (1 + scale) + shift",
5  "start_line": 195,
6  "op": "replace"
7}
Tool ResultStep 5
OK: Replaced lines 195..228 with 38 line(s) in diffusers-main/custom_train.py. Editable range: 195–232.

[Current file: diffusers-main/custom_train.py | editable: 195–232 | total: 615 lines]
   192| # Conditioning injection (EDITABLE REGION)
   193| # ============================================================================
   194| 
   195  def prepare_conditioning(time_emb, class_emb):
   196      # Keep time embedding pure to prevent interference
   197      return time_emb
   198  
   199  
   200  class ClassConditioner(nn.Module):
   201      """Feature-Adaptive FiLM (FA-FiLM).
   202      
   203      Dynamically predicts scale and shift for the features based on both
   204      the global class embedding and the current spatial feature state.
   205      """
   206      def __init__(self, channels, cond_dim):
   207          super().__init__()
   208          self.pool = nn.AdaptiveAvgPool2d(1)
   209          self.mlp = nn.Sequential(
   210              nn.Linear(channels 
AssistantStep 6test
Tool ResultStep 6
[Test #2] (1 test remaining; call submit(n=2) to submit this result)

### train_small (scripts/train_small.sh)
Training results:
Model parameters: 9.6M | GPUs: 8
step 200/35000 | loss 0.0345 | 19.9s
step 400/35000 | loss 0.0318 | 19.6s
step 600/35000 | loss 0.0382 | 19.4s
step 800/35000 | loss 0.0267 | 18.9s
step 1000/35000 | loss 0.0402 | 18.6s
step 1200/35000 | loss 0.0331 | 18.5s
step 1400/35000 | loss 0.0318 | 19.2s
step 1600/35000 | loss 0.0227 | 19.1s
step 1800/35000 | loss 0.0319 | 19.1s
step 2000/35000 | loss 0.0435 | 19.1s
step 2200/35000 | loss 0.0425 | 19.3s
step 2400/35000 | loss 0.0274 | 18.8s
step 2600/35000 | loss 0.0472 | 19.3s
step 2800/35000 | loss 0.0283 | 19.3s
step 3000/35000 | loss 0.0346 | 19.1s
step 3200/35000 | loss 0.0348 | 19.3s
step 3400/35000 | loss 0.0322 | 19.1s
step 3600/35000 | loss 0.0299 | 19.2s
step 3800/35000 | loss 0.0352 | 19.6s
step 4000/35000 | loss 0.0230 | 19.1s
step 4200/35000 | loss 0.0349 | 18.8s
step 4400/35000 | loss 0.0205 | 19.4s
step 4
AssistantStep 7edit
Tool Input
1{
2  "start_line": 195,
3  "end_line": 232,
4  "filename": "diffusers-main/custom_train.py",
5  "content": "def prepare_conditioning(time_emb, class_emb):\n    # Keep time_emb pure to avoid timestep interference\n    return time_emb\n\n\nclass ClassConditioner(nn.Module):\n    \"\"\"Dual-Injection Conditioning: Global FiLM + Multi-Token Spatial Attention.\"\"\"\n    def __init__(self, channels, cond_dim, num_heads=4, num_tokens=4):\n        super().__init__()\n        self.num_heads = num_heads\n        self.head_dim = channels // num_heads\n        self.num_tokens = num_tokens\n        \n        # 1. Global FiLM\n        self.film = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(cond_dim, 2 * channels)\n        )\n        nn.init.zeros_(self.film[-1].weight)\n        nn.init.zeros_(self.film[-1].bias)\n        \n        # 2. Multi-Token Cross Attention\n        self.norm = nn.GroupNorm(32, channels)\n        token_dim = cond_dim // num_tokens\n        self.q_proj = nn.Linear(channels, channels)\n        self.k_proj = nn.Linear(token_dim, channels)\n        self.v_proj = nn.Linear(token_dim, channels)\n        self.out_proj = zero_module(nn.Linear(channels, channels))\n\n    def forward(self, h, class_emb):\n        B, C, H, W = h.shape\n        \n        # Global modulation\n        scale, shift = self.film(class_emb).unsqueeze(-1).unsqueeze(-1).chunk(2, dim=1)\n        h_film = h * (1 + scale) + shift\n        \n        # Spatial Multi-Token routing\n        hx = self.norm(h_film).view(B, C, -1).transpose(1, 2)\n        ctx = class_emb.view(B, self.num_tokens, -1)\n        \n        q = self.q_proj(hx).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)\n        k = self.k_proj(ctx).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)\n        v = self.v_proj(ctx).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)\n        \n        attn = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_dim)\n        attn = F.softmax(attn, dim=-1)\n        out = torch.matmul(attn, v)\n        out = out.transpose(1, 2).reshape(B, H * W, C)\n        out = self.out_proj(out)\n        \n        return h_film + out.transpose(1, 2).view(B, C, H, W)",
6  "op": "replace"
7}
Tool ResultStep 7
OK: Replaced lines 195..232 with 51 line(s) in diffusers-main/custom_train.py. Editable range: 195–245.

[Current file: diffusers-main/custom_train.py | editable: 195–245 | total: 628 lines]
   192| # Conditioning injection (EDITABLE REGION)
   193| # ============================================================================
   194| 
   195  def prepare_conditioning(time_emb, class_emb):
   196      # Keep time_emb pure to avoid timestep interference
   197      return time_emb
   198  
   199  
   200  class ClassConditioner(nn.Module):
   201      """Dual-Injection Conditioning: Global FiLM + Multi-Token Spatial Attention."""
   202      def __init__(self, channels, cond_dim, num_heads=4, num_tokens=4):
   203          super().__init__()
   204          self.num_heads = num_heads
   205          self.head_dim = channels // num_heads
   206          self.num_tokens = num_tokens
   207          
   208          # 1. Global FiLM
   209          self.film = nn.Sequential(
   210       
AssistantStep 8test
Tool ResultStep 8
[Test #3] (0 tests remaining)

### train_small (scripts/train_small.sh)
Training results:
Model parameters: 9.9M | GPUs: 8
step 200/35000 | loss 0.0344 | 22.5s
step 400/35000 | loss 0.0318 | 22.3s
step 600/35000 | loss 0.0381 | 22.1s
step 800/35000 | loss 0.0271 | 22.0s
step 1000/35000 | loss 0.0404 | 23.0s
step 1200/35000 | loss 0.0330 | 22.7s
step 1400/35000 | loss 0.0316 | 22.9s
step 1600/35000 | loss 0.0226 | 22.4s
step 1800/35000 | loss 0.0319 | 22.8s
step 2000/35000 | loss 0.0431 | 22.6s
step 2200/35000 | loss 0.0427 | 22.1s
step 2400/35000 | loss 0.0275 | 21.9s
step 2600/35000 | loss 0.0472 | 22.3s
step 2800/35000 | loss 0.0282 | 22.0s
step 3000/35000 | loss 0.0346 | 22.2s
step 3200/35000 | loss 0.0349 | 22.7s
step 3400/35000 | loss 0.0326 | 22.4s
step 3600/35000 | loss 0.0300 | 21.9s
step 3800/35000 | loss 0.0355 | 22.1s
step 4000/35000 | loss 0.0227 | 22.3s
step 4200/35000 | loss 0.0348 | 22.3s
step 4400/35000 | loss 0.0204 | 22.3s
step 4600/35000 | loss 0.0352 | 22.2s
step 48