pytorch_detach 切断网络反传

detach

官方文档中,对这个方法是这么介绍的。

    detach = _add_docstr(_C._TensorBase.detach, r""" Returns a new Tensor, detached from the current graph. The result will never require gradient. .. note:: Returned Tensor uses the same data tensor as the original one. In-place modifications on either of them will be seen, and may trigger errors in correctness checks. """)
  • 返回一个新的从当前图中分离的 Variable。
  • 返回的 Variable 永远不会需要梯度
  • 如果 被 detach 的Variable volatile=True, 那么 detach 出来的 volatile 也为 True
  • 还有一个注意事项,即:返回的 Variable 和 被 detach 的Variable 指向同一个 tensor

import torch
from torch.nn import init

t1 = torch.tensor([1., 2.],requires_grad=True)
t2 = torch.tensor([2., 3.],requires_grad=True)
v3 = t1 + t2

v3_detached = v3.detach()
v3_detached.data.add_(t1) # 修改了 v3_detached Variable中 tensor 的值

print(v3, v3_detached)    # v3 中tensor 的值也会改变
print(v3.requires_grad,v3_detached.requires_grad)

''' tensor([4., 7.], grad_fn=<AddBackward0>) tensor([4., 7.]) True False '''

在pytorch中通过拷贝需要切断位置前的tensor实现这个功能。tensor中拷贝的函数有两个,一个是clone(),另外一个是copy_(),clone()相当于完全复制了之前的tensor,他的梯度也会复制,而且在反向传播时,克隆的样本和结果是等价的,可以简单的理解为clone只是给了同一个tensor不同的代号,和‘=’等价。所以如果想要生成一个新的分开的tensor,请使用copy_()。
不过对于这样的操作,pytorch中有专门的函数——detach()。

用户自己创建的节点是leaf_node(如图中的abc三个节点),不依赖于其他变量,对于leaf_node不能进行in_place操作.根节点是计算图的最终目标(如图y),通过链式法则可以计算出所有节点相对于根节点的梯度值.这一过程通过调用root.backward()就可以实现.
因此,detach所做的就是,重新声明一个变量,指向原变量的存放位置,但是requires_grad为false.更深入一点的理解是,计算图从detach过的变量这里就断了, 它变成了一个leaf_node.即使之后重新将它的requires_node置为true,它也不会具有梯度.

pytorch 梯度

  • (0.4之后),tensor和variable合并,tensor具有grad、grad_fn等属性;
  • 默认创建的tensor,grad默认为False, 如果当前tensor_grad为None,则不会向前传播,如果有其它支路具有grad,则只传播其它支路的grad
# 默认创建requires_grad = False的Tensor
x = torch.ones(1)   # create a tensor with requires_grad=False (default)
print(x.requires_grad)
 # out: False
 
 # 创建另一个Tensor,同样requires_grad = False
y = torch.ones(1)  # another tensor with requires_grad=False
 # both inputs have requires_grad=False. so does the output
z = x + y
 # 因为两个Tensor x,y,requires_grad=False.都无法实现自动微分,
 # 所以操作(operation)z=x+y后的z也是无法自动微分,requires_grad=False
print(z.requires_grad)
 # out: False
 
 # then autograd won't track this computation. let's verify!
 # 因而无法autograd,程序报错
# z.backward()
 # out:程序报错:RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

    
# now create a tensor with requires_grad=True
w = torch.ones(1, requires_grad=True)
print(w.requires_grad)
 # out: True
 
 # add to the previous result that has require_grad=False
 # 因为total的操作中输入Tensor w的requires_grad=True,因而操作可以进行反向传播和自动求导。
total = w + z
# the total sum now requires grad!
total.requires_grad
# out: True
# autograd can compute the gradients as well
total.backward()
print(w.grad)
#out: tensor([ 1.])

# and no computation is wasted to compute gradients for x, y and z, which don't require grad
# 由于z,x,y的requires_grad=False,所以并没有计算三者的梯度
z.grad == x.grad == y.grad == None
# True

nn.Paramter

import torch.nn.functional as F

# With square kernels and equal stride
filters = torch.randn(8,4,3,3)
weiths = torch.nn.Parameter(torch.randn(8,4,3,3))

inputs = torch.randn(1,4,5,5)

out = F.conv2d(inputs, weiths, stride=2,padding=1)
print(out.shape)

con2d = torch.nn.Conv2d(4,8,3,stride=2,padding=1)

out_2 = con2d(inputs)

print(out_2.shape)

全部评论

相关推荐

找工作勤劳小蜜蜂:自我描述部分太差,完全看不出想从事什么行业什么岗位,也看不出想在哪个地区发展,这样 会让HR很犹豫,从而把你简历否决掉。现在企业都很注重员工稳定性和专注性,特别对于热爱本行业的员工。 你实习的工作又太传统的it开发(老旧),这部分公司已经趋于被淘汰,新兴的互联网服务业,比如物流,电商,新传媒,游戏开发和传统的It开发有天然区别。不是说传统It开发不行,而是就业岗位太少,基本趋于饱和,很多老骨头还能坚持,不需要新血液。 工作区域(比如长三角,珠三角,成渝)等也是HR考虑的因素之一,也是要你有个坚定的决心。否则去几天,人跑了,HR会被用人单位骂死。
点赞 评论 收藏
分享
评论
点赞
收藏
分享

创作者周榜

更多
正在热议
更多
# AI面会问哪些问题? #
24847次浏览 491人参与
# 中国电信笔试 #
31080次浏览 283人参与
# 米连集团26产品管培生项目 #
12951次浏览 285人参与
# 你的实习产出是真实的还是包装的? #
18817次浏览 330人参与
# 如果秋招能重来,我会____ #
96691次浏览 500人参与
# 春招至今,你的战绩如何? #
59982次浏览 543人参与
# 开放七大实习专项,百度暑期实习值得冲吗 #
14143次浏览 209人参与
# i人适合做什么工作 #
36914次浏览 124人参与
# 我是面试官,请用一句话让我破防 #
79511次浏览 219人参与
# 哪些公司真双非友好? #
69200次浏览 287人参与
# 金三银四,你的春招进行到哪个阶段了? #
21567次浏览 277人参与
# 找AI工作可以去哪些公司? #
7673次浏览 186人参与
# 从事AI岗需要掌握哪些技术栈? #
7676次浏览 251人参与
# 投递几十家公司,到现在0offer,大家都一样吗 #
339915次浏览 2165人参与
# 面试尴尬现场 #
220755次浏览 861人参与
# 五一之后,实习真的很难找吗? #
102797次浏览 584人参与
# 你做过最难的笔试是哪家公司 #
30108次浏览 193人参与
# 你小时候最想从事什么职业 #
159840次浏览 2072人参与
# 应届生第一份工资要多少合适 #
20483次浏览 84人参与
# 阿里笔试 #
176460次浏览 1302人参与
# 一张图晒出你司的标语 #
3821次浏览 72人参与
# 面试被问期望薪资时该如何回答 #
382459次浏览 2163人参与
牛客网
牛客网在线编程
牛客网题解
牛客企业服务