倒反天罡视角:从Seq2Seq 编码器看穿 PyTorch 的 Permute 和模块机制

alt 一般都是熟悉pytorch框架后学习模型,今天就让我们倒反天罡一下,从一个编码器的实现上了解一下pytorch框架的Permute操作和模块机制.具体地 下面是一个序列到序列(Seq2Seq)模型里的“编码器”部分的代码,它的任务就像一个“翻译器的耳朵”:听清楚输入的每一个词,把它们变成“神经网络能理解的内部表示”:

class Seq2SeqEncoder(d2l.Encoder):
    """用于序列到序列学习的循环神经网络编码器

    Defined in :numref:`sec_seq2seq`"""
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
                 dropout=0, **kwargs):
        super(Seq2SeqEncoder, self).__init__(**kwargs)
        # 嵌入层
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size, num_hiddens, num_layers,
                          dropout=dropout)

    def forward(self, X, *args):
        # 输出'X'的形状:(batch_size,num_steps,embed_size)
        X = self.embedding(X)
        # 在循环神经网络模型中,第一个轴对应于时间步
        X = X.permute(1, 0, 2)
        # 如果未提及状态,则默认为0
        output, state = self.rnn(X)
        # output的形状:(num_steps,batch_size,num_hiddens)
        # state的形状:(num_layers,batch_size,num_hiddens)
        return output, state

针对上述代码,本篇帖子重点要搞清楚下面两个问题: a) permute(1, 0, 2) 的必要性,和 b) self.embedding(x) 为什么可以这么直接用,而不是 self.embedding.forward(x)。进而理解PyTorch 的 Permute 和模块机制在模型世界中的实际用处.

1. 先讲讲:代码里都干了什么?

self.embedding = nn.Embedding(vocab_size, embed_size)

这是一个“词嵌入层”,作用就像是给每个单词找个“向量座位”。比如“猫”可能坐在(0.1, 0.7, -0.3),而“狗”坐在(0.2, 0.6, -0.1)——总之大家都坐在一个高维空间里。

self.rnn = nn.GRU(embed_size, num_hiddens, num_layers)

然后是 RNN(这里是 GRU),它就像一个“多层记忆蛋糕”,一边吃进每个时间步的单词向量,一边“消化”并记住上下文关系。

问题 A:X.permute(1, 0, 2) 是干嘛的?

🌟 关键词:维度变换(维度调个顺序)

  • 你从 self.embedding(X) 得到的 X 形状是:

    (batch_size, num_steps, embed_size)
    

    即:我有多少句话(批量),每句话多少词,每个词是多长的向量

  • 但你这个 RNN(GRU)吃饭的方式有点特殊,它要吃的维度顺序是:

    (time_step, batch_size, input_size)
    

    它吃饭是按时间顺序来吃,第一口要吃所有样本的第一个词,第二口是所有样本的第二个词……

🧙‍所以怎么办?

我们得把饭碗顺序调一调,把“时间步”这个轴放在最前面,这就是这行代码的意义:

X = X.permute(1, 0, 2)
  • permute(1, 0, 2):把 dim=0(原来的 batch_size)和 dim=1(原来的时间步)交换顺序

  • 现在 X 的形状变成了:

    (num_steps, batch_size, embed_size)
    
  • RNN 满意了,可以开吃了 😋,同时也相当于批量处理批次中每各时间步上的此向量,在RNN串行建模的框架下尽可能提高处理效率

重点 B:为啥 self.embedding(x) 能直接调用?

🌟 关键词:PyTorch 的模块机制

在 PyTorch 中,所有的 nn.Module(比如 nn.Embedding, nn.GRU)都有一个神奇的特性:

✅ 你可以直接用:

self.embedding(x)

这其实会自动调用它内部的 .forward(x) 方法

❓那为什么不用写 self.embedding.forward(x) 呢?

因为 PyTorch 规定了调用方式:

  • self.embedding(x) 其实是 self.embedding.__call__(x)
  • 它内部会自动调用 self.embedding.forward(x),并帮你处理很多底层细节(比如 hooks、参数注册、GPU/CPU切换等)——这些你写 forward(x) 是处理不了的。

🎩 举个例子:

class SillyExample(nn.Module):
    def forward(self, x):
        print("I am in forward!")
        return x * 2

silly = SillyExample()
silly(3)  # 调用 __call__,自动执行 forward
# 输出:I am in forward!

你用 silly.forward(3) 也行,但那是“裸奔调用”,不会处理很多 PyTorch 的魔法,比如 hooks 或 device 设置。

所以写模型的时候,乖乖用 self.embedding(x) 就对了,别试图“硬肛” .forward()

2 总结成一句话:

🧠 permute(1, 0, 2) 是为了把“吃饭顺序”从“每句话整体吃”改为“逐时间吃”;

🧠 self.embedding(x) 是 PyTorch 安排的优雅调用方式,帮你自动走 forward() 和一堆底层魔法。

Python核心知识唠明白 文章被收录于专栏

想学Python怕被线程池|元组解包劝退?本专栏用打工人打工魂|拆快递|交换奶茶的生活化比喻,把核心知识点讲成唠家常!从线程池原理到元组解包技巧,每篇带代码实战+避坑指南,小白边看边练,无痛掌握。新手入门、老萌新优化代码都适用;学完直接上手批量下载、处理Excel、优化爬虫,Python原来这么简单好玩!

全部评论

相关推荐

评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务