深度学习网络的参数的获取及计算
#获取网络中的参数 #1、定义一个网络 import torch import torch.nn as nn class Net(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(in_channels=3, out_channels=5, kernel_size=3) def forward(self,x): x = self.conv(x) return x #2、随机模拟获得一个图像数据, 按 NCHW 的顺序填入数字 img = torch.randn((1, 3, 5, 5)) #3、创建网络实例 net = Net() #4、网络输出 y = net(img) #5、查看网络中的参数及计算网络中参数总量 #网络中的参数的计算方式:参数每一维尺寸相乘,如[5,3,3,3]--- 5*3*3*3 total_params = 0 for param in net.parameters(): print(param) #查看网络中每层每个参数的数值 print('--------------') print(param.size()) #查看网络中每层参数的尺寸 print('***************') dims = len(param.size()) #获取参数尺寸的维数 p = 1 for i in range(dims): p *= param.size(i) total_params += p print('总参数数量为:', total_params) #6、同时查看参数名称和参数 for name, param in net.named_parameters(): print('同时查看参数名称和参数') print(name) print(param) print(param.requires_grad) #查看该参数是否需要进梯度更新