Agent Conversation: google/gemini-3.1-pro-preview

Model: google/gemini-3.1-pro-preview
Total steps: 7
Messages: 15
UserStep 0

Task: cv-diffusion-prediction

Diffusion Prediction Parameterization

Background

In DDPM training, the model predicts a target from noisy input x_t. The noisy sample is constructed as:

x_t = sqrt(alpha_t) * x_0 + sqrt(1 - alpha_t) * epsilon

There are three standard parameterizations for what the model predicts:

  1. Epsilon prediction (Ho et al., 2020): predict the noise epsilon
  2. X0 prediction: directly predict the clean image x_0
  3. V-prediction (Salimans & Ho, 2022): predict the velocity v = sqrt(alpha_t) * epsilon - sqrt(1 - alpha_t) * x_0

These are mathematically equivalent (one can be converted to any other), but they result in different loss landscapes and training dynamics, leading to different FID scores under the same training budget.

Research Question

Can we design a prediction parameterization that achieves better FID than the standard epsilon, v-prediction, and x0-prediction baselines?

Task

You are given custom_train.py, a self-contained training script that trains an unconditional UNet2DModel (google/ddpm-cifar10-32 architecture) on CIFAR-10.

The editable region contains two functions:

  1. compute_training_target(x_0, noise, timesteps, schedule) — defines what the model should predict during training.

  2. predict_x0(model_output, x_t, timesteps, schedule) — recovers the predicted clean image from the model's output (used during DDIM sampling).

These two functions must be consistent: the sampling procedure must correctly invert the training parameterization.

The schedule dict provides precomputed noise schedule tensors:

  • alphas_cumprod: cumulative product of (1 - beta)
  • sqrt_alpha: sqrt(alphas_cumprod)
  • sqrt_one_minus_alpha: sqrt(1 - alphas_cumprod)

Evaluation

  • Dataset: CIFAR-10 (32x32)
  • Model: UNet2DModel (diffusers backbone) at three scales:
    • Small: block_out_channels=(64,128,128,128), ~9M params, batch 128
    • Medium: block_out_channels=(128,256,256,256), ~36M params, batch 128
    • Large: block_out_channels=(256,512,512,512), ~140M params, batch 64
  • Training: 35000 steps per scale, AdamW lr=2e-4, EMA rate 0.9995, 8-GPU DDP
  • Inference: 50-step DDIM sampler
  • Metric: FID (lower is better), computed with clean-fid against CIFAR-10 train set (50k samples)

Baselines

Your method is evaluated on train_small and train_medium (train_large is skipped during agent runs). FID (lower is better) for the 3 baselines:

NameTargetfid (small)fid (medium)
epsilonpredict noise ε (standard DDPM parameterization)18.588.46
vpredpredict velocity v = α_t * ε - σ_t * x_0 (v-prediction, Salimans & Ho 2022)21.7011.59
x0predpredict clean data x_0 directly25.2513.56

The strongest baseline is epsilon at fid 18.58 (small) / 8.46 (medium) — standard noise prediction remains the best single-target parameterization at this scale. Your method must achieve **lower fid than `epsi

AssistantStep 1edit
Tool Input
1{
2  "filename": "custom_train.py",
3  "op": "replace",
4  "start_line": 83,
5  "content": "def get_cv_minsnr_c(timesteps, schedule, gamma=5.0):\n    sa = schedule[\"sqrt_alpha\"][timesteps].view(-1, 1, 1, 1)\n    sma = schedule[\"sqrt_one_minus_alpha\"][timesteps].view(-1, 1, 1, 1)\n    snr = (sa**2) / (sma**2)\n    w = torch.clamp(snr, max=gamma)\n    S = torch.sqrt(w * (sma**2))\n    s_comp = torch.sqrt((1.0 - w * (sma**2)).clamp(min=0))\n    c1 = sa * s_comp - sma * S\n    c2 = sma * s_comp + sa * S\n    return c1, c2, S, sma\n\ndef compute_training_target(x_0, noise, timesteps, schedule):\n    c1, c2, _, _ = get_cv_minsnr_c(timesteps, schedule)\n    return c1 * x_0 + c2 * noise\n\ndef predict_x0(model_output, x_t, timesteps, schedule):\n    _, c2, S, sma = get_cv_minsnr_c(timesteps, schedule)\n    return (c2 * x_t - sma * model_output) / S.clamp(min=1e-8)\n",
6  "end_line": 118
7}
Tool ResultStep 1
ERROR: Package 'custom_train.py' is not in allowed packages
AssistantStep 2edit
Tool Input
1{
2  "end_line": 118,
3  "content": "def get_cv_minsnr_c(timesteps, schedule, gamma=5.0):\n    sa = schedule[\"sqrt_alpha\"][timesteps].view(-1, 1, 1, 1)\n    sma = schedule[\"sqrt_one_minus_alpha\"][timesteps].view(-1, 1, 1, 1)\n    snr = (sa**2) / (sma**2)\n    w = torch.clamp(snr, max=gamma)\n    S = torch.sqrt(w * (sma**2))\n    s_comp = torch.sqrt((1.0 - w * (sma**2)).clamp(min=0))\n    c1 = sa * s_comp - sma * S\n    c2 = sma * s_comp + sa * S\n    return c1, c2, S, sma\n\ndef compute_training_target(x_0, noise, timesteps, schedule):\n    c1, c2, _, _ = get_cv_minsnr_c(timesteps, schedule)\n    return c1 * x_0 + c2 * noise\n\ndef predict_x0(model_output, x_t, timesteps, schedule):\n    _, c2, S, sma = get_cv_minsnr_c(timesteps, schedule)\n    return (c2 * x_t - sma * model_output) / S.clamp(min=1e-8)",
4  "filename": "diffusers-main/custom_train.py",
5  "op": "replace",
6  "start_line": 83
7}
Tool ResultStep 2
OK: Replaced lines 83..118 with 18 line(s) in diffusers-main/custom_train.py. Editable range: 83–100.

[Current file: diffusers-main/custom_train.py | editable: 83–100 | total: 474 lines]
    80| # ============================================================================
    81| # Prediction parameterization (EDITABLE REGION)
    82| # ============================================================================
    83  def get_cv_minsnr_c(timesteps, schedule, gamma=5.0):
    84      sa = schedule["sqrt_alpha"][timesteps].view(-1, 1, 1, 1)
    85      sma = schedule["sqrt_one_minus_alpha"][timesteps].view(-1, 1, 1, 1)
    86      snr = (sa**2) / (sma**2)
    87      w = torch.clamp(snr, max=gamma)
    88      S = torch.sqrt(w * (sma**2))
    89      s_comp = torch.sqrt((1.0 - w * (sma**2)).clamp(min=0))
    90      c1 = sa * s_comp - sma * S
    91      c2 = sma * s_comp + sa * S
    92      return c1, c2, S, sma
    93  
    94  def compute_training_target(x_0, noise, timesteps, sch
AssistantStep 3test
Tool ResultStep 3
[Test #1] (2 tests remaining; call submit(n=1) to submit this result)

### train_small (scripts/train_small.sh)
Training results:
Model parameters: 9.0M | GPUs: 8
step 200/35000 | loss 0.0268 | 16.9s
step 400/35000 | loss 0.0158 | 16.3s
step 600/35000 | loss 0.0197 | 17.1s
step 800/35000 | loss 0.0147 | 16.3s
step 1000/35000 | loss 0.0182 | 15.9s
step 1200/35000 | loss 0.0187 | 16.8s
step 1400/35000 | loss 0.0164 | 16.7s
step 1600/35000 | loss 0.0140 | 16.3s
step 1800/35000 | loss 0.0166 | 16.8s
step 2000/35000 | loss 0.0193 | 16.2s
step 2200/35000 | loss 0.0177 | 16.3s
step 2400/35000 | loss 0.0164 | 16.1s
step 2600/35000 | loss 0.0175 | 17.2s
step 2800/35000 | loss 0.0166 | 16.5s
step 3000/35000 | loss 0.0178 | 16.5s
step 3200/35000 | loss 0.0156 | 16.3s
step 3400/35000 | loss 0.0126 | 16.2s
step 3600/35000 | loss 0.0154 | 16.6s
step 3800/35000 | loss 0.0150 | 16.6s
step 4000/35000 | loss 0.0140 | 15.9s
step 4200/35000 | loss 0.0190 | 16.2s
step 4400/35000 | loss 0.0159 | 16.0s
step 
AssistantStep 4edit
Tool Input
1{
2  "content": "def get_scaled_eps_target(timesteps, schedule, gamma=5.0):\n    sa = schedule[\"sqrt_alpha\"][timesteps].view(-1, 1, 1, 1)\n    sma = schedule[\"sqrt_one_minus_alpha\"][timesteps].view(-1, 1, 1, 1)\n    snr = (sa**2) / (sma**2)\n    w = torch.clamp(gamma / snr, max=1.0)\n    return torch.sqrt(w), sa, sma\n\ndef compute_training_target(x_0, noise, timesteps, schedule):\n    scale, _, _ = get_scaled_eps_target(timesteps, schedule)\n    return scale * noise\n\ndef predict_x0(model_output, x_t, timesteps, schedule):\n    scale, sa, sma = get_scaled_eps_target(timesteps, schedule)\n    eps_pred = model_output / scale.clamp(min=1e-8)\n    return (x_t - sma * eps_pred) / sa.clamp(min=1e-8)",
3  "start_line": 83,
4  "op": "replace",
5  "filename": "diffusers-main/custom_train.py",
6  "end_line": 100
7}
Tool ResultStep 4
OK: Replaced lines 83..100 with 15 line(s) in diffusers-main/custom_train.py. Editable range: 83–97.

[Current file: diffusers-main/custom_train.py | editable: 83–97 | total: 471 lines]
    80| # ============================================================================
    81| # Prediction parameterization (EDITABLE REGION)
    82| # ============================================================================
    83  def get_scaled_eps_target(timesteps, schedule, gamma=5.0):
    84      sa = schedule["sqrt_alpha"][timesteps].view(-1, 1, 1, 1)
    85      sma = schedule["sqrt_one_minus_alpha"][timesteps].view(-1, 1, 1, 1)
    86      snr = (sa**2) / (sma**2)
    87      w = torch.clamp(gamma / snr, max=1.0)
    88      return torch.sqrt(w), sa, sma
    89  
    90  def compute_training_target(x_0, noise, timesteps, schedule):
    91      scale, _, _ = get_scaled_eps_target(timesteps, schedule)
    92      return scale * noise
    93  
    94  def predict_x0(model_output, x_t, timeste
AssistantStep 5test
Tool ResultStep 5
[Test #2] (1 test remaining; call submit(n=2) to submit this result)

### train_small (scripts/train_small.sh)
Training results:
Model parameters: 9.0M | GPUs: 8
step 200/35000 | loss 0.0259 | 17.1s
step 400/35000 | loss 0.0155 | 16.4s
step 600/35000 | loss 0.0192 | 16.6s
step 800/35000 | loss 0.0147 | 16.6s
step 1000/35000 | loss 0.0179 | 16.6s
step 1200/35000 | loss 0.0185 | 16.5s
step 1400/35000 | loss 0.0162 | 16.9s
step 1600/35000 | loss 0.0138 | 16.4s
step 1800/35000 | loss 0.0166 | 16.9s
step 2000/35000 | loss 0.0189 | 16.6s
step 2200/35000 | loss 0.0175 | 16.4s
step 2400/35000 | loss 0.0163 | 16.4s
step 2600/35000 | loss 0.0174 | 16.8s
step 2800/35000 | loss 0.0162 | 16.3s
step 3000/35000 | loss 0.0176 | 16.3s
step 3200/35000 | loss 0.0155 | 16.3s
step 3400/35000 | loss 0.0125 | 16.4s
step 3600/35000 | loss 0.0154 | 16.3s
step 3800/35000 | loss 0.0149 | 16.7s
step 4000/35000 | loss 0.0139 | 16.4s
step 4200/35000 | loss 0.0190 | 16.3s
step 4400/35000 | loss 0.0160 | 16.2s
step 4
AssistantStep 6edit
Tool Input
1{
2  "content": "def get_tail_rotated_c(timesteps, schedule, alpha_thresh=0.05):\n    sa = schedule[\"sqrt_alpha\"][timesteps].view(-1, 1, 1, 1)\n    sma = schedule[\"sqrt_one_minus_alpha\"][timesteps].view(-1, 1, 1, 1)\n    alpha = sa ** 2\n    \n    ratio = (1.0 - alpha / alpha_thresh).clamp(min=0.0, max=1.0)\n    theta = - (math.pi / 2.0) * ratio\n    \n    c1 = torch.sin(theta)\n    c2 = torch.cos(theta)\n    return c1, c2, sa, sma\n\ndef compute_training_target(x_0, noise, timesteps, schedule):\n    c1, c2, _, _ = get_tail_rotated_c(timesteps, schedule)\n    return c1 * x_0 + c2 * noise\n\ndef predict_x0(model_output, x_t, timesteps, schedule):\n    c1, c2, sa, sma = get_tail_rotated_c(timesteps, schedule)\n    num = sma * model_output - c2 * x_t\n    den = sma * c1 - sa * c2\n    return num / den.clamp(max=-1e-8)",
3  "end_line": 97,
4  "filename": "diffusers-main/custom_train.py",
5  "op": "replace",
6  "start_line": 83
7}
Tool ResultStep 6
OK: Replaced lines 83..97 with 21 line(s) in diffusers-main/custom_train.py. Editable range: 83–103.

[Current file: diffusers-main/custom_train.py | editable: 83–103 | total: 477 lines]
    80| # ============================================================================
    81| # Prediction parameterization (EDITABLE REGION)
    82| # ============================================================================
    83  def get_tail_rotated_c(timesteps, schedule, alpha_thresh=0.05):
    84      sa = schedule["sqrt_alpha"][timesteps].view(-1, 1, 1, 1)
    85      sma = schedule["sqrt_one_minus_alpha"][timesteps].view(-1, 1, 1, 1)
    86      alpha = sa ** 2
    87      
    88      ratio = (1.0 - alpha / alpha_thresh).clamp(min=0.0, max=1.0)
    89      theta = - (math.pi / 2.0) * ratio
    90      
    91      c1 = torch.sin(theta)
    92      c2 = torch.cos(theta)
    93      return c1, c2, sa, sma
    94  
    95  def compute_training_target(x_0, noise, timesteps, schedule):
    96 
AssistantStep 7test
Tool ResultStep 7
[Test #3] (0 tests remaining)

### train_small (scripts/train_small.sh)
Training results:
Model parameters: 9.0M | GPUs: 8
step 200/35000 | loss 0.1276 | 17.1s
step 400/35000 | loss 0.1478 | 16.3s
step 600/35000 | loss 0.1354 | 16.5s
step 800/35000 | loss 0.1216 | 16.2s
step 1000/35000 | loss 0.1167 | 16.2s
step 1200/35000 | loss 0.1128 | 16.3s
step 1400/35000 | loss 0.1195 | 16.8s
step 1600/35000 | loss 0.1432 | 16.3s
step 1800/35000 | loss 0.1193 | 16.7s
step 2000/35000 | loss 0.1236 | 16.6s
step 2200/35000 | loss 0.1264 | 16.4s
step 2400/35000 | loss 0.1294 | 16.4s
step 2600/35000 | loss 0.1231 | 16.7s
step 2800/35000 | loss 0.1170 | 16.4s
step 3000/35000 | loss 0.1184 | 16.7s
step 3200/35000 | loss 0.1274 | 16.4s
step 3400/35000 | loss 0.1278 | 16.4s
step 3600/35000 | loss 0.1247 | 16.4s
step 3800/35000 | loss 0.1178 | 16.4s
step 4000/35000 | loss 0.1196 | 16.2s
step 4200/35000 | loss 0.1240 | 16.1s
step 4400/35000 | loss 0.1298 | 16.5s
step 4600/35000 | loss 0.1118 | 16.5s
step 48