Agent Conversation: deepseek-reasoner

Model: deepseek-reasoner
Total steps: 10
Messages: 31
UserStep 0

Task: cv-meanflow-perceptual-loss

Flow Matching with Perceptual Loss

Background

Flow matching trains a neural network to predict velocity fields that transport samples from noise to data. Traditional training uses only MSE loss on the predicted velocity:

loss = ||v_pred - v_target||^2

However, we can also compute the denoised image from the predicted velocity:

x_denoised = x_t - t * v_pred

And apply perceptual losses (LPIPS, gradient loss, etc.) on x_denoised to encourage the network to generate high-quality images, not just accurate velocities.

Research Question

Can adding perceptual losses to flow matching training improve FID scores?

Task

You are given custom_train_perceptual.py, a self-contained training script that trains a small DiT on CIFAR-10 (32x32) using flow matching with mean velocity objectives.

The editable region contains the loss computation in the training loop:

# Current: MSE loss only
loss_mse = ((pred_mean_vel - mean_vel_target) ** 2).mean()
loss = loss_mse

The fixed code already exposes:

  • lpips_fn(x_denoised, x_target) - perceptual loss
  • compute_gradient_loss(x_denoised, x_target) - gradient-domain loss
  • compute_multiscale_loss(x_denoised, x_target) - multi-resolution loss

Key constraint: Only apply auxiliary losses when t > 0.1 to avoid instability at small noise levels.

Evaluation

  • Dataset: CIFAR-10 (32x32)
  • Model: SmallDiT (512 hidden, 8 layers, ~40M params)
  • Training: 10000 steps, batch size 128
  • Metric: FID (lower is better), computed with clean-fid against CIFAR-10 train set
  • Inference: 10-step Euler sampler

Baselines

  1. mse-only: Pure MSE loss on velocity
  2. mse-lpips: MSE + LPIPS perceptual loss (VGG features)
  3. mse-lpips-grad: MSE + LPIPS + Gradient loss with timestep-adaptive weighting

alphaflow-main/custom_train_perceptual.py [EDITABLE — lines 411–428 only]

     1: """Custom Flow Matching Training Script — Perceptual Loss Variant
     2: Small-scale flow matching training on CIFAR-10 with a lightweight DiT.
     3: The training objective (MeanFlow) is pre-implemented; your task is to
     4: design an improved loss function, optionally using perceptual losses.
     5: """
     6: 
     7: import math
     8: import os
     9: import time
    10: 
    11: import lpips
    12: import numpy as np
    13: import torch
    14: import torch.nn as nn
    15: import torch.nn.functional as F
    16: from torch.autograd.functional import jvp
    17: from torchvision import datasets, transforms
    18: from torchvision.utils import save_image
    19: from perceptual_utils import compute_gradient_loss, compute_multiscale_loss
    20: 
    21: # ============================================================================
    22: # Model: Lightweight DiT for CIFAR-10 (32x32)
    23: # ============================================================================
    24: 
    25: def modulate(x, shift, scale):
    
AssistantStep 1edit
Tool Input
1{
2  "op": "replace",
3  "filename": "alphaflow-main/custom_train_perceptual.py",
4  "start_line": 411,
5  "end_line": 428,
6  "content": "            # Compute denoised image from predicted velocity\n            x_denoised = x_t - t * pred_mean_vel\n\n            # MSE loss on velocity (per-sample)\n            loss_mse_unscaled = ((pred_mean_vel - mean_vel_target) ** 2).flatten(1).mean(1)\n\n            # Adaptive weighting for stability (as in baseline)\n            t_flat = t.view(B)\n            t_next_flat = t_next.view(B)\n            dt_flat = dt.view(B)\n            mask_d = dt_flat > 0\n            weight = torch.ones(B, device=device)\n            if mask_d.any():\n                weight[mask_d] = alpha\n            weight = weight / (loss_mse_unscaled.detach() + 1e-3)\n\n            # Timestep-adaptive weighting for perceptual losses (Gaussian centered at t=0.5)\n            # Only apply perceptual losses when t > 0.1 (stability)\n            mask_aux = (t_flat > 0.1)\n            # Gaussian weight: exp(-(t-0.5)^2 / (2*0.2^2)), peak at 0.5, std 0.2\n            t_centered = t_flat - 0.5\n            gaussian_weight = torch.exp(-(t_centered ** 2) / (2 * 0.04))  # 0.04 = 0.2^2\n            # Zero weight when t <= 0.1\n            gaussian_weight = gaussian_weight * mask_aux.float()\n\n            # Initialize perceptual loss tensors\n            loss_lpips = torch.zeros(B, device=device)\n            loss_grad = torch.zeros(B, device=device)\n            loss_multiscale = torch.zeros(B, device=device)\n\n            if mask_aux.any():\n                # Clamp to [-1, 1] for perceptual losses\n                x_d_clamped = x_denoised[mask_aux].clamp(-1, 1)\n                x_clamped = x[mask_aux].clamp(-1, 1)\n\n                # LPIPS loss\n                loss_lpips_per_sample = lpips_fn(x_d_clamped, x_clamped).view(-1).float()\n                loss_lpips[mask_aux] = loss_lpips_per_sample\n\n                # Gradient loss (edge sharpness)\n                loss_grad_scalar = compute_gradient_loss(x_d_clamped, x_clamped)\n                loss_grad[mask_aux] = loss_grad_scalar\n\n                # Multiscale loss (structural similarity across scales)\n                loss_multiscale_scalar = compute_multiscale_loss(x_d_clamped, x_clamped)\n                loss_multiscale[mask_aux] = loss_multiscale_scalar\n\n            # Combine losses with coefficients\n            # Base MSE weight: 1.0 (already weighted by adaptive weight)\n            # Perceptual weights: tuned hyperparameters\n            lambda_lpips = 0.1\n            lambda_grad = 0.05\n            lambda_multiscale = 0.05\n\n            # Weighted loss per sample\n            loss_per_sample = weight * loss_mse_unscaled + \\\n                              gaussian_weight * (lambda_lpips * loss_lpips + \n                                                 lambda_grad * loss_grad + \n                                                 lambda_multiscale * loss_multiscale)\n\n            # Final scalar loss\n            loss = loss_per_sample.mean()"
7}
Tool ResultStep 1
OK: Replaced lines 411..428 with 62 line(s) in alphaflow-main/custom_train_perceptual.py. Editable range: 411–472.

[Current file: alphaflow-main/custom_train_perceptual.py | editable: 411–472 | total: 529 lines]
   408|             # Predict mean velocity
   409|             pred_mean_vel = net(x_t, sigma=t, sigma_next=t_next)
   410| 
   411              # Compute denoised image from predicted velocity
   412              x_denoised = x_t - t * pred_mean_vel
   413  
   414              # MSE loss on velocity (per-sample)
   415              loss_mse_unscaled = ((pred_mean_vel - mean_vel_target) ** 2).flatten(1).mean(1)
   416  
   417              # Adaptive weighting for stability (as in baseline)
   418              t_flat = t.view(B)
   419              t_next_flat = t_next.view(B)
   420              dt_flat = dt.view(B)
   421              mask_d = dt_flat > 0
   422              weight = torch.ones(B, device=device)
   423              if mask_d.any():
   424                  
AssistantStep 2test
Tool ResultStep 2
[Test #1] (2 tests remaining; call submit(n=1) to submit this result)

### train_small (scripts/train_small.sh)
Training results:
FID: 21.86, Best FID: 21.50

[Leaderboard] Results saved: {'fid': 21.86, 'best_fid': 21.5, 'elapsed_train_small': 6172}
AssistantStep 3edit
Tool Input
1{
2  "op": "replace",
3  "filename": "alphaflow-main/custom_train_perceptual.py",
4  "start_line": 411,
5  "end_line": 472,
6  "content": "            # Compute denoised image from predicted velocity\n            x_denoised = x_t - t * pred_mean_vel\n\n            # MSE loss on velocity (per-sample)\n            loss_mse_unscaled = ((pred_mean_vel - mean_vel_target) ** 2).flatten(1).mean(1)\n\n            # Adaptive weighting for stability (as in baseline)\n            t_flat = t.view(B)\n            t_next_flat = t_next.view(B)\n            dt_flat = dt.view(B)\n            mask_d = dt_flat > 0\n            weight = torch.ones(B, device=device)\n            if mask_d.any():\n                weight[mask_d] = alpha\n            weight = weight / (loss_mse_unscaled.detach() + 1e-3)\n\n            # Timestep-adaptive weighting for perceptual losses (linear increase with t)\n            # Only apply perceptual losses when t > 0.1 (stability)\n            mask_aux = (t_flat > 0.1)\n            # Linear weight: 0 at t=0.1, 1 at t=1.0\n            linear_weight = (t_flat - 0.1) / 0.9\n            linear_weight = linear_weight.clamp(min=0)  # ensure non-negative\n            # Zero weight when t <= 0.1\n            linear_weight = linear_weight * mask_aux.float()\n\n            # Initialize perceptual loss tensors\n            loss_lpips = torch.zeros(B, device=device)\n            loss_grad = torch.zeros(B, device=device)\n            loss_multiscale = torch.zeros(B, device=device)\n            loss_freq = torch.zeros(B, device=device)\n\n            if mask_aux.any():\n                # Clamp to [-1, 1] for perceptual losses\n                x_d_clamped = x_denoised[mask_aux].clamp(-1, 1)\n                x_clamped = x[mask_aux].clamp(-1, 1)\n\n                # LPIPS loss\n                loss_lpips_per_sample = lpips_fn(x_d_clamped, x_clamped).view(-1).float()\n                loss_lpips[mask_aux] = loss_lpips_per_sample\n\n                # Gradient loss (edge sharpness)\n                loss_grad_scalar = compute_gradient_loss(x_d_clamped, x_clamped)\n                loss_grad[mask_aux] = loss_grad_scalar\n\n                # Multiscale loss (structural similarity across scales)\n                loss_multiscale_scalar = compute_multiscale_loss(x_d_clamped, x_clamped)\n                loss_multiscale[mask_aux] = loss_multiscale_scalar\n\n                # Frequency-domain loss (FFT magnitude + phase)\n                # Compute FFT of both images\n                fft_d = torch.fft.rfft2(x_d_clamped, dim=(-2, -1))\n                fft_clean = torch.fft.rfft2(x_clamped, dim=(-2, -1))\n                # Magnitude loss\n                mag_loss = torch.abs(torch.abs(fft_d) - torch.abs(fft_clean)).mean(dim=(1,2,3))\n                # Phase loss (circular distance)\n                phase_loss = 1 - torch.cos(torch.angle(fft_d) - torch.angle(fft_clean))\n                phase_loss = phase_loss.mean(dim=(1,2,3))\n                # Combine magnitude and phase losses\n                loss_freq_per_sample = mag_loss + 0.5 * phase_loss\n                loss_freq[mask_aux] = loss_freq_per_sample\n\n            # Combine losses with coefficients\n            # Base MSE weight: 1.0 (already weighted by adaptive weight)\n            # Perceptual weights: tuned hyperparameters\n            lambda_lpips = 0.05\n            lambda_grad = 0.02\n            lambda_multiscale = 0.02\n            lambda_freq = 0.03\n\n            # Weighted loss per sample\n            loss_per_sample = weight * loss_mse_unscaled + \\\n                              linear_weight * (lambda_lpips * loss_lpips + \n                                               lambda_grad * loss_grad + \n                                               lambda_multiscale * loss_multiscale +\n                                               lambda_freq * loss_freq)\n\n            # Final scalar loss\n            loss = loss_per_sample.mean()"
7}
Tool ResultStep 3
OK: Replaced lines 411..472 with 78 line(s) in alphaflow-main/custom_train_perceptual.py. Editable range: 411–488.

[Current file: alphaflow-main/custom_train_perceptual.py | editable: 411–488 | total: 545 lines]
   408|             # Predict mean velocity
   409|             pred_mean_vel = net(x_t, sigma=t, sigma_next=t_next)
   410| 
   411              # Compute denoised image from predicted velocity
   412              x_denoised = x_t - t * pred_mean_vel
   413  
   414              # MSE loss on velocity (per-sample)
   415              loss_mse_unscaled = ((pred_mean_vel - mean_vel_target) ** 2).flatten(1).mean(1)
   416  
   417              # Adaptive weighting for stability (as in baseline)
   418              t_flat = t.view(B)
   419              t_next_flat = t_next.view(B)
   420              dt_flat = dt.view(B)
   421              mask_d = dt_flat > 0
   422              weight = torch.ones(B, device=device)
   423              if mask_d.any():
   424                  
AssistantStep 4test
AssistantStep 1edit
Tool Input
1{
2  "op": "replace",
3  "filename": "alphaflow-main/custom_train_perceptual.py",
4  "start_line": 411,
5  "end_line": 428,
6  "content": "            # MSE on velocity\n            err = pred_mean_vel - mean_vel_target\n            loss_mse_unscaled = (err ** 2).flatten(1).mean(1)\n\n            # Auxiliary perceptual losses on denoised image (mask t<=0.1 edge case)\n            x_denoised = x_t - t * pred_mean_vel\n            t_flat = t.view(B)\n            mask = (t_flat > 0.1)\n            perceptual_w = ((1.0 - t_flat) ** 2) * mask.float()\n\n            loss_lpips = torch.zeros(B, device=device)\n            loss_grad = torch.zeros(B, device=device)\n            loss_multi = torch.zeros(B, device=device)\n            loss_spec = torch.zeros(B, device=device)\n            loss_complex = torch.zeros(B, device=device)\n            if mask.any():\n                xd = x_denoised[mask].clamp(-1, 1).float()\n                xc = x[mask].clamp(-1, 1).float()\n                loss_lpips[mask] = lpips_fn(xd, xc).view(-1).float()\n                loss_grad[mask] = compute_gradient_loss(xd, xc).float()\n                loss_multi[mask] = compute_multiscale_loss(xd, xc).float()\n                # FFT magnitude L1: per-channel rfft2, abs, L1 of difference\n                xd_fft = torch.fft.rfft2(xd, norm='ortho')\n                xc_fft = torch.fft.rfft2(xc, norm='ortho')\n                loss_spec[mask] = torch.abs(torch.abs(xd_fft) - torch.abs(xc_fft)).mean(dim=(1,2,3)).float()\n                # Complex FFT loss: L1 on complex difference (magnitude + phase)\n                loss_complex[mask] = torch.abs(xd_fft - xc_fft).mean(dim=(1,2,3)).float()\n\n            # Combine losses with weights\n            w_mse = 1.0\n            w_lpips = 0.1\n            w_grad = 0.05\n            w_multi = 0.01\n            w_spec = 0.05\n            w_complex = 0.05\n            \n            loss_per_sample = (w_mse * loss_mse_unscaled +\n                               perceptual_w * (w_lpips * loss_lpips +\n                                               w_grad * loss_grad +\n                                               w_multi * loss_multi +\n                                               w_spec * loss_spec +\n                                               w_complex * loss_complex))\n            loss = loss_per_sample.mean()"
7}
Tool ResultStep 1
OK: Replaced lines 411..428 with 43 line(s) in alphaflow-main/custom_train_perceptual.py. Editable range: 411–453.

[Current file: alphaflow-main/custom_train_perceptual.py | editable: 411–453 | total: 510 lines]
   408|             # Predict mean velocity
   409|             pred_mean_vel = net(x_t, sigma=t, sigma_next=t_next)
   410| 
   411              # MSE on velocity
   412              err = pred_mean_vel - mean_vel_target
   413              loss_mse_unscaled = (err ** 2).flatten(1).mean(1)
   414  
   415              # Auxiliary perceptual losses on denoised image (mask t<=0.1 edge case)
   416              x_denoised = x_t - t * pred_mean_vel
   417              t_flat = t.view(B)
   418              mask = (t_flat > 0.1)
   419              perceptual_w = ((1.0 - t_flat) ** 2) * mask.float()
   420  
   421              loss_lpips = torch.zeros(B, device=device)
   422              loss_grad = torch.zeros(B, device=device)
   423              loss_multi = torch.zeros(B, de
AssistantStep 2test
AssistantStep 1edit
Tool Input
1{
2  "op": "replace",
3  "filename": "alphaflow-main/custom_train_perceptual.py",
4  "start_line": 411,
5  "end_line": 428,
6  "content": "            # MSE on velocity\n            err = pred_mean_vel - mean_vel_target\n            loss_mse_unscaled = (err ** 2).flatten(1).mean(1)\n\n            # Auxiliary perceptual losses on denoised image (mask t<=0.1 edge case)\n            x_denoised = x_t - t * pred_mean_vel\n            t_flat = t.view(B)\n            mask = (t_flat > 0.1)\n            perceptual_w = ((1.0 - t_flat) ** 2) * mask.float()\n\n            loss_lpips = torch.zeros(B, device=device)\n            loss_grad = torch.zeros(B, device=device)\n            loss_multi = torch.zeros(B, device=device)\n            loss_spec = torch.zeros(B, device=device)\n            loss_complex = torch.zeros(B, device=device)\n            if mask.any():\n                xd = x_denoised[mask].clamp(-1, 1).float()\n                xc = x[mask].clamp(-1, 1).float()\n                loss_lpips[mask] = lpips_fn(xd, xc).view(-1).float()\n                loss_grad[mask] = compute_gradient_loss(xd, xc).float()\n                loss_multi[mask] = compute_multiscale_loss(xd, xc).float()\n                # FFT magnitude L1: per-channel rfft2, abs, L1 of difference\n                xd_fft = torch.fft.rfft2(xd, dim=(-2, -1))\n                xc_fft = torch.fft.rfft2(xc, dim=(-2, -1))\n                loss_spec[mask] = (xd_fft.abs() - xc_fft.abs()).abs().flatten(1).mean(1).float()\n                # Complex loss: L1 on real and imaginary parts\n                loss_real = (xd_fft.real - xc_fft.real).abs().flatten(1).mean(1)\n                loss_imag = (xd_fft.imag - xc_fft.imag).abs().flatten(1).mean(1)\n                loss_complex[mask] = (loss_real + loss_imag).float()\n\n            # Combine losses\n            loss_vel = loss_mse_unscaled.mean()\n            loss_perceptual = (perceptual_w * (loss_lpips + loss_grad + loss_multi + loss_spec + loss_complex)).mean()\n            loss = loss_vel + loss_perceptual"
7}
Tool ResultStep 1
OK: Replaced lines 411..428 with 34 line(s) in alphaflow-main/custom_train_perceptual.py. Editable range: 411–444.

[Current file: alphaflow-main/custom_train_perceptual.py | editable: 411–444 | total: 501 lines]
   408|             # Predict mean velocity
   409|             pred_mean_vel = net(x_t, sigma=t, sigma_next=t_next)
   410| 
   411              # MSE on velocity
   412              err = pred_mean_vel - mean_vel_target
   413              loss_mse_unscaled = (err ** 2).flatten(1).mean(1)
   414  
   415              # Auxiliary perceptual losses on denoised image (mask t<=0.1 edge case)
   416              x_denoised = x_t - t * pred_mean_vel
   417              t_flat = t.view(B)
   418              mask = (t_flat > 0.1)
   419              perceptual_w = ((1.0 - t_flat) ** 2) * mask.float()
   420  
   421              loss_lpips = torch.zeros(B, device=device)
   422              loss_grad = torch.zeros(B, device=device)
   423              loss_multi = torch.zeros(B, de
AssistantStep 2test
Tool ResultStep 2
[Test #1] (2 tests remaining; call submit(n=1) to submit this result)

### train_small (scripts/train_small.sh)
Training results:
FID: 26.73, Best FID: 26.94

### train_medium (scripts/train_medium.sh)
Training results:
FID: 32.58, Best FID: 26.61

[Leaderboard] Results saved: {'fid': 32.58, 'best_fid': 26.61, 'fid_small': 26.73, 'best_fid_small': 26.94, 'elapsed_train_small': 6137, 'fid_medium': 32.58, 'best_fid_medium': 26.61, 'elapsed_train_medium': 7894}
AssistantStep 3edit
Tool Input
1{
2  "op": "replace",
3  "filename": "alphaflow-main/custom_train_perceptual.py",
4  "start_line": 411,
5  "end_line": 444,
6  "content": "            # MSE on velocity\n            err = pred_mean_vel - mean_vel_target\n            loss_mse_unscaled = (err ** 2).flatten(1).mean(1)\n\n            # Auxiliary perceptual losses on denoised image (mask t<=0.1 edge case)\n            x_denoised = x_t - t * pred_mean_vel\n            t_flat = t.view(B)\n            mask = (t_flat > 0.1)\n            perceptual_w = ((1.0 - t_flat) ** 2) * mask.float()\n\n            loss_lpips = torch.zeros(B, device=device)\n            loss_grad = torch.zeros(B, device=device)\n            loss_multi = torch.zeros(B, device=device)\n            loss_mag = torch.zeros(B, device=device)\n            loss_phase = torch.zeros(B, device=device)\n            if mask.any():\n                xd = x_denoised[mask].clamp(-1, 1).float()\n                xc = x[mask].clamp(-1, 1).float()\n                loss_lpips[mask] = lpips_fn(xd, xc).view(-1).float()\n                loss_grad[mask] = compute_gradient_loss(xd, xc).float()\n                loss_multi[mask] = compute_multiscale_loss(xd, xc).float()\n                # FFT magnitude L1\n                xd_fft = torch.fft.rfft2(xd, dim=(-2, -1))\n                xc_fft = torch.fft.rfft2(xc, dim=(-2, -1))\n                loss_mag[mask] = (xd_fft.abs() - xc_fft.abs()).abs().flatten(1).mean(1).float()\n                # Phase consistency loss: normalized complex difference\n                # Normalize by magnitude to focus on phase\n                eps = 1e-6\n                xd_mag = xd_fft.abs().clamp(min=eps)\n                xc_mag = xc_fft.abs().clamp(min=eps)\n                xd_unit = xd_fft / xd_mag\n                xc_unit = xc_fft / xc_mag\n                # L1 on real and imag parts of normalized spectra\n                loss_real = (xd_unit.real - xc_unit.real).abs().flatten(1).mean(1)\n                loss_imag = (xd_unit.imag - xc_unit.imag).abs().flatten(1).mean(1)\n                loss_phase[mask] = (loss_real + loss_imag).float()\n\n            # Combine losses\n            loss_vel = loss_mse_unscaled.mean()\n            loss_perceptual = (perceptual_w * (loss_lpips + loss_grad + loss_multi + loss_mag + loss_phase)).mean()\n            loss = loss_vel + loss_perceptual"
7}
Tool ResultStep 3
OK: Replaced lines 411..444 with 41 line(s) in alphaflow-main/custom_train_perceptual.py. Editable range: 411–451.

[Current file: alphaflow-main/custom_train_perceptual.py | editable: 411–451 | total: 508 lines]
   408|             # Predict mean velocity
   409|             pred_mean_vel = net(x_t, sigma=t, sigma_next=t_next)
   410| 
   411              # MSE on velocity
   412              err = pred_mean_vel - mean_vel_target
   413              loss_mse_unscaled = (err ** 2).flatten(1).mean(1)
   414  
   415              # Auxiliary perceptual losses on denoised image (mask t<=0.1 edge case)
   416              x_denoised = x_t - t * pred_mean_vel
   417              t_flat = t.view(B)
   418              mask = (t_flat > 0.1)
   419              perceptual_w = ((1.0 - t_flat) ** 2) * mask.float()
   420  
   421              loss_lpips = torch.zeros(B, device=device)
   422              loss_grad = torch.zeros(B, device=device)
   423              loss_multi = torch.zeros(B, de
AssistantStep 4test
Tool ResultStep 4
[Test #2] (1 test remaining; call submit(n=2) to submit this result)

### train_small (scripts/train_small.sh)
Training results:
FID: 23.59, Best FID: 23.94

### train_medium (scripts/train_medium.sh)
Training results:
FID: 24.85, Best FID: 23.36

[Leaderboard] Results saved: {'fid': 24.85, 'best_fid': 23.36, 'fid_small': 23.59, 'best_fid_small': 23.94, 'elapsed_train_small': 6114, 'fid_medium': 24.85, 'best_fid_medium': 23.36, 'elapsed_train_medium': 7897}
AssistantStep 5edit
Tool Input
1{
2  "op": "replace",
3  "filename": "alphaflow-main/custom_train_perceptual.py",
4  "start_line": 411,
5  "end_line": 451,
6  "content": "            # MSE on velocity\n            err = pred_mean_vel - mean_vel_target\n            loss_mse_unscaled = (err ** 2).flatten(1).mean(1)\n\n            # Auxiliary perceptual losses on denoised image (mask t<=0.1 edge case)\n            x_denoised = x_t - t * pred_mean_vel\n            t_flat = t.view(B)\n            mask = (t_flat > 0.1)\n            perceptual_w = ((1.0 - t_flat) ** 2) * mask.float()\n\n            loss_lpips = torch.zeros(B, device=device)\n            loss_grad = torch.zeros(B, device=device)\n            loss_multi = torch.zeros(B, device=device)\n            loss_mag = torch.zeros(B, device=device)\n            loss_phase = torch.zeros(B, device=device)\n            if mask.any():\n                xd = x_denoised[mask].clamp(-1, 1).float()\n                xc = x[mask].clamp(-1, 1).float()\n                loss_lpips[mask] = lpips_fn(xd, xc).view(-1).float()\n                loss_grad[mask] = compute_gradient_loss(xd, xc).float()\n                loss_multi[mask] = compute_multiscale_loss(xd, xc).float()\n                # FFT magnitude L1\n                xd_fft = torch.fft.rfft2(xd, dim=(-2, -1))\n                xc_fft = torch.fft.rfft2(xc, dim=(-2, -1))\n                loss_mag[mask] = (xd_fft.abs() - xc_fft.abs()).abs().flatten(1).mean(1).float()\n                # Phase consistency loss: normalized complex difference\n                eps = 1e-6\n                xd_mag = xd_fft.abs().clamp(min=eps)\n                xc_mag = xc_fft.abs().clamp(min=eps)\n                xd_unit = xd_fft / xd_mag\n                xc_unit = xc_fft / xc_mag\n                loss_real = (xd_unit.real - xc_unit.real).abs().flatten(1).mean(1)\n                loss_imag = (xd_unit.imag - xc_unit.imag).abs().flatten(1).mean(1)\n                loss_phase[mask] = (loss_real + loss_imag).float()\n\n            # Weighting factors for each perceptual loss\n            w_lpips = 1.0\n            w_grad = 0.1\n            w_multi = 0.1\n            w_mag = 0.5\n            w_phase = 0.5\n            perceptual_total = (w_lpips * loss_lpips + w_grad * loss_grad + w_multi * loss_multi +\n                                w_mag * loss_mag + w_phase * loss_phase)\n            \n            # Overall perceptual scaling\n            lambda_perceptual = 0.5\n            loss_vel = loss_mse_unscaled.mean()\n            loss_perceptual = lambda_perceptual * (perceptual_w * perceptual_total).mean()\n            loss = loss_vel + loss_perceptual"
7}