应用torchinfo计算网络的参数量

本文涉及的产品
函数计算FC,每月15万CU 3个月
简介: 应用torchinfo计算网络的参数量

1 问题

定义好一个VGG11网络模型后,我们需要验证一下我们的模型是否按需求准确无误的写出,这时可以用torchinfo库中的summary来打印一下模型各层的参数状况。这时发现表中有一个param以及在经过两个卷积后参数量(param)没变,出于想知道每层的param是怎么计算出来,于是对此进行探究。


2 方法

1、网络中的参数量(param)是什么?

param代表每一层需要训练的参数个数,在全连接层是突触权重的个数,在卷积层是卷积核的参数的个数。


2、网络中的参数量(param)的计算。

卷积层计算公式:Conv2d_param=(卷积核尺寸*输入图像通道+1)*卷积核数目

池化层:池化层不需要参数。

全连接计算公式:Fc_param=(输入数据维度+1)*神经元个数


3、解释一下图表中vgg网络的结构和组成。vgg11的网络结构即表中的第一列:

conv3-64→maxpool→conv3-128→maxpool→conv3-256→conv3-256→maxpool→conv3-512→conv3-512→maxpool→conv3-512→conv3-512→maxpool→FC-4096→FC-4096→FC-1000→softmax。


4、代码展示

import torch
from torch import nn
from torchinfo import summary
class MyNet(nn.Module):
   #定义哪些层
   def __init__(self) :
       super().__init__()
       #(1)conv3-64
       self.conv1 = nn.Conv2d(
           in_channels=1, #输入图像通道数
           out_channels=64,#卷积产生的通道数(卷积核个数)
           kernel_size=3,#卷积核尺寸
           stride=1,
           padding=1       #不改变特征图大小
       )  
       self.max_pool_1 = nn.MaxPool2d(2)
       #(2)conv3-128
       self.conv2 = nn.Conv2d(
           in_channels=64,
           out_channels=128,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.max_pool_2 = nn.MaxPool2d(2)
       #(3)conv3-256
       self.conv3 = nn.Conv2d(
           in_channels=128,
           out_channels=256,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.conv4 = nn.Conv2d(
           in_channels=256,
           out_channels=256,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.max_pool_3 = nn.MaxPool2d(2)
       #(4)conv3-512
       self.conv5 = nn.Conv2d(
           in_channels=256,
           out_channels=512,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.conv6 = nn.Conv2d(
           in_channels=512,
           out_channels=512,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.max_pool_4 = nn.MaxPool2d(2)
       #(5)conv3-512
       self.conv7 = nn.Conv2d(
           in_channels=512,
           out_channels=512,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.conv8 = nn.Conv2d(
           in_channels=512,
           out_channels=512,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.max_pool_5 = nn.MaxPool2d(2)
       self.fc1 = nn.Linear(in_features=7*7*512,out_features=4096)
       self.fc2 = nn.Linear(in_features=4096,out_features=4096)
       self.fc3 = nn.Linear(in_features=4096,out_features=1000)
   #计算流向
   def forward(self,x):
       x = self.conv1(x)
       x = self.max_pool_1(x)
       x = self.conv2(x)
       x = self.max_pool_2(x)
       x = self.conv3(x)
       x = self.conv4(x)
       x = self.max_pool_3(x)
       x = self.conv5(x)
       x = self.conv6(x)
       x = self.max_pool_4(x)
       x = self.conv7(x)
       x = self.conv8(x)
       x = self.max_pool_5(x)
       x = torch.flatten(x,1)  #[B,C,H,W]从C开始flatten,B不用flatten,所以要加1
       x = self.fc1(x)
       x = self.fc2(x)
       out = self.fc3(x)
       return out
if __name__ == '__main__':
   x = torch.rand(128,1,224,224)
   net = MyNet()
   out = net(x)
   #print(out.shape)
   summary(net, (12,1,224,224))

输出结果:

图片中红色方块计算过程:

1:相关代码及计算过程(卷积层)

self.conv7 = nn.Conv2d(
           in_channels=512,
           out_channels=512,
           kernel_size=3,
           stride=1,
           padding=1
       )

Conv2d_param= (3*3*512+1)*512=2,359,808(Conv2d-12代码同,故param同)

2:相关代码及计算过程

self.fc3 = nn.Linear(in_features=4096,out_features=1000)

Fc_fc_param=(4096+1)*1000=4,097,000


3 结语

以上为一般情况下参数量计算方法,当然还有很多细节与很多其他情况下的计算方法没有介绍,主要用来形容模型的大小程度,针对不同batch_size下param的不同,可以用于参考来选择更合适的batch_size。

相关实践学习
【文生图】一键部署Stable Diffusion基于函数计算
本实验教你如何在函数计算FC上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。函数计算提供一定的免费额度供用户使用。本实验答疑钉钉群:29290019867
建立 Serverless 思维
本课程包括: Serverless 应用引擎的概念, 为开发者带来的实际价值, 以及让您了解常见的 Serverless 架构模式
目录
相关文章
|
9天前
|
监控 安全
公司上网监控:Mercury 在网络监控高级逻辑编程中的应用
在数字化办公环境中,公司对员工上网行为的监控至关重要。Mercury 作为一种强大的编程工具,展示了在公司上网监控领域的独特优势。本文介绍了使用 Mercury 实现网络连接监听、数据解析和日志记录的功能,帮助公司确保信息安全和工作效率。
78 51
|
5天前
|
SQL 安全 前端开发
PHP与现代Web开发:构建高效的网络应用
【10月更文挑战第37天】在数字化时代,PHP作为一门强大的服务器端脚本语言,持续影响着Web开发的面貌。本文将深入探讨PHP在现代Web开发中的角色,包括其核心优势、面临的挑战以及如何利用PHP构建高效、安全的网络应用。通过具体代码示例和最佳实践的分享,旨在为开发者提供实用指南,帮助他们在不断变化的技术环境中保持竞争力。
RS-485网络中的标准端接与交流电端接应用解析
RS-485,作为一种广泛应用的差分信号传输标准,因其传输距离远、抗干扰能力强、支持多点通讯等优点,在工业自动化、智能建筑、交通运输等领域得到了广泛应用。在构建RS-485网络时,端接技术扮演着至关重要的角色,它直接影响到网络的信号完整性、稳定性和通信质量。
|
6天前
|
机器学习/深度学习 人工智能 算法框架/工具
深度学习中的卷积神经网络(CNN)及其在图像识别中的应用
【10月更文挑战第36天】探索卷积神经网络(CNN)的神秘面纱,揭示其在图像识别领域的威力。本文将带你了解CNN的核心概念,并通过实际代码示例,展示如何构建和训练一个简单的CNN模型。无论你是深度学习的初学者还是希望深化理解,这篇文章都将为你提供有价值的见解。
|
6天前
|
网络协议 数据挖掘 5G
适用于金融和交易应用的低延迟网络:技术、架构与应用
适用于金融和交易应用的低延迟网络:技术、架构与应用
31 5
|
6天前
|
运维 物联网 网络虚拟化
网络功能虚拟化(NFV):定义、原理及应用前景
网络功能虚拟化(NFV):定义、原理及应用前景
21 3
|
6天前
|
数据可视化 算法 安全
员工上网行为管理软件:S - PLUS 在网络统计分析中的应用
在数字化办公环境中,S-PLUS 员工上网行为管理软件通过精准的数据收集、深入的流量分析和直观的可视化呈现,有效帮助企业管理员工上网行为,保障网络安全和提高运营效率。
15 1
|
10天前
|
机器学习/深度学习 移动开发 自然语言处理
HTML5与神经网络技术的结合有哪些其他应用
HTML5与神经网络技术的结合有哪些其他应用
26 3
|
8天前
|
机器学习/深度学习 人工智能 安全
人工智能与机器学习在网络安全中的应用
人工智能与机器学习在网络安全中的应用
25 0
|
10天前
|
机器学习/深度学习 人工智能 TensorFlow
深度学习中的卷积神经网络(CNN)及其在图像识别中的应用
【10月更文挑战第32天】本文将介绍深度学习中的一个重要分支——卷积神经网络(CNN),以及其在图像识别领域的应用。我们将通过一个简单的代码示例,展示如何使用Python和TensorFlow库构建一个基本的CNN模型,并对其进行训练和测试。