Agent Conversation: anthropic/claude-opus-4.6
Task: cv-meanflow-perceptual-loss
Flow Matching with Perceptual Loss
Background
Flow matching trains a neural network to predict velocity fields that transport samples from noise to data. Traditional training uses only MSE loss on the predicted velocity:
loss = ||v_pred - v_target||^2
However, we can also compute the denoised image from the predicted velocity:
x_denoised = x_t - t * v_pred
And apply perceptual losses (LPIPS, gradient loss, etc.) on x_denoised to encourage the network to generate high-quality images, not just accurate velocities.
Research Question
Can adding perceptual losses to flow matching training improve FID scores?
Task
You are given custom_train_perceptual.py, a self-contained training script that trains a
small DiT on CIFAR-10 (32x32) using flow matching with mean velocity objectives.
The editable region contains the loss computation in the training loop:
# Current: MSE loss only
loss_mse = ((pred_mean_vel - mean_vel_target) ** 2).mean()
loss = loss_mse
The fixed code already exposes:
lpips_fn(x_denoised, x_target)- perceptual losscompute_gradient_loss(x_denoised, x_target)- gradient-domain losscompute_multiscale_loss(x_denoised, x_target)- multi-resolution loss
Numerical note: skip the t <= 0.1 samples when computing auxiliary
losses on x_denoised. At very small t, x_t = (1-t)*x + t*noise is
essentially clean, the velocity target degenerates, and any auxiliary
gradient through x_denoised = x_t - t*v_pred becomes ill-conditioned.
Evaluation
- Dataset: CIFAR-10 (32x32)
- Model: SmallDiT (512 hidden, 8 layers, ~40M params)
- Training: 10000 steps, batch size 128
- Metric: FID (lower is better), computed with clean-fid against CIFAR-10 train set
- Inference: 10-step Euler sampler
Baselines
All baselines use a clean linear loss combination (no inverse-loss adaptive
reweighting), and the auxiliary perceptual losses on x_denoised are
masked out for t <= 0.1 and decayed quadratically by (1-t)**2 so they
peak at low noise where the denoised image is meaningful.
Your method is evaluated on train_small and train_medium (train_large is skipped during agent runs). FID (lower is better):
| Name | Recipe | fid (small) |
|---|---|---|
mse_base | Pure MSE on predicted mean velocity. Floor. | 21.91 |
lpips_grad | MSE + LPIPS (Zhang 2018) + Sobel-gradient L1 (compute_gradient_loss). Spatial-domain perceptual recipe — Pix2PixHD / EnhanceNet style edge-sharpness matching. | 17.79 |
lpips_spectral | MSE + LPIPS + FFT-magnitude L1. Frequency-domain perceptual recipe (Mathieu 2016 / Fuoli 2021 "Fourier Space Losses"). | 17.38 |
The strongest baseline is lpips_spectral at fid 17.38 on train_small. Your method must achieve lower fid than lpips_spectral on both train_small and train_medium to count as a successful improvement. A method that wins on one scale but regresses on the other is a partial f
Tool Input
1{
2 "op": "replace",
3 "filename": "alphaflow-main/custom_train_perceptual.py",
4 "start_line": 411,
5 "end_line": 428,
6 "content": " # === Multi-domain perceptual loss with phase-aware spectral matching ===\n #\n # Novel contribution: Most spectral losses (Mathieu 2016, Fuoli 2021)\n # only match FFT magnitudes. We add a *magnitude-weighted phase\n # consistency* loss: phase errors are penalized proportionally to\n # the spectral energy at that frequency, so structurally important\n # phase relationships are preserved while noisy low-energy phases\n # are ignored. Combined with gradient-domain and LPIPS losses, this\n # gives comprehensive spatial + frequency + deep-perceptual coverage.\n\n # 1) MSE on predicted mean velocity (primary objective)\n err = pred_mean_vel - mean_vel_target\n loss_mse = (err ** 2).flatten(1).mean(1)\n\n # 2) Compute denoised image for auxiliary losses\n x_denoised = x_t - t * pred_mean_vel\n t_flat = t.view(B)\n mask = (t_flat > 0.1)\n perceptual_w = ((1.0 - t_flat) ** 2) * mask.float()\n\n loss_lpips_vec = torch.zeros(B, device=device)\n loss_grad_vec = torch.zeros(B, device=device)\n loss_spec_vec = torch.zeros(B, device=device)\n loss_phase_vec = torch.zeros(B, device=device)\n\n if mask.any():\n xd = x_denoised[mask].clamp(-1, 1).float()\n xc = x[mask].clamp(-1, 1).float()\n\n # 3a) LPIPS deep perceptual loss\n loss_lpips_vec[mask] = lpips_fn(xd, xc).view(-1).float()\n\n # 3b) Gradient-domain (Sobel edge) loss\n loss_grad_vec[mask] = compute_gradient_loss(xd, xc).float()\n\n # 3c) FFT magnitude L1 loss (standard spectral)\n fft_d = torch.fft.rfft2(xd, norm='ortho')\n fft_c = torch.fft.rfft2(xc, norm='ortho')\n mag_diff = (fft_d.abs() - fft_c.abs()).abs()\n loss_spec_vec[mask] = mag_diff.flatten(1).mean(1)\n\n # 3d) NOVEL: Magnitude-weighted phase consistency loss\n # Phase differences are weighted by target magnitude so that\n # phase accuracy is enforced where spectral energy is high\n # (structurally important) and relaxed where energy is low.\n mag_weight = fft_c.abs()\n phase_d = torch.angle(fft_d)\n phase_c = torch.angle(fft_c)\n phase_diff = phase_d - phase_c\n # Wrap to [-pi, pi] for correct angular distance\n phase_diff = torch.atan2(torch.sin(phase_diff), torch.cos(phase_diff))\n weighted_phase = mag_weight * phase_diff.abs()\n loss_phase_vec[mask] = weighted_phase.flatten(1).mean(1)\n\n # 4) Combine all losses\n loss = loss_mse.mean() \\\n + 0.08 * (perceptual_w * loss_lpips_vec).mean() \\\n + 0.10 * (perceptual_w * loss_grad_vec).mean() \\\n + 0.05 * (perceptual_w * loss_spec_vec).mean() \\\n + 0.02 * (perceptual_w * loss_phase_vec).mean()\n"
7}OK: Replaced lines 411..428 with 60 line(s) in alphaflow-main/custom_train_perceptual.py. Editable range: 411–470. [Current file: alphaflow-main/custom_train_perceptual.py | editable: 411–470 | total: 527 lines] 408| # Predict mean velocity 409| pred_mean_vel = net(x_t, sigma=t, sigma_next=t_next) 410| 411 # === Multi-domain perceptual loss with phase-aware spectral matching === 412 # 413 # Novel contribution: Most spectral losses (Mathieu 2016, Fuoli 2021) 414 # only match FFT magnitudes. We add a *magnitude-weighted phase 415 # consistency* loss: phase errors are penalized proportionally to 416 # the spectral energy at that frequency, so structurally important 417 # phase relationships are preserved while noisy low-energy phases 418 # are ignored. Combined with gradient-domain and LPIPS losses, this 419 # gi
Tool Input
1{
2 "op": "replace",
3 "filename": "alphaflow-main/custom_train_perceptual.py",
4 "start_line": 411,
5 "end_line": 428,
6 "content": " err = pred_mean_vel - mean_vel_target\n loss_mse = (err ** 2).flatten(1).mean(1)\n x_denoised = x_t - t * pred_mean_vel\n t_flat = t.view(B)\n mask = (t_flat > 0.1)\n pw = ((1.0 - t_flat) ** 2) * mask.float()\n l_lp, l_sp, l_ph, l_gr, l_ms = (torch.zeros(B, device=device) for _ in range(5))\n if mask.any():\n xd, xc = x_denoised[mask].clamp(-1, 1).float(), x[mask].clamp(-1, 1).float()\n l_lp[mask] = lpips_fn(xd, xc).view(-1).float()\n l_gr[mask] = compute_gradient_loss(xd, xc).float()\n l_ms[mask] = compute_multiscale_loss(xd, xc).float()\n fd, fc = torch.fft.rfft2(xd), torch.fft.rfft2(xc)\n l_sp[mask] = (fd.abs() - fc.abs()).abs().flatten(1).mean(1)\n mag_w = (fd.abs() + fc.abs()) * 0.5 + 1e-8\n l_ph[mask] = (mag_w * (1.0 - torch.cos(fd.angle() - fc.angle()))).flatten(1).mean(1)\n loss = loss_mse.mean() + (pw * (0.05 * l_lp + 0.01 * l_sp + 0.003 * l_ph + 0.01 * l_gr + 0.01 * l_ms)).mean()\n"
7}OK: Replaced lines 411..428 with 17 line(s) in alphaflow-main/custom_train_perceptual.py. Editable range: 411–427. [Current file: alphaflow-main/custom_train_perceptual.py | editable: 411–427 | total: 484 lines] 408| # Predict mean velocity 409| pred_mean_vel = net(x_t, sigma=t, sigma_next=t_next) 410| 411 err = pred_mean_vel - mean_vel_target 412 loss_mse = (err ** 2).flatten(1).mean(1) 413 x_denoised = x_t - t * pred_mean_vel 414 t_flat = t.view(B) 415 mask = (t_flat > 0.1) 416 pw = ((1.0 - t_flat) ** 2) * mask.float() 417 l_lp, l_sp, l_ph, l_gr, l_ms = (torch.zeros(B, device=device) for _ in range(5)) 418 if mask.any(): 419 xd, xc = x_denoised[mask].clamp(-1, 1).float(), x[mask].clamp(-1, 1).float() 420 l_lp[mask] = lpips_fn(xd, xc).view(-1).float() 421 l_gr[mask]
[Test #1] (2 tests remaining; call submit(n=1) to submit this result)
### train_small (scripts/train_small.sh)
Training results:
FID: 20.32, Best FID: 20.61
### train_medium (scripts/train_medium.sh)
Training results:
FID: 16.50, Best FID: 16.83
[Leaderboard] Results saved: {'fid': 16.5, 'best_fid': 16.83, 'fid_small': 20.32, 'best_fid_small': 20.61, 'elapsed_train_small': 6102, 'fid_medium': 16.5, 'best_fid_medium': 16.83, 'elapsed_train_medium': 7892}Tool Input
1{
2 "content": " err = pred_mean_vel - mean_vel_target\n loss_mse = (err ** 2).flatten(1).mean(1)\n x_denoised = x_t - t * pred_mean_vel\n t_flat = t.view(B)\n mask = (t_flat > 0.1)\n pw = ((1.0 - t_flat) ** 2) * mask.float()\n l_lp, l_sp, l_ph, l_gr = (torch.zeros(B, device=device) for _ in range(4))\n if mask.any():\n xd, xc = x_denoised[mask].clamp(-1, 1).float(), x[mask].clamp(-1, 1).float()\n l_lp[mask] = lpips_fn(xd, xc).view(-1).float()\n l_gr[mask] = compute_gradient_loss(xd, xc).float()\n fd, fc = torch.fft.rfft2(xd), torch.fft.rfft2(xc)\n l_sp[mask] = (fd.abs() - fc.abs()).abs().flatten(1).mean(1)\n phase_diff = fd.angle() - fc.angle()\n mag_w = fc.abs() / (fc.abs().flatten(1).mean(1, keepdim=True).unsqueeze(-1).unsqueeze(-1) + 1e-8)\n l_ph[mask] = (mag_w * (1.0 - torch.cos(phase_diff))).flatten(1).mean(1)\n loss = loss_mse.mean() + (pw * (0.08 * l_lp + 0.015 * l_sp + 0.002 * l_ph + 0.005 * l_gr)).mean()\n",
3 "end_line": 427,
4 "filename": "alphaflow-main/custom_train_perceptual.py",
5 "op": "replace",
6 "start_line": 411
7}OK: Replaced lines 411..427 with 17 line(s) in alphaflow-main/custom_train_perceptual.py. Editable range: 411–427. [Current file: alphaflow-main/custom_train_perceptual.py | editable: 411–427 | total: 484 lines] 408| # Predict mean velocity 409| pred_mean_vel = net(x_t, sigma=t, sigma_next=t_next) 410| 411 err = pred_mean_vel - mean_vel_target 412 loss_mse = (err ** 2).flatten(1).mean(1) 413 x_denoised = x_t - t * pred_mean_vel 414 t_flat = t.view(B) 415 mask = (t_flat > 0.1) 416 pw = ((1.0 - t_flat) ** 2) * mask.float() 417 l_lp, l_sp, l_ph, l_gr = (torch.zeros(B, device=device) for _ in range(4)) 418 if mask.any(): 419 xd, xc = x_denoised[mask].clamp(-1, 1).float(), x[mask].clamp(-1, 1).float() 420 l_lp[mask] = lpips_fn(xd, xc).view(-1).float() 421 l_gr[mask] = com
[Test #2] (1 test remaining; call submit(n=2) to submit this result)
### train_small (scripts/train_small.sh)
Training results:
FID: 19.88, Best FID: 20.19
### train_medium (scripts/train_medium.sh)
Training results:
FID: 15.44, Best FID: 15.85
[Leaderboard] Results saved: {'fid': 15.44, 'best_fid': 15.85, 'fid_small': 19.88, 'best_fid_small': 20.19, 'elapsed_train_small': 6057, 'fid_medium': 15.44, 'best_fid_medium': 15.85, 'elapsed_train_medium': 7852}Tool Input
1{
2 "content": " err = pred_mean_vel - mean_vel_target\n loss_mse = (err ** 2).flatten(1).mean(1)\n x_denoised = x_t - t * pred_mean_vel\n t_flat = t.view(B)\n mask = (t_flat > 0.1)\n pw = ((1.0 - t_flat) ** 2) * mask.float()\n l_lp, l_sp, l_ph = (torch.zeros(B, device=device) for _ in range(3))\n if mask.any():\n xd, xc = x_denoised[mask].clamp(-1, 1).float(), x[mask].clamp(-1, 1).float()\n l_lp[mask] = lpips_fn(xd, xc).view(-1).float()\n fd, fc = torch.fft.rfft2(xd), torch.fft.rfft2(xc)\n l_sp[mask] = (fd.abs() - fc.abs()).abs().flatten(1).mean(1)\n mag_w = fc.abs() / (fc.abs().flatten(1).mean(1, keepdim=True).unsqueeze(-1).unsqueeze(-1) + 1e-8)\n l_ph[mask] = (mag_w * (1.0 - torch.cos(fd.angle() - fc.angle()))).flatten(1).mean(1)\n ramp = min(step / 1000.0, 1.0)\n aux = (pw * (0.1 * l_lp + 0.02 * l_sp + 0.003 * l_ph)).mean()\n loss = loss_mse.mean() + ramp * aux\n",
3 "end_line": 427,
4 "filename": "alphaflow-main/custom_train_perceptual.py",
5 "op": "replace",
6 "start_line": 411
7}OK: Replaced lines 411..427 with 17 line(s) in alphaflow-main/custom_train_perceptual.py. Editable range: 411–427. [Current file: alphaflow-main/custom_train_perceptual.py | editable: 411–427 | total: 484 lines] 408| # Predict mean velocity 409| pred_mean_vel = net(x_t, sigma=t, sigma_next=t_next) 410| 411 err = pred_mean_vel - mean_vel_target 412 loss_mse = (err ** 2).flatten(1).mean(1) 413 x_denoised = x_t - t * pred_mean_vel 414 t_flat = t.view(B) 415 mask = (t_flat > 0.1) 416 pw = ((1.0 - t_flat) ** 2) * mask.float() 417 l_lp, l_sp, l_ph = (torch.zeros(B, device=device) for _ in range(3)) 418 if mask.any(): 419 xd, xc = x_denoised[mask].clamp(-1, 1).float(), x[mask].clamp(-1, 1).float() 420 l_lp[mask] = lpips_fn(xd, xc).view(-1).float() 421 fd, fc = torch.fft.rff
[Test #3] (0 tests remaining)
### train_small (scripts/train_small.sh)
Training results:
FID: 19.58, Best FID: 19.88
### train_medium (scripts/train_medium.sh)
Training results:
FID: 14.83, Best FID: 15.05
[Leaderboard] Results saved: {'fid': 14.83, 'best_fid': 15.05, 'fid_small': 19.58, 'best_fid_small': 19.88, 'elapsed_train_small': 5975, 'fid_medium': 14.83, 'best_fid_medium': 15.05, 'elapsed_train_medium': 7746}
[auto-submit] Test budget exhausted — auto-submitted BEST test #3 as final (metrics: {'fid': 14.83, 'best_fid': 15.05, 'fid_small': 19.58, 'best_fid_small': 19.88, 'elapsed_train_small': 5975, 'fid_medium': 14.83, 'best_fid_medium': 15.05, 'elapsed_train_medium': 7746}).