教程 | Tensorflow keras 极简神经网络构建与使用

简介: Tensorflow keras极简神经网络构建教程 Keras介绍Keras (κέρας) 在希腊语中意为号角,它来自古希腊和拉丁文学中的一个文学形象。发布于2015年,是一套高级API框架,其默认的backend是tensorflow,但是可以支持CNTK、Theano、MXNet作为backend运行。

Tensorflow keras极简神经网络构建教程

Keras介绍
Keras (κέρας) 在希腊语中意为号角,它来自古希腊和拉丁文学中的一个文学形象。发布于2015年,是一套高级API框架,其默认的backend是tensorflow,但是可以支持CNTK、Theano、MXNet作为backend运行。其特点是语法简单,容易上手,提供了大量的实验数据接口与预训练网络接口,最初是谷歌的一位工程师开发的,非常适合快速开发。Tensorflow虽然是非常流行的深度学习框架,但是tensorflow开发需要了解计算图与自动微分相关技术,对于完全没有任何深度学习基础的人不是一个很好的选择,而keras完全是为零基础的人准备,它简化了tensorflow中计算图、会话等基本概念,通过Sequential与功能API两个组件实现网络搭建,通过简单的添加一些层就可以快速搭建神经网络模型。

Mnist数据集准备
我们以mnist数据集为例,构建一个神经网络实现手写数字的训练与测试,首先我们需要认识一下mnist数据集,mnist数据集有6万张手写图像,1万张测试图像。Keras通过datase来下载与使用mnist数据集,下载与读取的代码如下:

mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) =mnist.load_data()

通过下面的代码可以显示手写数字图像:

print(train_labels[0])
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([   ])
    plt.grid(False)
    plt.imshow(train_images[i], cmap=plt.cm.gray)
    plt.xlabel(str(train_labels[i]))
plt.show()

image
对数据re-scale到0~1.0之间,对标签进行了one-hot编码,代码如下:

# re-scale to 0~1.0之间
train_images = train_images / 255.0
test_images = test_images / 255.0
train_labels = one_hot(train_labels)
test_labels = one_hot(test_labels)

其中one-hot编码函数如下:

def one_hot(labels):
    onehot_labels = np.zeros(shape=[len(labels), 10])
    for i in range(len(labels)):
        index = labels[i]
        onehot_labels[i][index] = 1
    return onehot_labels

建立模型
构建神经网络

输入层为28x28=784个输入节点

隐藏层120个节点

输出层10个节点

首先需要定义模型:

model = keras.Sequential()

然后按顺序添加模型各层

model.add(keras.layers.Flatten(input_shape=(28, 28)))
model.add(keras.layers.Dense(units=120, activation=tf.nn.relu))
model.add(keras.layers.Dense(units=10, activation=tf.nn.softmax))

编译模型
模型还需要再进行几项设置才可以开始训练。这些设置会添加到模型的编译步骤:

损失函数
衡量模型在训练期间的准确率。我们希望尽可能缩小该函数,以“引导”模型朝着正确的方向优化。
优化器
根据模型看到的数据及其损失函数更新模型的方式。
指标
用于监控训练和测试步骤。以下示例使用准确率,即图像被正确分类的比例

model.compile(optimizer=tf.train.AdamOptimizer(), 
loss="categorical_crossentropy", metrics=['accuracy'])

训练模型
训练神经网络模型需要执行以下步骤:
将训练数据馈送到模型中,在本示例中为 train_images 和 train_labels 数组。
模型学习将图像与标签相关联。我们要求模型对测试集进行预测,在本示例中为 test_images 数组。我们会验证预测结果是否与 test_labels 数组中的标签一致。
要开始训练,请调用 model.fit 方法,使模型与训练数据“拟合”:

model.fit(x=train_images, y=train_labels, epochs=5)

评估模型
模型在测试集数据上运行:

test_loss, test_acc = model.evaluate(x=test_images, y=test_labels)
print("Test Accuracy %.2f"% test_acc)

使用模型进行预测

# 开始预测
cnt = 0
predictions = model.predict(test_images)
for i in range(len(test_images)):
    target = np.argmax(predictions[i])
    label = np.argmax(test_labels[i])
    if target == label:
        cnt += 1
print("correct prediction of total : %.2f"%(cnt/len(test_images)))

卷积神经网络
mnist数据转换为四维

train_images = np.expand_dims(train_images, axis=3)
test_images = np.expand_dims(test_images, axis=3)

创建模型并构建CNN各层

model = keras.Sequential()
model.add(keras.layers.Conv2D(filters=32, kernel_size=5, strides=(1, 1),
                              padding='same', activation=tf.nn.relu, input_shape=(28, 28, 1)))
model.add(keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding='valid'))
model.add(keras.layers.Conv2D(filters=64, kernel_size=3, strides=(1, 1),
                              padding='same', activation=tf.nn.relu))
model.add(keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding='valid'))
model.add(keras.layers.Dropout(0.25))
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(units=128, activation=tf.nn.relu))
model.add(keras.layers.Dropout(0.5))
model.add(keras.layers.Dense(units=10, activation=tf.nn.softmax))

编译与训练模型

# 训练模型
model.compile(optimizer=tf.train.AdamOptimizer(), loss="categorical_crossentropy", metrics=['accuracy'])
model.fit(x=train_images, y=train_labels, epochs=10)

image

原文发布时间为:2018-12-9
本文作者: gloomyfish
本文来自云栖社区合作伙伴“ OpenCV学堂”,了解相关信息可以关注“CVSCHOOL”微信公众号

相关文章
|
2月前
|
JSON 监控 API
在线网络PING接口检测服务器连通状态免费API教程
接口盒子提供免费PING检测API,可测试域名或IP的连通性与响应速度,支持指定地域节点,适用于服务器运维和网络监控。
|
5月前
|
数据采集 存储 监控
Python 原生爬虫教程:网络爬虫的基本概念和认知
网络爬虫是一种自动抓取互联网信息的程序,广泛应用于搜索引擎、数据采集、新闻聚合和价格监控等领域。其工作流程包括 URL 调度、HTTP 请求、页面下载、解析、数据存储及新 URL 发现。Python 因其丰富的库(如 requests、BeautifulSoup、Scrapy)和简洁语法成为爬虫开发的首选语言。然而,在使用爬虫时需注意法律与道德问题,例如遵守 robots.txt 规则、控制请求频率以及合法使用数据,以确保爬虫技术健康有序发展。
653 31
|
5月前
|
域名解析 API PHP
VM虚拟机全版本网盘+免费本地网络穿透端口映射实时同步动态家庭IP教程
本文介绍了如何通过网络穿透技术让公网直接访问家庭电脑,充分发挥本地硬件性能。相比第三方服务受限于转发带宽,此方法利用自家宽带实现更高效率。文章详细讲解了端口映射教程,包括不同网络环境(仅光猫、光猫+路由器)下的设置步骤,并提供实时同步动态IP的两种方案:自建服务器或使用三方API接口。最后附上VM虚拟机全版本下载链接,便于用户在穿透后将服务运行于虚拟环境中,提升安全性与适用性。
|
6月前
|
机器学习/深度学习 人工智能 算法
基于Python深度学习的【害虫识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
害虫识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了12种常见的害虫种类数据集【"蚂蚁(ants)", "蜜蜂(bees)", "甲虫(beetle)", "毛虫(catterpillar)", "蚯蚓(earthworms)", "蜚蠊(earwig)", "蚱蜢(grasshopper)", "飞蛾(moth)", "鼻涕虫(slug)", "蜗牛(snail)", "黄蜂(wasp)", "象鼻虫(weevil)"】 再使用通过搭建的算法模型对数据集进行训练得到一个识别精度较高的模型,然后保存为为本地h5格式文件。最后使用Djan
346 1
基于Python深度学习的【害虫识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
|
7月前
|
机器学习/深度学习 人工智能 算法
基于Python深度学习的【蘑菇识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
蘑菇识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了9种常见的蘑菇种类数据集【"香菇(Agaricus)", "毒鹅膏菌(Amanita)", "牛肝菌(Boletus)", "网状菌(Cortinarius)", "毒镰孢(Entoloma)", "湿孢菌(Hygrocybe)", "乳菇(Lactarius)", "红菇(Russula)", "松茸(Suillus)"】 再使用通过搭建的算法模型对数据集进行训练得到一个识别精度较高的模型,然后保存为为本地h5格式文件。最后使用Django框架搭建了一个Web网页平台可视化操作界面,
470 11
基于Python深度学习的【蘑菇识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
|
7月前
|
监控 Linux PHP
【02】客户端服务端C语言-go语言-web端PHP语言整合内容发布-优雅草网络设备监控系统-2月12日优雅草简化Centos stream8安装zabbix7教程-本搭建教程非docker搭建教程-优雅草solution
【02】客户端服务端C语言-go语言-web端PHP语言整合内容发布-优雅草网络设备监控系统-2月12日优雅草简化Centos stream8安装zabbix7教程-本搭建教程非docker搭建教程-优雅草solution
173 20
|
9月前
|
机器学习/深度学习 人工智能 算法
【宠物识别系统】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+图像识别
宠物识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了37种常见的猫狗宠物种类数据集【'阿比西尼亚猫(Abyssinian)', '孟加拉猫(Bengal)', '暹罗猫(Birman)', '孟买猫(Bombay)', '英国短毛猫(British Shorthair)', '埃及猫(Egyptian Mau)', '缅因猫(Maine Coon)', '波斯猫(Persian)', '布偶猫(Ragdoll)', '俄罗斯蓝猫(Russian Blue)', '暹罗猫(Siamese)', '斯芬克斯猫(Sphynx)', '美国斗牛犬
459 29
【宠物识别系统】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+图像识别
|
8月前
|
前端开发 小程序 Java
uniapp-网络数据请求全教程
这篇文档介绍了如何在uni-app项目中使用第三方包发起网络请求
506 3
|
7月前
|
机器学习/深度学习 PyTorch TensorFlow
深度学习工具和框架详细指南:PyTorch、TensorFlow、Keras
在深度学习的世界中,PyTorch、TensorFlow和Keras是最受欢迎的工具和框架,它们为研究者和开发者提供了强大且易于使用的接口。在本文中,我们将深入探索这三个框架,涵盖如何用它们实现经典深度学习模型,并通过代码实例详细讲解这些工具的使用方法。
|
10月前
|
机器学习/深度学习 数据采集 数据可视化
TensorFlow,一款由谷歌开发的开源深度学习框架,详细讲解了使用 TensorFlow 构建深度学习模型的步骤
本文介绍了 TensorFlow,一款由谷歌开发的开源深度学习框架,详细讲解了使用 TensorFlow 构建深度学习模型的步骤,包括数据准备、模型定义、损失函数与优化器选择、模型训练与评估、模型保存与部署,并展示了构建全连接神经网络的具体示例。此外,还探讨了 TensorFlow 的高级特性,如自动微分、模型可视化和分布式训练,以及其在未来的发展前景。
841 5

热门文章

最新文章