PyTorch 小课堂开课啦!带你解析数据处理全流程(一)

简介: OK,在正式解析 PyTorch 中的 torch.utils.data 模块之前,我们需要理解一下 Python 中的迭代器(Iterator),因为在源码的 Dataset, Sampler 和 DataLoader 这三个类中都会用到包括 __len__(self),__getitem__(self) 和 __iter__(self) 的抽象类的魔法方法。

640.png

一张图带你看懂全文


最近被迫开始了居家办公,这不,每天认真工(mo)作(yu)之余,也有了更多时间重新学习分析起了 PyTorch 源码分享,属于是直接站在巨人的肩膀上了。在简单捋一捋思路之后,就从 torch.utils.data 数据处理模块开始,一步步重新学习 PyTorch 的一些源码模块解析,希望也能让大家重新认识已经不陌生的 PyTorch 这个小伙伴。

640.gif

1. 迭代器介绍



OK,在正式解析 PyTorch 中的 torch.utils.data 模块之前,我们需要理解一下 Python 中的迭代器(Iterator),因为在源码的 Dataset, Sampler 和 DataLoader 这三个类中都会用到包括 __len__(self),__getitem__(self) 和 __iter__(self) 的抽象类的魔法方法。


· __len__(self):定义当被 len() 函数调用时的行为,一般返回迭代器中元素的个数。


· __getitem__(self):定义获取容器中指定元素时的行为,相当于 self[key] ,即允许类对象拥有索引操作。


· __iter__(self):定义当迭代容器中的元素时的行为。


除此之外,我们也需要清楚两个概念:


· 迭代(Iteration):当我们用一个循环(比如 for 循环)来遍历容器(比如列表,元组)中的元素时,这种遍历的过程可称为迭代。


· 可迭代对象(Iterable):一般指含有 __iter__() 方法或 __getitem__() 方法的对象。我们通常接触的数据结构,如序列(列表、元组和字符串)还有字典等,都支持迭代操作,也可称为可迭代对象。


那什么是迭代器(Iterator)呢?简而言之,迭代器就是一种可以被遍历的容器类对象,但它又比较特别,它需要遵循迭代器协议,那什么又是迭代器协议呢?迭代器协议(iterator protocol)是指要实现对象的__iter()____next__() 方法。一个容器或者类如果是迭代器,那么就必须实现 __iter__() 方法以及重点实现 __next__() 方法,前者会返回一个迭代器(通常是迭代器对象本身),而后者决定了迭代的规则。现在,为更好地理解迭代器的内部运行机制,我们可以看一个斐波那契数列的迭代器实现例子:

class Fibs:
    def __init__(self, n=20):
        self.a = 0
        self.b = 1
        self.n = n
    def __iter__(self):
        return self
    def __next__(self):
        self.a, self.b = self.b, self.a + self.b
        if self.a > self.n:
            raise StopIteration
        return self.a
fibs = Fibs()
for each in fibs:
    print(each)
# 输出 
# 1 1 2 3 5 8 13

一般而言,迭代器满足以下几种特性:


· 迭代器是⼀个对象,但比较特别,需要满足迭代器协议,他还可以被 for 语句循环迭代直到终⽌。


· 迭代器可以被 next() 函数调⽤,并返回⼀个值,亦可以被 iter() 函数调⽤,但返回的是一个迭代器(可以是自身)。


· 迭代器连续被 next() 函数调⽤时,依次返回⼀系列的值,但如果到了迭代的末尾,则抛出 StopIteration 异常,另外他可以没有末尾,但只要被 next() 函数调⽤,就⼀定会返回⼀个值。


· Python3 中, next() 内置函数调⽤的是对象的 __next__() ⽅法,iter() 内置函数调⽤的是对象的 __iter__() ⽅法。


那么,了解了什么是迭代器后,我们马上开始解析 torch.utils.data 模块,对于 torch.utils.data 而言,重点是其 Dataset,Sampler,DataLoader 三个模块,辅以 collate,fetch,pin_memory 等组件对特定功能予以支持。


Tips:涉及的源码皆以 PyTorch 1.7 为准。


2. Dataset



Dataset 主要负责对 raw data source 封装,将其封装成 Python 可识别的数据结构,其必须提供提取数据个体的接口。Dataset 中共有 Map-style datasets 和 Iterable-style datasets 两种:


1.1 Map-style dataset


torch.utils.data.Dataset 它是一种通过实现  __len__()__getitem__()方法来获取数据的 Dataset,它表示从(可能是非整数)索引/关键字到数据样本的映射。因而,在我们访问 Map-style 的数据集时,使用 dataset[idx] 即可访问 idx 对应的数据。通常,我们使用 Map-style 类型的 dataset 居多,可以看到其数据接口定义如下:

class Dataset(Generic[T_co]):
    # Generic is an Abstract base class for generic types.
    def __getitem__(self, index) -> T_co:
        raise NotImplementedError
    def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
        return ConcatDataset([self, other])
    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    # in pytorch/torch/utils/data/sampler.py

在 PyTorch 1.7 源码中所有定义的 Dataset 都是其子类,而对于一般计算机视觉任务,我们通常也会在其中进行一些 resize,crop,flip 等预处理的操作。


值得一提的是,PyTorch 源码中并没有提供默认的 __len__() 方法实现,原因是 return NotImplemented 或者 raise NotImplementedError() 之类的默认实现都会存在各自的问题,这点我们在源码 pytorch/torch/utils/data/sampler.py 中的注释也可以得到解释。


1.2 Iterable-style dataset


torch.utils.data.IterableDataset 它是一种实现 __iter__() 来获取数据的 Dataset,Iterable-style 的数据集特别适用于以下情况:随机读取代价很大甚至不可能,且 batch size 取决于获取到的数据。其接口定义如下:

class IterableDataset(Dataset[T_co]):
    def __iter__(self) -> Iterator[T_co]:
        raise NotImplementedError
    def __add__(self, other: Dataset[T_co]):
        return ChainDataset([self, other])
    # No `def __len__(self)` default? Subclasses raise `TypeError` when needed.
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]


特别地
,当 DataLoader 的 num_workers > 0 时, 每个 worker 都将具有数据对象的不同样本。因此需要独立地对每个副本进行配置,以防止每个 worker 产生的数据不重复。同时,数据加载顺序完全由用户定义的可迭代样式控制。这允许更容易地实现块读取和动态批次大小(例如,通过每次产生一个批次的样本)。

1.3 其他 Dataset


除了 Map-style dataset 和 Iterable-style dataset 以外,PyTorch 也在此基础上提供了其他类型的 Dataset 子类:


· torch.utils.data.ConcatDataset:用于连接多个 ConcatDataset 数据集。


· torch.utils.data.ChainDataset:用于连接多个 IterableDataset 数据集,在 IterableDataset 的 __add__() 方法中被调用。


· torch.utils.data.Subset:用于获取指定一个索引序列对应的子数据集。

class Subset(Dataset[T_co]):
    dataset: Dataset[T_co]
    indices: Sequence[int]
    def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None:
        self.dataset = dataset
        self.indices = indices
    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]
    def __len__(self):
        return len(self.indices)

· torch.utils.data.TensorDataset:用于获取封装成 tensor 的数据集,每一个样本都可通过索引张量来获得。

class TensorDataset(Dataset):
    def __init__(self, *tensor):
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors
    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in tensors
    def __len__(self):
        return self.tensors[0].size(0)


3. Sampler



torch.utils.data.Sampler 主要负责提供一种遍历数据集所有元素索引的方式。可支持我们自定义,也可以使用 PyTorch 本身提供的,其基类接口定义如下:

lass Sampler(Generic[T_co]):
    r"""Base class for all Samplers.
    Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
    way to iterate over indices of dataset elements, and a :meth:`__len__` method
    that returns the length of the returned iterators.
    .. note:: The :meth:`__len__` method isn't strictly required by
              :class:`~torch.utils.data.DataLoader`, but is expected in any
              calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
    """
    def __init__(self, data_source: Optional[Sized]) -> None:
        pass
    def __iter__(self) -> Iterator[T_co]:
        raise NotImplementedError

特别地__len()__ 方法虽不是必要的,但是当 DataLoader 需要计算 length 的时候必须定义,这点在源码中也有注释加以体现。


同样,PyTorch 也在此基础上提供了其他类型的 Sampler 子类:


· torch.utils.data.SequentialSampler:顺序采样样本,始终按照同一个顺序。


· torch.utils.data.RandomSampler:可指定有无放回地,进行随机采样样本元素。


· torch.utils.data.SubsetRandomSampler:无放回地按照给定的索引列表采样样本元素。


· torch.utils.data.WeightedRandomSampler:按照给定的概率来采样样本。样本元素来自 [0,…,len(weights)-1] ,给定概率(权重)。


· torch.utils.data.BatchSampler:在一个 batch 中封装一个其他的采样器, 返回一个 batch 大小的 index 索引。


· torch.utils.data.DistributedSample:将数据加载限制为数据集子集的采样器。与 torch.nn.parallel.DistributedDataParallel 结合使用。在这种情况下,每个进程都可以将 DistributedSampler 实例作为 DataLoader 采样器传递。


4. DataLoader


torch.utils.data.DataLoader 是 PyTorch 数据加载的核心,负责加载数据,同时支持 Map-style 和 Iterable-style Dataset,支持单进程/多进程,还可以通过参数设置如 sampler, batch size, pin memory 等自定义数据加载顺序以及控制数据批处理功能。其接口定义如下:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

对于每个参数的含义,下面通过一个表格进行直观地介绍:

640.png

从参数定义中,我们可以看到 DataLoader 主要支持以下几个功能:


· 支持加载 map-style 和 iterable-style 的 dataset,主要涉及到的参数是 dataset。


· 自定义数据加载顺序,主要涉及到的参数有 shuffle,sampler,batch_sampler,collate_fn。


· 自动把数据整理成batch序列,主要涉及到的参数有 batch_size,batch_sampler,collate_fn,drop_last。


· 单进程和多进程的数据加载,主要涉及到的参数有 num_workers,worker_init_fn。


· 自动进行锁页内存读取 (memory pinning),主要涉及到的参数 pin_memory。


· 支持数据预加载,主要涉及的参数 prefetch_factor。


3.1 批处理

3.1.1 自动批处理(默认)


DataLoader 支持通过参数 batch_size, drop_last, batch_sampler,自动地把取出的数据整理(collate)成批次样本(batch),其中 batch_size 和 drop_last 参数用于指定 DataLoader 如何获取 dataset 的 key。特别地,对于 map-style 类型的 dataset,用户可以选择指定 batch_sample 参数,一次就生成一个 keys list。


在使用 sampler 产生的 indices 获取采样到的数据时,DataLoader 使用 collate_fn 参数将样本列表整理成 batch。抽象整个过程,其表示方式大致如下:

# For Map-style
for indices in batch_sampler:
    yield collate_fn([dataset[i] for i in indices])
# For Iterable-style
dataset_iter = iter(dataset)
for indices in batch_sampler:
    yield collate_fn([next(dataset_iter) for _ in indices])


3.1.2 关闭自动批处理


当我们想用 dataset 代码手动处理 batch,或仅加载单个 sample data 时,可将 batch_size 和 batch_sampler 设为 None, 将关闭自动批处理。此时,由 Dataset 产生的 sample 将会直接被 collate_fn 处理。抽象整个过程,其表示方式大致如下:

# For Map-style
for index in sampler:
    yield collate_fn(dataset[index])
# For Iterable-style
for data in iter(dataset):
    yield collate_fn(data)


3.1.3 collate_fn


当关闭自动批处理 (automatic batching) 时,collate_fn 作用于单个数据样本,只是在 PyTorch 张量中转换 NumPy 数组。


而当开启自动批处理 (automatic batching) 时,collate_fn 作用于数据样本列表,将输入样本整理为一个 batch,一般做下面 3 件事情:


· 添加新的批次维度(一般是第一维)。


· 它会自动将 NumPy 数组和 Python 数值转换为 PyTorch 张量。


· 它保留数据结构,例如,如果每个样本都是 dict,则输出具有相同键集但批处理过的张量作为值的字典(或 list,当数据类型不能转换的时候)。这在 list,tuples,namedtuples 同样适用。


自定义 collate_fn 可用于自定义排序规则,例如,将顺序数据填充到批处理的最大长度,添加对自定义数据类型的支持等。


5. 三者关系



通过以上解析的三者工作内容,不难可以推出其内在关系:


1)设置 Dataset,将数据 data source 包装成 Dataset 类,暴露出提取接口。


2)设置 Sampler,决定采样方式。我们虽然能从 Dataset 中提取元素了,但还是需要设置 Sampler 告诉程序提取 Dataset 的策略。


3)将设置好的 Dataset 和 Sampler 传入 DataLoader,同时可以设置 shuffle,batch_size 等参数。使用 DataLoader 对象可以方便快捷地在数据集上遍历。


至此我们就可以了解到了 Dataset,Sampler,Dataloader 三个类的基本定义以及对应实现功能,同时也介绍了批处理对应参数组件。总结来说,我们需要记得的是三点,即 Dataloader 负责总的调度,命令 Sampler 定义遍历索引的方式,然后用索引去 Dataset 中提取元素。于是就实现了对给定数据集的遍历。


文章来源:【OpenMMLab

2022-04-08 18:05

目录
相关文章
|
2天前
|
算法 数据处理 开发者
FFmpeg库的使用与深度解析:解码音频流流程
FFmpeg库的使用与深度解析:解码音频流流程
43 0
|
2天前
|
机器学习/深度学习 算法 PyTorch
RPN(Region Proposal Networks)候选区域网络算法解析(附PyTorch代码)
RPN(Region Proposal Networks)候选区域网络算法解析(附PyTorch代码)
357 1
|
2天前
|
机器学习/深度学习 存储 PyTorch
Pytorch中in-place操作相关错误解析及detach()方法说明
Pytorch中in-place操作相关错误解析及detach()方法说明
116 0
|
2天前
|
消息中间件 Unix Linux
Linux进程间通信(IPC)介绍:详细解析IPC的执行流程、状态和通信机制
Linux进程间通信(IPC)介绍:详细解析IPC的执行流程、状态和通信机制
90 1
|
2天前
|
数据采集 数据可视化 大数据
Python在数据科学中的实际应用:从数据清洗到可视化的全流程解析
Python在数据科学中的实际应用:从数据清洗到可视化的全流程解析
49 1
|
2天前
|
数据采集 机器学习/深度学习 数据可视化
数据科学项目实战:完整的Python数据分析流程案例解析
【4月更文挑战第12天】本文以Python为例,展示了数据分析的完整流程:从CSV文件加载数据,执行预处理(处理缺失值和异常值),进行数据探索(可视化和统计分析),选择并训练线性回归模型,评估模型性能,以及结果解释与可视化。每个步骤都包含相关代码示例,强调了数据科学项目中理论与实践的结合。
|
2天前
|
PyTorch 数据处理 算法框架/工具
pytorch 数据处理备忘
pytorch 数据处理备忘
10 1
|
2天前
|
算法 Linux 调度
xenomai内核解析--xenomai与普通linux进程之间通讯XDDP(一)--实时端socket创建流程
xenomai与普通linux进程之间通讯XDDP(一)--实时端socket创建流程
13 1
xenomai内核解析--xenomai与普通linux进程之间通讯XDDP(一)--实时端socket创建流程
|
2天前
|
Linux 调度 数据库
|
2天前
|
Linux API 调度
xenomai内核解析-xenomai实时线程创建流程
本文介绍了linux硬实时操作系统xenomai pthread_creta()接口的底层实现原理,解释了如何在双内核间创建和调度一个xenomai任务。本文是基于源代码的分析,提供了详细的流程和注释,同时给出了结论部分,方便读者快速了解核心内容。
20 0
xenomai内核解析-xenomai实时线程创建流程