一文搞懂 FFN / RNN / CNN 的参数量计算公式 !!

简介: 一文搞懂 FFN / RNN / CNN 的参数量计算公式 !!

前言

为什么我们需要了解计算深度学习模型中的参数数量?

  • 计算复杂性和资源需求:模型参数越多,通常需要的计算资源(如处理器时间和内存)也越多,了解参数数量有助于估计训练和推理过程中的资源需求。
  • 模型性能:容量越大的模型可以捕获更复杂的模式,但也容易过拟合,即在训练数据上表现良好但在未见过的数据上表现差,因此,了解参数数量有助于在模型复杂性和泛化能力之间取得平衡。
  • 内存需求:对于有限的硬件资源,如移动设备和嵌入式系统,了解参数数量有助于设计合适的模型结构。
  • 调优和优化:了解模型的参数数量有助于调优超参数,比如学习率和正则化项。

1、前置条件

为了详细说明,本文重点介绍三类网络训练参数的计算方式:

  • Feed-Forward Neural Network(FFN)
  • Recurrent Neural Network(RNN)
  • Convolutional Neural Network(CNN)

同时,本文将使用Keras的API构建模型,以方便模型设计和编写简洁的代码。首先导入相关的库函数:

from keras.layers import Input, Dense, SimpleRNN, LSTM, GRU, Conv2D
from keras.layers import Bidirectional
from keras.models import Model

使用上述库函数在建立模型后,通过调用 model.count_params() 来验证有多少参数用以训练。

2、前馈神经网络FFN

前馈神经网络相对比较简单,多个全连接层构成的网络结构,我们不妨假设:

  • i:输入维度
  • h:隐藏层大小
  • o:网络输出维度

那么一个隐藏层的参数的计算公式为:

num_params = (connections between layers + biases) in every layer
           = (i×h + h) + (h×o + o)

先来看个图例,如下:

观察上述图例中,我们知道 i=3,h=5,o=2,带入上述公式,得到的训练参数量为:

num_params = (3×5+5) + (5×2+2)
           = 32

我们用代码实现上述过程,如下:

input = Input((None, 3))
dense = Dense(5)(input)
output= Dense(2)(dense)
model = Model(input, output)
print(f"train params of the model is {model.count_params()}")

运行上述代码,得到结果如下:

3、循环神经网络RNN

前馈神经网络里相对简单,我们接下来分析循环神经网络的参数计算方式,这里假设:

  • g:一个单元中的FFN数量(一般来说,RNN结构中FFN数量为1,而GRU结构中FFN数量为3个,LSTM结构中FFN数量为4个)
  • h:隐藏单元的大小
  • i:输入大小

RNN中对于每个FFN,最开始输入状态和隐藏状态是concat在一起作为输入的,因此每个FFN具有 (h+i) x h + h 个参数。所以总的参数量的计算公式为:

num_params = g × [(h+i)×h + h]

我们来看以下LSTM的例子,含有2个隐藏单元,输入维度为3,图示如下:

观察上图,我们将 g=4,h=2,i=3 带入上式,得到上述LSTM的参数量为:

num_params = g × [(h+i)×h + h] 
           = 4 × [(2+3)×2 + 2] 
           = 48

我们用代码验证上述过程,如下:

input = Input((None, 3))
lstm  = LSTM(2)(input)
model = Model(input, lstm)
print(f"train params of the model is {model.count_params()}")

结果如下:

4、卷积神经网络CNN

对于卷积神经网络,我们主要观察卷积层,这里对每一层的卷积,我们假设:

  • i:输入特征图的通道数
  • f:滤波器的尺寸
  • o:输出的通道数(等于滤波器的个数)

则对应卷积层的参数量计算公式为:

num_params = weights + biases 
           = [i × (f×f) × o] + o

我们来看个例子,对灰度图像使用 2x2 滤波器,输出为3个通道,图示如下:

观察上图,我们知道 i=1,f=2,o=3 带入上式,得到结果为:

num_params = [i × (f×f) × o] + o 
           = [1 × (2×2) × 3] + 3 
           = 15

我们用代码进行验证,如下所示:

input  = Input((None, None, 1))
conv2d = Conv2D(kernel_size=2, filters=3)(input)
model  = Model(input, conv2d)
print(f"train params of the model is {model.count_params()}")


得到结果如下:

5、复杂例子

由于卷积神经网络多在计算机视觉领域得到应用,我们再来看个稍微复杂点的例子,针对2个通道输入使用32x2 的卷积核进行卷积操作,图示如下:

观察上图,我们知道 i=2,f=2,o=3 带入上式,得到结果为:

num_params = [i × (f×f) × o] + o 
           = [2 × (2×2) × 3] + 3 
           = 27

我们用代码进行验证,如下所示:

input  = Input((None, None, 2))
conv2d = Conv2D(kernel_size=2, filters=3)(input)
model  = Model(input, conv2d)
print(f"train params of the model is {model.count_params()}")

得到结果如下:

参考: AI算法之道

目录
相关文章
|
6月前
|
机器学习/深度学习 自然语言处理 异构计算
Python深度学习面试:CNN、RNN与Transformer详解
【4月更文挑战第16天】本文介绍了深度学习面试中关于CNN、RNN和Transformer的常见问题和易错点,并提供了Python代码示例。理解这三种模型的基本组成、工作原理及其在图像识别、文本处理等任务中的应用是评估技术实力的关键。注意点包括:模型结构的混淆、过拟合的防治、输入序列长度处理、并行化训练以及模型解释性。掌握这些知识和技巧,将有助于在面试中展现优秀的深度学习能力。
207 11
|
5月前
|
机器学习/深度学习
【从零开始学习深度学习】23. CNN中的多通道输入及多通道输出计算方式及1X1卷积层介绍
【从零开始学习深度学习】23. CNN中的多通道输入及多通道输出计算方式及1X1卷积层介绍
【从零开始学习深度学习】23. CNN中的多通道输入及多通道输出计算方式及1X1卷积层介绍
|
4月前
|
机器学习/深度学习 人工智能 自然语言处理
算法金 | 秒懂 AI - 深度学习五大模型:RNN、CNN、Transformer、BERT、GPT 简介
**RNN**,1986年提出,用于序列数据,如语言模型和语音识别,但原始模型有梯度消失问题。**LSTM**和**GRU**通过门控解决了此问题。 **CNN**,1989年引入,擅长图像处理,卷积层和池化层提取特征,经典应用包括图像分类和物体检测,如LeNet-5。 **Transformer**,2017年由Google推出,自注意力机制实现并行计算,优化了NLP效率,如机器翻译。 **BERT**,2018年Google的双向预训练模型,通过掩码语言模型改进上下文理解,适用于问答和文本分类。
151 9
|
4月前
|
机器学习/深度学习 PyTorch 算法框架/工具
图神经网络是一类用于处理图结构数据的神经网络。与传统的深度学习模型(如卷积神经网络CNN和循环神经网络RNN)不同,
图神经网络是一类用于处理图结构数据的神经网络。与传统的深度学习模型(如卷积神经网络CNN和循环神经网络RNN)不同,
|
4月前
|
机器学习/深度学习 人工智能 自然语言处理
计算机视觉借助深度学习实现了革命性进步,从图像分类到复杂场景理解,深度学习模型如CNN、RNN重塑了领域边界。
【7月更文挑战第2天】计算机视觉借助深度学习实现了革命性进步,从图像分类到复杂场景理解,深度学习模型如CNN、RNN重塑了领域边界。AlexNet开启新时代,后续模型不断优化,推动对象检测、语义分割、图像生成等领域发展。尽管面临数据隐私、模型解释性等挑战,深度学习已广泛应用于安防、医疗、零售和农业,预示着更智能、高效的未来,同时也强调了技术创新、伦理考量的重要性。
61 1
|
6月前
|
机器学习/深度学习 人工智能 自然语言处理
一文介绍CNN/RNN/GAN/Transformer等架构 !!
一文介绍CNN/RNN/GAN/Transformer等架构 !!
202 5
|
6月前
|
机器学习/深度学习 算法 TensorFlow
TensorFlow 2keras开发深度学习模型实例:多层感知器(MLP),卷积神经网络(CNN)和递归神经网络(RNN)
TensorFlow 2keras开发深度学习模型实例:多层感知器(MLP),卷积神经网络(CNN)和递归神经网络(RNN)
|
6月前
|
机器学习/深度学习 自然语言处理 并行计算
神经网络结构——CNN、RNN、LSTM、Transformer !!
神经网络结构——CNN、RNN、LSTM、Transformer !!
294 0
|
6月前
|
机器学习/深度学习 存储 人工智能
存内计算芯片研究进展及应用—以基于NorFlash的卷积神经网络量化及部署研究突出存内计算特性
存内计算芯片研究进展及应用—以基于NorFlash的卷积神经网络量化及部署研究突出存内计算特性
375 3
|
6月前
|
机器学习/深度学习 人工智能 算法
深度学习及CNN、RNN、GAN等神经网络简介(图文解释 超详细)
深度学习及CNN、RNN、GAN等神经网络简介(图文解释 超详细)
599 1