Rotary Position Embeddings (RoPE)
最近研究 Transformer 架構,發覺 RoPE 呢個位置編碼方法真係幾有趣。佢係 2021 年先提出嚟,但係而家好多大型語言模型都用緊,包括 LLaMA、PaLM、GLM 等等。我覺得 RoPE 之所以咁受歡迎,係因為佢解決咗傳統位置編碼嘅好多問題。
📑 目錄
傳統位置編碼嘅問題
一開始學 Transformer 嗰陣,我都覺得 sinusoidal position embeddings 好似已經幾好。但係用落先知有幾個限制:
缺乏相對位置信息
傳統嘅絕對位置編碼(absolute position embeddings)只係話俾個 model 知每個 token 喺第幾個位置,但係佢哋之間嘅相對距離就唔夠明確。
外推能力唔好
如果你 train 個 model 用 512 tokens,之後想用 1024 tokens,performance 會跌好多。呢個對於處理長文檔嚟講係個大問題。
計算開銷
有啲方法(例如 learned position embeddings)要額外參數,增加咗 model size。
RoPE 嘅核心思想
RoPE 嘅做法好有趣——佢將位置信息編碼成旋轉操作。喺複數空間入面旋轉 query 同 key 向量,咁樣就可以自然咁注入相對位置信息。
💡 關鍵洞察:如果將位置 嘅向量旋轉 度,位置 嘅向量旋轉 度,佢哋嘅內積就會自動包含相對位置 嘅信息!
呢個設計真係好巧妙,利用咗旋轉矩陣嘅數學特性嚟達到目的。
點解係旋轉?Sin/Cos 同圓形嘅特性
你可能會問:點解偏偏要用旋轉嚟 encode 位置?其他方法唔得咩?
答案係旋轉有幾個獨特嘅數學特性,啱晒我哋想要嘅效果:
1. 旋轉嘅群結構(Group Property)
旋轉矩陣有個好正嘅性質:兩次旋轉可以合併成一次旋轉:
呢個加法性質係 key!因為位置本身都係加法嘅(位置 5 = 位置 0 + 5 步)。
2. 圓形嘅週期性
想像喺個圓形上面行:
- 行 360° 就返回原點
- 位置可以用角度表示
- 角度 30° 同 390° 其實係同一個位置
呢個週期性好有用,因為語言都有週期性 pattern(例如每隔 k 個字就有個逗號)。
3. 轉置即反向旋轉
旋轉矩陣有個神奇性質:
所以:
**呢個就係點解 attention 會自動得到相對位置!**你將位置 嘅向量旋轉 ,位置 嘅向量旋轉 ,計內積嗰陣:
- 先將第一個向量轉返去()
- 再乘第二個向量()
- 結果就係相對角度 !
4. 複數/歐拉公式嘅連結
旋轉其實同複數乘法有深層連繫。歐拉公式話:
喺複數平面,乘以 就係旋轉 度!所以:
呢個係點解 sin/cos 特別啱——佢哋本質上就係複數嘅實部同虛部,天生就係做緊旋轉!
5. 內積保持長度
旋轉係 orthogonal transformation(正交變換),唔會改變向量長度:
咁樣就唔會因為位置編碼而搞亂原本 query/key 向量嘅 magnitude,只係改變方向。
類比:時鐘
諗返個時鐘:
- 時針指住 3 點(90°)
- 分針指住 12 點(0°)
- 佢哋之間嘅角度差就係 90°
如果兩個指針一齊旋轉(例如過咗 1 小時),角度差仍然係 90°!呢個就係相對位置嘅 invariance。
RoPE 用同樣嘅原理:將每個 token 當成指針,位置決定旋轉角度,咁樣 attention 計嘅就係角度差(相對位置)而唔係絕對角度。
數學原理
旋轉矩陣
對於位置 嘅 query 向量 同位置 嘅 key 向量 ,RoPE 定義旋轉操作:
其中旋轉矩陣 對於二維情況定義為:
點解會有相對位置?
計算 attention score 嗰陣:
由於旋轉矩陣嘅性質,,所以:
呢個就係精髓所在!Attention score 只係依賴相對位置 ,而唔係絕對位置!
我第一次睇到呢個推導嗰陣真係有種 "原來可以咁" 嘅感覺。用旋轉矩陣嘅數學特性嚟自然咁得到相對位置,唔使特登設計額外嘅機制。
具體矩陣例子
為咗更清楚理解,我哋睇個實際例子。假設用 2D 向量(即 ),旋轉頻率 (即 30°)。
位置 0 嘅旋轉矩陣 (,唔旋轉):
位置 1 嘅旋轉矩陣 (,旋轉 30°):
位置 2 嘅旋轉矩陣 (,旋轉 60°):
位置 3 嘅旋轉矩陣 (,旋轉 90°):
完整計算例子
假設我哋有原始 query 同 key 向量:
場景 1:位置 1 嘅 query attend 去位置 1 嘅 key (相對距離 = 0)
旋轉後:
Attention score:
場景 2:位置 1 嘅 query attend 去位置 3 嘅 key (相對距離 = 2)
同上,而:
Attention score:
驗證相對位置特性
關鍵嘅洞察:我哋可以直接用相對旋轉矩陣計算!
驗證:
完全一樣!呢個證明咗 attention score 真係只係取決於相對位置差 ,而唔係絕對位置!
視覺化:矩陣運算分解
將成個過程拆開睇:
# 場景:位置 m=1 attend 去位置 n=3
# 方法 1:直接計算(實際 implementation)
score = (R_1 @ q).T @ (R_3 @ k)
= q_1.T @ k_3
= 0.5
# 方法 2:利用旋轉矩陣性質
score = q.T @ (R_1.T @ R_3) @ k
= q.T @ R_(3-1) @ k
= q.T @ R_2 @ k
= 0.5 # 一樣!
關鍵觀察:
- 方法 1 用絕對位置 分別旋轉 query 同 key
- 方法 2 直接用相對位置 旋轉 key
- 兩者數學上等價!
呢個特性令 model 可以學到「兩個 token 之間距離 2」呢個概念,而唔使理會佢哋嘅絕對位置係咩。所以位置 1→3 同位置 10→12 會有一樣嘅 attention pattern,因為相對距離都係 2!
高維擴展
對於 維向量,RoPE 將佢分成 對,每對用唔同嘅旋轉頻率:
呢個公式其實同 Transformer 原本嘅 sinusoidal embeddings 好似,都係用唔同頻率嚟捕捉唔同尺度嘅位置信息。
完整嘅旋轉矩陣係塊對角矩陣:
3D 同高維旋轉擴展
你可能會諗:RoPE 用 2D 旋轉(圓形),咁3D 旋轉(球面)得唔得?或者更高維度?
答案係**有研究做緊呢方面!**特別係處理空間數據(images、videos、3D point clouds)嗰陣,3D rotational embeddings 好有用。
點解要 3D?
2D RoPE 嘅限制:
- 每對維度係獨立旋轉(block-diagonal 矩陣)
- 只能 model 1D sequence(文字序列、時間序列)
- 對於 2D/3D 空間數據(例如圖像、視頻),唔夠 expressive
3D 旋轉嘅應用場景:
- 📸 圖像: 2D 空間(x, y 座標)
- 🎥 視頻: 2D 空間 + 1D 時間 = 3D
- ☁️ 3D Point Clouds: 真 3D 空間座標
- 🧬 分子結構: 原子 3D 位置
兩種主流做法
1. Axis-wise 3D RoPE
最直接嘅方法——對每個空間軸(x, y, z)獨立做 2D 旋轉:
# 分開處理每個軸
theta_x = position_x * freq
theta_y = position_y * freq
theta_z = position_z * freq
# 每個軸都做 RoPE 旋轉
rotated_x = apply_rope_2d(x, theta_x)
rotated_y = apply_rope_2d(y, theta_y)
rotated_z = apply_rope_2d(z, theta_z)
優點:
✅ Implementation 簡單,直接 extend RoPE
✅ 計算高效,同 2D RoPE 複雜度一樣
缺點:
❌ 軸之間獨立,捕捉唔到 cross-axis 嘅關係
❌ 唔係真正嘅 3D 旋轉(只係 3 個 2D 旋轉夾埋)
2. Joint 3D RoPE (SO(3) 群)
用真正嘅 3D 旋轉群 SO(3) (Special Orthogonal Group):
關鍵挑戰: 3D 旋轉唔 commutative(唔可交換)!
- 2D: ✅ 可交換
- 3D: ❌ 唔可交換
點解唔可交換?
想像你拎住本書:
- 先向右轉 90°,再向前傾 90°
- 先向前傾 90°,再向右轉 90°
最終方向唔同!呢個就係 non-commutativity。
Quaternions(四元數)嚟拯救
由於 3D 旋轉唔 commutative,我哋需要新工具——Quaternions!
咩係 Quaternion?
Quaternion 係 4D 超複數:
其中
點解用 Quaternion?
- ✅ 可以無 gimbal lock 咁 represent 3D 旋轉
- ✅ 插值(interpolation)順滑
- ✅ 計算比 3×3 旋轉矩陣高效(4 個數 vs 9 個數)
- ✅ 數值穩定
Quaternion Transformer
有論文提出用 quaternion 做 transformer 嘅表示:
class QuaternionAttention(nn.Module):
def __init__(self, dim):
# 每個 feature 用 quaternion 表示
self.dim = dim // 4 # 分成 4 份: w, x, y, z
def quaternion_rotation(self, q, position):
# 用 quaternion 乘法做旋轉
w, x, y, z = split_quaternion(q)
# 構造 rotation quaternion
theta = position * self.freq
rot_q = angle_to_quaternion(theta)
# Quaternion 乘法
return quaternion_multiply(rot_q, q)
效果:
- 參數減少 75%(因為 quaternion 共享結構)
- 某啲 3D 任務 performance 仲好過 standard transformer
LieRE: 任意維度嘅推廣
LieRE (Lie group Relative Encodings) 將 RoPE 推廣到 n 維
核心 idea
用 Lie group theory(李群理論):
- 每個維度對應一個 Lie algebra 元素
- 通過 exponential map 生成旋轉
- 自動滿足群結構
其中 係 Lie algebra 嘅 generators, 係位置座標。
厲害之處:
- ✅ 自動保證相對位置性質:
- ✅ 適用於任意維度(1D, 2D, 3D, ...)
- ✅ 有數學保證(唔係 heuristic)
實驗結果
喺 2D/3D image classification 任務:
- CIFAR-10: +10.5% accuracy vs standard position embeddings
- ImageNet: +2.3% top-1 accuracy
- 3D point cloud classification: +15% accuracy
實際應用
1. Video Understanding
視頻有 2D 空間 + 1D 時間 = 3D structure:
# Video tokens: [batch, time, height, width, channels]
t, h, w = position_indices # 3D 座標
# 用 3D RoPE
q = apply_3d_rope(query, t, h, w)
k = apply_3d_rope(key, t, h, w)
attention = q @ k.T # 自動 encode 時空相對位置
2. Molecular Property Prediction
分子結構係 3D,用 SO(3)-equivariant attention:[2]
- 旋轉分子結構,prediction 唔變(rotation equivariance)
- 比 graph neural networks 更準確
3. 3D Point Cloud Understanding
處理 LiDAR、3D scan 數據,axis-wise 3D RoPE 有幫助捕捉空間關係。
實作難度
相比 2D RoPE:
| Method | Implementation | 計算開銷 | 表達能力 |
|---|---|---|---|
| 2D RoPE | ⭐ 簡單 | ⭐ 低 | ⭐⭐ 1D sequence |
| Axis-wise 3D | ⭐⭐ 中等 | ⭐⭐ 中等 | ⭐⭐⭐ 3D 獨立軸 |
| Quaternion | ⭐⭐⭐ 複雜 | ⭐⭐ 中等 | ⭐⭐⭐⭐ 真 3D 旋轉 |
| LieRE | ⭐⭐⭐⭐ 好複雜 | ⭐⭐⭐ 較高 | ⭐⭐⭐⭐⭐ 任意維度 |
我嘅睇法
對於語言模型(1D sequence),標準 2D RoPE 已經夠好。
但係對於多模態模型(vision + language)或者3D understanding 任務,3D/高維 rotational embeddings 好有潛力。特別係 video understanding 同埋 embodied AI(機器人),空間推理好重要。
Quaternion approach 我覺得幾 elegant,數學上靚,又有實際 benefit(減參數)。如果你做緊 3D 相關嘅研究,值得一試!
LieRE 就比較 academic,implementation 難度高,但係理論上好完整。可能要等多啲 library support 先會普及。
實際實現
高效計算
喺實際實現入面,我哋唔使真係 construct 成個旋轉矩陣出嚟。對於向量 ,旋轉可以高效咁計算:
def apply_rotary_emb(x, position, theta):
"""
Apply rotary position embeddings
Args:
x: input tensor [batch, seq_len, dim]
position: position indices [seq_len]
theta: frequency values [dim/2]
"""
# Split into pairs
x1, x2 = x[..., ::2], x[..., 1::2]
# Compute angles
angles = position[:, None] * theta[None, :]
cos_angles = torch.cos(angles)
sin_angles = torch.sin(angles)
# Apply rotation
rotated_x1 = x1 * cos_angles - x2 * sin_angles
rotated_x2 = x1 * sin_angles + x2 * cos_angles
# Interleave back
return torch.stack([rotated_x1, rotated_x2], dim=-1).flatten(-2)
我自己 implement 過一次,發覺其實唔算複雜。主要係將向量分成對,分別做旋轉,之後 interleave 返埋一齊。
PyTorch 完整例子
import torch
import torch.nn as nn
class RotaryPositionEmbedding(nn.Module):
def __init__(self, dim, max_seq_len=2048, base=10000):
super().__init__()
self.dim = dim
# Compute theta for each dimension pair
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
# Precompute for max sequence length
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer('cos_cached', emb.cos())
self.register_buffer('sin_cached', emb.sin())
def forward(self, q, k, seq_len):
cos = self.cos_cached[:seq_len]
sin = self.sin_cached[:seq_len]
q_rot = (q * cos) + (self.rotate_half(q) * sin)
k_rot = (k * cos) + (self.rotate_half(k) * sin)
return q_rot, k_rot
@staticmethod
def rotate_half(x):
x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
return torch.cat([-x2, x1], dim=-1)
呢個 implementation 仲做咗 precompute 同 cache,所以 inference 嗰陣會快好多。
Blockwise Parallelism 同分佈式計算
訓練大型語言模型嗰陣,單張 GPU 根本唔夠用。呢個時候就要用分佈式訓練(distributed training),而 RoPE 嘅設計其實好適合並行計算!
點解 RoPE 啱並行計算?
RoPE 有幾個特性令佢特別適合 parallelism:
1. Local Operation (局部操作)
RoPE 係 element-wise 或者 pair-wise operation:
- 每個位置嘅旋轉係獨立嘅
- 唔需要 global synchronization
- 可以喺唔同 GPU 同時計算唔同 tokens
2. No Learnable Parameters (無需學習參數)
- 唔使儲存額外嘅 embedding table
- 唔需要同步 parameter updates
- 計算完全由 position index 決定
3. Deterministic Computation (確定性計算)
- 只要知道 position,就可以計出 RoPE
- 唔同 GPU 可以獨立計算,結果一致
- 唔使傳遞額外資訊
三種主要嘅 Parallelism
訓練大型 transformer 通常會結合三種 parallelism:
1. Data Parallelism (數據並行)
做法:
- 每張 GPU 攞唔同 batch 嘅數據
- 每張 GPU 都有完整嘅 model copy
- Forward/backward 之後 synchronize gradients
RoPE 嘅優勢:
# GPU 0 處理 batch 0
positions_0 = torch.arange(0, seq_len) # [0, 1, 2, ...]
q_0 = apply_rope(query_0, positions_0)
# GPU 1 處理 batch 1
positions_1 = torch.arange(0, seq_len) # 一樣係 [0, 1, 2, ...]
q_1 = apply_rope(query_1, positions_1)
# 唔需要額外通訊!每張 GPU 獨立計算
因為 RoPE 無 learnable parameters,data parallelism 嘅通訊開銷更低。
2. Tensor Parallelism (張量並行)
做法:
- 將 model 嘅維度(例如 hidden dimension)切開
- 唔同 GPU 負責唔同維度嘅計算
- Attention heads 可以分散到多張 GPU
RoPE 嘅實現:
# 例如:16 個 attention heads,分散到 4 張 GPU
# GPU 0: heads 0-3
# GPU 1: heads 4-7
# GPU 2: heads 8-11
# GPU 3: heads 12-15
# 每張 GPU 獨立計 RoPE
class TensorParallelAttention(nn.Module):
def __init__(self, total_heads, rank, world_size):
self.heads_per_gpu = total_heads // world_size
self.head_start = rank * self.heads_per_gpu
def forward(self, x, positions):
# 每張 GPU 只處理自己嘅 heads
q = self.q_proj(x) # 已經切好
k = self.k_proj(x)
# RoPE 獨立計算,無需通訊
q_rot = apply_rope(q, positions)
k_rot = apply_rope(k, positions)
# Attention 計算
attn = q_rot @ k_rot.T
...
關鍵: 因為 RoPE 係 position-wise operation,唔同 heads 嘅 RoPE 可以完全獨立計算,唔需要跨 GPU 通訊!
3. Sequence Parallelism (序列並行)
做法:
- 將序列長度切開
- 每張 GPU 處理部分 tokens
- 特別適合超長序列(例如 32k+ tokens)
RoPE 嘅天生優勢:
# 例如:sequence length = 8192,分散到 4 張 GPU
# GPU 0: tokens 0-2047
# GPU 1: tokens 2048-4095
# GPU 2: tokens 4096-6143
# GPU 3: tokens 6144-8191
class SequenceParallelAttention(nn.Module):
def __init__(self, rank, world_size, seq_len):
self.rank = rank
self.chunk_size = seq_len // world_size
self.start_pos = rank * self.chunk_size
def forward(self, x_chunk):
# 計算呢個 chunk 嘅 position indices
positions = torch.arange(
self.start_pos,
self.start_pos + self.chunk_size
)
# 直接用 absolute positions 計 RoPE
q_rot = apply_rope(q_chunk, positions)
k_rot = apply_rope(k_chunk, positions)
# Attention 需要 all-to-all communication
# 但 RoPE 本身唔使!
...
點解 RoPE 特別適合?
因為 RoPE 用相對位置,所以:
- GPU 0 嘅 token 0 同 GPU 3 嘅 token 7000 做 attention 嗰陣
- 相對位置係 7000 - 0 = 7000
- 呢個相對位置係自然保留嘅,唔使額外處理!
如果用 learned position embeddings,就要確保每張 GPU 都有完整嘅 embedding table,麻煩好多。
Ring Attention 同 RoPE
Ring Attention 係一種特別嘅 sequence parallelism,用 ring topology 嚟減少通訊:
# 簡化版 Ring Attention 概念
class RingAttention:
def __init__(self, rank, world_size):
self.rank = rank
self.world_size = world_size
def forward(self, q_local, k_local, v_local, start_pos):
# 本地 query (唔郁)
positions_q = torch.arange(start_pos, start_pos + len(q_local))
q_rot = apply_rope(q_local, positions_q)
output = torch.zeros_like(q_local)
# Ring: 將 K, V 傳一圈
for step in range(self.world_size):
# 計算當前 K 嘅 position offset
k_rank = (self.rank + step) % self.world_size
k_start_pos = k_rank * len(k_local)
positions_k = torch.arange(k_start_pos, k_start_pos + len(k_local))
k_rot = apply_rope(k_local, positions_k)
# 計算 attention
scores = q_rot @ k_rot.T
attn_weights = softmax(scores)
output += attn_weights @ v_local
# 傳俾下一個 GPU
k_local, v_local = send_recv_ring(k_local, v_local)
return output
RoPE 嘅關鍵作用:
- 每個 step,K 嘅 position indices 都唔同
- RoPE 可以動態計算對應嘅旋轉
- 唔使預先儲存所有位置嘅 embeddings
Ring Attention 嘅限制 ⚠️
雖然 Ring Attention 好有用,但係有個嚴重問題:
❌ 唔支持 Causal Masking (因果遮罩)
喺 autoregressive language models (例如 GPT),我哋需要 causal masking:每個 token 只能 attend 去自己同之前嘅 tokens,唔能偷望未來。
# Causal mask 例子
# Token 3 只能 attend 去 [0, 1, 2, 3]
# Token 5 只能 attend 去 [0, 1, 2, 3, 4, 5]
點解 Ring Attention 有問題?
假設我哋有 4 張 GPU:
- GPU 0: tokens 0-255 (positions 0-255)
- GPU 1: tokens 256-511 (positions 256-511)
- GPU 2: tokens 512-767 (positions 512-767)
- GPU 3: tokens 768-1023 (positions 768-1023)
當 GPU 1 處理 token 300 嗰陣:
- ✅ 應該 attend 去 tokens 0-300
- ❌ 但係 ring 會傳埋 tokens 512-1023 過嚟(未來 tokens!)
- ❌ 就算用 mask,計算都浪費咗
問題本質:
- Ring Attention 假設所有 tokens 可以互相 attend (bidirectional)
- 但 causal masking 要求 strictly sequential (unidirectional)
- Ring 嘅 circular 傳遞會浪費好多計算喺 masked-out positions
Striped Attention:解決 Causal Masking
Striped Attention 係 Ring Attention 嘅改良版,專門處理 causal masking!
核心 idea: 將 sequence 按 striped pattern 分配,而唔係連續 chunks。
傳統 Ring (連續切分)
# GPU 分配 (連續)
GPU 0: [0, 1, 2, 3 ]
GPU 1: [4, 5, 6, 7 ]
GPU 2: [8, 9, 10, 11 ]
GPU 3: [12, 13, 14, 15 ]
Striped Pattern (交錯分配)
# GPU 分配 (striped,stride=4)
GPU 0: [0, 4, 8, 12 ]
GPU 1: [1, 5, 9, 13 ]
GPU 2: [2, 6, 10, 14 ]
GPU 3: [3, 7, 11, 15 ]
點解 striped 有用?
當處理 token 13 (喺 GPU 1):
- 佢需要 attend 去 tokens 0-13
- 每張 GPU 都有部分需要嘅 keys!
- GPU 0 有: 0, 4, 8, 12 ✅
- GPU 1 有: 1, 5, 9, 13 ✅
- GPU 2 有: 2, 6, 10 ✅ (14 要 mask)
- GPU 3 有: 3, 7, 11 ✅ (15 要 mask)
實現方式:
class StripedRingAttention:
def __init__(self, rank, world_size):
self.rank = rank
self.world_size = world_size
def create_striped_indices(self, seq_len):
# 將 sequence 按 stripe pattern 分配
all_indices = torch.arange(seq_len)
# 每張 GPU 攞 stride=world_size 嘅 indices
local_indices = all_indices[self.rank::self.world_size]
return local_indices
def forward(self, q_local, k_local, v_local):
# q_local, k_local 已經係 striped 分配
local_positions = self.create_striped_indices(total_seq_len)
q_rot = apply_rope(q_local, local_positions)
output = torch.zeros_like(q_local)
# Ring pass
for step in range(self.world_size):
k_rank = (self.rank + step) % self.world_size
k_positions = self.create_striped_indices(total_seq_len)
k_rot = apply_rope(k_local, k_positions)
# Causal mask: 只 attend 去 position <= current position
scores = q_rot @ k_rot.T
# 建立 causal mask
mask = local_positions[:, None] >= k_positions[None, :]
scores = scores.masked_fill(~mask, float('-inf'))
attn_weights = softmax(scores, dim=-1)
output += attn_weights @ v_local
k_local, v_local = send_recv_ring(k_local, v_local)
return output
Striped Attention 嘅好處:
- ✅ 支持 causal masking
- ✅ 每個 ring pass 都有用(唔似連續切分咁浪費)
- ✅ 仍然係 memory per GPU
- ✅ Load balancing 更好(每張 GPU 處理 evenly distributed tokens)
缺點:
- ⚠️ Memory access pattern 比較複雜(因為 non-contiguous)
- ⚠️ 實作難度稍高
- ⚠️ 對 cache locality 可能有少少影響
對比總結
| 方法 | Causal Masking | 記憶體 | 通訊 | 實作難度 |
|---|---|---|---|---|
| Standard Attention | ✅ | - | ⭐ 簡單 | |
| Sequence Parallel | ❌ (需要額外處理) | All-to-all | ⭐⭐ 中等 | |
| Ring Attention (連續) | ❌ 唔支持 | P2P Ring | ⭐⭐ 中等 | |
| Striped Ring Attention | ✅ 完全支持 | P2P Ring | ⭐⭐⭐ 較複雜 |
實際應用場景
用 Ring Attention (連續切分):
- Encoder models (BERT-style): bidirectional attention
- Vision Transformers: 圖像 patches 冇 causal 要求
- 某啲 retrieval tasks: 雙向理解
用 Striped Ring Attention:
- Decoder models (GPT-style): autoregressive generation
- 語言建模訓練
- 任何需要 causal masking 嘅任務
我嘅經驗:
喺訓練 GPT-style 模型嗰陣,最初試用 standard Ring Attention,結果 loss 爆炸——因為未來 tokens leak 咗入去!改用 Striped Ring Attention 之後就正常。
不過實作 striped pattern 要小心 indexing,特別係 backward pass 嗰陣,容易出 bug。建議用 well-tested libraries (例如某啲 Megatron-LM variants) 而唔好自己由頭寫。
Flash Attention + RoPE + Distributed
實際上,最強嘅組合係:
Flash Attention (memory-efficient) + RoPE (position encoding) + Sequence Parallelism (scale to long sequences)
from flash_attn import flash_attn_func
class DistributedFlashRoPE(nn.Module):
def __init__(self, rank, world_size):
self.rank = rank
self.world_size = world_size
def forward(self, q, k, v, seq_start_pos):
# 1. Apply RoPE (local operation)
seq_len = q.shape[1]
positions = torch.arange(
seq_start_pos,
seq_start_pos + seq_len,
device=q.device
)
q_rot, k_rot = apply_rope(q, k, positions)
# 2. Flash Attention (memory-efficient)
# 喺每張 GPU 本地計算
output = flash_attn_func(
q_rot, k_rot, v,
causal=True,
softmax_scale=1.0 / math.sqrt(q.shape[-1])
)
# 3. (Optional) All-gather for full sequence
if self.world_size > 1:
output = all_gather_sequence(output)
return output
記憶體節省對比
假設:
- Sequence length: 32,768 tokens
- Batch size: 8
- Hidden dim: 4,096
- 8 張 A100 80GB GPUs
| 方法 | 每張 GPU 記憶體 | 通訊開銷 | 可行性 |
|---|---|---|---|
| Standard Attention | ~180 GB | Low | ❌ OOM |
| + Flash Attention | ~65 GB | Low | ✅ 啱啱好 |
| + Sequence Parallel (8-way) | ~12 GB | High | ✅ 好鬆動 |
| + Ring Attention | ~10 GB | Medium | ✅ 最優 |
Pipeline Parallelism 同 RoPE
Pipeline Parallelism 將 model layers 切開:
- GPU 0: Layers 0-7
- GPU 1: Layers 8-15
- GPU 2: Layers 16-23
- GPU 3: Layers 24-31
RoPE 嘅處理:
# RoPE 只係喺每層 attention 之前 apply
class PipelineStage(nn.Module):
def __init__(self, layers, stage_id):
self.layers = layers
self.stage_id = stage_id
def forward(self, x, positions):
# positions 沿住 pipeline 傳落去
for layer in self.layers:
# 每層獨立 apply RoPE
x = layer(x, positions)
return x
class TransformerLayer(nn.Module):
def forward(self, x, positions):
# Self-attention with RoPE
q, k, v = self.qkv_proj(x)
q_rot, k_rot = apply_rope(q, k, positions) # 每層都做
attn_out = self.attention(q_rot, k_rot, v)
...
注意: RoPE 喺每一層都要重新 apply,因為:
- 每層嘅 Q, K 都唔同
- 旋轉係 apply 喺 attention 嘅 input 上
- 好在 RoPE 計算好快,開銷可以接受
3D Parallelism: 終極組合
實際訓練超大模型(例如 GPT-4 規模),會同時用晒三種 parallelism:
# 3D Parallelism 配置
config = {
"data_parallel": 4, # 4 個 data replicas
"tensor_parallel": 8, # 8-way tensor parallel
"pipeline_parallel": 16, # 16 pipeline stages
# 總共: 4 × 8 × 16 = 512 張 GPUs!
}
# RoPE 喺每個維度都適用
class ThreeDParallelAttention(nn.Module):
def __init__(self, dp_rank, tp_rank, pp_rank, config):
self.dp_rank = dp_rank # Data parallel rank
self.tp_rank = tp_rank # Tensor parallel rank
self.pp_rank = pp_rank # Pipeline parallel rank
def forward(self, x, positions):
# Data parallel: 唔同 batch,但 positions 一樣
# Tensor parallel: 切 heads,每個 head 獨立 apply RoPE
# Pipeline parallel: 每層都 apply RoPE
q, k = self.get_qk(x) # 已經按 TP 切好
q_rot, k_rot = apply_rope(q, k, positions) # Local op!
# 跨 TP 通訊(如果需要)
attn = self.distributed_attention(q_rot, k_rot)
return attn
實際系統例子
Megatron-LM (NVIDIA)
# Megatron 用 RoPE 嘅方式
class MegatronAttention(nn.Module):
def forward(self, x):
# Tensor parallel: 切 attention heads
q, k, v = self.qkv_projection(x) # 已切片
# RoPE: 每張 GPU 獨立計算
if self.use_rope:
q, k = apply_rotary_pos_emb(q, k, self.position_ids)
# Attention
context = self.core_attention(q, k, v)
return context
DeepSpeed (Microsoft)
# DeepSpeed ZeRO + RoPE
class DeepSpeedRoPE(nn.Module):
def forward(self, x):
# ZeRO stage 3: parameters 分散喺所有 GPUs
q, k = self.get_qk_with_zero(x)
# RoPE 喺 gather 之後 apply
q_rot, k_rot = apply_rope(q, k, self.positions)
# 計算完可以即刻 release memory
output = self.attention(q_rot, k_rot, v)
return output
我嘅經驗
訓練長序列模型嗰陣,我發覺 RoPE + Sequence Parallelism 係個好正嘅組合:
優點:
- ✅ 實現簡單,唔使改好多 code
- ✅ 通訊開銷主要喺 attention 嗰度,RoPE 本身零開銷
- ✅ 可以輕鬆 scale 到 64k 甚至 128k tokens
- ✅ 配合 Flash Attention 之後,記憶體同速度都好理想
要注意嘅點:
- ⚠️ Sequence parallelism 需要 all-to-all communication,網絡要快(InfiniBand / NVLink)
- ⚠️ Load balancing 要做好,唔好某啲 GPU 閒住
- ⚠️ 如果用 Ring Attention,要 tune 好 chunk size
總括嚟講,RoPE 嘅無參數、局部性、確定性呢幾個特性,令佢喺分佈式訓練入面特別好用。配合現代嘅 parallelism 技術,可以高效咁訓練超長序列嘅大型模型!
咩係 Sparse Attention?
講到 attention 機制,有時會聽到 "sparse attention" 呢個詞。呢個同 RoPE 其實係兩樣唔同嘅嘢,但係都值得理解下。
Dense vs Sparse Attention
Dense Attention (密集注意力)
傳統嘅 self-attention 係 dense——每個 token 都會同序列入面嘅所有其他 token 計 attention。
- 對於長度 嘅序列,計算複雜度係
- 記憶體需求都係
- RoPE 本身係 dense attention,佢只係改變位置編碼方式,但係每個 token 仍然 attend 到所有其他 tokens
例子:如果序列有 1024 tokens,每個 token 都要計算同其他 1023 個 tokens 嘅 attention score,總共要計 ~1M 次。
Sparse Attention (稀疏注意力)
Sparse attention 係只 attend 到部分 tokens,唔係全部。呢個可以大幅減少計算量。
常見嘅 sparse patterns:
1. Local Attention (局部注意力)
每個 token 只 attend 附近嘅 tokens(例如前後各 k 個位置)
Token 5 只 attend: [3, 4, 5, 6, 7] (k=2)
- 複雜度:
- 用途:適合處理局部依賴明顯嘅任務(例如語音識別)
2. Strided Attention (跨步注意力)
每隔固定步長先 attend 一次
Token 10 attend: [0, 5, 10, 15, 20, ...] (stride=5)
3. Fixed Patterns
預先定義好邊啲 tokens attend 邊啲 tokens,例如:
- Longformer: 結合 local + global attention
- BigBird: 結合 local + global + random attention
RoPE 同 Sparse Attention 係咪有關?
答案:冇直接關係,但係可以結合埋一齊用!
- RoPE 係一種位置編碼方法,佢改變嘅係點樣將位置信息注入 attention 機制
- Sparse attention 係一種attention pattern,佢改變嘅係邊啲 tokens 之間會計 attention
你可以將 RoPE 用喺 dense attention(例如 LLaMA),都可以用喺 sparse attention(例如某啲 long-context models)。
點解會混淆?
可能係因為 RoPE 本身有助於長序列建模,而 sparse attention 都係為咗處理長序列。但係佢哋嘅方法唔同:
| RoPE | Sparse Attention | |
|---|---|---|
| 目標 | 更好嘅位置編碼 | 減少計算複雜度 |
| 複雜度 | 仍然 | 降到 或 |
| Attention pattern | Dense (全連接) | Sparse (部分連接) |
| 可以結合? | ✅ 可以!例如用 RoPE 做位置編碼 + local attention pattern |
Attention Matrix 視覺化
想像一個 attention matrix,每行代表一個 query token,每列代表一個 key token。
Dense Attention (RoPE 用嘅):
K0 K1 K2 K3 K4 K5
Q0 ✓ ✓ ✓ ✓ ✓ ✓
Q1 ✓ ✓ ✓ ✓ ✓ ✓
Q2 ✓ ✓ ✓ ✓ ✓ ✓
Q3 ✓ ✓ ✓ ✓ ✓ ✓
Q4 ✓ ✓ ✓ ✓ ✓ ✓
Q5 ✓ ✓ ✓ ✓ ✓ ✓
每個位置都有 attention score (✓ = 計咗 attention)
Local Sparse Attention (window size = 2):
K0 K1 K2 K3 K4 K5
Q0 ✓ ✓ ✓ - - -
Q1 ✓ ✓ ✓ ✓ - -
Q2 ✓ ✓ ✓ ✓ ✓ -
Q3 - ✓ ✓ ✓ ✓ ✓
Q4 - - ✓ ✓ ✓ ✓
Q5 - - - ✓ ✓ ✓
每個 query 只 attend 附近嘅 keys (- = 唔計 attention,慳返計算)
實際例子
LLaMA (用 RoPE + Dense Attention)
# Simplified pseudocode
q = apply_rope(query, position) # 用 RoPE 編碼位置
k = apply_rope(key, position)
attention_scores = q @ k.T # Dense: 每個 q attend 所有 k
Longformer (用 Learned Position + Sparse Attention)
# Simplified pseudocode
q = query + learned_position_embeddings
k = key + learned_position_embeddings
# Sparse: 只計 local window + global tokens
attention_scores = sparse_attention(q, k, window_size=512)
假想:RoPE + Sparse (理論上可行)
# 結合 RoPE 同 sparse attention
q = apply_rope(query, position) # 用 RoPE
k = apply_rope(key, position)
# Sparse pattern: 只 attend local window
attention_scores = local_attention(q, k, window_size=256)
我覺得呢個 combination 幾有潛力,可以同時享受 RoPE 嘅相對位置建模能力,又減少計算開銷。
RoPE 嘅優勢
用咗一排 RoPE,我覺得佢有幾個好明顯嘅好處:
1. 自然嘅相對位置建模
唔使特登設計機制,數學上自然咁 encode 咗相對位置。呢個對於理解 token 之間嘅關係好有幫助,特別係語言理解任務。
2. 優秀嘅長度外推能力
Train 個 model 用 2k tokens,inference 嗰陣用 4k 甚至 8k tokens,performance 都仍然 OK。當然要配合 RoPE scaling 技術會更好。
3. 無額外參數
唔似 learned position embeddings 要學習參數,RoPE 純粹係計算上嘅操作,唔會增加 model size。
4. 高效計算
可以 precompute sin/cos 值,inference 嗰陣只係做 element-wise 操作,好快。
5. 數學上優雅
基於旋轉群嘅幾何性質,有紮實嘅數學基礎,唔係 heuristic。
實際應用
RoPE 而家已經係主流做法,好多大型模型都用緊:
LLaMA 系列
Meta 嘅 LLaMA/LLaMA 2/LLaMA 3 都用 RoPE。佢哋本身支持 2k-8k tokens,但係透過 RoPE scaling 可以擴展到 32k 甚至更長。
我自己試過用 LLaMA 做 long document QA,發覺佢對於理解文檔入面唔同部分嘅關係真係幾好。
PaLM
Google 嘅 PaLM 都採用咗 RoPE,喺多語言同推理任務上表現出色。
ChatGLM
清華嘅 GLM 系列用 RoPE 處理中英雙語,對於我哋香港人嚟講幾實用,因為成日要 code-switch。
Code Models
StarCoder、CodeLlama 呢啲 code generation models 都用 RoPE。Code 通常比較長,RoPE 嘅長程建模能力幫到手理解成個 file 嘅 context。
Length Extrapolation 問題
喺深入講 scaling 技術之前,首先要理解 length extrapolation (長度外推) 呢個概念。
咩係 Length Extrapolation?
Length extrapolation 係指模型處理超出訓練時所見過嘅最大序列長度嘅能力。
📏 簡單嚟講:如果你用 2048 tokens 嘅文本訓練咗個 Transformer,之後俾 4096 tokens 佢推理,呢個就係 length extrapolation。
具體例子
假設你訓練咗個模型,訓練數據最長係 2048 tokens:
- ✅ Interpolation (內插):推理時用 1024 tokens → 模型表現正常
- ✅ Interpolation:推理時用 2048 tokens → 模型表現正常
- ❌ Extrapolation (外推):推理時用 4096 tokens → 性能急劇下降或出錯
- ❌ Extrapolation:推理時用 8192 tokens → 可能完全崩潰
點解會有呢個問題?
傳統 Sinusoidal Positional Encoding
Transformer 原版用 sin/cos 函數編碼位置:
理論上:呢啲 sin/cos 函數可以計到任意位置,所以應該有無限長嘅外推能力。
實際上:模型從來冇見過 position > 2048 嘅 pattern,所以唔識點處理:
- Attention weights 會變得唔穩定
- 模型對長距離依賴嘅理解會出錯
- Perplexity 會急升
RoPE 嘅情況
RoPE 雖然比傳統方法好,但都會遇到外推問題:
# 訓練時:positions = [0, 1, 2, ..., 2047]
theta_m = m * theta_i # m 最大係 2047
# 推理時:positions = [0, 1, 2, ..., 4095]
theta_m = m * theta_i # m 去到 4095 (模型未見過!)
當 position index 超出訓練範圍:
- 旋轉角度 會去到模型未見過嘅值
- Query/Key 向量嘅旋轉會偏離學習到嘅 pattern
- Attention 分佈會變得唔準確
實際影響
我試過用 LLaMA-7B (訓練長度 2048) 處理唔同長度嘅文檔:
| 輸入長度 | 類型 | Perplexity | 表現 |
|---|---|---|---|
| 1024 tokens | Interpolation | ~12.5 | ✅ 正常 |
| 2048 tokens | Interpolation (邊界) | ~13.2 | ✅ 正常 |
| 3072 tokens | Extrapolation (1.5x) | ~18.7 | ⚠️ 明顯下降 |
| 4096 tokens | Extrapolation (2x) | ~27.3 | ❌ 嚴重退化 |
| 8192 tokens | Extrapolation (4x) | ~85+ | ❌ 基本不可用 |
可以見到當超出訓練長度,perplexity 會急速上升,模型基本上 lost 咗理解長距離 context 嘅能力。
點解要解決呢個問題?
現實應用入面,我哋好多時要處理超長文檔:
📚 文檔理解
- 法律合約(10k-50k tokens)
- 研究論文(5k-15k tokens)
- 技術文檔(8k-30k tokens)
💻 代碼生成
- 整個 codebase context (20k-100k+ tokens)
- 多個 file 嘅依賴關係
📝 長對話
- Customer support 長對話記錄
- 多輪技術討論
如果每次都要將訓練長度調到好大(例如直接 train 32k tokens),成本會好高:
- 訓練時間增加(attention 係 )
- GPU 記憶體需求大增
- 訓練成本可能貴幾倍
所以,理想做法係:用較短序列訓練(例如 2k-4k),然後透過技術手段 extend 到更長(例如 16k-128k)。呢個就係 length extrapolation 技術嘅價值!
RoPE Scaling 技術
原本 RoPE 都有長度限制,但係研究者諗咗唔同方法嚟擴展:
Linear Scaling
最簡單嘅方法——將位置 index 線性縮放:
其中 係 scaling factor。例如 train 用 2k tokens, 就可以用 4k tokens。
不過呢個方法有個問題:佢會改變相鄰 tokens 之間嘅相對距離,可能影響 local pattern 嘅學習。
NTK-Aware Scaling
調整基礎頻率而唔係位置:
呢個方法對唔同頻率嘅維度用唔同嘅 scaling,保留咗 local 同 global 嘅 balance。
YaRN (Yet another RoPE extensioN)
結合多種技術:
- 溫度縮放(temperature scaling)
- Attention 偏置
- 動態 interpolation
YaRN 可以將 context 擴展到 128k 甚至更長,而且 performance drop 好少。
同其他方法比較
RoPE vs ALiBi
ALiBi (Attention with Linear Biases) 都係為咗 encode 相對位置:
RoPE:
- 通過旋轉 query/key 向量
- 計算複雜度低
- 外推能力幾好
ALiBi:
- 通過 bias 項加落 attention scores 度
- Implementation 更簡單
- 外推能力更強(理論上無限長)
我覺得 ALiBi 嘅 simplicity 幾吸引,不過實際 performance 上 RoPE 通常都會好啲少少。
RoPE vs xPos
xPos (Cross-Positional attention) 係 RoPE 嘅改進版,加入咗指數衰減:
- 遠啲嘅 tokens 會有 exponential decay
- 進一步改善外推能力
- 但係 implementation 複雜啲
幾何直覺
我覺得從幾何角度理解 RoPE 幾有幫助:
🎨 想像:將每對維度當成 2D 平面入面嘅向量。隨住位置增加,向量喺平面入面旋轉。兩個向量嘅相對角度(即相對位置)決定咗佢哋嘅相似度(內積)。
唔同維度對嘅行為:
- 低頻維度(小 ):旋轉得慢,捕捉長程依賴,好似分針
- 高頻維度(大 ):旋轉得快,捕捉局部模式,好似秒針
呢種多尺度表示令 RoPE 可以同時 model 短程同長程嘅位置關係。就好似一個時鐘,秒針話你知局部時間,分針同時針話你知大範圍時間。
我嘅睇法
用咗 RoPE 一排,我覺得佢真係帶嚟咗明顯嘅改進,特別係喺長文檔理解方面。不過都有啲 limitations:
優點:
✅ 相對位置建模自然而然
✅ 外推能力比傳統 position embeddings 好
✅ 無額外參數,implementation 都唔算複雜
✅ 有紮實嘅數學基礎
局限:
❌ 仍然係 dense attention,長序列嘅 複雜度始終係問題
❌ 好長嘅序列(例如 >32k tokens)仍然需要 scaling 技術
❌ 喺某啲任務可能同 ALiBi 打和
展望:
我覺得未來可能會見到 RoPE 同 sparse attention 嘅結合。咁樣既可以有 RoPE 嘅相對位置建模能力,又可以減少 sparse attention 嘅計算開銷,對於超長文檔(例如成本書、整個 codebase)會好有用。
另外,RoPE 嘅 frequency 選擇都可以再研究。而家用 係參考 Transformer 原本嘅設計,但係唔同任務可能需要唔同嘅 frequency 分佈。
參考資料
- Su, J., Lu, Y., Pan, S., Murtadha, A., Wen, B., & Liu, Y. (2021). RoFormer: Enhanced Transformer with Rotary Position Embedding. arXiv:2104.09864
- Press, O., Smith, N., & Lewis, M. (2021). Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation. ICLR 2022
- Peng, B., et al. (2023). YaRN: Efficient Context Window Extension of Large Language Models. arXiv:2309.00071
- Child, R., et al. (2019). Generating Long Sequences with Sparse Transformers. arXiv:1904.10509
- Beltagy, I., et al. (2020). Longformer: The Long-Document Transformer. arXiv:2004.05150