【图像分类】TensorFlow2.7版本搭建NIN网络

简介: 注解:这里为了简单起见,只是模拟NIN网络结构,本代码只是采用3个mlpconv层和最终的全局平均池化输出层,每个mlpconv层中包含了3个1*1卷积层

NIN网络结构


注解:这里为了简单起见,只是模拟NIN网络结构,本代码只是采用3个mlpconv层和最终的全局平均池化输出层,每个mlpconv层中包含了3个1*1卷积层


mlpconv层
1*1卷积只是会改变通道维数并不会改变feature map的大小,它可以变向起到一个通道交叉全连接的作用


self.mlpconv1 = Sequential([

           Conv2D(filters=3,

                  kernel_size=1),

           ReLU(),

           Conv2D(filters=3,

                  kernel_size=1),

           ReLU(),

           Conv2D(filters=3,

                  kernel_size=1)]

       )


全局平均池化层
GlobalAveragePooling会将每个feature map所有的值相加取均值然后将这个实数作为该通道的特征值,NIN网络结构采用全局平均池化代替传统输出层使用MLP结构,这样有效防止过拟合,如果我们的任务存在1000个分类,那么我们最终的输出层的feature map的个数也为1000,然后对其进行全局平均池化,每个feature map代表一个类别,会形成一个1*1*1000的特征图,也就是一个维度为1000的特征向量,然后进行softmax操作


self.global_average_pool = GlobalAveragePooling2D()


"""

* Created with PyCharm

* 作者: 阿光

* 日期: 2021/1/12

* 时间: 22:35

* 描述: 作者原文中的手写数据集是32*32,这里mnist是28*28,所以在训练前修改了图像尺寸

       还有一种解决方式就是在第一个卷积层使用padding='same'进行填充,这样就保证了使用第一个卷积层后尺寸为28*28

       之后仍可正常进行

"""

import tensorflow as tf

from keras import Sequential

from tensorflow.keras.layers import *



class NIN(tf.keras.Model):

   def __init__(self, output_dim=10):

       super(NIN, self).__init__()

       self.mlpconv1 = Sequential([

           Conv2D(filters=3,

                  kernel_size=1),

           ReLU(),

           Conv2D(filters=3,

                  kernel_size=1),

           ReLU(),

           Conv2D(filters=3,

                  kernel_size=1)]

       )

       self.mlpconv2 = Sequential([

           Conv2D(filters=3,

                  kernel_size=1),

           ReLU(),

           Conv2D(filters=3,

                  kernel_size=1),

           ReLU(),

           Conv2D(filters=3,

                  kernel_size=1)]

       )

       self.mlpconv3 = Sequential([

           Conv2D(filters=3,

                  kernel_size=1),

           ReLU(),

           Conv2D(filters=3,

                  kernel_size=1),

           ReLU(),

           Conv2D(filters=output_dim,

                  kernel_size=1)]

       )

       self.global_average_pool = GlobalAveragePooling2D()


   def call(self, inputs):

       x = self.mlpconv1(inputs)

       x = self.mlpconv2(x)

       x = self.mlpconv3(x)

       x = self.global_average_pool(x)

       x = Softmax()(x)

       return x


调用模型,训练mnist数据集


"""

* Created with PyCharm

* 作者: 阿光

* 日期: 2021/1/12

* 时间: 22:20

* 描述:

"""


import tensorflow as tf

from tensorflow.keras import Input


# step1:加载数据集

import model

import model_sequential


(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()


# step2:将图像归一化

train_images, test_images = train_images / 255.0, test_images / 255.0


# step3:将图像的维度变为(60000,28,28,1)

train_images = tf.expand_dims(train_images, axis=3)

test_images = tf.expand_dims(test_images, axis=3)


# step5:导入模型

# history = LeNet5()

history = model.NIN(10)


# 让模型知道输入数据的形式

history.build(input_shape=(1, 28, 28, 1))


# 结局Output Shape为 multiple

history.call(Input(shape=(28, 28, 1)))

history.summary()


# step6:编译模型

history.compile(optimizer='adam',

               loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),

               metrics=['accuracy'])


# 权重保存路径

checkpoint_path = "./weight/cp.ckpt"


# 回调函数,用户保存权重

save_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,

                                                  save_best_only=True,

                                                  save_weights_only=True,

                                                  monitor='val_loss',

                                                  verbose=1)

# step7:训练模型

history = history.fit(train_images,

                     train_labels,

                     epochs=10,

                     batch_size=32,

                     validation_data=(test_images, test_labels),

                     callbacks=[save_callback])

目录
相关文章
|
网络协议 Linux 网络安全
使用tinc+quagga搭建个人SD-WAN网络
使用tinc+quagga搭建个人SD-WAN网络
2278 0
使用tinc+quagga搭建个人SD-WAN网络
|
Kubernetes Cloud Native 网络安全
【云原生-K8s】kubeadm搭建k8s集群1.25版本完整教程【docker、网络插件calico、中间层cri-docker】
【云原生-K8s】kubeadm搭建k8s集群1.25版本完整教程【docker、网络插件calico、中间层cri-docker】
2565 0
【云原生-K8s】kubeadm搭建k8s集群1.25版本完整教程【docker、网络插件calico、中间层cri-docker】
|
消息中间件 监控 NoSQL
ELK搭建(三):监控服务器CPU、网络、磁盘、内存指标
本期我们来讲解如何通过ELK+metricbeat来监控服务器/主机中的CPU、网络、磁盘、内存等指标变化。并绘制会数据看板来方便我们实时监控
584 0
ELK搭建(三):监控服务器CPU、网络、磁盘、内存指标
|
机器学习/深度学习 数据挖掘 PyTorch
|
消息中间件 NoSQL 前端开发
云计算搭建全部内容总结,保证可以搭建一个完整的云计算服务器,包括节点安装、实例的分配和网络的配置等内容
云计算搭建全部内容总结,保证可以搭建一个完整的云计算服务器,包括节点安装、实例的分配和网络的配置等内容
380 0
云计算搭建全部内容总结,保证可以搭建一个完整的云计算服务器,包括节点安装、实例的分配和网络的配置等内容
|
Web App开发 网络协议 安全
重学网络系列之(搭建Http实验环境)
前言 文本已收录至我的GitHub仓库,欢迎Star:github.com/bin39232820… 种一棵树最好的时间是十年前,其次是现在
208 0
|
弹性计算 网络协议 安全
ECS搭建企业级的专用网络隧道
现在很多企业都有移动办公的需求,希望出差员工,如同在公司内部一样可以访问公司的相关资源,很多企业直接购买商业VPN产品或者通过购买防火墙内的SCVPN授权方式满足此需求,在本章节将用开源软件OpenVPN实现此功能。
2467 0
ECS搭建企业级的专用网络隧道
|
机器学习/深度学习 PyTorch 算法框架/工具
从零搭建Pytorch模型教程(三)搭建Transformer网络
本文介绍了Transformer的基本流程,分块的两种实现方式,Position Emebdding的几种实现方式,Encoder的实现方式,最后分类的两种方式,以及最重要的数据格式的介绍。
从零搭建Pytorch模型教程(三)搭建Transformer网络
|
机器学习/深度学习 自然语言处理 PyTorch
搭建小型ViT网络构架进行分类任务(Pytorch)
搭建小型ViT网络构架进行分类任务(Pytorch)
858 0
|
JavaScript 前端开发
Webpack搭建ES6开发环境(部分摘自网络)
Webpack搭建ES6开发环境(部分摘自网络)
124 0
Webpack搭建ES6开发环境(部分摘自网络)

热门文章

最新文章