应用torchinfo计算网络的参数量

简介: 应用torchinfo计算网络的参数量

1 问题

定义好一个VGG11网络模型后,我们需要验证一下我们的模型是否按需求准确无误的写出,这时可以用torchinfo库中的summary来打印一下模型各层的参数状况。这时发现表中有一个param以及在经过两个卷积后参数量(param)没变,出于想知道每层的param是怎么计算出来,于是对此进行探究。


2 方法

1、网络中的参数量(param)是什么?

param代表每一层需要训练的参数个数,在全连接层是突触权重的个数,在卷积层是卷积核的参数的个数。


2、网络中的参数量(param)的计算。

卷积层计算公式:Conv2d_param=(卷积核尺寸*输入图像通道+1)*卷积核数目

池化层:池化层不需要参数。

全连接计算公式:Fc_param=(输入数据维度+1)*神经元个数


3、解释一下图表中vgg网络的结构和组成。vgg11的网络结构即表中的第一列:

conv3-64→maxpool→conv3-128→maxpool→conv3-256→conv3-256→maxpool→conv3-512→conv3-512→maxpool→conv3-512→conv3-512→maxpool→FC-4096→FC-4096→FC-1000→softmax。


4、代码展示

import torch
from torch import nn
from torchinfo import summary
class MyNet(nn.Module):
   #定义哪些层
   def __init__(self) :
       super().__init__()
       #(1)conv3-64
       self.conv1 = nn.Conv2d(
           in_channels=1, #输入图像通道数
           out_channels=64,#卷积产生的通道数(卷积核个数)
           kernel_size=3,#卷积核尺寸
           stride=1,
           padding=1       #不改变特征图大小
       )  
       self.max_pool_1 = nn.MaxPool2d(2)
       #(2)conv3-128
       self.conv2 = nn.Conv2d(
           in_channels=64,
           out_channels=128,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.max_pool_2 = nn.MaxPool2d(2)
       #(3)conv3-256
       self.conv3 = nn.Conv2d(
           in_channels=128,
           out_channels=256,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.conv4 = nn.Conv2d(
           in_channels=256,
           out_channels=256,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.max_pool_3 = nn.MaxPool2d(2)
       #(4)conv3-512
       self.conv5 = nn.Conv2d(
           in_channels=256,
           out_channels=512,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.conv6 = nn.Conv2d(
           in_channels=512,
           out_channels=512,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.max_pool_4 = nn.MaxPool2d(2)
       #(5)conv3-512
       self.conv7 = nn.Conv2d(
           in_channels=512,
           out_channels=512,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.conv8 = nn.Conv2d(
           in_channels=512,
           out_channels=512,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.max_pool_5 = nn.MaxPool2d(2)
       self.fc1 = nn.Linear(in_features=7*7*512,out_features=4096)
       self.fc2 = nn.Linear(in_features=4096,out_features=4096)
       self.fc3 = nn.Linear(in_features=4096,out_features=1000)
   #计算流向
   def forward(self,x):
       x = self.conv1(x)
       x = self.max_pool_1(x)
       x = self.conv2(x)
       x = self.max_pool_2(x)
       x = self.conv3(x)
       x = self.conv4(x)
       x = self.max_pool_3(x)
       x = self.conv5(x)
       x = self.conv6(x)
       x = self.max_pool_4(x)
       x = self.conv7(x)
       x = self.conv8(x)
       x = self.max_pool_5(x)
       x = torch.flatten(x,1)  #[B,C,H,W]从C开始flatten,B不用flatten,所以要加1
       x = self.fc1(x)
       x = self.fc2(x)
       out = self.fc3(x)
       return out
if __name__ == '__main__':
   x = torch.rand(128,1,224,224)
   net = MyNet()
   out = net(x)
   #print(out.shape)
   summary(net, (12,1,224,224))

输出结果:

图片中红色方块计算过程:

1:相关代码及计算过程(卷积层)

self.conv7 = nn.Conv2d(
           in_channels=512,
           out_channels=512,
           kernel_size=3,
           stride=1,
           padding=1
       )

Conv2d_param= (3*3*512+1)*512=2,359,808(Conv2d-12代码同,故param同)

2:相关代码及计算过程

self.fc3 = nn.Linear(in_features=4096,out_features=1000)

Fc_fc_param=(4096+1)*1000=4,097,000


3 结语

以上为一般情况下参数量计算方法,当然还有很多细节与很多其他情况下的计算方法没有介绍,主要用来形容模型的大小程度,针对不同batch_size下param的不同,可以用于参考来选择更合适的batch_size。

相关实践学习
【AI破次元壁合照】少年白马醉春风,函数计算一键部署AI绘画平台
本次实验基于阿里云函数计算产品能力开发AI绘画平台,可让您实现“破次元壁”与角色合照,为角色换背景效果,用AI绘图技术绘出属于自己的少年江湖。
从 0 入门函数计算
在函数计算的架构中,开发者只需要编写业务代码,并监控业务运行情况就可以了。这将开发者从繁重的运维工作中解放出来,将精力投入到更有意义的开发任务上。
目录
相关文章
|
2月前
|
机器学习/深度学习 PyTorch TensorFlow
卷积神经网络深度解析:从基础原理到实战应用的完整指南
蒋星熠Jaxonic,深度学习探索者。深耕TensorFlow与PyTorch,分享框架对比、性能优化与实战经验,助力技术进阶。
|
4月前
|
监控 安全 Shell
管道符在渗透测试与网络安全中的全面应用指南
管道符是渗透测试与网络安全中的关键工具,既可用于高效系统管理,也可能被攻击者利用实施命令注入、权限提升、数据外泄等攻击。本文全面解析管道符的基础原理、实战应用与防御策略,涵盖Windows与Linux系统差异、攻击技术示例及检测手段,帮助安全人员掌握其利用方式与防护措施,提升系统安全性。
225 6
|
7月前
|
人工智能 监控 安全
NTP网络子钟的技术架构与行业应用解析
在数字化与智能化时代,时间同步精度至关重要。西安同步电子科技有限公司专注时间频率领域,以“同步天下”品牌提供可靠解决方案。其明星产品SYN6109型NTP网络子钟基于网络时间协议,实现高精度时间同步,广泛应用于考场、医院、智慧场景等领域。公司坚持技术创新,产品通过权威认证,未来将结合5G、物联网等技术推动行业进步,引领精准时间管理新时代。
|
3月前
|
机器学习/深度学习 人工智能 算法
卷积神经网络深度解析:从基础原理到实战应用的完整指南
蒋星熠Jaxonic带你深入卷积神经网络(CNN)核心技术,从生物启发到数学原理,详解ResNet、注意力机制与模型优化,探索视觉智能的演进之路。
453 11
|
7月前
|
算法 JavaScript 数据安全/隐私保护
基于GA遗传优化的最优阈值计算认知异构网络(CHN)能量检测算法matlab仿真
本内容介绍了一种基于GA遗传优化的阈值计算方法在认知异构网络(CHN)中的应用。通过Matlab2022a实现算法,完整代码含中文注释与操作视频。能量检测算法用于感知主用户信号,其性能依赖检测阈值。传统固定阈值方法易受噪声影响,而GA算法通过模拟生物进化,在复杂环境中自动优化阈值,提高频谱感知准确性,增强CHN的通信效率与资源利用率。预览效果无水印,核心程序部分展示,适合研究频谱感知与优化算法的学者参考。
|
4月前
|
数据采集 存储 数据可视化
Python网络爬虫在环境保护中的应用:污染源监测数据抓取与分析
在环保领域,数据是决策基础,但分散在多个平台,获取困难。Python网络爬虫技术灵活高效,可自动化抓取空气质量、水质、污染源等数据,实现多平台整合、实时更新、结构化存储与异常预警。本文详解爬虫实战应用,涵盖技术选型、代码实现、反爬策略与数据分析,助力环保数据高效利用。
322 0
|
4月前
|
安全 Linux
利用Libevent在CentOS 7上打造异步网络应用
总结以上步骤,您可以在CentOS 7系统上,使用Libevent有效地构建和运行异步网络应用。通过采取正确的架构和代码设计策略,能保证网络应用的高效性和稳定性。
167 0
|
7月前
|
机器学习/深度学习 算法 测试技术
图神经网络在信息检索重排序中的应用:原理、架构与Python代码解析
本文探讨了基于图的重排序方法在信息检索领域的应用与前景。传统两阶段检索架构中,初始检索速度快但结果可能含噪声,重排序阶段通过强大语言模型提升精度,但仍面临复杂需求挑战
274 0
图神经网络在信息检索重排序中的应用:原理、架构与Python代码解析
|
6月前
|
监控 安全 Linux
AWK在网络安全中的高效应用:从日志分析到威胁狩猎
本文深入探讨AWK在网络安全中的高效应用,涵盖日志分析、威胁狩猎及应急响应等场景。通过实战技巧,助力安全工程师将日志分析效率提升3倍以上,构建轻量级监控方案。文章详解AWK核心语法与网络安全专用技巧,如时间范围分析、多条件过滤和数据脱敏,并提供性能优化与工具集成方案。掌握AWK,让安全工作事半功倍!
234 0
|
6月前
|
人工智能 安全 网络安全
网络安全厂商F5推出AI Gateway,化解大模型应用风险
网络安全厂商F5推出AI Gateway,化解大模型应用风险
239 0

热门文章

最新文章