如何搭建VGG网络,实现Mnist数据集的图像分类

简介: 如何搭建VGG网络,实现Mnist数据集的图像分类

1 问题

如何搭建VGG网络,实现Mnist数据集的图像分类?


2 方法

步骤:

  1. 首先导包
    Import torch
    from torch import nn
  2. VGG11由8个卷积,三个全连接组成,注意池化只改变特征图大小,不改变通道数
    class MyNet(nn.Module):
       def __init__(self) -> None:
           super().__init__()
           #(1)conv3-64
           self.conv1 = nn.Conv2d(
               in_channels=3,
               out_channels=64,
               kernel_size=3,
               stride=1,
               padding=1 #! 不改变特征图的大小
           )
           #! 池化只改变特征图大小,不改变通道数
           self.max_pool_1 = nn.MaxPool2d(2)
           #(2)conv3-128
           self.conv2 = nn.Conv2d(
               in_channels=64,
               out_channels=128,
               kernel_size=3,
               stride=1,
               padding=1
           )
           self.max_pool_2 = nn.MaxPool2d(2)
           #(3) conv3-256,conv3-256
           self.conv3_1 = nn.Conv2d(
               in_channels=128,
               out_channels=256,
               kernel_size=3,
               stride=1,
               padding=1)
           self.conv3_2 = nn.Conv2d(
               in_channels=256,
               out_channels=256,
               kernel_size=3,
               stride=1,
               padding=1
           )
           self.max_pool_3 = nn.MaxPool2d(2)
           #(4)conv3-512,conv3-512
           self.conv4_1 = nn.Conv2d(
               in_channels=256,
               out_channels=512,
               kernel_size=3,
               stride=1,
               padding=1
           )
           self.conv4_2 = nn.Conv2d(
               in_channels=512,
               out_channels=512,
               kernel_size=3,
               stride=1,
               padding=1
           )
           self.max_pool_4 = nn.MaxPool2d(2)
           #(5)conv3-512,conv3-512
           self.conv5_1 = nn.Conv2d(
               in_channels=512,
               out_channels=512,
               kernel_size=3,
               stride=1,
               padding=1
           )
           self.conv5_2 = nn.Conv2d(
               in_channels=512,
               out_channels=512,
               kernel_size=3,
               stride=1,
               padding=1
           )
           self.max_pool_5 = nn.MaxPool2d(2)
           #(6)
           self.fc1 = nn.Linear(25088,4096)
           self.fc2 = nn.Linear(4096,4096)
           self.fc3 = nn.Linear(4096,1000)
       def forward(self,x):
           x = self.conv1(x)
           print(x.shape)
           x = self.max_pool_1(x)
           print(x.shape)
           x = self.conv2(x)
           print(x.shape)
           x = self.max_pool_2(x)
           print(x.shape)
           x = self.conv3_1(x)
           print(x.shape)
           x = self.conv3_2(x)
           print(x.shape)
           x = self.max_pool_3(x)
           print(x.shape)
           x = self.conv4_1(x)
           print(x.shape)
           x = self.conv4_2(x)
           print(x.shape)
           x = self.max_pool_4(x)
           print(x.shape)
           x = self.conv5_1(x)
           print(x.shape)
           x = self.conv5_2(x)
           print(x.shape)
           x = self.max_pool_5(x)
           print(x.shape)
           x = torch.flatten(x,1)
           print(x.shape)
           x = self.fc1(x)
           print(x.shape)
           x = self.fc2(x)
           print(x.shape)
           out = self.fc3(x)
           return out
  3. 给定x查看最后结果
x = torch.rand(128,3,224,224)
net = MyNet()
out = net(x)
print(out.shape)
#torch.Size([128, 1000])


3 结语

  通过本周学习让我学会了VGG11网络,从实验中我遇到的容易出错的地方是卷积的in_features和out_features容易出错,尺寸不对的时候就会报错,在多个卷积的情况下尤其需要注意,第二点容易出错的地方是卷积以及池化所有结束后,一定要使用torch.flatten进行拉伸,第三点容易出错的地方是fc1的in_features,这个我通过使用断点的方法,得到fc1前一步的size值,从而得到in_features的值,从中收获颇深。

目录
相关文章
|
2月前
|
网络协议
计算机网络的分类
【10月更文挑战第11天】 计算机网络可按覆盖范围(局域网、城域网、广域网)、传输技术(有线、无线)、拓扑结构(星型、总线型、环型、网状型)、使用者(公用、专用)、交换方式(电路交换、分组交换)和服务类型(面向连接、无连接)等多种方式进行分类,每种分类方式揭示了网络的不同特性和应用场景。
|
16天前
|
机器学习/深度学习 Serverless 索引
分类网络中one-hot编码的作用
在分类任务中,使用神经网络时,通常需要将类别标签转换为一种合适的输入格式。这时候,one-hot编码(one-hot encoding)是一种常见且有效的方法。one-hot编码将类别标签表示为向量形式,其中只有一个元素为1,其他元素为0。
20 2
|
1月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
利用Python和TensorFlow构建简单神经网络进行图像分类
利用Python和TensorFlow构建简单神经网络进行图像分类
56 3
|
2月前
|
机器学习/深度学习 Serverless 索引
分类网络中one-hot的作用
在分类任务中,使用神经网络时,通常需要将类别标签转换为一种合适的输入格式。这时候,one-hot编码(one-hot encoding)是一种常见且有效的方法。one-hot编码将类别标签表示为向量形式,其中只有一个元素为1,其他元素为0。
66 3
|
1月前
|
机器学习/深度学习 人工智能 自动驾驶
深度学习的奇迹:如何用神经网络识别图像
【10月更文挑战第33天】在这篇文章中,我们将探索深度学习的奇妙世界,特别是卷积神经网络(CNN)在图像识别中的应用。我们将通过一个简单的代码示例,展示如何使用Python和Keras库构建一个能够识别手写数字的神经网络。这不仅是对深度学习概念的直观介绍,也是对技术实践的一次尝试。让我们一起踏上这段探索之旅,看看数据、模型和代码是如何交织在一起,创造出令人惊叹的结果。
31 0
|
2月前
|
存储 分布式计算 负载均衡
|
2月前
|
安全 区块链 数据库
|
1天前
|
SQL 安全 网络安全
网络安全与信息安全:知识分享####
【10月更文挑战第21天】 随着数字化时代的快速发展,网络安全和信息安全已成为个人和企业不可忽视的关键问题。本文将探讨网络安全漏洞、加密技术以及安全意识的重要性,并提供一些实用的建议,帮助读者提高自身的网络安全防护能力。 ####
34 17
|
12天前
|
存储 SQL 安全
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
随着互联网的普及,网络安全问题日益突出。本文将介绍网络安全的重要性,分析常见的网络安全漏洞及其危害,探讨加密技术在保障网络安全中的作用,并强调提高安全意识的必要性。通过本文的学习,读者将了解网络安全的基本概念和应对策略,提升个人和组织的网络安全防护能力。
|
13天前
|
SQL 安全 网络安全
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
随着互联网的普及,网络安全问题日益突出。本文将从网络安全漏洞、加密技术和安全意识三个方面进行探讨,旨在提高读者对网络安全的认识和防范能力。通过分析常见的网络安全漏洞,介绍加密技术的基本原理和应用,以及强调安全意识的重要性,帮助读者更好地保护自己的网络信息安全。
36 10