开发者社区> 阿里云大数据Al技术> 正文
阿里云
为了无法计算的价值
打开APP
阿里云APP内打开

【DSW Gallery】Tensorflow 2构建CNN模型

简介: 本文基于TensorFlow2版本,构建了一个CNN网络,然后基于Mnist手写体数据集进行手写体的识别。本文从模型的定义,数据的加载,处理,模型的训练到最后的结果的分析以及可视化等方面提供了一个端到端的sample。用户可以基于本文了解使用TensorFlow2进行模型开发的整个流程。
+关注继续查看

直接使用

请打开Tensorflow 2构建CNN模型,并点击右上角 “ 在DSW中打开” 。

image.png


Tensorflow2 And Keras

Tensorflow 2是Google公司基于Tensorflow 1开发的深度学习框架。在架构上,API,还有所支持的硬件种类都做了深度的优化。 Tensorflow 2的架构,主要包括两层

  1. 训练层
  2. 部署层

image

Tensorflow 2的主要特性 -(1)使用tf.data加载数据 -(2)使用tf.keras构建模型,也可以使用premade estimator来验证模型,使用tensorflow hub进行迁移学习 -(3)使用eager mode进行运行和调试 -(4)使用分发策略来进行分布式训练 -(5)导出到SaveModel -(6)使用Tensorflow Server、TensorFlow Lite、TensorFlow.js部署模型 -(7)强大的跨平台能力,Tensorflow2服务直接通过HTTP/REST或者GRPC/协议缓冲区实现,TensorFlow Lite可以直接部署在Android、IOS和嵌入式系统上,TensorFlow.js在 javascript中部署模型 -(8)Tf.keras功能API和子类API,允许创建负责的拓扑结构 -(9)自定义训练逻辑,使用tf.GradientTape和tf.custom_gradient进行更细粒度的控制 -(10)底层API可以与高层结合使用,完全的可定制 -(11)高级扩展:Ragged Tensors、Tensor2Tensor

Keras Keras是Tensorflow的高阶API,可以有效的提升模型开发的效率

本文就以Tensorflow 2中的tf.keras为基础,DEMO一下如何使用Keras来开发/训练模型

1. Import Tensorflow

import tensorflow as tf
import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

2. Load数据集

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step

2.1 可视化的看一下目前的数据集中,各个类别的样本是否均衡

sns.countplot(y_train)
/home/pai/lib/python3.6/site-packages/seaborn/_decorators.py:43: FutureWarning: Pass the following variable as a keyword arg: x. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  FutureWarning
<AxesSubplot:ylabel='count'>

27-1.png

2.2 查看训练数据中是否有NaN的样本

np.isnan(x_train).any()
False

查看测试数据集中是否有Nan的样本

np.isnan(x_test).any()
False

3. 数据预处理,这里做两件事:

    1. reshape我们的输入数据集,以满足本文中模型的对输入数据形状的要求
    2. 归一化
input_shape = (28, 28, 1)

x_train=x_train.reshape(x_train.shape[0], x_train.shape[1], x_train.shape[2], 1)
x_train=x_train / 255.0
x_test = x_test.reshape(x_test.shape[0], x_test.shape[1], x_test.shape[2], 1)
x_test=x_test/255.0

对label进行编码,这里使用one-hot-encoding

y_train = tf.one_hot(y_train.astype(np.int32), depth=10)
y_test = tf.one_hot(y_test.astype(np.int32), depth=10)

4. 构建CNN模型

    1. 使用tf.keras.models.Sequential接口进行构建
    2. 依次添加卷积层tf.keras.layers.Conv2D
    3. MaxPooling层tf.keras.layers.MaxPool2D
    4. Dropout tf.keras.layers.Dropout
    5. 全连接层 tf.keras.layers.Dense
batch_size = 64
num_classes = 10
epochs = 50
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (5,5), padding='same', activation='relu', input_shape=input_shape),
    tf.keras.layers.Conv2D(32, (5,5), padding='same', activation='relu'),
    tf.keras.layers.MaxPool2D(),
    tf.keras.layers.Dropout(0.25),
    tf.keras.layers.Conv2D(64, (3,3), padding='same', activation='relu'),
    tf.keras.layers.Conv2D(64, (3,3), padding='same', activation='relu'),
    tf.keras.layers.MaxPool2D(strides=(2,2)),
    tf.keras.layers.Dropout(0.25),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(num_classes, activation='softmax')
])

model.compile(optimizer=tf.keras.optimizers.RMSprop(epsilon=1e-08), loss='categorical_crossentropy', metrics=['acc'])

5. 定义相关的回调函数

  • 这个回调函数是在每一个epoch结束的时候,检查一下accuracy是不是大于99.5%,如果是的话,那么就停止训练
class myCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs={}):
    if(logs.get('acc')>0.995):
      print("\nReached 99.5% accuracy so cancelling training!")
      self.model.stop_training = True

callbacks = myCallback()

6. 开始训练

history = model.fit(x_train, y_train,
                    batch_size=batch_size,
                    epochs=epochs,
                    validation_split=0.1,
                    callbacks=[callbacks])

image.png

Click to hideEpoch 1/50
844/844 [==============================] - 85s 101ms/step - loss: 0.2355 - acc: 0.9275 - val_loss: 0.0546 - val_acc: 0.9845
Epoch 2/50
844/844 [==============================] - 85s 100ms/step - loss: 0.0767 - acc: 0.9779 - val_loss: 0.0305 - val_acc: 0.9913
Epoch 3/50
844/844 [==============================] - 85s 101ms/step - loss: 0.0567 - acc: 0.9837 - val_loss: 0.0287 - val_acc: 0.9925
Epoch 4/50
844/844 [==============================] - 85s 100ms/step - loss: 0.0500 - acc: 0.9861 - val_loss: 0.0261 - val_acc: 0.9928
Epoch 5/50
844/844 [==============================] - 85s 101ms/step - loss: 0.0490 - acc: 0.9867 - val_loss: 0.0280 - val_acc: 0.9927
Epoch 6/50
844/844 [==============================] - 85s 100ms/step - loss: 0.0465 - acc: 0.9876 - val_loss: 0.0457 - val_acc: 0.9895
Epoch 7/50
844/844 [==============================] - 85s 100ms/step - loss: 0.0449 - acc: 0.9872 - val_loss: 0.0290 - val_acc: 0.9928
Epoch 8/50
844/844 [==============================] - 85s 100ms/step - loss: 0.0490 - acc: 0.9869 - val_loss: 0.0297 - val_acc: 0.9925
Epoch 9/50
844/844 [==============================] - 84s 100ms/step - loss: 0.0487 - acc: 0.9868 - val_loss: 0.0265 - val_acc: 0.9932
Epoch 10/50
844/844 [==============================] - 85s 100ms/step - loss: 0.0502 - acc: 0.9875 - val_loss: 0.0280 - val_acc: 0.9932
Epoch 11/50
844/844 [==============================] - 85s 100ms/step - loss: 0.0531 - acc: 0.9864 - val_loss: 0.0345 - val_acc: 0.9908
Epoch 12/50
844/844 [==============================] - 85s 100ms/step - loss: 0.0530 - acc: 0.9860 - val_loss: 0.0418 - val_acc: 0.9893
Epoch 13/50
844/844 [==============================] - 84s 100ms/step - loss: 0.0534 - acc: 0.9869 - val_loss: 0.0403 - val_acc: 0.9910
Epoch 14/50
844/844 [==============================] - 84s 100ms/step - loss: 0.0584 - acc: 0.9862 - val_loss: 0.0326 - val_acc: 0.9927
Epoch 15/50
844/844 [==============================] - 85s 100ms/step - loss: 0.0580 - acc: 0.9863 - val_loss: 0.0455 - val_acc: 0.9910
Epoch 16/50
844/844 [==============================] - 84s 100ms/step - loss: 0.0628 - acc: 0.9862 - val_loss: 0.0437 - val_acc: 0.9910
Epoch 17/50
844/844 [==============================] - 85s 100ms/step - loss: 0.0642 - acc: 0.9849 - val_loss: 0.0302 - val_acc: 0.9918
Epoch 18/50
844/844 [==============================] - 85s 100ms/step - loss: 0.0631 - acc: 0.9850 - val_loss: 0.0427 - val_acc: 0.9930
Epoch 19/50
844/844 [==============================] - 85s 101ms/step - loss: 0.0684 - acc: 0.9847 - val_loss: 0.0348 - val_acc: 0.9923
Epoch 20/50
844/844 [==============================] - 85s 100ms/step - loss: 0.0715 - acc: 0.9844 - val_loss: 0.0327 - val_acc: 0.9920
Epoch 21/50
844/844 [==============================] - 84s 100ms/step - loss: 0.0694 - acc: 0.9839 - val_loss: 0.0822 - val_acc: 0.9862
Epoch 22/50
844/844 [==============================] - 84s 100ms/step - loss: 0.0781 - acc: 0.9838 - val_loss: 0.0423 - val_acc: 0.9910
Epoch 23/50
844/844 [==============================] - 84s 100ms/step - loss: 0.0820 - acc: 0.9831 - val_loss: 0.0392 - val_acc: 0.9903
Epoch 24/50
844/844 [==============================] - 84s 100ms/step - loss: 0.0835 - acc: 0.9824 - val_loss: 0.0425 - val_acc: 0.9912
Epoch 25/50
844/844 [==============================] - 84s 100ms/step - loss: 0.0798 - acc: 0.9825 - val_loss: 0.0310 - val_acc: 0.9940
Epoch 26/50
844/844 [==============================] - 84s 100ms/step - loss: 0.0840 - acc: 0.9820 - val_loss: 0.0737 - val_acc: 0.9895
Epoch 27/50
844/844 [==============================] - 84s 100ms/step - loss: 0.0871 - acc: 0.9823 - val_loss: 0.0597 - val_acc: 0.9897
Epoch 28/50
844/844 [==============================] - 84s 99ms/step - loss: 0.0830 - acc: 0.9829 - val_loss: 0.0333 - val_acc: 0.9908
Epoch 29/50
844/844 [==============================] - 84s 100ms/step - loss: 0.0924 - acc: 0.9806 - val_loss: 0.0406 - val_acc: 0.9915
Epoch 30/50
844/844 [==============================] - 84s 100ms/step - loss: 0.0924 - acc: 0.9815 - val_loss: 0.0490 - val_acc: 0.9895
Epoch 31/50
844/844 [==============================] - 84s 99ms/step - loss: 0.0962 - acc: 0.9811 - val_loss: 0.0491 - val_acc: 0.9888
Epoch 32/50
844/844 [==============================] - 84s 100ms/step - loss: 0.0960 - acc: 0.9803 - val_loss: 0.0538 - val_acc: 0.9882
Epoch 33/50
844/844 [==============================] - 84s 100ms/step - loss: 0.1068 - acc: 0.9793 - val_loss: 0.0411 - val_acc: 0.9905
Epoch 34/50
844/844 [==============================] - 84s 100ms/step - loss: 0.1023 - acc: 0.9799 - val_loss: 0.0437 - val_acc: 0.9893
Epoch 35/50
844/844 [==============================] - 84s 100ms/step - loss: 0.0992 - acc: 0.9802 - val_loss: 0.0424 - val_acc: 0.9905
Epoch 36/50
844/844 [==============================] - 84s 100ms/step - loss: 0.1086 - acc: 0.9794 - val_loss: 0.0435 - val_acc: 0.9890
Epoch 37/50
844/844 [==============================] - 84s 100ms/step - loss: 0.1165 - acc: 0.9784 - val_loss: 0.3074 - val_acc: 0.9313
Epoch 38/50
844/844 [==============================] - 84s 100ms/step - loss: 0.1172 - acc: 0.9784 - val_loss: 0.0559 - val_acc: 0.9913
Epoch 39/50
844/844 [==============================] - 84s 100ms/step - loss: 0.1283 - acc: 0.9779 - val_loss: 0.0429 - val_acc: 0.9912
Epoch 40/50
844/844 [==============================] - 84s 99ms/step - loss: 0.1209 - acc: 0.9778 - val_loss: 0.0463 - val_acc: 0.9910
Epoch 41/50
844/844 [==============================] - 84s 100ms/step - loss: 0.1292 - acc: 0.9755 - val_loss: 0.0459 - val_acc: 0.9877
Epoch 42/50
844/844 [==============================] - 84s 100ms/step - loss: 0.1300 - acc: 0.9750 - val_loss: 0.0437 - val_acc: 0.9888
Epoch 43/50
844/844 [==============================] - 84s 100ms/step - loss: 0.1289 - acc: 0.9760 - val_loss: 0.0728 - val_acc: 0.9857
Epoch 44/50
844/844 [==============================] - 84s 100ms/step - loss: 0.1393 - acc: 0.9748 - val_loss: 0.1381 - val_acc: 0.9725
Epoch 45/50
844/844 [==============================] - 84s 100ms/step - loss: 0.1415 - acc: 0.9736 - val_loss: 0.0547 - val_acc: 0.9897
Epoch 46/50
844/844 [==============================] - 84s 100ms/step - loss: 0.1498 - acc: 0.9744 - val_loss: 0.1881 - val_acc: 0.9838
Epoch 47/50
844/844 [==============================] - 84s 100ms/step - loss: 0.1502 - acc: 0.9728 - val_loss: 0.0550 - val_acc: 0.9890
Epoch 48/50
844/844 [==============================] - 84s 100ms/step - loss: 0.1449 - acc: 0.9722 - val_loss: 0.3376 - val_acc: 0.9302
Epoch 49/50
844/844 [==============================] - 84s 100ms/step - loss: 0.1522 - acc: 0.9712 - val_loss: 0.0673 - val_acc: 0.9852
Epoch 50/50
844/844 [==============================] - 84s 100ms/step - loss: 0.1546 - acc: 0.9735 - val_loss: 0.0588 - val_acc: 0.9883

7. Evaluate模型

fig, ax = plt.subplots(2,1)
ax[0].plot(history.history['loss'], color='b', label="Training Loss")
ax[0].plot(history.history['val_loss'], color='r', label="Validation Loss",axes =ax[0])
legend = ax[0].legend(loc='best', shadow=True)

ax[1].plot(history.history['acc'], color='b', label="Training Accuracy")
ax[1].plot(history.history['val_acc'], color='r',label="Validation Accuracy")
legend = ax[1].legend(loc='best', shadow=True)

27-2.png

8. 使用test数据集做为输入,基于训练好的模型输出结果

test_loss, test_acc = model.evaluate(x_test, y_test)
313/313 [==============================] - 3s 10ms/step - loss: 0.0477 - acc: 0.9877

9. 使用confusion matrix来观察模型的分类效果

# Predict the values from the testing dataset
Y_pred = model.predict(x_test)
# Convert predictions classes to one hot vectors 
Y_pred_classes = np.argmax(Y_pred,axis = 1) 
# Convert testing observations to one hot vectors
Y_true = np.argmax(y_test,axis = 1)
# compute the confusion matrix
confusion_mtx = tf.math.confusion_matrix(Y_true, Y_pred_classes) 

plt.figure(figsize=(10, 8))
sns.heatmap(confusion_mtx, annot=True, fmt='g')
<AxesSubplot:>

27-3.png

结果分析

  • 从上面的混淆矩阵结果来看,分类效果是比较理想的。唯一有点问题的就是[0,6]和[3,5].这个从我们的常识也知道,0和6这两个数字,如果是手写体的话,确实相似度比较高.还有3和5

版权声明:本文内容由阿里云实名注册用户自发贡献,版权归原作者所有,阿里云开发者社区不拥有其著作权,亦不承担相应法律责任。具体规则请查看《阿里云开发者社区用户服务协议》和《阿里云开发者社区知识产权保护指引》。如果您发现本社区中有涉嫌抄袭的内容,填写侵权投诉表单进行举报,一经查实,本社区将立刻删除涉嫌侵权内容。

相关文章
【DSW Gallery】使用Tensorflow来构建AutoEncoder
本文基于TensorFlow 1.x版本,实现了一个自编码器。自编码器是一个应用比较广泛的神经网络。他可以用来做非监督的异常检测,也可以用在特征工程之中,衡量feature之间的高阶非线性关系等等。
0 0
【DSW Gallery】 XGBoost:如何使用XGBoost解决回归问题
XGBoost作为机器学习领域的一款经典的Boosting算法,深受学界和工业界的推崇。其中很重要的一点就是它具有优秀的鲁棒性,并且在工程实现上面进行了大量的优化,在模型的复杂度和性能之间取得了很好的平衡。
0 0
一步一步学用Tensorflow构建卷积神经网络
本文主要和大家分享如何使用Tensorflow从头开始构建和训练卷积神经网络。这样就可以将这个知识作为一个构建块来创造有趣的深度学习应用程序了。
17836 0
TF之CNN:Tensorflow构建卷积神经网络CNN的简介、使用方法、应用之详细攻略
TF之CNN:Tensorflow构建卷积神经网络CNN的简介、使用方法、应用之详细攻略
0 0
简明 TensorFlow 教程 —  第三部分: 所有的模型
本文讲的是简明 TensorFlow 教程 —  第三部分: 所有的模型,在本文中,我们将讨论 TensorFlow 中当前可用的所有抽象模型,并描述该特定模型的用例以及简单的示例代码。 完整的工作示例源码。
980 0
TensorFlow构建循环神经网络
前言 前面在《循环神经网络》文章中已经介绍了深度学习的循环神经网络模型及其原理,接下去这篇文章将尝试使用TensorFlow来实现一个循环神经网络,该例子能通过训练给定的语料生成模型并实现对字符的预测。
1202 0
【DSW Gallery】使用Numpy实现卷积神经网络
Numpy是数值计算中使用非常广泛的一个工具包,可以进行高纬度空间内部的矩阵运算。本文以CNN为例子,使用Numpy来实现CNN网络的前向传递和反向传递逻辑。对于了解CNN网络的细节以及学习如何使用Numpy都很有帮助。
0 0
使用TensorFlow提供的slim模型来训练数据模型供iOS使用
使用slim模型来训练数据供移动端使用 1、  数据可以是slim提供的数据集或者是自己采集的图片 1.1、下载slim提供的数据集flowers 1.1.1、设置下载目录命令: DATA_DIR=/Users/javalong/Desktop/Test/output/flowers 1.
4228 0
+关注
阿里云大数据Al技术
阿里云大数据Al技术
文章
问答
来源圈子
更多
相关文档: 机器学习平台PAI
文章排行榜
最热
最新
相关电子书
更多
ADMM based Scalable Machine Learning on Apache Spark
立即下载
用AI高效测试移动应用
立即下载
ST+AliOS > Smart IoT !
立即下载