如何搭建VGG网络,实现Mnist数据集的图像分类

简介: 如何搭建VGG网络,实现Mnist数据集的图像分类

1 问题

如何搭建VGG网络,实现Mnist数据集的图像分类?


2 方法

步骤:

  1. 首先导包
    Import torch
    from torch import nn
  2. VGG11由8个卷积,三个全连接组成,注意池化只改变特征图大小,不改变通道数
    class MyNet(nn.Module):
       def __init__(self) -> None:
           super().__init__()
           #(1)conv3-64
           self.conv1 = nn.Conv2d(
               in_channels=3,
               out_channels=64,
               kernel_size=3,
               stride=1,
               padding=1 #! 不改变特征图的大小
           )
           #! 池化只改变特征图大小,不改变通道数
           self.max_pool_1 = nn.MaxPool2d(2)
           #(2)conv3-128
           self.conv2 = nn.Conv2d(
               in_channels=64,
               out_channels=128,
               kernel_size=3,
               stride=1,
               padding=1
           )
           self.max_pool_2 = nn.MaxPool2d(2)
           #(3) conv3-256,conv3-256
           self.conv3_1 = nn.Conv2d(
               in_channels=128,
               out_channels=256,
               kernel_size=3,
               stride=1,
               padding=1)
           self.conv3_2 = nn.Conv2d(
               in_channels=256,
               out_channels=256,
               kernel_size=3,
               stride=1,
               padding=1
           )
           self.max_pool_3 = nn.MaxPool2d(2)
           #(4)conv3-512,conv3-512
           self.conv4_1 = nn.Conv2d(
               in_channels=256,
               out_channels=512,
               kernel_size=3,
               stride=1,
               padding=1
           )
           self.conv4_2 = nn.Conv2d(
               in_channels=512,
               out_channels=512,
               kernel_size=3,
               stride=1,
               padding=1
           )
           self.max_pool_4 = nn.MaxPool2d(2)
           #(5)conv3-512,conv3-512
           self.conv5_1 = nn.Conv2d(
               in_channels=512,
               out_channels=512,
               kernel_size=3,
               stride=1,
               padding=1
           )
           self.conv5_2 = nn.Conv2d(
               in_channels=512,
               out_channels=512,
               kernel_size=3,
               stride=1,
               padding=1
           )
           self.max_pool_5 = nn.MaxPool2d(2)
           #(6)
           self.fc1 = nn.Linear(25088,4096)
           self.fc2 = nn.Linear(4096,4096)
           self.fc3 = nn.Linear(4096,1000)
       def forward(self,x):
           x = self.conv1(x)
           print(x.shape)
           x = self.max_pool_1(x)
           print(x.shape)
           x = self.conv2(x)
           print(x.shape)
           x = self.max_pool_2(x)
           print(x.shape)
           x = self.conv3_1(x)
           print(x.shape)
           x = self.conv3_2(x)
           print(x.shape)
           x = self.max_pool_3(x)
           print(x.shape)
           x = self.conv4_1(x)
           print(x.shape)
           x = self.conv4_2(x)
           print(x.shape)
           x = self.max_pool_4(x)
           print(x.shape)
           x = self.conv5_1(x)
           print(x.shape)
           x = self.conv5_2(x)
           print(x.shape)
           x = self.max_pool_5(x)
           print(x.shape)
           x = torch.flatten(x,1)
           print(x.shape)
           x = self.fc1(x)
           print(x.shape)
           x = self.fc2(x)
           print(x.shape)
           out = self.fc3(x)
           return out
  3. 给定x查看最后结果
x = torch.rand(128,3,224,224)
net = MyNet()
out = net(x)
print(out.shape)
#torch.Size([128, 1000])


3 结语

  通过本周学习让我学会了VGG11网络,从实验中我遇到的容易出错的地方是卷积的in_features和out_features容易出错,尺寸不对的时候就会报错,在多个卷积的情况下尤其需要注意,第二点容易出错的地方是卷积以及池化所有结束后,一定要使用torch.flatten进行拉伸,第三点容易出错的地方是fc1的in_features,这个我通过使用断点的方法,得到fc1前一步的size值,从而得到in_features的值,从中收获颇深。

目录
相关文章
|
11天前
|
算法 前端开发 数据挖掘
【类脑智能】脑网络通信模型分类及量化指标(附思维导图)
本文概述了脑网络通信模型的分类、算法原理及量化指标,介绍了扩散过程、路由协议和参数模型三种通信模型,并详细讨论了它们的性能指标、优缺点以及在脑网络研究中的应用,同时提供了思维导图以帮助理解这些概念。
13 3
【类脑智能】脑网络通信模型分类及量化指标(附思维导图)
|
1天前
|
机器学习/深度学习 人工智能 编解码
【神经网络】基于对抗神经网络的图像生成是如何实现的?
对抗神经网络,尤其是生成对抗网络(GAN),在图像生成领域扮演着重要角色。它们通过一个有趣的概念——对抗训练——来实现图像的生成。以下将深入探讨GAN是如何实现基于对抗神经网络的图像生成的
7 3
|
17天前
|
机器学习/深度学习 API 算法框架/工具
【Tensorflow+keras】Keras API三种搭建神经网络的方式及以mnist举例实现
使用Keras API构建神经网络的三种方法:使用Sequential模型、使用函数式API以及通过继承Model类来自定义模型,并提供了基于MNIST数据集的示例代码。
29 12
|
17天前
|
机器学习/深度学习 TensorFlow 算法框架/工具
【Tensorflow+Keras】keras实现条件生成对抗网络DCGAN--以Minis和fashion_mnist数据集为例
如何使用TensorFlow和Keras实现条件生成对抗网络(CGAN)并以MNIST和Fashion MNIST数据集为例进行演示。
23 3
|
4天前
|
机器学习/深度学习 监控 数据可视化
|
5天前
|
机器学习/深度学习 编解码 Android开发
MATLAB Mobile - 使用预训练网络对手机拍摄的图像进行分类
MATLAB Mobile - 使用预训练网络对手机拍摄的图像进行分类
17 0
|
16天前
|
机器学习/深度学习 数据可视化 算法框架/工具
【深度学习】Generative Adversarial Networks ,GAN生成对抗网络分类
文章概述了生成对抗网络(GANs)的不同变体,并对几种经典GAN模型进行了简介,包括它们的结构特点和应用场景。此外,文章还提供了一个GitHub项目链接,该项目汇总了使用Keras实现的各种GAN模型的代码。
28 0
|
2月前
|
算法 定位技术 网络架构
网络的分类与性能指标
可以分为广域网(WAN)、城域网(MAN)、局域网(LAN)、个人区域网(PAN)。
42 4
|
2月前
|
机器学习/深度学习 自然语言处理 搜索推荐
深度学习之分类网络
深度学习的分类网络(Classification Networks)是用于将输入数据分配到预定义类别的神经网络。它们广泛应用于图像分类、文本分类、语音识别等任务。以下是对深度学习分类网络的详细介绍,包括其基本概念、主要架构、常见模型、应用场景、优缺点及未来发展方向。
84 4
|
2月前
|
机器学习/深度学习 计算机视觉 网络架构
是VGG网络的主要特点和架构描述
是VGG网络的主要特点和架构描述:
31 1