tensorflow2.0图片分类实战---对fashion-mnist数据集分类

简介: tensorflow2.0图片分类实战---对fashion-mnist数据集分类

前言


其实写这篇博客的想法主要还是记载一些tf2.0常用api的用法以及如何简单快速的利用tf.keras搭建一个神经网络


1.首先讲讲tf.keras


有了它我们可以很轻松的搭建自己想搭建的网络模型,就像拼积木一样,一层一层的网络叠加起来。但是深层的网络会出现梯度消失等等问题,所以只是能搭建一个网络模型,对于模型的效果还需要一些其他知识方法来优化。对于fashion-mnist数据集的介绍可以看看下面的链接Github上fashion-mnist的介绍


2.再说说一般对于图像分类问题常用的优化方法


  • 1.图像数据的归一化(标准化):加快网络收敛,具体原理可以想象成同心圆沿着梯度到达圆心最快,而不正规的图形沿着梯度到达中心会很曲折


image.png


  • 2.数据特征增强:链接
  • 3.网络的超参数搜索:得到最好的模型参数,主要是网格搜索、随机搜索、遗传算法、启发式搜索
  • 4.dropout、earlystopping,正则化等方法的应用:通过添加遗忘层,正则化以及早停来防止模型过拟合


3.实现代码以及结果部分


#先导入一些常用库,后续用到再增加
import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
import sklearn
import os
import sys
#看一下版本,确认是2.0
print(tf.__version__)
复制代码


image.png


#使用keras自带的模块导入数据,并且切分训练集、验证集、测试集,对训练数据进行标准化处理
fashion_mnist=keras.datasets.fashion_mnist
(x_train_all,y_train_all),(x_test,y_test)=fashion_mnist.load_data()
print(x_train_all.shape)
print(y_train_all.shape)
print(x_test.shape)
print(y_test.shape)
#切分训练集和验证集
x_train,x_valid=x_train_all[5000:],x_train_all[:5000]
y_train,y_valid=y_train_all[5000:],y_train_all[:5000]
print(x_train.shape)
print(y_train.shape)
print(x_valid.shape)
print(y_valid.shape)
#标准化
from sklearn.preprocessing import StandardScaler
scaler=StandardScaler()
x_train_scaled=scaler.fit_transform(x_train.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
x_valid_scaled=scaler.fit_transform(x_valid.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
x_test_scaled=scaler.fit_transform(x_test.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
复制代码
#可视化一下图片以及对应的标签
#展示多张图片
def show_imgs(n_rows,n_cols,x_data,y_data,class_names):
    assert len(x_data)==len(y_data)#判断输入数据的信息是否对应一致
    assert n_rows*n_cols<=len(x_data)#保证不会出现数据量不够
    plt.figure(figsize=(n_cols*2,n_rows*1.6))
    for row in range(n_rows):
        for col in range(n_cols):
            index=n_cols*row+col   #得到当前展示图片的下标
            plt.subplot(n_rows,n_cols,index+1)
            plt.imshow(x_data[index],cmap="binary",interpolation="nearest")
            plt.axis("off")
            plt.title(class_names[y_data[index]])
    plt.show()
class_names=['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
show_imgs(5,5,x_train,y_train,class_names)
复制代码


image.png

#搭建网络模型
model=keras.models.Sequential()
model.add(keras.layers.Flatten(input_shape=[28,28]))
model.add(keras.layers.Dense(300,activation="relu"))
model.add(keras.layers.Dense(100,activation="relu"))
model.add(keras.layers.Dense(10,activation="softmax"))
model.compile(loss="sparse_categorical_crossentropy",optimizer="adam",metrics=["acc"])
model.summary()
复制代码


image.png


这里网络信息中params中的数字怎么来的呢? y=wx+b  然后根据矩阵相乘的规则从(None,784)到(None,300)中间的矩阵就是(784,300)然后偏置项b的大小是300,所以784300+300=235500,这是个小细节稍微提一下。


#训练,并且保存最好的模型、训练的记录以及使用早停防止过拟合
import datetime
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
logdir = os.path.join('logs', current_time)
output_model=os.path.join(logdir,"fashionmnist_model.h5")
callbacks=[
    keras.callbacks.TensorBoard(log_dir=logdir),
    keras.callbacks.ModelCheckpoint(output_model,save_best_only=True),
    keras.callbacks.EarlyStopping(patience=5,min_delta=1e-3)
          ]
history=model.fit(x_train_scaled,y_train,epochs=30,validation_data=(x_valid_scaled,y_valid),callbacks=callbacks)
复制代码


image.png

之前我用自己命名的文件夹使用TensorBoard和ModelCheckpoint运行会出错,搜了一下好像是windows上的bug,上面的这是一种解决方法,然后打开tensorboard看一下。


image.png


最好的模型也保存为h5文件,方便调用


def plot_learning_curves(history):
    pd.DataFrame(history.history).plot(figsize=(8,5))
    plt.grid()
    plt.gca().set_ylim(0,1)
    plt.show()
plot_learning_curves(history)
复制代码


这是自己绘制每次训练的变化情况,和上面的差不多


image.png


#最后在测试集上的准确率
loss,acc=model.evaluate(x_test_scaled,y_test,verbose=0)
print("在测试集上的损失为:",loss)
print("在测试集上的准确率为:",acc)
复制代码


image.png

#得到测试集上的预测标签,可视化和真实标签的区别
y_pred=model.predict(x_test_scaled)
predict = np.argmax(y_pred,axis=1) 
show_imgs(3,5,x_test,predict,class_names)
show_imgs(3,5,x_test,y_test,class_names)
复制代码


预测的结果

image.png

真实的结果


image.png


4.总结:


看了上面的例子,使用tf.keras搭建模型写法就是


model=keras.models.Sequential()
model.add(...)
model.add(...)
...
model.compile(...)
model.fit(...)
#当然也可以写成
model=keras.models.Sequential([
  ...
  ...
  ...
])
#这两者差别不大
#还有函数式的写法
inputs=...
hidden1=...(inputs)
....
#子类的写法
class ...:
  ...
复制代码


不过对于模型中的参数,比如损失函数的选择("sparse_categorical_crossentropy"与"categorical_crossentropy" 或者"binary_crossentropy")什么时候需要用到哪种损失函数最适合、每一层网络中的激活函数的选择、优化器的选择……都需要了解其中的含义才能在适当的场合使用,这里我没有给出使用超参数搜索得到最优模型参数的例子,下次应该会写一个关于超参数搜索的例子。

目录
相关文章
|
18天前
|
机器学习/深度学习 TensorFlow API
TensorFlow与Keras实战:构建深度学习模型
本文探讨了TensorFlow和其高级API Keras在深度学习中的应用。TensorFlow是Google开发的高性能开源框架,支持分布式计算,而Keras以其用户友好和模块化设计简化了神经网络构建。通过一个手写数字识别的实战案例,展示了如何使用Keras加载MNIST数据集、构建CNN模型、训练及评估模型,并进行预测。案例详述了数据预处理、模型构建、训练过程和预测新图像的步骤,为读者提供TensorFlow和Keras的基础实践指导。
153 59
|
2月前
|
数据可视化 TensorFlow 算法框架/工具
TensorFlow 实战(八)(4)
TensorFlow 实战(八)
25 1
|
2月前
|
TensorFlow API 算法框架/工具
TensorFlow 实战(八)(3)
TensorFlow 实战(八)
26 1
|
2月前
|
机器学习/深度学习 自然语言处理 TensorFlow
TensorFlow 实战(八)(5)
TensorFlow 实战(八)
25 0
|
2月前
|
并行计算 TensorFlow 算法框架/工具
TensorFlow 实战(八)(2)
TensorFlow 实战(八)
22 0
|
2月前
|
并行计算 Ubuntu TensorFlow
TensorFlow 实战(八)(1)
TensorFlow 实战(八)
23 0
|
2月前
|
API TensorFlow 算法框架/工具
TensorFlow 实战(七)(5)
TensorFlow 实战(七)
31 0
|
2月前
|
存储 TensorFlow 算法框架/工具
TensorFlow 实战(七)(4)
TensorFlow 实战(七)
29 0
|
2月前
|
存储 TensorFlow API
TensorFlow 实战(七)(3)
TensorFlow 实战(七)
29 0
|
2月前
|
机器学习/深度学习 数据可视化 TensorFlow
TensorFlow 实战(七)(2)
TensorFlow 实战(七)
24 0