【图像分类】TensorFlow2.7版本搭建NIN网络

简介: 【图像分类】TensorFlow2.7版本搭建NIN网络

NIN网络结构

注解:这里为了简单起见,只是模拟NIN网络结构,本代码只是采用3个mlpconv层和最终的全局平均池化输出层,每个mlpconv层中包含了3个1*1卷积层

mlpconv层

1*1卷积只是会改变通道维数并不会改变feature map的大小,它可以变向起到一个通道交叉全连接的作用

self.mlpconv1 = Sequential([
            Conv2D(filters=3,
                   kernel_size=1),
            ReLU(),
            Conv2D(filters=3,
                   kernel_size=1),
            ReLU(),
            Conv2D(filters=3,
                   kernel_size=1)]
        )

全局平均池化层

GlobalAveragePooling会将每个feature map所有的值相加取均值然后将这个实数作为该通道的特征值,NIN网络结构采用全局平均池化代替传统输出层使用MLP结构,这样有效防止过拟合,如果我们的任务存在1000个分类,那么我们最终的输出层的feature map的个数也为1000,然后对其进行全局平均池化,每个feature map代表一个类别,会形成一个1*1*1000的特征图,也就是一个维度为1000的特征向量,然后进行softmax操作

self.global_average_pool = GlobalAveragePooling2D()
"""
 * Created with PyCharm
 * 作者: 阿光
 * 日期: 2021/1/12
 * 时间: 22:35
 * 描述: 作者原文中的手写数据集是32*32,这里mnist是28*28,所以在训练前修改了图像尺寸
        还有一种解决方式就是在第一个卷积层使用padding='same'进行填充,这样就保证了使用第一个卷积层后尺寸为28*28
        之后仍可正常进行
"""
import tensorflow as tf
from keras import Sequential
from tensorflow.keras.layers import *
class NIN(tf.keras.Model):
    def __init__(self, output_dim=10):
        super(NIN, self).__init__()
        self.mlpconv1 = Sequential([
            Conv2D(filters=3,
                   kernel_size=1),
            ReLU(),
            Conv2D(filters=3,
                   kernel_size=1),
            ReLU(),
            Conv2D(filters=3,
                   kernel_size=1)]
        )
        self.mlpconv2 = Sequential([
            Conv2D(filters=3,
                   kernel_size=1),
            ReLU(),
            Conv2D(filters=3,
                   kernel_size=1),
            ReLU(),
            Conv2D(filters=3,
                   kernel_size=1)]
        )
        self.mlpconv3 = Sequential([
            Conv2D(filters=3,
                   kernel_size=1),
            ReLU(),
            Conv2D(filters=3,
                   kernel_size=1),
            ReLU(),
            Conv2D(filters=output_dim,
                   kernel_size=1)]
        )
        self.global_average_pool = GlobalAveragePooling2D()
    def call(self, inputs):
        x = self.mlpconv1(inputs)
        x = self.mlpconv2(x)
        x = self.mlpconv3(x)
        x = self.global_average_pool(x)
        x = Softmax()(x)
        return x

调用模型,训练mnist数据集

"""
 * Created with PyCharm
 * 作者: 阿光
 * 日期: 2021/1/12
 * 时间: 22:20
 * 描述:
"""
import tensorflow as tf
from tensorflow.keras import Input
# step1:加载数据集
import model
import model_sequential
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
# step2:将图像归一化
train_images, test_images = train_images / 255.0, test_images / 255.0
# step3:将图像的维度变为(60000,28,28,1)
train_images = tf.expand_dims(train_images, axis=3)
test_images = tf.expand_dims(test_images, axis=3)
# step5:导入模型
# history = LeNet5()
history = model.NIN(10)
# 让模型知道输入数据的形式
history.build(input_shape=(1, 28, 28, 1))
# 结局Output Shape为 multiple
history.call(Input(shape=(28, 28, 1)))
history.summary()
# step6:编译模型
history.compile(optimizer='adam',
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])
# 权重保存路径
checkpoint_path = "./weight/cp.ckpt"
# 回调函数,用户保存权重
save_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                   save_best_only=True,
                                                   save_weights_only=True,
                                                   monitor='val_loss',
                                                   verbose=1)
# step7:训练模型
history = history.fit(train_images,
                      train_labels,
                      epochs=10,
                      batch_size=32,
                      validation_data=(test_images, test_labels),
                      callbacks=[save_callback])


目录
相关文章
|
1月前
|
网络协议
计算机网络的分类
【10月更文挑战第11天】 计算机网络可按覆盖范围(局域网、城域网、广域网)、传输技术(有线、无线)、拓扑结构(星型、总线型、环型、网状型)、使用者(公用、专用)、交换方式(电路交换、分组交换)和服务类型(面向连接、无连接)等多种方式进行分类,每种分类方式揭示了网络的不同特性和应用场景。
|
10天前
|
网络虚拟化
生成树协议(STP)及其演进版本RSTP和MSTP,旨在解决网络中的环路问题,提高网络的可靠性和稳定性
生成树协议(STP)及其演进版本RSTP和MSTP,旨在解决网络中的环路问题,提高网络的可靠性和稳定性。本文介绍了这三种协议的原理、特点及区别,并提供了思科和华为设备的命令示例,帮助读者更好地理解和应用这些协议。
27 4
|
8天前
|
机器学习/深度学习 人工智能 算法
基于Python深度学习的【垃圾识别系统】实现~TensorFlow+人工智能+算法网络
垃圾识别分类系统。本系统采用Python作为主要编程语言,通过收集了5种常见的垃圾数据集('塑料', '玻璃', '纸张', '纸板', '金属'),然后基于TensorFlow搭建卷积神经网络算法模型,通过对图像数据集进行多轮迭代训练,最后得到一个识别精度较高的模型文件。然后使用Django搭建Web网页端可视化操作界面,实现用户在网页端上传一张垃圾图片识别其名称。
36 0
基于Python深度学习的【垃圾识别系统】实现~TensorFlow+人工智能+算法网络
|
12天前
|
机器学习/深度学习 TensorFlow 算法框架/工具
利用Python和TensorFlow构建简单神经网络进行图像分类
利用Python和TensorFlow构建简单神经网络进行图像分类
38 3
|
24天前
|
机器学习/深度学习 人工智能 算法
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
车辆车型识别,使用Python作为主要编程语言,通过收集多种车辆车型图像数据集,然后基于TensorFlow搭建卷积网络算法模型,并对数据集进行训练,最后得到一个识别精度较高的模型文件。再基于Django搭建web网页端操作界面,实现用户上传一张车辆图片识别其类型。
71 0
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
|
28天前
|
机器学习/深度学习 数据可视化 TensorFlow
使用TensorFlow构建一个简单的图像分类模型
【10月更文挑战第18天】使用TensorFlow构建一个简单的图像分类模型
48 1
|
1月前
|
机器学习/深度学习 Serverless 索引
分类网络中one-hot的作用
在分类任务中,使用神经网络时,通常需要将类别标签转换为一种合适的输入格式。这时候,one-hot编码(one-hot encoding)是一种常见且有效的方法。one-hot编码将类别标签表示为向量形式,其中只有一个元素为1,其他元素为0。
37 3
|
13天前
|
机器学习/深度学习 人工智能 自动驾驶
深度学习的奇迹:如何用神经网络识别图像
【10月更文挑战第33天】在这篇文章中,我们将探索深度学习的奇妙世界,特别是卷积神经网络(CNN)在图像识别中的应用。我们将通过一个简单的代码示例,展示如何使用Python和Keras库构建一个能够识别手写数字的神经网络。这不仅是对深度学习概念的直观介绍,也是对技术实践的一次尝试。让我们一起踏上这段探索之旅,看看数据、模型和代码是如何交织在一起,创造出令人惊叹的结果。
25 0
|
1月前
|
机器学习/深度学习 SQL 数据采集
基于tensorflow、CNN网络识别花卉的种类(图像识别)
基于tensorflow、CNN网络识别花卉的种类(图像识别)
27 1
|
1月前
|
机器学习/深度学习 人工智能 算法
【玉米病害识别】Python+卷积神经网络算法+人工智能+深度学习+计算机课设项目+TensorFlow+模型训练
玉米病害识别系统,本系统使用Python作为主要开发语言,通过收集了8种常见的玉米叶部病害图片数据集('矮花叶病', '健康', '灰斑病一般', '灰斑病严重', '锈病一般', '锈病严重', '叶斑病一般', '叶斑病严重'),然后基于TensorFlow搭建卷积神经网络算法模型,通过对数据集进行多轮迭代训练,最后得到一个识别精度较高的模型文件。再使用Django搭建Web网页操作平台,实现用户上传一张玉米病害图片识别其名称。
56 0
【玉米病害识别】Python+卷积神经网络算法+人工智能+深度学习+计算机课设项目+TensorFlow+模型训练

热门文章

最新文章