TensorFlow学习笔记--自定义图像识别

本文涉及的产品
视觉智能开放平台,图像资源包5000点
视觉智能开放平台,视频资源包5000点
视觉智能开放平台,分割抠图1万点
简介: TensorFlow学习笔记--自定义图像识别

零、学习目标

本篇文章主要讲解自己的图像数据如何在TnesorFlow上训练,主要从数据准备、训练模型、验证准确率和导出模型并对图片分类。重点如下:


  1. 微调
  2. 导出模型并对图片分类

一、微调

1.原理

对于新手来说,在自己的数据集上训练一个模型时,最简单的方法是在ImageNet的模型上进行微调。什么是微调呢?以VGG16为例,它的结构为5部分卷积层共13层(conv1 ~ conv5)和3层的全连接层(fc6 ~ fc8),一共16层,因此被称为VGG16。

如果将VGG16的结构用于一个新的数据集,就要去掉最后一层的全连接层,因为最后一层全连接层的输入是前一层的特征,输出的是1000类的概率,正好对应了ImageNet中的1000个类别,但是在这里,我们的类别只有6种,所以要去掉最后一层全连接层,采用一个更符合数据类别的全连接层。

这时,网络参数的初始化值就不是随机生成的了,而是利用VGG16在ImageNet上已经训练好的参数作为训练的初始值。因为ImageNet训练集上的VGG16已经包含了大量有用的卷积过滤器,使用已存在的参数不久节约时间,也有助于提高分类器的性能。


2.训练范围

在载入参数后,我们可以指定训练层数范围,训练层数可选范围如下:


3.只训练fc8这一层,保持其他层的参数不变,将VGG16作为一个特征提取器,用fc7层提起的特征做Softmax分类,这样做有利提高训练速度,但是性能不是最佳的;


4.训练所有参数,对网络中的所有参数都进行训练,性能得以提高,深度模型得以充分发挥,但是速度太慢;


5.训练部分参数,固定浅层参数不变,训练深层参数。


以上这三种方法就是神经网络的微调,通过微调可以将神经网络通过以有模型应用到自己的数据集上。


3.数据处理

我们首先将数据分为训练集和验证集,之后将图片转化为tfrecord格式【注1】。将文件夹 data_preoare 复制到项目的根部录下。这个文件夹中由所需的数据集和代码。data_preoare/pic/train 目录是训练文件所在的目录,data_preoare/pic/validation 目录是验证文件所在的目录。两个目录下又以不同分类划分了6种类别,分别为:农田、冰川、城市地区、森林、水域和岩石,每个文件夹中存放的图片为jpg格式的图片。

由于神经网络无法识别jpg格式的数据,所以需要将图片数据转为tfrecord格式的数据。 切换到 data_preoare 文件夹下,在命令行输入如下命令进行格式转换:

python data_convert.py -t pic/ --train-shards 2 --validation-shards 2 --num-threads 2 --dataset-name satellite

解释一下上面参数的含义:image.png

运行命令后,pic文件夹下会出现五个新的数据文件,以 satellite_train_ 开头的训练据文件和以 satellite_validation_ 开头的验证数据文件,并且还包含一个label.txt文件,表示图片的标签数字到真实类别字符串的映射顺序。例如tfrecod中图片标签为0,就代表类别为label.txt中的第一行类别。


注1:

文件下载地址:下载文件

注2:

如果训练数据集较大,则可以将训练数据集划分为多个数据块

注3:

线程数量必须能整除train-shars和validation-shards,这样才能抱枕每个线程中数据块的数量相等



4.下载TensorFlow Slim 源代码

下载TensorFlow Slim 是Google提供的图像分类工具。里面提供了图像分类的接口、常用的网络结构和预训练模型。

利用git下载Slim源码:git clone ht仁ps://github.corn/tensorflow/models.git,我所提供的下载地址中也有Slim源码。将 Slim 文件夹复制到根目录下即可。代码结构如下:

image.png

4. 定义dataset
slim/datasets 目录下创建 satellite.py 文件,将 flowers.py 文件中的内容复制进去。修改部分代码:

4. _FILE_PATTERN、SPLITS_TO_SIZE、_NUM_CLASSES

# 数据的文件名
_FILE_PATTERN = 'satellite_%s_*.tfrecord'
# 训练集和验证集的数量
SPLITS_TO_SIZE = {'train':4800,'validation':1200}
# 数据集中图片的类别数目
_NUM_CLASSES = 6

2. image/format

# 设定图片格式
'image/format' : tf.FixedLenFeature((),tf.string,default_value = 'jpg')

3. 修改dataset_factory.py

from datasets import cifar10
from datasets import flowers
from datasets import imagenet
from datasets import mnist
# 将satellite模块添加进来
from datasets import satellite
# satellite 数据库加入进来
datasets_map = {
  'cifar10':cifar10,
  'flowers':flowers,
  'imagenet':imagenet,
  'mnist':mnist,
  'satellite':satellite
}

5.准备训练文件夹

在slim文件夹下新建 satellite 目录、satellite/data(训练和验证数据文件夹)、satellite/train_dir(保存训练日志和模型文件夹)、satellite/pretrained。创建完目录后需要完成以下工作:


6.将转换好格式的数据(包括label.txt)复制 satellite/data 文件夹


7.下载Inception V3模型,下载地址是:下载地址,解压后,将inception_v3.ckpt文件复制到 satellite/pretrained


8.训练程序

在slim文件夹下启动命令行,输入如下命令开始训练(代码需要在TensorFlow GPU版本上运行):

python train_image_classifier.py --train_dir=satellite/train_dir --dataset_name=satellite --dataset_split_name=train --dataset_dir=satellite/data --model_name=inception_v3 --checkpoint_path=satellite/pretrained/inception_v3.ckpt --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits --max_number_of_steps=100000 --batch_size=32 --learning_rate=0.001 --learning_rate_decay_type=fixed --save_interval_secs=300 --save_summaries_secs=2 --log_every_n_steps=10 --optimizer=rmsprop --weight_decay=0.00004

解释一下上面参数的含义:

参数 说明
–trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits 指定模型微调变量的范围。这里指设定表示只对 InceptionV3/Logits 和 InceptionV3/AuxLogits 两个变量微调,也就是对fc8进行微调,如果不设置此参数,将会对所有参数进行训练。
–train_dir=satellite/train_dir 在 satellite/train_dir 目录下保存日志和模型文件(heckpoint)
–dataset_name=satellite、–datasets_split_name=train 指定训练数据集
–dataset_dir=satellite/data 训练数据集保存的位置
–model_name=inception_v3 使用的模型名称
–checkpoint_path=satellite/pretrained/inception_v3.ckpt 预训练模型保存的位置
–checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits 恢复预训练模型时不回复这两层,因为这两层模型对应着ImageNet数据集的1000类,与当前数据集不符,所以不要恢复他
–max_number_of_steps 100000 最大执行步数
–batch_size=32 每步的batch数量
–learning_rate=0.001 学习率
–learning_rate_decay_type=fixed 学习率是否下降,此处固定学习率
–save_interval_secs=300 每隔300秒保存一次模型,保存到train_dir目录下
–save_summaries_secs=2 每隔2秒保存一次日志
–log_every_n_steps=10 每隔10步在屏幕上打印出训练信息
–optimizer=rmsprop 指定优化器
–weight_decay=0.00004 设定weight_decay,即模型中所有参数的二次正则化超参数

注4:

开始训练时,如果训练文件夹(satellite/train_dir)里没有保存的模型,就会自动加载 checkpoint_path 中的预训练模型,然后程序会把初始模型保存在train_dir中,命名为 model.ckpt-0,0表示第0步。之后每隔300秒就会保存一次模型,由于模型较大,所以只会保留最新的5个模型。如果中断程序运行后再次运行,会首先检查train_dir文件夹中是否存在模型,如果存在则接着存在的模型开始训练。

7. 验证模型
要查看模型的准确率,可以使用 eval_image_classifier.py 来验证,在命令行输入如下命令:

python eval_image_classifier.py --checkpoint_path=satellite/train_dir --eval_dir=statellite/eval_dir --dataset_name=satellite --dataset_split_name=validation --dataset_dir=satellite/data --model_name=inception_v3

下面来解释一下参数

image.png

执行后会打印出如下内容:

eval/Accuracy[0.51]
eval/Recall_5[0.973333336]

Accuracy表示模型的分类准确率,Recall_5表示前5次的准确率

8. TensorBoard 可视化与超参数选择
使用TnesorBoard 有助于设定训练模型的方式以及超参数,命令行输入如下参数:

tensorboard --logdir satellite/train_dir

在TensorBoard中可以查看损失变化曲线,损失变化曲线有助于调整参数。如果损失曲线比动较大,无法收敛,就有可能时学习率过大,适当减小学习率就行了。

现在做如下操作:


1.在 train_dir 中建立两个文件夹,分别存放只微调fc8和微调整个网络的模型。通过调整 train_dir 参数将这两种模型分别存入新建的文件夹中,之后使用命令:

tensorboard --logdir satellite/train_dir

浏览器打开TensorBoard就可以看到狂歌模型的损失曲线,上方的为只训练末端的损失数,下方为训练所有层的损失函数。看损失函数可以看出训练所有层比只训练末端要好。


二、到处模型并分类图片

模型训练完之后,将会进行部署。这里提供了两个文件 freeze_graph.pyclassify_image_inception_v3.py 前者用于导出识别模型,后者用于识别单张图片。在slim文件夹下执行如下命令:

python export_inference_graph.py --alsologtostderr --model_name=inception_v3 --output_file=satellite/inception_v3_inf_graph.pb --dataset_name satellite

命令执行后,会在satellite文件夹下生成一个 inception_v3_inf_graph.pb 文件,但是这个文件不包含训练获得的模型参数,需要将cheeckpoint中的模型参数保存进来,方法是使用freeze_graph.py:

python freeze_graph.py --input_graph slim/satellite/inception_v3export_inference_graph.pb --input_checkpoint slim/satellite/train_dir/model.ckpt-5271 --input_binary true --output_node_names InceptionV3/Predictions/Reshape_1 --output_graph slim/satellite/frozen_graph.pb

这里讲解一下参数:

image.png

下面开始对图片进行识别。命令行执行脚本 classify_image_inception_v3.py ,运行如下命令:

python classify_image_inception_v3.py --model_path slim/statellite/frozen_graph.pb --label_path data_preoare/pic/label.txt --image_file test_image.jpg

讲解参数:

image.png

执行完参数后,将输出每种类别的概率。

三、总结

首先简要介绍了微调神经网络的基本原理,接着详细介绍了如何使用 TensorFlow Slim 微调预训练模型,包括数据准备、定义新的 datasets 文件、训练、 验证 、 导出模型井测试单张图片等。

目录
相关文章
|
2月前
|
机器学习/深度学习 算法 TensorFlow
动物识别系统Python+卷积神经网络算法+TensorFlow+人工智能+图像识别+计算机毕业设计项目
动物识别系统。本项目以Python作为主要编程语言,并基于TensorFlow搭建ResNet50卷积神经网络算法模型,通过收集4种常见的动物图像数据集(猫、狗、鸡、马)然后进行模型训练,得到一个识别精度较高的模型文件,然后保存为本地格式的H5格式文件。再基于Django开发Web网页端操作界面,实现用户上传一张动物图片,识别其名称。
82 1
动物识别系统Python+卷积神经网络算法+TensorFlow+人工智能+图像识别+计算机毕业设计项目
|
20天前
|
缓存 TensorFlow 算法框架/工具
TensorFlow学习笔记(一): tf.Variable() 和tf.get_variable()详解
这篇文章详细介绍了TensorFlow中`tf.Variable()`和`tf.get_variable()`的使用方法、参数含义以及它们之间的区别。
42 0
|
5月前
|
机器学习/深度学习 人工智能 算法
海洋生物识别系统+图像识别+Python+人工智能课设+深度学习+卷积神经网络算法+TensorFlow
海洋生物识别系统。以Python作为主要编程语言,通过TensorFlow搭建ResNet50卷积神经网络算法,通过对22种常见的海洋生物('蛤蜊', '珊瑚', '螃蟹', '海豚', '鳗鱼', '水母', '龙虾', '海蛞蝓', '章鱼', '水獭', '企鹅', '河豚', '魔鬼鱼', '海胆', '海马', '海豹', '鲨鱼', '虾', '鱿鱼', '海星', '海龟', '鲸鱼')数据集进行训练,得到一个识别精度较高的模型文件,然后使用Django开发一个Web网页平台操作界面,实现用户上传一张海洋生物图片识别其名称。
175 7
海洋生物识别系统+图像识别+Python+人工智能课设+深度学习+卷积神经网络算法+TensorFlow
|
5月前
|
机器学习/深度学习 人工智能 算法
【乐器识别系统】图像识别+人工智能+深度学习+Python+TensorFlow+卷积神经网络+模型训练
乐器识别系统。使用Python为主要编程语言,基于人工智能框架库TensorFlow搭建ResNet50卷积神经网络算法,通过对30种乐器('迪吉里杜管', '铃鼓', '木琴', '手风琴', '阿尔卑斯号角', '风笛', '班卓琴', '邦戈鼓', '卡萨巴', '响板', '单簧管', '古钢琴', '手风琴(六角形)', '鼓', '扬琴', '长笛', '刮瓜', '吉他', '口琴', '竖琴', '沙槌', '陶笛', '钢琴', '萨克斯管', '锡塔尔琴', '钢鼓', '长号', '小号', '大号', '小提琴')的图像数据集进行训练,得到一个训练精度较高的模型,并将其
67 0
【乐器识别系统】图像识别+人工智能+深度学习+Python+TensorFlow+卷积神经网络+模型训练
|
15天前
|
机器学习/深度学习 SQL 数据采集
基于tensorflow、CNN网络识别花卉的种类(图像识别)
基于tensorflow、CNN网络识别花卉的种类(图像识别)
14 1
|
2月前
|
机器学习/深度学习 人工智能 算法
鸟类识别系统Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+ResNet50算法模型+图像识别
鸟类识别系统。本系统采用Python作为主要开发语言,通过使用加利福利亚大学开源的200种鸟类图像作为数据集。使用TensorFlow搭建ResNet50卷积神经网络算法模型,然后进行模型的迭代训练,得到一个识别精度较高的模型,然后在保存为本地的H5格式文件。在使用Django开发Web网页端操作界面,实现用户上传一张鸟类图像,识别其名称。
92 12
鸟类识别系统Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+ResNet50算法模型+图像识别
|
3月前
|
机器学习/深度学习 人工智能 TensorFlow
使用Python和TensorFlow实现图像识别
【8月更文挑战第31天】本文将引导你了解如何使用Python和TensorFlow库来实现图像识别。我们将从基本的Python编程开始,逐步深入到TensorFlow的高级功能,最后通过一个简单的代码示例来展示如何训练一个模型来识别图像。无论你是初学者还是有经验的开发者,这篇文章都将为你提供有价值的信息。
152 53
|
20天前
|
TensorFlow 算法框架/工具
Tensorflow学习笔记(二):各种tf类型的函数用法集合
这篇文章总结了TensorFlow中各种函数的用法,包括创建张量、设备管理、数据类型转换、随机数生成等基础知识。
24 0
|
3月前
|
机器学习/深度学习 人工智能 算法
【眼疾病识别】图像识别+深度学习技术+人工智能+卷积神经网络算法+计算机课设+Python+TensorFlow
眼疾识别系统,使用Python作为主要编程语言进行开发,基于深度学习等技术使用TensorFlow搭建ResNet50卷积神经网络算法,通过对眼疾图片4种数据集进行训练('白内障', '糖尿病性视网膜病变', '青光眼', '正常'),最终得到一个识别精确度较高的模型。然后使用Django框架开发Web网页端可视化操作界面,实现用户上传一张眼疾图片识别其名称。
79 9
【眼疾病识别】图像识别+深度学习技术+人工智能+卷积神经网络算法+计算机课设+Python+TensorFlow
|
3月前
|
机器学习/深度学习 人工智能 TensorFlow
利用Python和TensorFlow实现简单图像识别
【8月更文挑战第31天】在这篇文章中,我们将一起踏上一段探索人工智能世界的奇妙之旅。正如甘地所言:“你必须成为你希望在世界上看到的改变。” 通过实践,我们不仅将学习如何使用Python和TensorFlow构建一个简单的图像识别模型,而且还将探索如何通过这个模型理解世界。文章以通俗易懂的方式,逐步引导读者从基础到高级,体验从编码到识别的整个过程,让每个人都能在AI的世界中看到自己的倒影。