深度学习网络的参数的获取及计算

#获取网络中的参数

#1、定义一个网络
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(in_channels=3, out_channels=5, kernel_size=3)

    def forward(self,x):
        x = self.conv(x)
        return x

#2、随机模拟获得一个图像数据, 按 NCHW 的顺序填入数字
img = torch.randn((1, 3, 5, 5))

#3、创建网络实例
net = Net()

#4、网络输出
y = net(img)

#5、查看网络中的参数及计算网络中参数总量
#网络中的参数的计算方式:参数每一维尺寸相乘,如[5,3,3,3]--- 5*3*3*3
total_params = 0
for param in net.parameters():
    print(param)     #查看网络中每层每个参数的数值
    print('--------------')
    print(param.size()) #查看网络中每层参数的尺寸
    print('***************')

    dims = len(param.size()) #获取参数尺寸的维数
    p = 1
    for i in range(dims):
        p *= param.size(i)
    total_params += p
print('总参数数量为:', total_params)



#6、同时查看参数名称和参数    
for name, param in net.named_parameters():
    print('同时查看参数名称和参数')
    print(name)
    print(param)
    print(param.requires_grad) #查看该参数是否需要进梯度更新

Jupyter上的python代码

全部评论

相关推荐

评论
点赞
收藏
分享

创作者周榜

更多
正在热议
更多
# 春招至今,你的战绩如何? #
7177次浏览 66人参与
# 你的实习产出是真实的还是包装的? #
1396次浏览 34人参与
# 米连集团26产品管培生项目 #
5067次浏览 206人参与
# 军工所铁饭碗 vs 互联网高薪资,你会选谁 #
7186次浏览 38人参与
# 简历第一个项目做什么 #
31388次浏览 317人参与
# 当下环境,你会继续卷互联网,还是看其他行业机会 #
186612次浏览 1116人参与
# MiniMax求职进展汇总 #
23325次浏览 304人参与
# 研究所笔面经互助 #
118806次浏览 577人参与
# 面试紧张时你会有什么表现? #
30431次浏览 188人参与
# 简历中的项目经历要怎么写? #
309712次浏览 4171人参与
# AI时代,哪些岗位最容易被淘汰 #
62931次浏览 760人参与
# 职能管理面试记录 #
10749次浏览 59人参与
# 网易游戏笔试 #
6391次浏览 83人参与
# 腾讯音乐求职进展汇总 #
160477次浏览 1107人参与
# 把自己当AI,现在最消耗你token的问题是什么? #
7054次浏览 155人参与
# 正在春招的你,也参与了去年秋招吗? #
362899次浏览 2633人参与
# 你怎么看待AI面试 #
179553次浏览 1193人参与
# 小红书求职进展汇总 #
226958次浏览 1357人参与
# 你觉得通信/硬件有必要实习吗? #
155407次浏览 1065人参与
# 从哪些方向判断这个offer值不值得去? #
56717次浏览 357人参与
# 校招笔试 #
468348次浏览 2957人参与
# 你的房租占工资的比例是多少? #
92165次浏览 896人参与
牛客网
牛客网在线编程
牛客网题解
牛客企业服务