【图像分类】TensorFlow2.7版本搭建NIN网络

简介: 【图像分类】TensorFlow2.7版本搭建NIN网络

NIN网络结构

注解:这里为了简单起见,只是模拟NIN网络结构,本代码只是采用3个mlpconv层和最终的全局平均池化输出层,每个mlpconv层中包含了3个1*1卷积层

mlpconv层

1*1卷积只是会改变通道维数并不会改变feature map的大小,它可以变向起到一个通道交叉全连接的作用

self.mlpconv1 = Sequential([
            Conv2D(filters=3,
                   kernel_size=1),
            ReLU(),
            Conv2D(filters=3,
                   kernel_size=1),
            ReLU(),
            Conv2D(filters=3,
                   kernel_size=1)]
        )

全局平均池化层

GlobalAveragePooling会将每个feature map所有的值相加取均值然后将这个实数作为该通道的特征值,NIN网络结构采用全局平均池化代替传统输出层使用MLP结构,这样有效防止过拟合,如果我们的任务存在1000个分类,那么我们最终的输出层的feature map的个数也为1000,然后对其进行全局平均池化,每个feature map代表一个类别,会形成一个1*1*1000的特征图,也就是一个维度为1000的特征向量,然后进行softmax操作

self.global_average_pool = GlobalAveragePooling2D()
"""
 * Created with PyCharm
 * 作者: 阿光
 * 日期: 2021/1/12
 * 时间: 22:35
 * 描述: 作者原文中的手写数据集是32*32,这里mnist是28*28,所以在训练前修改了图像尺寸
        还有一种解决方式就是在第一个卷积层使用padding='same'进行填充,这样就保证了使用第一个卷积层后尺寸为28*28
        之后仍可正常进行
"""
import tensorflow as tf
from keras import Sequential
from tensorflow.keras.layers import *
class NIN(tf.keras.Model):
    def __init__(self, output_dim=10):
        super(NIN, self).__init__()
        self.mlpconv1 = Sequential([
            Conv2D(filters=3,
                   kernel_size=1),
            ReLU(),
            Conv2D(filters=3,
                   kernel_size=1),
            ReLU(),
            Conv2D(filters=3,
                   kernel_size=1)]
        )
        self.mlpconv2 = Sequential([
            Conv2D(filters=3,
                   kernel_size=1),
            ReLU(),
            Conv2D(filters=3,
                   kernel_size=1),
            ReLU(),
            Conv2D(filters=3,
                   kernel_size=1)]
        )
        self.mlpconv3 = Sequential([
            Conv2D(filters=3,
                   kernel_size=1),
            ReLU(),
            Conv2D(filters=3,
                   kernel_size=1),
            ReLU(),
            Conv2D(filters=output_dim,
                   kernel_size=1)]
        )
        self.global_average_pool = GlobalAveragePooling2D()
    def call(self, inputs):
        x = self.mlpconv1(inputs)
        x = self.mlpconv2(x)
        x = self.mlpconv3(x)
        x = self.global_average_pool(x)
        x = Softmax()(x)
        return x

调用模型,训练mnist数据集

"""
 * Created with PyCharm
 * 作者: 阿光
 * 日期: 2021/1/12
 * 时间: 22:20
 * 描述:
"""
import tensorflow as tf
from tensorflow.keras import Input
# step1:加载数据集
import model
import model_sequential
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
# step2:将图像归一化
train_images, test_images = train_images / 255.0, test_images / 255.0
# step3:将图像的维度变为(60000,28,28,1)
train_images = tf.expand_dims(train_images, axis=3)
test_images = tf.expand_dims(test_images, axis=3)
# step5:导入模型
# history = LeNet5()
history = model.NIN(10)
# 让模型知道输入数据的形式
history.build(input_shape=(1, 28, 28, 1))
# 结局Output Shape为 multiple
history.call(Input(shape=(28, 28, 1)))
history.summary()
# step6:编译模型
history.compile(optimizer='adam',
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])
# 权重保存路径
checkpoint_path = "./weight/cp.ckpt"
# 回调函数,用户保存权重
save_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                   save_best_only=True,
                                                   save_weights_only=True,
                                                   monitor='val_loss',
                                                   verbose=1)
# step7:训练模型
history = history.fit(train_images,
                      train_labels,
                      epochs=10,
                      batch_size=32,
                      validation_data=(test_images, test_labels),
                      callbacks=[save_callback])


目录
相关文章
|
2月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
PYTHON TENSORFLOW 2二维卷积神经网络CNN对图像物体识别混淆矩阵评估|数据分享
PYTHON TENSORFLOW 2二维卷积神经网络CNN对图像物体识别混淆矩阵评估|数据分享
|
7天前
|
机器学习/深度学习 人工智能 算法
海洋生物识别系统+图像识别+Python+人工智能课设+深度学习+卷积神经网络算法+TensorFlow
海洋生物识别系统。以Python作为主要编程语言,通过TensorFlow搭建ResNet50卷积神经网络算法,通过对22种常见的海洋生物('蛤蜊', '珊瑚', '螃蟹', '海豚', '鳗鱼', '水母', '龙虾', '海蛞蝓', '章鱼', '水獭', '企鹅', '河豚', '魔鬼鱼', '海胆', '海马', '海豹', '鲨鱼', '虾', '鱿鱼', '海星', '海龟', '鲸鱼')数据集进行训练,得到一个识别精度较高的模型文件,然后使用Django开发一个Web网页平台操作界面,实现用户上传一张海洋生物图片识别其名称。
86 7
海洋生物识别系统+图像识别+Python+人工智能课设+深度学习+卷积神经网络算法+TensorFlow
|
7天前
|
机器学习/深度学习 人工智能 算法
【昆虫识别系统】图像识别Python+卷积神经网络算法+人工智能+深度学习+机器学习+TensorFlow+ResNet50
昆虫识别系统,使用Python作为主要开发语言。通过TensorFlow搭建ResNet50卷积神经网络算法(CNN)模型。通过对10种常见的昆虫图片数据集('蜜蜂', '甲虫', '蝴蝶', '蝉', '蜻蜓', '蚱蜢', '蛾', '蝎子', '蜗牛', '蜘蛛')进行训练,得到一个识别精度较高的H5格式模型文件,然后使用Django搭建Web网页端可视化操作界面,实现用户上传一张昆虫图片识别其名称。
130 7
【昆虫识别系统】图像识别Python+卷积神经网络算法+人工智能+深度学习+机器学习+TensorFlow+ResNet50
|
8天前
|
机器学习/深度学习 人工智能 算法
【球类识别系统】图像识别Python+卷积神经网络算法+人工智能+深度学习+TensorFlow
球类识别系统,本系统使用Python作为主要编程语言,基于TensorFlow搭建ResNet50卷积神经网络算法模型,通过收集 '美式足球', '棒球', '篮球', '台球', '保龄球', '板球', '足球', '高尔夫球', '曲棍球', '冰球', '橄榄球', '羽毛球', '乒乓球', '网球', '排球'等15种常见的球类图像作为数据集,然后进行训练,最终得到一个识别精度较高的模型文件。再使用Django开发Web网页端可视化界面平台,实现用户上传一张球类图片识别其名称。
101 7
【球类识别系统】图像识别Python+卷积神经网络算法+人工智能+深度学习+TensorFlow
|
27天前
|
机器学习/深度学习 算法 TensorFlow
【图像识别】谷物识别系统Python+人工智能深度学习+TensorFlow+卷积算法网络模型+图像识别
谷物识别系统,本系统使用Python作为主要编程语言,通过TensorFlow搭建ResNet50卷积神经算法网络模型,通过对11种谷物图片数据集('大米', '小米', '燕麦', '玉米渣', '红豆', '绿豆', '花生仁', '荞麦', '黄豆', '黑米', '黑豆')进行训练,得到一个进度较高的H5格式的模型文件。然后使用Django框架搭建了一个Web网页端可视化操作界面。实现用户上传一张图片识别其名称。
67 0
【图像识别】谷物识别系统Python+人工智能深度学习+TensorFlow+卷积算法网络模型+图像识别
|
2月前
|
机器学习/深度学习 算法 TensorFlow
Python深度学习基于Tensorflow(6)神经网络基础
Python深度学习基于Tensorflow(6)神经网络基础
26 2
Python深度学习基于Tensorflow(6)神经网络基础
|
2月前
|
机器学习/深度学习 人工智能 算法
食物识别系统Python+深度学习人工智能+TensorFlow+卷积神经网络算法模型
食物识别系统采用TensorFlow的ResNet50模型,训练了包含11类食物的数据集,生成高精度H5模型。系统整合Django框架,提供网页平台,用户可上传图片进行食物识别。效果图片展示成功识别各类食物。[查看演示视频、代码及安装指南](https://www.yuque.com/ziwu/yygu3z/yhd6a7vai4o9iuys?singleDoc#)。项目利用深度学习的卷积神经网络(CNN),其局部感受野和权重共享机制适于图像识别,广泛应用于医疗图像分析等领域。示例代码展示了一个使用TensorFlow训练的简单CNN模型,用于MNIST手写数字识别。
70 3
|
2月前
|
机器学习/深度学习 算法
基于深度学习网络的十二生肖图像分类matlab仿真
该内容是关于使用GoogLeNet算法进行十二生肖图像分类的总结。在MATLAB2022a环境下,GoogLeNet通过Inception模块学习高层语义特征,处理不同尺寸的输入。核心程序展示了验证集上部分图像的预测标签和置信度,以4x4网格显示16张图像,每张附带预测类别和概率。
|
2月前
|
机器学习/深度学习 人工智能 算法
中草药识别系统Python+深度学习人工智能+TensorFlow+卷积神经网络算法模型
中草药识别系统Python+深度学习人工智能+TensorFlow+卷积神经网络算法模型
66 0
|
2月前
|
网络协议 网络安全
在Windos Server 2016 版本配置网络参数和接入工作组网络
在Windos Server 2016 版本配置网络参数和接入工作组网络

热门文章

最新文章