Agent Conversation: google/gemini-3.1-pro-preview
Task: cv-meanflow-perceptual-loss
Flow Matching with Perceptual Loss
Background
Flow matching trains a neural network to predict velocity fields that transport samples from noise to data. Traditional training uses only MSE loss on the predicted velocity:
loss = ||v_pred - v_target||^2
However, we can also compute the denoised image from the predicted velocity:
x_denoised = x_t - t * v_pred
And apply perceptual losses (LPIPS, gradient loss, etc.) on x_denoised to encourage the network to generate high-quality images, not just accurate velocities.
Research Question
Can adding perceptual losses to flow matching training improve FID scores?
Task
You are given custom_train_perceptual.py, a self-contained training script that trains a
small DiT on CIFAR-10 (32x32) using flow matching with mean velocity objectives.
The editable region contains the loss computation in the training loop:
# Current: MSE loss only
loss_mse = ((pred_mean_vel - mean_vel_target) ** 2).mean()
loss = loss_mse
The fixed code already exposes:
lpips_fn(x_denoised, x_target)- perceptual losscompute_gradient_loss(x_denoised, x_target)- gradient-domain losscompute_multiscale_loss(x_denoised, x_target)- multi-resolution loss
Numerical note: skip the t <= 0.1 samples when computing auxiliary
losses on x_denoised. At very small t, x_t = (1-t)*x + t*noise is
essentially clean, the velocity target degenerates, and any auxiliary
gradient through x_denoised = x_t - t*v_pred becomes ill-conditioned.
Evaluation
- Dataset: CIFAR-10 (32x32)
- Model: SmallDiT (512 hidden, 8 layers, ~40M params)
- Training: 10000 steps, batch size 128
- Metric: FID (lower is better), computed with clean-fid against CIFAR-10 train set
- Inference: 10-step Euler sampler
Baselines
All baselines use a clean linear loss combination (no inverse-loss adaptive
reweighting), and the auxiliary perceptual losses on x_denoised are
masked out for t <= 0.1 and decayed quadratically by (1-t)**2 so they
peak at low noise where the denoised image is meaningful.
Your method is evaluated on train_small and train_medium (train_large is skipped during agent runs). FID (lower is better):
| Name | Recipe | fid (small) |
|---|---|---|
mse_base | Pure MSE on predicted mean velocity. Floor. | 21.91 |
lpips_grad | MSE + LPIPS (Zhang 2018) + Sobel-gradient L1 (compute_gradient_loss). Spatial-domain perceptual recipe — Pix2PixHD / EnhanceNet style edge-sharpness matching. | 17.79 |
lpips_spectral | MSE + LPIPS + FFT-magnitude L1. Frequency-domain perceptual recipe (Mathieu 2016 / Fuoli 2021 "Fourier Space Losses"). | 17.38 |
The strongest baseline is lpips_spectral at fid 17.38 on train_small. Your method must achieve lower fid than lpips_spectral on both train_small and train_medium to count as a successful improvement. A method that wins on one scale but regresses on the other is a partial f
Tool Input
1{
2 "content": " # Predict mean velocity\n pred_mean_vel = net(x_t, sigma=t, sigma_next=t_next)\n\n # 1. Base losses on Velocity\n err = pred_mean_vel - mean_vel_target\n loss_mse_unscaled = (err ** 2).flatten(1).mean(1)\n loss_vel_charb = torch.sqrt(err ** 2 + 1e-6).flatten(1).mean(1)\n \n # 2. Auxiliary perceptual losses on denoised image\n x_denoised = x_t - t * pred_mean_vel\n t_flat = t.view(B)\n mask = (t_flat > 0.1)\n \n # Weighting: peak at low noise where denoised image is clear\n perceptual_w = ((1.0 - t_flat) ** 2) * mask.float()\n \n loss_lpips = torch.zeros(B, device=device)\n loss_spec = torch.zeros(B, device=device)\n loss_grad = torch.zeros(B, device=device)\n loss_multi = torch.zeros(B, device=device)\n loss_l1_img = torch.zeros(B, device=device)\n \n if mask.any():\n xd = x_denoised[mask].clamp(-1, 1).float()\n xc = x[mask].clamp(-1, 1).float()\n \n # LPIPS perceptual loss\n loss_lpips[mask] = lpips_fn(xd, xc).view(-1).float()\n \n # Gradient / Edge loss\n loss_grad[mask] = compute_gradient_loss(xd, xc).float()\n \n # Multi-resolution loss\n loss_multi[mask] = compute_multiscale_loss(xd, xc).float()\n \n # Image L1 loss (Charbonnier)\n loss_l1_img[mask] = torch.sqrt((xd - xc) ** 2 + 1e-6).mean(dim=(1, 2, 3)).float()\n \n # Spectral loss (Magnitude + Complex L1)\n fft_xd = torch.fft.rfft2(xd, norm=\"ortho\")\n fft_xc = torch.fft.rfft2(xc, norm=\"ortho\")\n \n mag_diff = (fft_xd.abs() - fft_xc.abs()).abs().mean(dim=(1, 2, 3))\n complex_diff = (fft_xd - fft_xc).abs().mean(dim=(1, 2, 3))\n \n loss_spec[mask] = (mag_diff + complex_diff).float()\n\n loss_aux = loss_lpips + loss_spec + loss_grad + loss_multi + loss_l1_img\n loss = (loss_mse_unscaled + loss_vel_charb + perceptual_w * loss_aux).mean()",
3 "filename": "alphaflow-main/custom_train_perceptual.py",
4 "after_line": 407,
5 "end_line": 428,
6 "start_line": 408,
7 "op": "replace"
8}ERROR: Lines 408..428 exceed the editable range. You may only edit lines 411–428. [Current file: alphaflow-main/custom_train_perceptual.py | editable: 411–428 | total: 485 lines] 408| # Predict mean velocity 409| pred_mean_vel = net(x_t, sigma=t, sigma_next=t_next) 410| 411 # TODO: Implement your loss function here. 412 # 413 # You have access to: 414 # pred_mean_vel : [B, C, H, W] — model's predicted mean velocity 415 # mean_vel_target: [B, C, H, W] — ground-truth mean velocity target 416 # x : [B, C, H, W] — clean image (normalized to [-1, 1]) 417 # x_t : [B, C, H, W] — noisy image at timestep t 418 # t : [B, 1, 1, 1] — current timestep 419 # t_next : [B, 1, 1, 1] — target timestep 420 # dt : [B, 1, 1, 1] — step size 42
Tool Input
1{
2 "after_line": 410,
3 "op": "replace",
4 "content": " # 1. Base losses on Velocity\n err = pred_mean_vel - mean_vel_target\n loss_mse_unscaled = (err ** 2).flatten(1).mean(1)\n loss_vel_charb = torch.sqrt(err ** 2 + 1e-6).flatten(1).mean(1)\n \n # 2. Auxiliary perceptual losses on denoised image\n x_denoised = x_t - t * pred_mean_vel\n t_flat = t.view(B)\n mask = (t_flat > 0.1)\n \n # Weighting: peak at low noise where denoised image is clear\n perceptual_w = ((1.0 - t_flat) ** 2) * mask.float()\n \n loss_lpips = torch.zeros(B, device=device)\n loss_spec = torch.zeros(B, device=device)\n loss_grad = torch.zeros(B, device=device)\n loss_multi = torch.zeros(B, device=device)\n loss_l1_img = torch.zeros(B, device=device)\n \n if mask.any():\n xd = x_denoised[mask].clamp(-1, 1).float()\n xc = x[mask].clamp(-1, 1).float()\n \n # LPIPS perceptual loss\n loss_lpips[mask] = lpips_fn(xd, xc).view(-1).float()\n \n # Gradient / Edge loss\n loss_grad[mask] = compute_gradient_loss(xd, xc).float()\n \n # Multi-resolution loss\n loss_multi[mask] = compute_multiscale_loss(xd, xc).float()\n \n # Image L1 loss (Charbonnier)\n loss_l1_img[mask] = torch.sqrt((xd - xc) ** 2 + 1e-6).mean(dim=(1, 2, 3)).float()\n \n # Spectral loss (Magnitude + Complex L1)\n fft_xd = torch.fft.rfft2(xd, norm=\"ortho\")\n fft_xc = torch.fft.rfft2(xc, norm=\"ortho\")\n \n mag_diff = (fft_xd.abs() - fft_xc.abs()).abs().mean(dim=(1, 2, 3))\n complex_diff = (fft_xd - fft_xc).abs().mean(dim=(1, 2, 3))\n \n loss_spec[mask] = (mag_diff + complex_diff).float()\n\n loss_aux = loss_lpips + loss_spec + loss_grad + loss_multi + loss_l1_img\n loss = (loss_mse_unscaled + loss_vel_charb + perceptual_w * loss_aux).mean()",
5 "filename": "alphaflow-main/custom_train_perceptual.py",
6 "end_line": 428,
7 "start_line": 411
8}OK: Replaced lines 411..428 with 46 line(s) in alphaflow-main/custom_train_perceptual.py. Editable range: 411–456. [Current file: alphaflow-main/custom_train_perceptual.py | editable: 411–456 | total: 513 lines] 408| # Predict mean velocity 409| pred_mean_vel = net(x_t, sigma=t, sigma_next=t_next) 410| 411 # 1. Base losses on Velocity 412 err = pred_mean_vel - mean_vel_target 413 loss_mse_unscaled = (err ** 2).flatten(1).mean(1) 414 loss_vel_charb = torch.sqrt(err ** 2 + 1e-6).flatten(1).mean(1) 415 416 # 2. Auxiliary perceptual losses on denoised image 417 x_denoised = x_t - t * pred_mean_vel 418 t_flat = t.view(B) 419 mask = (t_flat > 0.1) 420 421 # Weighting: peak at low noise where denoised image is clear 422 perceptual_w = ((1.0 - t_flat) ** 2) * mask.floa
[Test #1] (2 tests remaining; call submit(n=1) to submit this result)
### train_small (scripts/train_small.sh)
Training results:
FID: 20.94, Best FID: 20.98
### train_medium (scripts/train_medium.sh)
Training results:
FID: 23.77, Best FID: 20.72
[Leaderboard] Results saved: {'fid': 23.77, 'best_fid': 20.72, 'fid_small': 20.94, 'best_fid_small': 20.98, 'elapsed_train_small': 6117, 'fid_medium': 23.77, 'best_fid_medium': 20.72, 'elapsed_train_medium': 7896}Tool Input
1{
2 "op": "replace",
3 "end_line": 456,
4 "start_line": 411,
5 "after_line": 410,
6 "content": " # 1. Base loss on Velocity\n err = pred_mean_vel - mean_vel_target\n loss_mse_unscaled = (err ** 2).flatten(1).mean(1)\n \n # 2. Auxiliary perceptual losses on denoised image\n x_denoised = x_t - t * pred_mean_vel\n t_flat = t.view(B)\n mask = (t_flat > 0.1)\n \n # Weighting: peak at low noise where denoised image is clear\n perceptual_w = ((1.0 - t_flat) ** 2) * mask.float()\n \n loss_lpips = torch.zeros(B, device=device)\n loss_spec = torch.zeros(B, device=device)\n loss_grad = torch.zeros(B, device=device)\n \n if mask.any():\n xd = x_denoised[mask].clamp(-1, 1).float()\n xc = x[mask].clamp(-1, 1).float()\n \n # LPIPS perceptual loss\n loss_lpips[mask] = lpips_fn(xd, xc).view(-1).float()\n \n # Gradient / Edge loss\n loss_grad[mask] = compute_gradient_loss(xd, xc).float()\n \n # Spectral loss\n fft_xd = torch.fft.rfft2(xd)\n fft_xc = torch.fft.rfft2(xc)\n \n mag_diff = (fft_xd.abs() - fft_xc.abs()).abs().mean(dim=(1, 2, 3))\n complex_diff = (fft_xd - fft_xc).abs().mean(dim=(1, 2, 3))\n \n loss_spec[mask] = (mag_diff + complex_diff).float()\n\n loss_aux = loss_lpips + loss_spec + loss_grad\n loss = (loss_mse_unscaled + perceptual_w * loss_aux).mean()",
7 "filename": "alphaflow-main/custom_train_perceptual.py"
8}OK: Replaced lines 411..456 with 37 line(s) in alphaflow-main/custom_train_perceptual.py. Editable range: 411–447. [Current file: alphaflow-main/custom_train_perceptual.py | editable: 411–447 | total: 504 lines] 408| # Predict mean velocity 409| pred_mean_vel = net(x_t, sigma=t, sigma_next=t_next) 410| 411 # 1. Base loss on Velocity 412 err = pred_mean_vel - mean_vel_target 413 loss_mse_unscaled = (err ** 2).flatten(1).mean(1) 414 415 # 2. Auxiliary perceptual losses on denoised image 416 x_denoised = x_t - t * pred_mean_vel 417 t_flat = t.view(B) 418 mask = (t_flat > 0.1) 419 420 # Weighting: peak at low noise where denoised image is clear 421 perceptual_w = ((1.0 - t_flat) ** 2) * mask.float() 422 423 loss_lpips = torch.zeros(B, device=device
[Test #2] (1 test remaining; call submit(n=2) to submit this result)
### train_small (scripts/train_small.sh)
Training results:
FID: 25.47, Best FID: 25.42
### train_medium (scripts/train_medium.sh)
Training results:
FID: 31.40, Best FID: 25.09
[Leaderboard] Results saved: {'fid': 31.4, 'best_fid': 25.09, 'fid_small': 25.47, 'best_fid_small': 25.42, 'elapsed_train_small': 6070, 'fid_medium': 31.4, 'best_fid_medium': 25.09, 'elapsed_train_medium': 7865}Tool Input
1{
2 "start_line": 411,
3 "content": " # 1. Base loss on Velocity\n err = pred_mean_vel - mean_vel_target\n loss_mse_unscaled = (err ** 2).flatten(1).mean(1)\n \n # 2. Auxiliary perceptual losses on denoised image\n x_denoised = x_t - t * pred_mean_vel\n t_flat = t.view(B)\n mask = (t_flat > 0.1)\n \n # Weighting: peak at low noise where denoised image is clear\n perceptual_w = ((1.0 - t_flat) ** 2) * mask.float()\n \n loss_lpips = torch.zeros(B, device=device)\n loss_spec = torch.zeros(B, device=device)\n \n if mask.any():\n xd = x_denoised[mask].clamp(-1, 1).float()\n xc = x[mask].clamp(-1, 1).float()\n \n # LPIPS perceptual loss\n loss_lpips[mask] = lpips_fn(xd, xc).view(-1).float()\n \n # Spectral loss: Comprehensive Frequency-Domain Recipe\n fft_xd = torch.fft.rfft2(xd)\n fft_xc = torch.fft.rfft2(xc)\n amp_xd = fft_xd.abs()\n amp_xc = fft_xc.abs()\n \n diff = (amp_xd - amp_xc).abs()\n \n # a) Standard magnitude L1 (baseline)\n loss_mag = diff.mean(dim=(1, 2, 3))\n \n # b) Focal penalty: emphasizes frequencies with highest absolute errors (low/mid freq)\n max_diff = diff.amax(dim=(1, 2, 3), keepdim=True)\n focal_weight = diff / (max_diff + 1e-6)\n loss_focal = (focal_weight * diff).mean(dim=(1, 2, 3))\n \n # c) Log magnitude: emphasizes frequencies with highest relative errors (high freq)\n log_diff = (torch.log(amp_xd + 1e-4) - torch.log(amp_xc + 1e-4)).abs()\n loss_log = log_diff.mean(dim=(1, 2, 3))\n \n loss_spec[mask] = (loss_mag + loss_focal + 0.1 * loss_log).float()\n\n loss_aux = loss_lpips + loss_spec\n loss = (loss_mse_unscaled + perceptual_w * loss_aux).mean()",
4 "filename": "alphaflow-main/custom_train_perceptual.py",
5 "end_line": 447,
6 "op": "replace",
7 "after_line": 410
8}OK: Replaced lines 411..447 with 46 line(s) in alphaflow-main/custom_train_perceptual.py. Editable range: 411–456. [Current file: alphaflow-main/custom_train_perceptual.py | editable: 411–456 | total: 513 lines] 408| # Predict mean velocity 409| pred_mean_vel = net(x_t, sigma=t, sigma_next=t_next) 410| 411 # 1. Base loss on Velocity 412 err = pred_mean_vel - mean_vel_target 413 loss_mse_unscaled = (err ** 2).flatten(1).mean(1) 414 415 # 2. Auxiliary perceptual losses on denoised image 416 x_denoised = x_t - t * pred_mean_vel 417 t_flat = t.view(B) 418 mask = (t_flat > 0.1) 419 420 # Weighting: peak at low noise where denoised image is clear 421 perceptual_w = ((1.0 - t_flat) ** 2) * mask.float() 422 423 loss_lpips = torch.zeros(B, device=device
[Test #3] (0 tests remaining)
### train_small (scripts/train_small.sh)
Training results:
FID: 25.33, Best FID: 25.41
### train_medium (scripts/train_medium.sh)
Training results:
FID: 26.82, Best FID: 26.77
[Leaderboard] Results saved: {'fid': 26.82, 'best_fid': 26.77, 'fid_small': 25.33, 'best_fid_small': 25.41, 'elapsed_train_small': 5968, 'fid_medium': 26.82, 'best_fid_medium': 26.77, 'elapsed_train_medium': 7736}
[auto-submit] Test budget exhausted — auto-submitted BEST test #1 as final (metrics: {'fid': 23.77, 'best_fid': 20.72, 'fid_small': 20.94, 'best_fid_small': 20.98, 'elapsed_train_small': 6117, 'fid_medium': 23.77, 'best_fid_medium': 20.72, 'elapsed_train_medium': 7896}).