FlashAttention-4: Algorithm and Kernel Co-Design for Blackwell GPUs

On March 5, 2026, Tri Dao and collaborators from Princeton University, Together AI, Meta, NVIDIA, and Colfax Research released FlashAttention-4 — a ground-up redesign of the attention kernel optimized for NVIDIA’s Blackwell GPUs. The new kernel reaches up to 1,605 TFLOPs/s on the B200 (71% hardware utilization), running 1.3x faster than cuDNN 9.13 and 2.7x faster than Triton-based implementations.
Why a New Version?
Each GPU generation scales tensor core throughput faster than shared memory bandwidth or special function units (SFUs). On Blackwell, tensor cores can process tiles of 128x256x16 — roughly 4x larger than Hopper’s — but shared memory and exponential-function units haven’t kept pace. FlashAttention-4 treats this asymmetry as a first-class design constraint, co-designing both the algorithm and kernel pipeline to hide non-matmul bottlenecks behind tensor core work.
How It Works
FA4 introduces a warp-specialized, multi-stage asynchronous pipeline where different warps handle distinct roles concurrently within a single kernel:
- Load warps stream Q, K, V tiles from global memory into shared memory via the Tensor Memory Accelerator (TMA)
- MMA warps perform matrix multiplications to compute unnormalized attention scores on 5th-gen async tensor cores
- Softmax warps (8 dedicated warps) normalize scores and track running statistics
- Correction warps rescale outputs only when the running maximum changes enough to affect numerical stability — reducing rescaling operations by roughly 10x compared to FA3
- Epilogue warps store final results back to global memory
A key optimization is the hybrid exponential computation: instead of relying solely on the GPU’s Special Function Units (which bottleneck at high throughput), FA4 approximates 2^x using a cubic polynomial on FMA units for smaller head dimensions, matching BF16 precision while freeing SFUs for other work.
The backward pass also sees major improvements. Intermediate results are stored in Blackwell’s new Tensor Memory (TMEM) — 256 KB per SM wired directly to tensor cores — reducing shared-memory traffic. A 2-CTA MMA mode distributes accumulation across paired CTAs, cutting atomic reductions by 50%. FA4 even supports a deterministic execution mode for reproducible training at 85–90% of peak throughput.
FlexAttention Integration
PyTorch’s FlexAttention API now supports FA4 as a backend on Hopper and Blackwell GPUs. Researchers can write custom attention variants (ALiBi, sliding window, document masking, soft-capping) as simple Python score_mod functions and have them JIT-compiled into FA4 kernels — delivering 1.2x to 3.2x speedups over the previous Triton backend without writing any CUDA.
What This Means
FlashAttention has become critical infrastructure for large-scale transformer training, and FA4 continues that trajectory by squeezing maximum performance from the latest hardware. The 71% utilization figure on B200 is notable — attention kernels have historically been memory-bound and hard to optimize beyond 50–60% on previous generations.
The entire implementation is written in CuTe-DSL, CUTLASS’s Python-based kernel DSL, which compiles 20–30x faster than equivalent C++ templates. The source code is available on GitHub. However, FA4 currently requires Blackwell or Hopper hardware — users on older GPUs will continue using FlashAttention-2 or FA3.
For the AI research community, the practical impact is faster and cheaper training runs on Blackwell clusters, plus the FlexAttention integration means custom attention patterns no longer carry a steep performance penalty.



沪公网安备31011502017015号