倒反天罡视角:从Seq2Seq 编码器看穿 PyTorch 的 Permute 和模块机制
一般都是熟悉pytorch框架后学习模型,今天就让我们倒反天罡一下,从一个编码器的实现上了解一下pytorch框架的Permute操作和模块机制.具体地
下面是一个序列到序列(Seq2Seq)模型里的“编码器”部分的代码,它的任务就像一个“翻译器的耳朵”:听清楚输入的每一个词,把它们变成“神经网络能理解的内部表示”:
class Seq2SeqEncoder(d2l.Encoder):
"""用于序列到序列学习的循环神经网络编码器
Defined in :numref:`sec_seq2seq`"""
def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
dropout=0, **kwargs):
super(Seq2SeqEncoder, self).__init__(**kwargs)
# 嵌入层
self.embedding = nn.Embedding(vocab_size, embed_size)
self.rnn = nn.GRU(embed_size, num_hiddens, num_layers,
dropout=dropout)
def forward(self, X, *args):
# 输出'X'的形状:(batch_size,num_steps,embed_size)
X = self.embedding(X)
# 在循环神经网络模型中,第一个轴对应于时间步
X = X.permute(1, 0, 2)
# 如果未提及状态,则默认为0
output, state = self.rnn(X)
# output的形状:(num_steps,batch_size,num_hiddens)
# state的形状:(num_layers,batch_size,num_hiddens)
return output, state
针对上述代码,本篇帖子重点要搞清楚下面两个问题:
a) permute(1, 0, 2)
的必要性,和 b) self.embedding(x)
为什么可以这么直接用,而不是 self.embedding.forward(x)
。进而理解PyTorch 的 Permute 和模块机制在模型世界中的实际用处.
1. 先讲讲:代码里都干了什么?
self.embedding = nn.Embedding(vocab_size, embed_size)
这是一个“词嵌入层”,作用就像是给每个单词找个“向量座位”。比如“猫”可能坐在(0.1, 0.7, -0.3),而“狗”坐在(0.2, 0.6, -0.1)——总之大家都坐在一个高维空间里。
self.rnn = nn.GRU(embed_size, num_hiddens, num_layers)
然后是 RNN(这里是 GRU),它就像一个“多层记忆蛋糕”,一边吃进每个时间步的单词向量,一边“消化”并记住上下文关系。
问题 A:X.permute(1, 0, 2)
是干嘛的?
🌟 关键词:维度变换(维度调个顺序)
-
你从
self.embedding(X)
得到的X
形状是:(batch_size, num_steps, embed_size)
即:我有多少句话(批量),每句话多少词,每个词是多长的向量
-
但你这个 RNN(GRU)吃饭的方式有点特殊,它要吃的维度顺序是:
(time_step, batch_size, input_size)
它吃饭是按时间顺序来吃,第一口要吃所有样本的第一个词,第二口是所有样本的第二个词……
🧙所以怎么办?
我们得把饭碗顺序调一调,把“时间步”这个轴放在最前面,这就是这行代码的意义:
X = X.permute(1, 0, 2)
-
permute(1, 0, 2):把
dim=0
(原来的 batch_size)和dim=1
(原来的时间步)交换顺序 -
现在
X
的形状变成了:(num_steps, batch_size, embed_size)
-
RNN 满意了,可以开吃了 😋,同时也相当于批量处理批次中每各时间步上的此向量,在RNN串行建模的框架下尽可能提高处理效率
重点 B:为啥 self.embedding(x)
能直接调用?
🌟 关键词:PyTorch 的模块机制
在 PyTorch 中,所有的 nn.Module
(比如 nn.Embedding
, nn.GRU
)都有一个神奇的特性:
✅ 你可以直接用:
self.embedding(x)
这其实会自动调用它内部的 .forward(x)
方法。
❓那为什么不用写 self.embedding.forward(x)
呢?
因为 PyTorch 规定了调用方式:
self.embedding(x)
其实是self.embedding.__call__(x)
,- 它内部会自动调用
self.embedding.forward(x)
,并帮你处理很多底层细节(比如 hooks、参数注册、GPU/CPU切换等)——这些你写forward(x)
是处理不了的。
🎩 举个例子:
class SillyExample(nn.Module):
def forward(self, x):
print("I am in forward!")
return x * 2
silly = SillyExample()
silly(3) # 调用 __call__,自动执行 forward
# 输出:I am in forward!
你用 silly.forward(3)
也行,但那是“裸奔调用”,不会处理很多 PyTorch 的魔法,比如 hooks 或 device 设置。
所以写模型的时候,乖乖用 self.embedding(x)
就对了,别试图“硬肛” .forward()
。
2 总结成一句话:
🧠 permute(1, 0, 2)
是为了把“吃饭顺序”从“每句话整体吃”改为“逐时间吃”;
🧠 self.embedding(x)
是 PyTorch 安排的优雅调用方式,帮你自动走 forward()
和一堆底层魔法。
想学Python怕被线程池|元组解包劝退?本专栏用打工人打工魂|拆快递|交换奶茶的生活化比喻,把核心知识点讲成唠家常!从线程池原理到元组解包技巧,每篇带代码实战+避坑指南,小白边看边练,无痛掌握。新手入门、老萌新优化代码都适用;学完直接上手批量下载、处理Excel、优化爬虫,Python原来这么简单好玩!