Billy Tse
HomeRoadmapBlogContact
Playground
Buy me a bug

© 2026 Billy Tse

OnlyFansLinkedInGitHubEmail
Back to Blog
January 29, 2026•28 min read

Transformer 架構詳解:Attention、QKV 同 Multi-Head 機制

深入淺出講解 Transformer 架構,包括 Self-Attention 機制、Query/Key/Value 概念、Encoder-Decoder 設計,以及 Multi-Head Attention 嘅運作原理

TransformerAttention MechanismsCSCI 5640 NLP

Transformer 架構詳解

Transformer 係 2017 年由 Google 提出嘅模型架構,出自經典論文 "Attention is All You Need"。呢個名真係改得好,因為 Transformer 真係將 attention 機制發揮到極致,完全唔用 RNN 或者 CNN。

我第一次學 Transformer 嗰陣覺得好複雜,但係理解咗之後就發覺佢嘅設計真係好巧妙。而家幾乎所有大型語言模型(GPT、BERT、LLaMA 等等)都係基於 Transformer 架構,所以理解佢真係好重要。

點解需要 Transformer?

喺 Transformer 之前,NLP 任務主要用 RNN (Recurrent Neural Networks) 同佢嘅變種(LSTM、GRU)。但係 RNN 有幾個大問題:

順序處理限制

RNN 要一個字一個字咁處理,唔可以平行化。處理 "I love machine learning" 嘅時候,要先處理 "I",再處理 "love",再處理 "machine"……要等前面處理完先可以處理後面。

長程依賴問題

就算有 LSTM,處理好長嘅句子嗰陣,前面嘅信息都會逐漸消失。例如一句有 100 個字嘅句子,第 100 個字好難記得第 1 個字係咩。

訓練慢

因為順序處理,訓練時間好長,就算有好多 GPU 都冇辦法加速。

Transformer 就係為咗解決呢啲問題而設計出嚟。佢用 attention 機制令到:

  • ✅ 所有位置可以同時處理(平行化)
  • ✅ 任何兩個位置都可以直接交互(解決長程依賴)
  • ✅ 訓練速度快好多

Attention 機制核心概念

咩係 Attention?

Attention 嘅核心思想好簡單:當你處理一個字嗰陣,應該要睇返其他邊啲字。

例子:

"The animal didn't cross the street because it was too tired."

當我哋處理 "it" 呢個字嗰陣,要知道 "it" 係指咩。人類會自然咁 attend 返去 "animal" 呢個字,因為佢係 "it" 嘅 antecedent。

Attention 機制就係想 model 呢種關係——俾 model 可以動態咁決定應該關注邊啲字。

Query、Key、Value (QKV) 概念

Attention 用咗三個向量嚟計算:

Query (查詢)

  • 代表「我想搵啲咩」
  • 係當前 token 想要嘅信息

Key (鍵)

  • 代表「我提供咩信息」
  • 係每個 token 嘅標識

Value (值)

  • 代表「我實際嘅內容」
  • 係實際會被提取出嚟嘅信息

類比:餐廳點菜

我成日用呢個比喻嚟記 QKV:

想像你去餐廳:

  • Query: 你想要嘅嘢(例如 "我想要啲辣嘅菜")
  • Key: 餐牌上面每道菜嘅描述(例如 "麻婆豆腐 - 辣"、"清蒸魚 - 清淡")
  • Value: 實際嘅菜(真正上到枱嘅食物)

你會用你嘅 Query("想要辣嘅") 同每道菜嘅 Key(描述) 做比較,搵出最 match 嘅(麻婆豆腐),然後攞返對應嘅 Value(真正嘅麻婆豆腐)。

Self-Attention 數學公式

Self-Attention 嘅計算公式好出名:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dk​​QKT​)V

Mermaid 流程圖

Loading diagram...

拆解一下:

步驟 1: 計算相似度

Score=QKT\text{Score} = QK^TScore=QKT

將 Query 同所有 Key 做內積(dot product),計算相似度。相似度越高,個 score 越大。

點解用內積? 因為兩個向量嘅內積可以反映佢哋嘅相似程度。如果兩個向量指向同一個方向,內積會好大。

步驟 2: 縮放

Scaled Score=QKTdk\text{Scaled Score} = \frac{QK^T}{\sqrt{d_k}}Scaled Score=dk​​QKT​

除以 dk\sqrt{d_k}dk​​(Key 向量嘅維度開方)。

點解要縮放? 當維度好大嗰陣(例如 dk=512d_k = 512dk​=512),內積嘅值會變得好大,搞到 softmax 之後啲梯度會好細(gradient 消失問題)。除以 dk\sqrt{d_k}dk​​ 可以穩定訓練。

步驟 3: Softmax

Attention Weights=softmax(QKTdk)\text{Attention Weights} = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)Attention Weights=softmax(dk​​QKT​)

用 softmax 將 scores 轉換成概率分佈,所有權重加埋等於 1。

步驟 4: 加權求和

Output=Attention Weights×V\text{Output} = \text{Attention Weights} \times VOutput=Attention Weights×V

用 attention weights 對 Values 做加權平均,得出最終輸出。

具體數字例子

假設我哋有句子 "I love AI",每個字已經 embed 成向量。簡化起見,假設維度係 4。

輸入向量 (已經 embed):

  • "I": [1,0,1,0][1, 0, 1, 0][1,0,1,0]
  • "love": [0,1,0,1][0, 1, 0, 1][0,1,0,1]
  • "AI": [1,1,0,0][1, 1, 0, 0][1,1,0,0]

生成 QKV

用三個唔同嘅矩陣 WQW_QWQ​、WKW_KWK​、WVW_VWV​ 將輸入轉換成 Query、Key、Value:

Q = Input @ W_Q # Query K = Input @ W_K # Key V = Input @ W_V # Value

假設轉換之後:

Q = [[1, 0], K = [[1, 1], V = [[2, 0], [0, 1], [0, 1], [0, 2], [1, 1]] [1, 0]] [1, 1]]

計算 Attention (處理第一個字 "I")

當我哋處理 "I" 嗰陣,佢嘅 Query 係 [1,0][1, 0][1,0]。

1. 計算同所有 Keys 嘅相似度:

  • Score("I", "I") = [1,0]⋅[1,1]=1[1, 0] \cdot [1, 1] = 1[1,0]⋅[1,1]=1
  • Score("I", "love") = [1,0]⋅[0,1]=0[1, 0] \cdot [0, 1] = 0[1,0]⋅[0,1]=0
  • Score("I", "AI") = [1,0]⋅[1,0]=1[1, 0] \cdot [1, 0] = 1[1,0]⋅[1,0]=1

2. 縮放 (假設 dk=2d_k = 2dk​=2,所以除以 2≈1.41\sqrt{2} \approx 1.412​≈1.41):

  • Scaled scores: [0.71,0,0.71][0.71, 0, 0.71][0.71,0,0.71]

3. Softmax:

softmax([0.71,0,0.71])≈[0.42,0.16,0.42]\text{softmax}([0.71, 0, 0.71]) \approx [0.42, 0.16, 0.42]softmax([0.71,0,0.71])≈[0.42,0.16,0.42]

4. 加權求和:

Output=0.42×[2,0]+0.16×[0,2]+0.42×[1,1]=[1.26,0.74]\text{Output} = 0.42 \times [2, 0] + 0.16 \times [0, 2] + 0.42 \times [1, 1] = [1.26, 0.74]Output=0.42×[2,0]+0.16×[0,2]+0.42×[1,1]=[1.26,0.74]

呢個就係 "I" 經過 self-attention 之後嘅輸出!佢結合咗其他字(特別係 "I" 同 "AI")嘅信息。

Multi-Head Attention

點解要 Multi-Head?

如果只有一個 attention head,model 只可以用一種方式嚟理解句子。但係語言好複雜——有啲時候你想關注語法關係,有啲時候想關注語義關係。

例子:

"The dog chased the cat because it was hungry."

  • Head 1 可能專注語法:搵主語、謂語、賓語
  • Head 2 可能專注 coreference:"it" 指 "dog"
  • Head 3 可能專注語義:理解 "chased" 同 "hungry" 嘅因果關係

Multi-head attention 就係平行咁跑多個 attention,每個 head 可以學習到唔同嘅 pattern。

Multi-Head Attention 運作原理

Mermaid 架構圖

Loading diagram...

架構

假設我哋有 8 個 heads(Transformer 原本 paper 用 8 個):

def multi_head_attention(x, num_heads=8): # 1. 將輸入 project 去 Q, K, V Q = x @ W_Q # [seq_len, d_model] K = x @ W_K V = x @ W_V # 2. 分成多個 heads # 將 d_model 分成 num_heads 份 # 例如 d_model=512, num_heads=8, 每個 head 得 512/8=64 維 Q = split_heads(Q, num_heads) # [num_heads, seq_len, d_k] K = split_heads(K, num_heads) V = split_heads(V, num_heads) # 3. 每個 head 獨立計算 attention outputs = [] for i in range(num_heads): attn_output = attention(Q[i], K[i], V[i]) outputs.append(attn_output) # 4. Concatenate 所有 heads concat_output = concatenate(outputs) # [seq_len, d_model] # 5. 最後一個 linear projection final_output = concat_output @ W_O return final_output

維度變化

呢度好多人會搞亂,我自己都花咗時間理解:

假設:

  • 輸入維度:dmodel=512d_{\text{model}} = 512dmodel​=512
  • Heads 數量:h=8h = 8h=8
  • 每個 head 嘅維度:dk=dv=dmodel/h=64d_k = d_v = d_{\text{model}} / h = 64dk​=dv​=dmodel​/h=64

過程:

  1. 輸入:[seq_len,512][seq\_len, 512][seq_len,512]
  2. Project 去 Q/K/V:[seq_len,512][seq\_len, 512][seq_len,512] 各一個
  3. 分成 8 個 heads:[8,seq_len,64][8, seq\_len, 64][8,seq_len,64]
  4. 每個 head 做 attention:[8,seq_len,64][8, seq\_len, 64][8,seq_len,64]
  5. Concatenate:[seq_len,512][seq\_len, 512][seq_len,512] (因為 8×64=5128 \times 64 = 5128×64=512)
  6. 最後 projection:[seq_len,512][seq\_len, 512][seq_len,512]

重點: 每個 head 用嘅參數(矩陣 WQ,WK,WVW_Q, W_K, W_VWQ​,WK​,WV​)係唔同嘅!所以佢哋會學到唔同嘅 attention patterns。

視覺化 Multi-Head Attention

想像你有 8 個 heads 處理句子 "The cat sat on the mat":

Head 1: The → cat (主語-名詞關係) cat → sat (主語-動詞關係) sat → mat (動詞-賓語關係) Head 2: cat → The (限定詞關係) mat → the (限定詞關係) Head 3: on → mat (介詞關係) sat → on (動詞-介詞關係) ...每個 head 專注唔同 patterns

每個 head 學習到嘅 attention pattern 都唔同。有啲 heads 會學到位置關係(attend 隔籬嘅字),有啲會學到語義關係。

Encoder-Decoder 架構

Transformer 原本設計係用嚟做翻譯(machine translation),所以有 Encoder 同 Decoder 兩部分。

Encoder (編碼器)

作用: 理解輸入句子,將佢 encode 成有意義嘅 representation。

結構: 由 N 個相同嘅 layer 疊起嚟(原本 paper 用 N=6)

每個 Encoder layer 有兩個主要部分:

1. Multi-Head Self-Attention

Input → Multi-Head Self-Attention → Add & Norm
  • 每個字 attend 去所有其他字(包括自己)
  • 理解字與字之間嘅關係

2. Feed-Forward Network (FFN)

→ Feed-Forward → Add & Norm → Output
  • 兩層 fully connected network
  • 每個位置獨立處理(但係用同一組參數)
class EncoderLayer(nn.Module): def __init__(self, d_model, num_heads, d_ff): super().__init__() self.self_attn = MultiHeadAttention(d_model, num_heads) self.ffn = FeedForward(d_model, d_ff) self.norm1 = LayerNorm(d_model) self.norm2 = LayerNorm(d_model) def forward(self, x): # Self-attention + residual connection + norm attn_output = self.self_attn(x, x, x) x = self.norm1(x + attn_output) # Feed-forward + residual connection + norm ffn_output = self.ffn(x) x = self.norm2(x + ffn_output) return x

重點:

  • Residual connection (殘差連接):x+Sublayer(x)x + \text{Sublayer}(x)x+Sublayer(x) 幫助訓練深層網絡
  • Layer Normalization: 穩定訓練

Decoder (解碼器)

作用: 根據 encoder 嘅輸出,逐個字咁生成翻譯。

結構: 都係由 N 個相同嘅 layer 疊起嚟

每個 Decoder layer 有三個主要部分:

1. Masked Multi-Head Self-Attention

Output Embedding → Masked Self-Attention → Add & Norm
  • 同 encoder 嘅 self-attention 好似,但係有 mask
  • 點解要 mask? 因為 decode 嗰陣係由左到右生成,唔可以睇到未來嘅字
  • 例如生成第 3 個字嗰陣,只可以睇到第 1 同第 2 個字

2. Cross-Attention (Encoder-Decoder Attention)

→ Cross-Attention → Add & Norm
  • Query 嚟自 decoder (當前生成嘅內容)
  • Key 同 Value 嚟自 encoder (輸入句子嘅 representation)
  • 呢個係 decoder "睇返" 輸入句子嘅機制

3. Feed-Forward Network

→ Feed-Forward → Add & Norm → Output
  • 同 encoder 嘅 FFN 一樣
class DecoderLayer(nn.Module): def __init__(self, d_model, num_heads, d_ff): super().__init__() self.masked_self_attn = MultiHeadAttention(d_model, num_heads) self.cross_attn = MultiHeadAttention(d_model, num_heads) self.ffn = FeedForward(d_model, d_ff) self.norm1 = LayerNorm(d_model) self.norm2 = LayerNorm(d_model) self.norm3 = LayerNorm(d_model) def forward(self, x, encoder_output, mask): # Masked self-attention attn_output = self.masked_self_attn(x, x, x, mask=mask) x = self.norm1(x + attn_output) # Cross-attention (Q from decoder, K,V from encoder) cross_attn_output = self.cross_attn( query=x, key=encoder_output, value=encoder_output ) x = self.norm2(x + cross_attn_output) # Feed-forward ffn_output = self.ffn(x) x = self.norm3(x + ffn_output) return x

Encoder vs Decoder 比較

EncoderDecoder
Self-AttentionBi-directional (睇晒成句)Masked (只睇前面)
Cross-Attention冇有 (attend 去 encoder 輸出)
用途理解輸入生成輸出
處理方式Parallel (一次過處理晒)Auto-regressive (逐個生成)

翻譯例子:英文 → 中文

輸入: "I love AI"

輸出: "我 愛 AI"

Encoding 階段

  1. "I love AI" → Encoder
  2. 每個 encoder layer 處理:
    • Self-attention: 理解 "love" 同 "I" 同 "AI" 嘅關係
    • FFN: 提取更高層次嘅特徵
  3. 最後得到 encoder 輸出:HencH_{\text{enc}}Henc​

Decoding 階段 (逐步生成)

Step 1: 生成第一個字

  • Decoder input: [START]
  • Masked self-attention: 只睇到 [START]
  • Cross-attention: attend 去成句 "I love AI" (透過 HencH_{\text{enc}}Henc​)
  • Output: "我"

Step 2: 生成第二個字

  • Decoder input: [START] 我
  • Masked self-attention: 睇到 [START] 同 "我"
  • Cross-attention: 再 attend 返去 "I love AI"
  • Output: "愛"

Step 3: 生成第三個字

  • Decoder input: [START] 我 愛
  • Masked self-attention: 睇到前面所有字
  • Cross-attention: attend 去 "I love AI"
  • Output: "AI"

Step 4: 生成結束符

  • Output: [END] → 停止

完整架構圖

Mermaid 流程圖

Loading diagram...

文字版架構圖

輸入: "I love AI" ↓ Input Embedding + Positional Encoding ↓ ┌─────────────────────────────┐ │ Encoder (×6 layers) │ │ ┌─────────────────────┐ │ │ │ Multi-Head Attn │ │ │ │ Add & Norm │ │ │ │ Feed-Forward │ │ │ │ Add & Norm │ │ │ └─────────────────────┘ │ └─────────────────────────────┘ ↓ H_enc (encoder output) 輸出: [START] ↓ Output Embedding + Positional Encoding ↓ ┌────────────────────────────────────┐ │ Decoder (×6 layers) │ │ ┌────────────────────────────┐ │ │ │ Masked Multi-Head Attn │ │ │ │ Add & Norm │ │ │ │ Cross-Attention ← H_enc │ ← Encoder output │ │ Add & Norm │ │ │ │ Feed-Forward │ │ │ │ Add & Norm │ │ │ └────────────────────────────┘ │ └────────────────────────────────────┘ ↓ Linear + Softmax ↓ 輸出概率: P("我") = 0.9, P("你") = 0.05, ... ↓ 生成: "我"

Encoder-Only vs Decoder-Only vs Encoder-Decoder

而家唔同嘅模型用唔同嘅架構:

Encoder-Only (例如 BERT)

結構: 只有 Encoder 部分

特點: Bi-directional attention (可以睇前面同後面)

用途:

  • 文本分類
  • Named Entity Recognition
  • Question Answering (理解型任務)

例子: 判斷 "呢個電影好正" 係正面定負面評價

# BERT 用法 input = "呢個電影好正 [MASK]" output = encoder(input) # 用 output 做分類或者 fill mask

Decoder-Only (例如 GPT)

結構: 只有 Decoder 部分 (但係冇 cross-attention)

特點: Causal/Auto-regressive (只睇前面)

用途:

  • 文本生成
  • Completion
  • 對話系統

例子: 生成故事

# GPT 用法 input = "Once upon a time" output = decoder(input) # "there was a princess"

GPT 其實係用咗 decoder 嘅 masked self-attention,但係唔需要 cross-attention (因為冇 encoder 輸出要 attend)。

Encoder-Decoder (例如 T5, BART)

結構: 有齊 Encoder 同 Decoder

特點: 結合兩者優點

用途:

  • 機器翻譯
  • 文本摘要
  • 問答系統

例子: 文本摘要

# T5 用法 input = "好長嘅文章內容..." encoder_output = encoder(input) summary = decoder(encoder_output) # "簡短總結"
架構Attention 類型代表模型最啱做咩
Encoder-OnlyBi-directionalBERT, RoBERTa理解、分類
Decoder-OnlyCausal (單向)GPT, LLaMA生成
Encoder-DecoderBi-dir + Causal + CrossT5, BART翻譯、摘要

其他重要組件

Positional Encoding (位置編碼)

Attention 本身冇位置信息——"I love AI" 同 "AI love I" 對佢嚟講冇分別(因為都係計 token 之間嘅相似度)。

所以要加 positional encoding 話俾 model 知每個字嘅位置:

PE(pos,2i)=sin⁡(pos100002i/d)PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d}}\right)PE(pos,2i)​=sin(100002i/dpos​) PE(pos,2i+1)=cos⁡(pos100002i/d)PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d}}\right)PE(pos,2i+1)​=cos(100002i/dpos​)
  • pospospos: 位置 (0, 1, 2, ...)
  • iii: 維度 index
  • ddd: embedding 維度

呢個設計令到:

  • 每個位置有獨特嘅 encoding
  • 相對位置可以透過 sin/cos 嘅性質計算出嚟

我之前寫咗篇 RoPE 嘅文章,嗰個係 positional encoding 嘅改進版,有興趣可以睇返!

Feed-Forward Network (FFN)

每個 encoder/decoder layer 入面都有個 FFN:

FFN(x)=ReLU(xW1+b1)W2+b2\text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2FFN(x)=ReLU(xW1​+b1​)W2​+b2​
  • 兩層 fully connected
  • 中間有 ReLU activation
  • 中間層嘅維度通常係 dmodeld_{\text{model}}dmodel​ 嘅 4 倍(例如 512 → 2048 → 512)

作用:

  • Attention 負責 "邊個字同邊個字有關係"
  • FFN 負責 "對每個位置做 transformation"
  • 兩者配合,先可以有強大嘅表達能力

Layer Normalization

每個 sub-layer 之後都會做 Layer Norm:

LayerNorm(x)=γx−μσ+β\text{LayerNorm}(x) = \gamma \frac{x - \mu}{\sigma} + \betaLayerNorm(x)=γσx−μ​+β
  • μ\muμ: mean
  • σ\sigmaσ: standard deviation
  • γ,β\gamma, \betaγ,β: learnable parameters

作用: 穩定訓練,避免梯度爆炸/消失

Residual Connection (殘差連接)

每個 sub-layer 都有 residual connection:

Output=LayerNorm(x+Sublayer(x))\text{Output} = \text{LayerNorm}(x + \text{Sublayer}(x))Output=LayerNorm(x+Sublayer(x))

呢個係由 ResNet 引入嘅概念,對訓練深層網絡好重要。冇佢嘅話,疊 6 層都會好難 train。

訓練同推理

訓練時 (Training)

Encoder: 一次過處理成句輸入

Decoder: 都係一次過處理(用 teacher forcing)

Teacher Forcing: 訓練時,就算 decoder 第 2 步生成錯咗,第 3 步都係用正確嘅答案做 input,而唔係用自己生成嘅錯誤輸出。

例子:

  • Target: "我 愛 AI"
  • Step 1: Input [START] → Output "我" ✓
  • Step 2: Input [START] 我 → Output "喜歡" ✗ (錯咗,正確係 "愛")
  • Step 3: Input [START] 我 愛 ← 用返正確答案,唔係 "喜歡"

呢個可以加快訓練,但係會有 exposure bias 問題。

推理時 (Inference)

Encoder: 一次過處理輸入

Decoder: Auto-regressive (逐個字生成)

每生成一個字,就將佢加入 input,再生成下一個字,直到出現 [END] token。

通常會用 beam search 而唔係 greedy decoding,可以搵到更好嘅結果。

點解 Transformer 咁成功?

我自己覺得 Transformer 成功有幾個原因:

1. 平行化

唔似 RNN 要順序處理,Transformer 可以同時計所有位置,訓練快好多。我試過 train LSTM 同 Transformer,速度差好遠。

2. 長程依賴

任何兩個位置都可以直接交互(透過 attention),唔似 RNN 要經過好多步先連接到。

3. 可解釋性

可以 visualize attention weights,睇到 model 關注邊啲字。呢個對理解 model 行為好有用。

4. 靈活性

可以輕易調整成 encoder-only、decoder-only 或者 encoder-decoder,適應唔同任務。

5. Scalability

增加 layers、增加 heads、增加維度都好容易,而且效果會越嚟越好(scaling laws)。

優化版 Attention 機制

標準 self-attention 有 O(n2)O(n^2)O(n2) 嘅計算複雜度同記憶體需求,所以研究者提出咗好多優化方法。呢度講下兩個最重要嘅優化:Flash Attention 同 PagedAttention。

Flash Attention

問題:記憶體存取係瓶頸

我一開始以為 attention 慢係因為計算量大,但其實真正嘅瓶頸係 memory access (記憶體存取)。

現代 GPU 有幾種記憶體:

  • HBM (High Bandwidth Memory): 大但慢(例如 40GB A100)
  • SRAM (on-chip memory): 細但快好多(例如 20MB)

標準 attention 實現會:

  1. 由 HBM 讀 Q, K → 計 QKTQK^TQKT → 寫返 HBM
  2. 由 HBM 讀 QKTQK^TQKT → 計 softmax → 寫返 HBM
  3. 由 HBM 讀 softmax 結果同 V → 計最終輸出 → 寫返 HBM

每次都要讀寫 HBM,好慢!而且要 store 成個 attention matrix [n×n][n \times n][n×n] 落 HBM,好食記憶體。

Flash Attention 嘅解決方法

Flash Attention (由 Tri Dao 等人喺 2022 年提出) 用咗兩個關鍵技巧:

1. Tiling (分塊計算)

唔好一次過計成個 attention matrix,而係將 Q, K, V 分成細塊,逐塊咁計。每塊可以 fit 入去 SRAM,減少 HBM 存取。

2. Recomputation (重新計算)

Backward pass 嗰陣唔 store attention matrix,而係 recompute 佢。雖然多咗計算,但係少咗記憶體存取,overall 反而快咗!

算法流程

簡化版嘅 Flash Attention 流程:

def flash_attention(Q, K, V, block_size): """ Q, K, V: [seq_len, d_model] block_size: 每次處理幾多個 tokens (例如 128) """ seq_len = Q.shape[0] output = zeros_like(Q) # 將 Q 分成 blocks for i in range(0, seq_len, block_size): Q_block = Q[i:i+block_size] # Load 入 SRAM # 對於呢個 Q block,逐個 K/V block 咁處理 block_output = zeros_like(Q_block) block_max = -inf block_sum = 0 for j in range(0, seq_len, block_size): K_block = K[j:j+block_size] # Load 入 SRAM V_block = V[j:j+block_size] # 計呢個 block 嘅 attention scores = Q_block @ K_block.T / sqrt(d_k) # Online softmax (唔使 store 成個 matrix) new_max = max(block_max, scores.max()) scores = exp(scores - new_max) # 更新 output correction = exp(block_max - new_max) block_output = block_output * correction + scores @ V_block block_sum = block_sum * correction + scores.sum() block_max = new_max # Normalize output[i:i+block_size] = block_output / block_sum return output

關鍵點:

  • 一次只 load 細塊入 SRAM,減少 HBM 存取
  • 用 online softmax 技巧,唔使 store 成個 attention matrix
  • 唔寫中間結果返 HBM

效能提升

Flash Attention 嘅改進真係好明顯:

速度:

  • GPT-2 (1024 tokens): 快 2-4x
  • 長序列 (2048+ tokens): 快 5-9x

記憶體:

  • 記憶體使用由 O(n2)O(n^2)O(n2) 降到 O(n)O(n)O(n)
  • 可以 train 更長嘅 sequence

例子:

喺 A100 GPU 上面,標準 attention 最多 handle 到 1024 tokens (batch size=16),但係 Flash Attention 可以去到 4096 tokens!

我自己試過用 Flash Attention train 長文檔模型,記憶體省咗好多,可以用大啲嘅 batch size,訓練快咗成倍。

Flash Attention 2 & 3

Flash Attention 2 (2023):

  • 進一步優化 parallelism
  • 減少 non-matmul operations
  • 再快 2x

Flash Attention 3 (2024):

  • 針對 Hopper GPU (H100/H800)
  • 利用新硬件特性(例如 wgmma 指令)
  • 達到接近理論上限嘅速度

而家好多 framework 都內置咗 Flash Attention:

import torch import torch.nn.functional as F # PyTorch 2.0+ 內置支援 output = F.scaled_dot_product_attention( query, key, value, attn_mask=None, is_causal=False ) # 會自動用 Flash Attention!

PagedAttention

問題:Inference 嘅記憶體浪費

當你用 LLM 做 inference (例如 ChatGPT 咁逐個字生成) 嗰陣,要 cache 之前所有 tokens 嘅 Key 同 Value (KV cache),避免重複計算。

KV Cache 記憶體計算:

對於一個 13B 參數嘅模型(例如 LLaMA-13B):

  • 每個 token 嘅 KV cache: ~800 KB
  • 生成 2048 tokens: ~1.6 GB per request
  • 處理 100 個 concurrent requests: ~160 GB!

問題係:我哋要預先 allocate 記憶體俾最長可能嘅 sequence (例如 2048 tokens),就算實際只用咗 100 tokens 都要 reserve 晒。呢個叫做 internal fragmentation (內部碎片化)。

另外,就算 requests 長度唔同,都好難 share 記憶體,造成 external fragmentation (外部碎片化)。

研究發現,實際上只有 20-40% 嘅 KV cache 記憶體真係有用緊!

PagedAttention 嘅解決方法

PagedAttention (由 UC Berkeley 嘅 vLLM 團隊喺 2023 年提出) 借鑑咗操作系統嘅 virtual memory (虛擬記憶體) 概念。

核心思想:

  • 將 KV cache 分成固定大小嘅 blocks/pages (例如每個 block 存 16 tokens)
  • 用 page table 嚟 map logical blocks → physical blocks
  • 唔同 requests 可以 share 相同嘅 blocks (例如 shared prompt)
  • 按需 allocate blocks,冇 internal fragmentation

視覺化例子

假設我哋要生成呢個 request:

Prompt: "Translate to French: Hello, how are you?"

Generation: "Bonjour, comment allez-vous?"

傳統方法 (Contiguous Memory)

記憶體: [________________________________] ↑ 預先 allocate 2048 tokens 嘅空間 實際用咗: [Prompt tokens][Generated tokens][______浪費______] 12 tokens + 6 tokens = 18/2048 (0.9% 使用率!)

PagedAttention (Paged Memory)

Logical view: [Block 0] → [Block 1] → [Block 2] → ... Prompt Prompt Generated tokens tokens tokens (16) (2) (6) Physical memory: Page table: Logical Block 0 → Physical Block 5 [16 tokens] Logical Block 1 → Physical Block 12 [2 tokens] Logical Block 2 → Physical Block 3 [6 tokens] 只 allocate 需要嘅 blocks,冇浪費!

Shared Prefix 優化

更勁嘅係,如果多個 requests 有相同嘅 prefix,可以 share blocks!

例子: Batch translation

Request 1: "Translate to French: Hello" → "Bonjour" Request 2: "Translate to French: Goodbye" → "Au revoir" Request 3: "Translate to French: Thank you" → "Merci"

所有 requests 都 share "Translate to French: " 嘅 KV cache blocks:

Shared blocks (read-only): ["Translate"] ["to"] ["French"] [":"] ↑ ↑ ↑ ↑ └───────────┴────────┴────────┴─── Shared by all 3 requests! Private blocks (copy-on-write): Request 1: ["Hello"] ["Bonjour"] Request 2: ["Goodbye"] ["Au revoir"] Request 3: ["Thank you"] ["Merci"]

呢個對 few-shot prompting 特別有用,因為好多 requests 會 share 同一個長 prompt!

實際效能

vLLM (用咗 PagedAttention) 同其他 serving systems 比較:

Throughput (每秒處理幾多 requests):

  • 比 FasterTransformer 快 2-4x
  • 比 Hugging Face Text Generation Inference 快 2-3x
  • 比 naive PyTorch 快 8-10x

記憶體使用:

  • 減少 55-80% 記憶體浪費
  • 可以同時處理更多 requests

例子:

喺 A100 (40GB) 上面 serve LLaMA-13B:

  • 傳統方法: 最多 ~20 concurrent requests
  • vLLM (PagedAttention): 可以去到 50+ concurrent requests

我喺 production 用過 vLLM,真係可以省返好多 GPU,throughput 提升明顯。特別係當你有好多 requests 用相同 prompt 嗰陣(例如 chatbot with system prompt),memory sharing 嘅效果好好。

PagedAttention 實現

核心係改動 attention 計算,令佢可以 access 非連續嘅 KV blocks:

def paged_attention(Q, K_pages, V_pages, page_table, block_size=16): """ Q: [batch, num_heads, 1, head_dim] # 當前 token K_pages: [num_physical_blocks, block_size, num_heads, head_dim] V_pages: [num_physical_blocks, block_size, num_heads, head_dim] page_table: [batch, max_num_blocks] # logical → physical mapping """ batch_size = Q.shape[0] outputs = [] for b in range(batch_size): # Get physical block indices for this request physical_blocks = page_table[b] # [num_blocks_used] # Gather K and V from physical blocks K_seq = [] V_seq = [] for block_idx in physical_blocks: if block_idx >= 0: # -1 means unused K_seq.append(K_pages[block_idx]) V_seq.append(V_pages[block_idx]) K_concat = concat(K_seq, dim=0) # [seq_len, num_heads, head_dim] V_concat = concat(V_seq, dim=0) # Standard attention scores = Q[b] @ K_concat.T / sqrt(head_dim) attn_weights = softmax(scores, dim=-1) output = attn_weights @ V_concat outputs.append(output) return stack(outputs)

實際 implementation 會用 CUDA kernel 嚟優化 gather 操作,令存取非連續記憶體都好快。

Flash Attention vs PagedAttention 比較

Flash AttentionPagedAttention
目標加快 attention 計算減少 KV cache 記憶體浪費
主要用途Training + InferenceInference (serving)
優化對象計算效率 + 記憶體存取記憶體管理
技術Tiling + RecomputationVirtual memory + Paging
加速2-9x faster2-4x throughput
記憶體節省O(n2)‘→‘O(n)O(n^2)` → `O(n)O(n2)‘→‘O(n) (attention matrix)55-80% 減少浪費 (KV cache)
可以一齊用?✅ 可以!vLLM 已經整合咗 Flash Attention

佢哋解決唔同問題,所以可以同時用:

  • Flash Attention: 令每次 attention 計算快啲
  • PagedAttention: 令記憶體管理好啲,可以 serve 更多 requests

結合埋一齊就係最強組合!

其他 Attention 優化

除咗 Flash Attention 同 PagedAttention,仲有其他方向:

Multi-Query Attention (MQA)

  • 所有 heads share 同一組 K, V
  • 大幅減少 KV cache size
  • 用於 PaLM, StarCoder

Grouped-Query Attention (GQA)

  • MQA 同 Multi-Head 嘅折衷
  • 將 heads 分組,每組 share K, V
  • 用於 LLaMA 2, Mistral

例子: 32 heads

  • Multi-Head: 32 組 KV
  • GQA (8 groups): 8 組 KV
  • MQA: 1 組 KV

GQA 可以減少 KV cache 去 1/4,而 quality 幾乎冇 drop!

局限同改進方向

雖然 Transformer 好強,但係都有問題:

計算複雜度 O(n2)O(n^2)O(n2)

Self-attention 要計所有 token pairs,序列長咗計算量就爆炸。呢個係點解 context window 通常得幾千個 tokens。

改進方向:

  • Sparse attention (Longformer, BigBird)
  • Linear attention (Performer, FNet)
  • 分層處理 (Hierarchical Transformers)
  • Flash Attention 同 PagedAttention 已經大幅改善!

記憶體需求

要 store 成個 attention matrix,長序列會 out of memory。

缺乏 inductive bias

Transformer 冇 built-in 嘅結構假設(唔似 CNN 有 locality, RNN 有 sequentiality),所以需要好多 data 先 train 得好。

我嘅睇法

學咗 Transformer 之後,我覺得呢個架構真係改變咗成個 NLP 領域。由 2017 到而家,幾乎所有 SOTA 模型都係基於 Transformer。

最令我欣賞嘅係佢嘅簡潔性。雖然一開始覺得複雜,但係理解咗之後就發覺每個部分都有佢嘅作用,而且設計得好合理。

而家研究方向主要係點樣處理更長嘅 context (解決 O(n2)O(n^2)O(n2) 問題) 同點樣減少計算資源。我覺得未來可能會見到:

  • 更高效嘅 attention 機制
  • 混合架構(結合 Transformer 同其他方法)
  • 針對特定任務優化嘅變種

如果你啱啱開始學 deep learning for NLP,我會建議由細嘅 Transformer 開始 implement (例如幾層、細 dimension),親手做一次真係會理解得深好多。

參考資料

  • Vaswani, A., et al. (2017). Attention is All You Need. NeurIPS 2017 [呢篇係必讀!]
  • The Illustrated Transformer by Jay Alammar [好正嘅視覺化講解]
  • The Annotated Transformer [有完整 code implementation]
  • Devlin, J., et al. (2018). BERT: Pre-training of Deep Bidirectional Transformers. arXiv:1810.04805
  • Radford, A., et al. (2018). Improving Language Understanding by Generative Pre-Training. OpenAI
  • Raffel, C., et al. (2019). Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer. arXiv:1910.10683
Back to all articles
目錄