配合视频一起食用这篇教程效果更佳:手把手教你用tensorflow2训练自己的数据集
tensorflow2.x版本对小白非常友好,2.x的api中对keras进行了合并,大家只需要安装tensorflow就可以使用里面封装好的keras,利用keras可以快速地加载数据集和构建模型,下面我们直接来看以下通过tensorflow2.3训练自己的分类数据集吧。
注:本文主要针对图片形式的数据集构建分类模型,文本数据、目标检测等任务暂不涉及。
本文使用到的代码已在码云上开源,请大家自行下载,star不迷路:
vegetables_tf2.3: 基于tensorflow2.3开发的水果蔬菜识别系统 (gitee.com)
另外我这边整理了一些物体分类的数据集,大家根据需要下载:
计算机视觉数据集清单-附赠tensorflow模型训练和使用教程_dejavu的博客-CSDN博客
数据集收集
数据集收集主要有3种方式,一种是使用某些机构或者组织开源出来的数据集,另一种是自己通过拍照或者爬虫的方式来自行获取数据集,还有一种是热心网友自己采集整理之后的数据集,下面的csdn链接中我给出了一些我整理的数据集,大家可以根据自己的需要下载使用。
计算机视觉数据集清单-附赠tensorflow模型训练和使用教程_dejavu的博客-CSDN博客
开源数据集
开源的分类数据集一般质量相对较好,数据集的所有者在发布前对数据集做了整理和清洗,直接使用开源的数据集可以帮助我们节省大量的时间,比较有名的有mnist数据集、cifar数据集等,另外大家可以在一些网站中寻找数据集,比如下列的几个网站:
UCI Machine Learning Repository
另外你也可以直接在搜索引擎中输入关键字来寻找数据集,比如你想要寻找垃圾分类的数据集,你可以在搜索栏中输入垃圾
分类
数据集
等关键字来直接查找,一般会有热心的网友给出数据集的链接,下载即可。
自行采集数据集
如果找不到相应的开源数据集,你也可以通过自己采集的方式来获取数据集,比如你可以通过拍照的方式来搜集你自己所需的数据集,或者是通过爬虫的方式来搜集数据集,这里有段爬虫爬取百度图片的代码,大家直接执行,输入自己想要爬取的图片名称和图片数量,即可爬取相应的图片,代码如下:
import requests
import re
import os
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/84.0.4147.125 Safari/537.36'}
name = input('请输入要爬取的图片类别:')
num = 0
num_1 = 0
num_2 = 0
x = input('请输入要爬取的图片数量?(1等于60张图片,2等于120张图片):')
list_1 = []
for i in range(int(x)):
name_1 = os.getcwd()
name_2 = os.path.join(name_1, 'data/' + name)
url = 'https://image.baidu.com/search/flip?tn=baiduimage&ie=utf-8&word=' + name + '&pn=' + str(i * 30)
res = requests.get(url, headers=headers)
htlm_1 = res.content.decode()
a = re.findall('"objURL":"(.*?)",', htlm_1)
if not os.path.exists(name_2):
os.makedirs(name_2)
for b in a:
try:
b_1 = re.findall('https:(.*?)&', b)
b_2 = ''.join(b_1)
if b_2 not in list_1:
num = num + 1
img = requests.get(b)
f = open(os.path.join(name_1, 'data/' + name, name + str(num) + '.jpg'), 'ab')
print('---------正在下载第' + str(num) + '张图片----------')
f.write(img.content)
f.close()
list_1.append(b_2)
elif b_2 in list_1:
num_1 = num_1 + 1
continue
except Exception as e:
print('---------第' + str(num) + '张图片无法下载----------')
num_2 = num_2 + 1
continue
print('下载完成,总共下载{}张,成功下载:{}张,重复下载:{}张,下载失败:{}张'.format(num + num_1 + num_2, num, num_1, num_2))
数据集整理
放置到相应的子文件夹
数据集收集完成之后,我们还需要对数据集进行整理,如果是爬虫爬取的图片可能会有一些质量比较差的图片,那么整理之前还需要进行数据的清洗,删除质量不好的图片,数据集整理其实很简单,我们只需要将数据集进行归类即可,即相同类别的图片放在一个文件夹下,比如下面的这个数据集,白菜的文件夹下放的全是白菜的图片,土豆的文件夹下则放的全是土豆的图片。
划分训练集和测试集
注:如果是使用的开源数据集,开源数据集可能已经进行了数据集的划分,直接使用即可,不需要再次进行划分,比如这里是我下载到的农作物病虫害的数据集,已经分别提供了训练集、测试集和验证集,就不需要再次进行数据集的划分。
为了方便我们进行数据集的加载,我们还需要将图片划分为训练集和测试集,如果需要的话你还需要划分出验证集,验证集在一般的任务中是可选的,因为是自己收集的数据集的话,数据量比较少,如果再划分验证集的话可能会导致训练量不够,这里我写了一段数据集划分的代码逻辑,大家输入原始的数据集位置和划分之后的数据集位置,指定数据集划分的比例,即可完成数据集的划分。
# 作者: 宋老狗
import os
import random
from shutil import copy2
def data_set_split(src_data_folder, target_data_folder, train_scale=0.8, val_scale=0.0, test_scale=0.2):
'''
读取源数据文件夹,生成划分好的文件夹,分为trian、val、test三个文件夹进行
:param src_data_folder: 源文件夹 E:/biye/gogogo/note_book/torch_note/data/utils_test/data_split/src_data
:param target_data_folder: 目标文件夹 E:/biye/gogogo/note_book/torch_note/data/utils_test/data_split/target_data
:param train_scale: 训练集比例
:param val_scale: 验证集比例
:param test_scale: 测试集比例
:return:
'''
print("开始数据集划分")
class_names = os.listdir(src_data_folder)
# 在目标目录下创建文件夹
split_names = ['train', 'val', 'test']
for split_name in split_names:
split_path = os.path.join(target_data_folder, split_name)
if os.path.isdir(split_path):
pass
else:
os.mkdir(split_path)
# 然后在split_path的目录下创建类别文件夹
for class_name in class_names:
class_split_path = os.path.join(split_path, class_name)
if os.path.isdir(class_split_path):
pass
else:
os.mkdir(class_split_path)
# 按照比例划分数据集,并进行数据图片的复制
# 首先进行分类遍历
for class_name in class_names:
current_class_data_path = os.path.join(src_data_folder, class_name)
current_all_data = os.listdir(current_class_data_path)
current_data_length = len(current_all_data)
current_data_index_list = list(range(current_data_length))
random.shuffle(current_data_index_list)
train_folder = os.path.join(os.path.join(target_data_folder, 'train'), class_name)
val_folder = os.path.join(os.path.join(target_data_folder, 'val'), class_name)
test_folder = os.path.join(os.path.join(target_data_folder, 'test'), class_name)
train_stop_flag = current_data_length * train_scale
val_stop_flag = current_data_length * (train_scale + val_scale)
current_idx = 0
train_num = 0
val_num = 0
test_num = 0
for i in current_data_index_list:
src_img_path = os.path.join(current_class_data_path, current_all_data[i])
if current_idx <= train_stop_flag:
copy2(src_img_path, train_folder)
# print("{}复制到了{}".format(src_img_path, train_folder))
train_num = train_num + 1
elif (current_idx > train_stop_flag) and (current_idx <= val_stop_flag):
copy2(src_img_path, val_folder)
# print("{}复制到了{}".format(src_img_path, val_folder))
val_num = val_num + 1
else:
copy2(src_img_path, test_folder)
# print("{}复制到了{}".format(src_img_path, test_folder))
test_num = test_num + 1
current_idx = current_idx + 1
print("*********************************{}*************************************".format(class_name))
print(
"{}类按照{}:{}:{}的比例划分完成,一共{}张图片".format(class_name, train_scale, val_scale, test_scale, current_data_length))
print("训练集{}:{}张".format(train_folder, train_num))
print("验证集{}:{}张".format(val_folder, val_num))
print("测试集{}:{}张".format(test_folder, test_num))
if __name__ == '__main__':
src_data_folder = "C:/Users/Scm97/Desktop/dejahu/data" # todo 原始数据集目录
target_data_folder = "C:/Users/Scm97/Desktop/dejahu/split_data" # todo 数据集分割之后存放的目录
data_set_split(src_data_folder, target_data_folder)
注:路径中最好不要出现中文
数据集划分之后,记住训练集和测试集的位置,接下来,我们就可以开始训练我们的模型了。
下面以花卉识别,我给大家演示一下,data
是演示目录,目录下存放的是5个子文文件夹,对应5种花卉,每个子文件夹下存放了相应的花卉图片,split_data
是新建的空文件夹,用于存放分割之后的数据集,这时候只需要修改代码种的两处即可。
代码默认训练集占80%,测试集占20%,修改完成之后右键直接执行即可。
执行之后你就可以得到划分好的数据集
这个时候记住训练集和测试集的目录,开始大干一场吧。
测试集目录为:C:/Users/Scm97/Desktop/dejahu/split_data/train
训练集目录为:C:/Users/Scm97/Desktop/dejahu/split_data/test
环境搭建
本次教程需要大家实现配置好python的环境,我们需要使用到anaconda和pycharm,不熟悉环境配置的同学可以看我得这篇博客,我在这里就不再进行赘述了。
如何在pycharm中配置anaconda的虚拟环境_dejavu的博客-CSDN博客
训练模型
模型训练的代码种,以cnn模型的训练为例,train_cnn.py
是训练cnn模型的代码,只需要修改三处即可,如下所示
train_mobilnet.py
是训练mobilenet模型的代码,训练的模型将会保存在models目录下,这里也是只需修改三处即可。
注:代码最后一行的epochs指的是跑的训练的轮数,这里默认是30,大家可以根据自己的需要增加或减少训练的轮数
修改之后直接运行即可,等代码跑完后模型就会保存在models目录下
另外,在results目录下你可以找到模型训练的过程图
模型训练的过程中会输出数据集的类名,这里记录一下,在后面的模型使用中会用到。
测试模型
模型的测试的代码为test_model.py
,也是只需要改动几处代码即可完成测试
改动如下:
测试的基本流程是:加载数据、加载模型、测试、保存结果
测试之后在命令行中会输出每个模型的准确率,并且会在results目录下生成相应的热力图
热力图中对应了每个类别的准确率,如下所示,是mobilenet测试的热力图。
使用模型
模型的时候中,我们通过Pyqt5来构建图形化界面,用户可以上传图片,并在系统中调用我们训练好的模型进行图片类别的预测。
在window.py
代码中修改四处即可完成基本功能
启动看看吧!
快去试试你自己的数据集吧!