mlsys-sparse-attention

ML SystemsSpargeAttnrigorous codebase

Description

Sparse Attention for Diffusion Model Inference

Research Question

Design an efficient sparse attention algorithm that accelerates diffusion model inference while preserving output quality. The algorithm should exploit the structured sparsity patterns inherent in diffusion model attention maps to skip unimportant computations.

Background

Attention is the computational bottleneck in large diffusion models for image and video generation. Full attention has O(N^2) complexity in sequence length N. FlashAttention-2 optimizes memory access patterns but still computes all attention scores. Recent work shows that diffusion model attention maps exhibit high block-level sparsity — many query-key block pairs contribute negligibly to the output. By identifying and skipping these blocks, we can achieve significant speedups.

Key approaches in the literature:

  • Block-sparse attention: Divide Q/K into blocks, predict which blocks are important, skip the rest
  • Mean-similarity filtering: Use cheap statistics (e.g., mean Q-K similarity) to predict block importance
  • Quantized attention: Use int8/fp8 quantization for Q/K/V to reduce memory bandwidth
  • Top-k selection: Only compute attention for the top-k most relevant blocks per query block
  • Triton kernels: Write custom GPU kernels for fused block-sparse computation

What You Can Modify

The CustomAttention class (lines 28-92) in custom_sparse_bench.py:

  • The sparsity pattern selection algorithm (which blocks to compute)
  • The attention computation kernel (how to compute attention on selected blocks)
  • Quantization strategy (int8/fp8 for Q, K, V)
  • Block size and granularity of sparsity decisions
  • Model-specific configuration via the configure() method
  • Any caching or state management via reset()

Interface contract: forward(q, k, v, is_causal=False) -> output

  • q, k, v: float16/bfloat16 tensors of shape (B, H, N, D) where B=batch, H=heads, N=seq_len, D=head_dim
  • output: same shape and dtype as q
  • Must be numerically stable and produce reasonable attention output

Evaluation

  • Models/Benchmarks: CogVideoX-2b (video generation, bfloat16, 49 frames, 50 steps, head_dim=64), Kernel micro-benchmark (synthetic tensors, multiple configs with head_dim=64/128), Wan2.1-T2V-1.3B (video generation, bfloat16, 33 frames, 30 steps, head_dim=128)
  • Quality measurement: For each model, inference is run on 4 diverse prompts. LPIPS is computed per-prompt against the full-attention reference output, then averaged. Both mean and max LPIPS are reported.
  • Metrics (per model):
    • throughput: frames/s (video) or images/s (image) -- higher is better
    • latency_s: wall-clock inference time in seconds -- lower is better
    • peak_mem_gb: peak GPU memory usage in GB -- lower is better
    • lpips_mean: mean LPIPS across all prompts vs full-attention reference -- lower is better (0.0 = identical)
    • lpips_max: worst-case LPIPS across prompts -- lower is better
    • quality_pass: binary (1 = pass, 0 = fail) -- whether mean LPIPS <= 0.05
  • Quality constraint: LPIPS degradation must stay <= 5% (mean LPIPS <= 0.05). This is a hard requirement. A submission that achieves high throughput but fails the quality check is considered invalid. The quality check ensures that sparse attention does not produce visually degraded outputs.
  • Hardware: H100 GPU with CUDA 12, Triton available
  • Goal: Maximize throughput while satisfying the quality constraint (mean LPIPS <= 0.05)

Code

custom_sparse_bench.py
EditableRead-only
1"""
2Sparse Attention Benchmark for Diffusion Model Inference.
3
4Evaluates a custom attention backend on CogVideoX-2b and PixArt-alpha.
5The custom module is set as `attn.inner_attention` and called by the model's
6standard attention processor, which handles QKV projection, norms, and rotary
7embeddings. The agent only needs to implement the core attention computation.
8"""
9
10# ================================================================
11# FIXED — imports and utilities (do not modify)
12# ================================================================
13import argparse
14import gc
15import os

Additional context files (read-only):

  • SpargeAttn/spas_sage_attn/core.py
  • SpargeAttn/spas_sage_attn/autotune.py
  • SpargeAttn/Triton_SpargeAttn/triton_kernel_example.py

Results

ModelTypethroughput fps cogvideo latency s cogvideo peak mem gb cogvideo lpips mean cogvideo lpips max cogvideo quality pass cogvideo mean speedup kernel peak mem gb kernel throughput fps wan latency s wan peak mem gb wan lpips mean wan lpips max wan quality pass wan
flash_attn_2baseline1.16142.21316.5740.0000.0001.0001.0310.1891.72919.08214.3250.0000.0001.000
sage_attentionbaseline1.14542.77416.5770.0890.3090.0000.8930.2811.64620.05014.3250.0000.0001.000
sparge_attnbaseline1.83226.74716.5740.5070.5840.0001.1940.3132.17415.17914.3250.0000.0001.000