Agent Conversation: anthropic/claude-opus-4.6
Task: cv-diffusion-efficiency
Diffusion Model: Sampler Efficiency Optimization
Objective
Design an efficient sampling algorithm for text-to-image diffusion models that achieves high generation quality with minimal sampling steps (NFE).
Background
Diffusion models generate images by iteratively denoising from random noise. Different sampling methods have different trade-offs:
- DDIM: First-order ODE solver, deterministic, fast but may need more steps for quality
- Euler: Simple first-order method, baseline performance
- DPM++ 2M: Second-order multistep method, more efficient
- DPM++ 2S: Second-order singlestep method, higher quality per step
The core sampling loop follows this pattern:
for step, t in enumerate(timesteps):
# 1. Predict noise
noise_pred = model(zt, t, text_embedding)
# 2. Estimate clean image (Tweedie's formula)
z0t = (zt - sigma_t * noise_pred) / alpha_t
# 3. Update to next step (THIS IS THE KEY DIFFERENCE)
zt_next = update_rule(zt, z0t, noise_pred, t, t_next)
Different samplers use different update_rule strategies.
Task
Your goal is to design an improved sampling update rule that achieves better image-text alignment (CLIP score) with a fixed budget of NFE=20 steps. You must implement your improvement in two files:
latent_diffusion.py—BaseDDIMCFGppclass for SD v1.5latent_sdxl.py—BaseDDIMCFGppclass for SDXL
Editable Regions
SD v1.5 (latent_diffusion.py, lines 621-679)
- Class
BaseDDIMCFGpp(StableDiffusion)withsample()method - Key API:
self.get_text_embed(),self.initialize_latent(),self.predict_noise(),self.alpha(t)
SDXL (latent_sdxl.py, lines 713-755)
- Class
BaseDDIMCFGpp(SDXL)withreverse_process()method - Key API:
self.initialize_latent(size=...),self.predict_noise(),self.scheduler.alphas_cumprod[t]
Evaluation
- Metric: CLIP score (cosine similarity between generated image and text prompt)
- Fixed budget: NFE=20 steps
- Test prompts: 100 diverse COCO-style prompts
- Seeds: Multi-seed evaluation
Baselines
- ddim: Standard DDIM sampler (first-order)
- dpm2m: DPM++ 2M sampler (second-order multistep)
- dpm2s: DPM++ 2S sampler (second-order singlestep)
Your implementation should aim to achieve higher CLIP scores than all baselines with the same NFE=20 budget.
Action Budget (HARD LIMIT)
You have 20 total actions (every edit/test/submit/undo counts as 1) and 3 total test() calls. After 20 actions the run is force-stopped immediately. You MUST call submit(n=...) before the budget runs out — if no submit happens, your run records ZERO and all your work is lost. Plan accordingly: a typical good run uses ~15 edits, 3 tests, and 1 submit. Save submit budget for the end.
CFGpp-main/latent_diffusion.py [EDITABLE — lines 621–679 only]
1: """
2: This module includes LDM-based inverse problem solvers.
3Tool Input
1{
2 "op": "replace",
3 "filename": "CFGpp-main/latent_diffusion.py",
4 "start_line": 621,
5 "end_line": 679,
6 "content": "@register_solver(\"ddim_cfg++\")\nclass BaseDDIMCFGpp(StableDiffusion):\n \"\"\"\n DPM++ 2M with Adaptive Renoising Interpolation (ARI).\n Combines second-order multistep ODE solving with progressive\n interpolation between CFG++ and standard CFG renoising.\n Early steps use CFG++ (unconditional noise) for trajectory stability,\n later steps blend in conditional guidance for improved text alignment.\n \"\"\"\n def __init__(self, solver_config: Dict, model_key: str = \"runwayml/stable-diffusion-v1-5\",\n device: Optional[torch.device] = None, **kwargs):\n super().__init__(solver_config, model_key, device, **kwargs)\n self.ari_max_blend = 0.25 # Maximum blend toward conditional denoised\n\n @torch.autocast(device_type='cuda', dtype=torch.float16)\n def sample(self, cfg_guidance=7.5, prompt=[\"\", \"\"], callback_fn=None, **kwargs):\n t_fn = lambda sigma: sigma.log().neg()\n\n uc, c = self.get_text_embed(null_prompt=prompt[0], prompt=prompt[1])\n total_sigmas = (1 - self.total_alphas).sqrt() / self.total_alphas.sqrt()\n sigmas = get_sigmas_karras(len(self.scheduler.timesteps), total_sigmas.min(), total_sigmas.max(), rho=7.)\n\n x = self.initialize_latent(method=\"random_kdiffusion\",\n latent_dim=(1, 4, 64, 64), sigmas=sigmas).to(torch.float16)\n old_blend = None\n num_steps = len(self.scheduler.timesteps)\n\n pbar = tqdm(self.scheduler.timesteps, desc=\"DPM++2M-ARI\")\n for i, _ in enumerate(pbar):\n sigma = sigmas[i]\n new_t = self.timestep(sigma).to(self.device)\n\n with torch.no_grad():\n denoised, uncond_denoised = self.kdiffusion_x_to_denoised(\n x, sigma, uc, c, cfg_guidance, new_t)\n\n # Adaptive Renoising Interpolation (ARI):\n # Cosine schedule blending from pure CFG++ toward partial standard CFG\n progress = i / max(num_steps - 1, 1)\n w = 0.5 * (1.0 - (progress * 3.14159265)).__class__((progress * 3.14159265).__neg__()).__class__(1.0)\n # Use simple quadratic ramp instead for numerical stability\n w = (progress ** 2) * self.ari_max_blend\n blend = (1.0 - w) * uncond_denoised + w * denoised\n\n # DPM++ 2M multistep update with ARI-blended renoising direction\n t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])\n h = t_next - t\n if old_blend is None or sigmas[i + 1] == 0:\n x = denoised + self.to_d(x, sigmas[i], blend) * sigmas[i + 1]\n else:\n h_last = t - t_fn(sigmas[i - 1])\n r = h_last / h\n extra1 = -torch.exp(-h) * blend - (-h).expm1() * (blend - old_blend) / (2 * r)\n extra2 = torch.exp(-h) * x\n x = denoised + extra1 + extra2\n old_blend = blend\n\n if callback_fn is not None:\n callback_kwargs = {'z0t': denoised.detach(), 'zt': x.detach(), 'decode': self.decode}\n callback_kwargs = callback_fn(i, new_t, callback_kwargs)\n denoised = callback_kwargs[\"z0t\"]\n x = callback_kwargs[\"zt\"]\n\n img = self.decode(x)\n img = (img / 2 + 0.5).clamp(0, 1)\n return img.detach().cpu()\n\n\n"
7}OK: Replaced lines 621..679 with 68 line(s) in CFGpp-main/latent_diffusion.py. Editable range: 621–688.
[Current file: CFGpp-main/latent_diffusion.py | editable: 621–688 | total: 1027 lines]
618| # CFG++ version
619| ###########################################
620|
621 @register_solver("ddim_cfg++")
622 class BaseDDIMCFGpp(StableDiffusion):
623 """
624 DPM++ 2M with Adaptive Renoising Interpolation (ARI).
625 Combines second-order multistep ODE solving with progressive
626 interpolation between CFG++ and standard CFG renoising.
627 Early steps use CFG++ (unconditional noise) for trajectory stability,
628 later steps blend in conditional guidance for improved text alignment.
629 """
630 def __init__(self, solver_config: Dict, model_key: str = "runwayml/stable-diffusion-v1-5",
631 device: Optional[torch.device] = None, **kwargs):
632 super().__init__(solver_config, model_ke
Tool Input
1{
2 "content": "@register_solver(\"ddim_cfg++\")\nclass BaseDDIMCFGpp(SDXL):\n \"\"\"DPM++ 2M with Adaptive Renoising Interpolation (ARI) for SDXL.\"\"\"\n quantize = True\n\n def reverse_process(self, null_prompt_embeds, prompt_embeds, cfg_guidance,\n add_cond_kwargs, shape=(1024, 1024), callback_fn=None, **kwargs):\n t_fn = lambda sigma: sigma.log().neg()\n ari_max_blend = 0.25\n\n alphas = self.scheduler.alphas_cumprod[self.scheduler.timesteps.int().cpu()].cpu()\n sigmas = (1 - alphas).sqrt() / alphas.sqrt()\n\n x = self.initialize_latent(size=(1, 4, shape[1] // self.vae_scale_factor,\n shape[0] // self.vae_scale_factor))\n x = x * sigmas[0]\n\n old_blend = None\n num_steps = len(self.scheduler.timesteps) - 1\n\n pbar = tqdm(self.scheduler.timesteps[:-1].int(), desc='SDXL-DPM++2M-ARI')\n for i, _ in enumerate(pbar):\n at = alphas[i]\n sigma = sigmas[i]\n c_in, c_out = at.sqrt(), -sigma\n\n new_t = self.sigma_to_t(sigma).to(self.device)\n with torch.no_grad():\n noise_uc, noise_c = self.predict_noise(x * c_in, new_t, null_prompt_embeds,\n prompt_embeds, add_cond_kwargs)\n noise_pred = noise_uc + cfg_guidance * (noise_c - noise_uc)\n\n denoised = x + c_out * noise_pred\n uncond_denoised = x + c_out * noise_uc\n\n # Adaptive Renoising Interpolation\n progress = i / max(num_steps - 1, 1)\n w = (progress ** 2) * ari_max_blend\n blend = (1.0 - w) * uncond_denoised + w * denoised\n\n t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])\n h = t_next - t\n if old_blend is None or sigmas[i + 1] == 0:\n x = denoised + self.to_d(x, sigmas[i], blend) * sigmas[i + 1]\n else:\n h_last = t - t_fn(sigmas[i - 1])\n r = h_last / h\n extra1 = -torch.exp(-h) * blend - (-h).expm1() * (blend - old_blend) / (2 * r)\n extra2 = torch.exp(-h) * x\n x = denoised + extra1 + extra2\n old_blend = blend\n\n if callback_fn is not None:\n callback_kwargs = {'z0t': denoised.detach(), 'zt': x.detach(), 'decode': self.decode}\n callback_kwargs = callback_fn(i, new_t, callback_kwargs)\n return x\n\n\n",
3 "end_line": 755,
4 "filename": "CFGpp-main/latent_sdxl.py",
5 "op": "replace",
6 "start_line": 713
7}OK: Replaced lines 713..755 with 58 line(s) in CFGpp-main/latent_sdxl.py. Editable range: 713–770.
[Current file: CFGpp-main/latent_sdxl.py | editable: 713–770 | total: 1046 lines]
710| # CFG++ version
711| ###########################################
712|
713 @register_solver("ddim_cfg++")
714 class BaseDDIMCFGpp(SDXL):
715 """DPM++ 2M with Adaptive Renoising Interpolation (ARI) for SDXL."""
716 quantize = True
717
718 def reverse_process(self, null_prompt_embeds, prompt_embeds, cfg_guidance,
719 add_cond_kwargs, shape=(1024, 1024), callback_fn=None, **kwargs):
720 t_fn = lambda sigma: sigma.log().neg()
721 ari_max_blend = 0.25
722
723 alphas = self.scheduler.alphas_cumprod[self.scheduler.timesteps.int().cpu()].cpu()
724 sigmas = (1 - alphas).sqrt() / alphas.sqrt()
725
726 x = self.initialize_latent(size=(1, 4, shape[1] // self.vae_scale_fact
[Test #1] (2 tests remaining; call submit(n=1) to submit this result)
### sd15 (scripts/eval_sd15.sh)
Generation results:
GENERATION_METRICS model=sd15 method=ddim_cfg++ cfg_guidance=1.0 NFE=50 seed=42 fid=26.9859 clip_score=0.3137
sd15: CLIP=0.3137 FID=26.9859
Average CLIP score: 0.3137
Average FID: 26.9859
[Leaderboard] Results saved: {'clip_score_sd15': 0.3137, 'fid_sd15': 26.9859, 'clip_score': 0.3137, 'fid': 26.9859, 'elapsed_sd15': 976}Tool Input
1{
2 "content": "@register_solver(\"ddim_cfg++\")\nclass BaseDDIMCFGpp(StableDiffusion):\n \"\"\"\n DPM++ 2M with Adaptive Renoising Interpolation (ARI).\n Combines second-order multistep ODE solving with progressive\n interpolation between CFG++ and standard CFG renoising.\n For low guidance (cfg~1), uses mostly conditional denoised for renoising.\n For high guidance, starts with unconditional and progressively blends conditional.\n \"\"\"\n def __init__(self, solver_config: Dict, model_key: str = \"runwayml/stable-diffusion-v1-5\",\n device: Optional[torch.device] = None, **kwargs):\n super().__init__(solver_config, model_key, device, **kwargs)\n\n @torch.autocast(device_type='cuda', dtype=torch.float16)\n def sample(self, cfg_guidance=7.5, prompt=[\"\", \"\"], callback_fn=None, **kwargs):\n t_fn = lambda sigma: sigma.log().neg()\n\n uc, c = self.get_text_embed(null_prompt=prompt[0], prompt=prompt[1])\n total_sigmas = (1 - self.total_alphas).sqrt() / self.total_alphas.sqrt()\n sigmas = get_sigmas_karras(len(self.scheduler.timesteps), total_sigmas.min(), total_sigmas.max(), rho=7.)\n\n x = self.initialize_latent(method=\"random_kdiffusion\",\n latent_dim=(1, 4, 64, 64), sigmas=sigmas).to(torch.float16)\n old_blend = None\n num_steps = len(self.scheduler.timesteps)\n\n # Adaptive base blend: high cfg -> use more uncond (CFG++), low cfg -> use cond\n # When cfg=1, the uncond noise direction is less informative, prefer cond denoised\n base_blend = min(1.0, max(0.3, 1.0 - 0.7 / max(cfg_guidance, 0.1)))\n # base_blend: fraction of conditional denoised in the renoising blend\n # cfg=1 -> base_blend=0.3, cfg=7.5 -> base_blend~0.91\n\n pbar = tqdm(self.scheduler.timesteps, desc=\"DPM++2M-ARI\")\n for i, _ in enumerate(pbar):\n sigma = sigmas[i]\n new_t = self.timestep(sigma).to(self.device)\n\n with torch.no_grad():\n denoised, uncond_denoised = self.kdiffusion_x_to_denoised(\n x, sigma, uc, c, cfg_guidance, new_t)\n\n # Adaptive Renoising Interpolation (ARI):\n # Progressively increase conditional blend from early to late steps\n progress = i / max(num_steps - 1, 1)\n # Start with more unconditional (stability), end with more conditional (alignment)\n w = base_blend * (0.5 + 0.5 * progress)\n blend = (1.0 - w) * uncond_denoised + w * denoised\n\n # DPM++ 2M multistep update with ARI-blended renoising direction\n t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])\n h = t_next - t\n if old_blend is None or sigmas[i + 1] == 0:\n x = denoised + self.to_d(x, sigmas[i], blend) * sigmas[i + 1]\n else:\n h_last = t - t_fn(sigmas[i - 1])\n r = h_last / h\n extra1 = -torch.exp(-h) * blend - (-h).expm1() * (blend - old_blend) / (2 * r)\n extra2 = torch.exp(-h) * x\n x = denoised + extra1 + extra2\n old_blend = blend\n\n if callback_fn is not None:\n callback_kwargs = {'z0t': denoised.detach(), 'zt': x.detach(), 'decode': self.decode}\n callback_kwargs = callback_fn(i, new_t, callback_kwargs)\n denoised = callback_kwargs[\"z0t\"]\n x = callback_kwargs[\"zt\"]\n\n img = self.decode(x)\n img = (img / 2 + 0.5).clamp(0, 1)\n return img.detach().cpu()\n\n\n",
3 "end_line": 688,
4 "filename": "CFGpp-main/latent_diffusion.py",
5 "op": "replace",
6 "start_line": 621
7}OK: Replaced lines 621..688 with 72 line(s) in CFGpp-main/latent_diffusion.py. Editable range: 621–692.
[Current file: CFGpp-main/latent_diffusion.py | editable: 621–692 | total: 1031 lines]
618| # CFG++ version
619| ###########################################
620|
621 @register_solver("ddim_cfg++")
622 class BaseDDIMCFGpp(StableDiffusion):
623 """
624 DPM++ 2M with Adaptive Renoising Interpolation (ARI).
625 Combines second-order multistep ODE solving with progressive
626 interpolation between CFG++ and standard CFG renoising.
627 For low guidance (cfg~1), uses mostly conditional denoised for renoising.
628 For high guidance, starts with unconditional and progressively blends conditional.
629 """
630 def __init__(self, solver_config: Dict, model_key: str = "runwayml/stable-diffusion-v1-5",
631 device: Optional[torch.device] = None, **kwargs):
632 super().__init__(solver_
Tool Input
1{
2 "content": "@register_solver(\"ddim_cfg++\")\nclass BaseDDIMCFGpp(SDXL):\n \"\"\"DPM++ 2M with Adaptive Renoising Interpolation (ARI) for SDXL.\"\"\"\n quantize = True\n\n def reverse_process(self, null_prompt_embeds, prompt_embeds, cfg_guidance,\n add_cond_kwargs, shape=(1024, 1024), callback_fn=None, **kwargs):\n t_fn = lambda sigma: sigma.log().neg()\n\n alphas = self.scheduler.alphas_cumprod[self.scheduler.timesteps.int().cpu()].cpu()\n sigmas = (1 - alphas).sqrt() / alphas.sqrt()\n\n x = self.initialize_latent(size=(1, 4, shape[1] // self.vae_scale_factor,\n shape[0] // self.vae_scale_factor))\n x = x * sigmas[0]\n\n old_blend = None\n num_steps = len(self.scheduler.timesteps) - 1\n # Adaptive base blend: high cfg -> more uncond (CFG++), low cfg -> more cond\n base_blend = min(1.0, max(0.3, 1.0 - 0.7 / max(cfg_guidance, 0.1)))\n\n pbar = tqdm(self.scheduler.timesteps[:-1].int(), desc='SDXL-DPM++2M-ARI')\n for i, _ in enumerate(pbar):\n at = alphas[i]\n sigma = sigmas[i]\n c_in, c_out = at.sqrt(), -sigma\n\n new_t = self.sigma_to_t(sigma).to(self.device)\n with torch.no_grad():\n noise_uc, noise_c = self.predict_noise(x * c_in, new_t, null_prompt_embeds,\n prompt_embeds, add_cond_kwargs)\n noise_pred = noise_uc + cfg_guidance * (noise_c - noise_uc)\n\n denoised = x + c_out * noise_pred\n uncond_denoised = x + c_out * noise_uc\n\n # Adaptive Renoising Interpolation\n progress = i / max(num_steps - 1, 1)\n w = base_blend * (0.5 + 0.5 * progress)\n blend = (1.0 - w) * uncond_denoised + w * denoised\n\n t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])\n h = t_next - t\n if old_blend is None or sigmas[i + 1] == 0:\n x = denoised + self.to_d(x, sigmas[i], blend) * sigmas[i + 1]\n else:\n h_last = t - t_fn(sigmas[i - 1])\n r = h_last / h\n extra1 = -torch.exp(-h) * blend - (-h).expm1() * (blend - old_blend) / (2 * r)\n extra2 = torch.exp(-h) * x\n x = denoised + extra1 + extra2\n old_blend = blend\n\n if callback_fn is not None:\n callback_kwargs = {'z0t': denoised.detach(), 'zt': x.detach(), 'decode': self.decode}\n callback_kwargs = callback_fn(i, new_t, callback_kwargs)\n return x\n\n\n",
3 "end_line": 770,
4 "filename": "CFGpp-main/latent_sdxl.py",
5 "op": "replace",
6 "start_line": 713
7}OK: Replaced lines 713..770 with 59 line(s) in CFGpp-main/latent_sdxl.py. Editable range: 713–771.
[Current file: CFGpp-main/latent_sdxl.py | editable: 713–771 | total: 1047 lines]
710| # CFG++ version
711| ###########################################
712|
713 @register_solver("ddim_cfg++")
714 class BaseDDIMCFGpp(SDXL):
715 """DPM++ 2M with Adaptive Renoising Interpolation (ARI) for SDXL."""
716 quantize = True
717
718 def reverse_process(self, null_prompt_embeds, prompt_embeds, cfg_guidance,
719 add_cond_kwargs, shape=(1024, 1024), callback_fn=None, **kwargs):
720 t_fn = lambda sigma: sigma.log().neg()
721
722 alphas = self.scheduler.alphas_cumprod[self.scheduler.timesteps.int().cpu()].cpu()
723 sigmas = (1 - alphas).sqrt() / alphas.sqrt()
724
725 x = self.initialize_latent(size=(1, 4, shape[1] // self.vae_scale_factor,
726
Tool Input
1{
2 "op": "replace",
3 "filename": "CFGpp-main/latent_diffusion.py",
4 "start_line": 621,
5 "end_line": 679,
6 "content": "@register_solver(\"ddim_cfg++\")\nclass BaseDDIMCFGpp(StableDiffusion):\n \"\"\"\n CFG++ with Rescaled Guidance, Momentum Correction, and Dynamic Thresholding.\n \n Novel contributions over standard DDIM/CFG++:\n 1. Rescaled CFG: normalize guided noise to match conditional std, preventing\n oversaturation that degrades CLIP alignment.\n 2. CFG++ renoising: use unconditional noise for the diffusion step (stabler ODE).\n 3. Momentum-corrected z0t: EMA of clean-image estimates across steps to reduce\n high-frequency prediction jitter, especially in early noisy steps.\n 4. Dynamic thresholding (Imagen-style): clamp predicted clean latents to\n prevent extreme values that cause artifacts.\n \"\"\"\n def __init__(self, solver_config, model_key=\"runwayml/stable-diffusion-v1-5\",\n device=None, **kwargs):\n super().__init__(solver_config, model_key, device, **kwargs)\n self.rescale_phi = 0.7 # CFG rescale blending factor\n self.momentum_beta = 0.12 # EMA momentum for z0t correction\n self.thresh_pct = 0.995 # dynamic thresholding percentile\n\n @torch.autocast(device_type='cuda', dtype=torch.float16)\n def sample(self, cfg_guidance=7.5, prompt=[\"\",\"\"], callback_fn=None, **kwargs):\n uc, c = self.get_text_embed(null_prompt=prompt[0], prompt=prompt[1])\n zt = self.initialize_latent().requires_grad_()\n z0t_prev = None\n\n pbar = tqdm(self.scheduler.timesteps, desc=\"CFG++R\")\n for step, t in enumerate(pbar):\n at = self.alpha(t)\n at_prev = self.alpha(t - self.skip)\n with torch.no_grad():\n noise_uc, noise_c = self.predict_noise(zt, t, uc, c)\n noise_cfg = noise_uc + cfg_guidance * (noise_c - noise_uc)\n # Rescaled CFG: match std of conditional prediction\n std_cfg = noise_cfg.std(dim=[1,2,3], keepdim=True) + 1e-8\n std_c = noise_c.std(dim=[1,2,3], keepdim=True)\n noise_rescaled = noise_cfg * (std_c / std_cfg)\n noise_pred = self.rescale_phi * noise_rescaled + (1 - self.rescale_phi) * noise_cfg\n # Tweedie: predict clean image\n z0t = (zt - (1-at).sqrt() * noise_pred) / at.sqrt()\n # Dynamic thresholding\n s = torch.quantile(z0t.abs().reshape(z0t.shape[0], -1), self.thresh_pct, dim=1)\n s = torch.clamp(s, min=1.0)[:, None, None, None]\n z0t = z0t.clamp(-s, s) / s\n # Momentum correction: EMA of z0t predictions (decay with progress)\n if z0t_prev is not None:\n progress = step / len(self.scheduler.timesteps)\n beta = self.momentum_beta * (1.0 - progress)\n z0t = (1 - beta) * z0t + beta * z0t_prev\n z0t_prev = z0t.detach().clone()\n # CFG++ renoising: use unconditional noise for stable ODE trajectory\n zt = at_prev.sqrt() * z0t + (1-at_prev).sqrt() * noise_uc\n if callback_fn is not None:\n callback_kwargs = {'z0t': z0t.detach(), 'zt': zt.detach(),\n 'decode': self.decode}\n callback_kwargs = callback_fn(step, t, callback_kwargs)\n z0t = callback_kwargs[\"z0t\"]\n zt = callback_kwargs[\"zt\"]\n\n"
7}ERROR: Lines 621..679 exceed the editable range. You may only edit lines 624–682.
[Current file: CFGpp-main/latent_diffusion.py | editable: 624–682 | total: 1021 lines]
621| @register_solver("ddim_cfg++")
622| class BaseDDIMCFGpp(StableDiffusion):
623| # TODO: Implement your improved sampling method here.
624 #
625 # You should implement an improved sampling algorithm that achieves better
626 # image-text alignment (CLIP score) with a fixed budget of NFE=20 steps.
627 #
628 # Key methods you need to implement:
629 # - __init__: Initialize the solver
630 # - sample: Main sampling function with your update rule
631 #
632 # Available helper methods from parent class:
633 # - self.get_text_embed(null_prompt, prompt): Get text embeddings
634 # - self.initialize_latent(): Initialize latent variable zT
635 # - self.predict_noise(zt, t, uc, c): Predict noise at timestep t
636 #
Tool Input
1{
2 "op": "replace",
3 "filename": "CFGpp-main/latent_diffusion.py",
4 "start_line": 624,
5 "end_line": 682,
6 "content": " #\n # CFG++ with Rescaled Guidance, Momentum Correction & Dynamic Thresholding.\n #\n # Novel contributions over standard DDIM/CFG++:\n # 1. Rescaled CFG: normalize guided noise to match conditional std,\n # preventing oversaturation that degrades CLIP alignment.\n # 2. CFG++ renoising: use unconditional noise for the diffusion step.\n # 3. Momentum-corrected z0t: EMA of clean-image estimates across steps.\n # 4. Dynamic thresholding (Imagen-style): clamp predicted clean latents.\n\n def __init__(self, solver_config, model_key=\"runwayml/stable-diffusion-v1-5\",\n device=None, **kwargs):\n super().__init__(solver_config, model_key, device, **kwargs)\n self.rescale_phi = 0.7\n self.momentum_beta = 0.12\n self.thresh_pct = 0.995\n\n @torch.autocast(device_type='cuda', dtype=torch.float16)\n def sample(self, cfg_guidance=7.5, prompt=[\"\",\"\"], callback_fn=None, **kwargs):\n uc, c = self.get_text_embed(null_prompt=prompt[0], prompt=prompt[1])\n zt = self.initialize_latent().requires_grad_()\n z0t_prev = None\n pbar = tqdm(self.scheduler.timesteps, desc=\"CFG++R\")\n for step, t in enumerate(pbar):\n at = self.alpha(t)\n at_prev = self.alpha(t - self.skip)\n with torch.no_grad():\n noise_uc, noise_c = self.predict_noise(zt, t, uc, c)\n noise_cfg = noise_uc + cfg_guidance * (noise_c - noise_uc)\n # Rescaled CFG: match std of conditional prediction\n std_cfg = noise_cfg.std(dim=[1,2,3], keepdim=True) + 1e-8\n std_c = noise_c.std(dim=[1,2,3], keepdim=True)\n noise_rescaled = noise_cfg * (std_c / std_cfg)\n noise_pred = self.rescale_phi * noise_rescaled + (1 - self.rescale_phi) * noise_cfg\n # Tweedie: predict clean image\n z0t = (zt - (1-at).sqrt() * noise_pred) / at.sqrt()\n # Dynamic thresholding\n s = torch.quantile(z0t.abs().reshape(z0t.shape[0], -1), self.thresh_pct, dim=1)\n s = torch.clamp(s, min=1.0)[:, None, None, None]\n z0t = z0t.clamp(-s, s) / s\n # Momentum correction: EMA of z0t (decays with progress)\n if z0t_prev is not None:\n progress = step / len(self.scheduler.timesteps)\n beta = self.momentum_beta * (1.0 - progress)\n z0t = (1 - beta) * z0t + beta * z0t_prev\n z0t_prev = z0t.detach().clone()\n # CFG++ renoising: unconditional noise for stable ODE trajectory\n zt = at_prev.sqrt() * z0t + (1-at_prev).sqrt() * noise_uc\n if callback_fn is not None:\n callback_kwargs = {'z0t': z0t.detach(), 'zt': zt.detach(),\n 'decode': self.decode}\n callback_kwargs = callback_fn(step, t, callback_kwargs)\n z0t = callback_kwargs[\"z0t\"]\n zt = callback_kwargs[\"zt\"]\n\n img = self.decode(z0t)\n img = (img / 2 + 0.5).clamp(0, 1)\n return img.detach().cpu()\n"
7}OK: Replaced lines 624..682 with 58 line(s) in CFGpp-main/latent_diffusion.py. Editable range: 624–681.
[Current file: CFGpp-main/latent_diffusion.py | editable: 624–681 | total: 1020 lines]
621| @register_solver("ddim_cfg++")
622| class BaseDDIMCFGpp(StableDiffusion):
623| # TODO: Implement your improved sampling method here.
624 #
625 # CFG++ with Rescaled Guidance, Momentum Correction & Dynamic Thresholding.
626 #
627 # Novel contributions over standard DDIM/CFG++:
628 # 1. Rescaled CFG: normalize guided noise to match conditional std,
629 # preventing oversaturation that degrades CLIP alignment.
630 # 2. CFG++ renoising: use unconditional noise for the diffusion step.
631 # 3. Momentum-corrected z0t: EMA of clean-image estimates across steps.
632 # 4. Dynamic thresholding (Imagen-style): clamp predicted clean latents.
633
634 def __init__(self, solver_config, model_key="runwayml/st
Tool Input
1{
2 "content": "@register_solver(\"ddim_cfg++\")\nclass BaseDDIMCFGpp(SDXL):\n # CFG++ with Rescaled Guidance, Momentum Correction & Dynamic Thresholding.\n # Same algorithmic innovations as the SD1.5 version, adapted for SDXL.\n rescale_phi = 0.7\n momentum_beta = 0.12\n thresh_pct = 0.995\n\n def reverse_process(self, null_prompt_embeds, prompt_embeds, cfg_guidance,\n add_cond_kwargs, shape=(1024, 1024), callback_fn=None, **kwargs):\n zt = self.initialize_latent(size=(1, 4, shape[1] // self.vae_scale_factor,\n shape[0] // self.vae_scale_factor))\n z0t_prev = None\n pbar = tqdm(self.scheduler.timesteps.int(), desc='SDXL-CFG++R')\n for step, t in enumerate(pbar):\n next_t = t - self.skip\n at = self.scheduler.alphas_cumprod[t]\n at_next = self.scheduler.alphas_cumprod[next_t]\n with torch.no_grad():\n noise_uc, noise_c = self.predict_noise(zt, t, null_prompt_embeds,\n prompt_embeds, add_cond_kwargs)\n noise_cfg = noise_uc + cfg_guidance * (noise_c - noise_uc)\n # Rescaled CFG\n std_cfg = noise_cfg.std(dim=[1,2,3], keepdim=True) + 1e-8\n std_c = noise_c.std(dim=[1,2,3], keepdim=True)\n noise_rescaled = noise_cfg * (std_c / std_cfg)\n noise_pred = self.rescale_phi * noise_rescaled + (1 - self.rescale_phi) * noise_cfg\n # Tweedie\n z0t = (zt - (1-at).sqrt() * noise_pred) / at.sqrt()\n # Dynamic thresholding\n s = torch.quantile(z0t.abs().reshape(z0t.shape[0], -1), self.thresh_pct, dim=1)\n s = torch.clamp(s, min=1.0)[:, None, None, None]\n z0t = z0t.clamp(-s, s) / s\n # Momentum correction\n if z0t_prev is not None:\n progress = step / len(self.scheduler.timesteps)\n beta = self.momentum_beta * (1.0 - progress)\n z0t = (1 - beta) * z0t + beta * z0t_prev\n z0t_prev = z0t.detach().clone()\n # CFG++ renoising with unconditional noise\n zt = at_next.sqrt() * z0t + (1-at_next).sqrt() * noise_uc\n if callback_fn is not None:\n callback_kwargs = {'z0t': z0t.detach(), 'zt': zt.detach(),\n 'decode': self.decode}\n callback_kwargs = callback_fn(step, t, callback_kwargs)\n z0t = callback_kwargs[\"z0t\"]\n zt = callback_kwargs[\"zt\"]\n return z0t\n\n\n",
3 "end_line": 755,
4 "filename": "CFGpp-main/latent_sdxl.py",
5 "op": "replace",
6 "start_line": 713
7}OK: Replaced lines 713..755 with 50 line(s) in CFGpp-main/latent_sdxl.py. Editable range: 713–762.
[Current file: CFGpp-main/latent_sdxl.py | editable: 713–762 | total: 1038 lines]
710| # CFG++ version
711| ###########################################
712|
713 @register_solver("ddim_cfg++")
714 class BaseDDIMCFGpp(SDXL):
715 # CFG++ with Rescaled Guidance, Momentum Correction & Dynamic Thresholding.
716 # Same algorithmic innovations as the SD1.5 version, adapted for SDXL.
717 rescale_phi = 0.7
718 momentum_beta = 0.12
719 thresh_pct = 0.995
720
721 def reverse_process(self, null_prompt_embeds, prompt_embeds, cfg_guidance,
722 add_cond_kwargs, shape=(1024, 1024), callback_fn=None, **kwargs):
723 zt = self.initialize_latent(size=(1, 4, shape[1] // self.vae_scale_factor,
724 shape[0] // self.vae_scale_factor))
725 z0t_pre
[Test #1] (2 tests remaining; call submit(n=1) to submit this result)
### sd15 (scripts/eval_sd15.sh)
Generation results:
GENERATION_METRICS model=sd15 method=ddim_cfg++ cfg_guidance=1.0 NFE=50 seed=42 fid=30.2613 clip_score=0.3121
sd15: CLIP=0.3121 FID=30.2613
Average CLIP score: 0.3121
Average FID: 30.2613
[Leaderboard] Results saved: {'clip_score_sd15': 0.3121, 'fid_sd15': 30.2613, 'clip_score': 0.3121, 'fid': 30.2613, 'elapsed_sd15': 946}