手写数字识别
Mnist数据集是一个手写数字识别数据集,被称为深度学习界的“Hello World”。
Mnist数据集包含:
- 训练集:60,000张28×28灰度图
- 测试集:10,000张28×28灰度图
共有0~9这10个手写数字体类别。
导入必要的模块
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import datasets, Input, Model
from tensorflow.keras.layers import Flatten, Dense, Activation, BatchNormalization
from tensorflow.keras.initializers import TruncatedNormal
载入Mnist数据集
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
print(x_train.shape)
print(y_train.shape)
print(x_test.shape)
print(y_test.shape)
(60000, 28, 28)
(60000,)
(10000, 28, 28)
(10000,)
维度(60000, 28, 28) —> (60000, 784)
x_train = x_train.reshape(-1, 28*28)
x_test = x_test.reshape(-1, 28*28)
归一化处理
x_train = x_train/255.0
x_test = x_test/255.0
模型搭建
用到的api:
全连接层tf.keras.layers.Dense
用到的参数:
- units:输入整数,全连接层神经元个数。
activation:激活函数。
可选项:
'sigmoid':sigmoid激活函数
'tanh':tanh激活函数
'relu':relu激活函数
'elu'或tf.keras.activations.elu(alpha=1.0):elu激活函数
'selu':selu激活函数
'swish': swish激活函数(tf2.2版本以上才有)
'softmax': softmax函数
kernel_initializer:权重初始化,默认是'glorot_uniform'(即Xavier均匀初始化)。
可选项:
'RandomNormal'或tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=0.05):正态分布采样,均值为0,标准差0.05
'glorot_normal':正态分布采样,均值为0,方差为2 / (fan_in + fan_out)
'glorot_uniform':均匀分布采样,范围[-limit, limit],limit = sqrt(6 / (fan_in + fan_out))
'lecun_normal':正态分布采样,均值为0,方差为1 / fan_in
'lecun_uniform':均匀分布采样,范围[-limit, limit],limit = sqrt(3 / fan_in)
'he_normal':正态分布采样,均值为0,方差为2 / fan_in
'he_uniform':均匀分布采样,范围[-limit, limit],limit = sqrt(6 / fan_in)
fan_in是输入的神经元个数,fan_out是输出的神经元个数。
- name:输入字符串,给该层设置一个名称。
输入层tf.keras.Input
用到的参数:
- shape:输入层的大小。
- name:输入字符串,给该层设置一个名称。
BN层tf.keras.layers.BatchNormalization
用到的参数:
- axis:需要做批处理的维度,默认为-1,也就是输入数据的最后一个维度。
- name:输入字符串,给该层设置一个名称。
模型设置tf.keras.Sequential.compile
用到的参数:
- loss:损失函数,多分类任务一般使用"sparse_categorical_crossentropy",sparse表示先对标签做one hot编码。
- optimizer:优化器,这里选用"sgd",更多优化器请查看https://tensorflow.google.cn/api_docs/python/tf/keras/optimizers
- metrics:评价指标,这里选用"accuracy",更多优化器请查看https://tensorflow.google.cn/api_docs/python/tf/keras/metrics
构建50个隐层的神经网络,每层100个神经元。
# 输入层inputs
inputs = Input(shape=(28*28), name='input')
# 隐层dense
x = Dense(units=100, kernel_initializer='lecun_normal', name='dense_0')(inputs)
x = BatchNormalization(axis=-1, name='batchnormalization_0')(x)
x = Activation(activation='selu', name='activation_0')(x)
for i in range(49):
x = Dense(units=100, kernel_initializer='lecun_normal', name='dense_'+str(i+1))(x)
x = BatchNormalization(axis=-1, name='batchnormalization_'+str(i+1))(x)
x = Activation(activation='selu', name='activation_'+str(i+1))(x)
# 输出层
outputs = Dense(units=10, activation='softmax', name='logit')(x)
# 设置模型的inputs和outputs
model = Model(inputs=inputs, outputs=outputs)
# 设置损失函数loss、优化器optimizer、评价标准metrics
model.compile(loss='sparse_categorical_crossentropy',
optimizer="sgd", metrics=['accuracy'])
查看模型每层的参数量和输出的大小
model.summary()
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input (InputLayer) [(None, 784)] 0
_________________________________________________________________
dense_0 (Dense) (None, 100) 78500
_________________________________________________________________
batch_normalization (BatchNo (None, 100) 400
_________________________________________________________________
activation_0 (Activation) (None, 100) 0
_________________________________________________________________
dense_1 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_1 (Batch (None, 100) 400
_________________________________________________________________
activation_1 (Activation) (None, 100) 0
_________________________________________________________________
dense_2 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_2 (Batch (None, 100) 400
_________________________________________________________________
activation_2 (Activation) (None, 100) 0
_________________________________________________________________
dense_3 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_3 (Batch (None, 100) 400
_________________________________________________________________
activation_3 (Activation) (None, 100) 0
_________________________________________________________________
dense_4 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_4 (Batch (None, 100) 400
_________________________________________________________________
activation_4 (Activation) (None, 100) 0
_________________________________________________________________
dense_5 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_5 (Batch (None, 100) 400
_________________________________________________________________
activation_5 (Activation) (None, 100) 0
_________________________________________________________________
dense_6 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_6 (Batch (None, 100) 400
_________________________________________________________________
activation_6 (Activation) (None, 100) 0
_________________________________________________________________
dense_7 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_7 (Batch (None, 100) 400
_________________________________________________________________
activation_7 (Activation) (None, 100) 0
_________________________________________________________________
dense_8 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_8 (Batch (None, 100) 400
_________________________________________________________________
activation_8 (Activation) (None, 100) 0
_________________________________________________________________
dense_9 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_9 (Batch (None, 100) 400
_________________________________________________________________
activation_9 (Activation) (None, 100) 0
_________________________________________________________________
dense_10 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_10 (Batc (None, 100) 400
_________________________________________________________________
activation_10 (Activation) (None, 100) 0
_________________________________________________________________
dense_11 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_11 (Batc (None, 100) 400
_________________________________________________________________
activation_11 (Activation) (None, 100) 0
_________________________________________________________________
dense_12 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_12 (Batc (None, 100) 400
_________________________________________________________________
activation_12 (Activation) (None, 100) 0
_________________________________________________________________
dense_13 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_13 (Batc (None, 100) 400
_________________________________________________________________
activation_13 (Activation) (None, 100) 0
_________________________________________________________________
dense_14 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_14 (Batc (None, 100) 400
_________________________________________________________________
activation_14 (Activation) (None, 100) 0
_________________________________________________________________
dense_15 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_15 (Batc (None, 100) 400
_________________________________________________________________
activation_15 (Activation) (None, 100) 0
_________________________________________________________________
dense_16 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_16 (Batc (None, 100) 400
_________________________________________________________________
activation_16 (Activation) (None, 100) 0
_________________________________________________________________
dense_17 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_17 (Batc (None, 100) 400
_________________________________________________________________
activation_17 (Activation) (None, 100) 0
_________________________________________________________________
dense_18 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_18 (Batc (None, 100) 400
_________________________________________________________________
activation_18 (Activation) (None, 100) 0
_________________________________________________________________
dense_19 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_19 (Batc (None, 100) 400
_________________________________________________________________
activation_19 (Activation) (None, 100) 0
_________________________________________________________________
dense_20 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_20 (Batc (None, 100) 400
_________________________________________________________________
activation_20 (Activation) (None, 100) 0
_________________________________________________________________
dense_21 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_21 (Batc (None, 100) 400
_________________________________________________________________
activation_21 (Activation) (None, 100) 0
_________________________________________________________________
dense_22 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_22 (Batc (None, 100) 400
_________________________________________________________________
activation_22 (Activation) (None, 100) 0
_________________________________________________________________
dense_23 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_23 (Batc (None, 100) 400
_________________________________________________________________
activation_23 (Activation) (None, 100) 0
_________________________________________________________________
dense_24 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_24 (Batc (None, 100) 400
_________________________________________________________________
activation_24 (Activation) (None, 100) 0
_________________________________________________________________
dense_25 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_25 (Batc (None, 100) 400
_________________________________________________________________
activation_25 (Activation) (None, 100) 0
_________________________________________________________________
dense_26 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_26 (Batc (None, 100) 400
_________________________________________________________________
activation_26 (Activation) (None, 100) 0
_________________________________________________________________
dense_27 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_27 (Batc (None, 100) 400
_________________________________________________________________
activation_27 (Activation) (None, 100) 0
_________________________________________________________________
dense_28 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_28 (Batc (None, 100) 400
_________________________________________________________________
activation_28 (Activation) (None, 100) 0
_________________________________________________________________
dense_29 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_29 (Batc (None, 100) 400
_________________________________________________________________
activation_29 (Activation) (None, 100) 0
_________________________________________________________________
dense_30 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_30 (Batc (None, 100) 400
_________________________________________________________________
activation_30 (Activation) (None, 100) 0
_________________________________________________________________
dense_31 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_31 (Batc (None, 100) 400
_________________________________________________________________
activation_31 (Activation) (None, 100) 0
_________________________________________________________________
dense_32 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_32 (Batc (None, 100) 400
_________________________________________________________________
activation_32 (Activation) (None, 100) 0
_________________________________________________________________
dense_33 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_33 (Batc (None, 100) 400
_________________________________________________________________
activation_33 (Activation) (None, 100) 0
_________________________________________________________________
dense_34 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_34 (Batc (None, 100) 400
_________________________________________________________________
activation_34 (Activation) (None, 100) 0
_________________________________________________________________
dense_35 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_35 (Batc (None, 100) 400
_________________________________________________________________
activation_35 (Activation) (None, 100) 0
_________________________________________________________________
dense_36 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_36 (Batc (None, 100) 400
_________________________________________________________________
activation_36 (Activation) (None, 100) 0
_________________________________________________________________
dense_37 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_37 (Batc (None, 100) 400
_________________________________________________________________
activation_37 (Activation) (None, 100) 0
_________________________________________________________________
dense_38 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_38 (Batc (None, 100) 400
_________________________________________________________________
activation_38 (Activation) (None, 100) 0
_________________________________________________________________
dense_39 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_39 (Batc (None, 100) 400
_________________________________________________________________
activation_39 (Activation) (None, 100) 0
_________________________________________________________________
dense_40 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_40 (Batc (None, 100) 400
_________________________________________________________________
activation_40 (Activation) (None, 100) 0
_________________________________________________________________
dense_41 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_41 (Batc (None, 100) 400
_________________________________________________________________
activation_41 (Activation) (None, 100) 0
_________________________________________________________________
dense_42 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_42 (Batc (None, 100) 400
_________________________________________________________________
activation_42 (Activation) (None, 100) 0
_________________________________________________________________
dense_43 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_43 (Batc (None, 100) 400
_________________________________________________________________
activation_43 (Activation) (None, 100) 0
_________________________________________________________________
dense_44 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_44 (Batc (None, 100) 400
_________________________________________________________________
activation_44 (Activation) (None, 100) 0
_________________________________________________________________
dense_45 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_45 (Batc (None, 100) 400
_________________________________________________________________
activation_45 (Activation) (None, 100) 0
_________________________________________________________________
dense_46 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_46 (Batc (None, 100) 400
_________________________________________________________________
activation_46 (Activation) (None, 100) 0
_________________________________________________________________
dense_47 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_47 (Batc (None, 100) 400
_________________________________________________________________
activation_47 (Activation) (None, 100) 0
_________________________________________________________________
dense_48 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_48 (Batc (None, 100) 400
_________________________________________________________________
activation_48 (Activation) (None, 100) 0
_________________________________________________________________
dense_49 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_49 (Batc (None, 100) 400
_________________________________________________________________
activation_49 (Activation) (None, 100) 0
_________________________________________________________________
logit (Dense) (None, 10) 1010
=================================================================
Total params: 594,410
Trainable params: 584,410
Non-trainable params: 10,000
_________________________________________________________________
模型训练
tf.keras.Sequential.fit
用到的参数:
- x:输入数据。
- y:输入标签。
- batch_size:一次梯度更新使用的数据量。
- epochs:数据集跑多少轮模型训练,一轮表示整个数据集训练一次。
- validation_split:验证集占总数据量的比例,取值0~1。
- shuffle:每轮训练是否打乱数据顺序,默认True。
返回:History对象,History.history属性会记录每一轮训练集和验证集的损失函数值和评价指标。。
model.fit(x=x_train, y=y_train, batch_size=128, epochs=10,
validation_split=0.3, shuffle=True)
# 创建一个列表,写入input层和前5层隐层的名称
layer_names = ['input'] + ['activation_'+str(int(i)) for i in range(5)]
# 设置训练的epoch个数
epochs = 10
# 设定画10行、len(layer_names)列的子图,总图大小为(30,20)
fig,ax = plt.subplots(epochs, len(layer_names), figsize=(30,20))
# 训练epochs次,每个epoch训练好后,画出layer_names指定隐层的输出分布直方图
for i in range(epochs):
print('epoch'+str(i))
model.fit(x=x_train, y=y_train, batch_size=128, epochs=1,
validation_split=0.3, shuffle=True)
# 每次训练完,输入x_train,获取指定隐层的输出,并画出直方图
for j, name in enumerate(layer_names):
layer_model = Model(
inputs=model.input, outputs=model.get_layer(name).output)
pred = layer_model.predict(x_train)
ax[i,j].hist(pred.reshape(-1), bins=100)
ax[i,j].set_xlabel(name)
ax[i,j].set_ylabel('epoch'+str(i))
plt.show()
epoch0
329/329 [==============================] - 8s 23ms/step - loss: 0.7001 - accuracy: 0.7831 - val_loss: 0.6781 - val_accuracy: 0.7853
epoch1
329/329 [==============================] - 7s 20ms/step - loss: 0.3437 - accuracy: 0.8966 - val_loss: 0.3088 - val_accuracy: 0.9134
epoch2
329/329 [==============================] - 7s 20ms/step - loss: 0.2707 - accuracy: 0.9181 - val_loss: 0.3986 - val_accuracy: 0.8781
epoch3
329/329 [==============================] - 7s 20ms/step - loss: 0.2344 - accuracy: 0.9290 - val_loss: 0.7669 - val_accuracy: 0.8232
epoch4
329/329 [==============================] - 7s 21ms/step - loss: 0.2089 - accuracy: 0.9366 - val_loss: 0.3106 - val_accuracy: 0.9089
epoch5
329/329 [==============================] - 7s 20ms/step - loss: 0.1898 - accuracy: 0.9426 - val_loss: 0.6345 - val_accuracy: 0.7938
epoch6
329/329 [==============================] - 7s 20ms/step - loss: 0.1804 - accuracy: 0.9456 - val_loss: 0.5739 - val_accuracy: 0.8327
epoch7
329/329 [==============================] - 7s 21ms/step - loss: 0.1640 - accuracy: 0.9507 - val_loss: 0.1901 - val_accuracy: 0.9456
epoch8
329/329 [==============================] - 7s 20ms/step - loss: 0.1519 - accuracy: 0.9526 - val_loss: 0.1932 - val_accuracy: 0.9464
epoch9
329/329 [==============================] - 7s 20ms/step - loss: 0.1390 - accuracy: 0.9576 - val_loss: 0.1572 - val_accuracy: 0.9570
测试集评估结果
loss, accuracy = model.evaluate(x_test, y_test)
print('loss: ', loss)
print('accuracy: ', accuracy)
313/313 [==============================] - 1s 3ms/step - loss: 0.1528 - accuracy: 0.9569
loss: 0.15275312960147858
accuracy: 0.9569000005722046
10个epoch训练得到的模型的test accuracy最好结果:
权重初始化 | 激活函数 | Batch Normalization | Test Accuracy |
---|---|---|---|
N(0,1) | tanh | no | 0.1094 |
N(0,0.005) | tanh | no | 0.1135 |
Xavier | tanh | no | 0.9289 |
He | relu | no | 0.9204 |
Xavier | elu | no | 0.7811 |
Xavier | selu | no | 0.9582 |
N(0,1) | tanh | yes | 0.3969 |
添加了Batch Normalization的N(0,1)高斯分布初始化,训练100个epoch的test accuracy为0.8867。