Agent Conversation: deepseek-reasoner
Task: cv-meanflow-perceptual-loss
Flow Matching with Perceptual Loss
Background
Flow matching trains a neural network to predict velocity fields that transport samples from noise to data. Traditional training uses only MSE loss on the predicted velocity:
loss = ||v_pred - v_target||^2
However, we can also compute the denoised image from the predicted velocity:
x_denoised = x_t - t * v_pred
And apply perceptual losses (LPIPS, gradient loss, etc.) on x_denoised to encourage the network to generate high-quality images, not just accurate velocities.
Research Question
Can adding perceptual losses to flow matching training improve FID scores?
Task
You are given custom_train_perceptual.py, a self-contained training script that trains a
small DiT on CIFAR-10 (32x32) using flow matching with mean velocity objectives.
The editable region contains the loss computation in the training loop:
# Current: MSE loss only
loss_mse = ((pred_mean_vel - mean_vel_target) ** 2).mean()
loss = loss_mse
The fixed code already exposes:
lpips_fn(x_denoised, x_target)- perceptual losscompute_gradient_loss(x_denoised, x_target)- gradient-domain losscompute_multiscale_loss(x_denoised, x_target)- multi-resolution loss
Key constraint: Only apply auxiliary losses when t > 0.1 to avoid instability at small noise levels.
Evaluation
- Dataset: CIFAR-10 (32x32)
- Model: SmallDiT (512 hidden, 8 layers, ~40M params)
- Training: 10000 steps, batch size 128
- Metric: FID (lower is better), computed with clean-fid against CIFAR-10 train set
- Inference: 10-step Euler sampler
Baselines
- mse-only: Pure MSE loss on velocity
- mse-lpips: MSE + LPIPS perceptual loss (VGG features)
- mse-lpips-grad: MSE + LPIPS + Gradient loss with timestep-adaptive weighting
alphaflow-main/custom_train_perceptual.py [EDITABLE — lines 411–428 only]
1: """Custom Flow Matching Training Script — Perceptual Loss Variant
2: Small-scale flow matching training on CIFAR-10 with a lightweight DiT.
3: The training objective (MeanFlow) is pre-implemented; your task is to
4: design an improved loss function, optionally using perceptual losses.
5: """
6:
7: import math
8: import os
9: import time
10:
11: import lpips
12: import numpy as np
13: import torch
14: import torch.nn as nn
15: import torch.nn.functional as F
16: from torch.autograd.functional import jvp
17: from torchvision import datasets, transforms
18: from torchvision.utils import save_image
19: from perceptual_utils import compute_gradient_loss, compute_multiscale_loss
20:
21: # ============================================================================
22: # Model: Lightweight DiT for CIFAR-10 (32x32)
23: # ============================================================================
24:
25: def modulate(x, shift, scale):
Tool Input
1{
2 "op": "replace",
3 "filename": "alphaflow-main/custom_train_perceptual.py",
4 "start_line": 411,
5 "end_line": 428,
6 "content": " # Compute denoised image from predicted velocity\n x_denoised = x_t - t * pred_mean_vel\n\n # MSE loss on velocity (per-sample)\n loss_mse_unscaled = ((pred_mean_vel - mean_vel_target) ** 2).flatten(1).mean(1)\n\n # Adaptive weighting for stability (as in baseline)\n t_flat = t.view(B)\n t_next_flat = t_next.view(B)\n dt_flat = dt.view(B)\n mask_d = dt_flat > 0\n weight = torch.ones(B, device=device)\n if mask_d.any():\n weight[mask_d] = alpha\n weight = weight / (loss_mse_unscaled.detach() + 1e-3)\n\n # Timestep-adaptive weighting for perceptual losses (Gaussian centered at t=0.5)\n # Only apply perceptual losses when t > 0.1 (stability)\n mask_aux = (t_flat > 0.1)\n # Gaussian weight: exp(-(t-0.5)^2 / (2*0.2^2)), peak at 0.5, std 0.2\n t_centered = t_flat - 0.5\n gaussian_weight = torch.exp(-(t_centered ** 2) / (2 * 0.04)) # 0.04 = 0.2^2\n # Zero weight when t <= 0.1\n gaussian_weight = gaussian_weight * mask_aux.float()\n\n # Initialize perceptual loss tensors\n loss_lpips = torch.zeros(B, device=device)\n loss_grad = torch.zeros(B, device=device)\n loss_multiscale = torch.zeros(B, device=device)\n\n if mask_aux.any():\n # Clamp to [-1, 1] for perceptual losses\n x_d_clamped = x_denoised[mask_aux].clamp(-1, 1)\n x_clamped = x[mask_aux].clamp(-1, 1)\n\n # LPIPS loss\n loss_lpips_per_sample = lpips_fn(x_d_clamped, x_clamped).view(-1).float()\n loss_lpips[mask_aux] = loss_lpips_per_sample\n\n # Gradient loss (edge sharpness)\n loss_grad_scalar = compute_gradient_loss(x_d_clamped, x_clamped)\n loss_grad[mask_aux] = loss_grad_scalar\n\n # Multiscale loss (structural similarity across scales)\n loss_multiscale_scalar = compute_multiscale_loss(x_d_clamped, x_clamped)\n loss_multiscale[mask_aux] = loss_multiscale_scalar\n\n # Combine losses with coefficients\n # Base MSE weight: 1.0 (already weighted by adaptive weight)\n # Perceptual weights: tuned hyperparameters\n lambda_lpips = 0.1\n lambda_grad = 0.05\n lambda_multiscale = 0.05\n\n # Weighted loss per sample\n loss_per_sample = weight * loss_mse_unscaled + \\\n gaussian_weight * (lambda_lpips * loss_lpips + \n lambda_grad * loss_grad + \n lambda_multiscale * loss_multiscale)\n\n # Final scalar loss\n loss = loss_per_sample.mean()"
7}OK: Replaced lines 411..428 with 62 line(s) in alphaflow-main/custom_train_perceptual.py. Editable range: 411–472. [Current file: alphaflow-main/custom_train_perceptual.py | editable: 411–472 | total: 529 lines] 408| # Predict mean velocity 409| pred_mean_vel = net(x_t, sigma=t, sigma_next=t_next) 410| 411 # Compute denoised image from predicted velocity 412 x_denoised = x_t - t * pred_mean_vel 413 414 # MSE loss on velocity (per-sample) 415 loss_mse_unscaled = ((pred_mean_vel - mean_vel_target) ** 2).flatten(1).mean(1) 416 417 # Adaptive weighting for stability (as in baseline) 418 t_flat = t.view(B) 419 t_next_flat = t_next.view(B) 420 dt_flat = dt.view(B) 421 mask_d = dt_flat > 0 422 weight = torch.ones(B, device=device) 423 if mask_d.any(): 424
[Test #1] (2 tests remaining; call submit(n=1) to submit this result)
### train_small (scripts/train_small.sh)
Training results:
FID: 21.86, Best FID: 21.50
[Leaderboard] Results saved: {'fid': 21.86, 'best_fid': 21.5, 'elapsed_train_small': 6172}Tool Input
1{
2 "op": "replace",
3 "filename": "alphaflow-main/custom_train_perceptual.py",
4 "start_line": 411,
5 "end_line": 472,
6 "content": " # Compute denoised image from predicted velocity\n x_denoised = x_t - t * pred_mean_vel\n\n # MSE loss on velocity (per-sample)\n loss_mse_unscaled = ((pred_mean_vel - mean_vel_target) ** 2).flatten(1).mean(1)\n\n # Adaptive weighting for stability (as in baseline)\n t_flat = t.view(B)\n t_next_flat = t_next.view(B)\n dt_flat = dt.view(B)\n mask_d = dt_flat > 0\n weight = torch.ones(B, device=device)\n if mask_d.any():\n weight[mask_d] = alpha\n weight = weight / (loss_mse_unscaled.detach() + 1e-3)\n\n # Timestep-adaptive weighting for perceptual losses (linear increase with t)\n # Only apply perceptual losses when t > 0.1 (stability)\n mask_aux = (t_flat > 0.1)\n # Linear weight: 0 at t=0.1, 1 at t=1.0\n linear_weight = (t_flat - 0.1) / 0.9\n linear_weight = linear_weight.clamp(min=0) # ensure non-negative\n # Zero weight when t <= 0.1\n linear_weight = linear_weight * mask_aux.float()\n\n # Initialize perceptual loss tensors\n loss_lpips = torch.zeros(B, device=device)\n loss_grad = torch.zeros(B, device=device)\n loss_multiscale = torch.zeros(B, device=device)\n loss_freq = torch.zeros(B, device=device)\n\n if mask_aux.any():\n # Clamp to [-1, 1] for perceptual losses\n x_d_clamped = x_denoised[mask_aux].clamp(-1, 1)\n x_clamped = x[mask_aux].clamp(-1, 1)\n\n # LPIPS loss\n loss_lpips_per_sample = lpips_fn(x_d_clamped, x_clamped).view(-1).float()\n loss_lpips[mask_aux] = loss_lpips_per_sample\n\n # Gradient loss (edge sharpness)\n loss_grad_scalar = compute_gradient_loss(x_d_clamped, x_clamped)\n loss_grad[mask_aux] = loss_grad_scalar\n\n # Multiscale loss (structural similarity across scales)\n loss_multiscale_scalar = compute_multiscale_loss(x_d_clamped, x_clamped)\n loss_multiscale[mask_aux] = loss_multiscale_scalar\n\n # Frequency-domain loss (FFT magnitude + phase)\n # Compute FFT of both images\n fft_d = torch.fft.rfft2(x_d_clamped, dim=(-2, -1))\n fft_clean = torch.fft.rfft2(x_clamped, dim=(-2, -1))\n # Magnitude loss\n mag_loss = torch.abs(torch.abs(fft_d) - torch.abs(fft_clean)).mean(dim=(1,2,3))\n # Phase loss (circular distance)\n phase_loss = 1 - torch.cos(torch.angle(fft_d) - torch.angle(fft_clean))\n phase_loss = phase_loss.mean(dim=(1,2,3))\n # Combine magnitude and phase losses\n loss_freq_per_sample = mag_loss + 0.5 * phase_loss\n loss_freq[mask_aux] = loss_freq_per_sample\n\n # Combine losses with coefficients\n # Base MSE weight: 1.0 (already weighted by adaptive weight)\n # Perceptual weights: tuned hyperparameters\n lambda_lpips = 0.05\n lambda_grad = 0.02\n lambda_multiscale = 0.02\n lambda_freq = 0.03\n\n # Weighted loss per sample\n loss_per_sample = weight * loss_mse_unscaled + \\\n linear_weight * (lambda_lpips * loss_lpips + \n lambda_grad * loss_grad + \n lambda_multiscale * loss_multiscale +\n lambda_freq * loss_freq)\n\n # Final scalar loss\n loss = loss_per_sample.mean()"
7}OK: Replaced lines 411..472 with 78 line(s) in alphaflow-main/custom_train_perceptual.py. Editable range: 411–488. [Current file: alphaflow-main/custom_train_perceptual.py | editable: 411–488 | total: 545 lines] 408| # Predict mean velocity 409| pred_mean_vel = net(x_t, sigma=t, sigma_next=t_next) 410| 411 # Compute denoised image from predicted velocity 412 x_denoised = x_t - t * pred_mean_vel 413 414 # MSE loss on velocity (per-sample) 415 loss_mse_unscaled = ((pred_mean_vel - mean_vel_target) ** 2).flatten(1).mean(1) 416 417 # Adaptive weighting for stability (as in baseline) 418 t_flat = t.view(B) 419 t_next_flat = t_next.view(B) 420 dt_flat = dt.view(B) 421 mask_d = dt_flat > 0 422 weight = torch.ones(B, device=device) 423 if mask_d.any(): 424
Tool Input
1{
2 "op": "replace",
3 "filename": "alphaflow-main/custom_train_perceptual.py",
4 "start_line": 411,
5 "end_line": 428,
6 "content": " # MSE on velocity\n err = pred_mean_vel - mean_vel_target\n loss_mse_unscaled = (err ** 2).flatten(1).mean(1)\n\n # Auxiliary perceptual losses on denoised image (mask t<=0.1 edge case)\n x_denoised = x_t - t * pred_mean_vel\n t_flat = t.view(B)\n mask = (t_flat > 0.1)\n perceptual_w = ((1.0 - t_flat) ** 2) * mask.float()\n\n loss_lpips = torch.zeros(B, device=device)\n loss_grad = torch.zeros(B, device=device)\n loss_multi = torch.zeros(B, device=device)\n loss_spec = torch.zeros(B, device=device)\n loss_complex = torch.zeros(B, device=device)\n if mask.any():\n xd = x_denoised[mask].clamp(-1, 1).float()\n xc = x[mask].clamp(-1, 1).float()\n loss_lpips[mask] = lpips_fn(xd, xc).view(-1).float()\n loss_grad[mask] = compute_gradient_loss(xd, xc).float()\n loss_multi[mask] = compute_multiscale_loss(xd, xc).float()\n # FFT magnitude L1: per-channel rfft2, abs, L1 of difference\n xd_fft = torch.fft.rfft2(xd, norm='ortho')\n xc_fft = torch.fft.rfft2(xc, norm='ortho')\n loss_spec[mask] = torch.abs(torch.abs(xd_fft) - torch.abs(xc_fft)).mean(dim=(1,2,3)).float()\n # Complex FFT loss: L1 on complex difference (magnitude + phase)\n loss_complex[mask] = torch.abs(xd_fft - xc_fft).mean(dim=(1,2,3)).float()\n\n # Combine losses with weights\n w_mse = 1.0\n w_lpips = 0.1\n w_grad = 0.05\n w_multi = 0.01\n w_spec = 0.05\n w_complex = 0.05\n \n loss_per_sample = (w_mse * loss_mse_unscaled +\n perceptual_w * (w_lpips * loss_lpips +\n w_grad * loss_grad +\n w_multi * loss_multi +\n w_spec * loss_spec +\n w_complex * loss_complex))\n loss = loss_per_sample.mean()"
7}OK: Replaced lines 411..428 with 43 line(s) in alphaflow-main/custom_train_perceptual.py. Editable range: 411–453. [Current file: alphaflow-main/custom_train_perceptual.py | editable: 411–453 | total: 510 lines] 408| # Predict mean velocity 409| pred_mean_vel = net(x_t, sigma=t, sigma_next=t_next) 410| 411 # MSE on velocity 412 err = pred_mean_vel - mean_vel_target 413 loss_mse_unscaled = (err ** 2).flatten(1).mean(1) 414 415 # Auxiliary perceptual losses on denoised image (mask t<=0.1 edge case) 416 x_denoised = x_t - t * pred_mean_vel 417 t_flat = t.view(B) 418 mask = (t_flat > 0.1) 419 perceptual_w = ((1.0 - t_flat) ** 2) * mask.float() 420 421 loss_lpips = torch.zeros(B, device=device) 422 loss_grad = torch.zeros(B, device=device) 423 loss_multi = torch.zeros(B, de
Tool Input
1{
2 "op": "replace",
3 "filename": "alphaflow-main/custom_train_perceptual.py",
4 "start_line": 411,
5 "end_line": 428,
6 "content": " # MSE on velocity\n err = pred_mean_vel - mean_vel_target\n loss_mse_unscaled = (err ** 2).flatten(1).mean(1)\n\n # Auxiliary perceptual losses on denoised image (mask t<=0.1 edge case)\n x_denoised = x_t - t * pred_mean_vel\n t_flat = t.view(B)\n mask = (t_flat > 0.1)\n perceptual_w = ((1.0 - t_flat) ** 2) * mask.float()\n\n loss_lpips = torch.zeros(B, device=device)\n loss_grad = torch.zeros(B, device=device)\n loss_multi = torch.zeros(B, device=device)\n loss_spec = torch.zeros(B, device=device)\n loss_complex = torch.zeros(B, device=device)\n if mask.any():\n xd = x_denoised[mask].clamp(-1, 1).float()\n xc = x[mask].clamp(-1, 1).float()\n loss_lpips[mask] = lpips_fn(xd, xc).view(-1).float()\n loss_grad[mask] = compute_gradient_loss(xd, xc).float()\n loss_multi[mask] = compute_multiscale_loss(xd, xc).float()\n # FFT magnitude L1: per-channel rfft2, abs, L1 of difference\n xd_fft = torch.fft.rfft2(xd, dim=(-2, -1))\n xc_fft = torch.fft.rfft2(xc, dim=(-2, -1))\n loss_spec[mask] = (xd_fft.abs() - xc_fft.abs()).abs().flatten(1).mean(1).float()\n # Complex loss: L1 on real and imaginary parts\n loss_real = (xd_fft.real - xc_fft.real).abs().flatten(1).mean(1)\n loss_imag = (xd_fft.imag - xc_fft.imag).abs().flatten(1).mean(1)\n loss_complex[mask] = (loss_real + loss_imag).float()\n\n # Combine losses\n loss_vel = loss_mse_unscaled.mean()\n loss_perceptual = (perceptual_w * (loss_lpips + loss_grad + loss_multi + loss_spec + loss_complex)).mean()\n loss = loss_vel + loss_perceptual"
7}OK: Replaced lines 411..428 with 34 line(s) in alphaflow-main/custom_train_perceptual.py. Editable range: 411–444. [Current file: alphaflow-main/custom_train_perceptual.py | editable: 411–444 | total: 501 lines] 408| # Predict mean velocity 409| pred_mean_vel = net(x_t, sigma=t, sigma_next=t_next) 410| 411 # MSE on velocity 412 err = pred_mean_vel - mean_vel_target 413 loss_mse_unscaled = (err ** 2).flatten(1).mean(1) 414 415 # Auxiliary perceptual losses on denoised image (mask t<=0.1 edge case) 416 x_denoised = x_t - t * pred_mean_vel 417 t_flat = t.view(B) 418 mask = (t_flat > 0.1) 419 perceptual_w = ((1.0 - t_flat) ** 2) * mask.float() 420 421 loss_lpips = torch.zeros(B, device=device) 422 loss_grad = torch.zeros(B, device=device) 423 loss_multi = torch.zeros(B, de
[Test #1] (2 tests remaining; call submit(n=1) to submit this result)
### train_small (scripts/train_small.sh)
Training results:
FID: 26.73, Best FID: 26.94
### train_medium (scripts/train_medium.sh)
Training results:
FID: 32.58, Best FID: 26.61
[Leaderboard] Results saved: {'fid': 32.58, 'best_fid': 26.61, 'fid_small': 26.73, 'best_fid_small': 26.94, 'elapsed_train_small': 6137, 'fid_medium': 32.58, 'best_fid_medium': 26.61, 'elapsed_train_medium': 7894}Tool Input
1{
2 "op": "replace",
3 "filename": "alphaflow-main/custom_train_perceptual.py",
4 "start_line": 411,
5 "end_line": 444,
6 "content": " # MSE on velocity\n err = pred_mean_vel - mean_vel_target\n loss_mse_unscaled = (err ** 2).flatten(1).mean(1)\n\n # Auxiliary perceptual losses on denoised image (mask t<=0.1 edge case)\n x_denoised = x_t - t * pred_mean_vel\n t_flat = t.view(B)\n mask = (t_flat > 0.1)\n perceptual_w = ((1.0 - t_flat) ** 2) * mask.float()\n\n loss_lpips = torch.zeros(B, device=device)\n loss_grad = torch.zeros(B, device=device)\n loss_multi = torch.zeros(B, device=device)\n loss_mag = torch.zeros(B, device=device)\n loss_phase = torch.zeros(B, device=device)\n if mask.any():\n xd = x_denoised[mask].clamp(-1, 1).float()\n xc = x[mask].clamp(-1, 1).float()\n loss_lpips[mask] = lpips_fn(xd, xc).view(-1).float()\n loss_grad[mask] = compute_gradient_loss(xd, xc).float()\n loss_multi[mask] = compute_multiscale_loss(xd, xc).float()\n # FFT magnitude L1\n xd_fft = torch.fft.rfft2(xd, dim=(-2, -1))\n xc_fft = torch.fft.rfft2(xc, dim=(-2, -1))\n loss_mag[mask] = (xd_fft.abs() - xc_fft.abs()).abs().flatten(1).mean(1).float()\n # Phase consistency loss: normalized complex difference\n # Normalize by magnitude to focus on phase\n eps = 1e-6\n xd_mag = xd_fft.abs().clamp(min=eps)\n xc_mag = xc_fft.abs().clamp(min=eps)\n xd_unit = xd_fft / xd_mag\n xc_unit = xc_fft / xc_mag\n # L1 on real and imag parts of normalized spectra\n loss_real = (xd_unit.real - xc_unit.real).abs().flatten(1).mean(1)\n loss_imag = (xd_unit.imag - xc_unit.imag).abs().flatten(1).mean(1)\n loss_phase[mask] = (loss_real + loss_imag).float()\n\n # Combine losses\n loss_vel = loss_mse_unscaled.mean()\n loss_perceptual = (perceptual_w * (loss_lpips + loss_grad + loss_multi + loss_mag + loss_phase)).mean()\n loss = loss_vel + loss_perceptual"
7}OK: Replaced lines 411..444 with 41 line(s) in alphaflow-main/custom_train_perceptual.py. Editable range: 411–451. [Current file: alphaflow-main/custom_train_perceptual.py | editable: 411–451 | total: 508 lines] 408| # Predict mean velocity 409| pred_mean_vel = net(x_t, sigma=t, sigma_next=t_next) 410| 411 # MSE on velocity 412 err = pred_mean_vel - mean_vel_target 413 loss_mse_unscaled = (err ** 2).flatten(1).mean(1) 414 415 # Auxiliary perceptual losses on denoised image (mask t<=0.1 edge case) 416 x_denoised = x_t - t * pred_mean_vel 417 t_flat = t.view(B) 418 mask = (t_flat > 0.1) 419 perceptual_w = ((1.0 - t_flat) ** 2) * mask.float() 420 421 loss_lpips = torch.zeros(B, device=device) 422 loss_grad = torch.zeros(B, device=device) 423 loss_multi = torch.zeros(B, de
[Test #2] (1 test remaining; call submit(n=2) to submit this result)
### train_small (scripts/train_small.sh)
Training results:
FID: 23.59, Best FID: 23.94
### train_medium (scripts/train_medium.sh)
Training results:
FID: 24.85, Best FID: 23.36
[Leaderboard] Results saved: {'fid': 24.85, 'best_fid': 23.36, 'fid_small': 23.59, 'best_fid_small': 23.94, 'elapsed_train_small': 6114, 'fid_medium': 24.85, 'best_fid_medium': 23.36, 'elapsed_train_medium': 7897}Tool Input
1{
2 "op": "replace",
3 "filename": "alphaflow-main/custom_train_perceptual.py",
4 "start_line": 411,
5 "end_line": 451,
6 "content": " # MSE on velocity\n err = pred_mean_vel - mean_vel_target\n loss_mse_unscaled = (err ** 2).flatten(1).mean(1)\n\n # Auxiliary perceptual losses on denoised image (mask t<=0.1 edge case)\n x_denoised = x_t - t * pred_mean_vel\n t_flat = t.view(B)\n mask = (t_flat > 0.1)\n perceptual_w = ((1.0 - t_flat) ** 2) * mask.float()\n\n loss_lpips = torch.zeros(B, device=device)\n loss_grad = torch.zeros(B, device=device)\n loss_multi = torch.zeros(B, device=device)\n loss_mag = torch.zeros(B, device=device)\n loss_phase = torch.zeros(B, device=device)\n if mask.any():\n xd = x_denoised[mask].clamp(-1, 1).float()\n xc = x[mask].clamp(-1, 1).float()\n loss_lpips[mask] = lpips_fn(xd, xc).view(-1).float()\n loss_grad[mask] = compute_gradient_loss(xd, xc).float()\n loss_multi[mask] = compute_multiscale_loss(xd, xc).float()\n # FFT magnitude L1\n xd_fft = torch.fft.rfft2(xd, dim=(-2, -1))\n xc_fft = torch.fft.rfft2(xc, dim=(-2, -1))\n loss_mag[mask] = (xd_fft.abs() - xc_fft.abs()).abs().flatten(1).mean(1).float()\n # Phase consistency loss: normalized complex difference\n eps = 1e-6\n xd_mag = xd_fft.abs().clamp(min=eps)\n xc_mag = xc_fft.abs().clamp(min=eps)\n xd_unit = xd_fft / xd_mag\n xc_unit = xc_fft / xc_mag\n loss_real = (xd_unit.real - xc_unit.real).abs().flatten(1).mean(1)\n loss_imag = (xd_unit.imag - xc_unit.imag).abs().flatten(1).mean(1)\n loss_phase[mask] = (loss_real + loss_imag).float()\n\n # Weighting factors for each perceptual loss\n w_lpips = 1.0\n w_grad = 0.1\n w_multi = 0.1\n w_mag = 0.5\n w_phase = 0.5\n perceptual_total = (w_lpips * loss_lpips + w_grad * loss_grad + w_multi * loss_multi +\n w_mag * loss_mag + w_phase * loss_phase)\n \n # Overall perceptual scaling\n lambda_perceptual = 0.5\n loss_vel = loss_mse_unscaled.mean()\n loss_perceptual = lambda_perceptual * (perceptual_w * perceptual_total).mean()\n loss = loss_vel + loss_perceptual"
7}