详细介绍torch中的from torch.utils.data.sampler相关知识

简介: PyTorch中的torch.utils.data.sampler模块提供了一些用于数据采样的类和函数,这些类和函数可以用于控制如何从数据集中选择样本。下面是一些常用的Sampler类和函数的介绍:Sampler基类:Sampler是一个抽象类,它定义了一个__iter__方法,返回一个迭代器,用于生成数据集中的样本索引。RandomSampler:随机采样器,它会随机从数据集中选择样本。可以设置随机数种子,以确保每次采样结果相同。SequentialSampler:顺序采样器,它会按照数据集中的顺序,依次选择样本。SubsetRandomSampler:子集随机采样器

PyTorch中的torch.utils.data.sampler模块提供了一些用于数据采样的类和函数,这些类和函数可以用于控制如何从数据集中选择样本。下面是一些常用的Sampler类和函数的介绍:

  1. Sampler基类: Sampler是一个抽象类,它定义了一个__iter__方法,返回一个迭代器,用于生成数据集中的样本索引。
  2. RandomSampler: 随机采样器,它会随机从数据集中选择样本。可以设置随机数种子,以确保每次采样结果相同。
  3. SequentialSampler: 顺序采样器,它会按照数据集中的顺序,依次选择样本。
  4. SubsetRandomSampler: 子集随机采样器,它会从数据集的指定子集中随机选择样本。可以用于将数据集分成训练集和验证集等子集。
  5. WeightedRandomSampler: 加权随机采样器,它会根据指定的样本权重,进行随机采样。可以用于处理类别不平衡的问题。
  6. BatchSampler: 批次采样器,它会将样本索引分成多个批次,每个批次包含指定数量的样本索引。

这些Sampler类可以通过在DataLoader的构造函数中指定来使用。例如,可以使用RandomSampler来实现随机采样,使用SubsetRandomSampler来实现将数据集分成训练集和验证集。此外,还可以使用函数如WeightedRandomSampler来实现加权随机采样。


下面是使用上述Sampler类和函数的示例代码:


import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler, SequentialSampler, SubsetRandomSampler, WeightedRandomSampler
# 创建一个数据集
dataset = torch.utils.data.TensorDataset(torch.randn(10, 3), torch.randint(0, 2, (10,)))
# 创建一个使用RandomSampler的DataLoader
random_loader = DataLoader(dataset, batch_size=2, sampler=RandomSampler(dataset))
# 创建一个使用SequentialSampler的DataLoader
seq_loader = DataLoader(dataset, batch_size=2, sampler=SequentialSampler(dataset))
# 创建一个使用SubsetRandomSampler的DataLoader
train_indices = [0, 1, 2, 3, 4]
val_indices = [5, 6, 7, 8, 9]
train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)
train_loader = DataLoader(dataset, batch_size=2, sampler=train_sampler)
val_loader = DataLoader(dataset, batch_size=2, sampler=val_sampler)
# 创建一个使用WeightedRandomSampler的DataLoader
weights = [0.1, 0.9]
weighted_sampler = WeightedRandomSampler(weights, num_samples=10, replacement=True)
weighted_loader = DataLoader(dataset, batch_size=2, sampler=weighted_sampler)
# 使用BatchSampler将样本索引分成多个批次
batch_sampler = torch.utils.data.sampler.BatchSampler(SequentialSampler(dataset), batch_size=2, drop_last=False)
batch_loader = DataLoader(dataset, batch_sampler=batch_sampler)
# 遍历DataLoader,输出每个批次的数据
for data, label in random_loader:
    print(data, label)
for data, label in seq_loader:
    print(data, label)
for data, label in train_loader:
    print(data, label)
for data, label in val_loader:
    print(data, label)
for data, label in weighted_loader:
    print(data, label)
for batch_indices in batch_sampler:
    batch_data = [dataset[idx] for idx in batch_indices]
    print(batch_data)

在这个示例中,我们首先创建了一个包含10个样本的TensorDataset。然后,我们创建了5个不同的DataLoader,每个DataLoader使用不同的采样器(RandomSampler、SequentialSampler、SubsetRandomSampler、WeightedRandomSampler、BatchSampler)来从数据集中选择样本。最后,我们遍历这些DataLoader,输出每个批次的数据。


可以通过继承Sampler基类来自定义采样函数。自定义采样函数需要实现__iter__方法和__len__方法。

__iter__方法需要返回一个迭代器,迭代器的每个元素都是数据集中的一个样本的索引。在这个方法中,可以自定义样本索引的选取方式,例如根据某种规则筛选样本或者将数据集分成多个子集。

__len__方法需要返回采样器的样本数量。如果采样器使用的是数据集的全部样本,则返回数据集的长度。

下面是一个自定义采样器的示例代码:


import torch
from torch.utils.data.sampler import Sampler
class CustomSampler(Sampler):
    def __init__(self, data_source):
        self.data_source = data_source
        # 在初始化方法中,可以根据需要对数据集进行处理
    def __iter__(self):
        # 在这个方法中,可以自定义样本索引的选取方式
        # 这里的示例是随机选取样本
        indices = torch.randperm(len(self.data_source)).tolist()
        return iter(indices)
    def __len__(self):
        # 在这个方法中,需要返回采样器的样本数量
        # 这里的示例是采样器的样本数量等于数据集的长度
        return len(self.data_source)

在这个示例中,我们定义了一个名为CustomSampler的采样器类,它继承自Sampler基类。在初始化方法中,我们保存了数据集,并可以根据需要对数据集进行处理。在__iter__方法中,我们自定义了样本索引的选取方式,这里的示例是随机选取样本。在__len__方法中,我们返回了采样器的样本数量,这里的示例是采样器的样本数量等于数据集的长度。

使用自定义采样器时,只需要将它传入DataLoader的构造函数即可:


dataset = torch.utils.data.TensorDataset(torch.randn(10, 3), torch.randint(0, 2, (10,)))
custom_sampler = CustomSampler(dataset)
loader = DataLoader(dataset, batch_size=2, sampler=custom_sampler)

在这个示例中,我们首先创建了一个包含10个样本的TensorDataset。然后,我们使用CustomSampler创建了一个采样器,并将它传入DataLoader的构造函数。最后,我们遍历这个DataLoader,输出每个批次的数据。

相关文章
|
5月前
|
PyTorch 算法框架/工具
【chat-gpt问答记录】torch.tensor和torch.Tensor什么区别?
【chat-gpt问答记录】torch.tensor和torch.Tensor什么区别?
127 2
|
24天前
|
TensorFlow 算法框架/工具
Tensorflow error(二):x and y must have the same dtype, got tf.float32 != tf.int32
本文讨论了TensorFlow中的一个常见错误,即在计算过程中,变量的数据类型(dtype)不一致导致的错误,并通过使用`tf.cast`函数来解决这个问题。
18 0
|
3月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
【Tensorflow+keras】解决cuDNN launch failure : input shape ([32,2,8,8]) [[{{node sequential_1/batch_nor
在使用TensorFlow 2.0和Keras训练生成对抗网络(GAN)时,遇到了“cuDNN launch failure”错误,特别是在调用self.generator.predict方法时出现,输入形状为([32,2,8,8])。此问题可能源于输入数据形状与模型期望的形状不匹配或cuDNN版本不兼容。解决方案包括设置GPU内存增长、检查模型定义和输入数据形状、以及确保TensorFlow和cuDNN版本兼容。
42 1
|
3月前
|
TensorFlow 算法框架/工具
【Tensorflow+Keras】tf.keras.backend.image_data_format()的解析与举例使用
介绍了TensorFlow和Keras中tf.keras.backend.image_data_format()函数的用法。
49 5
|
3月前
|
TensorFlow API 算法框架/工具
【Tensorflow+keras】解决使用model.load_weights时报错 ‘str‘ object has no attribute ‘decode‘
python 3.6,Tensorflow 2.0,在使用Tensorflow 的keras API,加载权重模型时,报错’str’ object has no attribute ‘decode’
49 0
|
4月前
|
PyTorch 算法框架/工具 机器学习/深度学习
|
6月前
|
PyTorch 算法框架/工具
ImportError: cannot import name ‘_DataLoaderIter‘ from ‘torch.utils.data.dataloader‘
ImportError: cannot import name ‘_DataLoaderIter‘ from ‘torch.utils.data.dataloader‘
71 2
|
6月前
|
存储 PyTorch 算法框架/工具
torch.Storage()是什么?和torch.Tensor()有什么区别?
torch.Storage()是什么?和torch.Tensor()有什么区别?
40 1
torch.argmax(dim=1)用法
)torch.argmax(input, dim=None, keepdim=False)返回指定维度最大值的序号;
632 0
|
机器学习/深度学习 PyTorch API
Torch
Torch是一个用于构建深度学习模型的开源机器学习库,它基于Lua编程语言。然而,由于PyTorch的出现,现在通常所说的"torch"指的是PyTorch。PyTorch是一个基于Torch的Python库,它提供了一个灵活而高效的深度学习框架。
264 1