當 Transformer 統治咗成個 AI 世界,Attention 機制嘅計算效率就成為咗最大嘅樽頸。FlashAttention 用 IO-aware 嘅思路,將原本需要 memory 嘅 Attention 壓縮到 ,而且仲越做越快。由第一代嘅 tiling 技術,到第二代嘅並行優化,再到第三代嘅異步計算同 FP8 支援,最後到第四代喺 Blackwell 架構突破 Petaflop 級別,呢四代 FlashAttention 點樣將 Transformer Attention 推向極致?
TL;DR
核心重點:
- 🎯 FlashAttention 1(2022):IO-aware tiling(分塊),將 Attention 由 HBM 搬入 SRAM,減少 memory access bottleneck
- 🔄 FlashAttention 2(2023):改變 parallelism 策略(partition over K/V),減少 non-matmul FLOPs,達到 2× 加速
- ⚡ FlashAttention 3(2024):異步計算 + FP8 低精度(Hopper H100),GPU 利用率去到 75%(vs FA2 嘅 35%)
- 🚀 FlashAttention 4(2025):Blackwell B200 突破,達到 1,605 TFLOPS(71% 理論峰值),係第一個突破 Petaflop 嘅 attention kernel
- ✅ 實際效果:訓練速度提升 3-15×,長序列支援 128K+ tokens,幾乎零 memory overhead
- ⚠️ 重要:FlashAttention 係 Dense Attention(計晒所有 N×N pairs),唔係 Sparse Attention
Table of Contents
- 問題背景:Attention 嘅 Memory Wall
- FlashAttention 1:IO-Aware Tiling 革命
- FlashAttention 2:Work Partitioning 同 Parallelism 優化
- FlashAttention 3:異步計算同 FP8 低精度
- FlashAttention 4:Blackwell 突破 Petaflop
- 效能比較同實際應用
- 何時用 FlashAttention?
- 總結
- 相關資源
問題背景:Attention 嘅 Memory Wall
Standard Attention 嘅瓶頸
標準 Attention 機制嘅計算流程:
where:
- :Query, Key, Value matrices(shape: )
- :sequence length
- :head dimension
問題出喺邊?
喺 GPU 上面,Attention 嘅計算涉及多次 HBM(High Bandwidth Memory)同 SRAM(on-chip cache)之間嘅數據搬運:
- 由 HBM load → 計算 → 寫返 入 HBM
- 由 HBM load → 計算 → 寫返 入 HBM
- 由 HBM load → 計算 → 寫 output 入 HBM
⚠️ Memory Bottleneck
HBM bandwidth(~1.5 TB/s on A100)遠低於 SRAM bandwidth(~19 TB/s),而 Attention 需要 嘅 memory 去 store 同 matrices。當 sequence length 去到幾千甚至幾萬,HBM access 就成為咗最大嘅瓶頸。
⚠️ 重要澄清:呢啲全部都係 GPU 上面嘅 memory!
HBM (DRAM) = GPU 嘅 VRAM:就係你買 GPU 時講嘅「40 GB VRAM」,用 HBM2e/HBM3 技術
SRAM = GPU chip 入面嘅 on-chip cache:唔係 CPU 嘅 memory!每個 GPU 嘅 SM 都有自己嘅 SRAM
簡單講:DRAM 同 SRAM 都喺 GPU 上面,分別係:
DRAM (HBM):喺 GPU chip 外面,大但慢(就係 VRAM)
SRAM:喺 GPU chip 入面,細但快(L1 cache / Shared Memory)
| Memory Type | 實際硬件 | Size | Bandwidth | 物理位置 | 你可能聽過嘅名 |
|---|---|---|---|---|---|
| HBM (DRAM) | HBM2e/HBM3 stacks | 40-80 GB | ~1.5-3 TB/s | GPU 封裝外圍,用 silicon interposer 連接 | ✅ VRAM(Video RAM) |
| SRAM (L1/Shared) | On-chip SRAM | ~192 KB/SM | ~19 TB/s | GPU chip 入面,每個 SM 有自己嘅 SRAM | Shared Memory / L1 Cache |
| L2 Cache | On-chip SRAM | 40-50 MB | ~4-7 TB/s | GPU chip 入面,所有 SM 共享 | L2 Cache |
💡 生活例子:圖書館 vs 書枱
想像你做緊功課:
HBM (DRAM) = 圖書館:好大(幾萬本書),但每次拎書要行去書架,好慢
SRAM = 你書枱:好細(只放到 10 本書),但拎書即刻就拎到,好快
Standard Attention = 每計一步都要行去圖書館拎書,再擺返落書架
FlashAttention = 一次過拎晒需要嘅書返嚟書枱,做晒先擺返
Key insight:行去圖書館(HBM access)嘅時間 >> 睇書(計算)嘅時間,所以減少往返次數先係關鍵!
❓ 等等,SRAM 咁細,點放得落成個 Attention?
你可能會問:SRAM 只有 ~192 KB,但一個 Attention matrix(N×N)可以好大(例如 16K tokens → 256 MB),點可能放得落?答案:唔係將所有嘢一次過放入 SRAM,而係用 Tiling(分塊)分批處理!
❌ 唔係:將成個 N×N matrix 一次過放入 SRAM(根本放唔落)
✅ 而係:將大 matrix 切成細塊(e.g., 64×64),逐個 block 放入 SRAM 計算
每個 block 細到放得入 SRAM → 計完即棄 → load 下一個 block
trade-off:雖然要做多次 HBM access(因為分批),但比起 standard attention 少好多!
FlashAttention 嘅策略:
唔係消除所有 HBM access(唔可能),而係將 HBM access 由 降到 ,其中 係 SRAM size。透過 tiling,將「每個中間步驟都要寫入 HBM」變成「只係 load inputs 同 write final outputs」。
FlashAttention 1:IO-Aware Tiling 革命
核心思路:Tiling(分塊)同 Online Softmax
FlashAttention 1 嘅突破在於兩個技術:
- Tiling:將大 matrix 切成細塊
咩係 Tiling?
Tiling 就係將大 matrix 分成細嘅 blocks(tiles),逐個 block 放入 SRAM 處理。
點解需要 Tiling?
- 一個 16K×16K 嘅 Attention matrix = 256 MB(用 FP16)
- SRAM 只有 ~192 KB per SM
- 結論:根本唔可能將成個 matrix 一次過放入 SRAM!
解決方法:將大 matrix 切成細塊(e.g., 128×128 = 32 KB),咁就放得入 SRAM。
🎯 生活例子:搬屋
你要搬 100 箱嘢,但你部車一次只載到 10 箱:
Without Tiling:租架大貨車,一次過載晒 100 箱
但你根本租唔到咁大架車(= SRAM 太細,放唔落)
唯有用好大嘅倉庫(= HBM)暫存,但出入倉好慢
With Tiling:用你自己部車,分 10 次,每次載 10 箱
每次只用少少空間(= 放得入 SRAM)
雖然要行 10 次,但因為你唔需要用慢嘅倉庫,overall 反而快咗
關鍵:FlashAttention 唔係「將所有嘢放入 SRAM」(根本做唔到),而係「逐小塊處理,每小塊都放得入 SRAM」。
Tiling 示意圖:
假設 都係 matrix,我哋分成 嘅 blocks:
Tiling 流程:
- 將 Q 分成 4 個 blocks(Q₁, Q₂, Q₃, Q₄),每個 block 例如 128×64 = 16 KB(放得入 SRAM)
- 將 K 都分成 4 個 blocks(K₁, K₂, K₃, K₄)
- 逐對計算:
- Load Q₁ 同 K₁ 入 SRAM → 計 S₁₁ = Q₁K₁ᵀ → 用完即棄
- Load Q₁ 同 K₂ 入 SRAM → 計 S₁₂ = Q₁K₂ᵀ → accumulate
- 繼續...
- 關鍵:每次只處理一個細 block,唔需要 store 成個 N×N matrix
實際數字例子:
- 假設 N = 16K tokens, d = 64 (head dim), block size = 128
- 完整 attention matrix: 16K × 16K × 2 bytes = 512 MB(放唔入 SRAM)
- 每個 block: 128 × 128 × 2 bytes = 32 KB(✅ 放得入 SRAM!)
- 需要嘅 blocks: (16K/128)² = 16,384 個 blocks
- FlashAttention 會逐個 block 計,所以 peak memory 只係 32 KB,唔係 512 MB
- Online Softmax:唔需要睇晒全部先計
Online Softmax 可以 incremental 咁更新 softmax,唔需要 store 成個 matrix。
🎯 核心概念:IO-Aware Algorithm
唔係單純優化 FLOPs,而係優化 memory I/O。通過 tiling 同 recomputation,將 Attention 嘅 memory footprint 由 降到 ,同時 HBM access 減少到 ( 係 SRAM size)。
Online Softmax 原理
傳統 softmax 需要兩次 pass:
- 計
- 計
Online 版本可以 incremental 更新:
當有新嘅 block 加入已計算嘅 :
咁就可以一邊 load 新 blocks,一邊更新 softmax,唔使 store 成個 matrix。
Backward Pass:Recomputation
FlashAttention 1 喺 backward pass 唔 save 同 ,而係 recompute 佢哋。
Trade-off:
- ✅ Memory footprint 大減(唔使 store 嘅 intermediate results)
- ⚠️ 額外嘅 FLOPs(recomputation cost)
- ✅ 但因為減少咗 HBM I/O,overall 反而更快
效能提升
vs Standard Attention(PyTorch implementation):
- GPT-2 訓練:15% faster end-to-end
- 長序列(e.g., Path-X 16K tokens):3× faster
- Memory usage:可以 train 去到 64K tokens(standard attention OOM)
FlashAttention 2:Work Partitioning 同 Parallelism 優化
FlashAttention 2(2023 年)專注解決 FA1 嘅兩大瓶頸:低 GPU occupancy 同 過多 non-matmul operations。
核心問題:FA1 嘅 Parallelism 策略唔夠好
FA1 嘅做法(partition over Q):
- 將 分成 個 blocks,每個 thread block 負責一個 block
- 每個 thread block 需要 iterate over 所有 blocks
- 問題:唔同 thread blocks 會重複 load 相同嘅 blocks
具體例子:
假設有 4 個 thread blocks(TB0, TB1, TB2, TB3),每個負責一個 block:
- TB0 處理 Q₁:需要 load K₁, K₂, K₃, K₄ 同 V₁, V₂, V₃, V₄
- TB1 處理 Q₂:又要 load K₁, K₂, K₃, K₄ 同 V₁, V₂, V₃, V₄(重複!)
- TB2 處理 Q₃:又要 load K₁, K₂, K₃, K₄ 同 V₁, V₂, V₃, V₄(重複!)
- TB3 處理 Q₄:又要 load K₁, K₂, K₃, K₄ 同 V₁, V₂, V₃, V₄(重複!)
結果:K, V 被 load 咗 4 次,浪費 memory bandwidth!
FA2 嘅解決方法:Partition over K/V instead
FA2 嘅新策略(partition over K, V, batch, head):
- 將 parallelism 改為 outer loop over K/V blocks
- 每個 thread block 負責一個 block,然後 iterate over 所有 blocks
- 好處:每個 block 只 load 一次,所有 thread blocks 共享
同一個例子用 FA2:
假設有 4 個 thread blocks,每個負責一個 block:
- TB0 負責 K₁V₁:load 一次 K₁V₁,iterate Q₁, Q₂, Q₃, Q₄
- TB1 負責 K₂V₂:load 一次 K₂V₂,iterate Q₁, Q₂, Q₃, Q₄
- TB2 負責 K₃V₃:load 一次 K₃V₃,iterate Q₁, Q₂, Q₃, Q₄
- TB3 負責 K₄V₄:load 一次 K₄V₄,iterate Q₁, Q₂, Q₃, Q₄
結果:K, V 每個只 load 一次,減少 4× memory access!
🎯 Work Partitioning 策略比較
FA1:
for each Q_block: for each K_block: compute attention每個 Q block 係一個 thread block
K, V 被重複 load 次( = Q blocks 數量)
FA2:
for each K_block: for each Q_block: compute attention每個 K/V block 係一個 thread block
K, V 只 load 一次
額外 benefit:可以同時 parallelize over batch 同 head dimension
技術細節:點樣處理 Output Accumulation?
因為 FA2 改咗 loop order,每個 thread block 只計算 output 嘅一部分,需要 atomic add 嚟 accumulate results。
FA2 用 locking mechanism:
- 每個 output block 有一個 lock
- Thread block 計完自己嘅 partial output 後,acquire lock → update → release lock
- 用 GPU 嘅 atomic operations(e.g.,
atomicAdd)實現
Trade-off:
- ✅ 減少 K, V memory access(主要 bottleneck)
- ⚠️ 增加少少 synchronization overhead(但遠比 memory savings 值得)
2. 減少 Non-Matmul FLOPs
問題背景:
喺 GPU 上,matmul operations(e.g., , )可以充分利用 Tensor Cores,達到 peak throughput。但其他 operations(e.g., softmax, dropout, masking, rescaling)係 element-wise operations,慢好多。
FA1 嘅 non-matmul FLOPs:
- Forward: rescaling, softmax, dropout, masking
- Backward: softmax gradient 特別貴(需要 recompute softmax)
FA2 優化:
- 簡化 backward pass:用更高效嘅 online softmax gradient 公式,減少 recomputation
- 更好嘅 warp scheduling:將 non-matmul ops 同 matmul ops overlap
- Reduce rescaling ops:將某啲 rescaling 提前做,減少重複計算
💡 Non-matmul FLOPs 比例:
FA1 forward: ~25% FLOPs 係 non-matmul(softmax, masking 等)
FA1 backward: ~40% FLOPs 係 non-matmul(softmax gradient 好貴!)
FA2: 將 non-matmul 降到 ~10-15%(forward + backward)
結果:雖然 total FLOPs 差唔多,但因為 Tensor Cores 利用率高咗,throughput 提升 2×
效能數字
| Model | Sequence Length | FA1 Speedup | FA2 Speedup |
|---|---|---|---|
| BERT-large | 512 | 1.7× | 2.2× |
| GPT-3 | 2048 | 3.0× | 5.1× |
| Long document | 16K | 5.2× | 9.3× |
(相對 PyTorch standard attention)
FlashAttention 3:異步計算同 FP8 低精度
FlashAttention 3(2024 年)專為 Hopper architecture(H100 GPU)設計。
1. Asynchronous Computation
問題:FA2 喺 Tensor Cores 做 matmul 嘅時候,其他 units(e.g., softmax, SRAM load/store)會 idle。
解決方法:利用 Hopper 嘅 asynchronous TMA(Tensor Memory Accelerator):
- Matmul 同 data movement overlap
- 當 warp 0 做 matmul 嘅時候,warp 1-3 可以 prefetch 下一個 block
🚀 效能突破
通過 async overlap,FA3 將 GPU utilization 由 FA2 嘅 35% 提升到 75%,喺 H100 上達到 1.5-2× speedup over FA2。
2. FP8 Low-Precision Support
Hopper 支援 FP8(8-bit floating point),可以:
- 2× matmul throughput(相對 FP16/BF16)
- 減少 memory bandwidth(因為數據細咗一半)
FA3 嘅 FP8 strategy:
- 用 FP8 做 matmul
- Intermediate results(e.g., softmax)仍然用 FP16,保持 numerical stability
- Incoherent processing:避免 FP8 量化誤差 accumulate
3. Block Quantization
為咗 minimize FP8 quantization error,FA3 用 block-wise quantization:
每個 tile(e.g., 64×64)有自己嘅 scaling factor:
效能數字(H100)
- Forward pass:1.5× faster than FA2
- Backward pass:1.8× faster
- Overall training:~1.7× speedup
- Sequence length 16K:throughput 去到 2.2 PFLOPs/s
FlashAttention 4:Blackwell 突破 Petaflop
FlashAttention 4(2025 年 1 月發布)係第一個突破 Petaflop barrier 嘅 attention kernel,專為 NVIDIA Blackwell architecture(B100/B200)設計。
Blackwell Architecture 嘅關鍵特性
NVIDIA Blackwell(2024 年發布)嘅新功能:
- Second-generation Transformer Engine:
- 硬件層面優化 attention operations
- 支援 FP4(4-bit floating point)+ micro-tensor scaling
- 2× attention-layer acceleration(vs Hopper)
- 1.5× 更多 AI compute FLOPs
- 更大 SRAM (shared memory):
- 每個 SM 嘅 shared memory 增加到 256 KB(vs Hopper 嘅 228 KB)
- 可以處理更大嘅 tiling blocks → 減少 HBM access
- HBM3e 更快 memory bandwidth:
- HBM3e bandwidth: 8 TB/s(vs H100 嘅 3.35 TB/s)
- 雖然 FlashAttention 主要 focus 減少 HBM access,但更快 HBM 仍然有幫助(尤其係 load Q, K, V inputs 同 write final outputs)
- 5th-gen NVLink(1.8 TB/s):
- NVLink 係 GPU-to-GPU 嘅高速互連技術
- ⚠️ 重要:FlashAttention 本身唔需要 NVLink(single GPU 就行)
- 但喺 multi-GPU 訓練(例如訓練超大型 LLM),NVLink 負責:
- 跨 GPU 傳輸 gradients、activations、model parameters
- 實現 data parallelism 同 model parallelism
- Blackwell 嘅 5th-gen NVLink(1.8 TB/s)比 Hopper 4th-gen(900 GB/s)快 2×
- 結合效果:FlashAttention 優化 single GPU 上嘅 memory,NVLink 優化 multi-GPU 之間嘅 communication,兩者配合先可以充分發揮大規模訓練嘅性能
- Ultra Tensor Cores:
- Blackwell Ultra 版本進一步增強
- 支援 NVFP4(Nvidia 專有 FP4 格式)
- 1-bit sign, 2-bit exponent, 1-bit mantissa
FA4 嘅突破:首個 Petaflop Attention Kernel
核心成就:
- 喺 Blackwell B200 達到 1,605 TFLOPS(forward pass)
- 相當於理論峰值嘅 71%(2,250 TFLOPS 係 B200 理論值)
- 首個突破 1 Petaflop(1,000 TFLOPS) 嘅 attention implementation
重要 note:
- FA4 目前只用 BF16(16-bit brain float),未用 Blackwell 嘅 FP4 功能
- 已經達到 71% utilization 係一個好驚人嘅成就(因為 attention 有好多 non-matmul ops)
- 未來如果加入 FP4 support,理論上可以再快 2-4×
🚀 FlashAttention 4 on Blackwell:實測數字
基於 Blackwell B200 實測:
Forward pass: 1,605 TFLOPS(71% theoretical peak)
Backward pass: 預計 ~1,400 TFLOPS(official numbers 未公佈)
Overall training: 估計比 FA3 on H100 快 2-2.5×
Sequence length 16K: 預計 throughput ~4.5 PFLOPs/s(vs FA3 嘅 2.2 PFLOPs/s)
Architecture: 目前用 BF16 only,未使用 Blackwell 嘅 FP4/FP6 capabilities
FA4 技術改進(相對 FA3)
雖然詳細技術細節未完全公開,但根據 reverse engineering 同 community 分析:
- 更激進嘅 warp scheduling:
- Blackwell 嘅 SM 有更多 warp scheduler
- FA4 充分利用呢個特性做更多 async overlap
- 優化 Tensor Memory 使用:
- Blackwell 引入新嘅 on-chip memory tier(介乎 L1 同 HBM 之間)
- FA4 可能利用呢層 memory 做更好嘅 tiling
- 減少 synchronization overhead:
- FA2/FA3 需要 atomic operations 做 output accumulation
- FA4 可能用 Blackwell 嘅 hardware-accelerated reduction
- 更細 tile size:
- 因為有 256 KB shared memory,可以用更細嘅 tiles
- 減少 HBM roundtrips,提高 data reuse
實際支援情況
# FlashAttention 4 on different architectures
import torch
from flash_attn import flash_attn_func
# Auto-detect GPU architecture
device = torch.device("cuda")
arch = torch.cuda.get_device_capability()
if arch >= (10, 0): # Blackwell (B100, B200, GB200)
# FlashAttention 4 - Petaflop performance!
print("Using FlashAttention 4 on Blackwell")
print("Expected: 1,605+ TFLOPS (BF16)")
elif arch >= (9, 0): # Hopper (H100, H200)
# 支援 FP8 + async ops
print("Using FlashAttention 3 with FP8 & async")
else: # Ampere (A100) or older
# Fall back to FA2
print("Using FlashAttention 2")
⚠️ Blackwell 支援 Status(2026 年 2 月)
✅ FlashAttention 4 已發布(2025 年 1 月)
✅ Blackwell B200/B100 已支援
⚠️ FP4 support 未啟用:FA4 目前只用 BF16
⚠️ 某啲 Blackwell variants(e.g., AGX Thor)可能未完全支援
📅 FP4-enabled version 預計 2026 年中發布
效能比較同實際應用
Benchmark Summary
| Metric | Standard | FA1 | FA2 | FA3 | FA4 |
|---|---|---|---|---|---|
| GPU | A100 | A100 | A100 | H100 | B200 |
| 發布年份 | - | 2022 | 2023 | 2024 | 2025 |
| HBM Access | Same | Same | Same | ||
| GPU Utilization | ~20% | ~25% | ~35% | ~75% | ~71% |
| Peak Throughput | ~200 TFLOPS | ~250 TFLOPS | ~450 TFLOPS | ~950 TFLOPS | 1,605 TFLOPS |
| GPT-3 (2K seq) | Baseline | 3.0× | 5.1× | 8.5× | ~15× |
| Max Seq Length | ~4K | 64K | 64K+ | 128K+ | 128K+ |
| Precision | FP16 | FP16 | FP16/BF16 | FP8/FP16 | BF16 |
⚠️ Dense vs Sparse Attention
FlashAttention(FA1-4)全部都係 Dense Attention:
✅ 計算所有 N×N token pairs 嘅 attention scores
✅ Exact attention,數學上同 standard attention 完全一樣
❌ 唔係 Sparse Attention(例如 Longformer, BigBird 只計 subset of pairs)
點解 Dense 仍然快?
FlashAttention 唔係減少計算量(仍然係 O(N²) FLOPs)
而係透過 tiling 大幅減少 memory access(HBM → SRAM)
Trade-off: Dense = 更準確但更多 compute,Sparse = 更快但 approximate
可以 combine 嗎?
可以!有啲實現將 FlashAttention tiling 用喺 sparse patterns 上面(e.g., windowed attention + FlashAttention),兩全其美。
實際應用場景
1. 長文檔處理
- 法律文件、研究論文分析
- 支援 64K-128K tokens,唔使 truncate
2. 大型 Language Models 訓練
- GPT-style models:end-to-end 加速 3-15×
- 減少訓練時間同 cost
3. Multi-modal Models
- Vision + Language(e.g., image tokens + text)
- 長 context 對話系統
4. Inference Optimization
- Batch inference with long prompts
- 尤其喺 H100 用 FA3 + FP8,B200 用 FA4
何時用 FlashAttention?
✅ 推薦使用場景
- 訓練或 inference 有長序列(>2K tokens)
- GPU memory 受限(想 train 更大 batch size)
- 用 Hugging Face Transformers / PyTorch(integration 簡單)
- 有 A100/H100/B200(FA2/FA3/FA4 專門優化)
⚠️ 唔太適用場景
- Sequence length 好短(<512):overhead 可能大過 benefit
- 用緊舊 GPU(e.g., V100):FA1 有用,但 speedup 有限
- Custom attention patterns(e.g., sparse attention):需要額外實現
點樣 integrate?
PyTorch (xFormers library):
import torch
from xformers.ops import memory_efficient_attention
# Standard attention
attn_output = torch.nn.functional.scaled_dot_product_attention(Q, K, V)
# FlashAttention (auto-detect available version)
attn_output = memory_efficient_attention(Q, K, V)
Hugging Face Transformers (2.0+):
from transformers import AutoModelForCausalLM
# FlashAttention 自動啟用(如果 installed)
model = AutoModelForCausalLM.from_pretrained(
"gpt2",
attn_implementation="flash_attention_2" # 指定 FA2
)
總結
FlashAttention 四代演進,展示咗點樣通過 algorithm + hardware co-design 打破 Attention 嘅 memory wall:
- FlashAttention 1(2022):用 IO-aware tiling(分塊)同 online softmax,將 memory 由 降到
- FlashAttention 2(2023):改變 work partitioning(partition over K/V),減少 non-matmul FLOPs,達到 2× speedup over FA1
- FlashAttention 3(2024):async computation + FP8 quantization(Hopper),GPU utilization 去到 75%
- FlashAttention 4(2025):Blackwell 突破 1,605 TFLOPS,首個 Petaflop attention kernel
關鍵 takeaway:
- 📊 實際效果:訓練加速 3-15×,可以 train 128K+ tokens
- 🎯 核心 insight:優化 memory I/O 比優化 FLOPs 更重要
- ⚠️ Dense vs Sparse:FlashAttention 係 Dense(計晒所有 pairs),唔係 Sparse Attention
- 🚀 未來方向:Blackwell FP4 support、更大 context windows、sparse + dense hybrid
- 🏢 硬件支援:專為 NVIDIA GPU 優化(Ampere/Hopper/Blackwell),其他廠商(AMD/Intel)需要唔同 implementation
升級建議:
- 如果你用緊 A100:upgrade 到 FA2(2× faster)
- 如果你有 H100:upgrade 到 FA3(1.5-2× faster over FA2,支援 FP8)
- 如果你有 B200:upgrade 到 FA4(2× faster over FA3,首個 Petaflop kernel!)
如果你而家做 LLM training 或者 long-context inference,FlashAttention 係 must-have。唔單止快,而且幾乎零 API change,drop-in replacement 就得。
相關資源
- 📄 FlashAttention 1 論文:FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
- 📄 FlashAttention 2 論文:FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
- 📄 FlashAttention 3 論文:FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision
- 💻 官方 GitHub:Dao-AILab/flash-attention
- 💻 xFormers library:facebookresearch/xformers
- 📚 PyTorch 2.0 SDPA docs:torch.nn.functional.scaled_dot_product_attention