Transformer 架構詳解
Transformer 係 2017 年由 Google 提出嘅模型架構,出自經典論文 "Attention is All You Need"。呢個名真係改得好,因為 Transformer 真係將 attention 機制發揮到極致,完全唔用 RNN 或者 CNN。
我第一次學 Transformer 嗰陣覺得好複雜,但係理解咗之後就發覺佢嘅設計真係好巧妙。而家幾乎所有大型語言模型(GPT、BERT、LLaMA 等等)都係基於 Transformer 架構,所以理解佢真係好重要。
點解需要 Transformer?
喺 Transformer 之前,NLP 任務主要用 RNN (Recurrent Neural Networks) 同佢嘅變種(LSTM、GRU)。但係 RNN 有幾個大問題:
順序處理限制
RNN 要一個字一個字咁處理,唔可以平行化。處理 "I love machine learning" 嘅時候,要先處理 "I",再處理 "love",再處理 "machine"……要等前面處理完先可以處理後面。
長程依賴問題
就算有 LSTM,處理好長嘅句子嗰陣,前面嘅信息都會逐漸消失。例如一句有 100 個字嘅句子,第 100 個字好難記得第 1 個字係咩。
訓練慢
因為順序處理,訓練時間好長,就算有好多 GPU 都冇辦法加速。
Transformer 就係為咗解決呢啲問題而設計出嚟。佢用 attention 機制令到:
- ✅ 所有位置可以同時處理(平行化)
- ✅ 任何兩個位置都可以直接交互(解決長程依賴)
- ✅ 訓練速度快好多
Attention 機制核心概念
咩係 Attention?
Attention 嘅核心思想好簡單:當你處理一個字嗰陣,應該要睇返其他邊啲字。
例子:
"The animal didn't cross the street because it was too tired."
當我哋處理 "it" 呢個字嗰陣,要知道 "it" 係指咩。人類會自然咁 attend 返去 "animal" 呢個字,因為佢係 "it" 嘅 antecedent。
Attention 機制就係想 model 呢種關係——俾 model 可以動態咁決定應該關注邊啲字。
Query、Key、Value (QKV) 概念
Attention 用咗三個向量嚟計算:
Query (查詢)
- 代表「我想搵啲咩」
- 係當前 token 想要嘅信息
Key (鍵)
- 代表「我提供咩信息」
- 係每個 token 嘅標識
Value (值)
- 代表「我實際嘅內容」
- 係實際會被提取出嚟嘅信息
類比:餐廳點菜
我成日用呢個比喻嚟記 QKV:
想像你去餐廳:
- Query: 你想要嘅嘢(例如 "我想要啲辣嘅菜")
- Key: 餐牌上面每道菜嘅描述(例如 "麻婆豆腐 - 辣"、"清蒸魚 - 清淡")
- Value: 實際嘅菜(真正上到枱嘅食物)
你會用你嘅 Query("想要辣嘅") 同每道菜嘅 Key(描述) 做比較,搵出最 match 嘅(麻婆豆腐),然後攞返對應嘅 Value(真正嘅麻婆豆腐)。
Self-Attention 數學公式
Self-Attention 嘅計算公式好出名:
Mermaid 流程圖
拆解一下:
步驟 1: 計算相似度
將 Query 同所有 Key 做內積(dot product),計算相似度。相似度越高,個 score 越大。
點解用內積? 因為兩個向量嘅內積可以反映佢哋嘅相似程度。如果兩個向量指向同一個方向,內積會好大。
步驟 2: 縮放
除以 (Key 向量嘅維度開方)。
點解要縮放? 當維度好大嗰陣(例如 ),內積嘅值會變得好大,搞到 softmax 之後啲梯度會好細(gradient 消失問題)。除以 可以穩定訓練。
步驟 3: Softmax
用 softmax 將 scores 轉換成概率分佈,所有權重加埋等於 1。
步驟 4: 加權求和
用 attention weights 對 Values 做加權平均,得出最終輸出。
具體數字例子
假設我哋有句子 "I love AI",每個字已經 embed 成向量。簡化起見,假設維度係 4。
輸入向量 (已經 embed):
- "I":
- "love":
- "AI":
生成 QKV
用三個唔同嘅矩陣 、、 將輸入轉換成 Query、Key、Value:
Q = Input @ W_Q # Query
K = Input @ W_K # Key
V = Input @ W_V # Value
假設轉換之後:
Q = [[1, 0], K = [[1, 1], V = [[2, 0],
[0, 1], [0, 1], [0, 2],
[1, 1]] [1, 0]] [1, 1]]
計算 Attention (處理第一個字 "I")
當我哋處理 "I" 嗰陣,佢嘅 Query 係 。
1. 計算同所有 Keys 嘅相似度:
- Score("I", "I") =
- Score("I", "love") =
- Score("I", "AI") =
2. 縮放 (假設 ,所以除以 ):
- Scaled scores:
3. Softmax:
4. 加權求和:
呢個就係 "I" 經過 self-attention 之後嘅輸出!佢結合咗其他字(特別係 "I" 同 "AI")嘅信息。
Multi-Head Attention
點解要 Multi-Head?
如果只有一個 attention head,model 只可以用一種方式嚟理解句子。但係語言好複雜——有啲時候你想關注語法關係,有啲時候想關注語義關係。
例子:
"The dog chased the cat because it was hungry."
- Head 1 可能專注語法:搵主語、謂語、賓語
- Head 2 可能專注 coreference:"it" 指 "dog"
- Head 3 可能專注語義:理解 "chased" 同 "hungry" 嘅因果關係
Multi-head attention 就係平行咁跑多個 attention,每個 head 可以學習到唔同嘅 pattern。
Multi-Head Attention 運作原理
Mermaid 架構圖
架構
假設我哋有 8 個 heads(Transformer 原本 paper 用 8 個):
def multi_head_attention(x, num_heads=8):
# 1. 將輸入 project 去 Q, K, V
Q = x @ W_Q # [seq_len, d_model]
K = x @ W_K
V = x @ W_V
# 2. 分成多個 heads
# 將 d_model 分成 num_heads 份
# 例如 d_model=512, num_heads=8, 每個 head 得 512/8=64 維
Q = split_heads(Q, num_heads) # [num_heads, seq_len, d_k]
K = split_heads(K, num_heads)
V = split_heads(V, num_heads)
# 3. 每個 head 獨立計算 attention
outputs = []
for i in range(num_heads):
attn_output = attention(Q[i], K[i], V[i])
outputs.append(attn_output)
# 4. Concatenate 所有 heads
concat_output = concatenate(outputs) # [seq_len, d_model]
# 5. 最後一個 linear projection
final_output = concat_output @ W_O
return final_output
維度變化
呢度好多人會搞亂,我自己都花咗時間理解:
假設:
- 輸入維度:
- Heads 數量:
- 每個 head 嘅維度:
過程:
- 輸入:
- Project 去 Q/K/V: 各一個
- 分成 8 個 heads:
- 每個 head 做 attention:
- Concatenate: (因為 )
- 最後 projection:
重點: 每個 head 用嘅參數(矩陣 )係唔同嘅!所以佢哋會學到唔同嘅 attention patterns。
視覺化 Multi-Head Attention
想像你有 8 個 heads 處理句子 "The cat sat on the mat":
Head 1: The → cat (主語-名詞關係)
cat → sat (主語-動詞關係)
sat → mat (動詞-賓語關係)
Head 2: cat → The (限定詞關係)
mat → the (限定詞關係)
Head 3: on → mat (介詞關係)
sat → on (動詞-介詞關係)
...每個 head 專注唔同 patterns
每個 head 學習到嘅 attention pattern 都唔同。有啲 heads 會學到位置關係(attend 隔籬嘅字),有啲會學到語義關係。
Encoder-Decoder 架構
Transformer 原本設計係用嚟做翻譯(machine translation),所以有 Encoder 同 Decoder 兩部分。
Encoder (編碼器)
作用: 理解輸入句子,將佢 encode 成有意義嘅 representation。
結構: 由 N 個相同嘅 layer 疊起嚟(原本 paper 用 N=6)
每個 Encoder layer 有兩個主要部分:
1. Multi-Head Self-Attention
Input → Multi-Head Self-Attention → Add & Norm
- 每個字 attend 去所有其他字(包括自己)
- 理解字與字之間嘅關係
2. Feed-Forward Network (FFN)
→ Feed-Forward → Add & Norm → Output
- 兩層 fully connected network
- 每個位置獨立處理(但係用同一組參數)
class EncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.ffn = FeedForward(d_model, d_ff)
self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)
def forward(self, x):
# Self-attention + residual connection + norm
attn_output = self.self_attn(x, x, x)
x = self.norm1(x + attn_output)
# Feed-forward + residual connection + norm
ffn_output = self.ffn(x)
x = self.norm2(x + ffn_output)
return x
重點:
- Residual connection (殘差連接): 幫助訓練深層網絡
- Layer Normalization: 穩定訓練
Decoder (解碼器)
作用: 根據 encoder 嘅輸出,逐個字咁生成翻譯。
結構: 都係由 N 個相同嘅 layer 疊起嚟
每個 Decoder layer 有三個主要部分:
1. Masked Multi-Head Self-Attention
Output Embedding → Masked Self-Attention → Add & Norm
- 同 encoder 嘅 self-attention 好似,但係有 mask
- 點解要 mask? 因為 decode 嗰陣係由左到右生成,唔可以睇到未來嘅字
- 例如生成第 3 個字嗰陣,只可以睇到第 1 同第 2 個字
2. Cross-Attention (Encoder-Decoder Attention)
→ Cross-Attention → Add & Norm
- Query 嚟自 decoder (當前生成嘅內容)
- Key 同 Value 嚟自 encoder (輸入句子嘅 representation)
- 呢個係 decoder "睇返" 輸入句子嘅機制
3. Feed-Forward Network
→ Feed-Forward → Add & Norm → Output
- 同 encoder 嘅 FFN 一樣
class DecoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff):
super().__init__()
self.masked_self_attn = MultiHeadAttention(d_model, num_heads)
self.cross_attn = MultiHeadAttention(d_model, num_heads)
self.ffn = FeedForward(d_model, d_ff)
self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)
self.norm3 = LayerNorm(d_model)
def forward(self, x, encoder_output, mask):
# Masked self-attention
attn_output = self.masked_self_attn(x, x, x, mask=mask)
x = self.norm1(x + attn_output)
# Cross-attention (Q from decoder, K,V from encoder)
cross_attn_output = self.cross_attn(
query=x,
key=encoder_output,
value=encoder_output
)
x = self.norm2(x + cross_attn_output)
# Feed-forward
ffn_output = self.ffn(x)
x = self.norm3(x + ffn_output)
return x
Encoder vs Decoder 比較
| Encoder | Decoder | |
|---|---|---|
| Self-Attention | Bi-directional (睇晒成句) | Masked (只睇前面) |
| Cross-Attention | 冇 | 有 (attend 去 encoder 輸出) |
| 用途 | 理解輸入 | 生成輸出 |
| 處理方式 | Parallel (一次過處理晒) | Auto-regressive (逐個生成) |
翻譯例子:英文 → 中文
輸入: "I love AI"
輸出: "我 愛 AI"
Encoding 階段
- "I love AI" → Encoder
- 每個 encoder layer 處理:
- Self-attention: 理解 "love" 同 "I" 同 "AI" 嘅關係
- FFN: 提取更高層次嘅特徵
- 最後得到 encoder 輸出:
Decoding 階段 (逐步生成)
Step 1: 生成第一個字
- Decoder input:
[START] - Masked self-attention: 只睇到
[START] - Cross-attention: attend 去成句 "I love AI" (透過 )
- Output: "我"
Step 2: 生成第二個字
- Decoder input:
[START] 我 - Masked self-attention: 睇到
[START]同 "我" - Cross-attention: 再 attend 返去 "I love AI"
- Output: "愛"
Step 3: 生成第三個字
- Decoder input:
[START] 我 愛 - Masked self-attention: 睇到前面所有字
- Cross-attention: attend 去 "I love AI"
- Output: "AI"
Step 4: 生成結束符
- Output:
[END]→ 停止
完整架構圖
Mermaid 流程圖
文字版架構圖
輸入: "I love AI"
↓
Input Embedding + Positional Encoding
↓
┌─────────────────────────────┐
│ Encoder (×6 layers) │
│ ┌─────────────────────┐ │
│ │ Multi-Head Attn │ │
│ │ Add & Norm │ │
│ │ Feed-Forward │ │
│ │ Add & Norm │ │
│ └─────────────────────┘ │
└─────────────────────────────┘
↓ H_enc (encoder output)
輸出: [START]
↓
Output Embedding + Positional Encoding
↓
┌────────────────────────────────────┐
│ Decoder (×6 layers) │
│ ┌────────────────────────────┐ │
│ │ Masked Multi-Head Attn │ │
│ │ Add & Norm │ │
│ │ Cross-Attention ← H_enc │ ← Encoder output
│ │ Add & Norm │ │
│ │ Feed-Forward │ │
│ │ Add & Norm │ │
│ └────────────────────────────┘ │
└────────────────────────────────────┘
↓
Linear + Softmax
↓
輸出概率: P("我") = 0.9, P("你") = 0.05, ...
↓
生成: "我"
Encoder-Only vs Decoder-Only vs Encoder-Decoder
而家唔同嘅模型用唔同嘅架構:
Encoder-Only (例如 BERT)
結構: 只有 Encoder 部分
特點: Bi-directional attention (可以睇前面同後面)
用途:
- 文本分類
- Named Entity Recognition
- Question Answering (理解型任務)
例子: 判斷 "呢個電影好正" 係正面定負面評價
# BERT 用法
input = "呢個電影好正 [MASK]"
output = encoder(input)
# 用 output 做分類或者 fill mask
Decoder-Only (例如 GPT)
結構: 只有 Decoder 部分 (但係冇 cross-attention)
特點: Causal/Auto-regressive (只睇前面)
用途:
- 文本生成
- Completion
- 對話系統
例子: 生成故事
# GPT 用法
input = "Once upon a time"
output = decoder(input) # "there was a princess"
GPT 其實係用咗 decoder 嘅 masked self-attention,但係唔需要 cross-attention (因為冇 encoder 輸出要 attend)。
Encoder-Decoder (例如 T5, BART)
結構: 有齊 Encoder 同 Decoder
特點: 結合兩者優點
用途:
- 機器翻譯
- 文本摘要
- 問答系統
例子: 文本摘要
# T5 用法
input = "好長嘅文章內容..."
encoder_output = encoder(input)
summary = decoder(encoder_output) # "簡短總結"
| 架構 | Attention 類型 | 代表模型 | 最啱做咩 |
|---|---|---|---|
| Encoder-Only | Bi-directional | BERT, RoBERTa | 理解、分類 |
| Decoder-Only | Causal (單向) | GPT, LLaMA | 生成 |
| Encoder-Decoder | Bi-dir + Causal + Cross | T5, BART | 翻譯、摘要 |
其他重要組件
Positional Encoding (位置編碼)
Attention 本身冇位置信息——"I love AI" 同 "AI love I" 對佢嚟講冇分別(因為都係計 token 之間嘅相似度)。
所以要加 positional encoding 話俾 model 知每個字嘅位置:
- : 位置 (0, 1, 2, ...)
- : 維度 index
- : embedding 維度
呢個設計令到:
- 每個位置有獨特嘅 encoding
- 相對位置可以透過 sin/cos 嘅性質計算出嚟
我之前寫咗篇 RoPE 嘅文章,嗰個係 positional encoding 嘅改進版,有興趣可以睇返!
Feed-Forward Network (FFN)
每個 encoder/decoder layer 入面都有個 FFN:
- 兩層 fully connected
- 中間有 ReLU activation
- 中間層嘅維度通常係 嘅 4 倍(例如 512 → 2048 → 512)
作用:
- Attention 負責 "邊個字同邊個字有關係"
- FFN 負責 "對每個位置做 transformation"
- 兩者配合,先可以有強大嘅表達能力
Layer Normalization
每個 sub-layer 之後都會做 Layer Norm:
- : mean
- : standard deviation
- : learnable parameters
作用: 穩定訓練,避免梯度爆炸/消失
Residual Connection (殘差連接)
每個 sub-layer 都有 residual connection:
呢個係由 ResNet 引入嘅概念,對訓練深層網絡好重要。冇佢嘅話,疊 6 層都會好難 train。
訓練同推理
訓練時 (Training)
Encoder: 一次過處理成句輸入
Decoder: 都係一次過處理(用 teacher forcing)
Teacher Forcing: 訓練時,就算 decoder 第 2 步生成錯咗,第 3 步都係用正確嘅答案做 input,而唔係用自己生成嘅錯誤輸出。
例子:
- Target: "我 愛 AI"
- Step 1: Input
[START]→ Output "我" ✓ - Step 2: Input
[START] 我→ Output "喜歡" ✗ (錯咗,正確係 "愛") - Step 3: Input
[START] 我 愛← 用返正確答案,唔係 "喜歡"
呢個可以加快訓練,但係會有 exposure bias 問題。
推理時 (Inference)
Encoder: 一次過處理輸入
Decoder: Auto-regressive (逐個字生成)
每生成一個字,就將佢加入 input,再生成下一個字,直到出現 [END] token。
通常會用 beam search 而唔係 greedy decoding,可以搵到更好嘅結果。
點解 Transformer 咁成功?
我自己覺得 Transformer 成功有幾個原因:
1. 平行化
唔似 RNN 要順序處理,Transformer 可以同時計所有位置,訓練快好多。我試過 train LSTM 同 Transformer,速度差好遠。
2. 長程依賴
任何兩個位置都可以直接交互(透過 attention),唔似 RNN 要經過好多步先連接到。
3. 可解釋性
可以 visualize attention weights,睇到 model 關注邊啲字。呢個對理解 model 行為好有用。
4. 靈活性
可以輕易調整成 encoder-only、decoder-only 或者 encoder-decoder,適應唔同任務。
5. Scalability
增加 layers、增加 heads、增加維度都好容易,而且效果會越嚟越好(scaling laws)。
優化版 Attention 機制
標準 self-attention 有 嘅計算複雜度同記憶體需求,所以研究者提出咗好多優化方法。呢度講下兩個最重要嘅優化:Flash Attention 同 PagedAttention。
Flash Attention
問題:記憶體存取係瓶頸
我一開始以為 attention 慢係因為計算量大,但其實真正嘅瓶頸係 memory access (記憶體存取)。
現代 GPU 有幾種記憶體:
- HBM (High Bandwidth Memory): 大但慢(例如 40GB A100)
- SRAM (on-chip memory): 細但快好多(例如 20MB)
標準 attention 實現會:
- 由 HBM 讀 Q, K → 計 → 寫返 HBM
- 由 HBM 讀 → 計 softmax → 寫返 HBM
- 由 HBM 讀 softmax 結果同 V → 計最終輸出 → 寫返 HBM
每次都要讀寫 HBM,好慢!而且要 store 成個 attention matrix 落 HBM,好食記憶體。
Flash Attention 嘅解決方法
Flash Attention (由 Tri Dao 等人喺 2022 年提出) 用咗兩個關鍵技巧:
1. Tiling (分塊計算)
唔好一次過計成個 attention matrix,而係將 Q, K, V 分成細塊,逐塊咁計。每塊可以 fit 入去 SRAM,減少 HBM 存取。
2. Recomputation (重新計算)
Backward pass 嗰陣唔 store attention matrix,而係 recompute 佢。雖然多咗計算,但係少咗記憶體存取,overall 反而快咗!
算法流程
簡化版嘅 Flash Attention 流程:
def flash_attention(Q, K, V, block_size):
"""
Q, K, V: [seq_len, d_model]
block_size: 每次處理幾多個 tokens (例如 128)
"""
seq_len = Q.shape[0]
output = zeros_like(Q)
# 將 Q 分成 blocks
for i in range(0, seq_len, block_size):
Q_block = Q[i:i+block_size] # Load 入 SRAM
# 對於呢個 Q block,逐個 K/V block 咁處理
block_output = zeros_like(Q_block)
block_max = -inf
block_sum = 0
for j in range(0, seq_len, block_size):
K_block = K[j:j+block_size] # Load 入 SRAM
V_block = V[j:j+block_size]
# 計呢個 block 嘅 attention
scores = Q_block @ K_block.T / sqrt(d_k)
# Online softmax (唔使 store 成個 matrix)
new_max = max(block_max, scores.max())
scores = exp(scores - new_max)
# 更新 output
correction = exp(block_max - new_max)
block_output = block_output * correction + scores @ V_block
block_sum = block_sum * correction + scores.sum()
block_max = new_max
# Normalize
output[i:i+block_size] = block_output / block_sum
return output
關鍵點:
- 一次只 load 細塊入 SRAM,減少 HBM 存取
- 用 online softmax 技巧,唔使 store 成個 attention matrix
- 唔寫中間結果返 HBM
效能提升
Flash Attention 嘅改進真係好明顯:
速度:
- GPT-2 (1024 tokens): 快 2-4x
- 長序列 (2048+ tokens): 快 5-9x
記憶體:
- 記憶體使用由 降到
- 可以 train 更長嘅 sequence
例子:
喺 A100 GPU 上面,標準 attention 最多 handle 到 1024 tokens (batch size=16),但係 Flash Attention 可以去到 4096 tokens!
我自己試過用 Flash Attention train 長文檔模型,記憶體省咗好多,可以用大啲嘅 batch size,訓練快咗成倍。
Flash Attention 2 & 3
Flash Attention 2 (2023):
- 進一步優化 parallelism
- 減少 non-matmul operations
- 再快 2x
Flash Attention 3 (2024):
- 針對 Hopper GPU (H100/H800)
- 利用新硬件特性(例如 wgmma 指令)
- 達到接近理論上限嘅速度
而家好多 framework 都內置咗 Flash Attention:
import torch
import torch.nn.functional as F
# PyTorch 2.0+ 內置支援
output = F.scaled_dot_product_attention(
query, key, value,
attn_mask=None,
is_causal=False
) # 會自動用 Flash Attention!
PagedAttention
問題:Inference 嘅記憶體浪費
當你用 LLM 做 inference (例如 ChatGPT 咁逐個字生成) 嗰陣,要 cache 之前所有 tokens 嘅 Key 同 Value (KV cache),避免重複計算。
KV Cache 記憶體計算:
對於一個 13B 參數嘅模型(例如 LLaMA-13B):
- 每個 token 嘅 KV cache: ~800 KB
- 生成 2048 tokens: ~1.6 GB per request
- 處理 100 個 concurrent requests: ~160 GB!
問題係:我哋要預先 allocate 記憶體俾最長可能嘅 sequence (例如 2048 tokens),就算實際只用咗 100 tokens 都要 reserve 晒。呢個叫做 internal fragmentation (內部碎片化)。
另外,就算 requests 長度唔同,都好難 share 記憶體,造成 external fragmentation (外部碎片化)。
研究發現,實際上只有 20-40% 嘅 KV cache 記憶體真係有用緊!
PagedAttention 嘅解決方法
PagedAttention (由 UC Berkeley 嘅 vLLM 團隊喺 2023 年提出) 借鑑咗操作系統嘅 virtual memory (虛擬記憶體) 概念。
核心思想:
- 將 KV cache 分成固定大小嘅 blocks/pages (例如每個 block 存 16 tokens)
- 用 page table 嚟 map logical blocks → physical blocks
- 唔同 requests 可以 share 相同嘅 blocks (例如 shared prompt)
- 按需 allocate blocks,冇 internal fragmentation
視覺化例子
假設我哋要生成呢個 request:
Prompt: "Translate to French: Hello, how are you?"
Generation: "Bonjour, comment allez-vous?"
傳統方法 (Contiguous Memory)
記憶體: [________________________________]
↑ 預先 allocate 2048 tokens 嘅空間
實際用咗: [Prompt tokens][Generated tokens][______浪費______]
12 tokens + 6 tokens = 18/2048 (0.9% 使用率!)
PagedAttention (Paged Memory)
Logical view: [Block 0] → [Block 1] → [Block 2] → ...
Prompt Prompt Generated
tokens tokens tokens
(16) (2) (6)
Physical memory:
Page table:
Logical Block 0 → Physical Block 5 [16 tokens]
Logical Block 1 → Physical Block 12 [2 tokens]
Logical Block 2 → Physical Block 3 [6 tokens]
只 allocate 需要嘅 blocks,冇浪費!
Shared Prefix 優化
更勁嘅係,如果多個 requests 有相同嘅 prefix,可以 share blocks!
例子: Batch translation
Request 1: "Translate to French: Hello" → "Bonjour"
Request 2: "Translate to French: Goodbye" → "Au revoir"
Request 3: "Translate to French: Thank you" → "Merci"
所有 requests 都 share "Translate to French: " 嘅 KV cache blocks:
Shared blocks (read-only):
["Translate"] ["to"] ["French"] [":"]
↑ ↑ ↑ ↑
└───────────┴────────┴────────┴─── Shared by all 3 requests!
Private blocks (copy-on-write):
Request 1: ["Hello"] ["Bonjour"]
Request 2: ["Goodbye"] ["Au revoir"]
Request 3: ["Thank you"] ["Merci"]
呢個對 few-shot prompting 特別有用,因為好多 requests 會 share 同一個長 prompt!
實際效能
vLLM (用咗 PagedAttention) 同其他 serving systems 比較:
Throughput (每秒處理幾多 requests):
- 比 FasterTransformer 快 2-4x
- 比 Hugging Face Text Generation Inference 快 2-3x
- 比 naive PyTorch 快 8-10x
記憶體使用:
- 減少 55-80% 記憶體浪費
- 可以同時處理更多 requests
例子:
喺 A100 (40GB) 上面 serve LLaMA-13B:
- 傳統方法: 最多 ~20 concurrent requests
- vLLM (PagedAttention): 可以去到 50+ concurrent requests
我喺 production 用過 vLLM,真係可以省返好多 GPU,throughput 提升明顯。特別係當你有好多 requests 用相同 prompt 嗰陣(例如 chatbot with system prompt),memory sharing 嘅效果好好。
PagedAttention 實現
核心係改動 attention 計算,令佢可以 access 非連續嘅 KV blocks:
def paged_attention(Q, K_pages, V_pages, page_table, block_size=16):
"""
Q: [batch, num_heads, 1, head_dim] # 當前 token
K_pages: [num_physical_blocks, block_size, num_heads, head_dim]
V_pages: [num_physical_blocks, block_size, num_heads, head_dim]
page_table: [batch, max_num_blocks] # logical → physical mapping
"""
batch_size = Q.shape[0]
outputs = []
for b in range(batch_size):
# Get physical block indices for this request
physical_blocks = page_table[b] # [num_blocks_used]
# Gather K and V from physical blocks
K_seq = []
V_seq = []
for block_idx in physical_blocks:
if block_idx >= 0: # -1 means unused
K_seq.append(K_pages[block_idx])
V_seq.append(V_pages[block_idx])
K_concat = concat(K_seq, dim=0) # [seq_len, num_heads, head_dim]
V_concat = concat(V_seq, dim=0)
# Standard attention
scores = Q[b] @ K_concat.T / sqrt(head_dim)
attn_weights = softmax(scores, dim=-1)
output = attn_weights @ V_concat
outputs.append(output)
return stack(outputs)
實際 implementation 會用 CUDA kernel 嚟優化 gather 操作,令存取非連續記憶體都好快。
Flash Attention vs PagedAttention 比較
| Flash Attention | PagedAttention | |
|---|---|---|
| 目標 | 加快 attention 計算 | 減少 KV cache 記憶體浪費 |
| 主要用途 | Training + Inference | Inference (serving) |
| 優化對象 | 計算效率 + 記憶體存取 | 記憶體管理 |
| 技術 | Tiling + Recomputation | Virtual memory + Paging |
| 加速 | 2-9x faster | 2-4x throughput |
| 記憶體節省 | (attention matrix) | 55-80% 減少浪費 (KV cache) |
| 可以一齊用? | ✅ 可以!vLLM 已經整合咗 Flash Attention |
佢哋解決唔同問題,所以可以同時用:
- Flash Attention: 令每次 attention 計算快啲
- PagedAttention: 令記憶體管理好啲,可以 serve 更多 requests
結合埋一齊就係最強組合!
其他 Attention 優化
除咗 Flash Attention 同 PagedAttention,仲有其他方向:
Multi-Query Attention (MQA)
- 所有 heads share 同一組 K, V
- 大幅減少 KV cache size
- 用於 PaLM, StarCoder
Grouped-Query Attention (GQA)
- MQA 同 Multi-Head 嘅折衷
- 將 heads 分組,每組 share K, V
- 用於 LLaMA 2, Mistral
例子: 32 heads
- Multi-Head: 32 組 KV
- GQA (8 groups): 8 組 KV
- MQA: 1 組 KV
GQA 可以減少 KV cache 去 1/4,而 quality 幾乎冇 drop!
局限同改進方向
雖然 Transformer 好強,但係都有問題:
計算複雜度
Self-attention 要計所有 token pairs,序列長咗計算量就爆炸。呢個係點解 context window 通常得幾千個 tokens。
改進方向:
- Sparse attention (Longformer, BigBird)
- Linear attention (Performer, FNet)
- 分層處理 (Hierarchical Transformers)
- Flash Attention 同 PagedAttention 已經大幅改善!
記憶體需求
要 store 成個 attention matrix,長序列會 out of memory。
缺乏 inductive bias
Transformer 冇 built-in 嘅結構假設(唔似 CNN 有 locality, RNN 有 sequentiality),所以需要好多 data 先 train 得好。
我嘅睇法
學咗 Transformer 之後,我覺得呢個架構真係改變咗成個 NLP 領域。由 2017 到而家,幾乎所有 SOTA 模型都係基於 Transformer。
最令我欣賞嘅係佢嘅簡潔性。雖然一開始覺得複雜,但係理解咗之後就發覺每個部分都有佢嘅作用,而且設計得好合理。
而家研究方向主要係點樣處理更長嘅 context (解決 問題) 同點樣減少計算資源。我覺得未來可能會見到:
- 更高效嘅 attention 機制
- 混合架構(結合 Transformer 同其他方法)
- 針對特定任務優化嘅變種
如果你啱啱開始學 deep learning for NLP,我會建議由細嘅 Transformer 開始 implement (例如幾層、細 dimension),親手做一次真係會理解得深好多。
參考資料
- Vaswani, A., et al. (2017). Attention is All You Need. NeurIPS 2017 [呢篇係必讀!]
- The Illustrated Transformer by Jay Alammar [好正嘅視覺化講解]
- The Annotated Transformer [有完整 code implementation]
- Devlin, J., et al. (2018). BERT: Pre-training of Deep Bidirectional Transformers. arXiv:1810.04805
- Radford, A., et al. (2018). Improving Language Understanding by Generative Pre-Training. OpenAI
- Raffel, C., et al. (2019). Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer. arXiv:1910.10683