> 本文从DataSet、DataLoader和Sampler的关系出发,介绍Pytorch实现的五种采样,并应用到DataLoader中。
✨1 DataSet、DataLoader和Sampler的关系
我们知道DataSet建立数据集,本质是读取一张张图像。而DataLoader是将DataSet中的图像一个个取出来,打包成一个个batch。
但是这里存在一个问题,DataLoader从Dataet中是如何取一张张图像的,该问题对我们训练也有影响:
假设我们数据集是按照类别放在一起的,那么DataSet的读取的图像也是按照类别放在一起的。此时,如果DataLoader顺序读取打包,则可能出现每个batch中都是同一个类别的图像。这就会影响我们模型的训练效果。
因此需要Sampler决定打包时的读取图像的顺序。这就是三者之间的关系。
✨2 Sampler
Pytorch中实现了五种Sampler:
- SequentialSampler(顺序采样)
- RandomSampler(随机采样)
- WeightedSampler(加权随机采样)
- SubsetRandomSampler(子集随机采样)
- BatchSampler(批采样)
(其中1,2,5可应用到DataLoader中,第三节详细展开)
🎃 2.1 SequentialSampler(顺序采样)
用于获取数据索引
torch.utils.data.SequentialSampler( data_source, )
参数:
- data_source:可迭代数据,一般为数据集
返回:
顺序返回数据集索引
示例:
🎉 2.2 RandomSampler(随即采样)
用于获取打乱的数据索引
torch.utils.data.SequentialSampler( data_source, num_samples, replacement, )
参数:
data_source:同上
num_samples:指定采样的数量,默认是所有
replacement:若为True,则表示可以重复采样,即同一个样本可以重复采样,这样可能导致有的样本采样不到。所以此时我们可以设置num_samples来增加采样数量使得每个样本都可能被采样到。
返回:
乱序返回数据集索
🎄2.3 BatchSampler(批采样)
BatchSampler将前面的Sampler采样得到的单个的索引值进行合并,当数量等于一个batch大小后就将这一批的索引值返回。(训练时使用的是批量数据)
torch.utils.data.BatchSampler( sampler, batch_size, drop_last, )
参数:
sampler:上述两种采样器,即SequentialSampler或RandomSampler
batch_size:batch的大小
drop_last:True或False。drop_last为True时,如果采样得到的数据个数小于batch_size则抛弃本个batch的数据。
返回:
分组完成的数据索引shape=(num_data/batch_size, batch_size)
比较抽象,下面举一个例子:
import torch.utils.data from torch.utils.data import BatchSampler, SequentialSampler, RandomSampler, SubsetRandomSampler, WeightedRandomSampler a = [1,5,78,9,68] b = BatchSampler(a, 2, False) print(list(b))
可以看到已经分成三组,每组大小都是设置的batch_size=2。而drop_last=False,并未去掉于batch_size的分组。
🎄2.4 SubsetRandomSampler(子集随机采样)
torch.utils.data.SubsetRandomSampler( indices )
参数:
- indices:数据集索引
返回:
与上面返回数据的索引不同,这里返回的是对应索引的数据本身
该方法更多应用于切分数据集,比如:
import torch.utils.data from torch.utils.data import BatchSampler, SequentialSampler, RandomSampler, SubsetRandomSampler a = [1,5,78,9,68] b1 = torch.utils.data.SubsetRandomSampler(a[:3]) b2 = torch.utils.data.SubsetRandomSampler(a[3:]) for x in b1: print("train:", x) for x in b2: print("val:", x)
🎃 2.5 WeightedRandomSampler(加权随机采样)
torch.utils.data.WeightedRandomSampler( weights, num_samples, replacement=True, )
参数:
weights:采样到该索引的权重
num_samples:指定采样的数量,默认是所有
replacement:若为True,则表示可以重复采样,即同一个样本可以重复采样,这样可能导致有的样本采样不到。所以此时我们可以设置num_samples来增加采样数量使得每个样本都可能被采样到。
返回:
与上面返回数据的索引不同,这里返回的是对应索引的数据本身
示例代码:
import torch.utils.data from torch.utils.data import BatchSampler, SequentialSampler, RandomSampler, SubsetRandomSampler, WeightedRandomSampler a = [1,5,78,9,68] weights = [0, 3, 1.1, 1.1, 1.1, 1.1, 1.1] b = WeightedRandomSampler(weights, 7, replacement=True) for i in b: print(i)
代码中,replacement
设置为True,允许重复采样后,由于位置1的权重为3比较大,因此被采样次数较多。
✨3 应用
了解上面五种Sampler后,如何在我们的项目中使用是重点:
- 采用
- DataLoader应用
🎃 3.1 采样
首先,创建顺序采样或随机采样,比如:
sampler = torch.utils.data.RandomSampler(train_dataset) # train_dataset,自定义数据集(重载的DataSet)
其次,在上面的基础上创建批采样:
batch_sampler_train = torch.utils.data.BatchSampler(sampler, 16, drop_last=True)
结果类似:
🎉 3.2 DataLoader应用
其中,指定顺序采样或随机采样用到DatLoader的参数sampler
。而指定批采样的参数是batch_sampler
。
由于参数之间可能冲突,使用时分为以下几种情况:
sampler和batch_sampler都为None:batch_sampler使用Pytorch实现的批采样,而sampler分为两种情况
====================================================================
a). shuffle=True,则sampler使用随机采样
b). shuffle=False,则sampler使用顺序采样====================================================================
自定义了batch_sampler,那么batch_size,shuffle,sampler,drop_last必须都是默认值
自定义了sampler,此时batch_sampler不能再指定,且shuffle必须为False。