cv-diffusion-conditioning

Computer Visiondiffusers-mainrigorous codebase

Description

Class-Conditional Diffusion: Conditioning Injection Methods

Background

Class-conditional diffusion models generate images conditioned on a class label. The key design choice is how to inject the class information into the UNet:

  • Cross-Attention: Class embedding serves as key/value in a cross-attention layer after each ResBlock. Used in Stable Diffusion for text conditioning.
  • Adaptive Normalization (AdaLN-Zero): Class embedding modulates LayerNorm with learned scale, shift, and gating parameters. Used in DiT.
  • FiLM Conditioning: Class embedding is added to the timestep embedding and injected via adaptive GroupNorm (scale/shift) in ResBlocks.

Research Question

Which conditioning injection method achieves the best class-conditional FID on CIFAR-10?

Task

You are given custom_train.py, a self-contained class-conditional DDPM training script with a small UNet on CIFAR-10 (32x32, 10 classes).

The editable region contains:

  1. prepare_conditioning(time_emb, class_emb) — controls how class embedding is combined with the timestep embedding before entering ResBlocks.

  2. ClassConditioner(nn.Module) — an additional conditioning module applied after each ResBlock, enabling methods like cross-attention or adaptive norm.

Your goal is to design a conditioning injection method that achieves lower FID than the baselines.

Evaluation

  • Dataset: CIFAR-10 (32x32, 10 classes)
  • Model: UNet2DModel (diffusers backbone) at three scales:
    • Small: block_out_channels=(64,128,128,128), ~9M params, batch 128
    • Medium: block_out_channels=(128,256,256,256), ~36M params, batch 128
    • Large: block_out_channels=(256,512,512,512), ~140M params, batch 64
  • Training: 35000 steps per scale, AdamW lr=2e-4, EMA rate 0.9995
  • Metric: FID (lower is better), computed with clean-fid against CIFAR-10 train set (50k samples)
  • Inference: 50-step DDIM sampling (class-conditional)

Baselines

  1. concat-film: Class embedding added to timestep embedding, injected via FiLM (adaptive GroupNorm) in ResBlocks. Simplest method.
  2. cross-attn: Class embedding used as key/value in cross-attention layers after ResBlocks. Most expressive method.
  3. adanorm: Class embedding generates scale/shift/gate parameters for adaptive LayerNorm after ResBlocks. DiT-style method.

Code

custom_train.py
EditableRead-only
1"""Class-Conditional DDPM Training on CIFAR-10.
2
3Uses diffusers UNet2DModel backbone (same architecture as google/ddpm-cifar10-32)
4with configurable class-conditioning injection. Only the conditioning method
5(prepare_conditioning + ClassConditioner) is editable.
6"""
7
8import copy
9import math
10import os
11import sys
12import time
13
14import numpy as np
15import torch

Results

ModelTypebest fid small best fid medium best fid large
adanormbaseline20.01012.31011.770
concat-filmbaseline19.39011.50010.520
cross-attnbaseline19.17011.21010.410
anthropic/claude-opus-4.6vanilla19.16011.440-
deepseek-reasonervanilla19.85012.770-
google/gemini-3.1-pro-previewvanilla19.10011.430-
qwen/qwen3.6-plusvanilla19.91011.980-
anthropic/claude-opus-4.6agent19.16011.4409.930
deepseek-reasoneragent19.53011.15010.430
google/gemini-3.1-pro-previewagent19.10011.43011.110
qwen/qwen3.6-plusagent19.23011.27010.370

Agent Conversations