四、深度学习基础:深度学习基础之手写Mnist数字识别

简介: 四、深度学习基础:深度学习基础之手写Mnist数字识别

手写数字识别

Mnist数据集是一个手写数字识别数据集,被称为深度学习界的“Hello World”。

在这里插入图片描述

Mnist数据集包含:

  • 训练集:60,000张28×28灰度图
  • 测试集:10,000张28×28灰度图

共有0~9这10个手写数字体类别。

导入必要的模块

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras import datasets, Input, Model
from tensorflow.keras.layers import Flatten, Dense, Activation, BatchNormalization
from tensorflow.keras.initializers import TruncatedNormal

载入Mnist数据集

(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
print(x_train.shape)
print(y_train.shape)
print(x_test.shape)
print(y_test.shape)
(60000, 28, 28)
(60000,)
(10000, 28, 28)
(10000,)

维度(60000, 28, 28) —> (60000, 784)

x_train = x_train.reshape(-1, 28*28)
x_test = x_test.reshape(-1, 28*28)

归一化处理

x_train = x_train/255.0
x_test = x_test/255.0

模型搭建

用到的api:

全连接层tf.keras.layers.Dense

用到的参数:

  • units:输入整数,全连接层神经元个数。
  • activation:激活函数。

    可选项:

    • 'sigmoid':sigmoid激活函数

    • 'tanh':tanh激活函数

    • 'relu':relu激活函数

    • 'elu'或tf.keras.activations.elu(alpha=1.0):elu激活函数

    • 'selu':selu激活函数

    • 'swish': swish激活函数(tf2.2版本以上才有)

    • 'softmax': softmax函数

  • kernel_initializer:权重初始化,默认是'glorot_uniform'(即Xavier均匀初始化)。

    可选项:

    • 'RandomNormal'或tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=0.05):正态分布采样,均值为0,标准差0.05

    • 'glorot_normal':正态分布采样,均值为0,方差为2 / (fan_in + fan_out)

    • 'glorot_uniform':均匀分布采样,范围[-limit, limit],limit = sqrt(6 / (fan_in + fan_out))

    • 'lecun_normal':正态分布采样,均值为0,方差为1 / fan_in

    • 'lecun_uniform':均匀分布采样,范围[-limit, limit],limit = sqrt(3 / fan_in)

    • 'he_normal':正态分布采样,均值为0,方差为2 / fan_in

    • 'he_uniform':均匀分布采样,范围[-limit, limit],limit = sqrt(6 / fan_in)

    fan_in是输入的神经元个数,fan_out是输出的神经元个数。

  • name:输入字符串,给该层设置一个名称。

输入层tf.keras.Input

用到的参数:

  • shape:输入层的大小。
  • name:输入字符串,给该层设置一个名称。

BN层tf.keras.layers.BatchNormalization

用到的参数:

  • axis:需要做批处理的维度,默认为-1,也就是输入数据的最后一个维度。
  • name:输入字符串,给该层设置一个名称。

模型设置tf.keras.Sequential.compile

用到的参数:

  • loss:损失函数,多分类任务一般使用"sparse_categorical_crossentropy",sparse表示先对标签做one hot编码。

构建50个隐层的神经网络,每层100个神经元。

# 输入层inputs
inputs = Input(shape=(28*28), name='input')

# 隐层dense
x = Dense(units=100, kernel_initializer='lecun_normal', name='dense_0')(inputs)
x = BatchNormalization(axis=-1, name='batchnormalization_0')(x)
x = Activation(activation='selu', name='activation_0')(x)

for i in range(49):
    x = Dense(units=100, kernel_initializer='lecun_normal', name='dense_'+str(i+1))(x)
    x = BatchNormalization(axis=-1, name='batchnormalization_'+str(i+1))(x)
    x = Activation(activation='selu', name='activation_'+str(i+1))(x)

# 输出层
outputs = Dense(units=10, activation='softmax', name='logit')(x)

# 设置模型的inputs和outputs
model = Model(inputs=inputs, outputs=outputs)

# 设置损失函数loss、优化器optimizer、评价标准metrics
model.compile(loss='sparse_categorical_crossentropy',
              optimizer="sgd", metrics=['accuracy'])

查看模型每层的参数量和输出的大小

model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input (InputLayer)           [(None, 784)]             0         
_________________________________________________________________
dense_0 (Dense)              (None, 100)               78500     
_________________________________________________________________
batch_normalization (BatchNo (None, 100)               400       
_________________________________________________________________
activation_0 (Activation)    (None, 100)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 100)               10100     
_________________________________________________________________
batch_normalization_1 (Batch (None, 100)               400       
_________________________________________________________________
activation_1 (Activation)    (None, 100)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 100)               10100     
_________________________________________________________________
batch_normalization_2 (Batch (None, 100)               400       
_________________________________________________________________
activation_2 (Activation)    (None, 100)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 100)               10100     
_________________________________________________________________
batch_normalization_3 (Batch (None, 100)               400       
_________________________________________________________________
activation_3 (Activation)    (None, 100)               0         
_________________________________________________________________
dense_4 (Dense)              (None, 100)               10100     
_________________________________________________________________
batch_normalization_4 (Batch (None, 100)               400       
_________________________________________________________________
activation_4 (Activation)    (None, 100)               0         
_________________________________________________________________
dense_5 (Dense)              (None, 100)               10100     
_________________________________________________________________
batch_normalization_5 (Batch (None, 100)               400       
_________________________________________________________________
activation_5 (Activation)    (None, 100)               0         
_________________________________________________________________
dense_6 (Dense)              (None, 100)               10100     
_________________________________________________________________
batch_normalization_6 (Batch (None, 100)               400       
_________________________________________________________________
activation_6 (Activation)    (None, 100)               0         
_________________________________________________________________
dense_7 (Dense)              (None, 100)               10100     
_________________________________________________________________
batch_normalization_7 (Batch (None, 100)               400       
_________________________________________________________________
activation_7 (Activation)    (None, 100)               0         
_________________________________________________________________
dense_8 (Dense)              (None, 100)               10100     
_________________________________________________________________
batch_normalization_8 (Batch (None, 100)               400       
_________________________________________________________________
activation_8 (Activation)    (None, 100)               0         
_________________________________________________________________
dense_9 (Dense)              (None, 100)               10100     
_________________________________________________________________
batch_normalization_9 (Batch (None, 100)               400       
_________________________________________________________________
activation_9 (Activation)    (None, 100)               0         
_________________________________________________________________
dense_10 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_10 (Batc (None, 100)               400       
_________________________________________________________________
activation_10 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_11 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_11 (Batc (None, 100)               400       
_________________________________________________________________
activation_11 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_12 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_12 (Batc (None, 100)               400       
_________________________________________________________________
activation_12 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_13 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_13 (Batc (None, 100)               400       
_________________________________________________________________
activation_13 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_14 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_14 (Batc (None, 100)               400       
_________________________________________________________________
activation_14 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_15 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_15 (Batc (None, 100)               400       
_________________________________________________________________
activation_15 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_16 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_16 (Batc (None, 100)               400       
_________________________________________________________________
activation_16 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_17 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_17 (Batc (None, 100)               400       
_________________________________________________________________
activation_17 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_18 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_18 (Batc (None, 100)               400       
_________________________________________________________________
activation_18 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_19 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_19 (Batc (None, 100)               400       
_________________________________________________________________
activation_19 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_20 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_20 (Batc (None, 100)               400       
_________________________________________________________________
activation_20 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_21 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_21 (Batc (None, 100)               400       
_________________________________________________________________
activation_21 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_22 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_22 (Batc (None, 100)               400       
_________________________________________________________________
activation_22 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_23 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_23 (Batc (None, 100)               400       
_________________________________________________________________
activation_23 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_24 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_24 (Batc (None, 100)               400       
_________________________________________________________________
activation_24 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_25 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_25 (Batc (None, 100)               400       
_________________________________________________________________
activation_25 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_26 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_26 (Batc (None, 100)               400       
_________________________________________________________________
activation_26 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_27 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_27 (Batc (None, 100)               400       
_________________________________________________________________
activation_27 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_28 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_28 (Batc (None, 100)               400       
_________________________________________________________________
activation_28 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_29 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_29 (Batc (None, 100)               400       
_________________________________________________________________
activation_29 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_30 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_30 (Batc (None, 100)               400       
_________________________________________________________________
activation_30 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_31 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_31 (Batc (None, 100)               400       
_________________________________________________________________
activation_31 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_32 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_32 (Batc (None, 100)               400       
_________________________________________________________________
activation_32 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_33 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_33 (Batc (None, 100)               400       
_________________________________________________________________
activation_33 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_34 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_34 (Batc (None, 100)               400       
_________________________________________________________________
activation_34 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_35 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_35 (Batc (None, 100)               400       
_________________________________________________________________
activation_35 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_36 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_36 (Batc (None, 100)               400       
_________________________________________________________________
activation_36 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_37 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_37 (Batc (None, 100)               400       
_________________________________________________________________
activation_37 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_38 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_38 (Batc (None, 100)               400       
_________________________________________________________________
activation_38 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_39 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_39 (Batc (None, 100)               400       
_________________________________________________________________
activation_39 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_40 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_40 (Batc (None, 100)               400       
_________________________________________________________________
activation_40 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_41 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_41 (Batc (None, 100)               400       
_________________________________________________________________
activation_41 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_42 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_42 (Batc (None, 100)               400       
_________________________________________________________________
activation_42 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_43 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_43 (Batc (None, 100)               400       
_________________________________________________________________
activation_43 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_44 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_44 (Batc (None, 100)               400       
_________________________________________________________________
activation_44 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_45 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_45 (Batc (None, 100)               400       
_________________________________________________________________
activation_45 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_46 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_46 (Batc (None, 100)               400       
_________________________________________________________________
activation_46 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_47 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_47 (Batc (None, 100)               400       
_________________________________________________________________
activation_47 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_48 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_48 (Batc (None, 100)               400       
_________________________________________________________________
activation_48 (Activation)   (None, 100)               0         
_________________________________________________________________
dense_49 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_49 (Batc (None, 100)               400       
_________________________________________________________________
activation_49 (Activation)   (None, 100)               0         
_________________________________________________________________
logit (Dense)                (None, 10)                1010      
=================================================================
Total params: 594,410
Trainable params: 584,410
Non-trainable params: 10,000
_________________________________________________________________

模型训练

tf.keras.Sequential.fit

用到的参数:

  • x:输入数据。
  • y:输入标签。
  • batch_size:一次梯度更新使用的数据量。
  • epochs:数据集跑多少轮模型训练,一轮表示整个数据集训练一次。
  • validation_split:验证集占总数据量的比例,取值0~1。
  • shuffle:每轮训练是否打乱数据顺序,默认True。

返回:History对象,History.history属性会记录每一轮训练集和验证集的损失函数值和评价指标。。

model.fit(x=x_train, y=y_train, batch_size=128, epochs=10,
validation_split=0.3, shuffle=True)

# 创建一个列表,写入input层和前5层隐层的名称
layer_names = ['input'] + ['activation_'+str(int(i)) for i in range(5)]

# 设置训练的epoch个数
epochs = 10

# 设定画10行、len(layer_names)列的子图,总图大小为(30,20)
fig,ax = plt.subplots(epochs, len(layer_names), figsize=(30,20))

# 训练epochs次,每个epoch训练好后,画出layer_names指定隐层的输出分布直方图
for i in range(epochs):
    print('epoch'+str(i))
    model.fit(x=x_train, y=y_train, batch_size=128, epochs=1,
              validation_split=0.3, shuffle=True)

    # 每次训练完,输入x_train,获取指定隐层的输出,并画出直方图
    for j, name in enumerate(layer_names):
        layer_model = Model(
            inputs=model.input, outputs=model.get_layer(name).output)
        pred = layer_model.predict(x_train)

        ax[i,j].hist(pred.reshape(-1), bins=100)
        ax[i,j].set_xlabel(name)
        ax[i,j].set_ylabel('epoch'+str(i))

plt.show()
epoch0
329/329 [==============================] - 8s 23ms/step - loss: 0.7001 - accuracy: 0.7831 - val_loss: 0.6781 - val_accuracy: 0.7853
epoch1
329/329 [==============================] - 7s 20ms/step - loss: 0.3437 - accuracy: 0.8966 - val_loss: 0.3088 - val_accuracy: 0.9134
epoch2
329/329 [==============================] - 7s 20ms/step - loss: 0.2707 - accuracy: 0.9181 - val_loss: 0.3986 - val_accuracy: 0.8781
epoch3
329/329 [==============================] - 7s 20ms/step - loss: 0.2344 - accuracy: 0.9290 - val_loss: 0.7669 - val_accuracy: 0.8232
epoch4
329/329 [==============================] - 7s 21ms/step - loss: 0.2089 - accuracy: 0.9366 - val_loss: 0.3106 - val_accuracy: 0.9089
epoch5
329/329 [==============================] - 7s 20ms/step - loss: 0.1898 - accuracy: 0.9426 - val_loss: 0.6345 - val_accuracy: 0.7938
epoch6
329/329 [==============================] - 7s 20ms/step - loss: 0.1804 - accuracy: 0.9456 - val_loss: 0.5739 - val_accuracy: 0.8327
epoch7
329/329 [==============================] - 7s 21ms/step - loss: 0.1640 - accuracy: 0.9507 - val_loss: 0.1901 - val_accuracy: 0.9456
epoch8
329/329 [==============================] - 7s 20ms/step - loss: 0.1519 - accuracy: 0.9526 - val_loss: 0.1932 - val_accuracy: 0.9464
epoch9
329/329 [==============================] - 7s 20ms/step - loss: 0.1390 - accuracy: 0.9576 - val_loss: 0.1572 - val_accuracy: 0.9570

在这里插入图片描述

测试集评估结果

loss, accuracy = model.evaluate(x_test, y_test)
print('loss: ', loss)
print('accuracy: ', accuracy)
313/313 [==============================] - 1s 3ms/step - loss: 0.1528 - accuracy: 0.9569
loss:  0.15275312960147858
accuracy:  0.9569000005722046

10个epoch训练得到的模型的test accuracy最好结果:

权重初始化 激活函数 Batch Normalization Test Accuracy
N(0,1) tanh no 0.1094
N(0,0.005) tanh no 0.1135
Xavier tanh no 0.9289
He relu no 0.9204
Xavier elu no 0.7811
Xavier selu no 0.9582
N(0,1) tanh yes 0.3969

添加了Batch Normalization的N(0,1)高斯分布初始化,训练100个epoch的test accuracy为0.8867。

相关文章
|
6月前
|
机器学习/深度学习 数据采集 TensorFlow
R语言KERAS深度学习CNN卷积神经网络分类识别手写数字图像数据(MNIST)
R语言KERAS深度学习CNN卷积神经网络分类识别手写数字图像数据(MNIST)
|
机器学习/深度学习 算法 PyTorch
【深度学习】实验16 使用CNN完成MNIST手写体识别(PyTorch)
【深度学习】实验16 使用CNN完成MNIST手写体识别(PyTorch)
156 0
|
机器学习/深度学习 自然语言处理 算法
【深度学习】实验15 使用CNN完成MNIST手写体识别(Keras)
【深度学习】实验15 使用CNN完成MNIST手写体识别(Keras)
116 0
|
机器学习/深度学习 算法 TensorFlow
【深度学习】实验14 使用CNN完成MNIST手写体识别(TensorFlow)
【深度学习】实验14 使用CNN完成MNIST手写体识别(TensorFlow)
124 0
|
机器学习/深度学习 监控 测试技术
深度学习神经网络数字识别案例
深度学习神经网络数字识别案例
306 0
|
机器学习/深度学习 算法 BI
【深度学习】基于知识库的手写体数字识别(Matlab代码实现)
【深度学习】基于知识库的手写体数字识别(Matlab代码实现)
167 0
|
机器学习/深度学习
深度学习入门笔记7 手写数字识别 续
深度学习入门笔记7 手写数字识别 续
|
机器学习/深度学习 算法 数据挖掘
基于ResNet18深度学习网络的mnist手写数字数据库识别matlab仿真
基于ResNet18深度学习网络的mnist手写数字数据库识别matlab仿真
247 0
|
机器学习/深度学习 计算机视觉 Python
【深度学习实践(二)】上手手写数字识别
【深度学习实践(二)】上手手写数字识别
【深度学习实践(二)】上手手写数字识别