ai agent手撕真题- ROPE旋转位置编码

先搞懂:为什么需要位置编码?

Transformer 模型本身是 位置无关 的!它不知道哪个词在前,哪个词在后。比如

"我喜欢吃苹果" 和 "苹果喜欢吃我"

Transformer 会认为这两个句子的结构完全一样,因为它只关心词与词之间的关系,不关心顺序。但实际上这两个句子的意思完全不同!

所以需要 位置编码 来告诉模型每个词的位置信息

传统位置编码问题

方案1:学习型位置编码,初始化一个位置向量,全程全靠train

  • 给每个位置随机初始化一个向量
  • 让模型自己学习每个位置应该是什么样的
  • ❌ 缺点:只能处理训练时见过的长度,遇到更长的句子就傻眼了

方案2:正弦位置编码

  • 用三角函数(sin、cos)计算位置编码
  • ✅ 优点:可以处理任意长度的句子
  • ❌ 缺点:只能编码 绝对位置 ,无法很好地捕捉 相对位置

ROPE核心思想:旋转位置

想象一个简单的场景:你有一个向量 [x1, x2] ,现在给它旋转一个角度 θ 。

旋转前:

(x1, x2) → 在坐标系中的一个点

旋转后:

(x1*cosθ - x2*sinθ, x1*sinθ + x2*cosθ)

那么,ROPE的方法核心就是,对于每个词,不同位置的词旋转成不同的角度

(1)位置0:不旋转(角度0)

(2)位置1:旋转θ度

(3)位置2:旋转2θ度

·····

(m)位置m:旋转mθ度

因此推理到高维向量,它旋转mθ度,那么计算公式就是:

第一步将向量分为两半

假设我们有一个维度为4的向量: [a, b, c, d]

分成两个2维向量:

  • 偶数维度: [a, c]
  • 奇数维度: [b, d]

第二步:对每一半应用旋转

a' = a * cos(mθ) - b * sin(mθ)

b' = a * sin(mθ) + b * cos(mθ)

c' = c * cos(mθ) - d * sin(mθ)

d' = c * sin(mθ) + d * cos(mθ)

第三步:再拼接回来

最终得到: [a', b', c', d']

频率

但是在原文当中,论文的θ并不是对于一个向量中的各个维度都是统一的,他也是变化的

ROPE 的频率公式:

θ = 1 / 10000^(2i/d)
- i :维度索引(0, 1, 2, ...)
- d :总维度
  • 维度越低(i越小),频率越高,旋转越快
  • 维度越高(i越大),频率越低,旋转越慢

优势

优势1:长度外推

训练时用1024长度,推理时可以处理2048、4096甚至更长的序列!

因为旋转是无限循环的,无论多长的位置都能计算。

优势2:相对位置编码

模型能理解"猫追狗"和"狗追猫"的区别,因为它知道谁在前谁在后。

优势3:计算高效

不需要存储额外的位置向量,计算开销很小。

优势4:参数为0

不需要额外学习位置向量的参数,减少内存占用。

手撕

import torch
import math

class ROPE:
	def __init__(self, dim: int, max_len: int = 2048):
	  self.dim = dim
	  self.max_len = max_len
	  # 计算旋转频率
	  self.freqs = self._compute_freqs(dim)
	
	def _compute_freqs(self, dim: int):
	  """计算旋转频率"""
        # 频率计算公式: theta = 1 / 10000^(2i/d)
        # 其中 i 是维度索引,d 是总维度
		inv_freq = 1.0/(10000.0 ** (torch.arange(0, dim, 2).float() / dim))
	
	def forward(self, x: torch.Tensor):
        """
        应用旋转位置编码
        
        Args:
            x: 输入张量 [batch_size, seq_len, dim]
            
        Returns:
            编码后的张量 [batch_size, seq_len, dim]
        """
        batch_size, seq_len, dim = x.shape
        
        # 计算位置索引
        positions = torch.arange(seq_len, device=x.device)
        
        # 计算频率: [seq_len, dim/2]
        freqs = self.freqs[:dim // 2].to(x.device)
        freqs = positions[:, None] * freqs[None, :]
        
        # 转换为复数表示
        cos_freqs = torch.cos(freqs)
        sin_freqs = torch.sin(freqs)
        
        # 应用旋转
        # 将输入分为两部分
        x1 = x[..., ::2]  # 偶数索引
        x2 = x[..., 1::2]  # 奇数索引
        
        # 计算旋转后的向量
        # x' = [x1*cos - x2*sin, x1*sin + x2*cos]
        x1_rot = x1 * cos_freqs - x2 * sin_freqs
        x2_rot = x1 * sin_freqs + x2 * cos_freqs
        
        # 合并旋转后的向量
        rotated_x = torch.cat([x1_rot, x2_rot], dim=-1)
        
        return rotated_x

#agent##AI求职记录##25届秋招公司红黑榜#
ai agent每日手撕 文章被收录于专栏

ai agent每日手撕,喜欢还请关注+收藏

全部评论
创作不易,麻烦点赞加关注,这样作者才能坚持每日一更
点赞 回复 分享
发布于 昨天 22:31 广东

相关推荐

03-27 02:23
门头沟学院 Java
鼠鼠bg:9本无实习,项目上只有烂大街的黑马点评和小林的agent,算法只会hot100,常规八股比较熟练但是缺乏深度,属于典型的大众脸半个月前满怀信心开始投递暑期实习1.处女面是腾讯,面试官很好,即使是烂大街的点评也会探讨项目漏洞和技术方案,可惜鼠鼠准备不充分,没把握住机会2.第二次面携程,体验很好,面试官问的问题也比较常规,也是顺利进入二面,第一次面试通过给了鼠鼠很大的鼓舞3.二战腾讯,遇到了懂ai的面试官,问到transformer底层架构(当时还不会),还有agent的很多新名词,鼠鼠答上来大半最后还是遗憾挂掉,但是也从中学到了很多东西,回去恶补4.一战字节,字节的面试难度鼠鼠早有耳闻,面试前看了很多同部门的面经结合自己的简历做了很多模拟,结果面试官对鼠鼠简历上的东西毫无兴趣,只问了https握手经过几个rdt还有cas在操作系统层面的具体实现(闻所未闻的八股),两个问题过后给出两道非hot100手撕,鼠鼠大概被字节拉黑了吧。5.携程二面,面试官很好,会引导鼠鼠思考推理,问了场景设计,还有一些八股的深挖,比如hashmap负载因子0.75是怎么算出来的,在jdk1.7之前为什么使用链表仍能保持一个较快的查询速度,鼠鼠确实绞尽脑汁想不出来总结一下,面试被问到什么大概取决于面试官的心情吧,特别是目前ai时代,每场面试都会被意想不到的角度拷打,鼠鼠也不是很懂该往哪个方向努力了,只能面到不会的再补吧。后面可能沉下心来,日常和暑期同时投一投吧,希望最终能有一个offer,也渴望得到各位uu的宝贵建议
查看5道真题和解析
点赞 评论 收藏
分享
什么时候能收到off...:为什么上次那么难啊,pdd出题的人真神了
拼多多集团-PDD笔试
点赞 评论 收藏
分享
评论
1
1
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务