DataLoader 与 Dataset
深度学习模型训练一般流程
torch.utils.data.DataLoader
功能:构建可迭代的数据装载器
dataset: Dataset类,决定数据从哪读取及如何读取
batchsize : 批大小
num_works: 是否多进程读取数据
shuffle: 每个epoch是否乱序
drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据
概念辨析:
Epoch: 所有训练样本都已输入到模型中,称为一个Epoch
Iteration:一批样本输入到模型中,称之为一个Iteration
Batchsize:批大小,决定一个Epoch有多少个Iteration
e.g.:
class MyDataset(Dataset): def __init__(self, data_dir, transforms=None): super().__init__() self.Label = {'1':0, '100',1} self.data_info = self.get_img_info(data_dir) # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本 self.transform = transform def __getitem__(self, index): path_img, label = self.data_info[index] img = Image.open(path_img).convert('RGB') # 0~255 if self.transform is not None: img = self.transform(img) # 在这里做transform,转为tensor等等 return img, label def __len__(self): return len(self.data_info)
torch.utils.data.Dataset
功能:Dataset抽象类,所有自定义的Dataset需要继承它,并且复写__getitem__()
getitem :接收一个索引,返回一个样本
数据读取流程
transforms
torchvision.transforms : 常用的图像预处理方法
数据中心化
数据标准化
缩放
裁剪
旋转
翻转
填充
噪声添加
灰度变换
线性变换
仿射变换
亮度、饱和度及对比度变换
作用位置:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-WEO1DU2s-1637072332139)(https://i.loli.net/2021/11/16/9c4xH5n6JegIPjS.png)]
transforms图像增强
数据增强又称为数据增广,数据扩增,它是对训练集进行变换,使训练集更丰富,从而让模型更具泛化能力
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-2IQ1OkSu-1637072332143)(https://i.loli.net/2021/11/16/OxMqR7fTFnL8VQs.png)]
transforms——Crop
transforms.CenterCrop
功能:从图像中心裁剪图片
transforms.RandomCrop
功能:从图片中随机裁剪出尺寸为size的图片
RandomResizedCrop
功能:随机大小、长宽比裁剪图片
FiveCrop
TenCrop
功能:在图像的上下左右以及中心裁剪出尺寸为size的5张图片,TenCrop对这5张图片进行水平或者垂直镜像获得10张图片
transforms——Flip and Rotation
RandomHorizontalFlip
RandomVerticalFlip
功能:依概率水平(左右)或垂直(上下)翻转图片
RandomRotation
功能:随机旋转图片
**degrees:**旋转角度
当为a时,在(-a,a)之间选择旋转角度
当为(a, b)时,在(a, b)之间选择旋转角度
**resample:**重采样方法
**expand:**是否扩大图片,以保持原图信息
图像变换
Pad
功能:对图片边缘进行填充
transforms.Pad(padding, fill=0, padding_mode='constant')
2.ColorJitter
功能:调整亮度、对比度、饱和度和色相
transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)
3.Grayscale
4.RandomGrayscale
功能:依概率将图片转换为灰度图
RandomGrayscale(num_output_channels, p=0.1) Grayscale(num_output_channels)
5.RandomAffine
功能:对图像进行仿射变换,仿射变换是二维的线性变换,由五种基本原子变换构成,分别是旋转、平移、缩放、错切和翻转
RandomAffine(degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0)
6.RandomErasing
功能:对图像进行随机遮挡
7.transforms.Lambda
功能:用户自定义lambda方法
transforms.RandomChoice
transforms.RandomChoice([transforms1, transforms2, transforms3])
功能:从一系列transforms方法中随机挑选一个
transforms.RandomApply
transforms.RandomApply([transforms1, transforms2, transforms3], p=0.5)
功能:依据概率执行一组transforms操作
transforms.RandomOrder
transforms.RandomOrder([transforms1, transforms2, transforms3])
功能:对一组transforms操作打乱顺序
自定义transforms
自定义transforms要素:
仅接收一个参数,返回一个参数
注意上下游的输出与输入
class YourTransforms(object): def __init__(self, ...): def __call__(self, img): return img
总结:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-z9YuW2g3-1637072332145)(https://i.loli.net/2021/11/16/nC9JMbkfG1dZQeT.png)]
数据增强实战
原则:让训练集与测试集更接近
ansforms要素:
仅接收一个参数,返回一个参数
注意上下游的输出与输入
class YourTransforms(object): def __init__(self, ...): def __call__(self, img): return img
总结:
[外链图片转存中…(img-z9YuW2g3-1637072332145)]
数据增强实战
原则:让训练集与测试集更接近