一文读懂Pytorh Sampler

简介: 一文读懂Pytorh Sampler

> 本文从DataSet、DataLoader和Sampler的关系出发,介绍Pytorch实现的五种采样,并应用到DataLoader中。

✨1 DataSet、DataLoader和Sampler的关系

我们知道DataSet建立数据集,本质是读取一张张图像。而DataLoader是将DataSet中的图像一个个取出来,打包成一个个batch。

但是这里存在一个问题,DataLoader从Dataet中是如何取一张张图像的,该问题对我们训练也有影响:

假设我们数据集是按照类别放在一起的,那么DataSet的读取的图像也是按照类别放在一起的。此时,如果DataLoader顺序读取打包,则可能出现每个batch中都是同一个类别的图像。这就会影响我们模型的训练效果。

因此需要Sampler决定打包时的读取图像的顺序。这就是三者之间的关系。

✨2 Sampler

Pytorch中实现了五种Sampler:

  1. SequentialSampler(顺序采样)
  2. RandomSampler(随机采样)
  3. WeightedSampler(加权随机采样)
  4. SubsetRandomSampler(子集随机采样)
  5. BatchSampler(批采样)

(其中1,2,5可应用到DataLoader中,第三节详细展开)

🎃 2.1 SequentialSampler(顺序采样)

用于获取数据索引

torch.utils.data.SequentialSampler(
  data_source,
)

参数:

  1. data_source:可迭代数据,一般为数据集

返回:

顺序返回数据集索引

示例:

cbdb92202fda446bb993bb66b8c9828d.png

🎉 2.2 RandomSampler(随即采样)

用于获取打乱的数据索引

torch.utils.data.SequentialSampler(
  data_source,
  num_samples,
  replacement,
)

参数:

data_source:同上

num_samples:指定采样的数量,默认是所有

replacement:若为True,则表示可以重复采样,即同一个样本可以重复采样,这样可能导致有的样本采样不到。所以此时我们可以设置num_samples来增加采样数量使得每个样本都可能被采样到。

返回:

乱序返回数据集索

29e1c21b893d442598eaa92240823549.png

🎄2.3 BatchSampler(批采样)

BatchSampler将前面的Sampler采样得到的单个的索引值进行合并,当数量等于一个batch大小后就将这一批的索引值返回。(训练时使用的是批量数据)

torch.utils.data.BatchSampler(
  sampler,
  batch_size, 
  drop_last,
)

参数:

sampler:上述两种采样器,即SequentialSampler或RandomSampler

batch_size:batch的大小

drop_last:True或False。drop_last为True时,如果采样得到的数据个数小于batch_size则抛弃本个batch的数据。

返回:

分组完成的数据索引shape=(num_data/batch_size, batch_size)

比较抽象,下面举一个例子:

import torch.utils.data
from torch.utils.data import BatchSampler, SequentialSampler, RandomSampler, SubsetRandomSampler, WeightedRandomSampler
a = [1,5,78,9,68]
b = BatchSampler(a, 2, False)
print(list(b))

8febe285a32b43f3b8234c29997eb5b6.png

可以看到已经分成三组,每组大小都是设置的batch_size=2。而drop_last=False,并未去掉于batch_size的分组。

🎄2.4 SubsetRandomSampler(子集随机采样)

torch.utils.data.SubsetRandomSampler(
  indices
)

参数:

  1. indices:数据集索引

返回:

与上面返回数据的索引不同,这里返回的是对应索引的数据本身

该方法更多应用于切分数据集,比如:

import torch.utils.data
from torch.utils.data import BatchSampler, SequentialSampler, RandomSampler, SubsetRandomSampler
a = [1,5,78,9,68]
b1 = torch.utils.data.SubsetRandomSampler(a[:3])
b2 = torch.utils.data.SubsetRandomSampler(a[3:])
for x in b1:
    print("train:", x)
for x in b2:
    print("val:", x)

213d890f3445464994bc7d4c6435ada6.png

🎃 2.5 WeightedRandomSampler(加权随机采样)

torch.utils.data.WeightedRandomSampler(
   weights, 
   num_samples, 
   replacement=True,
)

参数:

weights:采样到该索引的权重

num_samples:指定采样的数量,默认是所有

replacement:若为True,则表示可以重复采样,即同一个样本可以重复采样,这样可能导致有的样本采样不到。所以此时我们可以设置num_samples来增加采样数量使得每个样本都可能被采样到。

返回:

与上面返回数据的索引不同,这里返回的是对应索引的数据本身

示例代码:

import torch.utils.data
from torch.utils.data import BatchSampler, SequentialSampler, RandomSampler, SubsetRandomSampler, WeightedRandomSampler
a = [1,5,78,9,68]
weights = [0, 3, 1.1, 1.1, 1.1, 1.1, 1.1]
b = WeightedRandomSampler(weights, 7, replacement=True)
for i in b:
    print(i)

cf19c69f375949daa2be4f3326cf6f00.png

代码中,replacement设置为True,允许重复采样后,由于位置1的权重为3比较大,因此被采样次数较多。

✨3 应用

了解上面五种Sampler后,如何在我们的项目中使用是重点:

  1. 采用
  2. DataLoader应用

🎃 3.1 采样

首先,创建顺序采样或随机采样,比如:

sampler = torch.utils.data.RandomSampler(train_dataset)  # train_dataset,自定义数据集(重载的DataSet)

其次,在上面的基础上创建批采样:

batch_sampler_train = torch.utils.data.BatchSampler(sampler, 16, drop_last=True)

结果类似:

f5dc3dc9d4c046c7be18408185bf4c14.png

🎉 3.2 DataLoader应用

其中,指定顺序采样或随机采样用到DatLoader的参数sampler。而指定批采样的参数是batch_sampler

由于参数之间可能冲突,使用时分为以下几种情况:

sampler和batch_sampler都为None:batch_sampler使用Pytorch实现的批采样,而sampler分为两种情况

====================================================================

a). shuffle=True,则sampler使用随机采样

b). shuffle=False,则sampler使用顺序采样====================================================================

自定义了batch_sampler,那么batch_size,shuffle,sampler,drop_last必须都是默认值

自定义了sampler,此时batch_sampler不能再指定,且shuffle必须为False。


相关文章
|
JavaScript
面试官:v-model原理?
面试官:v-model原理?
249 2
|
存储 缓存 Rust
一文读懂 Deno
一文读懂 Deno
269 0
|
5月前
|
安全 Windows
|
7月前
|
机器学习/深度学习 PyTorch 算法框架/工具
详解Batch Normalization并基于PyTorch实操(附代码)
详解Batch Normalization并基于PyTorch实操(附代码)
183 2
|
机器学习/深度学习 存储 算法
一文读懂K-Means原理与Python实现
在本文中,你将学习到K-means算法的数学原理,作者会以尼日利亚音乐数据集为案例。带你了解了如何通过可视化的方式发现数据中潜在的特征。最后对训练好的K-means模型进行评估。
331 0
|
前端开发 JavaScript 程序员
【万字长文】通过grunt、gulp和fit,彻底搞懂前端的自动化构建(三)
【万字长文】通过grunt、gulp和fit,彻底搞懂前端的自动化构建
123 0
|
移动开发 资源调度 前端开发
【万字长文】通过grunt、gulp和fit,彻底搞懂前端的自动化构建(一)
【万字长文】通过grunt、gulp和fit,彻底搞懂前端的自动化构建
148 0
|
前端开发 JavaScript API
【万字长文】通过grunt、gulp和fit,彻底搞懂前端的自动化构建(二)
【万字长文】通过grunt、gulp和fit,彻底搞懂前端的自动化构建
248 0
|
索引
【Pytorch--代码技巧】各种论文代码常见技巧
博主在阅读论文原代码的时候常常看见一些没有见过的代码技巧,特此将这些内容进行汇总
178 0
|
机器学习/深度学习 算法框架/工具
5分钟入门GANS:原理解释和keras代码实现
5分钟入门GANS:原理解释和keras代码实现
231 0
5分钟入门GANS:原理解释和keras代码实现