使用PyTorch构建卷积神经网络(CNN)源码(详细步骤讲解+注释版) 01 手写数字识别

简介: 在使用PyTorch构建GAN生成对抗网络一文中,我们使用GAN构建了一个可以生成人脸图像的模型。但尽管是较为简单的模型,仍占用了1G左右的GPU内存,因此需要探索更加节约资源的方式。

1 卷积神经网络(CNN)简介


在使用PyTorch构建GAN生成对抗网络一文中,我们使用GAN构建了一个可以生成人脸图像的模型。但尽管是较为简单的模型,仍占用了1G左右的GPU内存,因此需要探索更加节约资源的方式。


卷积神经网络(Convolutional Neural Network,简称CNN)是一种深度学习模型,主要应用于图像处理、语音识别等领域。它的主要思想是通过卷积操作对输入图像的特征进行提取,再通过多层网络对特征进行分类和判断。


CNN的网络结构通常由卷积层、池化层和全连接层组成。卷积层的作用是对输入图像的特征进行提取,池化层的作用是减少数据的维度,以提高计算效率;全连接层则用于对特征进行分类和判断。


CNN可以通过训练学习到输入图像的特征表示,从而可以在未知图像上进行分类、识别等任务。它已经成为计算机视觉领域的重要技术,在诸多应用中取得了良好的效果。


6698439c66f74b9d8dd9868088162c02.png



2 从普通BP到CNN的网路结构转变


以前面建立好的手写数字分类器为例,(使用PyTorch构建神经网络构建手写数字分类器)在模型结构定义中,需要对神经网络层做出相应的修改:

self.model = nn.Sequential(
            # expand 1 to 10 filters
            nn.Conv2d(1, 10, kernel_size=5, stride=2),
            nn.LeakyReLU(0.02),
            nn.BatchNorm2d(10),
            # 10 filters to 10 filters
            nn.Conv2d(10, 10, kernel_size=3, stride=2),
            nn.LeakyReLU(0.02),
            nn.BatchNorm2d(10),
            View(250),
            nn.Linear(250, 10),
            nn.Sigmoid()
        )


更新后的神经网络架构如下:


第一个卷积层:把1个通道的输入图像扩展为10个通道,使用5x5的卷积核,步长为2。

第二个卷积层:10个通道的输入图像不变,使用3x3的卷积核,步长为2。

第一个全连接层:把250个节点的一维向量映射到10个节点。

其中用到的函数的含义:

4. Conv2d:对由一个或多个输入平面组成的输入信号进行二维卷积。第1个参数是输入参数,对于黑白图像,输入的通道数即为1。第2个参数是输出通道的数量。在上面的代码中,我们创建了10个卷积核,从而生成10个特征图。kernel_size函数代表了卷积核的大小,使用的是5×5的卷积核。stride是卷积核移动时的大小。该数值小于卷积核大小时,说明卷积核所覆盖的区域有重叠。

5. LeakyReLU:非线性激活函数,常用于生成对抗网络。

6. BatchNorm2d:批量归一化,用于提高网络的稳定性和收敛速度。

7. View:将多维张量展平为一维向量。(自定义函数,详见完整代码)

8. Sigmoid:S形函数,用于二分类问题的输出。


对于一个28*28像素的图片,第一步卷积之后将会生成一个12*12像素的图片(计算方式:共走了 28 − 5 2 \frac{28-5}{2} 步)。第二步卷积之后将会生成一个5*5像素的图片(计算方式:共走了 12 − 3 2

3 从普通BP到CNN的辅助修改


在网络结构中用到了View函数,在上面的参考博文中并未涉及这部分代码,因此把这给你功能进行补充。(与人脸识别篇代码中的View完全相同)


class View(nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.shape = shape,
    def forward(self, x):
        return x.view(*self.shape)


此外,修改后的CNN网络结构,其传入的图片应将其修改为4D数据。因此在模型训练时,将传入的数据进行变形。


start_time = time.perf_counter()  # 计时开始
C = Classifier()
epochs = 3
for i in range(epochs):
    print('training epoch', i+1, 'of', epochs)
    for label, image_data_tensor, target_tensor in mnist_dataset:
        C.train(image_data_tensor.view(1, 1, 28, 28), target_tensor)


注:上面两个VIEW并不相同,一个是我们自行定义用于分类器类使用的函数,一个是torch的自带功能。

除此之外代码均可保持不变,这部分的原始代码可在此找到到或文末留言申请。


4 模型评估




4f2acc3207034c2195a355fd90c3373a.png


在训练初期,可以看到模型的损失呈现迅速下降。下面使用测试集对模型准确率进行评价:

a9c7cc4150e14a31a6ea079d7fae66fa.png




使用一张图片来查看模型的生成。此处我们分别选择了一张数字0和数字6,可以发现与BP模型相比,CNN模型对结果变得更有信心了。




76f9e93264fc424eab8fd3025cf1a738.png



3ae3d35c76b548668d04a0d5438dd0c7.png




ddca2505a51d4f2f84f0140e44967edf.png


f23c0471b0824702afc2721ead777f8c.png





相关文章
|
6月前
|
前端开发 JavaScript 开发者
JavaScript:构建动态网络的引擎
JavaScript:构建动态网络的引擎
|
7月前
|
机器学习/深度学习 人工智能 算法
AI 基础知识从 0.6 到 0.7—— 彻底拆解深度神经网络训练的五大核心步骤
本文以一个经典的PyTorch手写数字识别代码示例为引子,深入剖析了简洁代码背后隐藏的深度神经网络(DNN)训练全过程。
1144 56
|
8月前
|
机器学习/深度学习 算法 量子技术
GQNN框架:让Python开发者轻松构建量子神经网络
为降低量子神经网络的研发门槛并提升其实用性,本文介绍一个名为GQNN(Generalized Quantum Neural Network)的Python开发框架。
184 4
GQNN框架:让Python开发者轻松构建量子神经网络
|
11月前
|
边缘计算 安全 算法
阿里云CDN:构建全球化智能加速网络的数字高速公路
阿里云CDN构建全球化智能加速网络,拥有2800多个边缘节点覆盖67个国家,实现毫秒级网络延迟。其三级节点拓扑结构与智能路由系统,结合流量预测模型,确保高命中率。全栈式加速技术包括QUIC协议优化和Brotli压缩算法,保障安全与性能。五层防御机制有效抵御攻击,行业解决方案涵盖视频、物联网及游戏等领域,支持新兴AR/VR与元宇宙需求,持续推动数字内容分发技术边界。
727 13
|
6月前
|
人工智能 监控 数据可视化
如何破解AI推理延迟难题:构建敏捷多云算力网络
本文探讨了AI企业在突破算力瓶颈后,如何构建高效、稳定的网络架构以支撑AI产品化落地。文章分析了典型AI IT架构的四个层次——流量接入层、调度决策层、推理服务层和训练算力层,并深入解析了AI架构对网络提出的三大核心挑战:跨云互联、逻辑隔离与业务识别、网络可视化与QoS控制。最终提出了一站式网络解决方案,助力AI企业实现多云调度、业务融合承载与精细化流量管理,推动AI服务高效、稳定交付。
|
5月前
|
机器学习/深度学习 分布式计算 Java
Java与图神经网络:构建企业级知识图谱与智能推理系统
图神经网络(GNN)作为处理非欧几里得数据的前沿技术,正成为企业知识管理和智能推理的核心引擎。本文深入探讨如何在Java生态中构建基于GNN的知识图谱系统,涵盖从图数据建模、GNN模型集成、分布式图计算到实时推理的全流程。通过具体的代码实现和架构设计,展示如何将先进的图神经网络技术融入传统Java企业应用,为构建下一代智能决策系统提供完整解决方案。
512 0
|
6月前
|
机器学习/深度学习 算法 搜索推荐
从零开始构建图注意力网络:GAT算法原理与数值实现详解
本文详细解析了图注意力网络(GAT)的算法原理和实现过程。GAT通过引入注意力机制解决了图卷积网络(GCN)中所有邻居节点贡献相等的局限性,让模型能够自动学习不同邻居的重要性权重。
1073 0
从零开始构建图注意力网络:GAT算法原理与数值实现详解
|
6月前
|
机器学习/深度学习 传感器 数据采集
基于贝叶斯优化CNN-LSTM混合神经网络预测(Matlab代码实现)
基于贝叶斯优化CNN-LSTM混合神经网络预测(Matlab代码实现)
932 0
|
6月前
|
机器学习/深度学习 传感器 数据采集
【故障识别】基于CNN-SVM卷积神经网络结合支持向量机的数据分类预测研究(Matlab代码实现)
【故障识别】基于CNN-SVM卷积神经网络结合支持向量机的数据分类预测研究(Matlab代码实现)
411 0
|
8月前
|
监控 安全 Go
使用Go语言构建网络IP层安全防护
在Go语言中构建网络IP层安全防护是一项需求明确的任务,考虑到高性能、并发和跨平台的优势,Go是构建此类安全系统的合适选择。通过紧密遵循上述步骤并结合最佳实践,可以构建一个强大的网络防护系统,以保障数字环境的安全完整。
175 12

热门文章

最新文章

推荐镜像

更多