【14】自定义宝可梦数据集

简介: 【14】自定义宝可梦数据集

一般的深度学习训练模型的搭建框架过程为,导入数据-建立模型-训练与测试/迁移学习,在这篇笔记中,我主要记录了自定义一个自己的数据集过程与迁移学习的方法。对于其中涉及的到的训练过程与测试过程在其他的笔记中已有提到。


对于之前用到的MNIST数据集与Cifar10数据集的导入,其实我们都只是利用了pytorch提供的函数,分别是torchvision.datasets.MNIST与torchvision.datasets.CIFAR10两个函数帮助我们实现了样本数据的导入。但是,当我们需要训练我们自己的数据集时,具体的datasets操作函数便需要我们来编写。


对于我们设计自定义数据集类时,具体有三个步骤:


  1. 继承torch.utils.data中的Dataset类
  2. 编写 __ len __ ()函数
  3. 编写 __ getitem __ ()函数


源码中的Dataset如下:


class Dataset(Generic[T_co]):
    r"""An abstract class representing a :class:`Dataset`.
    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`.
    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs a index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.
    """
    def __getitem__(self, index) -> T_co:
        raise NotImplementedError
    def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
        return ConcatDataset([self, other])
    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    # in pytorch/torch/utils/data/sampler.py


所以一个最基本的数据集类模型的原始为


class Pokemon(dataset):
  # 定义初始化函数
    def __init__(self):
        pass
  # 定义返回样本数量函数:实现返回具体样本的数目
    def __len__(self):
        pass
  # 定义返回具体样本的函数:实现读取一个具体的样本
    def __getitem__(self, item):
        pass


完善后的参考代码:


import torch
import torchvision
from torch.utils.data import Dataset, DataLoader   # 注意是Dataset而不是dataset
from torchvision import transforms
from PIL import Image
from matplotlib import pyplot as plt
from torchvision.utils import save_image
from visdom import Visdom
import os, glob, random, csv, time
class Pokemon(Dataset):
    def __init__(self, root, resize, mode):
        super(Pokemon, self).__init__()
        self.root = root
        self.resize = resize
        self.image = []
        self.label = []
        # 创建一个字典存储类别与标签
        self.name2label = {}
        for name in os.listdir(root):
            # 判断文件名是否为目录
            if not os.path.isdir(os.path.join(root, name)):
                continue
            # 关键字的取值为当前的关键字个数
            self.name2label[name] = len(self.name2label.keys())
            # print(self.name2label.keys())
            # dict_keys(['bulbasaur', 'charmander', 'mewtwo', 'pikachu'])
        # print(self.name2label)
        # {'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}
        # 导入图像数据
        # self.load_csv('images.csv')
        self.image, self.label = self.load_csv('images.csv')
        # 设置train-val-test比例
        # nums: 700
        if mode == 'train':
            self.image = self.image[:int(0.6 * len(self.image))]
            self.label = self.label[:int(0.6 * len(self.label))]
        # nums: 233
        elif mode == 'val':
            self.image = self.image[int(0.6 * len(self.image)):int(0.8 * len(self.image))]
            self.label = self.label[int(0.6 * len(self.label)):int(0.8 * len(self.label))]
        # nums: 234
        elif mode == 'test':
            self.image = self.image[int(0.8 * len(self.image)):]
            self.label = self.label[int(0.8 * len(self.label)):]
        else:
            print("Error! 'Mode' has no such mode choice!")
    def __len__(self):
        return len(self.image)
    def __getitem__(self, item):
        # item = self.__len__()
        # print(" item: ", item)
        image = self.image[item]
        label = self.label[item]
        # print("image: ",image,"label: ",label)
        # 对图像进行预处理
        transform = transforms.Compose([
            # 转换为RGB图像
            lambda x: Image.open(x).convert('RGB'),
            # 重新确定尺寸
            transforms.Resize((int(self.resize * 1.25), int(self.resize * 1.25))),
            # 旋转角度
            transforms.RandomRotation(15),
            # 中心裁剪
            transforms.CenterCrop(self.resize),
            # 转换为Tensor格式
            transforms.ToTensor(),
            # 使数据分布在0附近
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        image = transform(image)
        # label是int形式,转换为tensor格式
        label = torch.tensor(label)
        return image, label
    # 导入csv样本数据
    def load_csv(self, csv_file):
        # 当没有csv数据文件时创建文件, 将数据集信息保存在一个csv_file文件中
        if not os.path.exists(os.path.join(self.root, csv_file)):
            # 用来存储图像路径信息
            image = []
            # 现查找数据集文件中的png,jpg,jpeg格式的全部图像,路径全部保存在image中
            for name in self.name2label.keys():
                # glob 模块用于查找符合特定规则的文件路径名
                image += glob.glob(os.path.join(self.root, name, '*.png'))
                image += glob.glob(os.path.join(self.root, name, '*.jpg'))
                image += glob.glob(os.path.join(self.root, name, '*.jpeg'))
            # 'E:\\学习\\机器学习\\数据集\\pokemon\\bulbasaur\\00000000.png'
            # print(image, len(image))
            # 随机打乱图像
            random.shuffle(image)
            # 截取绝对路径下的图像名字
            # name = next(iter(image))
            # name = name.split('\\')[-2]
            # print(name)   # charmander
            # 读写打开文件, 注意newline=''是为了不让存储的时候回车两行
            with open(csv_file, mode='w', newline='') as f:
                # 创建 csv 对象
                writer = csv.writer(f)
                for img in image:
                    # split: 对路径进行分割,以列表形式返回
                    # os.sep: 当前操作系统所使用的路径分隔符 windows->'\' linux 和 unix->'/'
                    # ['E:/学习/机器学习/数据集/pokemon', 'pikachu', '00000179.jpg']
                    # [-2]既提取了文件夹名字: 'pikachu'
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    # 写入一行或多行数据
                    # 形式: E:\学习\机器学习\数据集\pokemon\charmander\00000185.png,1
                    writer.writerow([img, label])
                # print('writen into csv file:', csv_file)
        # 打开csv文件读取信息
        with open(csv_file) as f:
            # 创建两个list存储图像名字与标签
            image = []
            label = []
            # #创建 csv 对象,它是一个包含所有数据的列表,每一行为名字与标签,eg: charmander,1
            reader = csv.reader(f)
            # 循环赋值各行内容
            for row in reader:
                # 导入数据, 若没有设置newline=''会报错,因为回车了两行
                image.append(row[0])
                label.append(int(row[1]))
            # print(len(image), len(label))
        if len(image) == len(label):
            return image, label
        else:
            print("Error! len(image) != len(label) !")
    def denormalize(self, x_hat):
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        # x_hat = (x-mean)/std
        # x = x_hat*std = mean
        # x: [c, h, w]
        # mean: [3] => [3, 1, 1]
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        # print(mean.shape, std.shape)
        x = x_hat * std + mean
        return x
def plot_image(img):
    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
        plt.xticks([])
        plt.yticks([])
    plt.show()
root = 'E:\学习\机器学习\数据集\pokemon'
viz = Visdom()
train_data = Pokemon(root=root, resize=64, mode='train')
# print(train_data.__len__())
# image, label = next(iter(train_data))
# print(image.shape, label)
# 利用DataLoader加载数据集
data = DataLoader(train_data, batch_size=64, shuffle=True)
# 测试
for epochodx, (image, label) in enumerate(data):
    # plot_image(train_data.denormalize(image))
    # time.sleep(5)
  # 保存图像在本地
    save_image(image, os.path.join('sample', 'image-{}.png'.format(epochodx + 1)), nrow=8, normalize=True)
  # 可视化操作
    # viz.images(image, nrow=8, win='batch', opts=dict(title='batch'))
    viz.images(train_data.denormalize(image), nrow=8, win='batch', opts=dict(title='batch'))
    time.sleep(5)


在visdom中输出的结果是有点奇怪的

image.png

但是保存在本地的图像是可以正常显示的

image.png

大图

image.png

原因:因为我们本身对数据集进行的transforms操作中,包含了Normalize的操作,数据分布变成了在0附近的一个分布,这就代表有些数值是小于0的,而visdom本身只能显示大于0以上的像素,所以会出现这种情况,现在只需要将图像进行Denormalize操作即可:


def denormalize(self, x_hat):
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        # x_hat = (x-mean)/std
        # x = x_hat*std = mean
        # x: [c, h, w]
        # mean: [3] => [3, 1, 1]
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        # print(mean.shape, std.shape)
        x = x_hat * std + mean
        return x
# 可视化操作作以下更改
# viz.images(image, nrow=8, win='batch', opts=dict(title='batch'))
viz.images(train_data.denormalize(image), nrow=8, win='batch', opts=dict(title='batch'))

可以看见,图像变得正常许多

image.png

到此,我们实现了自定义数据集的加载操作,其中image.csv文件中数据存储的格式为图像的对应路径与便签,如图所示:

image.png

目录
相关文章
|
SQL 测试技术
|
2月前
|
PyTorch 算法框架/工具
数据集学习笔记(三):调用不同数据集获取trainloader和testloader
本文介绍了如何使用PyTorch框架调用CIFAR10数据集,并获取训练和测试的数据加载器(trainloader和testloader)。
45 4
数据集学习笔记(三):调用不同数据集获取trainloader和testloader
|
1月前
|
存储 JSON API
如何创建自己的数据集!!!
本文介绍了如何创建和使用自定义数据集,特别是针对GitHub Issues的语料库。内容涵盖了从获取数据、清理数据到扩充数据集的全过程,最终将数据集上传到Hugging Face Hub并与社区分享。具体步骤包括使用GitHub REST API下载Issues,通过Python脚本进行数据处理,以及添加评论信息。此外,还介绍了如何创建数据集卡片,以提供详细的背景信息和使用指南。
35 0
|
4月前
|
自然语言处理
评估数据集CGoDial问题之数据集中包含哪些基线模型
评估数据集CGoDial问题之数据集中包含哪些基线模型
|
4月前
|
自然语言处理
评估数据集CGoDial问题之Doc2Bot数据集的问题如何解决
评估数据集CGoDial问题之Doc2Bot数据集的问题如何解决
|
7月前
|
SQL Oracle 关系型数据库
C# 利用IDbDataAdapter / IDataReader 实现通用数据集获取
C# 利用IDbDataAdapter / IDataReader 实现通用数据集获取
|
7月前
|
Python
创建模型
创建模型。
35 1
|
存储 编解码 数据安全/隐私保护
ISPRS Vaihingen 数据集解析
ISPRS Vaihingen 数据集解析
1261 0
ISPRS Vaihingen 数据集解析
|
XML JSON 算法
【数据集转换】VOC数据集转COCO数据集·代码实现+操作步骤
与VOC一个文件一个xml标注不同,COCO所有的目标框标注都是放在一个json文件中的。
1467 1
|
XML 数据可视化 数据格式
【数据集显示标注】VOC文件结构+数据集标注可视化+代码实现
【数据集显示标注】VOC文件结构+数据集标注可视化+代码实现
456 0