【DSW Gallery】Tensorflow 2构建CNN模型

本文涉及的产品
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
模型训练 PAI-DLC,5000CU*H 3个月
交互式建模 PAI-DSW,每月250计算时 3个月
简介: 本文基于TensorFlow2版本,构建了一个CNN网络,然后基于Mnist手写体数据集进行手写体的识别。本文从模型的定义,数据的加载,处理,模型的训练到最后的结果的分析以及可视化等方面提供了一个端到端的sample。用户可以基于本文了解使用TensorFlow2进行模型开发的整个流程。

直接使用

请打开Tensorflow 2构建CNN模型,并点击右上角 “ 在DSW中打开” 。

image.png


Tensorflow2 And Keras

Tensorflow 2是Google公司基于Tensorflow 1开发的深度学习框架。在架构上,API,还有所支持的硬件种类都做了深度的优化。 Tensorflow 2的架构,主要包括两层

  1. 训练层
  2. 部署层

网络异常,图片无法展示
|

Tensorflow 2的主要特性 -(1)使用tf.data加载数据 -(2)使用tf.keras构建模型,也可以使用premade estimator来验证模型,使用tensorflow hub进行迁移学习 -(3)使用eager mode进行运行和调试 -(4)使用分发策略来进行分布式训练 -(5)导出到SaveModel -(6)使用Tensorflow Server、TensorFlow Lite、TensorFlow.js部署模型 -(7)强大的跨平台能力,Tensorflow2服务直接通过HTTP/REST或者GRPC/协议缓冲区实现,TensorFlow Lite可以直接部署在Android、IOS和嵌入式系统上,TensorFlow.js在 javascript中部署模型 -(8)Tf.keras功能API和子类API,允许创建负责的拓扑结构 -(9)自定义训练逻辑,使用tf.GradientTape和tf.custom_gradient进行更细粒度的控制 -(10)底层API可以与高层结合使用,完全的可定制 -(11)高级扩展:Ragged Tensors、Tensor2Tensor

Keras Keras是Tensorflow的高阶API,可以有效的提升模型开发的效率

本文就以Tensorflow 2中的tf.keras为基础,DEMO一下如何使用Keras来开发/训练模型

1. Import Tensorflow

import tensorflow as tf
import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

2. Load数据集

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step

2.1 可视化的看一下目前的数据集中,各个类别的样本是否均衡

sns.countplot(y_train)
/home/pai/lib/python3.6/site-packages/seaborn/_decorators.py:43: FutureWarning: Pass the following variable as a keyword arg: x. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  FutureWarning
<AxesSubplot:ylabel='count'>

27-1.png

2.2 查看训练数据中是否有NaN的样本

np.isnan(x_train).any()
False

查看测试数据集中是否有Nan的样本

np.isnan(x_test).any()
False

3. 数据预处理,这里做两件事:

  1. reshape我们的输入数据集,以满足本文中模型的对输入数据形状的要求
  2. 归一化
input_shape = (28, 28, 1)
x_train=x_train.reshape(x_train.shape[0], x_train.shape[1], x_train.shape[2], 1)
x_train=x_train / 255.0
x_test = x_test.reshape(x_test.shape[0], x_test.shape[1], x_test.shape[2], 1)
x_test=x_test/255.0

对label进行编码,这里使用one-hot-encoding

y_train = tf.one_hot(y_train.astype(np.int32), depth=10)
y_test = tf.one_hot(y_test.astype(np.int32), depth=10)

4. 构建CNN模型

  1. 使用tf.keras.models.Sequential接口进行构建
  2. 依次添加卷积层tf.keras.layers.Conv2D
  3. MaxPooling层tf.keras.layers.MaxPool2D
  4. Dropout tf.keras.layers.Dropout
  5. 全连接层 tf.keras.layers.Dense
batch_size = 64
num_classes = 10
epochs = 50
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (5,5), padding='same', activation='relu', input_shape=input_shape),
    tf.keras.layers.Conv2D(32, (5,5), padding='same', activation='relu'),
    tf.keras.layers.MaxPool2D(),
    tf.keras.layers.Dropout(0.25),
    tf.keras.layers.Conv2D(64, (3,3), padding='same', activation='relu'),
    tf.keras.layers.Conv2D(64, (3,3), padding='same', activation='relu'),
    tf.keras.layers.MaxPool2D(strides=(2,2)),
    tf.keras.layers.Dropout(0.25),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(num_classes, activation='softmax')
])
model.compile(optimizer=tf.keras.optimizers.RMSprop(epsilon=1e-08), loss='categorical_crossentropy', metrics=['acc'])

5. 定义相关的回调函数

  • 这个回调函数是在每一个epoch结束的时候,检查一下accuracy是不是大于99.5%,如果是的话,那么就停止训练
class myCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs={}):
    if(logs.get('acc')>0.995):
      print("\nReached 99.5% accuracy so cancelling training!")
      self.model.stop_training = True
callbacks = myCallback()

6. 开始训练

history = model.fit(x_train, y_train,
                    batch_size=batch_size,
                    epochs=epochs,
                    validation_split=0.1,
                    callbacks=[callbacks])

image.png

Click to hideEpoch 1/50
844/844 [==============================] - 85s 101ms/step - loss: 0.2355 - acc: 0.9275 - val_loss: 0.0546 - val_acc: 0.9845
Epoch 2/50
844/844 [==============================] - 85s 100ms/step - loss: 0.0767 - acc: 0.9779 - val_loss: 0.0305 - val_acc: 0.9913
Epoch 3/50
844/844 [==============================] - 85s 101ms/step - loss: 0.0567 - acc: 0.9837 - val_loss: 0.0287 - val_acc: 0.9925
Epoch 4/50
844/844 [==============================] - 85s 100ms/step - loss: 0.0500 - acc: 0.9861 - val_loss: 0.0261 - val_acc: 0.9928
Epoch 5/50
844/844 [==============================] - 85s 101ms/step - loss: 0.0490 - acc: 0.9867 - val_loss: 0.0280 - val_acc: 0.9927
Epoch 6/50
844/844 [==============================] - 85s 100ms/step - loss: 0.0465 - acc: 0.9876 - val_loss: 0.0457 - val_acc: 0.9895
Epoch 7/50
844/844 [==============================] - 85s 100ms/step - loss: 0.0449 - acc: 0.9872 - val_loss: 0.0290 - val_acc: 0.9928
Epoch 8/50
844/844 [==============================] - 85s 100ms/step - loss: 0.0490 - acc: 0.9869 - val_loss: 0.0297 - val_acc: 0.9925
Epoch 9/50
844/844 [==============================] - 84s 100ms/step - loss: 0.0487 - acc: 0.9868 - val_loss: 0.0265 - val_acc: 0.9932
Epoch 10/50
844/844 [==============================] - 85s 100ms/step - loss: 0.0502 - acc: 0.9875 - val_loss: 0.0280 - val_acc: 0.9932
Epoch 11/50
844/844 [==============================] - 85s 100ms/step - loss: 0.0531 - acc: 0.9864 - val_loss: 0.0345 - val_acc: 0.9908
Epoch 12/50
844/844 [==============================] - 85s 100ms/step - loss: 0.0530 - acc: 0.9860 - val_loss: 0.0418 - val_acc: 0.9893
Epoch 13/50
844/844 [==============================] - 84s 100ms/step - loss: 0.0534 - acc: 0.9869 - val_loss: 0.0403 - val_acc: 0.9910
Epoch 14/50
844/844 [==============================] - 84s 100ms/step - loss: 0.0584 - acc: 0.9862 - val_loss: 0.0326 - val_acc: 0.9927
Epoch 15/50
844/844 [==============================] - 85s 100ms/step - loss: 0.0580 - acc: 0.9863 - val_loss: 0.0455 - val_acc: 0.9910
Epoch 16/50
844/844 [==============================] - 84s 100ms/step - loss: 0.0628 - acc: 0.9862 - val_loss: 0.0437 - val_acc: 0.9910
Epoch 17/50
844/844 [==============================] - 85s 100ms/step - loss: 0.0642 - acc: 0.9849 - val_loss: 0.0302 - val_acc: 0.9918
Epoch 18/50
844/844 [==============================] - 85s 100ms/step - loss: 0.0631 - acc: 0.9850 - val_loss: 0.0427 - val_acc: 0.9930
Epoch 19/50
844/844 [==============================] - 85s 101ms/step - loss: 0.0684 - acc: 0.9847 - val_loss: 0.0348 - val_acc: 0.9923
Epoch 20/50
844/844 [==============================] - 85s 100ms/step - loss: 0.0715 - acc: 0.9844 - val_loss: 0.0327 - val_acc: 0.9920
Epoch 21/50
844/844 [==============================] - 84s 100ms/step - loss: 0.0694 - acc: 0.9839 - val_loss: 0.0822 - val_acc: 0.9862
Epoch 22/50
844/844 [==============================] - 84s 100ms/step - loss: 0.0781 - acc: 0.9838 - val_loss: 0.0423 - val_acc: 0.9910
Epoch 23/50
844/844 [==============================] - 84s 100ms/step - loss: 0.0820 - acc: 0.9831 - val_loss: 0.0392 - val_acc: 0.9903
Epoch 24/50
844/844 [==============================] - 84s 100ms/step - loss: 0.0835 - acc: 0.9824 - val_loss: 0.0425 - val_acc: 0.9912
Epoch 25/50
844/844 [==============================] - 84s 100ms/step - loss: 0.0798 - acc: 0.9825 - val_loss: 0.0310 - val_acc: 0.9940
Epoch 26/50
844/844 [==============================] - 84s 100ms/step - loss: 0.0840 - acc: 0.9820 - val_loss: 0.0737 - val_acc: 0.9895
Epoch 27/50
844/844 [==============================] - 84s 100ms/step - loss: 0.0871 - acc: 0.9823 - val_loss: 0.0597 - val_acc: 0.9897
Epoch 28/50
844/844 [==============================] - 84s 99ms/step - loss: 0.0830 - acc: 0.9829 - val_loss: 0.0333 - val_acc: 0.9908
Epoch 29/50
844/844 [==============================] - 84s 100ms/step - loss: 0.0924 - acc: 0.9806 - val_loss: 0.0406 - val_acc: 0.9915
Epoch 30/50
844/844 [==============================] - 84s 100ms/step - loss: 0.0924 - acc: 0.9815 - val_loss: 0.0490 - val_acc: 0.9895
Epoch 31/50
844/844 [==============================] - 84s 99ms/step - loss: 0.0962 - acc: 0.9811 - val_loss: 0.0491 - val_acc: 0.9888
Epoch 32/50
844/844 [==============================] - 84s 100ms/step - loss: 0.0960 - acc: 0.9803 - val_loss: 0.0538 - val_acc: 0.9882
Epoch 33/50
844/844 [==============================] - 84s 100ms/step - loss: 0.1068 - acc: 0.9793 - val_loss: 0.0411 - val_acc: 0.9905
Epoch 34/50
844/844 [==============================] - 84s 100ms/step - loss: 0.1023 - acc: 0.9799 - val_loss: 0.0437 - val_acc: 0.9893
Epoch 35/50
844/844 [==============================] - 84s 100ms/step - loss: 0.0992 - acc: 0.9802 - val_loss: 0.0424 - val_acc: 0.9905
Epoch 36/50
844/844 [==============================] - 84s 100ms/step - loss: 0.1086 - acc: 0.9794 - val_loss: 0.0435 - val_acc: 0.9890
Epoch 37/50
844/844 [==============================] - 84s 100ms/step - loss: 0.1165 - acc: 0.9784 - val_loss: 0.3074 - val_acc: 0.9313
Epoch 38/50
844/844 [==============================] - 84s 100ms/step - loss: 0.1172 - acc: 0.9784 - val_loss: 0.0559 - val_acc: 0.9913
Epoch 39/50
844/844 [==============================] - 84s 100ms/step - loss: 0.1283 - acc: 0.9779 - val_loss: 0.0429 - val_acc: 0.9912
Epoch 40/50
844/844 [==============================] - 84s 99ms/step - loss: 0.1209 - acc: 0.9778 - val_loss: 0.0463 - val_acc: 0.9910
Epoch 41/50
844/844 [==============================] - 84s 100ms/step - loss: 0.1292 - acc: 0.9755 - val_loss: 0.0459 - val_acc: 0.9877
Epoch 42/50
844/844 [==============================] - 84s 100ms/step - loss: 0.1300 - acc: 0.9750 - val_loss: 0.0437 - val_acc: 0.9888
Epoch 43/50
844/844 [==============================] - 84s 100ms/step - loss: 0.1289 - acc: 0.9760 - val_loss: 0.0728 - val_acc: 0.9857
Epoch 44/50
844/844 [==============================] - 84s 100ms/step - loss: 0.1393 - acc: 0.9748 - val_loss: 0.1381 - val_acc: 0.9725
Epoch 45/50
844/844 [==============================] - 84s 100ms/step - loss: 0.1415 - acc: 0.9736 - val_loss: 0.0547 - val_acc: 0.9897
Epoch 46/50
844/844 [==============================] - 84s 100ms/step - loss: 0.1498 - acc: 0.9744 - val_loss: 0.1881 - val_acc: 0.9838
Epoch 47/50
844/844 [==============================] - 84s 100ms/step - loss: 0.1502 - acc: 0.9728 - val_loss: 0.0550 - val_acc: 0.9890
Epoch 48/50
844/844 [==============================] - 84s 100ms/step - loss: 0.1449 - acc: 0.9722 - val_loss: 0.3376 - val_acc: 0.9302
Epoch 49/50
844/844 [==============================] - 84s 100ms/step - loss: 0.1522 - acc: 0.9712 - val_loss: 0.0673 - val_acc: 0.9852
Epoch 50/50
844/844 [==============================] - 84s 100ms/step - loss: 0.1546 - acc: 0.9735 - val_loss: 0.0588 - val_acc: 0.9883

7. Evaluate模型

fig, ax = plt.subplots(2,1)
ax[0].plot(history.history['loss'], color='b', label="Training Loss")
ax[0].plot(history.history['val_loss'], color='r', label="Validation Loss",axes =ax[0])
legend = ax[0].legend(loc='best', shadow=True)
ax[1].plot(history.history['acc'], color='b', label="Training Accuracy")
ax[1].plot(history.history['val_acc'], color='r',label="Validation Accuracy")
legend = ax[1].legend(loc='best', shadow=True)

27-2.png

8. 使用test数据集做为输入,基于训练好的模型输出结果

test_loss, test_acc = model.evaluate(x_test, y_test)
313/313 [==============================] - 3s 10ms/step - loss: 0.0477 - acc: 0.9877

9. 使用confusion matrix来观察模型的分类效果

# Predict the values from the testing dataset
Y_pred = model.predict(x_test)
# Convert predictions classes to one hot vectors 
Y_pred_classes = np.argmax(Y_pred,axis = 1) 
# Convert testing observations to one hot vectors
Y_true = np.argmax(y_test,axis = 1)
# compute the confusion matrix
confusion_mtx = tf.math.confusion_matrix(Y_true, Y_pred_classes) 
plt.figure(figsize=(10, 8))
sns.heatmap(confusion_mtx, annot=True, fmt='g')
<AxesSubplot:>

27-3.png

结果分析

  • 从上面的混淆矩阵结果来看,分类效果是比较理想的。唯一有点问题的就是[0,6]和[3,5].这个从我们的常识也知道,0和6这两个数字,如果是手写体的话,确实相似度比较高.还有3和5
相关实践学习
使用PAI-EAS一键部署ChatGLM及LangChain应用
本场景中主要介绍如何使用模型在线服务(PAI-EAS)部署ChatGLM的AI-Web应用以及启动WebUI进行模型推理,并通过LangChain集成自己的业务数据。
机器学习概览及常见算法
机器学习(Machine Learning, ML)是人工智能的核心,专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能,它是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。 本课程将带你入门机器学习,掌握机器学习的概念和常用的算法。
相关文章
|
1月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式
将Keras训练好的.hdf5模型转换为TensorFlow的.pb模型,然后再转换为TensorRT支持的.uff格式,并提供了转换代码和测试步骤。
87 3
深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式
|
1月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
【大作业-04】手把手教你构建垃圾分类系统-基于tensorflow2.3
本文介绍了基于TensorFlow 2.3的垃圾分类系统,通过B站视频和博客详细讲解了系统的构建过程。系统使用了包含8万张图片、245个类别的数据集,训练了LeNet和MobileNet两个卷积神经网络模型,并通过PyQt5构建了图形化界面,用户上传图片后,系统能识别垃圾的具体种类。此外,还提供了模型和数据集的下载链接,方便读者复现实验。垃圾分类对于提高资源利用率、减少环境污染具有重要意义。
42 0
【大作业-04】手把手教你构建垃圾分类系统-基于tensorflow2.3
|
11天前
|
机器学习/深度学习 人工智能 算法
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
手写数字识别系统,使用Python作为主要开发语言,基于深度学习TensorFlow框架,搭建卷积神经网络算法。并通过对数据集进行训练,最后得到一个识别精度较高的模型。并基于Flask框架,开发网页端操作平台,实现用户上传一张图片识别其名称。
39 0
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
|
11天前
|
机器学习/深度学习 人工智能 算法
基于深度学习的【蔬菜识别】系统实现~Python+人工智能+TensorFlow+算法模型
蔬菜识别系统,本系统使用Python作为主要编程语言,通过收集了8种常见的蔬菜图像数据集('土豆', '大白菜', '大葱', '莲藕', '菠菜', '西红柿', '韭菜', '黄瓜'),然后基于TensorFlow搭建卷积神经网络算法模型,通过多轮迭代训练最后得到一个识别精度较高的模型文件。在使用Django开发web网页端操作界面,实现用户上传一张蔬菜图片识别其名称。
51 0
基于深度学习的【蔬菜识别】系统实现~Python+人工智能+TensorFlow+算法模型
|
16天前
|
机器学习/深度学习 TensorFlow 算法框架/工具
利用Python和TensorFlow构建简单神经网络进行图像分类
利用Python和TensorFlow构建简单神经网络进行图像分类
39 3
|
28天前
|
机器学习/深度学习 人工智能 算法
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
车辆车型识别,使用Python作为主要编程语言,通过收集多种车辆车型图像数据集,然后基于TensorFlow搭建卷积网络算法模型,并对数据集进行训练,最后得到一个识别精度较高的模型文件。再基于Django搭建web网页端操作界面,实现用户上传一张车辆图片识别其类型。
72 0
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
|
1月前
|
机器学习/深度学习 数据可视化 TensorFlow
使用TensorFlow构建一个简单的图像分类模型
【10月更文挑战第18天】使用TensorFlow构建一个简单的图像分类模型
51 1
|
1月前
|
机器学习/深度学习 SQL 数据采集
基于tensorflow、CNN网络识别花卉的种类(图像识别)
基于tensorflow、CNN网络识别花卉的种类(图像识别)
30 1
|
1月前
|
机器学习/深度学习 TensorFlow API
使用 TensorFlow 和 Keras 构建图像分类器
【10月更文挑战第2天】使用 TensorFlow 和 Keras 构建图像分类器
|
1月前
|
机器学习/深度学习 编解码 算法
【深度学习】经典的深度学习模型-01 开山之作:CNN卷积神经网络LeNet-5
【深度学习】经典的深度学习模型-01 开山之作:CNN卷积神经网络LeNet-5
41 0

热门文章

最新文章

下一篇
无影云桌面