加性注意力中的 unsqueeze() 魔法:给张量"长高"的趣味解释
想象你有一叠 3D 彩色纸片(张量),而 unsqueeze()
就像给这些纸片偷偷增加一层隐形夹层,让它们能和其他纸片对齐拼接!
🎲 1. 为什么需要 unsqueeze()
?
在加性注意力机制中,我们需要计算 查询(queries)和键(keys) 的相似度。但它们的形状不匹配:
queries
形状:(batch_size, 查询个数, num_hiddens)
keys
形状:(batch_size, “键-值”对个数, num_hiddens)
为了让它们能逐元素相加(广播机制),我们需要让它们的形状变成:
queries
→(batch_size, 查询个数, 1, num_hiddens)
(增加一个维度)keys
→(batch_size, 1, “键-值”对个数, num_hiddens)
(增加一个维度)
这样,queries + keys
就能自动扩展成 (batch_size, 查询个数, “键-值”对个数, num_hiddens)
,实现批量计算相似度!
📦 2. unsqueeze()
的具体操作
(1) queries.unsqueeze(2)
- 作用:在
queries
的第 2 个维度(从 0 开始数)插入一个长度为 1 的维度。 - 形状变化:
- 原始
queries
:(batch_size, 查询个数, num_hiddens)
queries.unsqueeze(2)
:(batch_size, 查询个数, 1, num_hiddens)
- 原始
小白比喻:
- 原本
queries
是一叠 平铺的纸片(3D)。 unsqueeze(2)
相当于在这叠纸片里偷偷塞了一层透明塑料膜(新增一个维度),变成 4D 纸片堆!
(2) keys.unsqueeze(1)
- 作用:在
keys
的第 1 个维度插入一个长度为 1 的维度。 - 形状变化:
- 原始
keys
:(batch_size, “键-值”对个数, num_hiddens)
keys.unsqueeze(1)
:(batch_size, 1, “键-值”对个数, num_hiddens)
- 原始
小白比喻:
- 原本
keys
是另一叠 平铺的纸片(3D)。 unsqueeze(1)
相当于在这叠纸片里横向插入一层透明塑料膜(新增一个维度),变成 4D 纸片堆!
🧙 3. 广播机制的魔法
现在:
queries.unsqueeze(2)
:(batch_size, 查询个数, 1, num_hiddens)
keys.unsqueeze(1)
:(batch_size, 1, “键-值”对个数, num_hiddens)
当它们相加时,PyTorch 会自动扩展维度,变成:
(batch_size, 查询个数, “键-值”对个数, num_hiddens)
效果:
- 每个查询(query)会自动和所有键(keys)计算相似度,无需写循环!
💡 4. 为什么不能直接 queries + keys
?
如果直接相加:
queries
形状:(batch_size, 查询个数, num_hiddens)
keys
形状:(batch_size, “键-值”对个数, num_hiddens)
PyTorch 无法自动对齐,会报错!
必须用 unsqueeze()
让它们维度匹配,才能触发广播机制。
🎯 5. 总结(小白记忆口诀)
操作 | 作用 | 类比 |
---|---|---|
queries.unsqueeze(2) |
给查询纸片偷偷加一层夹层 | 从 3D → 4D,让查询能和所有键对齐 |
keys.unsqueeze(1) |
给键纸片横向加一层夹层 | 从 3D → 4D,让键能和所有查询对齐 |
queries + keys |
广播计算相似度 | 自动扩展维度,批量计算注意力分数 |
幽默总结:
unsqueeze()
就像给张量**"长高",让它们能站在一起对齐**!- 没有它,查询和键就像不同高度的积木,没法拼在一起;有了它,就能自动搭建注意力大厦! 🏗️
这样,加性注意力就能高效计算查询和键的相似度啦! ✨
大模型小白拆解站 文章被收录于专栏
想和大模型零障碍对话?这里是你的入门急救站! 从大模型到底是啥到训练时都在干啥,用大白话拆解技术原理;从参数是个啥到微调怎么玩,用生活案例讲透核心概念。拒绝枯燥公式,只有能听懂的干货和冷到爆的梗;帮你从大模型小白变身入门小能手,轻松get前沿AI知识!