直接使用
请打开Tensorflow 2构建CNN模型,并点击右上角 “ 在DSW中打开” 。
Tensorflow2 And Keras
Tensorflow 2是Google公司基于Tensorflow 1开发的深度学习框架。在架构上,API,还有所支持的硬件种类都做了深度的优化。 Tensorflow 2的架构,主要包括两层
- 训练层
- 部署层
Tensorflow 2的主要特性 -(1)使用tf.data加载数据 -(2)使用tf.keras构建模型,也可以使用premade estimator来验证模型,使用tensorflow hub进行迁移学习 -(3)使用eager mode进行运行和调试 -(4)使用分发策略来进行分布式训练 -(5)导出到SaveModel -(6)使用Tensorflow Server、TensorFlow Lite、TensorFlow.js部署模型 -(7)强大的跨平台能力,Tensorflow2服务直接通过HTTP/REST或者GRPC/协议缓冲区实现,TensorFlow Lite可以直接部署在Android、IOS和嵌入式系统上,TensorFlow.js在 javascript中部署模型 -(8)Tf.keras功能API和子类API,允许创建负责的拓扑结构 -(9)自定义训练逻辑,使用tf.GradientTape和tf.custom_gradient进行更细粒度的控制 -(10)底层API可以与高层结合使用,完全的可定制 -(11)高级扩展:Ragged Tensors、Tensor2Tensor
Keras Keras是Tensorflow的高阶API,可以有效的提升模型开发的效率
本文就以Tensorflow 2中的tf.keras为基础,DEMO一下如何使用Keras来开发/训练模型
1. Import Tensorflow
import tensorflow as tf import seaborn as sns import numpy as np import pandas as pd import matplotlib.pyplot as plt import matplotlib.image as mpimg
2. Load数据集
mnist = tf.keras.datasets.mnist (x_train, y_train), (x_test, y_test) = mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11493376/11490434 [==============================] - 0s 0us/step
2.1 可视化的看一下目前的数据集中,各个类别的样本是否均衡
sns.countplot(y_train)
/home/pai/lib/python3.6/site-packages/seaborn/_decorators.py:43: FutureWarning: Pass the following variable as a keyword arg: x. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation. FutureWarning
<AxesSubplot:ylabel='count'>
2.2 查看训练数据中是否有NaN的样本
np.isnan(x_train).any()
False
查看测试数据集中是否有Nan的样本
np.isnan(x_test).any()
False
3. 数据预处理,这里做两件事:
- reshape我们的输入数据集,以满足本文中模型的对输入数据形状的要求
- 归一化
input_shape = (28, 28, 1) x_train=x_train.reshape(x_train.shape[0], x_train.shape[1], x_train.shape[2], 1) x_train=x_train / 255.0 x_test = x_test.reshape(x_test.shape[0], x_test.shape[1], x_test.shape[2], 1) x_test=x_test/255.0
对label进行编码,这里使用one-hot-encoding
y_train = tf.one_hot(y_train.astype(np.int32), depth=10) y_test = tf.one_hot(y_test.astype(np.int32), depth=10)
4. 构建CNN模型
- 使用tf.keras.models.Sequential接口进行构建
- 依次添加卷积层tf.keras.layers.Conv2D
- MaxPooling层tf.keras.layers.MaxPool2D
- Dropout tf.keras.layers.Dropout
- 全连接层 tf.keras.layers.Dense
batch_size = 64 num_classes = 10 epochs = 50 model = tf.keras.models.Sequential([ tf.keras.layers.Conv2D(32, (5,5), padding='same', activation='relu', input_shape=input_shape), tf.keras.layers.Conv2D(32, (5,5), padding='same', activation='relu'), tf.keras.layers.MaxPool2D(), tf.keras.layers.Dropout(0.25), tf.keras.layers.Conv2D(64, (3,3), padding='same', activation='relu'), tf.keras.layers.Conv2D(64, (3,3), padding='same', activation='relu'), tf.keras.layers.MaxPool2D(strides=(2,2)), tf.keras.layers.Dropout(0.25), tf.keras.layers.Flatten(), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dropout(0.5), tf.keras.layers.Dense(num_classes, activation='softmax') ]) model.compile(optimizer=tf.keras.optimizers.RMSprop(epsilon=1e-08), loss='categorical_crossentropy', metrics=['acc'])
5. 定义相关的回调函数
- 这个回调函数是在每一个epoch结束的时候,检查一下accuracy是不是大于99.5%,如果是的话,那么就停止训练
class myCallback(tf.keras.callbacks.Callback): def on_epoch_end(self, epoch, logs={}): if(logs.get('acc')>0.995): print("\nReached 99.5% accuracy so cancelling training!") self.model.stop_training = True callbacks = myCallback()
6. 开始训练
history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1, callbacks=[callbacks])
Click to hideEpoch 1/50 844/844 [==============================] - 85s 101ms/step - loss: 0.2355 - acc: 0.9275 - val_loss: 0.0546 - val_acc: 0.9845 Epoch 2/50 844/844 [==============================] - 85s 100ms/step - loss: 0.0767 - acc: 0.9779 - val_loss: 0.0305 - val_acc: 0.9913 Epoch 3/50 844/844 [==============================] - 85s 101ms/step - loss: 0.0567 - acc: 0.9837 - val_loss: 0.0287 - val_acc: 0.9925 Epoch 4/50 844/844 [==============================] - 85s 100ms/step - loss: 0.0500 - acc: 0.9861 - val_loss: 0.0261 - val_acc: 0.9928 Epoch 5/50 844/844 [==============================] - 85s 101ms/step - loss: 0.0490 - acc: 0.9867 - val_loss: 0.0280 - val_acc: 0.9927 Epoch 6/50 844/844 [==============================] - 85s 100ms/step - loss: 0.0465 - acc: 0.9876 - val_loss: 0.0457 - val_acc: 0.9895 Epoch 7/50 844/844 [==============================] - 85s 100ms/step - loss: 0.0449 - acc: 0.9872 - val_loss: 0.0290 - val_acc: 0.9928 Epoch 8/50 844/844 [==============================] - 85s 100ms/step - loss: 0.0490 - acc: 0.9869 - val_loss: 0.0297 - val_acc: 0.9925 Epoch 9/50 844/844 [==============================] - 84s 100ms/step - loss: 0.0487 - acc: 0.9868 - val_loss: 0.0265 - val_acc: 0.9932 Epoch 10/50 844/844 [==============================] - 85s 100ms/step - loss: 0.0502 - acc: 0.9875 - val_loss: 0.0280 - val_acc: 0.9932 Epoch 11/50 844/844 [==============================] - 85s 100ms/step - loss: 0.0531 - acc: 0.9864 - val_loss: 0.0345 - val_acc: 0.9908 Epoch 12/50 844/844 [==============================] - 85s 100ms/step - loss: 0.0530 - acc: 0.9860 - val_loss: 0.0418 - val_acc: 0.9893 Epoch 13/50 844/844 [==============================] - 84s 100ms/step - loss: 0.0534 - acc: 0.9869 - val_loss: 0.0403 - val_acc: 0.9910 Epoch 14/50 844/844 [==============================] - 84s 100ms/step - loss: 0.0584 - acc: 0.9862 - val_loss: 0.0326 - val_acc: 0.9927 Epoch 15/50 844/844 [==============================] - 85s 100ms/step - loss: 0.0580 - acc: 0.9863 - val_loss: 0.0455 - val_acc: 0.9910 Epoch 16/50 844/844 [==============================] - 84s 100ms/step - loss: 0.0628 - acc: 0.9862 - val_loss: 0.0437 - val_acc: 0.9910 Epoch 17/50 844/844 [==============================] - 85s 100ms/step - loss: 0.0642 - acc: 0.9849 - val_loss: 0.0302 - val_acc: 0.9918 Epoch 18/50 844/844 [==============================] - 85s 100ms/step - loss: 0.0631 - acc: 0.9850 - val_loss: 0.0427 - val_acc: 0.9930 Epoch 19/50 844/844 [==============================] - 85s 101ms/step - loss: 0.0684 - acc: 0.9847 - val_loss: 0.0348 - val_acc: 0.9923 Epoch 20/50 844/844 [==============================] - 85s 100ms/step - loss: 0.0715 - acc: 0.9844 - val_loss: 0.0327 - val_acc: 0.9920 Epoch 21/50 844/844 [==============================] - 84s 100ms/step - loss: 0.0694 - acc: 0.9839 - val_loss: 0.0822 - val_acc: 0.9862 Epoch 22/50 844/844 [==============================] - 84s 100ms/step - loss: 0.0781 - acc: 0.9838 - val_loss: 0.0423 - val_acc: 0.9910 Epoch 23/50 844/844 [==============================] - 84s 100ms/step - loss: 0.0820 - acc: 0.9831 - val_loss: 0.0392 - val_acc: 0.9903 Epoch 24/50 844/844 [==============================] - 84s 100ms/step - loss: 0.0835 - acc: 0.9824 - val_loss: 0.0425 - val_acc: 0.9912 Epoch 25/50 844/844 [==============================] - 84s 100ms/step - loss: 0.0798 - acc: 0.9825 - val_loss: 0.0310 - val_acc: 0.9940 Epoch 26/50 844/844 [==============================] - 84s 100ms/step - loss: 0.0840 - acc: 0.9820 - val_loss: 0.0737 - val_acc: 0.9895 Epoch 27/50 844/844 [==============================] - 84s 100ms/step - loss: 0.0871 - acc: 0.9823 - val_loss: 0.0597 - val_acc: 0.9897 Epoch 28/50 844/844 [==============================] - 84s 99ms/step - loss: 0.0830 - acc: 0.9829 - val_loss: 0.0333 - val_acc: 0.9908 Epoch 29/50 844/844 [==============================] - 84s 100ms/step - loss: 0.0924 - acc: 0.9806 - val_loss: 0.0406 - val_acc: 0.9915 Epoch 30/50 844/844 [==============================] - 84s 100ms/step - loss: 0.0924 - acc: 0.9815 - val_loss: 0.0490 - val_acc: 0.9895 Epoch 31/50 844/844 [==============================] - 84s 99ms/step - loss: 0.0962 - acc: 0.9811 - val_loss: 0.0491 - val_acc: 0.9888 Epoch 32/50 844/844 [==============================] - 84s 100ms/step - loss: 0.0960 - acc: 0.9803 - val_loss: 0.0538 - val_acc: 0.9882 Epoch 33/50 844/844 [==============================] - 84s 100ms/step - loss: 0.1068 - acc: 0.9793 - val_loss: 0.0411 - val_acc: 0.9905 Epoch 34/50 844/844 [==============================] - 84s 100ms/step - loss: 0.1023 - acc: 0.9799 - val_loss: 0.0437 - val_acc: 0.9893 Epoch 35/50 844/844 [==============================] - 84s 100ms/step - loss: 0.0992 - acc: 0.9802 - val_loss: 0.0424 - val_acc: 0.9905 Epoch 36/50 844/844 [==============================] - 84s 100ms/step - loss: 0.1086 - acc: 0.9794 - val_loss: 0.0435 - val_acc: 0.9890 Epoch 37/50 844/844 [==============================] - 84s 100ms/step - loss: 0.1165 - acc: 0.9784 - val_loss: 0.3074 - val_acc: 0.9313 Epoch 38/50 844/844 [==============================] - 84s 100ms/step - loss: 0.1172 - acc: 0.9784 - val_loss: 0.0559 - val_acc: 0.9913 Epoch 39/50 844/844 [==============================] - 84s 100ms/step - loss: 0.1283 - acc: 0.9779 - val_loss: 0.0429 - val_acc: 0.9912 Epoch 40/50 844/844 [==============================] - 84s 99ms/step - loss: 0.1209 - acc: 0.9778 - val_loss: 0.0463 - val_acc: 0.9910 Epoch 41/50 844/844 [==============================] - 84s 100ms/step - loss: 0.1292 - acc: 0.9755 - val_loss: 0.0459 - val_acc: 0.9877 Epoch 42/50 844/844 [==============================] - 84s 100ms/step - loss: 0.1300 - acc: 0.9750 - val_loss: 0.0437 - val_acc: 0.9888 Epoch 43/50 844/844 [==============================] - 84s 100ms/step - loss: 0.1289 - acc: 0.9760 - val_loss: 0.0728 - val_acc: 0.9857 Epoch 44/50 844/844 [==============================] - 84s 100ms/step - loss: 0.1393 - acc: 0.9748 - val_loss: 0.1381 - val_acc: 0.9725 Epoch 45/50 844/844 [==============================] - 84s 100ms/step - loss: 0.1415 - acc: 0.9736 - val_loss: 0.0547 - val_acc: 0.9897 Epoch 46/50 844/844 [==============================] - 84s 100ms/step - loss: 0.1498 - acc: 0.9744 - val_loss: 0.1881 - val_acc: 0.9838 Epoch 47/50 844/844 [==============================] - 84s 100ms/step - loss: 0.1502 - acc: 0.9728 - val_loss: 0.0550 - val_acc: 0.9890 Epoch 48/50 844/844 [==============================] - 84s 100ms/step - loss: 0.1449 - acc: 0.9722 - val_loss: 0.3376 - val_acc: 0.9302 Epoch 49/50 844/844 [==============================] - 84s 100ms/step - loss: 0.1522 - acc: 0.9712 - val_loss: 0.0673 - val_acc: 0.9852 Epoch 50/50 844/844 [==============================] - 84s 100ms/step - loss: 0.1546 - acc: 0.9735 - val_loss: 0.0588 - val_acc: 0.9883
7. Evaluate模型
fig, ax = plt.subplots(2,1) ax[0].plot(history.history['loss'], color='b', label="Training Loss") ax[0].plot(history.history['val_loss'], color='r', label="Validation Loss",axes =ax[0]) legend = ax[0].legend(loc='best', shadow=True) ax[1].plot(history.history['acc'], color='b', label="Training Accuracy") ax[1].plot(history.history['val_acc'], color='r',label="Validation Accuracy") legend = ax[1].legend(loc='best', shadow=True)
8. 使用test数据集做为输入,基于训练好的模型输出结果
test_loss, test_acc = model.evaluate(x_test, y_test)
313/313 [==============================] - 3s 10ms/step - loss: 0.0477 - acc: 0.9877
9. 使用confusion matrix来观察模型的分类效果
# Predict the values from the testing dataset Y_pred = model.predict(x_test) # Convert predictions classes to one hot vectors Y_pred_classes = np.argmax(Y_pred,axis = 1) # Convert testing observations to one hot vectors Y_true = np.argmax(y_test,axis = 1) # compute the confusion matrix confusion_mtx = tf.math.confusion_matrix(Y_true, Y_pred_classes) plt.figure(figsize=(10, 8)) sns.heatmap(confusion_mtx, annot=True, fmt='g')
<AxesSubplot:>
结果分析
- 从上面的混淆矩阵结果来看,分类效果是比较理想的。唯一有点问题的就是[0,6]和[3,5].这个从我们的常识也知道,0和6这两个数字,如果是手写体的话,确实相似度比较高.还有3和5