Fused Causal Attention Kernel

Studies how fused self-attention kernels improve throughput and latency while preserving numerical agreement.

ML Systems & Efficient MLflash-attention
mlsys-fused-attention

Description

Fused Attention Kernel Design for H100 GPUs

Research Question

Design an efficient fused self-attention forward pass kernel using OpenAI Triton that maximizes throughput (TFLOPs/s) on H100 GPUs while maintaining numerical correctness.

Background

Self-attention is the computational bottleneck of Transformer models. The standard implementation materializes the full N x N attention score matrix, requiring O(N^2) memory and O(N^2 d) FLOPs. FlashAttention (Dao et al., NeurIPS 2022; arXiv:2205.14135) introduced a tiled, IO-aware algorithm using online softmax that avoids materializing the full matrix, reducing HBM accesses from O(Nd + N^2) to O(N^2 d^2 / M) where M is SRAM size.

Subsequent versions improved throughput through better parallelism and hardware-specific optimizations:

  • FlashAttention-2 (Dao, 2023; arXiv:2307.08691): reduced non-matmul FLOPs, parallelized over sequence length, better warp-level work partitioning.
  • FlashAttention-3 (Shah et al., NeurIPS 2024; arXiv:2407.08608): exploits H100 Hopper features — warp specialization (producer/consumer warps overlapping TMA loads with GMMA compute), GEMM-softmax interleaving, and FP8 support. The paper reports ~740 TFLOPs/s on H100 in FP16 (~75% utilization).

Task

Modify the custom_attention_forward function and the associated Triton kernel _custom_attn_fwd to implement an efficient fused attention forward pass. You may:

  • Redesign the tiling strategy (block sizes, tile shapes)
  • Optimize the online softmax computation (e.g., use exp2 instead of exp, delay rescaling)
  • Improve memory access patterns (coalescing, prefetching)
  • Split the causal/non-causal iteration into separate passes to avoid per-block masking overhead
  • Use Triton autotuning (@triton.autotune) to search configurations
  • Define multiple helper kernels if needed

Interface

def custom_attention_forward(q, k, v, causal=True, sm_scale=None):
    """
    Args:
        q, k, v: (batch, nheads, seqlen, headdim), contiguous, FP16
        causal: if True, apply causal mask (key_pos <= query_pos)
        sm_scale: softmax scale factor (default: 1/sqrt(headdim))
    Returns:
        output: (batch, nheads, seqlen, headdim), same dtype as input
    """

Correctness constraint: max absolute difference from reference (PyTorch SDPA) must be < 1e-2.

Evaluation

Benchmarked on multiple causal configurations aligned with the FA3 paper (total tokens = 16384):

ConfigBatchSeqLenHeadsHeadDim
hdim64_seq4k440963264
hdim128_seq8k2819216128
hdim256_seq16k1163848256

All configurations use FP16, causal masking, on H100 80GB SXM5.

Metrics (per configuration):

  • tflops: achieved TFLOPs/s (higher is better) — primary metric
  • latency_ms: kernel latency in milliseconds (lower is better)
  • correct: binary (1 if max_diff < 1e-2, else 0) — hard constraint

FLOP formula (FA2/FA3 convention): 4 * batch * seqlen^2 * nheads * headdim / 2 (causal).

Hints

  • The default template provides a basic flash attention kernel. Key optimization opportunities:
    1. Two-pass causal: split the K/V loop into non-causal blocks (no mask check) and causal boundary blocks, reducing branch overhead
    2. Block size tuning: different (BLOCK_M, BLOCK_N) for different headdims — larger blocks amortize loop overhead but increase register pressure
    3. Triton autotuning: use @triton.autotune with configs=[...] to search block sizes at compile time
    4. Reduced rescaling: in the online softmax, the rescaling acc *= alpha can be deferred or batched to reduce non-matmul operations
    5. Memory coalescing: ensure K/V loads are coalesced along the headdim dimension
  • The Triton tutorial fused attention demonstrates the two-pass approach
  • FA3 achieves its speedup through Hopper-specific CUDA features (warp specialization, TMA, GMMA) that are not directly accessible from Triton — closing the gap requires algorithmic cleverness in the Triton DSL
  • Available imports: torch, triton, triton.language as tl, math, torch.nn.functional as F

Code

custom_triton_bench.py
EditableRead-only
1"""Fused Attention Kernel Benchmark — H100 GPU.
2
3Benchmark harness for evaluating custom Triton attention kernels.
4Aligned with Flash Attention 3 (Shah et al., NeurIPS 2024) benchmarks.
5FLOP formula (FA2/FA3): 4 * batch * seqlen^2 * nheads * headdim (halved for causal).
6Total tokens fixed at 16384 (batch = 16384 / seqlen).
7"""
8
9import argparse
10import math
11import os
12import time
13
14import torch
15import torch.nn.functional as F

Method Summary

Auto-summarized from each method's code by an LLM reviewer — not the model's original output. Browse via the picker below; the Code section is independent.
Baselines
Agents
Claude Opus 4.6·Pseudocodehigh

Two-pass Causal + Wide Autotune Sweep

Flash-v3-style two-pass causal kernel with fused log2(e)*scale into Q and an enlarged 19-config autotune sweep across (BLOCK_M, BLOCK_N, num_stages, num_warps).

1. wrapper: qq(slog2e)q \leftarrow q \cdot (s\cdot\log_2 e), then launch autotuned _custom_attn_fwd
2. kernel: each (m_block, head) program loads Q tile once
3. pass 1 — non-causal blocks (no mask, no boundary check):
for nn in [0,m0/BNBN)[0, \lfloor m_0/B_N\rfloor B_N):
S=QKS = QK^\top; online-softmax in exp2\exp_2; acc += PV
4. pass 2 — causal-boundary blocks:
for nn in [boundary,(m0+1)BM)[\text{boundary}, (m_0+1)B_M):
SSwhere(ij,S,)\mathrm{where}(i \geq j, S, -\infty); same online-softmax update
5. acc /= l_i; store O
Autotune key = (seqlen, BLOCK_DMODEL, IS_CAUSAL); sweeps BLOCK_M ∈ {16, 32, 64, 128, 256}, BLOCK_N ∈ {32, 64, 128}, num_stages ∈ {2..4}, num_warps ∈ {4, 8}
Δ vs. baselineBeyond flash_v3 it removes the per-load boundary mask in pass 1 (full tiles) and roughly doubles the autotune search space, including very large BLOCK_M=256 tiles for hdim=64 and small BLOCK_M ∈ {16, 32} tiles for hdim=256 to manage register pressure.
scale_fusion=q *= sm_scale * log2(e)two_pass_causal=non-causal full tiles + causal boundary tileautotune_configs=19autotune_key=(seqlen, BLOCK_DMODEL, IS_CAUSAL)BLOCK_DMODEL=headdimnon_causal_load_mask=removedRecovers flash_v3 baseline at the matching autotune choice (default config set is a strict superset)

Results