cv-vae-loss

Computer Visiondiffusers-mainrigorous codebase

Description

VAE Loss Function Design for Image Reconstruction

Objective

Design a training loss function for a Variational Autoencoder (VAE) that achieves the best reconstruction quality on CIFAR-10.

Background

Variational Autoencoders encode images into a compressed latent representation and decode them back. The quality of reconstruction depends critically on the training loss function. Standard approaches use combinations of:

  • Reconstruction loss: L1 or L2 pixel-level error
  • KL divergence: Regularizes the latent space toward a standard normal prior
  • Perceptual loss: LPIPS or VGG-based feature matching for perceptual quality
  • Adversarial loss: Discriminator-based training for sharpness
  • Frequency-domain loss: FFT-based weighting to preserve fine detail

Recent work on the Prism Hypothesis (UAE, Fan et al.) demonstrates that explicitly handling different frequency bands in the training objective can significantly improve reconstruction quality. The key insight is that semantic information concentrates at low frequencies while fine perceptual detail lives in higher bands.

Task

Implement the VAELoss class in custom_train.py (lines 32–76). Your loss function will be used to train an AutoencoderKL model from the diffusers library on CIFAR-10 32×32 images.

Editable Region (lines 32–76)

class VAELoss(nn.Module):
    def __init__(self, device):
        super().__init__()
        # Initialize your loss components here

    def forward(self, recon, target, posterior, step):
        # recon:     [B, 3, 32, 32] reconstructed images in [-1, 1]
        # target:    [B, 3, 32, 32] original images in [-1, 1]
        # posterior:  DiagonalGaussianDistribution
        #             - posterior.kl() -> KL divergence per sample
        #             - posterior.mean, posterior.logvar
        # step:       current training step (int)
        #
        # Return: (loss_tensor, metrics_dict)
        ...

Available Libraries

  • torch, torch.nn, torch.nn.functional — standard PyTorch
  • torch.fft — frequency-domain operations (fft2, ifft2, fftshift, etc.)
  • lpips — learned perceptual loss: lpips.LPIPS(net='vgg').to(device)
  • numpy, math

Architecture (Fixed)

The model is AutoencoderKL from diffusers with 3 blocks and 2 downsample stages, giving latent resolution 8×8 (f=4 compression) suited for 32×32 input:

  • latent_channels=4, layers_per_block=2
  • GroupNorm (32 groups) + SiLU activation

Channel widths and latent channels scale via environment variables:

  • Small: BLOCK_OUT_CHANNELS=(64,128,256), LATENT_CHANNELS=4 — lightweight
  • Medium: BLOCK_OUT_CHANNELS=(96,192,384), LATENT_CHANNELS=8 — standard
  • Large: BLOCK_OUT_CHANNELS=(128,256,512), LATENT_CHANNELS=16 — wide

Training (Fixed)

  • Optimizer: AdamW, lr=4e-4, weight_decay=1e-4
  • LR schedule: 5% warmup + cosine decay
  • Mixed precision (autocast + GradScaler)
  • Gradient clipping at 1.0
  • EMA with rate 0.999

Evaluation

Reconstruction quality is measured on the full CIFAR-10 test set (10,000 images):

MetricDirectionDescription
rFIDlower is betterReconstruction FID between original and reconstructed test images (primary metric)
PSNRhigher is betterPeak signal-to-noise ratio in dB
SSIMhigher is betterStructural similarity index

Training scales:

  • Small (group 1): 20,000 steps, channels (64,128,256), latent_channels=4
  • Medium (group 2): 30,000 steps, channels (96,192,384), latent_channels=8
  • Large (group 3): 30,000 steps, channels (128,256,512), latent_channels=16

Baselines

NameStrategyDescription
l2-klMSE + KLSimplest VAE loss: pixel-level L2 reconstruction + KL regularization
perceptualMSE + LPIPS + KLStandard practice: adds learned perceptual similarity (VGG features)
freq-weightedL1 + LPIPS + GAN + KLMulti-objective: combines L1 reconstruction, LPIPS perceptual loss, and PatchGAN adversarial training

Code

custom_train.py
EditableRead-only
1"""VAE Training on CIFAR-10 with configurable loss function.
2
3Uses AutoencoderKL architecture (fixed). Only the loss function is editable.
4"""
5
6import copy
7import math
8import os
9import shutil
10import sys
11import time
12from datetime import timedelta
13
14import numpy as np
15import torch

Results

ModelTypebest rfid small best rfid medium best rfid large psnr medium psnr small ssim medium ssim small
l2-klbaseline53.00015.7505.350----
perceptualbaseline19.64010.6903.840----
vqganbaseline15.8007.1303.380----
anthropic/claude-opus-4.6vanilla-7.900-32.040-0.974-
deepseek-reasonervanilla-14.060-22.520-0.860-
google/gemini-3.1-pro-previewvanilla21.04011.070-30.71026.2200.9670.914
qwen/qwen3.6-plusvanilla18.20010.910-29.82025.5400.9610.909
anthropic/claude-opus-4.6agent24.2807.9004.050----
deepseek-reasoneragent20.0807.4703.980----
google/gemini-3.1-pro-previewagent20.9409.4903.430----
qwen/qwen3.6-plusagent18.20010.9104.360----

Agent Conversations