Diffusion-Prior Inverse Solver

Studies how diffusion priors and measurement guidance can be combined for inverse-problem reconstruction.

AI for ScienceInverseBench
ai4sci-inverse-diffusion-algo

Description

Task: Inverse Problem Algorithm Design with Diffusion Priors

Research Question

Design a novel algorithm for solving scientific inverse problems using pre-trained diffusion model priors. Given a forward operator A and observation y = A(x) + noise, the algorithm should reconstruct x by leveraging a learned diffusion prior p(x).

Background

Diffusion models learn rich priors p(x) over signal distributions. For inverse problems, we want to sample from the posterior p(x|y) ∝ p(y|x) p(x). Existing approaches include:

  • DPS — Diffusion Posterior Sampling (Chung et al., "Diffusion Posterior Sampling for General Noisy Inverse Problems", ICLR 2023; arXiv:2209.14687). Uses the score ∇_x log p(x) from the diffusion model and adds measurement guidance ∇_x log p(y|x) at each denoising step. Code: https://github.com/DPS2022/diffusion-posterior-sampling.
  • REDDiff — Variational / Regularization-by-Denoising-Diffusion (Mardani, Song, Kautz, Vahdat, "A Variational Perspective on Solving Inverse Problems with Diffusion Models", ICLR 2024; arXiv:2305.04391). Variational formulation that yields a regularization-by-denoising update where denoisers at different timesteps concurrently impose structural constraints. Code: https://github.com/NVlabs/RED-diff.
  • LGD — Loss-Guided Diffusion (Song et al., "Loss-Guided Diffusion Models for Plug-and-Play Controllable Generation", ICML 2023). Estimates the guidance term via Monte Carlo sampling around the denoised estimate to reduce bias of point-estimate approximations.

What to Implement

Implement the Custom class in algo/custom.py. You must implement:

  1. __init__: Set up your algorithm (schedulers, optimizers, hyperparameters).
  2. inference(observation, num_samples): Given observation y, return reconstructed x.

Available Components

  • self.net(x, sigma) → denoised estimate (Tweedie's formula: E[x_0 | x_t]).
  • self.forward_op.forward(x) → compute A(x).
  • self.forward_op.gradient(x, y, return_loss=True)(∇_x ||A(x) - y||², loss).
  • self.forward_op.loss(x, y)||A(x) - y||².
  • Scheduler(num_steps, schedule, timestep, scaling) → diffusion noise schedule.
  • DiffusionSampler(scheduler).sample(model, x_start) → unconditional sampling.

The pretrained denoiser, the forward-operator definitions, and the evaluation problems are fixed; the algorithm only chooses how to combine these pieces.

Evaluation

The algorithm is tested on three scientific inverse problems:

  1. Inverse Scattering (optical tomography): Recover permittivity from scattered EM fields. Metrics: PSNR, SSIM.
  2. Black Hole Imaging (radio astronomy): Reconstruct black hole images from sparse interferometric observations (EHT data). Metrics: PSNR, blur-PSNR (f=15), closure-phase chi-squared.
  3. FFHQ256 Image Inpainting (computer vision): Recover an FFHQ-256 face image from a masked observation (box mask) with additive Gaussian noise (σ=0.05). The forward operator is a fixed pixel-wise mask. Metrics: PSNR, SSIM, LPIPS.

Higher PSNR/SSIM is better; lower LPIPS and chi-squared are better.

Editable Region

The entire algo/custom.py file is editable. You may define any helper classes/functions within this file.

Code

custom.py
EditableRead-only
1import torch
2from tqdm import tqdm
3from algo.base import Algo
4from utils.scheduler import Scheduler
5from utils.diffusion import DiffusionSampler
6import numpy as np
7
8
9class Custom(Algo):
10 """Custom algorithm for solving inverse problems with diffusion priors.
11
12 Available utilities:
13 - self.net: pre-trained diffusion model.
14 - self.net(x, sigma) returns denoised estimate (Tweedie's formula).
15 - self.net.img_channels, self.net.img_resolution: image shape info.
base.py
EditableRead-only
1from abc import ABC, abstractmethod
2
3
4class Algo(ABC):
5 def __init__(self, net, # pre-trained diffusion model
6 forward_op # forward operator of the inverse problem
7 ):
8 self.net = net
9 self.forward_op = forward_op
10
11 @abstractmethod
12 def inference(self, observation, num_samples=1, **kwargs):
13 '''
14 Args:
15 - observation: observation for one single ground truth
scheduler.py
EditableRead-only
1import numpy as np
2import copy
3
4'''
5 Scheduler for diffusion sampling following EDM framework.
6 schedule (\sigma(t)): linear, sqrt, vp
7 timestep (discretization of t): log, poly-n, vp
8 scaling: none, vp
9
10 Example:
11 VP: Scheduler(num_steps=1000, schedule='vp', timestep='vp', scaling='vp')
12 VE: Scheduler(num_steps=1000, schedule='sqrt', timestep='log', scaling='none')
13 EDM: Scheduler(num_steps=200, schedule='linear', timestep='poly-7', scaling='none')
14
15 Example Usage: See DiffusionSampler in utils/diffusion.py for unconditional diffusion sampling.
diffusion.py
EditableRead-only
1from tqdm import tqdm
2import torch
3import numpy as np
4from utils.scheduler import Scheduler
5
6class DiffusionSampler:
7 """
8 Diffusion sampler for reverse SDE or PF-ODE
9 """
10
11 def __init__(self, scheduler, solver='euler'):
12 """
13 Initializes the diffusion sampler with the given scheduler and solver.
14
15 Parameters:
base.py
EditableRead-only
1from abc import ABC, abstractmethod
2from torch.autograd import grad
3
4import torch
5from typing import Dict
6
7
8class BaseOperator(ABC):
9 def __init__(self, sigma_noise=0.0, unnorm_shift=0.0, unnorm_scale=1.0, device='cuda'):
10 self.sigma_noise = sigma_noise
11 self.unnorm_shift = unnorm_shift
12 self.unnorm_scale = unnorm_scale
13 self.device = device
14
15 @abstractmethod
blackhole.py
EditableRead-only
1from inverse_problems.base import BaseOperator
2import ehtim.statistics.dataframes as ehdf
3import pandas as pd
4import torch
5import numpy as np
6import ehtim as eh
7from eval import Evaluator
8import copy
9from piq import psnr
10import torch.nn.functional as F
11from typing import Dict
12
13
14class BlackHoleImaging(BaseOperator):
15 """
image_restore.py
EditableRead-only
1from abc import ABC, abstractmethod
2import numpy as np
3import scipy
4
5import torch
6from torch.nn import functional as F
7from torchvision.transforms import functional as TF
8from torchvision.transforms import InterpolationMode
9
10from typing import List, Optional
11from .base import BaseOperator
12
13# helper functions for implementing the operators
14class Blurkernel(torch.nn.Module):
15 def __init__(self, blur_type='gaussian', kernel_size=31, std=3.0, device=None):

Method Summary

Auto-summarized from each method's code by an LLM reviewer — not the model's original output. Browse via the picker below; the Code section is independent.
Baselines
Agents
Claude Opus 4.6·Pseudocodehigh

DPS + proximal $x_0$ refinement + EMA

Combine DPS Jacobian guidance with iterative clean-space proximal refinement of $\hat x_0$, both smoothed by per-step EMA momentum.

1. for diffusion step i=1Ni = 1\ldots N do
2. x^0net(xt/st,σt)\hat x_0 \leftarrow \mathrm{net}(x_t/s_t, \sigma_t)
3. gdpsxtx^0x^0Ax^0y2g_{dps} \leftarrow \nabla_{x_t} \hat x_0 \cdot \nabla_{\hat x_0}\|A\hat x_0 - y\|^2
4. mdpsβmdps+(1β)gdpsm_{dps} \leftarrow \beta\, m_{dps} + (1-\beta)\,g_{dps}
5. x~0x^0\tilde x_0 \leftarrow \hat x_0; for k=1Kk=1\ldots K: x~0x~0ηx~0Ax~0y2\tilde x_0 \leftarrow \tilde x_0 - \eta\, \nabla_{\tilde x_0}\|A\tilde x_0 - y\|^2 (with grad-norm clip)
6. mproxβmprox+(1β)(x~0x^0)m_{prox} \leftarrow \beta\, m_{prox} + (1-\beta)(\tilde x_0 - \hat x_0)
7. wσ(1+σ)pw_\sigma \leftarrow (1+\sigma)^{-p}; x^0x^0+λpwσmprox\hat x_0' \leftarrow \hat x_0 + \lambda_p\, w_\sigma\, m_{prox}
8. apply SDE/ODE step using score (x^0xt/st)/(σ2st)(\hat x_0' - x_t/s_t)/(\sigma^2 s_t), then xt1=λdpsmdpsx_{t-1} \mathrel{-}= \lambda_{dps}\, m_{dps}
9. if inpaint mask MM: xt1(1M)xt1+M(wy+(1w)xt1)x_{t-1} \leftarrow (1-M)\,x_{t-1} + M(w\, y + (1-w)\,x_{t-1})
Δ vs. baselineAdds two ingredients on top of DPS: an inner proximal-gradient loop refining $\hat x_0$ in the clean signal domain, and EMA momentum on both DPS and proximal updates; also injects mask-based pixel replacement for inpainting.
$\lambda_{dps}$={50, 1e-3, 1.0}$\lambda_{prox}$={1.0, 0.5, 1.0}$\beta$={0.7, 0.5, 0.6}$K$={3, 2, 3}$\eta$={1e-3, 5e-4, 1e-2}$p$={1.0, 1.0, 0.5}use_data_replacement=true (inpainting only)Recovers DPS when $\lambda_{prox}=0, \beta=0$

Results