如何搭建VGG网络,实现Mnist数据集的图像分类

简介: 如何搭建VGG网络,实现Mnist数据集的图像分类

1 问题

如何搭建VGG网络,实现Mnist数据集的图像分类?


2 方法

步骤:

  1. 首先导包
    Import torch
    from torch import nn
  2. VGG11由8个卷积,三个全连接组成,注意池化只改变特征图大小,不改变通道数
    class MyNet(nn.Module):
       def __init__(self) -> None:
           super().__init__()
           #(1)conv3-64
           self.conv1 = nn.Conv2d(
               in_channels=3,
               out_channels=64,
               kernel_size=3,
               stride=1,
               padding=1 #! 不改变特征图的大小
           )
           #! 池化只改变特征图大小,不改变通道数
           self.max_pool_1 = nn.MaxPool2d(2)
           #(2)conv3-128
           self.conv2 = nn.Conv2d(
               in_channels=64,
               out_channels=128,
               kernel_size=3,
               stride=1,
               padding=1
           )
           self.max_pool_2 = nn.MaxPool2d(2)
           #(3) conv3-256,conv3-256
           self.conv3_1 = nn.Conv2d(
               in_channels=128,
               out_channels=256,
               kernel_size=3,
               stride=1,
               padding=1)
           self.conv3_2 = nn.Conv2d(
               in_channels=256,
               out_channels=256,
               kernel_size=3,
               stride=1,
               padding=1
           )
           self.max_pool_3 = nn.MaxPool2d(2)
           #(4)conv3-512,conv3-512
           self.conv4_1 = nn.Conv2d(
               in_channels=256,
               out_channels=512,
               kernel_size=3,
               stride=1,
               padding=1
           )
           self.conv4_2 = nn.Conv2d(
               in_channels=512,
               out_channels=512,
               kernel_size=3,
               stride=1,
               padding=1
           )
           self.max_pool_4 = nn.MaxPool2d(2)
           #(5)conv3-512,conv3-512
           self.conv5_1 = nn.Conv2d(
               in_channels=512,
               out_channels=512,
               kernel_size=3,
               stride=1,
               padding=1
           )
           self.conv5_2 = nn.Conv2d(
               in_channels=512,
               out_channels=512,
               kernel_size=3,
               stride=1,
               padding=1
           )
           self.max_pool_5 = nn.MaxPool2d(2)
           #(6)
           self.fc1 = nn.Linear(25088,4096)
           self.fc2 = nn.Linear(4096,4096)
           self.fc3 = nn.Linear(4096,1000)
       def forward(self,x):
           x = self.conv1(x)
           print(x.shape)
           x = self.max_pool_1(x)
           print(x.shape)
           x = self.conv2(x)
           print(x.shape)
           x = self.max_pool_2(x)
           print(x.shape)
           x = self.conv3_1(x)
           print(x.shape)
           x = self.conv3_2(x)
           print(x.shape)
           x = self.max_pool_3(x)
           print(x.shape)
           x = self.conv4_1(x)
           print(x.shape)
           x = self.conv4_2(x)
           print(x.shape)
           x = self.max_pool_4(x)
           print(x.shape)
           x = self.conv5_1(x)
           print(x.shape)
           x = self.conv5_2(x)
           print(x.shape)
           x = self.max_pool_5(x)
           print(x.shape)
           x = torch.flatten(x,1)
           print(x.shape)
           x = self.fc1(x)
           print(x.shape)
           x = self.fc2(x)
           print(x.shape)
           out = self.fc3(x)
           return out
  3. 给定x查看最后结果
x = torch.rand(128,3,224,224)
net = MyNet()
out = net(x)
print(out.shape)
#torch.Size([128, 1000])


3 结语

  通过本周学习让我学会了VGG11网络,从实验中我遇到的容易出错的地方是卷积的in_features和out_features容易出错,尺寸不对的时候就会报错,在多个卷积的情况下尤其需要注意,第二点容易出错的地方是卷积以及池化所有结束后,一定要使用torch.flatten进行拉伸,第三点容易出错的地方是fc1的in_features,这个我通过使用断点的方法,得到fc1前一步的size值,从而得到in_features的值,从中收获颇深。

目录
相关文章
|
20天前
|
机器学习/深度学习 数据可视化 测试技术
YOLO11实战:新颖的多尺度卷积注意力(MSCA)加在网络不同位置的涨点情况 | 创新点如何在自己数据集上高效涨点,解决不涨点掉点等问题
本文探讨了创新点在自定义数据集上表现不稳定的问题,分析了不同数据集和网络位置对创新效果的影响。通过在YOLO11的不同位置引入MSCAAttention模块,展示了三种不同的改进方案及其效果。实验结果显示,改进方案在mAP50指标上分别提升了至0.788、0.792和0.775。建议多尝试不同配置,找到最适合特定数据集的解决方案。
165 0
|
15天前
|
网络协议
计算机网络的分类
【10月更文挑战第11天】 计算机网络可按覆盖范围(局域网、城域网、广域网)、传输技术(有线、无线)、拓扑结构(星型、总线型、环型、网状型)、使用者(公用、专用)、交换方式(电路交换、分组交换)和服务类型(面向连接、无连接)等多种方式进行分类,每种分类方式揭示了网络的不同特性和应用场景。
|
12天前
|
机器学习/深度学习 Serverless 索引
分类网络中one-hot的作用
在分类任务中,使用神经网络时,通常需要将类别标签转换为一种合适的输入格式。这时候,one-hot编码(one-hot encoding)是一种常见且有效的方法。one-hot编码将类别标签表示为向量形式,其中只有一个元素为1,其他元素为0。
25 3
|
2月前
|
机器学习/深度学习 人工智能 算法
【新闻文本分类识别系统】Python+卷积神经网络算法+人工智能+深度学习+计算机毕设项目+Django网页界面平台
文本分类识别系统。本系统使用Python作为主要开发语言,首先收集了10种中文文本数据集("体育类", "财经类", "房产类", "家居类", "教育类", "科技类", "时尚类", "时政类", "游戏类", "娱乐类"),然后基于TensorFlow搭建CNN卷积神经网络算法模型。通过对数据集进行多轮迭代训练,最后得到一个识别精度较高的模型,并保存为本地的h5格式。然后使用Django开发Web网页端操作界面,实现用户上传一段文本识别其所属的类别。
79 1
【新闻文本分类识别系统】Python+卷积神经网络算法+人工智能+深度学习+计算机毕设项目+Django网页界面平台
|
26天前
|
机器学习/深度学习 PyTorch 算法框架/工具
深度学习入门案例:运用神经网络实现价格分类
深度学习入门案例:运用神经网络实现价格分类
|
14天前
|
存储 分布式计算 负载均衡
|
14天前
|
安全 区块链 数据库
|
26天前
|
机器学习/深度学习 PyTorch API
深度学习入门:卷积神经网络 | CNN概述,图像基础知识,卷积层,池化层(超详解!!!)
深度学习入门:卷积神经网络 | CNN概述,图像基础知识,卷积层,池化层(超详解!!!)
|
3天前
|
SQL 安全 网络安全
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
【10月更文挑战第23天】在数字时代,网络安全和信息安全已成为我们生活中不可或缺的一部分。本文将探讨网络安全漏洞、加密技术和安全意识等方面的内容,以帮助读者更好地了解如何保护自己的网络安全。通过分析常见的网络安全漏洞,介绍加密技术的基本原理和应用,以及强调安全意识的重要性,我们将为读者提供一些实用的建议和技巧,以增强他们的网络安全防护能力。
|
1天前
|
SQL 存储 安全
网络安全与信息安全:防范漏洞、加密技术及安全意识
随着互联网的快速发展,网络安全和信息安全问题日益凸显。本文将探讨网络安全漏洞的类型及其影响、加密技术的应用以及提高个人和组织的安全意识的重要性。通过深入了解这些关键要素,我们可以更好地保护自己的数字资产免受网络攻击的威胁。