Billy Tse
HomeRoadmapBlogContact
Playground
Buy me a bug

© 2026 Billy Tse

OnlyFansLinkedInGitHubEmail
Back to Blog
February 4, 2026•25 min read

FlashAttention 1, 2, 3, 4 完全解析:由 IO-Aware 到 Blackwell Petaflop

深入解析 FlashAttention 四代演進:從 FlashAttention 1 嘅 IO-aware tiling,到 FlashAttention 2 嘅並行優化,再到 FlashAttention 3 嘅異步計算同 FP8 支援,最後到 FlashAttention 4 喺 Blackwell B200 突破 1.6 PetaFLOPs,了解點樣將 Transformer Attention 推向極致

Attention MechanismsInference OptimizationHardware AccelerationCSCI 5640 NLP

當 Transformer 統治咗成個 AI 世界,Attention 機制嘅計算效率就成為咗最大嘅樽頸。FlashAttention 用 IO-aware 嘅思路,將原本需要 O(N2)O(N^2)O(N2) memory 嘅 Attention 壓縮到 O(N)O(N)O(N),而且仲越做越快。由第一代嘅 tiling 技術,到第二代嘅並行優化,再到第三代嘅異步計算同 FP8 支援,最後到第四代喺 Blackwell 架構突破 Petaflop 級別,呢四代 FlashAttention 點樣將 Transformer Attention 推向極致?

TL;DR

核心重點:

  • 🎯 FlashAttention 1(2022):IO-aware tiling(分塊),將 Attention 由 HBM 搬入 SRAM,減少 memory access bottleneck
  • 🔄 FlashAttention 2(2023):改變 parallelism 策略(partition over K/V),減少 non-matmul FLOPs,達到 2× 加速
  • ⚡ FlashAttention 3(2024):異步計算 + FP8 低精度(Hopper H100),GPU 利用率去到 75%(vs FA2 嘅 35%)
  • 🚀 FlashAttention 4(2025):Blackwell B200 突破,達到 1,605 TFLOPS(71% 理論峰值),係第一個突破 Petaflop 嘅 attention kernel
  • ✅ 實際效果:訓練速度提升 3-15×,長序列支援 128K+ tokens,幾乎零 memory overhead
  • ⚠️ 重要:FlashAttention 係 Dense Attention(計晒所有 N×N pairs),唔係 Sparse Attention

Table of Contents

  • 問題背景:Attention 嘅 Memory Wall
  • FlashAttention 1:IO-Aware Tiling 革命
  • FlashAttention 2:Work Partitioning 同 Parallelism 優化
  • FlashAttention 3:異步計算同 FP8 低精度
  • FlashAttention 4:Blackwell 突破 Petaflop
  • 效能比較同實際應用
  • 何時用 FlashAttention?
  • 總結
  • 相關資源

問題背景:Attention 嘅 Memory Wall

Standard Attention 嘅瓶頸

標準 Attention 機制嘅計算流程:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dk​​QKT​)V

where:

  • Q,K,VQ, K, VQ,K,V:Query, Key, Value matrices(shape: NtimesdN times dNtimesd)
  • NNN:sequence length
  • dkd_kdk​:head dimension

問題出喺邊?

喺 GPU 上面,Attention 嘅計算涉及多次 HBM(High Bandwidth Memory)同 SRAM(on-chip cache)之間嘅數據搬運:

  1. 由 HBM load Q,KQ, KQ,K → 計算 S=QKTS = QK^TS=QKT → 寫返 SSS 入 HBM
  2. 由 HBM load SSS → 計算 P=softmax(S)P = \text{softmax}(S)P=softmax(S) → 寫返 PPP 入 HBM
  3. 由 HBM load P,VP, VP,V → 計算 O=PVO = PVO=PV → 寫 output 入 HBM

⚠️ Memory Bottleneck
HBM bandwidth(~1.5 TB/s on A100)遠低於 SRAM bandwidth(~19 TB/s),而 Attention 需要 O(N2)O(N^2)O(N2) 嘅 memory 去 store SSS 同 PPP matrices。當 sequence length NNN 去到幾千甚至幾萬,HBM access 就成為咗最大嘅瓶頸。

⚠️ 重要澄清:呢啲全部都係 GPU 上面嘅 memory!

  • HBM (DRAM) = GPU 嘅 VRAM:就係你買 GPU 時講嘅「40 GB VRAM」,用 HBM2e/HBM3 技術

  • SRAM = GPU chip 入面嘅 on-chip cache:唔係 CPU 嘅 memory!每個 GPU 嘅 SM 都有自己嘅 SRAM

簡單講:DRAM 同 SRAM 都喺 GPU 上面,分別係:

  • DRAM (HBM):喺 GPU chip 外面,大但慢(就係 VRAM)

  • SRAM:喺 GPU chip 入面,細但快(L1 cache / Shared Memory)

Memory Type實際硬件SizeBandwidth物理位置你可能聽過嘅名
HBM (DRAM)HBM2e/HBM3 stacks40-80 GB~1.5-3 TB/sGPU 封裝外圍,用 silicon interposer 連接✅ VRAM(Video RAM)
SRAM (L1/Shared)On-chip SRAM~192 KB/SM~19 TB/sGPU chip 入面,每個 SM 有自己嘅 SRAMShared Memory / L1 Cache
L2 CacheOn-chip SRAM40-50 MB~4-7 TB/sGPU chip 入面,所有 SM 共享L2 Cache

💡 生活例子:圖書館 vs 書枱
想像你做緊功課:

  • HBM (DRAM) = 圖書館:好大(幾萬本書),但每次拎書要行去書架,好慢

  • SRAM = 你書枱:好細(只放到 10 本書),但拎書即刻就拎到,好快

  • Standard Attention = 每計一步都要行去圖書館拎書,再擺返落書架

  • FlashAttention = 一次過拎晒需要嘅書返嚟書枱,做晒先擺返

Key insight:行去圖書館(HBM access)嘅時間 >> 睇書(計算)嘅時間,所以減少往返次數先係關鍵!

❓ 等等,SRAM 咁細,點放得落成個 Attention?
你可能會問:SRAM 只有 ~192 KB,但一個 Attention matrix(N×N)可以好大(例如 16K tokens → 256 MB),點可能放得落?

答案:唔係將所有嘢一次過放入 SRAM,而係用 Tiling(分塊)分批處理!

  • ❌ 唔係:將成個 N×N matrix 一次過放入 SRAM(根本放唔落)

  • ✅ 而係:將大 matrix 切成細塊(e.g., 64×64),逐個 block 放入 SRAM 計算

  • 每個 block 細到放得入 SRAM → 計完即棄 → load 下一個 block

  • trade-off:雖然要做多次 HBM access(因為分批),但比起 standard attention 少好多!

FlashAttention 嘅策略:

唔係消除所有 HBM access(唔可能),而係將 HBM access 由 O(N2)O(N^2)O(N2) 降到 O(N2d2M−1)O(N^2 d^2 M^{-1})O(N2d2M−1),其中 MMM 係 SRAM size。透過 tiling,將「每個中間步驟都要寫入 HBM」變成「只係 load inputs 同 write final outputs」。

FlashAttention 1:IO-Aware Tiling 革命

核心思路:Tiling(分塊)同 Online Softmax

FlashAttention 1 嘅突破在於兩個技術:

  1. Tiling:將大 matrix 切成細塊

咩係 Tiling?

Tiling 就係將大 matrix 分成細嘅 blocks(tiles),逐個 block 放入 SRAM 處理。

點解需要 Tiling?

  • 一個 16K×16K 嘅 Attention matrix = 256 MB(用 FP16)
  • SRAM 只有 ~192 KB per SM
  • 結論:根本唔可能將成個 matrix 一次過放入 SRAM!

解決方法:將大 matrix 切成細塊(e.g., 128×128 = 32 KB),咁就放得入 SRAM。

Loading diagram...

🎯 生活例子:搬屋
你要搬 100 箱嘢,但你部車一次只載到 10 箱:

  • Without Tiling:租架大貨車,一次過載晒 100 箱

  • 但你根本租唔到咁大架車(= SRAM 太細,放唔落)

  • 唯有用好大嘅倉庫(= HBM)暫存,但出入倉好慢

  • With Tiling:用你自己部車,分 10 次,每次載 10 箱

  • 每次只用少少空間(= 放得入 SRAM)

  • 雖然要行 10 次,但因為你唔需要用慢嘅倉庫,overall 反而快咗

關鍵:FlashAttention 唔係「將所有嘢放入 SRAM」(根本做唔到),而係「逐小塊處理,每小塊都放得入 SRAM」。

Tiling 示意圖:

假設 Q,K,VQ, K, VQ,K,V 都係 N×dN \times dN×d matrix,我哋分成 Br×BcB_r \times B_cBr​×Bc​ 嘅 blocks:

Loading diagram...

Tiling 流程:

  1. 將 Q 分成 4 個 blocks(Q₁, Q₂, Q₃, Q₄),每個 block 例如 128×64 = 16 KB(放得入 SRAM)
  2. 將 K 都分成 4 個 blocks(K₁, K₂, K₃, K₄)
  3. 逐對計算:
    • Load Q₁ 同 K₁ 入 SRAM → 計 S₁₁ = Q₁K₁ᵀ → 用完即棄
    • Load Q₁ 同 K₂ 入 SRAM → 計 S₁₂ = Q₁K₂ᵀ → accumulate
    • 繼續...
  4. 關鍵:每次只處理一個細 block,唔需要 store 成個 N×N matrix

實際數字例子:

  • 假設 N = 16K tokens, d = 64 (head dim), block size = 128
  • 完整 attention matrix: 16K × 16K × 2 bytes = 512 MB(放唔入 SRAM)
  • 每個 block: 128 × 128 × 2 bytes = 32 KB(✅ 放得入 SRAM!)
  • 需要嘅 blocks: (16K/128)² = 16,384 個 blocks
  • FlashAttention 會逐個 block 計,所以 peak memory 只係 32 KB,唔係 512 MB
  1. Online Softmax:唔需要睇晒全部先計

Online Softmax 可以 incremental 咁更新 softmax,唔需要 store 成個 SSS matrix。

🎯 核心概念:IO-Aware Algorithm
唔係單純優化 FLOPs,而係優化 memory I/O。通過 tiling 同 recomputation,將 Attention 嘅 memory footprint 由 O(N2)O(N^2)O(N2) 降到 O(N)O(N)O(N),同時 HBM access 減少到 O(N2d2M−1)O(N^2 d^2 M^{-1})O(N2d2M−1)(MMM 係 SRAM size)。

Online Softmax 原理

傳統 softmax 需要兩次 pass:

  1. 計 m=max⁡(xi)m = \max(x_i)m=max(xi​)
  2. 計 softmax(xi)=exi−m∑jexj−m\text{softmax}(x_i) = \frac{e^{x_i - m}}{\sum_j e^{x_j - m}}softmax(xi​)=∑j​exj​−mexi​−m​

Online 版本可以 incremental 更新:

當有新嘅 block x(2)x^{(2)}x(2) 加入已計算嘅 x(1)x^{(1)}x(1):

m(2)=max⁡(m(1),mnew(2))m^{(2)} = \max(m^{(1)}, m^{(2)}_{\text{new}})m(2)=max(m(1),mnew(2)​) ℓ(2)=em(1)−m(2)ℓ(1)+emnew(2)−m(2)ℓnew(2)\ell^{(2)} = e^{m^{(1)} - m^{(2)}} \ell^{(1)} + e^{m^{(2)}_{\text{new}} - m^{(2)}} \ell^{(2)}_{\text{new}}ℓ(2)=em(1)−m(2)ℓ(1)+emnew(2)​−m(2)ℓnew(2)​

咁就可以一邊 load 新 blocks,一邊更新 softmax,唔使 store 成個 matrix。

Backward Pass:Recomputation

FlashAttention 1 喺 backward pass 唔 save SSS 同 PPP,而係 recompute 佢哋。

Trade-off:

  • ✅ Memory footprint 大減(唔使 store O(N2)O(N^2)O(N2) 嘅 intermediate results)
  • ⚠️ 額外嘅 FLOPs(recomputation cost)
  • ✅ 但因為減少咗 HBM I/O,overall 反而更快

效能提升

vs Standard Attention(PyTorch implementation):

  • GPT-2 訓練:15% faster end-to-end
  • 長序列(e.g., Path-X 16K tokens):3× faster
  • Memory usage:可以 train 去到 64K tokens(standard attention OOM)

FlashAttention 2:Work Partitioning 同 Parallelism 優化

FlashAttention 2(2023 年)專注解決 FA1 嘅兩大瓶頸:低 GPU occupancy 同 過多 non-matmul operations。

核心問題:FA1 嘅 Parallelism 策略唔夠好

FA1 嘅做法(partition over Q):

  • 將 QQQ 分成 TrT_rTr​ 個 blocks,每個 thread block 負責一個 QQQ block
  • 每個 thread block 需要 iterate over 所有 K,VK, VK,V blocks
  • 問題:唔同 thread blocks 會重複 load 相同嘅 K,VK, VK,V blocks

具體例子:

假設有 4 個 thread blocks(TB0, TB1, TB2, TB3),每個負責一個 QQQ block:

  • TB0 處理 Q₁:需要 load K₁, K₂, K₃, K₄ 同 V₁, V₂, V₃, V₄
  • TB1 處理 Q₂:又要 load K₁, K₂, K₃, K₄ 同 V₁, V₂, V₃, V₄(重複!)
  • TB2 處理 Q₃:又要 load K₁, K₂, K₃, K₄ 同 V₁, V₂, V₃, V₄(重複!)
  • TB3 處理 Q₄:又要 load K₁, K₂, K₃, K₄ 同 V₁, V₂, V₃, V₄(重複!)

結果:K, V 被 load 咗 4 次,浪費 memory bandwidth!

FA2 嘅解決方法:Partition over K/V instead

FA2 嘅新策略(partition over K, V, batch, head):

  • 將 parallelism 改為 outer loop over K/V blocks
  • 每個 thread block 負責一個 K,VK, VK,V block,然後 iterate over 所有 QQQ blocks
  • 好處:每個 K,VK, VK,V block 只 load 一次,所有 thread blocks 共享

同一個例子用 FA2:

假設有 4 個 thread blocks,每個負責一個 K,VK, VK,V block:

  • TB0 負責 K₁V₁:load 一次 K₁V₁,iterate Q₁, Q₂, Q₃, Q₄
  • TB1 負責 K₂V₂:load 一次 K₂V₂,iterate Q₁, Q₂, Q₃, Q₄
  • TB2 負責 K₃V₃:load 一次 K₃V₃,iterate Q₁, Q₂, Q₃, Q₄
  • TB3 負責 K₄V₄:load 一次 K₄V₄,iterate Q₁, Q₂, Q₃, Q₄

結果:K, V 每個只 load 一次,減少 4× memory access!

🎯 Work Partitioning 策略比較

  • FA1:for each Q_block: for each K_block: compute attention

  • 每個 Q block 係一個 thread block

  • K, V 被重複 load TrT_rTr​ 次(TrT_rTr​ = Q blocks 數量)

  • FA2:for each K_block: for each Q_block: compute attention

  • 每個 K/V block 係一個 thread block

  • K, V 只 load 一次

  • 額外 benefit:可以同時 parallelize over batch 同 head dimension

技術細節:點樣處理 Output Accumulation?

因為 FA2 改咗 loop order,每個 thread block 只計算 output 嘅一部分,需要 atomic add 嚟 accumulate results。

FA2 用 locking mechanism:

  • 每個 output block 有一個 lock
  • Thread block 計完自己嘅 partial output 後,acquire lock → update → release lock
  • 用 GPU 嘅 atomic operations(e.g., atomicAdd)實現

Trade-off:

  • ✅ 減少 K, V memory access(主要 bottleneck)
  • ⚠️ 增加少少 synchronization overhead(但遠比 memory savings 值得)

2. 減少 Non-Matmul FLOPs

問題背景:

喺 GPU 上,matmul operations(e.g., QKTQK^TQKT, PVPVPV)可以充分利用 Tensor Cores,達到 peak throughput。但其他 operations(e.g., softmax, dropout, masking, rescaling)係 element-wise operations,慢好多。

FA1 嘅 non-matmul FLOPs:

  • Forward: rescaling, softmax, dropout, masking
  • Backward: softmax gradient 特別貴(需要 recompute softmax)

FA2 優化:

  1. 簡化 backward pass:用更高效嘅 online softmax gradient 公式,減少 recomputation
  2. 更好嘅 warp scheduling:將 non-matmul ops 同 matmul ops overlap
  3. Reduce rescaling ops:將某啲 rescaling 提前做,減少重複計算

💡 Non-matmul FLOPs 比例:

  • FA1 forward: ~25% FLOPs 係 non-matmul(softmax, masking 等)

  • FA1 backward: ~40% FLOPs 係 non-matmul(softmax gradient 好貴!)

  • FA2: 將 non-matmul 降到 ~10-15%(forward + backward)

  • 結果:雖然 total FLOPs 差唔多,但因為 Tensor Cores 利用率高咗,throughput 提升 2×

效能數字

ModelSequence LengthFA1 SpeedupFA2 Speedup
BERT-large5121.7×2.2×
GPT-320483.0×5.1×
Long document16K5.2×9.3×

(相對 PyTorch standard attention)

FlashAttention 3:異步計算同 FP8 低精度

FlashAttention 3(2024 年)專為 Hopper architecture(H100 GPU)設計。

1. Asynchronous Computation

問題:FA2 喺 Tensor Cores 做 matmul 嘅時候,其他 units(e.g., softmax, SRAM load/store)會 idle。

解決方法:利用 Hopper 嘅 asynchronous TMA(Tensor Memory Accelerator):

  • Matmul 同 data movement overlap
  • 當 warp 0 做 matmul 嘅時候,warp 1-3 可以 prefetch 下一個 block

🚀 效能突破
通過 async overlap,FA3 將 GPU utilization 由 FA2 嘅 35% 提升到 75%,喺 H100 上達到 1.5-2× speedup over FA2。

2. FP8 Low-Precision Support

Hopper 支援 FP8(8-bit floating point),可以:

  • 2× matmul throughput(相對 FP16/BF16)
  • 減少 memory bandwidth(因為數據細咗一半)

FA3 嘅 FP8 strategy:

  • Q,K,VQ, K, VQ,K,V 用 FP8 做 matmul
  • Intermediate results(e.g., softmax)仍然用 FP16,保持 numerical stability
  • Incoherent processing:避免 FP8 量化誤差 accumulate

3. Block Quantization

為咗 minimize FP8 quantization error,FA3 用 block-wise quantization:

每個 tile(e.g., 64×64)有自己嘅 scaling factor:

QFP8=quantize(QFP16sQ)Q_{\text{FP8}} = \text{quantize}\left(\frac{Q_{\text{FP16}}}{s_Q}\right)QFP8​=quantize(sQ​QFP16​​)

效能數字(H100)

  • Forward pass:1.5× faster than FA2
  • Backward pass:1.8× faster
  • Overall training:~1.7× speedup
  • Sequence length 16K:throughput 去到 2.2 PFLOPs/s

FlashAttention 4:Blackwell 突破 Petaflop

FlashAttention 4(2025 年 1 月發布)係第一個突破 Petaflop barrier 嘅 attention kernel,專為 NVIDIA Blackwell architecture(B100/B200)設計。

Blackwell Architecture 嘅關鍵特性

NVIDIA Blackwell(2024 年發布)嘅新功能:

  1. Second-generation Transformer Engine:
    • 硬件層面優化 attention operations
    • 支援 FP4(4-bit floating point)+ micro-tensor scaling
    • 2× attention-layer acceleration(vs Hopper)
    • 1.5× 更多 AI compute FLOPs
  2. 更大 SRAM (shared memory):
    • 每個 SM 嘅 shared memory 增加到 256 KB(vs Hopper 嘅 228 KB)
    • 可以處理更大嘅 tiling blocks → 減少 HBM access
  3. HBM3e 更快 memory bandwidth:
    • HBM3e bandwidth: 8 TB/s(vs H100 嘅 3.35 TB/s)
    • 雖然 FlashAttention 主要 focus 減少 HBM access,但更快 HBM 仍然有幫助(尤其係 load Q, K, V inputs 同 write final outputs)
  4. 5th-gen NVLink(1.8 TB/s):
    • NVLink 係 GPU-to-GPU 嘅高速互連技術
    • ⚠️ 重要:FlashAttention 本身唔需要 NVLink(single GPU 就行)
    • 但喺 multi-GPU 訓練(例如訓練超大型 LLM),NVLink 負責:
      • 跨 GPU 傳輸 gradients、activations、model parameters
      • 實現 data parallelism 同 model parallelism
    • Blackwell 嘅 5th-gen NVLink(1.8 TB/s)比 Hopper 4th-gen(900 GB/s)快 2×
    • 結合效果:FlashAttention 優化 single GPU 上嘅 memory,NVLink 優化 multi-GPU 之間嘅 communication,兩者配合先可以充分發揮大規模訓練嘅性能
  5. Ultra Tensor Cores:
    • Blackwell Ultra 版本進一步增強
    • 支援 NVFP4(Nvidia 專有 FP4 格式)
    • 1-bit sign, 2-bit exponent, 1-bit mantissa

FA4 嘅突破:首個 Petaflop Attention Kernel

核心成就:

  • 喺 Blackwell B200 達到 1,605 TFLOPS(forward pass)
  • 相當於理論峰值嘅 71%(2,250 TFLOPS 係 B200 理論值)
  • 首個突破 1 Petaflop(1,000 TFLOPS) 嘅 attention implementation

重要 note:

  • FA4 目前只用 BF16(16-bit brain float),未用 Blackwell 嘅 FP4 功能
  • 已經達到 71% utilization 係一個好驚人嘅成就(因為 attention 有好多 non-matmul ops)
  • 未來如果加入 FP4 support,理論上可以再快 2-4×

🚀 FlashAttention 4 on Blackwell:實測數字
基於 Blackwell B200 實測:

  • Forward pass: 1,605 TFLOPS(71% theoretical peak)

  • Backward pass: 預計 ~1,400 TFLOPS(official numbers 未公佈)

  • Overall training: 估計比 FA3 on H100 快 2-2.5×

  • Sequence length 16K: 預計 throughput ~4.5 PFLOPs/s(vs FA3 嘅 2.2 PFLOPs/s)

  • Architecture: 目前用 BF16 only,未使用 Blackwell 嘅 FP4/FP6 capabilities

FA4 技術改進(相對 FA3)

雖然詳細技術細節未完全公開,但根據 reverse engineering 同 community 分析:

  1. 更激進嘅 warp scheduling:
    • Blackwell 嘅 SM 有更多 warp scheduler
    • FA4 充分利用呢個特性做更多 async overlap
  2. 優化 Tensor Memory 使用:
    • Blackwell 引入新嘅 on-chip memory tier(介乎 L1 同 HBM 之間)
    • FA4 可能利用呢層 memory 做更好嘅 tiling
  3. 減少 synchronization overhead:
    • FA2/FA3 需要 atomic operations 做 output accumulation
    • FA4 可能用 Blackwell 嘅 hardware-accelerated reduction
  4. 更細 tile size:
    • 因為有 256 KB shared memory,可以用更細嘅 tiles
    • 減少 HBM roundtrips,提高 data reuse

實際支援情況

# FlashAttention 4 on different architectures import torch from flash_attn import flash_attn_func # Auto-detect GPU architecture device = torch.device("cuda") arch = torch.cuda.get_device_capability() if arch >= (10, 0): # Blackwell (B100, B200, GB200) # FlashAttention 4 - Petaflop performance! print("Using FlashAttention 4 on Blackwell") print("Expected: 1,605+ TFLOPS (BF16)") elif arch >= (9, 0): # Hopper (H100, H200) # 支援 FP8 + async ops print("Using FlashAttention 3 with FP8 & async") else: # Ampere (A100) or older # Fall back to FA2 print("Using FlashAttention 2")

⚠️ Blackwell 支援 Status(2026 年 2 月)

  • ✅ FlashAttention 4 已發布(2025 年 1 月)

  • ✅ Blackwell B200/B100 已支援

  • ⚠️ FP4 support 未啟用:FA4 目前只用 BF16

  • ⚠️ 某啲 Blackwell variants(e.g., AGX Thor)可能未完全支援

  • 📅 FP4-enabled version 預計 2026 年中發布

效能比較同實際應用

Benchmark Summary

MetricStandardFA1FA2FA3FA4
GPUA100A100A100H100B200
發布年份-2022202320242025
HBM AccessO(N2)O(N^2)O(N2)O(N2d2M−1)O(N^2 d^2 M^{-1})O(N2d2M−1)SameSameSame
GPU Utilization~20%~25%~35%~75%~71%
Peak Throughput~200 TFLOPS~250 TFLOPS~450 TFLOPS~950 TFLOPS1,605 TFLOPS
GPT-3 (2K seq)Baseline3.0×5.1×8.5×~15×
Max Seq Length~4K64K64K+128K+128K+
PrecisionFP16FP16FP16/BF16FP8/FP16BF16

⚠️ Dense vs Sparse Attention
FlashAttention(FA1-4)全部都係 Dense Attention:

  • ✅ 計算所有 N×N token pairs 嘅 attention scores

  • ✅ Exact attention,數學上同 standard attention 完全一樣

  • ❌ 唔係 Sparse Attention(例如 Longformer, BigBird 只計 subset of pairs)

點解 Dense 仍然快?

  • FlashAttention 唔係減少計算量(仍然係 O(N²) FLOPs)

  • 而係透過 tiling 大幅減少 memory access(HBM → SRAM)

  • Trade-off: Dense = 更準確但更多 compute,Sparse = 更快但 approximate

可以 combine 嗎?

可以!有啲實現將 FlashAttention tiling 用喺 sparse patterns 上面(e.g., windowed attention + FlashAttention),兩全其美。

實際應用場景

1. 長文檔處理

  • 法律文件、研究論文分析
  • 支援 64K-128K tokens,唔使 truncate

2. 大型 Language Models 訓練

  • GPT-style models:end-to-end 加速 3-15×
  • 減少訓練時間同 cost

3. Multi-modal Models

  • Vision + Language(e.g., image tokens + text)
  • 長 context 對話系統

4. Inference Optimization

  • Batch inference with long prompts
  • 尤其喺 H100 用 FA3 + FP8,B200 用 FA4

何時用 FlashAttention?

✅ 推薦使用場景

  1. 訓練或 inference 有長序列(>2K tokens)
  2. GPU memory 受限(想 train 更大 batch size)
  3. 用 Hugging Face Transformers / PyTorch(integration 簡單)
  4. 有 A100/H100/B200(FA2/FA3/FA4 專門優化)

⚠️ 唔太適用場景

  1. Sequence length 好短(<512):overhead 可能大過 benefit
  2. 用緊舊 GPU(e.g., V100):FA1 有用,但 speedup 有限
  3. Custom attention patterns(e.g., sparse attention):需要額外實現

點樣 integrate?

PyTorch (xFormers library):

import torch from xformers.ops import memory_efficient_attention # Standard attention attn_output = torch.nn.functional.scaled_dot_product_attention(Q, K, V) # FlashAttention (auto-detect available version) attn_output = memory_efficient_attention(Q, K, V)

Hugging Face Transformers (2.0+):

from transformers import AutoModelForCausalLM # FlashAttention 自動啟用(如果 installed) model = AutoModelForCausalLM.from_pretrained( "gpt2", attn_implementation="flash_attention_2" # 指定 FA2 )

總結

FlashAttention 四代演進,展示咗點樣通過 algorithm + hardware co-design 打破 Attention 嘅 memory wall:

  1. FlashAttention 1(2022):用 IO-aware tiling(分塊)同 online softmax,將 memory 由 O(N2)O(N^2)O(N2) 降到 O(N)O(N)O(N)
  2. FlashAttention 2(2023):改變 work partitioning(partition over K/V),減少 non-matmul FLOPs,達到 2× speedup over FA1
  3. FlashAttention 3(2024):async computation + FP8 quantization(Hopper),GPU utilization 去到 75%
  4. FlashAttention 4(2025):Blackwell 突破 1,605 TFLOPS,首個 Petaflop attention kernel

關鍵 takeaway:

  • 📊 實際效果:訓練加速 3-15×,可以 train 128K+ tokens
  • 🎯 核心 insight:優化 memory I/O 比優化 FLOPs 更重要
  • ⚠️ Dense vs Sparse:FlashAttention 係 Dense(計晒所有 pairs),唔係 Sparse Attention
  • 🚀 未來方向:Blackwell FP4 support、更大 context windows、sparse + dense hybrid
  • 🏢 硬件支援:專為 NVIDIA GPU 優化(Ampere/Hopper/Blackwell),其他廠商(AMD/Intel)需要唔同 implementation

升級建議:

  • 如果你用緊 A100:upgrade 到 FA2(2× faster)
  • 如果你有 H100:upgrade 到 FA3(1.5-2× faster over FA2,支援 FP8)
  • 如果你有 B200:upgrade 到 FA4(2× faster over FA3,首個 Petaflop kernel!)

如果你而家做 LLM training 或者 long-context inference,FlashAttention 係 must-have。唔單止快,而且幾乎零 API change,drop-in replacement 就得。

相關資源

  • 📄 FlashAttention 1 論文:FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
  • 📄 FlashAttention 2 論文:FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
  • 📄 FlashAttention 3 論文:FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision
  • 💻 官方 GitHub:Dao-AILab/flash-attention
  • 💻 xFormers library:facebookresearch/xformers
  • 📚 PyTorch 2.0 SDPA docs:torch.nn.functional.scaled_dot_product_attention
Back to all articles
目錄