Agent Conversation: qwen/qwen3.6-plus
Task: cv-meanflow-perceptual-loss
Flow Matching with Perceptual Loss
Background
Flow matching trains a neural network to predict velocity fields that transport samples from noise to data. Traditional training uses only MSE loss on the predicted velocity:
loss = ||v_pred - v_target||^2
However, we can also compute the denoised image from the predicted velocity:
x_denoised = x_t - t * v_pred
And apply perceptual losses (LPIPS, gradient loss, etc.) on x_denoised to encourage the network to generate high-quality images, not just accurate velocities.
Research Question
Can adding perceptual losses to flow matching training improve FID scores?
Task
You are given custom_train_perceptual.py, a self-contained training script that trains a
small DiT on CIFAR-10 (32x32) using flow matching with mean velocity objectives.
The editable region contains the loss computation in the training loop:
# Current: MSE loss only
loss_mse = ((pred_mean_vel - mean_vel_target) ** 2).mean()
loss = loss_mse
The fixed code already exposes:
lpips_fn(x_denoised, x_target)- perceptual losscompute_gradient_loss(x_denoised, x_target)- gradient-domain losscompute_multiscale_loss(x_denoised, x_target)- multi-resolution loss
Numerical note: skip the t <= 0.1 samples when computing auxiliary
losses on x_denoised. At very small t, x_t = (1-t)*x + t*noise is
essentially clean, the velocity target degenerates, and any auxiliary
gradient through x_denoised = x_t - t*v_pred becomes ill-conditioned.
Evaluation
- Dataset: CIFAR-10 (32x32)
- Model: SmallDiT (512 hidden, 8 layers, ~40M params)
- Training: 10000 steps, batch size 128
- Metric: FID (lower is better), computed with clean-fid against CIFAR-10 train set
- Inference: 10-step Euler sampler
Baselines
All baselines use a clean linear loss combination (no inverse-loss adaptive
reweighting), and the auxiliary perceptual losses on x_denoised are
masked out for t <= 0.1 and decayed quadratically by (1-t)**2 so they
peak at low noise where the denoised image is meaningful.
Your method is evaluated on train_small and train_medium (train_large is skipped during agent runs). FID (lower is better):
| Name | Recipe | fid (small) |
|---|---|---|
mse_base | Pure MSE on predicted mean velocity. Floor. | 21.91 |
lpips_grad | MSE + LPIPS (Zhang 2018) + Sobel-gradient L1 (compute_gradient_loss). Spatial-domain perceptual recipe — Pix2PixHD / EnhanceNet style edge-sharpness matching. | 17.79 |
lpips_spectral | MSE + LPIPS + FFT-magnitude L1. Frequency-domain perceptual recipe (Mathieu 2016 / Fuoli 2021 "Fourier Space Losses"). | 17.38 |
The strongest baseline is lpips_spectral at fid 17.38 on train_small. Your method must achieve lower fid than lpips_spectral on both train_small and train_medium to count as a successful improvement. A method that wins on one scale but regresses on the other is a partial f
Tool Input
1{
2 "op": "replace",
3 "filename": "alphaflow-main/custom_train_perceptual.py",
4 "start_line": 411,
5 "end_line": 428,
6 "content": " # ── Proposed: Adaptive Perceptual Loss with Spectral Consistency ──\n #\n # Key ideas:\n # 1) Adaptive perceptual weighting: scale LPIPS/gradient/spectral\n # inversely with MSE magnitude. When the model is uncertain\n # (high MSE), reduce perceptual weight to avoid conflicting\n # gradients. When confident (low MSE), increase weight for\n # fine-grained quality refinement.\n # 2) FFT spectral loss: L1 on log-magnitude of 2D FFT per channel,\n # encouraging frequency-domain fidelity (Mathieu 2016 / Fuoli 2021).\n # 3) Charbonnier regulariser on velocity: smooth-L1 on residuals\n # for robust outlier handling.\n # 4) Multi-scale perceptual: included via compute_multiscale_loss.\n #\n # Numerical note: auxiliary losses on x_denoised are skipped for\n # t <= 0.1 (ill-conditioned) and decayed quadratically by (1-t)^2.\n\n # Base MSE on velocity (per-sample for adaptive weighting)\n err = pred_mean_vel - mean_vel_target\n loss_mse_unscaled = (err ** 2).flatten(1).mean(1)\n\n # Charbonnier regulariser on velocity residuals (robust to outliers)\n loss_charb = torch.sqrt(err ** 2 + 1e-6).flatten(1).mean(1)\n\n # Denoised image for perceptual supervision\n x_denoised = x_t - t * pred_mean_vel\n t_flat = t.view(B)\n mask = (t_flat > 0.1)\n perceptual_w = ((1.0 - t_flat) ** 2) * mask.float()\n\n # Adaptive perceptual weight: inversely proportional to normalised\n # MSE. High MSE → low perceptual weight; low MSE → high weight.\n # Sigmoid-like mapping clamped to [0.5, 1.5] for stability.\n mse_mean = loss_mse_unscaled.mean() + 1e-8\n loss_mse_normed = loss_mse_unscaled / mse_mean\n adapt_w = 1.5 / (1.0 + 0.5 * loss_mse_normed) # ≈ 1/(1+0.5*ratio)\n\n # Preallocate per-sample auxiliary losses\n loss_lpips = torch.zeros(B, device=device)\n loss_grad = torch.zeros(B, device=device)\n loss_multi = torch.zeros(B, device=device)\n loss_spec = torch.zeros(B, device=device)\n\n if mask.any():\n xd = x_denoised[mask].clamp(-1, 1).float()\n xc = x[mask].clamp(-1, 1).float()\n loss_lpips[mask] = lpips_fn(xd, xc).view(-1).float()\n loss_grad[mask] = compute_gradient_loss(xd, xc).float()\n loss_multi[mask] = compute_multiscale_loss(xd, xc).float()\n\n # FFT-magnitude L1 loss (per-channel, log-amplitude)\n eps_fft = 1e-4\n mag_d = torch.abs(torch.fft.rfft2(xd, norm='ortho')) + eps_fft\n mag_t = torch.abs(torch.fft.rfft2(xc, norm='ortho')) + eps_fft\n loss_spec[mask] = F.l1_loss(\n torch.log(mag_d), torch.log(mag_t),\n reduction='none'\n ).flatten(1).mean(1).float()\n\n # Weight coefficients (balanced for CIFAR-10 32x32)\n w_mse = 1.0\n w_charb = 0.03\n w_lpips = 0.12\n w_grad = 0.05\n w_multi = 0.03\n w_spec = 0.04\n\n loss = (\n w_mse * loss_mse_unscaled\n + w_charb * loss_charb\n + adapt_w * perceptual_w * (\n w_lpips * loss_lpips\n + w_grad * loss_grad\n + w_multi * loss_multi\n + w_spec * loss_spec\n )\n ).mean()"
7}OK: Replaced lines 411..428 with 77 line(s) in alphaflow-main/custom_train_perceptual.py. Editable range: 411–487. [Current file: alphaflow-main/custom_train_perceptual.py | editable: 411–487 | total: 544 lines] 408| # Predict mean velocity 409| pred_mean_vel = net(x_t, sigma=t, sigma_next=t_next) 410| 411 # ── Proposed: Adaptive Perceptual Loss with Spectral Consistency ── 412 # 413 # Key ideas: 414 # 1) Adaptive perceptual weighting: scale LPIPS/gradient/spectral 415 # inversely with MSE magnitude. When the model is uncertain 416 # (high MSE), reduce perceptual weight to avoid conflicting 417 # gradients. When confident (low MSE), increase weight for 418 # fine-grained quality refinement. 419 # 2) FFT spectral loss: L1 on log-magnitude of 2D FFT per channel, 420 # encouraging f
[Test #1] (2 tests remaining; call submit(n=1) to submit this result)
### train_small (scripts/train_small.sh)
Training results:
FID: 19.53, Best FID: 19.79
### train_medium (scripts/train_medium.sh)
Training results:
FID: 15.25, Best FID: 15.56
[Leaderboard] Results saved: {'fid': 15.25, 'best_fid': 15.56, 'fid_small': 19.53, 'best_fid_small': 19.79, 'elapsed_train_small': 6106, 'fid_medium': 15.25, 'best_fid_medium': 15.56, 'elapsed_train_medium': 7893}Tool Input
1{
2 "op": "replace",
3 "filename": "alphaflow-main/custom_train_perceptual.py",
4 "start_line": 411,
5 "end_line": 487,
6 "content": " # ── Proposed: Phase-Aware Spectral Perceptual Loss ──\n #\n # Novel contribution: In addition to standard MSE + LPIPS + FFT\n # magnitude, we add a Fourier PHASE loss. The Fourier phase\n # spectrum encodes spatial structure (edges, shapes, object\n # placement) far more than magnitude. Minimising phase error on\n # the denoised image encourages correct geometry and sharp edges.\n #\n # Phase difference uses a circular metric: d(φ1,φ2) = min(|Δφ|, 2π-|Δφ|)\n # which correctly handles the periodicity of phase angles.\n #\n # We also add a gradient-magnitude consistency term in Fourier\n # domain: the radial profile of the log-magnitude spectrum should\n # match between denoised and clean images.\n #\n # Numerical note: auxiliary losses on x_denoised are skipped for\n # t <= 0.1 (ill-conditioned) and decayed quadratically by (1-t)^2.\n\n # Base MSE on velocity\n err = pred_mean_vel - mean_vel_target\n loss_mse_unscaled = (err ** 2).flatten(1).mean(1)\n\n # Denoised image for perceptual supervision\n x_denoised = x_t - t * pred_mean_vel\n t_flat = t.view(B)\n mask = (t_flat > 0.1)\n perceptual_w = ((1.0 - t_flat) ** 2) * mask.float()\n\n # Preallocate per-sample auxiliary losses\n loss_lpips = torch.zeros(B, device=device)\n loss_grad = torch.zeros(B, device=device)\n loss_spec = torch.zeros(B, device=device)\n loss_phase = torch.zeros(B, device=device)\n loss_radial = torch.zeros(B, device=device)\n\n if mask.any():\n xd = x_denoised[mask].clamp(-1, 1).float()\n xc = x[mask].clamp(-1, 1).float()\n\n # LPIPS (VGG-based perceptual)\n loss_lpips[mask] = lpips_fn(xd, xc).view(-1).float()\n\n # Sobel gradient L1\n loss_grad[mask] = compute_gradient_loss(xd, xc).float()\n\n # FFT-magnitude L1 (spectral fidelity)\n eps_fft = 1e-4\n fft_d = torch.fft.rfft2(xd, norm='ortho')\n fft_c = torch.fft.rfft2(xc, norm='ortho')\n mag_d = torch.abs(fft_d) + eps_fft\n mag_c = torch.abs(fft_c) + eps_fft\n loss_spec[mask] = F.l1_loss(\n torch.log(mag_d), torch.log(mag_c),\n reduction='none'\n ).flatten(1).mean(1).float()\n\n # Fourier PHASE loss: circular phase difference\n # Phase encodes spatial structure; magnitude encodes texture energy\n phase_d = torch.angle(fft_d) # [-pi, pi]\n phase_c = torch.angle(fft_c)\n # Circular difference: min(|dφ|, 2π-|dφ|)\n dphase = (phase_d - phase_c).abs()\n dphase = torch.minimum(dphase, 2 * math.pi - dphase)\n # Weight by magnitude to emphasise informative frequencies\n phase_weight = mag_c / (mag_c.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[0] + 1e-4)\n loss_phase[mask] = (dphase * phase_weight).flatten(1).mean(1).float()\n\n # Radial profile loss: compare average log-magnitude at each\n # radial frequency. This enforces correct frequency energy\n # distribution (blur vs sharp) independent of phase.\n H, W = fft_d.shape[-2], fft_d.shape[-1]\n # Compute radial distance for each frequency bin\n y_coords, x_coords = torch.meshgrid(\n torch.arange(H, device=fft_d.device, dtype=torch.float32),\n torch.arange(W, device=fft_d.device, dtype=torch.float32),\n indexing='ij'\n )\n # Normalised radial distance\n r = torch.sqrt((y_coords / H) ** 2 + (x_coords / W) ** 2)\n num_rings = 16\n r_quantized = (r * num_rings).floor().long().clamp(0, num_rings - 1)\n # Average log-magnitude per ring\n log_mag_d = torch.log(mag_d + eps_fft)\n log_mag_c = torch.log(mag_c + eps_fft)\n radial_loss_batch = torch.zeros(mask.sum(), device=fft_d.device)\n for ring_idx in range(num_rings):\n ring_mask = (r_quantized == ring_idx).float()\n count = ring_mask.sum() + 1e-4\n avg_d = (log_mag_d * ring_mask).flatten(1).sum(1) / count\n avg_c = (log_mag_c * ring_mask).flatten(1).sum(1) / count\n radial_loss_batch += (avg_d - avg_c).abs()\n loss_radial[mask] = radial_loss_batch.float()\n\n # Weight coefficients — tuned for CIFAR-10 32x32\n w_mse = 1.0\n w_lpips = 0.12\n w_grad = 0.05\n w_spec = 0.06\n w_phase = 0.08\n w_radial = 0.04\n\n loss = (\n w_mse * loss_mse_unscaled\n + perceptual_w * (\n w_lpips * loss_lpips\n + w_grad * loss_grad\n + w_spec * loss_spec\n + w_phase * loss_phase\n + w_radial * loss_radial\n )\n ).mean()"
7}OK: Replaced lines 411..487 with 111 line(s) in alphaflow-main/custom_train_perceptual.py. Editable range: 411–521. [Current file: alphaflow-main/custom_train_perceptual.py | editable: 411–521 | total: 578 lines] 408| # Predict mean velocity 409| pred_mean_vel = net(x_t, sigma=t, sigma_next=t_next) 410| 411 # ── Proposed: Phase-Aware Spectral Perceptual Loss ── 412 # 413 # Novel contribution: In addition to standard MSE + LPIPS + FFT 414 # magnitude, we add a Fourier PHASE loss. The Fourier phase 415 # spectrum encodes spatial structure (edges, shapes, object 416 # placement) far more than magnitude. Minimising phase error on 417 # the denoised image encourages correct geometry and sharp edges. 418 # 419 # Phase difference uses a circular metric: d(φ1,φ2) = min(|Δφ|, 2π-|Δφ|) 420 # which correc
[Test #2] (1 test remaining; call submit(n=2) to submit this result)
### train_small (scripts/train_small.sh)
Training results:
FID: 35.24, Best FID: 35.79
### train_medium (scripts/train_medium.sh)
Training results:
FID: 633.24, Best FID: 633.24
[Leaderboard] Results saved: {'fid': 633.24, 'best_fid': 633.24, 'fid_small': 35.24, 'best_fid_small': 35.79, 'elapsed_train_small': 6279, 'fid_medium': 633.24, 'best_fid_medium': 633.24, 'elapsed_train_medium': 7907}Tool Input
1{
2 "op": "replace",
3 "end_line": 521,
4 "filename": "alphaflow-main/custom_train_perceptual.py",
5 "start_line": 411,
6 "content": " # ── Proposed: Structure-Aware Perceptual Loss ──\n #\n # Novel contributions:\n # 1) Gradient DIRECTION consistency: match the angle of spatial\n # gradients between denoised and clean images. This is\n # complementary to gradient MAGNITUDE loss (already in baseline)\n # — direction enforces correct edge orientation, not just\n # intensity change.\n # 2) Huber loss on velocity: smooth-L1 instead of pure MSE for\n # robust outlier handling.\n # 3) Wavelet subband loss: simple 1-level Haar decomposition,\n # then L1 on each subband separately. This gives localized\n # frequency control that is more spatially aware than FFT.\n #\n # Numerical note: auxiliary losses on x_denoised are skipped for\n # t <= 0.1 (ill-conditioned) and decayed quadratically by (1-t)^2.\n\n # Base velocity loss: Huber (robust to outliers)\n err = pred_mean_vel - mean_vel_target\n # Huber loss: 0.5*err^2 if |err|<delta, delta*(|err|-0.5*delta) otherwise\n huber_delta = 1.0\n abs_err = err.abs()\n huber_per_elem = torch.where(\n abs_err < huber_delta,\n 0.5 * err ** 2,\n huber_delta * (abs_err - 0.5 * huber_delta)\n )\n loss_vel = huber_per_elem.flatten(1).mean(1)\n\n # Denoised image for perceptual supervision\n x_denoised = x_t - t * pred_mean_vel\n t_flat = t.view(B)\n mask = (t_flat > 0.1)\n perceptual_w = ((1.0 - t_flat) ** 2) * mask.float()\n\n # Preallocate per-sample auxiliary losses\n loss_lpips = torch.zeros(B, device=device)\n loss_grad = torch.zeros(B, device=device)\n loss_spec = torch.zeros(B, device=device)\n loss_gdir = torch.zeros(B, device=device)\n loss_wavelet = torch.zeros(B, device=device)\n\n if mask.any():\n xd = x_denoised[mask].clamp(-1, 1).float()\n xc = x[mask].clamp(-1, 1).float()\n\n # LPIPS (VGG-based perceptual)\n loss_lpips[mask] = lpips_fn(xd, xc).view(-1).float()\n\n # Sobel gradient L1 (spatial edge fidelity)\n loss_grad[mask] = compute_gradient_loss(xd, xc).float()\n\n # FFT-magnitude L1 (spectral fidelity)\n eps_fft = 1e-4\n mag_d = torch.abs(torch.fft.rfft2(xd, norm='ortho')) + eps_fft\n mag_c = torch.abs(torch.fft.rfft2(xc, norm='ortho')) + eps_fft\n loss_spec[mask] = F.l1_loss(\n torch.log(mag_d), torch.log(mag_c),\n reduction='none'\n ).flatten(1).mean(1).float()\n\n # Gradient DIRECTION consistency (novel)\n # Compute gradients via Sobel, then match gradient angles\n # using cos(θ_d - θ_c) = cos(θ_d)*cos(θ_c) + sin(θ_d)*sin(θ_c)\n # which equals (gx_d*gx_c + gy_d*gy_c) / (|g_d|*|g_c|)\n # We use 1 - cosine_similarity as the loss\n # First compute Sobel gradients\n sobel_x = torch.tensor(\n [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],\n device=device, dtype=torch.float32\n ).view(1, 1, 3, 3)\n sobel_y = torch.tensor(\n [[-1, -2, -1], [0, 0, 0], [1, 2, 1]],\n device=device, dtype=torch.float32\n ).view(1, 1, 3, 3)\n\n gx_d = F.conv2d(xd, sobel_x.repeat(3, 1, 1, 1), groups=3, padding=1)\n gy_d = F.conv2d(xd, sobel_y.repeat(3, 1, 1, 1), groups=3, padding=1)\n gx_c = F.conv2d(xc, sobel_x.repeat(3, 1, 1, 1), groups=3, padding=1)\n gy_c = F.conv2d(xc, sobel_y.repeat(3, 1, 1, 1), groups=3, padding=1)\n\n # Cosine similarity of gradient vectors\n dot = gx_d * gx_c + gy_d * gy_c\n norm_d = torch.sqrt(gx_d ** 2 + gy_d ** 2 + 1e-8)\n norm_c = torch.sqrt(gx_c ** 2 + gy_c ** 2 + 1e-8)\n cos_sim = (dot / (norm_d * norm_c + 1e-8)).mean(dim=[1, 2, 3])\n # Loss = 1 - cos_sim (0 = perfect alignment, 2 = opposite)\n loss_gdir[mask] = (1.0 - cos_sim).clamp(0, 2).float()\n\n # Wavelet subband loss (1-level Haar, novel for flow matching)\n # Haar wavelet via averaging and differencing adjacent pixels\n def haar_decompose(img):\n \"\"\"1-level 2D Haar wavelet decomposition.\"\"\"\n # Horizontal split\n lo_h = 0.5 * (img[..., ::2] + img[..., 1::2])\n hi_h = 0.5 * (img[..., ::2] - img[..., 1::2])\n # Vertical split\n LL = 0.5 * (lo_h[..., ::2, :] + lo_h[..., 1::2, :])\n LH = 0.5 * (lo_h[..., ::2, :] - lo_h[..., 1::2, :])\n HL = 0.5 * (hi_h[..., ::2, :] + hi_h[..., 1::2, :])\n HH = 0.5 * (hi_h[..., ::2, :] - hi_h[..., 1::2, :])\n return LL, LH, HL, HH\n\n LL_d, LH_d, HL_d, HH_d = haar_decompose(xd)\n LL_c, LH_c, HL_c, HH_c = haar_decompose(xc)\n\n # L1 on each subband, weighted by perceptual importance\n w_LL, w_LH, w_HL, w_HH = 0.5, 1.0, 1.0, 1.5\n wavelet_loss_batch = (\n w_LL * F.l1_loss(LL_d, LL_c, reduction='none').flatten(1).mean(1)\n + w_LH * F.l1_loss(LH_d, LH_c, reduction='none').flatten(1).mean(1)\n + w_HL * F.l1_loss(HL_d, HL_c, reduction='none').flatten(1).mean(1)\n + w_HH * F.l1_loss(HH_d, HH_c, reduction='none').flatten(1).mean(1)\n )\n loss_wavelet[mask] = wavelet_loss_batch.float()\n\n # Weight coefficients — calibrated for CIFAR-10 32x32\n w_vel = 1.0\n w_lpips = 0.12\n w_grad = 0.05\n w_spec = 0.04\n w_gdir = 0.06\n w_wavelet = 0.03\n\n loss = (\n w_vel * loss_vel\n + perceptual_w * (\n w_lpips * loss_lpips\n + w_grad * loss_grad\n + w_spec * loss_spec\n + w_gdir * loss_gdir\n + w_wavelet * loss_wavelet\n )\n ).mean()"
7}OK: Replaced lines 411..521 with 134 line(s) in alphaflow-main/custom_train_perceptual.py. Editable range: 411–544. [Current file: alphaflow-main/custom_train_perceptual.py | editable: 411–544 | total: 601 lines] 408| # Predict mean velocity 409| pred_mean_vel = net(x_t, sigma=t, sigma_next=t_next) 410| 411 # ── Proposed: Structure-Aware Perceptual Loss ── 412 # 413 # Novel contributions: 414 # 1) Gradient DIRECTION consistency: match the angle of spatial 415 # gradients between denoised and clean images. This is 416 # complementary to gradient MAGNITUDE loss (already in baseline) 417 # — direction enforces correct edge orientation, not just 418 # intensity change. 419 # 2) Huber loss on velocity: smooth-L1 instead of pure MSE for 420 # robust outlier handling. 421
Tool Input
1{
2 "op": "replace",
3 "end_line": 544,
4 "filename": "alphaflow-main/custom_train_perceptual.py",
5 "start_line": 411,
6 "content": " # ── Proposed: MSE + LPIPS + Spectral + SSIM Structural Loss ──\n #\n # Novel contribution: SSIM (Structural Similarity) loss on the\n # denoised image. While LPIPS captures deep-feature similarity\n # and FFT captures global frequency content, SSIM directly\n # optimizes for local structural similarity (luminance, contrast,\n # and correlation structure) which is complementary to both.\n #\n # SSIM = (2*μx*μy + C1)/(μx²+μy²+C1) * (2*σxy + C2)/(σx²+σy²+C2)\n #\n # We also use gradient-direction cosine similarity (not just\n # magnitude), which encourages correct edge orientations.\n #\n # Numerical note: auxiliary losses on x_denoised are skipped for\n # t <= 0.1 (ill-conditioned) and decayed quadratically by (1-t)^2.\n\n # Base MSE on velocity\n err = pred_mean_vel - mean_vel_target\n loss_mse_unscaled = (err ** 2).flatten(1).mean(1)\n\n # Denoised image for perceptual supervision\n x_denoised = x_t - t * pred_mean_vel\n t_flat = t.view(B)\n mask = (t_flat > 0.1)\n perceptual_w = ((1.0 - t_flat) ** 2) * mask.float()\n\n # Preallocate per-sample auxiliary losses\n loss_lpips = torch.zeros(B, device=device)\n loss_grad = torch.zeros(B, device=device)\n loss_spec = torch.zeros(B, device=device)\n loss_ssim = torch.zeros(B, device=device)\n loss_gdir = torch.zeros(B, device=device)\n\n if mask.any():\n xd = x_denoised[mask].clamp(-1, 1).float()\n xc = x[mask].clamp(-1, 1).float()\n\n # LPIPS (VGG-based perceptual)\n loss_lpips[mask] = lpips_fn(xd, xc).view(-1).float()\n\n # Sobel gradient L1 (spatial edge fidelity)\n loss_grad[mask] = compute_gradient_loss(xd, xc).float()\n\n # FFT-magnitude L1 (spectral fidelity)\n eps_fft = 1e-4\n mag_d = torch.abs(torch.fft.rfft2(xd, norm='ortho')) + eps_fft\n mag_c = torch.abs(torch.fft.rfft2(xc, norm='ortho')) + eps_fft\n loss_spec[mask] = F.l1_loss(\n torch.log(mag_d), torch.log(mag_c),\n reduction='none'\n ).flatten(1).mean(1).float()\n\n # SSIM structural similarity loss (novel contribution)\n # Gaussian-weighted SSIM with 7x7 kernel, computed per-channel\n # then averaged. SSIM = l * c * s (luminance, contrast, structure)\n def gaussian_1d(size, sigma):\n coords = torch.arange(size, device=xd.device, dtype=torch.float32) - (size // 2)\n g = torch.exp(-coords ** 2 / (2 * sigma ** 2))\n return g / g.sum()\n\n kernel_1d = gaussian_1d(7, 1.5)\n kernel_2d = (kernel_1d.unsqueeze(1) * kernel_1d.unsqueeze(0)).view(1, 1, 7, 7)\n\n C1 = 0.01 ** 2\n C2 = 0.03 ** 2\n\n # Convolve with 2D Gaussian kernel per channel\n pad = 3\n def conv_gaussian(x):\n # x: [B, 3, H, W]\n return F.conv2d(x, kernel_2d.repeat(3, 1, 1, 1), padding=pad, groups=3)\n\n mu_d = conv_gaussian(xd)\n mu_c = conv_gaussian(xc)\n\n mu_d_sq = mu_d ** 2\n mu_c_sq = mu_c ** 2\n mu_dc = mu_d * mu_c\n\n sigma_d_sq = conv_gaussian(xd ** 2) - mu_d_sq\n sigma_c_sq = conv_gaussian(xc ** 2) - mu_c_sq\n sigma_dc = conv_gaussian(xd * xc) - mu_dc\n\n sigma_d_sq = sigma_d_sq.clamp(min=0)\n sigma_c_sq = sigma_c_sq.clamp(min=0)\n\n # SSIM map per sample: average over spatial dims and channels\n numerator = (2 * mu_dc + C1) * (2 * sigma_dc + C2)\n denominator = (mu_d_sq + mu_c_sq + C1) * (sigma_d_sq + sigma_c_sq + C2)\n ssim_map = numerator / (denominator + 1e-12)\n # Average SSIM over channels and spatial dimensions\n ssim_per_sample = ssim_map.mean(dim=[1, 2, 3])\n loss_ssim[mask] = (1.0 - ssim_per_sample).float()\n\n # Gradient direction consistency (novel)\n sobel_x = torch.tensor(\n [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],\n device=xd.device, dtype=torch.float32\n ).view(1, 1, 3, 3)\n sobel_y = torch.tensor(\n [[-1, -2, -1], [0, 0, 0], [1, 2, 1]],\n device=xd.device, dtype=torch.float32\n ).view(1, 1, 3, 3)\n\n gx_d = F.conv2d(xd, sobel_x.repeat(3, 1, 1, 1), groups=3, padding=1)\n gy_d = F.conv2d(xd, sobel_y.repeat(3, 1, 1, 1), groups=3, padding=1)\n gx_c = F.conv2d(xc, sobel_x.repeat(3, 1, 1, 1), groups=3, padding=1)\n gy_c = F.conv2d(xc, sobel_y.repeat(3, 1, 1, 1), groups=3, padding=1)\n\n # Cosine similarity of gradient vectors at each pixel\n dot = gx_d * gx_c + gy_d * gy_c\n norm_d = torch.sqrt(gx_d ** 2 + gy_d ** 2 + 1e-8)\n norm_c = torch.sqrt(gx_c ** 2 + gy_c ** 2 + 1e-8)\n cos_sim = dot / (norm_d * norm_c + 1e-8)\n # Weight by gradient magnitude of clean image (important edges matter more)\n g_mag_c = norm_c\n weight = g_mag_c / (g_mag_c.mean(dim=[1, 2, 3], keepdim=True) + 1e-8)\n weighted_cos = (cos_sim * weight).mean(dim=[1, 2, 3])\n loss_gdir[mask] = (1.0 - weighted_cos).clamp(0, 2).float()\n\n # Weight coefficients — conservative to not destabilize\n w_mse = 1.0\n w_lpips = 0.12\n w_grad = 0.05\n w_spec = 0.04\n w_ssim = 0.08\n w_gdir = 0.03\n\n loss = (\n w_mse * loss_mse_unscaled\n + perceptual_w * (\n w_lpips * loss_lpips\n + w_grad * loss_grad\n + w_spec * loss_spec\n + w_ssim * loss_ssim\n + w_gdir * loss_gdir\n )\n ).mean()"
7}OK: Replaced lines 411..544 with 138 line(s) in alphaflow-main/custom_train_perceptual.py. Editable range: 411–548. [Current file: alphaflow-main/custom_train_perceptual.py | editable: 411–548 | total: 605 lines] 408| # Predict mean velocity 409| pred_mean_vel = net(x_t, sigma=t, sigma_next=t_next) 410| 411 # ── Proposed: MSE + LPIPS + Spectral + SSIM Structural Loss ── 412 # 413 # Novel contribution: SSIM (Structural Similarity) loss on the 414 # denoised image. While LPIPS captures deep-feature similarity 415 # and FFT captures global frequency content, SSIM directly 416 # optimizes for local structural similarity (luminance, contrast, 417 # and correlation structure) which is complementary to both. 418 # 419 # SSIM = (2*μx*μy + C1)/(μx²+μy²+C1) * (2*σxy + C2)/(σx²+σy²+C2) 420 # 421
Tool Input
1{
2 "op": "replace",
3 "end_line": 548,
4 "filename": "alphaflow-main/custom_train_perceptual.py",
5 "start_line": 411,
6 "content": " # ── Proposed: MSE + LPIPS + Spectral + Laplacian Pyramid Loss ──\n #\n # Novel contribution: Laplacian pyramid reconstruction loss.\n # The Laplacian pyramid decomposes an image into band-pass residuals\n # at multiple scales. Each level captures structure at a specific\n # spatial frequency band. Matching these residuals between denoised\n # and clean images enforces correct multi-scale texture and edge\n # structure — more localized than FFT, more spatially aware than\n # LPIPS.\n #\n # Numerical note: auxiliary losses on x_denoised are skipped for\n # t <= 0.1 (ill-conditioned) and decayed quadratically by (1-t)^2.\n\n # Base MSE on velocity\n err = pred_mean_vel - mean_vel_target\n loss_mse_unscaled = (err ** 2).flatten(1).mean(1)\n\n # Denoised image for perceptual supervision\n x_denoised = x_t - t * pred_mean_vel\n t_flat = t.view(B)\n mask = (t_flat > 0.1)\n perceptual_w = ((1.0 - t_flat) ** 2) * mask.float()\n\n # Preallocate per-sample auxiliary losses\n loss_lpips = torch.zeros(B, device=device)\n loss_grad = torch.zeros(B, device=device)\n loss_spec = torch.zeros(B, device=device)\n loss_laplacian = torch.zeros(B, device=device)\n\n if mask.any():\n xd = x_denoised[mask].clamp(-1, 1).float()\n xc = x[mask].clamp(-1, 1).float()\n\n # LPIPS (VGG-based perceptual)\n loss_lpips[mask] = lpips_fn(xd, xc).view(-1).float()\n\n # Sobel gradient L1 (spatial edge fidelity)\n loss_grad[mask] = compute_gradient_loss(xd, xc).float()\n\n # FFT-magnitude L1 (spectral fidelity)\n eps_fft = 1e-4\n mag_d = torch.abs(torch.fft.rfft2(xd, norm='ortho')) + eps_fft\n mag_c = torch.abs(torch.fft.rfft2(xc, norm='ortho')) + eps_fft\n loss_spec[mask] = F.l1_loss(\n torch.log(mag_d), torch.log(mag_c),\n reduction='none'\n ).flatten(1).mean(1).float()\n\n # Laplacian pyramid reconstruction loss (novel contribution)\n # Build 3-level Laplacian pyramid using Gaussian pyramid\n def gaussian_blur(x):\n \"\"\"Simple 5x5 separable Gaussian blur.\"\"\"\n # 1D Gaussian kernel [1, 4, 6, 4, 1] / 16\n k = torch.tensor([1, 4, 6, 4, 1], dtype=torch.float32, device=x.device) / 16.0\n k_row = k.view(1, 1, 1, 5)\n k_col = k.view(1, 1, 5, 1)\n # Pad\n x_pad = F.pad(x, (2, 2, 0, 0), mode='reflect')\n x_pad = F.conv2d(x_pad, k_row.repeat(3, 1, 1, 1), groups=3)\n x_pad = F.pad(x_pad, (0, 0, 2, 2), mode='reflect')\n x_pad = F.conv2d(x_pad, k_col.repeat(3, 1, 1, 1), groups=3)\n return x_pad\n\n def build_laplacian_pyramid(img, levels=3):\n \"\"\"Build Laplacian pyramid. Returns list of band-pass residuals.\"\"\"\n pyramid = []\n current = img\n for _ in range(levels):\n blurred = gaussian_blur(current)\n # Downsample then upsample to get coarse version\n coarse = F.interpolate(\n blurred, scale_factor=0.5, mode='bilinear',\n align_corners=False, recompute_scale_factor=True\n )\n upsampled = F.interpolate(\n coarse, size=current.shape[2:], mode='bilinear',\n align_corners=False, recompute_scale_factor=True\n )\n # Laplacian = original upsampled - upsampled(coarse)\n laplacian = current - upsampled\n pyramid.append(laplacian)\n current = coarse\n # Add the coarsest level\n pyramid.append(current)\n return pyramid\n\n # Build pyramids for both images\n pyr_d = build_laplacian_pyramid(xd, levels=3)\n pyr_c = build_laplacian_pyramid(xc, levels=3)\n\n # L1 loss on each level, weighted to emphasize fine detail\n # Levels: 0=finest, 1=medium, 2=coarse, 3=coarsest\n level_weights = [1.0, 0.8, 0.5, 0.3]\n laplacian_loss_batch = torch.zeros(mask.sum(), device=xd.device)\n for level, (ld, lc) in enumerate(zip(pyr_d, pyr_c)):\n w = level_weights[level]\n laplacian_loss_batch += w * F.l1_loss(ld, lc, reduction='none').flatten(1).mean(1)\n loss_laplacian[mask] = laplacian_loss_batch.float()\n\n # Weight coefficients — tuned for CIFAR-10 32x32\n w_mse = 1.0\n w_lpips = 0.12\n w_grad = 0.05\n w_spec = 0.06\n w_laplacian = 0.05\n\n loss = (\n w_mse * loss_mse_unscaled\n + perceptual_w * (\n w_lpips * loss_lpips\n + w_grad * loss_grad\n + w_spec * loss_spec\n + w_laplacian * loss_laplacian\n )\n ).mean()"
7}OK: Replaced lines 411..548 with 115 line(s) in alphaflow-main/custom_train_perceptual.py. Editable range: 411–525. [Current file: alphaflow-main/custom_train_perceptual.py | editable: 411–525 | total: 582 lines] 408| # Predict mean velocity 409| pred_mean_vel = net(x_t, sigma=t, sigma_next=t_next) 410| 411 # ── Proposed: MSE + LPIPS + Spectral + Laplacian Pyramid Loss ── 412 # 413 # Novel contribution: Laplacian pyramid reconstruction loss. 414 # The Laplacian pyramid decomposes an image into band-pass residuals 415 # at multiple scales. Each level captures structure at a specific 416 # spatial frequency band. Matching these residuals between denoised 417 # and clean images enforces correct multi-scale texture and edge 418 # structure — more localized than FFT, more spatially aware than 419 # LPIPS. 420
[Test #3] (0 tests remaining) ### train_small (scripts/train_small.sh) NCCL version 2.21.5+cuda12.4 Model parameters: 38.7M, 8 GPUs, ema=0.999 Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off] Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off] Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off] Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]/home/yangyucheng/miniconda3/envs/mlsbench-alphaflow-main/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead. warnings.warn( /home/yangyucheng/miniconda3/envs/mlsbench-alphaflow-main/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weigh
Tool Input
1{
2 "content": " # ================================================================\n # Multi-band Spectral-Structural Perceptual Loss\n #\n # Novel combination of:\n # 1. Base MSE on velocity (flow matching objective)\n # 2. LPIPS for deep perceptual features\n # 3. Log-spectral loss: L1 on log(1+|FFT|) for multi-band sensitivity\n # 4. Phase coherence: weighted angular consistency\n # 5. Gradient alignment: edge/sharpness matching\n #\n # Key insight: Log-magnitude compresses dynamic range, giving\n # high-frequency details (where CIFAR-10 quality matters most)\n # comparable weight to low-frequency structure. Combined with\n # phase coherence and gradient alignment, this captures both\n # structural and textural fidelity.\n #\n # All auxiliary losses are:\n # - Masked for t <= 0.1 (ill-conditioned denoised image)\n # - Weighted by (1-t)^2 (peaks at low noise)\n # - Ramped in over first 1000 steps\n # ================================================================\n\n # Base MSE on velocity (per-sample)\n err = pred_mean_vel - mean_vel_target\n loss_mse_unscaled = (err ** 2).flatten(1).mean(1)\n\n # Denoised image estimate\n x_denoised = x_t - t * pred_mean_vel\n t_flat = t.view(B)\n\n # Mask: skip ill-conditioned low-t samples\n mask = (t_flat > 0.1)\n perceptual_w = ((1.0 - t_flat) ** 2) * mask.float()\n\n # Initialize per-sample auxiliary losses\n loss_lpips = torch.zeros(B, device=device)\n loss_log_spec = torch.zeros(B, device=device)\n loss_phase = torch.zeros(B, device=device)\n loss_grad = torch.zeros(B, device=device)\n\n if mask.any():\n xd = x_denoised[mask].clamp(-1, 1).float()\n xc = x[mask].clamp(-1, 1).float()\n\n # 1. LPIPS (VGG-based perceptual loss)\n loss_lpips[mask] = lpips_fn(xd, xc).view(-1).float()\n\n # 2. Log-spectral loss: L1 on log(1 + |FFT|)\n # Log compression gives comparable weight to high-frequency details\n fd = torch.fft.rfft2(xd)\n fc = torch.fft.rfft2(xc)\n log_mag_diff = (torch.log1p(fd.abs()) - torch.log1p(fc.abs())).abs()\n loss_log_spec[mask] = log_mag_diff.flatten(1).mean(1)\n\n # 3. Phase coherence with magnitude weighting\n # High-magnitude frequencies matter more for perception\n mag_w = fc.abs() / (fc.abs().flatten(1).mean(1, keepdim=True).unsqueeze(-1).unsqueeze(-1) + 1e-8)\n phase_diff = 1.0 - torch.cos(fd.angle() - fc.angle())\n loss_phase[mask] = (mag_w * phase_diff).flatten(1).mean(1)\n\n # 4. Gradient alignment (edge sharpness)\n def sobel_grad(img):\n gx = F.conv2d(img, torch.tensor([[[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]]], device=img.device).repeat(3, 1, 1, 1), groups=3)\n gy = F.conv2d(img, torch.tensor([[[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]]], device=img.device).repeat(3, 1, 1, 1), groups=3)\n return torch.sqrt(gx**2 + gy**2 + 1e-8)\n\n gd = sobel_grad(xd)\n gc = sobel_grad(xc)\n loss_grad[mask] = (gd - gc).abs().flatten(1).mean(1)\n\n # Ramp up auxiliary losses over first 1000 steps\n ramp = min(step / 1000.0, 1.0)\n\n # Weighted combination:\n # - LPIPS: 0.05 (strong perceptual signal)\n # - Log-spectral: 0.04 (frequency fidelity)\n # - Phase: 0.005 (structural coherence, lower weight due to noise sensitivity)\n # - Gradient: 0.02 (edge sharpness)\n aux = (perceptual_w * (0.05 * loss_lpips + 0.04 * loss_log_spec + 0.005 * loss_phase + 0.02 * loss_grad)).mean()\n loss = loss_mse_unscaled.mean() + ramp * aux\n",
3 "filename": "custom_train_perceptual.py",
4 "op": "replace",
5 "start_line": 411,
6 "end_line": 428
7}ERROR: Package 'custom_train_perceptual.py' is not in allowed packages
Tool Input
1{
2 "content": " # ================================================================\n # Multi-band Spectral-Structural Perceptual Loss\n #\n # Novel combination of:\n # 1. Base MSE on velocity (flow matching objective)\n # 2. LPIPS for deep perceptual features\n # 3. Log-spectral loss: L1 on log(1+|FFT|) for multi-band sensitivity\n # 4. Phase coherence: weighted angular consistency\n # 5. Gradient alignment: edge/sharpness matching\n #\n # Key insight: Log-magnitude compresses dynamic range, giving\n # high-frequency details (where CIFAR-10 quality matters most)\n # comparable weight to low-frequency structure. Combined with\n # phase coherence and gradient alignment, this captures both\n # structural and textural fidelity.\n #\n # All auxiliary losses are:\n # - Masked for t <= 0.1 (ill-conditioned denoised image)\n # - Weighted by (1-t)^2 (peaks at low noise)\n # - Ramped in over first 1000 steps\n # ================================================================\n\n # Base MSE on velocity (per-sample)\n err = pred_mean_vel - mean_vel_target\n loss_mse_unscaled = (err ** 2).flatten(1).mean(1)\n\n # Denoised image estimate\n x_denoised = x_t - t * pred_mean_vel\n t_flat = t.view(B)\n\n # Mask: skip ill-conditioned low-t samples\n mask = (t_flat > 0.1)\n perceptual_w = ((1.0 - t_flat) ** 2) * mask.float()\n\n # Initialize per-sample auxiliary losses\n loss_lpips = torch.zeros(B, device=device)\n loss_log_spec = torch.zeros(B, device=device)\n loss_phase = torch.zeros(B, device=device)\n loss_grad = torch.zeros(B, device=device)\n\n if mask.any():\n xd = x_denoised[mask].clamp(-1, 1).float()\n xc = x[mask].clamp(-1, 1).float()\n\n # 1. LPIPS (VGG-based perceptual loss)\n loss_lpips[mask] = lpips_fn(xd, xc).view(-1).float()\n\n # 2. Log-spectral loss: L1 on log(1 + |FFT|)\n # Log compression gives comparable weight to high-frequency details\n fd = torch.fft.rfft2(xd)\n fc = torch.fft.rfft2(xc)\n log_mag_diff = (torch.log1p(fd.abs()) - torch.log1p(fc.abs())).abs()\n loss_log_spec[mask] = log_mag_diff.flatten(1).mean(1)\n\n # 3. Phase coherence with magnitude weighting\n # High-magnitude frequencies matter more for perception\n mag_w = fc.abs() / (fc.abs().flatten(1).mean(1, keepdim=True).unsqueeze(-1).unsqueeze(-1) + 1e-8)\n phase_diff = 1.0 - torch.cos(fd.angle() - fc.angle())\n loss_phase[mask] = (mag_w * phase_diff).flatten(1).mean(1)\n\n # 4. Gradient alignment (edge sharpness)\n def sobel_grad(img):\n gx = F.conv2d(img, torch.tensor([[[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]]], device=img.device).repeat(3, 1, 1, 1), groups=3)\n gy = F.conv2d(img, torch.tensor([[[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]]], device=img.device).repeat(3, 1, 1, 1), groups=3)\n return torch.sqrt(gx**2 + gy**2 + 1e-8)\n\n gd = sobel_grad(xd)\n gc = sobel_grad(xc)\n loss_grad[mask] = (gd - gc).abs().flatten(1).mean(1)\n\n # Ramp up auxiliary losses over first 1000 steps\n ramp = min(step / 1000.0, 1.0)\n\n # Weighted combination:\n # - LPIPS: 0.05 (strong perceptual signal)\n # - Log-spectral: 0.04 (frequency fidelity)\n # - Phase: 0.005 (structural coherence, lower weight due to noise sensitivity)\n # - Gradient: 0.02 (edge sharpness)\n aux = (perceptual_w * (0.05 * loss_lpips + 0.04 * loss_log_spec + 0.005 * loss_phase + 0.02 * loss_grad)).mean()\n loss = loss_mse_unscaled.mean() + ramp * aux\n",
3 "filename": "alphaflow-main/custom_train_perceptual.py",
4 "op": "replace",
5 "start_line": 411,
6 "end_line": 428
7}