ai agent手撕真题- ROPE旋转位置编码
先搞懂:为什么需要位置编码?
Transformer 模型本身是 位置无关 的!它不知道哪个词在前,哪个词在后。比如
"我喜欢吃苹果" 和 "苹果喜欢吃我"
Transformer 会认为这两个句子的结构完全一样,因为它只关心词与词之间的关系,不关心顺序。但实际上这两个句子的意思完全不同!
所以需要 位置编码 来告诉模型每个词的位置信息
传统位置编码问题
方案1:学习型位置编码,初始化一个位置向量,全程全靠train
- 给每个位置随机初始化一个向量
- 让模型自己学习每个位置应该是什么样的
- ❌ 缺点:只能处理训练时见过的长度,遇到更长的句子就傻眼了
方案2:正弦位置编码
- 用三角函数(sin、cos)计算位置编码
- ✅ 优点:可以处理任意长度的句子
- ❌ 缺点:只能编码 绝对位置 ,无法很好地捕捉 相对位置
ROPE核心思想:旋转位置
想象一个简单的场景:你有一个向量 [x1, x2] ,现在给它旋转一个角度 θ 。
旋转前:
(x1, x2) → 在坐标系中的一个点
旋转后:
(x1*cosθ - x2*sinθ, x1*sinθ + x2*cosθ)
那么,ROPE的方法核心就是,对于每个词,不同位置的词旋转成不同的角度:
(1)位置0:不旋转(角度0)
(2)位置1:旋转θ度
(3)位置2:旋转2θ度
·····
(m)位置m:旋转mθ度
因此推理到高维向量,它旋转mθ度,那么计算公式就是:
第一步将向量分为两半
假设我们有一个维度为4的向量: [a, b, c, d]
分成两个2维向量:
- 偶数维度: [a, c]
- 奇数维度: [b, d]
第二步:对每一半应用旋转
a' = a * cos(mθ) - b * sin(mθ)
b' = a * sin(mθ) + b * cos(mθ)
c' = c * cos(mθ) - d * sin(mθ)
d' = c * sin(mθ) + d * cos(mθ)
第三步:再拼接回来
最终得到: [a', b', c', d']
频率
但是在原文当中,论文的θ并不是对于一个向量中的各个维度都是统一的,他也是变化的
ROPE 的频率公式:
θ = 1 / 10000^(2i/d) - i :维度索引(0, 1, 2, ...) - d :总维度
- 维度越低(i越小),频率越高,旋转越快
- 维度越高(i越大),频率越低,旋转越慢
优势
优势1:长度外推
训练时用1024长度,推理时可以处理2048、4096甚至更长的序列!
因为旋转是无限循环的,无论多长的位置都能计算。
优势2:相对位置编码
模型能理解"猫追狗"和"狗追猫"的区别,因为它知道谁在前谁在后。
优势3:计算高效
不需要存储额外的位置向量,计算开销很小。
优势4:参数为0
不需要额外学习位置向量的参数,减少内存占用。
手撕
import torch
import math
class ROPE:
def __init__(self, dim: int, max_len: int = 2048):
self.dim = dim
self.max_len = max_len
# 计算旋转频率
self.freqs = self._compute_freqs(dim)
def _compute_freqs(self, dim: int):
"""计算旋转频率"""
# 频率计算公式: theta = 1 / 10000^(2i/d)
# 其中 i 是维度索引,d 是总维度
inv_freq = 1.0/(10000.0 ** (torch.arange(0, dim, 2).float() / dim))
def forward(self, x: torch.Tensor):
"""
应用旋转位置编码
Args:
x: 输入张量 [batch_size, seq_len, dim]
Returns:
编码后的张量 [batch_size, seq_len, dim]
"""
batch_size, seq_len, dim = x.shape
# 计算位置索引
positions = torch.arange(seq_len, device=x.device)
# 计算频率: [seq_len, dim/2]
freqs = self.freqs[:dim // 2].to(x.device)
freqs = positions[:, None] * freqs[None, :]
# 转换为复数表示
cos_freqs = torch.cos(freqs)
sin_freqs = torch.sin(freqs)
# 应用旋转
# 将输入分为两部分
x1 = x[..., ::2] # 偶数索引
x2 = x[..., 1::2] # 奇数索引
# 计算旋转后的向量
# x' = [x1*cos - x2*sin, x1*sin + x2*cos]
x1_rot = x1 * cos_freqs - x2 * sin_freqs
x2_rot = x1 * sin_freqs + x2 * cos_freqs
# 合并旋转后的向量
rotated_x = torch.cat([x1_rot, x2_rot], dim=-1)
return rotated_x
#agent##AI求职记录##25届秋招公司红黑榜#ai agent每日手撕,喜欢还请关注+收藏
查看5道真题和解析