cv-meanflow-perceptual-loss
Description
Flow Matching with Perceptual Loss
Background
Flow matching trains a neural network to predict velocity fields that transport samples from noise to data. Traditional training uses only MSE loss on the predicted velocity:
loss = ||v_pred - v_target||^2
However, we can also compute the denoised image from the predicted velocity:
x_denoised = x_t - t * v_pred
And apply perceptual losses (LPIPS, gradient loss, etc.) on x_denoised to encourage the network to generate high-quality images, not just accurate velocities.
Research Question
Can adding perceptual losses to flow matching training improve FID scores?
Task
You are given custom_train_perceptual.py, a self-contained training script that trains a
small DiT on CIFAR-10 (32x32) using flow matching with mean velocity objectives.
The editable region contains the loss computation in the training loop:
# Current: MSE loss only
loss_mse = ((pred_mean_vel - mean_vel_target) ** 2).mean()
loss = loss_mse
The fixed code already exposes:
lpips_fn(x_denoised, x_target)- perceptual losscompute_gradient_loss(x_denoised, x_target)- gradient-domain losscompute_multiscale_loss(x_denoised, x_target)- multi-resolution loss
Key constraint: Only apply auxiliary losses when t > 0.1 to avoid instability at small noise levels.
Evaluation
- Dataset: CIFAR-10 (32x32)
- Model: SmallDiT (512 hidden, 8 layers, ~40M params)
- Training: 10000 steps, batch size 128
- Metric: FID (lower is better), computed with clean-fid against CIFAR-10 train set
- Inference: 10-step Euler sampler
Baselines
- mse-only: Pure MSE loss on velocity
- mse-lpips: MSE + LPIPS perceptual loss (VGG features)
- mse-lpips-grad: MSE + LPIPS + Gradient loss with timestep-adaptive weighting
Code
1"""Custom Flow Matching Training Script — Perceptual Loss Variant2Small-scale flow matching training on CIFAR-10 with a lightweight DiT.3The training objective (MeanFlow) is pre-implemented; your task is to4design an improved loss function, optionally using perceptual losses.5"""67import math8import os9import time1011import lpips12import numpy as np13import torch14import torch.nn as nn15import torch.nn.functional as F
Additional context files (read-only):
alphaflow-main/perceptual_utils.py
Results
| Model | Type | best fid small ↓ | best fid medium ↓ | best fid large ↓ |
|---|---|---|---|---|
| lpips_grad | baseline | 17.790 | 17.190 | 14.490 |
| lpips_spectral | baseline | 17.380 | 15.820 | 13.630 |
| mse_base | baseline | 22.330 | 21.910 | N/A |
| anthropic/claude-opus-4.6 | vanilla | 20.610 | 16.830 | - |
| deepseek-reasoner | vanilla | 26.940 | 26.610 | - |
| google/gemini-3.1-pro-preview | vanilla | 20.980 | 20.720 | - |
| qwen/qwen3.6-plus | vanilla | - | 16.520 | - |
| anthropic/claude-opus-4.6 | agent | 19.880 | 15.050 | 14.000 |
| deepseek-reasoner | agent | 20.810 | 20.160 | 19.160 |
| google/gemini-3.1-pro-preview | agent | 20.980 | 20.720 | 18.630 |
| qwen/qwen3.6-plus | agent | N/A | 16.520 | N/A |