第7章 使用Keras开发神经网络

简介: 第7章 使用Keras开发神经网络 Keras基于Python,开发深度学习模型很容易。Keras将Theano和TensorFlow的数值计算封装好,几句话就可以配置并训练神经网络。本章开始使用Keras开发神经网络。

第7章 使用Keras开发神经网络

Keras基于Python,开发深度学习模型很容易。Keras将Theano和TensorFlow的数值计算封装好,几句话就可以配置并训练神经网络。本章开始使用Keras开发神经网络。本章将:

  • 将CSV数据读入Keras
  • 用Keras配置并编译多层感知器模型
  • 用验证数据集验证Keras模型

我们开始吧。

7.1 简介

虽然代码量不大,但是我们还是慢慢来。大体分几步:

  1. 导入数据
  2. 定义模型
  3. 编译模型
  4. 训练模型
  5. 测试模型
  6. 写出程序

7.2 皮马人糖尿病数据集

我们使用皮马人糖尿病数据集(Pima Indians onset of diabetes),在UCI的机器学习网站可以免费下载。数据集的内容是皮马人的医疗记录,以及过去5年内是否有糖尿病。所有的数据都是数字,问题是(是否有糖尿病是1或0),是二分类问题。数据的数量级不同,有8个属性:

  1. 怀孕次数
  2. 2小时口服葡萄糖耐量试验中的血浆葡萄糖浓度
  3. 舒张压(毫米汞柱)
  4. 2小时血清胰岛素(mu U/ml)
  5. 体重指数(BMI)
  6. 糖尿病血系功能
  7. 年龄(年)
  8. 类别:过去5年内是否有糖尿病

所有的数据都是数字,可以直接导入Keras。本书后面也会用到这个数据集。数据有768行,前5行的样本长这样:

6,148,72,35,0,33.6,0.627,50,1
1,85,66,29,0,26.6,0.351,31,0
8,183,64,0,0,23.3,0.672,32,1
1,89,66,23,94,28.1,0.167,21,0
0,137,40,35,168,43.1,2.288,33,1

数据在本书代码的data 目录,也可以在UCI机器学习的网站下载。把数据和Python文件放在一起,改名:

pima-indians-diabetes.csv

基准准确率是65.1%,在10次交叉验证中最高的正确率是77.7%。在UCI机器学习的网站可以得到数据集的更多资料。

7.3 导入资料

使用随机梯度下降时最好固定随机数种子,这样你的代码每次运行的结果都一致。这种做法在演示结果、比较算法或debug时特别有效。你可以随便选种子:

# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)

现在导入皮马人数据集。NumPy的loadtxt()函数可以直接带入数据,输入变量是8个,输出1个。导入数据后,我们把数据分成输入和输出两组以便交叉检验:

# load pima indians dataset
dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]

这样我们的数据每次结果都一致,可以定义模型了。

7.4 定义模型

Keras的模型由层构成:我们建立一个Sequential模型,一层层加入神经元。第一步是确定输入层的数目正确:在创建模型时用input_dim参数确定。例如,有8个输入变量,就设成8。

隐层怎么设置?这个问题很难回答,需要慢慢试验。一般来说,如果网络够大,即使存在问题也不会有影响。这个例子里我们用3层全连接网络。

全连接层用Dense类定义:第一个参数是本层神经元个数,然后是初始化方式和激活函数。这里的初始化方法是0到0.05的连续型均匀分布(uniform),Keras的默认方法也是这个。也可以用高斯分布进行初始化(normal)。

前两层的激活函数是线性整流函数(relu),最后一层的激活函数是S型函数(sigmoid)。之前大家喜欢用S型和正切函数,但现在线性整流函数效果更好。为了保证输出是0到1的概率数字,最后一层的激活函数是S型函数,这样映射到0.5的阈值函数也容易。前两个隐层分别有12和8个神经元,最后一层是1个神经元(是否有糖尿病)。

# create model
model = Sequential()
model.add(Dense(12, input_dim=8, init='uniform', activation='relu')) model.add(Dense(8, init='uniform', activation='relu')) model.add(Dense(1, init='uniform', activation='sigmoid'))

7.5 编译模型

定义好的模型可以编译:Keras会调用Theano或者TensorFlow编译模型。后端会自动选择表示网络的最佳方法,配合你的硬件。这步需要定义几个新的参数。训练神经网络的意义是:找到最好的一组权重,解决问题。

我们需要定义损失函数和优化算法,以及需要收集的数据。我们使用binary_crossentropy,错误的对数作为损失函数;adam作为优化算法,因为这东西好用。想深入了解请查阅:Adam: A Method for Stochastic Optimization论文。因为这个问题是分类问题,我们收集每轮的准确率。

7.6 训练模型

终于开始训练了!调用模型的fit()方法即可开始训练。

网络按轮训练,通过nb_epoch参数控制。每次送入的数据(批尺寸)可以用batch_size参数控制。这里我们只跑150轮,每次10个数据。多试试就知道了。

# Fit the model
model.fit(X, Y, nb_epoch=150, batch_size=10)

现在CPU或GPU开始煎鸡蛋了。

相关文章
|
3月前
|
Linux 开发工具 Android开发
FFmpeg开发笔记(六十)使用国产的ijkplayer播放器观看网络视频
ijkplayer是由Bilibili基于FFmpeg3.4研发并开源的播放器,适用于Android和iOS,支持本地视频及网络流媒体播放。本文详细介绍如何在新版Android Studio中导入并使用ijkplayer库,包括Gradle版本及配置更新、导入编译好的so文件以及添加直播链接播放代码等步骤,帮助开发者顺利进行App调试与开发。更多FFmpeg开发知识可参考《FFmpeg开发实战:从零基础到短视频上线》。
351 2
FFmpeg开发笔记(六十)使用国产的ijkplayer播放器观看网络视频
|
2月前
|
API
鸿蒙开发:切换至基于rcp的网络请求
本文的内容主要是把之前基于http封装的库,修改为当前的Remote Communication Kit(远场通信服务),无非就是通信的方式变了,其他都大差不差。
117 4
鸿蒙开发:切换至基于rcp的网络请求
|
2月前
|
存储 网络协议 物联网
C 语言物联网开发之网络通信与数据传输难题
本文探讨了C语言在物联网开发中遇到的网络通信与数据传输挑战,分析了常见问题并提出了优化策略,旨在提高数据传输效率和系统稳定性。
|
3月前
|
XML 开发工具 Android开发
FFmpeg开发笔记(五十六)使用Media3的Exoplayer播放网络视频
ExoPlayer最初是为了解决Android早期MediaPlayer控件对网络视频兼容性差的问题而推出的。现在,Android官方已将其升级并纳入Jetpack的Media3库,使其成为音视频操作的统一引擎。新版ExoPlayer支持多种协议,解决了设备和系统碎片化问题,可在整个Android生态中一致运行。通过修改`build.gradle`文件、布局文件及Activity代码,并添加必要的权限,即可集成并使用ExoPlayer进行网络视频播放。具体步骤包括引入依赖库、配置播放界面、编写播放逻辑以及添加互联网访问权限。
243 1
FFmpeg开发笔记(五十六)使用Media3的Exoplayer播放网络视频
|
5月前
|
安全 网络安全 Android开发
安卓与iOS开发:选择的艺术网络安全与信息安全:漏洞、加密与意识的交织
【8月更文挑战第20天】在数字时代,安卓和iOS两大平台如同两座巍峨的山峰,分别占据着移动互联网的半壁江山。它们各自拥有独特的魅力和优势,吸引着无数开发者投身其中。本文将探讨这两个平台的特点、优势以及它们在移动应用开发中的地位,帮助读者更好地理解这两个平台的差异,并为那些正在面临选择的开发者提供一些启示。
135 56
|
5月前
|
C++
C++ Qt开发:QUdpSocket网络通信组件
QUdpSocket是Qt网络编程中一个非常有用的组件,它提供了在UDP协议下进行数据发送和接收的能力。通过简单的方法和信号,可以轻松实现基于UDP的网络通信。不过,需要注意的是,UDP协议本身不保证数据的可靠传输,因此在使用QUdpSocket时,可能需要在应用层实现一些机制来保证数据的完整性和顺序,或者选择在适用的场景下使用UDP协议。
244 2
|
5月前
|
小程序 数据安全/隐私保护
Taro@3.x+Vue@3.x+TS开发微信小程序,网络请求封装
在 `src/http` 目录下创建 `request.ts` 文件,并配置 Taro 的网络请求方法 `Taro.request`,支持多种 HTTP 方法并处理数据加密。
210 0
Taro@3.x+Vue@3.x+TS开发微信小程序,网络请求封装
|
5月前
|
数据采集 存储 前端开发
豆瓣评分9.0!Python3网络爬虫开发实战,堪称教学典范!
今天我们所处的时代是信息化时代,是数据驱动的人工智能时代。在人工智能、物联网时代,万物互联和物理世界的全面数字化使得人工智能可以基于这些数据产生优质的决策,从而对人类的生产生活产生巨大价值。 在这个以数据驱动为特征的时代,数据是最基础的。数据既可以通过研发产品获得,也可以通过爬虫采集公开数据获得,因此爬虫技术在这个快速发展的时代就显得尤为重要,高端爬虫人才的收人也在逐年提高。
|
5月前
|
安全 网络安全 数据安全/隐私保护
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享安卓与iOS开发中的线程管理比较
【8月更文挑战第30天】本文将探讨网络安全与信息安全的重要性,并分享关于网络安全漏洞、加密技术和安全意识的知识。我们将了解常见的网络攻击类型和防御策略,以及如何通过加密技术和提高安全意识来保护个人和组织的信息安全。
|
5月前
|
安全 网络安全 Android开发
探索安卓开发之旅:从新手到专家网络安全与信息安全:防范网络威胁,保护数据安全
【8月更文挑战第29天】在这篇技术性文章中,我们将踏上一段激动人心的旅程,探索安卓开发的世界。无论你是刚开始接触编程的新手,还是希望提升技能的资深开发者,这篇文章都将为你提供宝贵的知识和指导。我们将从基础概念入手,逐步深入到安卓开发的高级主题,包括UI设计、数据存储、网络通信等方面。通过阅读本文,你将获得一个全面的安卓开发知识体系,并学会如何将这些知识应用到实际项目中。让我们一起开启这段探索之旅吧!

热门文章

最新文章