深度学习篇之数据集划分方法-附代码python详细注释

简介: 深度学习篇之数据集划分方法-附代码python详细注释

在深度学习训练模型过程中,我们第一步就是要收集相应的数据集,之后我们就是要将数据划分为训练集train和验证集val,但是有时间我们时常面临数据量庞大的问题,手动划分显然是不现实的,因为太麻烦了,而且不具有固定规律的随机性。


但是python对文件和文件夹极其强大的操作性帮助我们解决了数据集划分的问题,本篇博客我们将开源数据集划分的代码,让我们学习如何使用python划分自己的数据集。且我们在程序中设置随机种子,确保每次从数据集中抽取图片划分数据集的时候都是随机的,且保留种子,整个过程可复制。


先简单讲解一下代码的使用方法,在博客的最后会附上完整的代码

data_path = './data'#数据集存放的地方,建议在程序所在的文件夹下新建一个data文件夹,将需要划分的数据集存放进去
data_root = './'  #这里是生成的训练集和验证集所处的位置,这里设置的是在当前文件夹下。

主要在于这里,data_path,我在代码中设置的是在当前文件夹下新建一个data文件夹,将你需要划分的数据集放入data文件夹下,data_root,为我门划分完的训练集和验证集所放置的位置,我这里设置的是在当前文件夹下,我这里提到的当前文件夹下,就是与这个程序放置的位置一致的位置。


image.png


简单来说即,在程序的当前文件夹下,新建一个data文件夹用来放置自己的数据集,然后直接运行程序即可,生成的训练集和训练集train文件夹和验证集val文件夹会生成在当前文件夹下。


image.png


划分过程ing


image.png


这里的split_rate = 0.1 #这里填多少 就是验证集的比例是多少,比如填0.1就是验证集的数量占总数据集的10%。


split_rate = 0.1 #这里填多少 就是验证集的比例是多少,比如填0.1就是验证集的数量占总数据集的10%

附上数据集划分完整代码:

import os
from shutil import copy, rmtree
import random
def make_file(file_path: str):
    if os.path.exists(file_path):
        # 如果文件夹存在,则先删除原文件夹在重新创建
        rmtree(file_path)
    os.makedirs(file_path)
# 保证随机可复现
random.seed(0)#保证每次随机抽取的都可以复现
# 将数据集中10%的数据划分到验证集中
split_rate = 0.1 #这里填多少 就是验证集的比例是多少,比如填0.1就是验证集的数量占总数据集的10%
data_path = './data'#数据集存放的地方,建议在程序所在的文件夹下新建一个data文件夹,将需要划分的数据集存放进去
data_root = './'  #这里是生成的训练集和验证集所处的位置,这里设置的是在当前文件夹下。
data_class = [cla for cla in os.listdir(data_path)]
print("数据的种类分别为:")
print(data_class)# 输出数据种类,数据种类默认为读取的文件夹的名称
# 建立保存训练集的文件夹
train_data_root = os.path.join(data_root, "train") #训练集的文件夹名称为 train
make_file(train_data_root)
for num_class in data_class:
    # 建立每个类别对应的文件夹
    make_file(os.path.join(train_data_root, num_class))
# 建立保存验证集的文件夹
val_data_root = os.path.join(data_root, "val")#验证集的文件夹名称为 val
make_file(val_data_root)
for num_class in data_class:
    # 建立每个类别对应的文件夹
    make_file(os.path.join(val_data_root, num_class))
for num_class in data_class:
    num_class_path = os.path.join(data_path, num_class)
    images = os.listdir(num_class_path)
    num = len(images)
    val_index = random.sample(images, k=int(num*split_rate))   #随机抽取图片
    for index, image in enumerate(images):
        if image in val_index:
            # 将划分到验证集中的文件复制到相应目录
            data_image_path = os.path.join(num_class_path, image)
            val_new_path = os.path.join(val_data_root, num_class)
            copy(data_image_path, val_new_path)
        else:
            # 将划分到训练集中的文件复制到相应目录
            data_image_path = os.path.join(num_class_path, image)
            train_new_path = os.path.join(train_data_root, num_class)
            copy(data_image_path, train_new_path)
    print("\r[{}] split_rating [{}/{}]".format(num_class, index+1, num), end="")  # processing bar
    print()
print("       ")
print("       ")
print("划分完成")
相关文章
|
28天前
|
机器学习/深度学习 数据采集 TensorFlow
使用Python实现智能食品加工优化的深度学习模型
使用Python实现智能食品加工优化的深度学习模型
141 59
|
23天前
|
机器学习/深度学习 数据采集 TensorFlow
使用Python实现智能食品市场预测的深度学习模型
使用Python实现智能食品市场预测的深度学习模型
62 5
|
9天前
|
JSON 安全 API
Python调用API接口的方法
Python调用API接口的方法
47 5
|
25天前
|
机器学习/深度学习 算法 数据可视化
使用Python实现深度学习模型:智能食品配送优化
使用Python实现深度学习模型:智能食品配送优化
41 2
|
29天前
|
机器学习/深度学习 数据采集 数据库
使用Python实现智能食品营养分析的深度学习模型
使用Python实现智能食品营养分析的深度学习模型
63 6
|
24天前
|
机器学习/深度学习 人工智能 算法
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
手写数字识别系统,使用Python作为主要开发语言,基于深度学习TensorFlow框架,搭建卷积神经网络算法。并通过对数据集进行训练,最后得到一个识别精度较高的模型。并基于Flask框架,开发网页端操作平台,实现用户上传一张图片识别其名称。
67 0
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
|
24天前
|
机器学习/深度学习 人工智能 算法
基于深度学习的【蔬菜识别】系统实现~Python+人工智能+TensorFlow+算法模型
蔬菜识别系统,本系统使用Python作为主要编程语言,通过收集了8种常见的蔬菜图像数据集('土豆', '大白菜', '大葱', '莲藕', '菠菜', '西红柿', '韭菜', '黄瓜'),然后基于TensorFlow搭建卷积神经网络算法模型,通过多轮迭代训练最后得到一个识别精度较高的模型文件。在使用Django开发web网页端操作界面,实现用户上传一张蔬菜图片识别其名称。
67 0
基于深度学习的【蔬菜识别】系统实现~Python+人工智能+TensorFlow+算法模型
|
26天前
|
机器学习/深度学习 数据采集 TensorFlow
使用Python实现智能食品储存管理的深度学习模型
使用Python实现智能食品储存管理的深度学习模型
53 2
|
28天前
|
机器学习/深度学习 算法 PyTorch
用Python实现简单机器学习模型:以鸢尾花数据集为例
用Python实现简单机器学习模型:以鸢尾花数据集为例
73 1
|
1月前
|
机器学习/深度学习 供应链 安全
使用Python实现智能食品供应链管理的深度学习模型
使用Python实现智能食品供应链管理的深度学习模型
87 3