cv-diffusion-conditioning
Description
Class-Conditional Diffusion: Conditioning Injection Methods
Background
Class-conditional diffusion models generate images conditioned on a class label. The key design choice is how to inject the class information into the UNet:
- Cross-Attention: Class embedding serves as key/value in a cross-attention layer after each ResBlock. Used in Stable Diffusion for text conditioning.
- Adaptive Normalization (AdaLN-Zero): Class embedding modulates LayerNorm with learned scale, shift, and gating parameters. Used in DiT.
- FiLM Conditioning: Class embedding is added to the timestep embedding and injected via adaptive GroupNorm (scale/shift) in ResBlocks.
Research Question
Which conditioning injection method achieves the best class-conditional FID on CIFAR-10?
Task
You are given custom_train.py, a self-contained class-conditional DDPM
training script with a small UNet on CIFAR-10 (32x32, 10 classes).
The editable region contains:
-
prepare_conditioning(time_emb, class_emb)— controls how class embedding is combined with the timestep embedding before entering ResBlocks. -
ClassConditioner(nn.Module)— an additional conditioning module applied after each ResBlock, enabling methods like cross-attention or adaptive norm.
Your goal is to design a conditioning injection method that achieves lower FID than the baselines.
Evaluation
- Dataset: CIFAR-10 (32x32, 10 classes)
- Model: UNet2DModel (diffusers backbone) at three scales:
- Small: block_out_channels=(64,128,128,128), ~9M params, batch 128
- Medium: block_out_channels=(128,256,256,256), ~36M params, batch 128
- Large: block_out_channels=(256,512,512,512), ~140M params, batch 64
- Training: 35000 steps per scale, AdamW lr=2e-4, EMA rate 0.9995
- Metric: FID (lower is better), computed with clean-fid against CIFAR-10 train set (50k samples)
- Inference: 50-step DDIM sampling (class-conditional)
Baselines
- concat-film: Class embedding added to timestep embedding, injected via FiLM (adaptive GroupNorm) in ResBlocks. Simplest method.
- cross-attn: Class embedding used as key/value in cross-attention layers after ResBlocks. Most expressive method.
- adanorm: Class embedding generates scale/shift/gate parameters for adaptive LayerNorm after ResBlocks. DiT-style method.
Code
1"""Class-Conditional DDPM Training on CIFAR-10.23Uses diffusers UNet2DModel backbone (same architecture as google/ddpm-cifar10-32)4with configurable class-conditioning injection. Only the conditioning method5(prepare_conditioning + ClassConditioner) is editable.6"""78import copy9import math10import os11import sys12import time1314import numpy as np15import torch
Results
| Model | Type | best fid small ↓ | best fid medium ↓ | best fid large ↓ |
|---|---|---|---|---|
| adanorm | baseline | 20.010 | 12.310 | 11.770 |
| concat-film | baseline | 19.390 | 11.500 | 10.520 |
| cross-attn | baseline | 19.170 | 11.210 | 10.410 |
| anthropic/claude-opus-4.6 | vanilla | 19.160 | 11.440 | - |
| deepseek-reasoner | vanilla | 19.850 | 12.770 | - |
| google/gemini-3.1-pro-preview | vanilla | 19.100 | 11.430 | - |
| qwen/qwen3.6-plus | vanilla | 19.910 | 11.980 | - |
| anthropic/claude-opus-4.6 | agent | 19.160 | 11.440 | 9.930 |
| deepseek-reasoner | agent | 19.530 | 11.150 | 10.430 |
| google/gemini-3.1-pro-preview | agent | 19.100 | 11.430 | 11.110 |
| qwen/qwen3.6-plus | agent | 19.230 | 11.270 | 10.370 |