一般的深度学习训练模型的搭建框架过程为,导入数据-建立模型-训练与测试/迁移学习,在这篇笔记中,我主要记录了自定义一个自己的数据集过程与迁移学习的方法。对于其中涉及的到的训练过程与测试过程在其他的笔记中已有提到。
对于之前用到的MNIST数据集与Cifar10数据集的导入,其实我们都只是利用了pytorch提供的函数,分别是torchvision.datasets.MNIST与torchvision.datasets.CIFAR10两个函数帮助我们实现了样本数据的导入。但是,当我们需要训练我们自己的数据集时,具体的datasets操作函数便需要我们来编写。
对于我们设计自定义数据集类时,具体有三个步骤:
- 继承torch.utils.data中的Dataset类
- 编写 __ len __ ()函数
- 编写 __ getitem __ ()函数
源码中的Dataset如下:
class Dataset(Generic[T_co]): r"""An abstract class representing a :class:`Dataset`. All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:`__len__`, which is expected to return the size of the dataset by many :class:`~torch.utils.data.Sampler` implementations and the default options of :class:`~torch.utils.data.DataLoader`. .. note:: :class:`~torch.utils.data.DataLoader` by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided. """ def __getitem__(self, index) -> T_co: raise NotImplementedError def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]': return ConcatDataset([self, other]) # No `def __len__(self)` default? # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] # in pytorch/torch/utils/data/sampler.py
所以一个最基本的数据集类模型的原始为
class Pokemon(dataset): # 定义初始化函数 def __init__(self): pass # 定义返回样本数量函数:实现返回具体样本的数目 def __len__(self): pass # 定义返回具体样本的函数:实现读取一个具体的样本 def __getitem__(self, item): pass
完善后的参考代码:
import torch import torchvision from torch.utils.data import Dataset, DataLoader # 注意是Dataset而不是dataset from torchvision import transforms from PIL import Image from matplotlib import pyplot as plt from torchvision.utils import save_image from visdom import Visdom import os, glob, random, csv, time class Pokemon(Dataset): def __init__(self, root, resize, mode): super(Pokemon, self).__init__() self.root = root self.resize = resize self.image = [] self.label = [] # 创建一个字典存储类别与标签 self.name2label = {} for name in os.listdir(root): # 判断文件名是否为目录 if not os.path.isdir(os.path.join(root, name)): continue # 关键字的取值为当前的关键字个数 self.name2label[name] = len(self.name2label.keys()) # print(self.name2label.keys()) # dict_keys(['bulbasaur', 'charmander', 'mewtwo', 'pikachu']) # print(self.name2label) # {'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4} # 导入图像数据 # self.load_csv('images.csv') self.image, self.label = self.load_csv('images.csv') # 设置train-val-test比例 # nums: 700 if mode == 'train': self.image = self.image[:int(0.6 * len(self.image))] self.label = self.label[:int(0.6 * len(self.label))] # nums: 233 elif mode == 'val': self.image = self.image[int(0.6 * len(self.image)):int(0.8 * len(self.image))] self.label = self.label[int(0.6 * len(self.label)):int(0.8 * len(self.label))] # nums: 234 elif mode == 'test': self.image = self.image[int(0.8 * len(self.image)):] self.label = self.label[int(0.8 * len(self.label)):] else: print("Error! 'Mode' has no such mode choice!") def __len__(self): return len(self.image) def __getitem__(self, item): # item = self.__len__() # print(" item: ", item) image = self.image[item] label = self.label[item] # print("image: ",image,"label: ",label) # 对图像进行预处理 transform = transforms.Compose([ # 转换为RGB图像 lambda x: Image.open(x).convert('RGB'), # 重新确定尺寸 transforms.Resize((int(self.resize * 1.25), int(self.resize * 1.25))), # 旋转角度 transforms.RandomRotation(15), # 中心裁剪 transforms.CenterCrop(self.resize), # 转换为Tensor格式 transforms.ToTensor(), # 使数据分布在0附近 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image = transform(image) # label是int形式,转换为tensor格式 label = torch.tensor(label) return image, label # 导入csv样本数据 def load_csv(self, csv_file): # 当没有csv数据文件时创建文件, 将数据集信息保存在一个csv_file文件中 if not os.path.exists(os.path.join(self.root, csv_file)): # 用来存储图像路径信息 image = [] # 现查找数据集文件中的png,jpg,jpeg格式的全部图像,路径全部保存在image中 for name in self.name2label.keys(): # glob 模块用于查找符合特定规则的文件路径名 image += glob.glob(os.path.join(self.root, name, '*.png')) image += glob.glob(os.path.join(self.root, name, '*.jpg')) image += glob.glob(os.path.join(self.root, name, '*.jpeg')) # 'E:\\学习\\机器学习\\数据集\\pokemon\\bulbasaur\\00000000.png' # print(image, len(image)) # 随机打乱图像 random.shuffle(image) # 截取绝对路径下的图像名字 # name = next(iter(image)) # name = name.split('\\')[-2] # print(name) # charmander # 读写打开文件, 注意newline=''是为了不让存储的时候回车两行 with open(csv_file, mode='w', newline='') as f: # 创建 csv 对象 writer = csv.writer(f) for img in image: # split: 对路径进行分割,以列表形式返回 # os.sep: 当前操作系统所使用的路径分隔符 windows->'\' linux 和 unix->'/' # ['E:/学习/机器学习/数据集/pokemon', 'pikachu', '00000179.jpg'] # [-2]既提取了文件夹名字: 'pikachu' name = img.split(os.sep)[-2] label = self.name2label[name] # 写入一行或多行数据 # 形式: E:\学习\机器学习\数据集\pokemon\charmander\00000185.png,1 writer.writerow([img, label]) # print('writen into csv file:', csv_file) # 打开csv文件读取信息 with open(csv_file) as f: # 创建两个list存储图像名字与标签 image = [] label = [] # #创建 csv 对象,它是一个包含所有数据的列表,每一行为名字与标签,eg: charmander,1 reader = csv.reader(f) # 循环赋值各行内容 for row in reader: # 导入数据, 若没有设置newline=''会报错,因为回车了两行 image.append(row[0]) label.append(int(row[1])) # print(len(image), len(label)) if len(image) == len(label): return image, label else: print("Error! len(image) != len(label) !") def denormalize(self, x_hat): mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] # x_hat = (x-mean)/std # x = x_hat*std = mean # x: [c, h, w] # mean: [3] => [3, 1, 1] mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1) std = torch.tensor(std).unsqueeze(1).unsqueeze(1) # print(mean.shape, std.shape) x = x_hat * std + mean return x def plot_image(img): fig = plt.figure() for i in range(6): plt.subplot(2, 3, i + 1) plt.tight_layout() plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none') plt.xticks([]) plt.yticks([]) plt.show() root = 'E:\学习\机器学习\数据集\pokemon' viz = Visdom() train_data = Pokemon(root=root, resize=64, mode='train') # print(train_data.__len__()) # image, label = next(iter(train_data)) # print(image.shape, label) # 利用DataLoader加载数据集 data = DataLoader(train_data, batch_size=64, shuffle=True) # 测试 for epochodx, (image, label) in enumerate(data): # plot_image(train_data.denormalize(image)) # time.sleep(5) # 保存图像在本地 save_image(image, os.path.join('sample', 'image-{}.png'.format(epochodx + 1)), nrow=8, normalize=True) # 可视化操作 # viz.images(image, nrow=8, win='batch', opts=dict(title='batch')) viz.images(train_data.denormalize(image), nrow=8, win='batch', opts=dict(title='batch')) time.sleep(5)
在visdom中输出的结果是有点奇怪的
但是保存在本地的图像是可以正常显示的
大图
原因:因为我们本身对数据集进行的transforms操作中,包含了Normalize的操作,数据分布变成了在0附近的一个分布,这就代表有些数值是小于0的,而visdom本身只能显示大于0以上的像素,所以会出现这种情况,现在只需要将图像进行Denormalize操作即可:
def denormalize(self, x_hat): mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] # x_hat = (x-mean)/std # x = x_hat*std = mean # x: [c, h, w] # mean: [3] => [3, 1, 1] mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1) std = torch.tensor(std).unsqueeze(1).unsqueeze(1) # print(mean.shape, std.shape) x = x_hat * std + mean return x # 可视化操作作以下更改 # viz.images(image, nrow=8, win='batch', opts=dict(title='batch')) viz.images(train_data.denormalize(image), nrow=8, win='batch', opts=dict(title='batch'))
可以看见,图像变得正常许多
到此,我们实现了自定义数据集的加载操作,其中image.csv文件中数据存储的格式为图像的对应路径与便签,如图所示: