Python深度学习入门——手写数字分类

简介: Python深度学习入门——手写数字分类

什么是 Keras


Keras 是基于 TensorFlowTheano(由加拿大蒙特利尔大学开发的机器学习框架)的深度学习库,是由纯 python 编写而成的高层神经网络 API,也仅支持 Python 开发。它是为了支持快速实践而对 Tensorflow 或者 Theano 的再次封装,让我们可以不用关注过多的底层细节,能够把想法快速转换为结果。它也很灵活,且比较容易学。


安装 Keras


使用豆瓣镜像源安装 Keras 库。

pip install -i https://pypi.douban.com/simple Keras
复制代码


手写数字分类


导入数据集

加载 Keras 中的 MNIST 数据集。

from keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
复制代码


训练集

  • train_images:训练集样本
  • train_labels:训练集标签


测试集

  • test_images:测试集样本
  • test_labels:测试集标签


查看数据集的形状

train_images.shape
# Out: (60000, 28, 28)
train_labels.shape
# Out: (60000,)
test_images.shape
# Out: (10000, 28, 28)
test_labels.shape
# Out: (10000,)
复制代码


构建网络


层(layer)是神经网络的核心组件,它是一种数据处理模块,你可以将它看成数据过滤器。进去一些数据,出来的数据变得更加有用。大多数深度学习都是将简单的层链接起来,从而实现渐进式 的数据蒸馏(data distillation)。深度学习模型就像是数据处理的筛子,包含一系列越来越精细的数据过滤器(即层)。


下面先导入所需模块,构造一个序列模型(Sequential),序列模型是多个网络层的线性堆叠。即“一条路走到黑”。

from keras import models
from keras import layers
network = models.Sequential()
复制代码


通过 add 方法将 layer 加入模型中。

network.add(layers.Dense(512, activation='relu', input_shape=(28 * 28,)))
复制代码


主要参数

  • units:神经元节点数,即输出空间维度
  • activation:激活函数,若不指定,则不使用激活函数 (即线性激活: a(x) = x)
  • input_shape:即张量的形状


relu 为线性整流函数,它返回逐元素的 max(x, 0)。

再添加第二层,一个 10 路 softmax 层,通过 Softmax 函数可以将多分类的输出值转换为范围在 [0, 1]和为 1 的概率分布。

network.add(layers.Dense(10, activation='softmax'))
复制代码


编译(compile)

训练网络之前,我们还需要选择编译步骤的三个参数。

  • 损失函数(loss function):网络如何衡量在训练数据上的性能。
  • 优化器(optimizer):基于训练数据和损失函数来更新网络的机制。
  • 在训练和测试过程中需要监控的指标(metric):本例只关心精度,即正确分类的图像所 占的比例。
network.compile(optimizer='rmsprop',
                loss='categorical_crossentropy',
                metrics=['accuracy'])
复制代码


主要参数

  • RMSprop:RMSProp 优化器是 AdaGrad 算法的一种改进。将梯度除以最近幅度的移动平均值。
  • categorical_crossentropy:分类交叉熵,推导公式

−∑i=1outputsize yi×log⁡y^i-\sum_{i=1}^{\text {outputsize }} y_{i} \times \log _{\hat{y}_{i}}i=1outputsize yi×logy^i

对于损失函数和优化器后续文章会详细讲解


数据预处理

在开始训练之前,我们将对数据进行预处理,将其变换为 network 要求的形状 ,并缩放到所有值都在 [0, 1] 区间。


比如,之前训练图像保存在一个 uint8 类型的数组中,其形状为 (60000, 28, 28),取值区间为 [0, 255]。我们需要将其变换为一个 float32 数组,其形状为(60000, 28 * 28),取值范围为 0~1。


train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype('float32') / 255
test_images = test_images.reshape((10000, 28 * 28))
test_images = test_images.astype('float32') / 255
复制代码


类别转换独热编码

现在我们需要对标签进行分类编码,即将类别标签转换为二进制(只包括0和1)的矩阵类型表示。


看一个简单的例子。我们定义一个类别标签 labels ,并通过 keras.utils.to_categorical 将其转换为独热向量。

from keras.utils import to_categorical
labels = [0,1,2,3,4,5]
convert_to_one_hot = to_categorical(labels)
convert_to_one_hot
复制代码

image.png

可以看到,原来类别标签中的每个值都转换为矩阵里的一个行向量。原标签中的 0 为[1. 0. 0. 0. 0. 0. 0. 0. 0.],第一个作为有效位,其余全部为0。

下面回到本例,对标签进行分类编码。

from keras.utils import to_categorical
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)
复制代码


我们可以看一下现在训练集标签(train_labels)的形状。

# 转换前 Out: (60000,)
train_labels.shape
# Out: (60000, 10)
复制代码


训练网络

现在我们开始训练网络,通过 fit 方法来训练 network

network.fit(train_images, train_labels, epochs=5, batch_size=128)
复制代码

image.png


主要参数

  • train_images:训练集样本
  • train_labels:训练集标签
  • epochs:训练模型迭代次数
  • batch_size:每次梯度更新的样本数。在深度学习中,一般采用 SGD 训练,即每次训练在训练集中取 batchsize 个样本训练

上述每次训练输出两个值:一个是网络在训练数据上的损失(loss),即当前输出与预期值的差距,另一个是网络在训练数据上的精度(acc)。

可以看到 loss 的值随着训练次数的增加不断降低,精度最终也达到了98.9%,下面看一下模型在测试集上的性能。


模型测试

test_loss, test_acc = network.evaluate(test_images, test_labels)
test_loss, test_acc
复制代码

image.png

测试集精度为 97.9%,比训练集精度要低。训练精度和测试精度之间的这种差距是过拟合(overfit)造成的,导致模型的泛化性能较差。



相关文章
|
1天前
|
缓存 算法 数据处理
Python入门:9.递归函数和高阶函数
在 Python 编程中,函数是核心组成部分之一。递归函数和高阶函数是 Python 中两个非常重要的特性。递归函数帮助我们以更直观的方式处理重复性问题,而高阶函数通过函数作为参数或返回值,为代码增添了极大的灵活性和优雅性。无论是实现复杂的算法还是处理数据流,这些工具都在开发者的工具箱中扮演着重要角色。本文将从概念入手,逐步带你掌握递归函数、匿名函数(lambda)以及高阶函数的核心要领和应用技巧。
Python入门:9.递归函数和高阶函数
|
1天前
|
开发者 Python
Python入门:8.Python中的函数
### 引言 在编写程序时,函数是一种强大的工具。它们可以将代码逻辑模块化,减少重复代码的编写,并提高程序的可读性和可维护性。无论是初学者还是资深开发者,深入理解函数的使用和设计都是编写高质量代码的基础。本文将从基础概念开始,逐步讲解 Python 中的函数及其高级特性。
Python入门:8.Python中的函数
|
1天前
|
存储 SQL 索引
Python入门:7.Pythond的内置容器
Python 提供了强大的内置容器(container)类型,用于存储和操作数据。容器是 Python 数据结构的核心部分,理解它们对于写出高效、可读的代码至关重要。在这篇博客中,我们将详细介绍 Python 的五种主要内置容器:字符串(str)、列表(list)、元组(tuple)、字典(dict)和集合(set)。
Python入门:7.Pythond的内置容器
|
1天前
|
存储 索引 Python
Python入门:6.深入解析Python中的序列
在 Python 中,**序列**是一种有序的数据结构,广泛应用于数据存储、操作和处理。序列的一个显著特点是支持通过**索引**访问数据。常见的序列类型包括字符串(`str`)、列表(`list`)和元组(`tuple`)。这些序列各有特点,既可以存储简单的字符,也可以存储复杂的对象。 为了帮助初学者掌握 Python 中的序列操作,本文将围绕**字符串**、**列表**和**元组**这三种序列类型,详细介绍其定义、常用方法和具体示例。
Python入门:6.深入解析Python中的序列
|
1天前
|
知识图谱 Python
Python入门:4.Python中的运算符
Python是一间强大而且便捷的编程语言,支持多种类型的运算符。在Python中,运算符被分为算术运算符、赋值运算符、复合赋值运算符、比较运算符和逻辑运算符等。本文将从基础到进阶进行分析,并通过一个综合案例展示其实际应用。
|
1天前
|
程序员 UED Python
Python入门:3.Python的输入和输出格式化
在 Python 编程中,输入与输出是程序与用户交互的核心部分。而输出格式化更是对程序表达能力的极大增强,可以让结果以清晰、美观且易读的方式呈现给用户。本文将深入探讨 Python 的输入与输出操作,特别是如何使用格式化方法来提升代码质量和可读性。
Python入门:3.Python的输入和输出格式化
|
1天前
|
存储 Linux iOS开发
Python入门:2.注释与变量的全面解析
在学习Python编程的过程中,注释和变量是必须掌握的两个基础概念。注释帮助我们理解代码的意图,而变量则是用于存储和操作数据的核心工具。熟练掌握这两者,不仅能提高代码的可读性和维护性,还能为后续学习复杂编程概念打下坚实的基础。
Python入门:2.注释与变量的全面解析
|
1天前
|
机器学习/深度学习 人工智能 算法框架/工具
Python入门:1.Python介绍
Python是一种功能强大、易于学习和运行的解释型高级语言。由**Guido van Rossum**于1991年创建,Python以其简洁、易读和十分工程化的设计而带来了庞大的用户群体和丰富的应用场景。这个语言在全球范围内都被认为是**创新和效率的重要工具**。
Python入门:1.Python介绍
|
7天前
|
机器学习/深度学习 人工智能 算法
基于Python深度学习的【蘑菇识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
蘑菇识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了9种常见的蘑菇种类数据集【"香菇(Agaricus)", "毒鹅膏菌(Amanita)", "牛肝菌(Boletus)", "网状菌(Cortinarius)", "毒镰孢(Entoloma)", "湿孢菌(Hygrocybe)", "乳菇(Lactarius)", "红菇(Russula)", "松茸(Suillus)"】 再使用通过搭建的算法模型对数据集进行训练得到一个识别精度较高的模型,然后保存为为本地h5格式文件。最后使用Django框架搭建了一个Web网页平台可视化操作界面,
51 11
基于Python深度学习的【蘑菇识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
|
1月前
|
机器学习/深度学习 算法 前端开发
基于Python深度学习果蔬识别系统实现
本项目基于Python和TensorFlow,使用ResNet卷积神经网络模型,对12种常见果蔬(如土豆、苹果等)的图像数据集进行训练,构建了一个高精度的果蔬识别系统。系统通过Django框架搭建Web端可视化界面,用户可上传图片并自动识别果蔬种类。该项目旨在提高农业生产效率,广泛应用于食品安全、智能农业等领域。CNN凭借其强大的特征提取能力,在图像分类任务中表现出色,为实现高效的自动化果蔬识别提供了技术支持。
基于Python深度学习果蔬识别系统实现

推荐镜像

更多