mlsys-sparse-attention
Description
Sparse Attention for Diffusion Model Inference
Research Question
Design an efficient sparse attention algorithm that accelerates diffusion model inference while preserving output quality. The algorithm should exploit the structured sparsity patterns inherent in diffusion model attention maps to skip unimportant computations.
Background
Attention is the computational bottleneck in large diffusion models for image and video generation. Full attention has O(N^2) complexity in sequence length N. FlashAttention-2 optimizes memory access patterns but still computes all attention scores. Recent work shows that diffusion model attention maps exhibit high block-level sparsity — many query-key block pairs contribute negligibly to the output. By identifying and skipping these blocks, we can achieve significant speedups.
Key approaches in the literature:
- Block-sparse attention: Divide Q/K into blocks, predict which blocks are important, skip the rest
- Mean-similarity filtering: Use cheap statistics (e.g., mean Q-K similarity) to predict block importance
- Quantized attention: Use int8/fp8 quantization for Q/K/V to reduce memory bandwidth
- Top-k selection: Only compute attention for the top-k most relevant blocks per query block
- Triton kernels: Write custom GPU kernels for fused block-sparse computation
What You Can Modify
The CustomAttention class (lines 28-92) in custom_sparse_bench.py:
- The sparsity pattern selection algorithm (which blocks to compute)
- The attention computation kernel (how to compute attention on selected blocks)
- Quantization strategy (int8/fp8 for Q, K, V)
- Block size and granularity of sparsity decisions
- Model-specific configuration via the
configure()method - Any caching or state management via
reset()
Interface contract: forward(q, k, v, is_causal=False) -> output
q, k, v: float16/bfloat16 tensors of shape(B, H, N, D)where B=batch, H=heads, N=seq_len, D=head_dimoutput: same shape and dtype asq- Must be numerically stable and produce reasonable attention output
Evaluation
- Models/Benchmarks: CogVideoX-2b (video generation, bfloat16, 49 frames, 50 steps, head_dim=64), Kernel micro-benchmark (synthetic tensors, multiple configs with head_dim=64/128), Wan2.1-T2V-1.3B (video generation, bfloat16, 33 frames, 30 steps, head_dim=128)
- Quality measurement: For each model, inference is run on 4 diverse prompts. LPIPS is computed per-prompt against the full-attention reference output, then averaged. Both mean and max LPIPS are reported.
- Metrics (per model):
throughput: frames/s (video) or images/s (image) -- higher is betterlatency_s: wall-clock inference time in seconds -- lower is betterpeak_mem_gb: peak GPU memory usage in GB -- lower is betterlpips_mean: mean LPIPS across all prompts vs full-attention reference -- lower is better (0.0 = identical)lpips_max: worst-case LPIPS across prompts -- lower is betterquality_pass: binary (1 = pass, 0 = fail) -- whether mean LPIPS <= 0.05
- Quality constraint: LPIPS degradation must stay <= 5% (mean LPIPS <= 0.05). This is a hard requirement. A submission that achieves high throughput but fails the quality check is considered invalid. The quality check ensures that sparse attention does not produce visually degraded outputs.
- Hardware: H100 GPU with CUDA 12, Triton available
- Goal: Maximize throughput while satisfying the quality constraint (mean LPIPS <= 0.05)
Code
1"""2Sparse Attention Benchmark for Diffusion Model Inference.34Evaluates a custom attention backend on CogVideoX-2b and PixArt-alpha.5The custom module is set as `attn.inner_attention` and called by the model's6standard attention processor, which handles QKV projection, norms, and rotary7embeddings. The agent only needs to implement the core attention computation.8"""910# ================================================================11# FIXED — imports and utilities (do not modify)12# ================================================================13import argparse14import gc15import os
Additional context files (read-only):
SpargeAttn/spas_sage_attn/core.pySpargeAttn/spas_sage_attn/autotune.pySpargeAttn/Triton_SpargeAttn/triton_kernel_example.py
Results
| Model | Type | throughput fps cogvideo ↑ | latency s cogvideo ↑ | peak mem gb cogvideo ↑ | lpips mean cogvideo ↑ | lpips max cogvideo ↑ | quality pass cogvideo ↑ | mean speedup kernel ↑ | peak mem gb kernel ↑ | throughput fps wan ↑ | latency s wan ↑ | peak mem gb wan ↑ | lpips mean wan ↑ | lpips max wan ↑ | quality pass wan ↑ |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| flash_attn_2 | baseline | 1.161 | 42.213 | 16.574 | 0.000 | 0.000 | 1.000 | 1.031 | 0.189 | 1.729 | 19.082 | 14.325 | 0.000 | 0.000 | 1.000 |
| sage_attention | baseline | 1.145 | 42.774 | 16.577 | 0.089 | 0.309 | 0.000 | 0.893 | 0.281 | 1.646 | 20.050 | 14.325 | 0.000 | 0.000 | 1.000 |
| sparge_attn | baseline | 1.832 | 26.747 | 16.574 | 0.507 | 0.584 | 0.000 | 1.194 | 0.313 | 2.174 | 15.179 | 14.325 | 0.000 | 0.000 | 1.000 |