Agent Conversation: google/gemini-3.1-pro-preview
Task: cv-diffusion-architecture
Diffusion Model Architecture Design
Background
The UNet backbone is the standard architecture for denoising diffusion models. Key design choices include:
- Block types: Whether to use pure convolutional blocks (
DownBlock2D/UpBlock2D) or blocks with self-attention (AttnDownBlock2D/AttnUpBlock2D), and at which resolution levels. - Attention placement: Self-attention is expensive at high resolutions (32x32) but may improve global coherence. The original DDPM places attention only at 16x16.
- Depth and normalization:
layers_per_block,norm_num_groups,attention_head_dim, and other structural hyperparameters. - Custom modules: Entirely new backbone designs (e.g., hybrid convolution-transformer, gated blocks, multi-scale fusion) as long as they satisfy the input/output interface.
Research Question
What UNet architecture achieves the best FID on unconditional CIFAR-10 diffusion, given a fixed training procedure (epsilon prediction, DDIM sampling, same optimizer and schedule)?
Task
You are given custom_train.py, a self-contained unconditional DDPM training
script on CIFAR-10. Everything is fixed except the build_model(device)
function.
Your goal is to design a model architecture that achieves lower FID than the baselines. The model must satisfy:
- Input:
(x, timestep)wherexis[B, 3, 32, 32],timestepis[B] - Output: an object with
.sampleattribute of shape[B, 3, 32, 32] UNet2DModelfrom diffusers satisfies this interface, but you may also build a fully customnn.Module.
Channel widths are provided via the BLOCK_OUT_CHANNELS environment variable
(e.g. "128,256,256,256") so the same architecture scales across evaluation
tiers. LAYERS_PER_BLOCK (default 2) is also available.
Evaluation
- Dataset: CIFAR-10 (32x32, unconditional)
- Training: fixed epsilon prediction, MSE loss, AdamW lr=2e-4, EMA
- Model scales:
- Small: block_out_channels=(64,128,128,128), ~9M params, batch 128
- Medium: block_out_channels=(128,256,256,256), ~36M params, batch 128
- Large: block_out_channels=(256,512,512,512), ~140M params, batch 64
- Training: 35,000 steps per scale, EMA rate 0.9995
- Metric: FID (lower is better), computed with clean-fid against CIFAR-10 train set (50,000 samples)
- Inference: 50-step DDIM sampling
Baselines
All three baselines use modern UNet improvements: scale-shift FiLM time
conditioning (resnet_time_scale_shift="scale_shift"), flip_sin_to_cos=True,
freq_shift=0, and explicit multi-head attention. They differ in
attention placement and head capacity:
- iddpm (Improved DDPM, Nichol & Dhariwal 2021): self-attention at the
16x16 and 8x8 resolutions (2 levels),
attention_head_dim=8. Modest compute, focused attention on the mid-deep features. - adm (ADM "Diffusion Models Beat GANs", Dhariwal & Nichol 2021): self-attention at 16x16, 8x8, and 4x4 (
Tool Input
1{
2 "end_line": 58,
3 "content": "def build_model(device):\n import diffusers\n import os\n import torch\n import torch.nn as nn\n import torch.nn.functional as F\n\n class TGCBlock(nn.Module):\n \"\"\"Time-conditioned Global Context Block.\"\"\"\n def __init__(self, in_channels, time_emb_dim, ratio=4):\n super().__init__()\n mid_channels = max(16, in_channels // ratio)\n self.spatial_attn = nn.Conv2d(in_channels, 1, kernel_size=1)\n self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1)\n self.norm = nn.GroupNorm(1, mid_channels)\n self.time_proj = nn.Linear(time_emb_dim, mid_channels)\n self.conv2 = nn.Conv2d(mid_channels, in_channels, kernel_size=1)\n \n nn.init.constant_(self.conv2.weight, 0)\n nn.init.constant_(self.conv2.bias, 0)\n\n def forward(self, x, temb):\n B, C, H, W = x.size()\n attn = self.spatial_attn(x).view(B, 1, H * W)\n attn = F.softmax(attn, dim=-1)\n \n context = torch.bmm(x.view(B, C, H * W), attn.transpose(1, 2))\n context = context.view(B, C, 1, 1)\n \n hidden = self.norm(self.conv1(context))\n \n if temb is not None:\n time_bias = self.time_proj(F.mish(temb)).view(B, -1, 1, 1)\n hidden = hidden + time_bias\n \n hidden = F.mish(hidden)\n out = self.conv2(hidden)\n return x + out\n\n class ResnetWithTGC(nn.Module):\n def __init__(self, resnet, tgc_block):\n super().__init__()\n self.resnet = resnet\n self.tgc_block = tgc_block\n \n def forward(self, input_tensor, temb, *args, **kwargs):\n x = self.resnet(input_tensor, temb, *args, **kwargs)\n return self.tgc_block(x, temb)\n\n channels = (128, 256, 256, 256)\n if os.environ.get('BLOCK_OUT_CHANNELS'):\n channels = tuple(int(x) for x in os.environ['BLOCK_OUT_CHANNELS'].split(','))\n layers = int(os.environ.get('LAYERS_PER_BLOCK', 2))\n time_emb_dim = channels[0] * 4\n\n model = diffusers.UNet2DModel(\n sample_size=32,\n in_channels=3,\n out_channels=3,\n block_out_channels=channels,\n down_block_types=(\"AttnDownBlock2D\", \"AttnDownBlock2D\", \"AttnDownBlock2D\", \"DownBlock2D\"),\n up_block_types=(\"UpBlock2D\", \"AttnUpBlock2D\", \"AttnUpBlock2D\", \"AttnUpBlock2D\"),\n layers_per_block=layers,\n norm_num_groups=32,\n norm_eps=1e-6,\n act_fn=\"mish\",\n time_embedding_type=\"positional\",\n flip_sin_to_cos=True,\n freq_shift=0,\n downsample_padding=1,\n attention_head_dim=32,\n resnet_time_scale_shift=\"scale_shift\",\n )\n\n def inject_tgc_blocks(mod):\n for name, child in mod.named_children():\n if child.__class__.__name__ == \"ResnetBlock2D\":\n out_channels = getattr(child, \"out_channels\", None)\n if out_channels is None:\n out_channels = child.conv1.out_channels\n tgc = TGCBlock(out_channels, time_emb_dim)\n wrapped = ResnetWithTGC(child, tgc)\n if isinstance(mod, nn.ModuleList):\n mod[int(name)] = wrapped\n else:\n setattr(mod, name, wrapped)\n else:\n inject_tgc_blocks(child)\n \n inject_tgc_blocks(model)\n return model.to(device)\n",
4 "filename": "custom_train.py",
5 "op": "replace",
6 "start_line": 31
7}ERROR: Package 'custom_train.py' is not in allowed packages
Tool Input
1{
2 "content": "def build_model(device):\n import diffusers\n import os\n import torch\n import torch.nn as nn\n import torch.nn.functional as F\n\n class TGCBlock(nn.Module):\n \"\"\"Time-conditioned Global Context Block.\"\"\"\n def __init__(self, in_channels, time_emb_dim, ratio=4):\n super().__init__()\n mid_channels = max(16, in_channels // ratio)\n self.spatial_attn = nn.Conv2d(in_channels, 1, kernel_size=1)\n self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1)\n self.norm = nn.GroupNorm(1, mid_channels)\n self.time_proj = nn.Linear(time_emb_dim, mid_channels)\n self.conv2 = nn.Conv2d(mid_channels, in_channels, kernel_size=1)\n \n nn.init.constant_(self.conv2.weight, 0)\n nn.init.constant_(self.conv2.bias, 0)\n\n def forward(self, x, temb):\n B, C, H, W = x.size()\n attn = self.spatial_attn(x).view(B, 1, H * W)\n attn = F.softmax(attn, dim=-1)\n \n context = torch.bmm(x.view(B, C, H * W), attn.transpose(1, 2))\n context = context.view(B, C, 1, 1)\n \n hidden = self.norm(self.conv1(context))\n \n if temb is not None:\n time_bias = self.time_proj(F.mish(temb)).view(B, -1, 1, 1)\n hidden = hidden + time_bias\n \n hidden = F.mish(hidden)\n out = self.conv2(hidden)\n return x + out\n\n class ResnetWithTGC(nn.Module):\n def __init__(self, resnet, tgc_block):\n super().__init__()\n self.resnet = resnet\n self.tgc_block = tgc_block\n \n def forward(self, input_tensor, temb, *args, **kwargs):\n x = self.resnet(input_tensor, temb, *args, **kwargs)\n return self.tgc_block(x, temb)\n\n channels = (128, 256, 256, 256)\n if os.environ.get('BLOCK_OUT_CHANNELS'):\n channels = tuple(int(x) for x in os.environ['BLOCK_OUT_CHANNELS'].split(','))\n layers = int(os.environ.get('LAYERS_PER_BLOCK', 2))\n time_emb_dim = channels[0] * 4\n\n model = diffusers.UNet2DModel(\n sample_size=32,\n in_channels=3,\n out_channels=3,\n block_out_channels=channels,\n down_block_types=(\"AttnDownBlock2D\", \"AttnDownBlock2D\", \"AttnDownBlock2D\", \"DownBlock2D\"),\n up_block_types=(\"UpBlock2D\", \"AttnUpBlock2D\", \"AttnUpBlock2D\", \"AttnUpBlock2D\"),\n layers_per_block=layers,\n norm_num_groups=32,\n norm_eps=1e-6,\n act_fn=\"mish\",\n time_embedding_type=\"positional\",\n flip_sin_to_cos=True,\n freq_shift=0,\n downsample_padding=1,\n attention_head_dim=32,\n resnet_time_scale_shift=\"scale_shift\",\n )\n\n def inject_tgc_blocks(mod):\n for name, child in mod.named_children():\n if child.__class__.__name__ == \"ResnetBlock2D\":\n out_channels = getattr(child, \"out_channels\", None)\n if out_channels is None:\n out_channels = child.conv1.out_channels\n tgc = TGCBlock(out_channels, time_emb_dim)\n wrapped = ResnetWithTGC(child, tgc)\n if isinstance(mod, nn.ModuleList):\n mod[int(name)] = wrapped\n else:\n setattr(mod, name, wrapped)\n else:\n inject_tgc_blocks(child)\n \n inject_tgc_blocks(model)\n return model.to(device)\n",
3 "end_line": 58,
4 "start_line": 31,
5 "filename": "diffusers-main/custom_train.py",
6 "op": "replace"
7}OK: Replaced lines 31..58 with 91 line(s) in diffusers-main/custom_train.py. Editable range: 31–121.
[Current file: diffusers-main/custom_train.py | editable: 31–121 | total: 474 lines]
28| # Model Architecture (EDITABLE REGION)
29| # ============================================================================
30|
31 def build_model(device):
32 import diffusers
33 import os
34 import torch
35 import torch.nn as nn
36 import torch.nn.functional as F
37
38 class TGCBlock(nn.Module):
39 """Time-conditioned Global Context Block."""
40 def __init__(self, in_channels, time_emb_dim, ratio=4):
41 super().__init__()
42 mid_channels = max(16, in_channels // ratio)
43 self.spatial_attn = nn.Conv2d(in_channels, 1, kernel_size=1)
44 self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1)
45 self.norm =
[Test #1] (2 tests remaining; call submit(n=1) to submit this result) ### train_small (scripts/train_small.sh) Training results: Model parameters: 10.3M | GPUs: 8 step 200/35000 | loss 0.0348 | 32.3s step 400/35000 | loss 0.0319 | 31.5s step 600/35000 | loss 0.0378 | 31.3s step 800/35000 | loss 0.0268 | 31.1s step 1000/35000 | loss 0.0402 | 30.6s step 1200/35000 | loss 0.0327 | 31.1s step 1400/35000 | loss 0.0311 | 31.0s step 1600/35000 | loss 0.0226 | 32.0s step 1800/35000 | loss 0.0318 | 31.9s step 2000/35000 | loss 0.0426 | 31.2s step 2200/35000 | loss 0.0426 | 31.3s step 2400/35000 | loss 0.0275 | 32.0s step 2600/35000 | loss 0.0464 | 31.2s step 2800/35000 | loss 0.0280 | 30.8s step 3000/35000 | loss 0.0342 | 31.4s step 3200/35000 | loss 0.0343 | 31.2s step 3400/35000 | loss 0.0325 | 31.3s step 3600/35000 | loss 0.0297 | 32.0s step 3800/35000 | loss 0.0351 | 31.7s step 4000/35000 | loss 0.0227 | 31.5s step 4200/35000 | loss 0.0345 | 31.5s step 4400/35000 | loss 0.0201 | 31.9s step
Tool Input
1{
2 "op": "replace",
3 "filename": "diffusers-main/custom_train.py",
4 "content": " channels = (128, 256, 256, 256)\n if os.environ.get('BLOCK_OUT_CHANNELS'):\n channels = tuple(int(x) for x in os.environ['BLOCK_OUT_CHANNELS'].split(','))\n layers = int(os.environ.get('LAYERS_PER_BLOCK', 2))\n time_emb_dim = channels[0] * 4\n\n if channels[0] <= 64:\n down_blocks = (\"AttnDownBlock2D\", \"AttnDownBlock2D\", \"AttnDownBlock2D\", \"DownBlock2D\")\n up_blocks = (\"UpBlock2D\", \"AttnUpBlock2D\", \"AttnUpBlock2D\", \"AttnUpBlock2D\")\n else:\n down_blocks = (\"DownBlock2D\", \"AttnDownBlock2D\", \"AttnDownBlock2D\", \"DownBlock2D\")\n up_blocks = (\"UpBlock2D\", \"AttnUpBlock2D\", \"AttnUpBlock2D\", \"UpBlock2D\")\n\n model = diffusers.UNet2DModel(\n sample_size=32,\n in_channels=3,\n out_channels=3,\n block_out_channels=channels,\n down_block_types=down_blocks,\n up_block_types=up_blocks,\n layers_per_block=layers,\n norm_num_groups=32,\n norm_eps=1e-6,\n act_fn=\"mish\",\n time_embedding_type=\"positional\",\n flip_sin_to_cos=True,\n freq_shift=0,\n downsample_padding=1,\n attention_head_dim=32,\n resnet_time_scale_shift=\"scale_shift\",\n )",
5 "end_line": 103,
6 "start_line": 80
7}OK: Replaced lines 80..103 with 31 line(s) in diffusers-main/custom_train.py. Editable range: 31–128.
[Current file: diffusers-main/custom_train.py | editable: 31–128 | total: 481 lines]
28| # Model Architecture (EDITABLE REGION)
29| # ============================================================================
30|
31 def build_model(device):
32 import diffusers
33 import os
34 import torch
35 import torch.nn as nn
36 import torch.nn.functional as F
37
38 class TGCBlock(nn.Module):
39 """Time-conditioned Global Context Block."""
40 def __init__(self, in_channels, time_emb_dim, ratio=4):
41 super().__init__()
42 mid_channels = max(16, in_channels // ratio)
43 self.spatial_attn = nn.Conv2d(in_channels, 1, kernel_size=1)
44 self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1)
45 self.norm =
[Test #2] (1 test remaining; call submit(n=2) to submit this result) ### train_small (scripts/train_small.sh) Training results: Model parameters: 10.3M | GPUs: 8 step 200/35000 | loss 0.0348 | 31.9s step 400/35000 | loss 0.0320 | 31.1s step 600/35000 | loss 0.0378 | 31.3s step 800/35000 | loss 0.0268 | 31.3s step 1000/35000 | loss 0.0404 | 31.1s step 1200/35000 | loss 0.0328 | 31.7s step 1400/35000 | loss 0.0311 | 31.8s step 1600/35000 | loss 0.0226 | 31.4s step 1800/35000 | loss 0.0317 | 30.9s step 2000/35000 | loss 0.0426 | 31.7s step 2200/35000 | loss 0.0425 | 31.8s step 2400/35000 | loss 0.0274 | 31.2s step 2600/35000 | loss 0.0464 | 31.4s step 2800/35000 | loss 0.0280 | 31.0s step 3000/35000 | loss 0.0343 | 31.4s step 3200/35000 | loss 0.0344 | 31.8s step 3400/35000 | loss 0.0324 | 31.2s step 3600/35000 | loss 0.0297 | 31.4s step 3800/35000 | loss 0.0351 | 31.5s step 4000/35000 | loss 0.0227 | 31.6s step 4200/35000 | loss 0.0345 | 31.1s step 4400/35000 | loss 0.0201 | 31.3s step
Tool Input
1{
2 "content": "def build_model(device):\n import diffusers\n import os\n import torch\n import torch.nn as nn\n import torch.nn.functional as F\n\n class TGCBlock(nn.Module):\n \"\"\"Time-conditioned Global Context Block.\"\"\"\n def __init__(self, in_channels, time_emb_dim, ratio=4):\n super().__init__()\n mid_channels = max(16, in_channels // ratio)\n self.spatial_attn = nn.Conv2d(in_channels, 1, kernel_size=1)\n self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1)\n self.norm = nn.GroupNorm(1, mid_channels)\n self.time_proj = nn.Linear(time_emb_dim, mid_channels)\n self.conv2 = nn.Conv2d(mid_channels, in_channels, kernel_size=1)\n \n nn.init.constant_(self.conv2.weight, 0)\n nn.init.constant_(self.conv2.bias, 0)\n\n def forward(self, x, temb):\n B, C, H, W = x.size()\n attn = self.spatial_attn(x).view(B, 1, H * W)\n attn = F.softmax(attn, dim=-1)\n \n context = torch.bmm(x.view(B, C, H * W), attn.transpose(1, 2))\n context = context.view(B, C, 1, 1)\n \n hidden = self.norm(self.conv1(context))\n \n if temb is not None:\n time_bias = self.time_proj(F.mish(temb)).view(B, -1, 1, 1)\n hidden = hidden + time_bias\n \n hidden = F.mish(hidden)\n out = self.conv2(hidden)\n return x + out\n\n class ResnetWithTGC(nn.Module):\n def __init__(self, resnet, tgc_block):\n super().__init__()\n self.resnet = resnet\n self.tgc_block = tgc_block\n \n def forward(self, input_tensor, temb, *args, **kwargs):\n x = self.resnet(input_tensor, temb, *args, **kwargs)\n return self.tgc_block(x, temb)\n\n channels = (128, 256, 256, 256)\n if os.environ.get('BLOCK_OUT_CHANNELS'):\n channels = tuple(int(x) for x in os.environ['BLOCK_OUT_CHANNELS'].split(','))\n layers = int(os.environ.get('LAYERS_PER_BLOCK', 2))\n time_emb_dim = channels[0] * 4\n\n is_small = (channels[0] <= 64)\n\n if is_small:\n down_blocks = (\"AttnDownBlock2D\", \"AttnDownBlock2D\", \"AttnDownBlock2D\", \"DownBlock2D\")\n up_blocks = (\"UpBlock2D\", \"AttnUpBlock2D\", \"AttnUpBlock2D\", \"AttnUpBlock2D\")\n head_dim = 32\n else:\n down_blocks = (\"DownBlock2D\", \"AttnDownBlock2D\", \"AttnDownBlock2D\", \"DownBlock2D\")\n up_blocks = (\"UpBlock2D\", \"AttnUpBlock2D\", \"AttnUpBlock2D\", \"UpBlock2D\")\n head_dim = 8\n\n model = diffusers.UNet2DModel(\n sample_size=32,\n in_channels=3,\n out_channels=3,\n block_out_channels=channels,\n down_block_types=down_blocks,\n up_block_types=up_blocks,\n layers_per_block=layers,\n norm_num_groups=32,\n norm_eps=1e-6,\n act_fn=\"mish\",\n time_embedding_type=\"positional\",\n flip_sin_to_cos=True,\n freq_shift=0,\n downsample_padding=1,\n attention_head_dim=head_dim,\n resnet_time_scale_shift=\"scale_shift\",\n )\n\n def inject_tgc_blocks(mod):\n for name, child in mod.named_children():\n if child.__class__.__name__ == \"ResnetBlock2D\":\n out_channels = getattr(child, \"out_channels\", None)\n if out_channels is None:\n out_channels = child.conv1.out_channels\n tgc = TGCBlock(out_channels, time_emb_dim)\n wrapped = ResnetWithTGC(child, tgc)\n if isinstance(mod, nn.ModuleList):\n mod[int(name)] = wrapped\n else:\n setattr(mod, name, wrapped)\n else:\n inject_tgc_blocks(child)\n \n def zero_init_resnet(mod):\n for name, child in mod.named_children():\n if child.__class__.__name__ == \"ResnetBlock2D\":\n if hasattr(child, 'conv2'):\n nn.init.constant_(child.conv2.weight, 0)\n nn.init.constant_(child.conv2.bias, 0)\n else:\n zero_init_resnet(child)\n\n if is_small:\n inject_tgc_blocks(model)\n else:\n zero_init_resnet(model)\n\n return model.to(device)\n",
3 "end_line": 128,
4 "op": "replace",
5 "start_line": 31,
6 "filename": "diffusers-main/custom_train.py"
7}OK: Replaced lines 31..128 with 115 line(s) in diffusers-main/custom_train.py. Editable range: 31–145.
[Current file: diffusers-main/custom_train.py | editable: 31–145 | total: 498 lines]
28| # Model Architecture (EDITABLE REGION)
29| # ============================================================================
30|
31 def build_model(device):
32 import diffusers
33 import os
34 import torch
35 import torch.nn as nn
36 import torch.nn.functional as F
37
38 class TGCBlock(nn.Module):
39 """Time-conditioned Global Context Block."""
40 def __init__(self, in_channels, time_emb_dim, ratio=4):
41 super().__init__()
42 mid_channels = max(16, in_channels // ratio)
43 self.spatial_attn = nn.Conv2d(in_channels, 1, kernel_size=1)
44 self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1)
45 self.norm
[Test #3] (0 tests remaining) ### train_small (scripts/train_small.sh) Training results: Model parameters: 10.3M | GPUs: 8 step 200/35000 | loss 0.0349 | 32.1s step 400/35000 | loss 0.0318 | 31.9s step 600/35000 | loss 0.0377 | 31.6s step 800/35000 | loss 0.0271 | 32.4s step 1000/35000 | loss 0.0404 | 31.2s step 1200/35000 | loss 0.0327 | 30.9s step 1400/35000 | loss 0.0310 | 31.4s step 1600/35000 | loss 0.0227 | 30.7s step 1800/35000 | loss 0.0317 | 31.3s step 2000/35000 | loss 0.0426 | 31.6s step 2200/35000 | loss 0.0427 | 31.2s step 2400/35000 | loss 0.0275 | 31.7s step 2600/35000 | loss 0.0464 | 32.2s step 2800/35000 | loss 0.0280 | 31.2s step 3000/35000 | loss 0.0342 | 31.3s step 3200/35000 | loss 0.0343 | 31.2s step 3400/35000 | loss 0.0324 | 31.1s step 3600/35000 | loss 0.0298 | 31.1s step 3800/35000 | loss 0.0352 | 31.5s step 4000/35000 | loss 0.0228 | 31.1s step 4200/35000 | loss 0.0344 | 31.2s step 4400/35000 | loss 0.0202 | 31.3s step 4600/35000 | loss 0.0349 | 31.7s step 4