应用torchinfo计算网络的参数量

本文涉及的产品
简介: 应用torchinfo计算网络的参数量

1 问题

定义好一个VGG11网络模型后,我们需要验证一下我们的模型是否按需求准确无误的写出,这时可以用torchinfo库中的summary来打印一下模型各层的参数状况。这时发现表中有一个param以及在经过两个卷积后参数量(param)没变,出于想知道每层的param是怎么计算出来,于是对此进行探究。


2 方法

1、网络中的参数量(param)是什么?

param代表每一层需要训练的参数个数,在全连接层是突触权重的个数,在卷积层是卷积核的参数的个数。


2、网络中的参数量(param)的计算。

卷积层计算公式:Conv2d_param=(卷积核尺寸*输入图像通道+1)*卷积核数目

池化层:池化层不需要参数。

全连接计算公式:Fc_param=(输入数据维度+1)*神经元个数


3、解释一下图表中vgg网络的结构和组成。vgg11的网络结构即表中的第一列:

conv3-64→maxpool→conv3-128→maxpool→conv3-256→conv3-256→maxpool→conv3-512→conv3-512→maxpool→conv3-512→conv3-512→maxpool→FC-4096→FC-4096→FC-1000→softmax。


4、代码展示

import torch
from torch import nn
from torchinfo import summary
class MyNet(nn.Module):
   #定义哪些层
   def __init__(self) :
       super().__init__()
       #(1)conv3-64
       self.conv1 = nn.Conv2d(
           in_channels=1, #输入图像通道数
           out_channels=64,#卷积产生的通道数(卷积核个数)
           kernel_size=3,#卷积核尺寸
           stride=1,
           padding=1       #不改变特征图大小
       )  
       self.max_pool_1 = nn.MaxPool2d(2)
       #(2)conv3-128
       self.conv2 = nn.Conv2d(
           in_channels=64,
           out_channels=128,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.max_pool_2 = nn.MaxPool2d(2)
       #(3)conv3-256
       self.conv3 = nn.Conv2d(
           in_channels=128,
           out_channels=256,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.conv4 = nn.Conv2d(
           in_channels=256,
           out_channels=256,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.max_pool_3 = nn.MaxPool2d(2)
       #(4)conv3-512
       self.conv5 = nn.Conv2d(
           in_channels=256,
           out_channels=512,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.conv6 = nn.Conv2d(
           in_channels=512,
           out_channels=512,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.max_pool_4 = nn.MaxPool2d(2)
       #(5)conv3-512
       self.conv7 = nn.Conv2d(
           in_channels=512,
           out_channels=512,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.conv8 = nn.Conv2d(
           in_channels=512,
           out_channels=512,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.max_pool_5 = nn.MaxPool2d(2)
       self.fc1 = nn.Linear(in_features=7*7*512,out_features=4096)
       self.fc2 = nn.Linear(in_features=4096,out_features=4096)
       self.fc3 = nn.Linear(in_features=4096,out_features=1000)
   #计算流向
   def forward(self,x):
       x = self.conv1(x)
       x = self.max_pool_1(x)
       x = self.conv2(x)
       x = self.max_pool_2(x)
       x = self.conv3(x)
       x = self.conv4(x)
       x = self.max_pool_3(x)
       x = self.conv5(x)
       x = self.conv6(x)
       x = self.max_pool_4(x)
       x = self.conv7(x)
       x = self.conv8(x)
       x = self.max_pool_5(x)
       x = torch.flatten(x,1)  #[B,C,H,W]从C开始flatten,B不用flatten,所以要加1
       x = self.fc1(x)
       x = self.fc2(x)
       out = self.fc3(x)
       return out
if __name__ == '__main__':
   x = torch.rand(128,1,224,224)
   net = MyNet()
   out = net(x)
   #print(out.shape)
   summary(net, (12,1,224,224))

输出结果:

图片中红色方块计算过程:

1:相关代码及计算过程(卷积层)

self.conv7 = nn.Conv2d(
           in_channels=512,
           out_channels=512,
           kernel_size=3,
           stride=1,
           padding=1
       )

Conv2d_param= (3*3*512+1)*512=2,359,808(Conv2d-12代码同,故param同)

2:相关代码及计算过程

self.fc3 = nn.Linear(in_features=4096,out_features=1000)

Fc_fc_param=(4096+1)*1000=4,097,000


3 结语

以上为一般情况下参数量计算方法,当然还有很多细节与很多其他情况下的计算方法没有介绍,主要用来形容模型的大小程度,针对不同batch_size下param的不同,可以用于参考来选择更合适的batch_size。

相关实践学习
基于函数计算一键部署掌上游戏机
本场景介绍如何使用阿里云计算服务命令快速搭建一个掌上游戏机。
建立 Serverless 思维
本课程包括: Serverless 应用引擎的概念, 为开发者带来的实际价值, 以及让您了解常见的 Serverless 架构模式
目录
相关文章
|
6天前
|
机器学习/深度学习 数据采集 自然语言处理
理解并应用机器学习算法:神经网络深度解析
【5月更文挑战第15天】本文深入解析了神经网络的基本原理和关键组成,包括神经元、层、权重、偏置及损失函数。介绍了神经网络在图像识别、NLP等领域的应用,并涵盖了从数据预处理、选择网络结构到训练与评估的实践流程。理解并掌握这些知识,有助于更好地运用神经网络解决实际问题。随着技术发展,神经网络未来潜力无限。
|
4天前
|
存储 安全 网络安全
云端防御战线:融合云计算与先进网络安全策略
【5月更文挑战第17天】 随着企业纷纷迁移至云平台,数据和服务的集中化带来了前所未有的便利性。然而,这种集中化也使得网络攻击的潜在影响范围和危害程度急剧扩大。本文探讨了在快速发展的云计算环境中,如何通过综合运用最新的网络安全技术和策略来保障信息资产的安全。重点讨论了多租户环境下的数据隔离问题、加密技术的应用、入侵检测系统的集成以及安全事件管理和响应机制。同时,分析了未来云计算安全领域可能面临的新挑战,并提出了相应的应对措施。
|
3天前
|
网络协议 Python
Python 网络编程实战:构建高效的网络应用
【5月更文挑战第18天】Python在数字化时代成为构建网络应用的热门语言,因其简洁的语法和强大功能。本文介绍了网络编程基础知识,包括TCP和UDP套接字,强调异步编程、数据压缩和连接池的关键作用。提供了一个简单的TCP服务器和客户端代码示例,并提及优化与改进方向,鼓励读者通过实践提升网络应用性能。
20 6
|
6天前
|
算法 网络架构
网络地址的相关计算(超详细,快来快来!)
网络地址的相关计算(超详细,快来快来!)
28 0
|
6天前
|
开发框架 网络协议 Java
【计算机网络】—— 网络应用通信基本原理
【计算机网络】—— 网络应用通信基本原理
12 0
|
6天前
|
存储 安全 网络安全
云端防御:融合云计算与网络安全的未来
【5月更文挑战第11天】 在数字化时代,数据是新的石油,而云计算则是提炼这种石油的超级工厂。随着企业和个人越来越依赖于云服务来存储和处理数据,网络安全的重要性也呈指数级增长。本文将探讨云计算与网络安全的交汇点,分析云服务模型中的安全挑战,并提出一系列创新策略和技术,用以增强信息安全。从身份验证到数据加密,再到入侵检测系统,我们将一探究竟如何在不牺牲性能的前提下,确保云环境的稳固和可信。此外,我们还将讨论未来的趋势和潜在的研究方向,以期打造一个更加安全、可靠的云计算生态系统。
11 0
|
6天前
|
存储 安全 算法
网络安全与信息安全:防范漏洞、应用加密技术与培养安全意识
【5月更文挑战第10天】在数字化时代,网络安全与信息安全已成为维护社会稳定、保障个人隐私和确保企业资产的关键。面对日益复杂的网络威胁,本文深入探讨了网络安全漏洞的成因与影响、加密技术的基本原理与应用,以及提升全民网络安全意识的必要性和方法。通过分析当前网络安全形势,提供了一系列针对性的技术解决方案和管理策略,旨在为读者构建一个全方位的网络安全防护体系。
13 1
|
6天前
|
安全
AC/DC电源模块在通信与网络设备中的应用的研究
AC/DC电源模块在通信与网络设备中的应用的研究
AC/DC电源模块在通信与网络设备中的应用的研究
|
6天前
BOSHIDA AC/DC电源模块在通信与网络设备中的应用研究
BOSHIDA AC/DC电源模块在通信与网络设备中的应用研究
BOSHIDA AC/DC电源模块在通信与网络设备中的应用研究
|
6天前
|
监控 安全 算法
网络安全与信息安全:防范漏洞、应用加密技术及提升安全意识
【5月更文挑战第8天】 在数字化时代,网络安全与信息安全已成为我们不可忽视的问题。本文将深入探讨网络安全漏洞的产生原因及其危害,加密技术的种类和应用,以及提升个人和企业的安全意识的重要性。通过对这些方面的知识分享,旨在帮助读者更好地理解网络安全的重要性,提高防范意识,保护个人信息和数据安全。

热门文章

最新文章