【单点知识】基于实例详解PyTorch中的DataLoader类

简介: 【单点知识】基于实例详解PyTorch中的DataLoader类

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

在深度学习中,数据的预处理和加载方式对模型训练的效率与效果具有重要影响。PyTorch提供了一种强大的工具——DataLoader,它能够高效地将数据集转化为适合模型训练的小批量数据,并支持多线程并行加载机制,极大地提升了数据读取速度。本文将详细介绍PyTorch中的DataLoader。


本文的说明参考了PyTorch官网文件:https://pytorch.org/docs/stable/data.html


DataLoader通常会结合ImageFolder和transforms类(即构建Dataset过程)一起使用,这两个类已经在此前文章中专题说明过:


1. DataLoader的功能

根据DataLoader的官方文档说明,将从以下5个方面说明DataLoader的功能:


1.1 可处理映射式/可迭代式数据集

PyTorch 的 DataLoader 能够处理两种形式的数据集:映射式数据集(map-style)和可迭代式数据集(iterable-style)。映射式数据集指的是那些可以通过索引直接访问其元素的数据集,它们需要实现 __getitem__ 方法和可选的 __len__ 方法。例如,torch.utils.data.Dataset 子类就是这样一种数据集,可以通过 dataset[i] 获取第 i 个样本。


可迭代式数据集则不依赖于索引访问,而是可以直接迭代出数据,这类似于 Python 中的迭代器协议。DataLoader 能够兼容这两种风格的数据集,并将它们转化为合适的形式以供模型训练使用。


1.2 可自定义数据加载顺序

DataLoader 支持通过 sampler 参数来自定义数据加载的顺序。默认情况下,如果设置了 shuffle=True,那么在每个 epoch 开始时,数据加载器会打乱数据集的顺序。但也可以传入自定义的采样器,如 RandomSampler 实现随机抽样,或 SequentialSampler 保持原有顺序,甚至可以使用 WeightedRandomSampler 来实现加权随机抽样等复杂逻辑。


1.3 可自动批量化打包数据

DataLoader 自动将数据集中的样本打包成小批量,这是通过设置 batch_size 参数来实现的。每次调用 DataLoader 的迭代器时,都会返回一个包含 batch_size 个样本的数据批次,这对于训练深度学习模型是非常关键的,因为大多数模型都需要按照批次进行前向传播和反向传播计算。


1.4 可支持多进程加载

为了加速数据加载过程,特别是对于大型数据集,DataLoader 提供了多进程支持。通过设置 num_workers 参数,可以启动多个工作进程来并发地加载数据。这意味着数据准备和模型训练可以同时进行,极大地提高了整体效率。需要注意的是,启用多进程数据加载时需要考虑数据集是否线程安全,并确保在CPU或系统资源充足的环境下运行。


注意:num_workers 参数并不一定是越大越好。以下是考虑 num_workers 设置时需要权衡的因素:


  1. 系统资源


  • CPU核心数num_workers 应设置得小于或等于可用 CPU 核心数(包括超线程)。若设置得过高,可能会导致过多的上下文切换,反而降低性能。
  • 内存限制:增加 num_workers 可能会增加内存消耗,因为每个工作进程都会缓存一部分数据。过高的 num_workers 可能会导致内存溢出。
  1. I/O 瓶颈


  • 如果数据读取主要是由硬盘 I/O 速率决定的瓶颈,那么超过某个点后增加 num_workers 不会带来进一步的速度提升,反而可能由于争抢 I/O 资源而造成负面影响。
  1. GPU 同步


  • 当数据加载速度远大于 GPU 计算速度时,更多的 num_workers 可能不会显著提高训练效率,因为 GPU 处理速度成为瓶颈。
  1. 同步与异步行为


  • PyTorch 的 DataLoader 默认实现了一个队列系统来进行数据加载的同步操作。过多的 num_workers 可能会导致队列中积累过多的数据,这些数据在被 GPU 使用前需要等待,因此并不会提高整体吞吐量。

通常的经验值可能是把num_workers设定在 CPU 核心数的一半到全部之间。不过最佳实践是要根据具体的硬件配置、数据集大小和读取速度以及模型训练速度等因素进行调整。在某些情况下,如遇到操作系统兼容性问题或者为了避免不必要的复杂性,也可能需要将 num_workers 设置为较小的值甚至 0。在 Windows 系统中,由于进程间通信的限制,有时必须将 num_workers 设置为 0 才能避免错误。


1.5 可pin住内存


如果是在 GPU 上进行训练,DataLoader 可以通过设置 pin_memory=True 来自动将 CPU 内存中的数据拷贝至 CUDA 可以直接访问的内存区域(即“pin”住内存),这样在数据从 CPU 到 GPU 的转移过程中可以享受到更快的速度。这是因为被pin住内存可以利用异步内存复制操作,避免同步等待,从而使得数据流水线更为顺畅。


2. DataLoader的调用


在PyTorch中,DataLoader是基于torch.utils.data接口进行工作的,DataLoader也是torch.utils.data中的核心类。


At the heart of PyTorch data loading utility is the torch.utils.data.DataLoader class.

2.1 DataLoader的调用方法


DataLoader的调用方法如下:

from torch.utils.data import DataLoader

dataset = ...
loader = DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, prefetch_factor=2,
           persistent_workers=False)
2.2 DataLoader的参数说明
  • dataset: (Dataset) 必须是一个实现了 __getitem____len__ 方法的数据集对象,用于定义如何访问和获取样本及其对应的标签。
  • batch_size: (int) 指定每个批次加载多少个样本。这是训练神经网络时常见的参数,控制着每一步梯度更新所使用的样本数量。
  • shuffle: (bool) 若为 True,则在每个 epoch 开始时都会对数据集进行随机打乱顺序。这对于防止模型过拟合和确保模型看到所有样本组合至关重要,尤其是在训练阶段。
  • sampler: (Sampler) 可选参数,用于指定数据抽样的策略。如果提供了 sampler,则 shuffle 参数将被忽略。例如,可以使用此参数实现分布式训练时的数据分片。
  • batch_sampler: (Sampler) 可直接指定一个批量抽样器,它可以返回一批批索引而不是单个索引。如果指定了 batch_sampler,则 batch_sizeshuffle 将被忽略。
  • num_workers: (int) 上面已经详细说明,不再赘述。
  • collate_fn: (callable) 自定义函数,用于合并多个样本到一个批次。默认的 collate_fn 会堆叠具有相同形状的张量。用户可以自定义此函数以满足特殊的数据整合需求。
  • pin_memory: (bool) 若为 True,则在数据加载后将其移到 CUDA 可以直接访问的页锁定内存中,从而加快数据从 CPU 到 GPU 的传输速度。
  • drop_last: (bool) 若为 True,并且最后一个批次的样本数量小于 batch_size,则丢弃该批次。这在保证所有批次样本数量一致时很有用。
  • timeout: (float) 设置数据加载过程中阻塞的最大秒数。如果设置为0,则无限期等待。
  • worker_init_fn: (callable) 用户自定义的回调函数,用于初始化每个工作进程。可以在每个工作进程中设置不同的随机种子等。
  • prefetch_factor: (int) 预取因子,决定了工作进程在向主进程提交批次的同时,提前生成多少个额外的批次。增加此值可以减少潜在的 I/O 瓶颈,但也可能增加内存占用。
  • persistent_workers: (bool) 若为 True,则保留工作进程在多个数据加载迭代之间,这样可以避免每次重新启动工作进程带来的开销,尤其在长时间运行的任务中效果更明显。然而,这要求 num_workers > 0 并且 multiprocessing.get_start_method() 返回 ‘fork’ 或 ‘spawn’。

3. DataLoader的使用实例

创建了一个 ImageFolder 数据集并应用了 transforms.ToTensor() 转换之后,你已经正确地设置了数据预处理流程,确保了 DataLoader 在批处理时接收到的是 Tensor 类型的数据。接下来要查看 DataLoader 中的特定元素,你可以通过迭代的方式来访问它们:

from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

# 数据预处理
transform = transforms.Compose([transforms.Resize((200,200)), transforms.ToTensor()])

# 加载数据集
dataset = ImageFolder('D:\\DL\\pretrain\\hymenoptera\\hymenoptera_data', transform=transform)

# 创建 DataLoader
loader = DataLoader(dataset,batch_size=4)

# 通过迭代的方式访问 DataLoader 中的元素
for i, (images, labels) in enumerate(loader):
    if i == 0:  # 仅显示第一个批次的数据
        print(f"第{i}个批次的图像张量:")
        print(images.shape)  # 显示图像张量的形状
        print("对应的标签:", labels)
    break


上述代码会打印出 DataLoader 中的第一个批次(索引为0)的图像张量及其对应的类别标签。输出为:


第0个批次的图像张量:
torch.Size([4, 3, 200, 200])
对应的标签: tensor([0, 0, 0, 0])

如果你想查看特定索引位置的样本(即上文所述的映射式map-style),但不是按批次而是直接查看原始数据集中的某个样本,那么可以直接从 ImageFolder 数据集中索引该样本,而非通过 DataLoader

# 获取数据集中的特定样本(假设索引为10)
sample_idx = 10
image, label = dataset[sample_idx]
print(f"索引 {sample_idx} 的图像张量:")
print(image.shape)
print("对应的标签:", label)

输出为:

索引 10 的图像张量:
torch.Size([3, 200, 200])
对应的标签: 0


再次强调,在实际应用中,由于 DataLoader 主要是用来进行批处理的,所以直接从其中索引单个元素并不常见。如需查看单个样本,通常是从原始 dataset 中访问。如果确实需要从 DataLoader 中一次性获取单个样本而不是一批次,需要特殊处理,例如通过 next(iter(loader)) 或者额外编写逻辑来实现。


4. 总结

最后我想再总结下 DataLoaderDataset的关系:


  • DataLoader 依赖于 Dataset 来获取原始数据,它的目的是为了更好地管理和高效地喂入数据给训练过程。
  • 使用时,首先需要基于 Dataset 构建好数据集实例,然后将这个数据集实例传给 DataLoader 构造函数,配置好加载参数后得到一个数据加载器。


总的来说,Dataset 负责定义数据源和访问逻辑,而 DataLoader 负责根据这些定义好的逻辑,按需以适合训练的形式加载和提供数据。


相关实践学习
部署Stable Diffusion玩转AI绘画(GPU云服务器)
本实验通过在ECS上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。
相关文章
|
8月前
|
机器学习/深度学习 算法 PyTorch
【PyTorch实战演练】自调整学习率实例应用(附代码)
【PyTorch实战演练】自调整学习率实例应用(附代码)
261 0
|
8月前
|
PyTorch 算法框架/工具
Pytorch中最大池化层Maxpool的作用说明及实例使用(附代码)
Pytorch中最大池化层Maxpool的作用说明及实例使用(附代码)
782 0
|
8月前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch基础之网络模块torch.nn中函数和模板类的使用详解(附源码)
PyTorch基础之网络模块torch.nn中函数和模板类的使用详解(附源码)
712 0
|
8月前
|
数据采集 PyTorch 算法框架/工具
PyTorch基础之数据模块Dataset、DataLoader用法详解(附源码)
PyTorch基础之数据模块Dataset、DataLoader用法详解(附源码)
1179 0
|
8月前
|
机器学习/深度学习 PyTorch 算法框架/工具
基于Pytorch通过实例详细剖析CNN
基于Pytorch通过实例详细剖析CNN
88 1
基于Pytorch通过实例详细剖析CNN
|
8月前
|
存储 数据可视化 PyTorch
PyTorch中 Datasets & DataLoader 的介绍
PyTorch中 Datasets & DataLoader 的介绍
161 0
|
8月前
|
机器学习/深度学习 算法 大数据
基于PyTorch对凸函数采用SGD算法优化实例(附源码)
基于PyTorch对凸函数采用SGD算法优化实例(附源码)
120 3
|
8月前
|
机器学习/深度学习 算法 PyTorch
基于Pytorch用GAN生成手写数字实例(附代码)
基于Pytorch用GAN生成手写数字实例(附代码)
208 0
|
8月前
|
机器学习/深度学习 算法 PyTorch
基于Pytorch的机器学习Regression问题实例(附源码)
基于Pytorch的机器学习Regression问题实例(附源码)
100 1
|
8月前
|
机器学习/深度学习 自然语言处理 算法
PyTorch实例:简单线性回归的训练和反向传播解析
PyTorch实例:简单线性回归的训练和反向传播解析
PyTorch实例:简单线性回归的训练和反向传播解析