PyTorch中的torch.utils.data.sampler模块提供了一些用于数据采样的类和函数,这些类和函数可以用于控制如何从数据集中选择样本。下面是一些常用的Sampler类和函数的介绍:
- Sampler基类: Sampler是一个抽象类,它定义了一个__iter__方法,返回一个迭代器,用于生成数据集中的样本索引。
- RandomSampler: 随机采样器,它会随机从数据集中选择样本。可以设置随机数种子,以确保每次采样结果相同。
- SequentialSampler: 顺序采样器,它会按照数据集中的顺序,依次选择样本。
- SubsetRandomSampler: 子集随机采样器,它会从数据集的指定子集中随机选择样本。可以用于将数据集分成训练集和验证集等子集。
- WeightedRandomSampler: 加权随机采样器,它会根据指定的样本权重,进行随机采样。可以用于处理类别不平衡的问题。
- BatchSampler: 批次采样器,它会将样本索引分成多个批次,每个批次包含指定数量的样本索引。
这些Sampler类可以通过在DataLoader的构造函数中指定来使用。例如,可以使用RandomSampler来实现随机采样,使用SubsetRandomSampler来实现将数据集分成训练集和验证集。此外,还可以使用函数如WeightedRandomSampler来实现加权随机采样。
下面是使用上述Sampler类和函数的示例代码:
import torch from torch.utils.data import DataLoader from torch.utils.data.sampler import RandomSampler, SequentialSampler, SubsetRandomSampler, WeightedRandomSampler # 创建一个数据集 dataset = torch.utils.data.TensorDataset(torch.randn(10, 3), torch.randint(0, 2, (10,))) # 创建一个使用RandomSampler的DataLoader random_loader = DataLoader(dataset, batch_size=2, sampler=RandomSampler(dataset)) # 创建一个使用SequentialSampler的DataLoader seq_loader = DataLoader(dataset, batch_size=2, sampler=SequentialSampler(dataset)) # 创建一个使用SubsetRandomSampler的DataLoader train_indices = [0, 1, 2, 3, 4] val_indices = [5, 6, 7, 8, 9] train_sampler = SubsetRandomSampler(train_indices) val_sampler = SubsetRandomSampler(val_indices) train_loader = DataLoader(dataset, batch_size=2, sampler=train_sampler) val_loader = DataLoader(dataset, batch_size=2, sampler=val_sampler) # 创建一个使用WeightedRandomSampler的DataLoader weights = [0.1, 0.9] weighted_sampler = WeightedRandomSampler(weights, num_samples=10, replacement=True) weighted_loader = DataLoader(dataset, batch_size=2, sampler=weighted_sampler) # 使用BatchSampler将样本索引分成多个批次 batch_sampler = torch.utils.data.sampler.BatchSampler(SequentialSampler(dataset), batch_size=2, drop_last=False) batch_loader = DataLoader(dataset, batch_sampler=batch_sampler) # 遍历DataLoader,输出每个批次的数据 for data, label in random_loader: print(data, label) for data, label in seq_loader: print(data, label) for data, label in train_loader: print(data, label) for data, label in val_loader: print(data, label) for data, label in weighted_loader: print(data, label) for batch_indices in batch_sampler: batch_data = [dataset[idx] for idx in batch_indices] print(batch_data)
在这个示例中,我们首先创建了一个包含10个样本的TensorDataset。然后,我们创建了5个不同的DataLoader,每个DataLoader使用不同的采样器(RandomSampler、SequentialSampler、SubsetRandomSampler、WeightedRandomSampler、BatchSampler)来从数据集中选择样本。最后,我们遍历这些DataLoader,输出每个批次的数据。
可以通过继承Sampler基类来自定义采样函数。自定义采样函数需要实现__iter__方法和__len__方法。
__iter__方法需要返回一个迭代器,迭代器的每个元素都是数据集中的一个样本的索引。在这个方法中,可以自定义样本索引的选取方式,例如根据某种规则筛选样本或者将数据集分成多个子集。
__len__方法需要返回采样器的样本数量。如果采样器使用的是数据集的全部样本,则返回数据集的长度。
下面是一个自定义采样器的示例代码:
import torch from torch.utils.data.sampler import Sampler class CustomSampler(Sampler): def __init__(self, data_source): self.data_source = data_source # 在初始化方法中,可以根据需要对数据集进行处理 def __iter__(self): # 在这个方法中,可以自定义样本索引的选取方式 # 这里的示例是随机选取样本 indices = torch.randperm(len(self.data_source)).tolist() return iter(indices) def __len__(self): # 在这个方法中,需要返回采样器的样本数量 # 这里的示例是采样器的样本数量等于数据集的长度 return len(self.data_source)
在这个示例中,我们定义了一个名为CustomSampler的采样器类,它继承自Sampler基类。在初始化方法中,我们保存了数据集,并可以根据需要对数据集进行处理。在__iter__方法中,我们自定义了样本索引的选取方式,这里的示例是随机选取样本。在__len__方法中,我们返回了采样器的样本数量,这里的示例是采样器的样本数量等于数据集的长度。
使用自定义采样器时,只需要将它传入DataLoader的构造函数即可:
dataset = torch.utils.data.TensorDataset(torch.randn(10, 3), torch.randint(0, 2, (10,))) custom_sampler = CustomSampler(dataset) loader = DataLoader(dataset, batch_size=2, sampler=custom_sampler)
在这个示例中,我们首先创建了一个包含10个样本的TensorDataset。然后,我们使用CustomSampler创建了一个采样器,并将它传入DataLoader的构造函数。最后,我们遍历这个DataLoader,输出每个批次的数据。