keras实战项目——CIFAR-10 图像分类

简介: 本文将首先介绍这些深度神经网络的零件,然后再分别介绍上游的批量输入模块,以及下游的凸优化模块。

我们可以简单的将深度神经网络的模块,分成以下的三个部分,即深度神经网络上游的基于生成器的 输入模块,深度神经网络本身,以及深度神经网络下游基于批量梯度下降算法的 凸优化模块:

批量输入模块
各种深度学习零件搭建的深度神经网络
凸优化模块
其中,搭建深度神经网络的零件又可以分成以下类别:

各种深度学习零件搭建的深度神经网络

image

需要强调一下,这些层与之前一样,都 同时包括了正向传播、反向传播两条通路。我们这里只介绍比较好理解的正向传播过程,基于其导数的反向过程同样也是存在的,其代码已经包括在 Tensorflow 的框架中对应的模块里,可以直接使用。

接下来的部分,我们将首先介绍这些深度神经网络的零件,然后再分别介绍上游的批量输入模块,以及下游的凸优化模块。

1.深度神经网络的基本零件

1.1 常用层:

1.1.1. Dense

Dense 层,就是我们上一篇文章里提到的 Linear 层,即 y=wx+b ,计算乘法以及加法。

1.1.2. Activation

Activation 层在我们上一篇文章中,同样出现过,即 Tanh层以及Sigmoid层,他们都是 Activation 层的一种。当然 Activation 不止有这两种形式,比如有:

image

这其中 relu 层可能是深度学习时代最重要的一种激发函数,在2011年首次被提出。由公式可见,relu 相比早期的 tanh 与 sigmoid函数, relu 有两个重要的特点,其一是在较小处都是0(sigmoid,relu)或者-5(tanh),但是较大值relu函数没有取值上限。其次是relu层在0除不可导,是一个非线性的函数:

image

即 y=x*(x>0)

对其求导,其结果是:

image

1.1.3. Dropout

Dropout 层,指的是在训练过程中,每次更新参数时将会随机断开一定百分比(rate)的输入神经元,这种方式可以用于防止过拟合。

image

1.1.4. Flatten

Flatten 层,指的是将高维的张量(Tensor, 如二维的矩阵、三维的3D矩阵等)变成一个一维张量(向量)。Flatten 层通常位于连接深度神经网络的 卷积层部分 以及 全连接层部分。

1.2 卷积层

提到卷积层,就必须讲一下卷积神经网络。我们在第一讲的最后部分提高深度学习模型预测的准确性部分,提了一句 “使用更复杂的深度学习网络,在图片中挖出数以百万计的特征”。这种“更复杂的神经网络”,指的就是卷积神经网络。卷积神经网络相比之前基于 Dense 层建立的神经网络,有所区别之处在于,卷积神经网络可以使用更少的参数,对局部特征有更好的理解。

1.2.1. Conv2D

我们这里以2D 的卷积神经网络为例,来逐一介绍卷积神经网络中的重要函数。比如我们使用一个形状如下的卷积核:

image

扫描这样一个二维矩阵,比如一张图片:

image

其过程与结果会是这样:

image

当然,这里很重要的一点,就是正如我们上一讲提到的, Linear 函数的 w, b两个参数都是变量,会在不断的训练中,不断学习更新。卷积神经网络中,卷积核其实也是一个变量。这里的

image

可能只是初始值,也可能是某一次迭代时选用的值。随着模型的不断训练,将会不断的更新成其他值,结果也将会是一个不规则的形状。清楚了其原理,卷积神经网络还需要再理解几个输入参数:

Conv2D(filters, kernel_size, strides=(1, 1), padding='valid', ...)

其中:

filters 指的是输出的卷积层的层数。如上面的动图,只输出了一个卷积层,filters = 1,而实际运用过程中,一次会输出很多卷积层。
kernel_size 指的是卷积层的大小,是一个 二维数组,分别代表卷积层有几行、几列。
strides 指的是卷积核在输入层扫描时,在 x,y 两个方向,每间隔多长扫执行一次扫描。
padding 这里指的是是否扫描边缘。如果是 valid,则仅仅扫描已知的矩阵,即忽略边缘。而如果是 same,则将根据情况在边缘补上0,并且扫描边缘,使得输出的大小等于 input_size / strides。

1.2.2. Cropping2D

这里 Cropping2D 就比较好理解了,就是特地选取输入图像的某一个固定的小部分。比如车载摄像头检测路面的马路线时,摄像头上半部分拍到的天空就可以被 Cropping2D函数直接切掉忽略不计。

image

1.2.3. ZeroPadding2D

1.2.1部分提到输入参数时,提到 padding参数如果是same,扫描图像边缘时会补上0,确保输出数量等于 input / strides。这里 ZeroPadding2D 的作用,就是在图像外层边缘补上几层0。如下图,就是对原本 32x32x3 的图片进行 ZeroPadding2D(padding=(2, 2)) 操作后的结果:

image

1.3. 池化层

1.3.1. MaxPooling2D

可能大家在上一部分会意识到一点,就是通过与一个相同的、大小为11x11的卷积核做卷积操作,每次移动步长为1,则相邻的结果会非常接近,正是由于结果接近,有很多信息是冗余的。

因此,MaxPooling 就是一种减少模型冗余程度的方法。以 2x 2 MaxPooling 为例。图中如果是一个 4x4 的输入矩阵,则这个 4x4 的矩阵,会被分割成由两行、两列组成的 2x2 子矩阵,然后每个 2x2 子矩阵取一个最大值作为代表,由此得到一个两行、两列的结果:

image

1.3.2. AveragePooling2D

AveragePooling 与 MaxPooling 类似,不同的是一个取最大值,一个是平均值。如果上图的 MaxPooling 换成 AveragePooling2D,结果会是:

image

1.4.正则化层

除了之前提到的 Dropout 策略,以及用 GlobalAveragePooling取代全连接层的策略,还有一种方法可以降低网络的过拟合,就是正则化,这里着重介绍下 BatchNormalization。

1.4.1. BatchNormalization

BatchNormalization 确实适合降低过拟合,但他提出的本意,是为了加速神经网络训练的收敛速度。比如我们进行最优值搜索时,我们不清楚最优值位于哪里,可能是上千、上万,也可能是个负数。这种不确定性,会造成搜索时间的浪费。

BatchNormalization就是一种将需要进行最优值搜索数据,转换成标准正态分布,这样optimizer就可以加速优化:

输入:一批input 数据: B

期望输出: β,γ

image

2.深度神经网络的上下游结构

深度神经网络的参数大小动辄几十M、上百M,如何合理训练这些参数是个大问题。这就需要在这个网络的上下游,合理处理这个问题。

海量参数背后的意义是,深度神经网络可以获取海量的特征。第一讲中提到过,深度学习是脱胎于传统机器学习的,两者之间的区别,就是深度学习可以在图像处理中,自动进行特征工程,如我们第一讲所言:

想让计算机帮忙挖掘、标注这些更多的特征,这就离不开 更优化的模型 了。事实上,这几年深度学习领域的新进展,就是以这个想法为基础产生的。我们可以使用更复杂的深度学习网络,在图片中挖出数以百万计的特征。

这时候,问题也就来了。机器学习过程中,是需要一个输入文件的。这个输入文件的行、列,分别指代样本名称以及特征名称。如果是进行百万张图片的分类,每个图片都有数以百万计的特征,我们将拿到一个 百万样本 x 百万特征 的巨型矩阵。传统的机器学习方法拿到这个矩阵时,受限于计算机内存大小的限制,通常是无从下手的。也就是说,传统机器学习方法,除了在多数情况下不会自动产生这么多的特征以外,模型的训练也会是一个大问题。

深度学习算法为了实现对这一量级数据的计算,做了以下算法以及工程方面的创新:

将全部所有数据按照样本拆分成若干批次,每个批次大小通常在十几个到100多个样本之间。(详见下文 输入模块)

将产生的批次逐一参与训练,更新参数。(详见下文 凸优化模块)
使用 GPU 等计算卡代替 CPU,加速并行计算速度。

这就有点《愚公移山》的意思了。我们可以把训练深度神经网络的训练任务,想象成是搬走一座大山。成语故事中,愚公的办法是既然没有办法直接把山搬走,那就子子孙孙,每人每天搬几筐土走,山就会越来越矮,总有一天可以搬完——这种任务分解方式就如同深度学习算法的分批训练方式。同时,随着科技进步,可能搬着搬着就用翻斗车甚至是高达来代替背筐,就相当于是用 GPU 等高并行计算卡代替了 CPU。

于是,我们这里将主要提到的上游输入模块,以及下游凸优化模块,实际上就是在说如何使用愚公移山的策略,用 少量多次 的方法,去“搬”深度神经网络背后大规模计算量这座大山。

2.2. 输入模块

这一部分实际是在说,当我们有成千上万的图片,存在硬盘中时,如何实现一个函数,每调用一次,就会读取指定张数的图片(以n=32为例),将其转化成矩阵,返回输出。

有 Python 基础的人可能意识到了,这里可能是使用了 Python 的 生成器 特性。其具体作用如廖雪峰博客所言:

创建一个包含100万个元素的 list,不仅占用很大的存储空间,如果我们仅仅需要访问前面几个元素,那后面绝大多数元素占用的空间都白白浪费了。 所以,如果 list 元素可以按照某种算法推算出来,那我们是否可以在循环的过程中不断推算出后续的元素呢?这样就不必创建完整的list,从而节省大量的空间。在Python中,这种一边循环一边计算的机制,称为生成器:generator。
其关键的写法,是把传统函数的 return 换成 yield:

image

next(generator)

即可一次返回 32 张图像以及对应的标注信息。

当然,keras 同样提供了这一模块,ImageDataGenerator,并且还是加强版,可以对图片进行 增强处理(data argument)(如旋转、反转、白化、截取等)。图片的增强处理在样本数量不多时增加样本量——因为如果图中是一只猫,旋转、反转、颜色调整之后,这张图片可能会不太相同,但它仍然是一只猫。

datagen = ImageDataGenerator(
            featurewise_center=False,
            samplewise_center=False,
            featurewise_std_normalization=False,
            samplewise_std_normalization=False,
            zca_whitening=False,
            width_shift_range=0.1,
            height_shift_range=0.1,
            horizontal_flip=True,
            vertical_flip=False)

# compute quantities required for featurewise normalization
datagen.fit(X_train)

2.3 凸优化模块

这一部分谈的是,如何使用基于批量梯度下降算法的凸优化模块,优化模型参数。

前面提到,深度学习的“梯度下降”计算,可以理解成搬走一座大山,而“批量梯度下降”,则是一群人拿着土筐,一点一点把山上的土给搬下山。那么这一点具体应该如何实现呢?其实在第二讲,我们就实现了一个随机批量梯度下降(Stochastic gradient descent, SGD),这里再回顾一下:

image

当然,SGD 其实并不是一个很好的方法,有很多改进版本,可以用下面这张gif图概况:

image

Keras 里,可以直接使用 SGD, Adagrad, Adadelta, RMSProp 以及 Adam 等模块。其实在优化过程中,直接使用 Adam 默认参数,基本就可以得到最优的结果:

from keras.optimizers import Adam
adam = Adam()model.compile(loss='categorical_crossentropy',
                optimizer=adam,
                metrics=['accuracy'])

3.实战项目——CIFAR-10 图像分类

最后我们用一个keras 中的示例,首先做一些前期准备:

image

核心部分,用各种零件搭建深度神经网络:

datagen = ImageDataGenerator(
            featurewise_center=False,
            samplewise_center=False,
            featurewise_std_normalization=False,
            samplewise_std_normalization=False,
            zca_whitening=False,
            rotation_range=0,
            width_shift_range=0.1,
            height_shift_range=0.1,
            horizontal_flip=True,
            vertical_flip=False)datagen.fit(X_train)

核心部分,用各种零件搭建深度神经网络:

image

image

下游部分,使用凸优化模块:

adam = Adam(lr=0.0001)model.compile(loss='categorical_crossentropy',

                   optimizer=adam,
                   metrics=['accuracy'])

最后,开始训练模型,并且评估模型准确性:

image

以上代码本人使用 Pascal TitanX 执行,50个 epoch 中,每个 epoch 用时 12s 左右,总计用时在十五分钟以内,约25 epoch 后,验证集的准确率数会逐步收敛在0.8左右。

原文发布时间为:2018-07-04
本文来自云栖社区合作伙伴“大数据挖掘DT机器学习”,了解相关信息可以关注“大数据挖掘DT机器学习

相关文章
|
机器学习/深度学习 算法 索引
LSTM(长短期记忆网络)原理介绍
LSTM算法是一种重要的目前使用最多的时间序列算法,是一种特殊的RNN(Recurrent Neural Network,循环神经网络),能够学习长期的依赖关系。主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题。简单来说,就是相比普通的RNN,LSTM能够在更长的序列中有更好的表现。
9166 0
LSTM(长短期记忆网络)原理介绍
|
2月前
|
图形学 Windows
IBM SPSS Amos 29安装教程 Windows版:旧版清理+自定义路径+Crack替换指南
Amos是IBM开发的专业结构方程模型(SEM)分析软件,支持回归、因子分析、路径建模等高级统计方法。本指南详述Amos 29(64位)的下载、清理、安装与授权激活全流程,助用户快速完成本地部署。(239字)
|
3月前
|
存储 安全 固态存储
2026阿里云服务器价格表:最新收费标准与38元1年、9.9元1个月、99元1年等活动价格参考
阿里云服务器收费标准涵盖实例配置、带宽及云盘三大核心组件,价格随规格、时长动态调整。2026年活动中推出多类优惠:轻量应用服务器2核4G低至9.9元/月、199元/年;经济型e实例2核2G 3M带宽99元/年;九代ECS(如计算型c9i 8核16G)年付低至6.4折。选购时需注意带宽与CPU/内存的匹配、云盘类型选择及实例适用场景。用户还可领取各种优惠券,在活动价基础上进一步减免,实现成本优化。
1425 4
|
SQL 安全 网络协议
常用和不常用端口一览表收藏
大家在学习计算机的时候,对于最常用的几个端口比如80端口肯定有很深的印象,但是对于其他一些不是那么常用的端口可能就没那么了解。所以,在一些使用频率相对较高的端口上,很容易会引发一些由于陌生而出现的错误,或者被黑客利用某些端口进行入侵。
4774 0
|
16天前
|
人工智能 安全 算法
大模型应用:AI 智能体核心引擎:RAG检索增强生成原理与医疗场景深度落地.126
本文详解RAG(检索增强生成)在医疗智能体中的落地实践:针对大模型知识过时、幻觉、专业性不足三大痛点,基于Qwen本地大模型、MiniLM嵌入、FAISS向量库与LangChain框架,实现全流程可追溯、全本地化、无幻觉的精准问答。含环境配置、适配器封装、知识库构建及调试分析。
252 7
|
16天前
|
人工智能 安全 开发者
Claw-Eval开源:300个真实任务,端到端评测AI智能体的完成度、安全性与鲁棒性
Claw-Eval是面向自主Agent的端到端评测框架,突破“只看结果”局限,聚焦任务执行全过程——可追溯、合规、容错。基于300个人工验证的真实任务,从完成度、安全性、鲁棒性三维度评估14个前沿模型,开源数据集、排行榜及代码。
378 4
|
2月前
|
人工智能 运维 机器人
保姆级图文教程|阿里云轻量服务器部署OpenClaw、Discord集成与千问Qwen3.6-Plus全配置指南
本文完整覆盖从**轻量服务器实例创建、端口放行、OpenClaw初始化、Discord深度集成、大模型API配置、技能扩展、运维排错**的全流程,所有步骤均为2026年4月最新实践,配合详细的避坑指南与运维命令,可解决新手部署中90%以上的问题。遵循**“选对海外地域、放通核心端口、准确配置凭证、及时重启服务、使用专用小号”**五大核心原则,即可实现OpenClaw 7×24小时稳定运行,通过Discord随时随地与专属AI助理交互,高效完成社群管理、内容创作、代码编写、信息查询等各类任务,快速落地AI智能化应用场景,让AI真正成为个人与团队的高效生产力工具。
688 4
|
4月前
|
数据可视化 关系型数据库 MySQL
PhpStudy2018怎么用?完整安装与使用指南(新手必看)
PhpStudy2018是一款Windows下PHP集成开发环境,一键安装Apache、MySQL、PHP,免去逐个配置的麻烦。支持多版本PHP切换、可视化服务管理,根目录(WWW)即放即用,适合初学者学习与本地网站调试。(239字)
|
存储 缓存 移动开发
HTML5 的离线储存怎么使用,工作原理
HTML5 的离线储存怎么使用,工作原理
468 0
|
11月前
|
监控 供应链 API
1688商品列表API全参数指南:从基础搜索到高级筛选
1688商品列表API是阿里巴巴B2B平台的核心接口,支持关键词搜索、高级筛选、排序与分页功能,适用于选品、价格监控等场景。数据规范、稳定高效,日均调用量大。提供Python示例代码,便于快速接入与扩展应用。