【大作业-03】手把手教你用tensorflow2.3训练自己的分类数据集

简介: 本教程详细介绍了如何使用TensorFlow 2.3训练自定义图像分类数据集,涵盖数据集收集、整理、划分及模型训练与测试全过程。提供完整代码示例及图形界面应用开发指导,适合初学者快速上手。[教程链接](https://www.bilibili.com/video/BV1rX4y1A7N8/),配套视频更易理解。

配合视频一起食用这篇教程效果更佳:手把手教你用tensorflow2训练自己的数据集

tensorflow2.x版本对小白非常友好,2.x的api中对keras进行了合并,大家只需要安装tensorflow就可以使用里面封装好的keras,利用keras可以快速地加载数据集和构建模型,下面我们直接来看以下通过tensorflow2.3训练自己的分类数据集吧。

注:本文主要针对图片形式的数据集构建分类模型,文本数据、目标检测等任务暂不涉及。

本文使用到的代码已在码云上开源,请大家自行下载,star不迷路:

vegetables_tf2.3: 基于tensorflow2.3开发的水果蔬菜识别系统 (gitee.com)

另外我这边整理了一些物体分类的数据集,大家根据需要下载:

计算机视觉数据集清单-附赠tensorflow模型训练和使用教程_dejavu的博客-CSDN博客

数据集收集

数据集收集主要有3种方式,一种是使用某些机构或者组织开源出来的数据集,另一种是自己通过拍照或者爬虫的方式来自行获取数据集,还有一种是热心网友自己采集整理之后的数据集,下面的csdn链接中我给出了一些我整理的数据集,大家可以根据自己的需要下载使用。

计算机视觉数据集清单-附赠tensorflow模型训练和使用教程_dejavu的博客-CSDN博客

开源数据集

开源的分类数据集一般质量相对较好,数据集的所有者在发布前对数据集做了整理和清洗,直接使用开源的数据集可以帮助我们节省大量的时间,比较有名的有mnist数据集、cifar数据集等,另外大家可以在一些网站中寻找数据集,比如下列的几个网站:

和鲸社区 - Heywhale.com

UCI Machine Learning Repository

CSDN - 专业开发者社区

另外你也可以直接在搜索引擎中输入关键字来寻找数据集,比如你想要寻找垃圾分类的数据集,你可以在搜索栏中输入垃圾 分类 数据集等关键字来直接查找,一般会有热心的网友给出数据集的链接,下载即可。

image-20210616112527938

自行采集数据集

如果找不到相应的开源数据集,你也可以通过自己采集的方式来获取数据集,比如你可以通过拍照的方式来搜集你自己所需的数据集,或者是通过爬虫的方式来搜集数据集,这里有段爬虫爬取百度图片的代码,大家直接执行,输入自己想要爬取的图片名称和图片数量,即可爬取相应的图片,代码如下:

import requests
import re
import os

headers = {
    'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/84.0.4147.125 Safari/537.36'}
name = input('请输入要爬取的图片类别:')
num = 0
num_1 = 0
num_2 = 0
x = input('请输入要爬取的图片数量?(1等于60张图片,2等于120张图片):')
list_1 = []
for i in range(int(x)):
    name_1 = os.getcwd()
    name_2 = os.path.join(name_1, 'data/' + name)
    url = 'https://image.baidu.com/search/flip?tn=baiduimage&ie=utf-8&word=' + name + '&pn=' + str(i * 30)
    res = requests.get(url, headers=headers)
    htlm_1 = res.content.decode()
    a = re.findall('"objURL":"(.*?)",', htlm_1)
    if not os.path.exists(name_2):
        os.makedirs(name_2)
    for b in a:
        try:
            b_1 = re.findall('https:(.*?)&', b)
            b_2 = ''.join(b_1)
            if b_2 not in list_1:
                num = num + 1
                img = requests.get(b)
                f = open(os.path.join(name_1, 'data/' + name, name + str(num) + '.jpg'), 'ab')
                print('---------正在下载第' + str(num) + '张图片----------')
                f.write(img.content)
                f.close()
                list_1.append(b_2)
            elif b_2 in list_1:
                num_1 = num_1 + 1
                continue
        except Exception as e:
            print('---------第' + str(num) + '张图片无法下载----------')
            num_2 = num_2 + 1
            continue

print('下载完成,总共下载{}张,成功下载:{}张,重复下载:{}张,下载失败:{}张'.format(num + num_1 + num_2, num, num_1, num_2))

数据集整理

放置到相应的子文件夹

数据集收集完成之后,我们还需要对数据集进行整理,如果是爬虫爬取的图片可能会有一些质量比较差的图片,那么整理之前还需要进行数据的清洗,删除质量不好的图片,数据集整理其实很简单,我们只需要将数据集进行归类即可,即相同类别的图片放在一个文件夹下,比如下面的这个数据集,白菜的文件夹下放的全是白菜的图片,土豆的文件夹下则放的全是土豆的图片。

image-20210616122055853

image-20210616122124060

划分训练集和测试集

注:如果是使用的开源数据集,开源数据集可能已经进行了数据集的划分,直接使用即可,不需要再次进行划分,比如这里是我下载到的农作物病虫害的数据集,已经分别提供了训练集、测试集和验证集,就不需要再次进行数据集的划分。

image-20210616122542476

为了方便我们进行数据集的加载,我们还需要将图片划分为训练集和测试集,如果需要的话你还需要划分出验证集,验证集在一般的任务中是可选的,因为是自己收集的数据集的话,数据量比较少,如果再划分验证集的话可能会导致训练量不够,这里我写了一段数据集划分的代码逻辑,大家输入原始的数据集位置和划分之后的数据集位置,指定数据集划分的比例,即可完成数据集的划分。

# 作者: 宋老狗
import os
import random
from shutil import copy2


def data_set_split(src_data_folder, target_data_folder, train_scale=0.8, val_scale=0.0, test_scale=0.2):
    '''
    读取源数据文件夹,生成划分好的文件夹,分为trian、val、test三个文件夹进行
    :param src_data_folder: 源文件夹 E:/biye/gogogo/note_book/torch_note/data/utils_test/data_split/src_data
    :param target_data_folder: 目标文件夹 E:/biye/gogogo/note_book/torch_note/data/utils_test/data_split/target_data
    :param train_scale: 训练集比例
    :param val_scale: 验证集比例
    :param test_scale: 测试集比例
    :return:
    '''
    print("开始数据集划分")
    class_names = os.listdir(src_data_folder)
    # 在目标目录下创建文件夹
    split_names = ['train', 'val', 'test']
    for split_name in split_names:
        split_path = os.path.join(target_data_folder, split_name)
        if os.path.isdir(split_path):
            pass
        else:
            os.mkdir(split_path)
        # 然后在split_path的目录下创建类别文件夹
        for class_name in class_names:
            class_split_path = os.path.join(split_path, class_name)
            if os.path.isdir(class_split_path):
                pass
            else:
                os.mkdir(class_split_path)

    # 按照比例划分数据集,并进行数据图片的复制
    # 首先进行分类遍历
    for class_name in class_names:
        current_class_data_path = os.path.join(src_data_folder, class_name)
        current_all_data = os.listdir(current_class_data_path)
        current_data_length = len(current_all_data)
        current_data_index_list = list(range(current_data_length))
        random.shuffle(current_data_index_list)

        train_folder = os.path.join(os.path.join(target_data_folder, 'train'), class_name)
        val_folder = os.path.join(os.path.join(target_data_folder, 'val'), class_name)
        test_folder = os.path.join(os.path.join(target_data_folder, 'test'), class_name)
        train_stop_flag = current_data_length * train_scale
        val_stop_flag = current_data_length * (train_scale + val_scale)
        current_idx = 0
        train_num = 0
        val_num = 0
        test_num = 0
        for i in current_data_index_list:
            src_img_path = os.path.join(current_class_data_path, current_all_data[i])
            if current_idx <= train_stop_flag:
                copy2(src_img_path, train_folder)
                # print("{}复制到了{}".format(src_img_path, train_folder))
                train_num = train_num + 1
            elif (current_idx > train_stop_flag) and (current_idx <= val_stop_flag):
                copy2(src_img_path, val_folder)
                # print("{}复制到了{}".format(src_img_path, val_folder))
                val_num = val_num + 1
            else:
                copy2(src_img_path, test_folder)
                # print("{}复制到了{}".format(src_img_path, test_folder))
                test_num = test_num + 1

            current_idx = current_idx + 1

        print("*********************************{}*************************************".format(class_name))
        print(
            "{}类按照{}:{}:{}的比例划分完成,一共{}张图片".format(class_name, train_scale, val_scale, test_scale, current_data_length))
        print("训练集{}:{}张".format(train_folder, train_num))
        print("验证集{}:{}张".format(val_folder, val_num))
        print("测试集{}:{}张".format(test_folder, test_num))


if __name__ == '__main__':
    src_data_folder = "C:/Users/Scm97/Desktop/dejahu/data"  # todo 原始数据集目录
    target_data_folder = "C:/Users/Scm97/Desktop/dejahu/split_data"  # todo 数据集分割之后存放的目录
    data_set_split(src_data_folder, target_data_folder)

注:路径中最好不要出现中文

数据集划分之后,记住训练集和测试集的位置,接下来,我们就可以开始训练我们的模型了。

下面以花卉识别,我给大家演示一下,data是演示目录,目录下存放的是5个子文文件夹,对应5种花卉,每个子文件夹下存放了相应的花卉图片,split_data是新建的空文件夹,用于存放分割之后的数据集,这时候只需要修改代码种的两处即可。

image-20210616125843872

image-20210616130151242

代码默认训练集占80%,测试集占20%,修改完成之后右键直接执行即可。

image-20210616130259610

执行之后你就可以得到划分好的数据集

image-20210616130407052

这个时候记住训练集和测试集的目录,开始大干一场吧。

测试集目录为:C:/Users/Scm97/Desktop/dejahu/split_data/train

训练集目录为:C:/Users/Scm97/Desktop/dejahu/split_data/test

环境搭建

本次教程需要大家实现配置好python的环境,我们需要使用到anaconda和pycharm,不熟悉环境配置的同学可以看我得这篇博客,我在这里就不再进行赘述了。

如何在pycharm中配置anaconda的虚拟环境_dejavu的博客-CSDN博客

训练模型

模型训练的代码种,以cnn模型的训练为例,train_cnn.py是训练cnn模型的代码,只需要修改三处即可,如下所示

image-20210616131845980

train_mobilnet.py是训练mobilenet模型的代码,训练的模型将会保存在models目录下,这里也是只需修改三处即可。

image-20210616131957245

注:代码最后一行的epochs指的是跑的训练的轮数,这里默认是30,大家可以根据自己的需要增加或减少训练的轮数

修改之后直接运行即可,等代码跑完后模型就会保存在models目录下

image-20210616140111560

另外,在results目录下你可以找到模型训练的过程图

image-20210616140207484

模型训练的过程中会输出数据集的类名,这里记录一下,在后面的模型使用中会用到。

image-20210616134038441

测试模型

模型的测试的代码为test_model.py,也是只需要改动几处代码即可完成测试

改动如下:

image-20210616140427881

image-20210616140521754

测试的基本流程是:加载数据、加载模型、测试、保存结果

测试之后在命令行中会输出每个模型的准确率,并且会在results目录下生成相应的热力图

image-20210616140754011
image-20210616140824493

热力图中对应了每个类别的准确率,如下所示,是mobilenet测试的热力图。

heatmap_mobilenet

使用模型

模型的时候中,我们通过Pyqt5来构建图形化界面,用户可以上传图片,并在系统中调用我们训练好的模型进行图片类别的预测。

window.py代码中修改四处即可完成基本功能

image-20210616141502873

启动看看吧!

image-20210616141525770

快去试试你自己的数据集吧!

目录
相关文章
|
4天前
|
并行计算 Shell TensorFlow
Tensorflow-GPU训练MTCNN出现错误-Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
在使用TensorFlow-GPU训练MTCNN时,如果遇到“Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED”错误,通常是由于TensorFlow、CUDA和cuDNN版本不兼容或显存分配问题导致的,可以通过安装匹配的版本或在代码中设置动态显存分配来解决。
20 1
Tensorflow-GPU训练MTCNN出现错误-Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
|
2天前
|
机器学习/深度学习 TensorFlow 算法框架/工具
【大作业-04】手把手教你构建垃圾分类系统-基于tensorflow2.3
本文介绍了基于TensorFlow 2.3的垃圾分类系统,通过B站视频和博客详细讲解了系统的构建过程。系统使用了包含8万张图片、245个类别的数据集,训练了LeNet和MobileNet两个卷积神经网络模型,并通过PyQt5构建了图形化界面,用户上传图片后,系统能识别垃圾的具体种类。此外,还提供了模型和数据集的下载链接,方便读者复现实验。垃圾分类对于提高资源利用率、减少环境污染具有重要意义。
10 0
【大作业-04】手把手教你构建垃圾分类系统-基于tensorflow2.3
|
2天前
|
机器学习/深度学习 TensorFlow 算法框架/工具
【大作业-02】水果蔬菜识别系统-基于tensorflow2.3开发
2021年6月18日,TensorFlow 2.3物体分类代码已修复并更新。本项目支持自定义数据集训练,包括基于CNN和Mobilenet的模型,后者准确率高达97%。提供了详细的CSDN教程、B站教学视频及数据集下载链接,帮助用户快速上手。项目还包括PyQt5构建的图形界面,方便用户上传图片进行果蔬识别。更多详情与代码可在Gitee获取。
7 0
【大作业-02】水果蔬菜识别系统-基于tensorflow2.3开发
|
2天前
|
机器学习/深度学习 测试技术 TensorFlow
【大作业-01】花卉识别-基于tensorflow2.3实现
2021年6月18日更新:提供修复后的TensorFlow 2.3物体分类代码,支持自定义数据集训练。包含CSDN教程、B站视频、数据集及代码下载链接。示例项目为花卉识别,涵盖模型训练、测试、保存和使用,附带图形界面操作指南。
5 0
【大作业-01】花卉识别-基于tensorflow2.3实现
|
2月前
|
UED 存储 数据管理
深度解析 Uno Platform 离线状态处理技巧:从网络检测到本地存储同步,全方位提升跨平台应用在无网环境下的用户体验与数据管理策略
【8月更文挑战第31天】处理离线状态下的用户体验是现代应用开发的关键。本文通过在线笔记应用案例,介绍如何使用 Uno Platform 优雅地应对离线状态。首先,利用 `NetworkInformation` 类检测网络状态;其次,使用 SQLite 实现离线存储;然后,在网络恢复时同步数据;最后,通过 UI 反馈提升用户体验。
59 0
|
2月前
|
安全 Apache 数据安全/隐私保护
你的Wicket应用安全吗?揭秘在Apache Wicket中实现坚不可摧的安全认证策略
【8月更文挑战第31天】在当前的网络环境中,安全性是任何应用程序的关键考量。Apache Wicket 是一个强大的 Java Web 框架,提供了丰富的工具和组件,帮助开发者构建安全的 Web 应用程序。本文介绍了如何在 Wicket 中实现安全认证,
38 0
|
2月前
|
机器学习/深度学习 数据采集 TensorFlow
从零到精通:TensorFlow与卷积神经网络(CNN)助你成为图像识别高手的终极指南——深入浅出教你搭建首个猫狗分类器,附带实战代码与训练技巧揭秘
【8月更文挑战第31天】本文通过杂文形式介绍了如何利用 TensorFlow 和卷积神经网络(CNN)构建图像识别系统,详细演示了从数据准备、模型构建到训练与评估的全过程。通过具体示例代码,展示了使用 Keras API 训练猫狗分类器的步骤,旨在帮助读者掌握图像识别的核心技术。此外,还探讨了图像识别在物体检测、语义分割等领域的广泛应用前景。
17 0
|
2月前
|
机器学习/深度学习 TensorFlow 数据处理
分布式训练在TensorFlow中的全面应用指南:掌握多机多卡配置与实践技巧,让大规模数据集训练变得轻而易举,大幅提升模型训练效率与性能
【8月更文挑战第31天】本文详细介绍了如何在Tensorflow中实现多机多卡的分布式训练,涵盖环境配置、模型定义、数据处理及训练执行等关键环节。通过具体示例代码,展示了使用`MultiWorkerMirroredStrategy`进行分布式训练的过程,帮助读者更好地应对大规模数据集与复杂模型带来的挑战,提升训练效率。
39 0
|
1月前
|
机器学习/深度学习 数据挖掘 TensorFlow
解锁Python数据分析新技能,TensorFlow&PyTorch双引擎驱动深度学习实战盛宴
在数据驱动时代,Python凭借简洁的语法和强大的库支持,成为数据分析与机器学习的首选语言。Pandas和NumPy是Python数据分析的基础,前者提供高效的数据处理工具,后者则支持科学计算。TensorFlow与PyTorch作为深度学习领域的两大框架,助力数据科学家构建复杂神经网络,挖掘数据深层价值。通过Python打下的坚实基础,结合TensorFlow和PyTorch的强大功能,我们能在数据科学领域探索无限可能,解决复杂问题并推动科研进步。
50 0
|
1月前
|
机器学习/深度学习 数据挖掘 TensorFlow
从数据小白到AI专家:Python数据分析与TensorFlow/PyTorch深度学习的蜕变之路
【9月更文挑战第10天】从数据新手成长为AI专家,需先掌握Python基础语法,并学会使用NumPy和Pandas进行数据分析。接着,通过Matplotlib和Seaborn实现数据可视化,最后利用TensorFlow或PyTorch探索深度学习。这一过程涉及从数据清洗、可视化到构建神经网络的多个步骤,每一步都需不断实践与学习。借助Python的强大功能及各类库的支持,你能逐步解锁数据的深层价值。
53 0

热门文章

最新文章