PyTorch中的squeeze和unsqueeze:张量的"瘦身"与"增肥"魔法
想象你有一个张量(Tensor),它就像一包压缩饼干(squeeze
)或一盒膨化食品(unsqueeze
)。PyTorch的squeeze
和unsqueeze
操作,就是让张量瘦身或增肥的魔法!
1. 什么是squeeze
?(张量的"瘦身术")
squeeze
的作用是"挤掉"张量里所有大小为1的维度,就像把一包多层压缩饼干压扁成一包普通饼干。
数学解释:
- 如果张量某个维度的长度=1(比如
shape=[1, 3, 1, 4]
),squeeze
会直接删除这个维度,变成[3, 4]
。 - 如果不指定维度,它会挤掉所有大小为1的维度。
- 如果指定维度(比如
squeeze(0)
),则只挤掉第0维(如果它的长度是1)。
类比:
- 原始张量:
[1, 3, 1, 4]
(像一包4层压缩饼干,其中第0层和第2层只有1片饼干) -
squeeze()
后:[3, 4]
(直接压扁成普通饼干,去掉所有空层) -
squeeze(0)
后:[3, 1, 4]
(只去掉第0层,第2层还是1片饼干)
代码示例:
import torch x = torch.randn(1, 3, 1, 4) # shape=[1, 3, 1, 4] y = x.squeeze() # shape=[3, 4](挤掉所有大小为1的维度) z = x.squeeze(0) # shape=[3, 1, 4](只挤掉第0维)
2. 什么是unsqueeze
?(张量的"增肥术")
unsqueeze
的作用是"增加一个大小为1的维度",就像把一包普通饼干包装进一个新盒子,变成多层压缩饼干。
数学解释:
unsqueeze(dim)
会在指定维度dim
的位置插入一个长度=1的维度。- 比如
shape=[3, 4]
→unsqueeze(0)
→shape=[1, 3, 4]
(在第0维外面套一个盒子)。
类比:
- 原始张量:
[3, 4]
(普通饼干) -
unsqueeze(0)
后:[1, 3, 4]
(套一个新盒子,变成1层压缩饼干) -
unsqueeze(1)
后:[3, 1, 4]
(在第1维外面套一个盒子,变成3×1×4的膨化饼干)
代码示例:
import torch x = torch.randn(3, 4) # shape=[3, 4] y = x.unsqueeze(0) # shape=[1, 3, 4](在第0维外面套一个盒子) z = x.unsqueeze(1) # shape=[3, 1, 4](在第1维外面套一个盒子)
3. 什么时候用squeeze
和unsqueeze
?
✅ squeeze
的典型场景:
- 神经网络输入/输出调整:比如模型输出
[1, 10]
(batch=1),但你想去掉batch维度,变成[10]
。 - 矩阵运算前对齐维度:比如
[3, 1]
和[3, 4]
相乘会报错,但squeeze
后[3]
和[3, 4]
可以广播(Broadcasting)。
✅ unsqueeze
的典型场景:
- 增加batch维度:比如
[3, 4]
→[1, 3, 4]
,让数据变成1个样本的batch,方便输入模型。 - 广播(Broadcasting)计算:比如
[3, 4]
和[4]
不能直接相加,但unsqueeze(0)
后[1, 4]
可以广播成[3, 4]
。
4. 常见错误 & 注意事项
❌ 错误1:试图squeeze
一个长度>1的维度
x = torch.randn(2, 3) # shape=[2, 3] y = x.squeeze(0) # 报错!因为第0维长度=2(不是1)
❌ 错误2:unsqueeze
的维度超出范围
x = torch.randn(3, 4) # shape=[3, 4] y = x.unsqueeze(3) # 报错!因为当前只有2维,最大只能`unsqueeze(2)`
5. 总结
- x.squeeze(): 挤掉张量x所有大小为1的维度
- x.squeeze(0): 只挤掉张量x指定维度大小为1的维度
- x.unsqueeze(0): 在张量x指定维度插入一个长度为1的维度
- 适用场景:调整维度以匹配神经网络输入、广播计算等。
掌握这两个操作后,你的PyTorch张量维度管理能力会直接起飞!
Python核心知识唠明白 文章被收录于专栏
想学Python怕被线程池|元组解包劝退?本专栏用打工人打工魂|拆快递|交换奶茶的生活化比喻,把核心知识点讲成唠家常!从线程池原理到元组解包技巧,每篇带代码实战+避坑指南,小白边看边练,无痛掌握。新手入门、老萌新优化代码都适用;学完直接上手批量下载、处理Excel、优化爬虫,Python原来这么简单好玩!