PyTorch 小课堂开课啦!带你解析数据处理全流程(一)

本文涉及的产品
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
全局流量管理 GTM,标准版 1个月
云解析 DNS,旗舰版 1个月
简介: OK,在正式解析 PyTorch 中的 torch.utils.data 模块之前,我们需要理解一下 Python 中的迭代器(Iterator),因为在源码的 Dataset, Sampler 和 DataLoader 这三个类中都会用到包括 __len__(self),__getitem__(self) 和 __iter__(self) 的抽象类的魔法方法。

640.png

一张图带你看懂全文


最近被迫开始了居家办公,这不,每天认真工(mo)作(yu)之余,也有了更多时间重新学习分析起了 PyTorch 源码分享,属于是直接站在巨人的肩膀上了。在简单捋一捋思路之后,就从 torch.utils.data 数据处理模块开始,一步步重新学习 PyTorch 的一些源码模块解析,希望也能让大家重新认识已经不陌生的 PyTorch 这个小伙伴。

640.gif

1. 迭代器介绍



OK,在正式解析 PyTorch 中的 torch.utils.data 模块之前,我们需要理解一下 Python 中的迭代器(Iterator),因为在源码的 Dataset, Sampler 和 DataLoader 这三个类中都会用到包括 __len__(self),__getitem__(self) 和 __iter__(self) 的抽象类的魔法方法。


· __len__(self):定义当被 len() 函数调用时的行为,一般返回迭代器中元素的个数。


· __getitem__(self):定义获取容器中指定元素时的行为,相当于 self[key] ,即允许类对象拥有索引操作。


· __iter__(self):定义当迭代容器中的元素时的行为。


除此之外,我们也需要清楚两个概念:


· 迭代(Iteration):当我们用一个循环(比如 for 循环)来遍历容器(比如列表,元组)中的元素时,这种遍历的过程可称为迭代。


· 可迭代对象(Iterable):一般指含有 __iter__() 方法或 __getitem__() 方法的对象。我们通常接触的数据结构,如序列(列表、元组和字符串)还有字典等,都支持迭代操作,也可称为可迭代对象。


那什么是迭代器(Iterator)呢?简而言之,迭代器就是一种可以被遍历的容器类对象,但它又比较特别,它需要遵循迭代器协议,那什么又是迭代器协议呢?迭代器协议(iterator protocol)是指要实现对象的__iter()____next__() 方法。一个容器或者类如果是迭代器,那么就必须实现 __iter__() 方法以及重点实现 __next__() 方法,前者会返回一个迭代器(通常是迭代器对象本身),而后者决定了迭代的规则。现在,为更好地理解迭代器的内部运行机制,我们可以看一个斐波那契数列的迭代器实现例子:

class Fibs:
    def __init__(self, n=20):
        self.a = 0
        self.b = 1
        self.n = n
    def __iter__(self):
        return self
    def __next__(self):
        self.a, self.b = self.b, self.a + self.b
        if self.a > self.n:
            raise StopIteration
        return self.a
fibs = Fibs()
for each in fibs:
    print(each)
# 输出 
# 1 1 2 3 5 8 13

一般而言,迭代器满足以下几种特性:


· 迭代器是⼀个对象,但比较特别,需要满足迭代器协议,他还可以被 for 语句循环迭代直到终⽌。


· 迭代器可以被 next() 函数调⽤,并返回⼀个值,亦可以被 iter() 函数调⽤,但返回的是一个迭代器(可以是自身)。


· 迭代器连续被 next() 函数调⽤时,依次返回⼀系列的值,但如果到了迭代的末尾,则抛出 StopIteration 异常,另外他可以没有末尾,但只要被 next() 函数调⽤,就⼀定会返回⼀个值。


· Python3 中, next() 内置函数调⽤的是对象的 __next__() ⽅法,iter() 内置函数调⽤的是对象的 __iter__() ⽅法。


那么,了解了什么是迭代器后,我们马上开始解析 torch.utils.data 模块,对于 torch.utils.data 而言,重点是其 Dataset,Sampler,DataLoader 三个模块,辅以 collate,fetch,pin_memory 等组件对特定功能予以支持。


Tips:涉及的源码皆以 PyTorch 1.7 为准。


2. Dataset



Dataset 主要负责对 raw data source 封装,将其封装成 Python 可识别的数据结构,其必须提供提取数据个体的接口。Dataset 中共有 Map-style datasets 和 Iterable-style datasets 两种:


1.1 Map-style dataset


torch.utils.data.Dataset 它是一种通过实现  __len__()__getitem__()方法来获取数据的 Dataset,它表示从(可能是非整数)索引/关键字到数据样本的映射。因而,在我们访问 Map-style 的数据集时,使用 dataset[idx] 即可访问 idx 对应的数据。通常,我们使用 Map-style 类型的 dataset 居多,可以看到其数据接口定义如下:

class Dataset(Generic[T_co]):
    # Generic is an Abstract base class for generic types.
    def __getitem__(self, index) -> T_co:
        raise NotImplementedError
    def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
        return ConcatDataset([self, other])
    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    # in pytorch/torch/utils/data/sampler.py

在 PyTorch 1.7 源码中所有定义的 Dataset 都是其子类,而对于一般计算机视觉任务,我们通常也会在其中进行一些 resize,crop,flip 等预处理的操作。


值得一提的是,PyTorch 源码中并没有提供默认的 __len__() 方法实现,原因是 return NotImplemented 或者 raise NotImplementedError() 之类的默认实现都会存在各自的问题,这点我们在源码 pytorch/torch/utils/data/sampler.py 中的注释也可以得到解释。


1.2 Iterable-style dataset


torch.utils.data.IterableDataset 它是一种实现 __iter__() 来获取数据的 Dataset,Iterable-style 的数据集特别适用于以下情况:随机读取代价很大甚至不可能,且 batch size 取决于获取到的数据。其接口定义如下:

class IterableDataset(Dataset[T_co]):
    def __iter__(self) -> Iterator[T_co]:
        raise NotImplementedError
    def __add__(self, other: Dataset[T_co]):
        return ChainDataset([self, other])
    # No `def __len__(self)` default? Subclasses raise `TypeError` when needed.
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]


特别地
,当 DataLoader 的 num_workers > 0 时, 每个 worker 都将具有数据对象的不同样本。因此需要独立地对每个副本进行配置,以防止每个 worker 产生的数据不重复。同时,数据加载顺序完全由用户定义的可迭代样式控制。这允许更容易地实现块读取和动态批次大小(例如,通过每次产生一个批次的样本)。

1.3 其他 Dataset


除了 Map-style dataset 和 Iterable-style dataset 以外,PyTorch 也在此基础上提供了其他类型的 Dataset 子类:


· torch.utils.data.ConcatDataset:用于连接多个 ConcatDataset 数据集。


· torch.utils.data.ChainDataset:用于连接多个 IterableDataset 数据集,在 IterableDataset 的 __add__() 方法中被调用。


· torch.utils.data.Subset:用于获取指定一个索引序列对应的子数据集。

class Subset(Dataset[T_co]):
    dataset: Dataset[T_co]
    indices: Sequence[int]
    def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None:
        self.dataset = dataset
        self.indices = indices
    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]
    def __len__(self):
        return len(self.indices)

· torch.utils.data.TensorDataset:用于获取封装成 tensor 的数据集,每一个样本都可通过索引张量来获得。

class TensorDataset(Dataset):
    def __init__(self, *tensor):
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors
    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in tensors
    def __len__(self):
        return self.tensors[0].size(0)


3. Sampler



torch.utils.data.Sampler 主要负责提供一种遍历数据集所有元素索引的方式。可支持我们自定义,也可以使用 PyTorch 本身提供的,其基类接口定义如下:

lass Sampler(Generic[T_co]):
    r"""Base class for all Samplers.
    Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
    way to iterate over indices of dataset elements, and a :meth:`__len__` method
    that returns the length of the returned iterators.
    .. note:: The :meth:`__len__` method isn't strictly required by
              :class:`~torch.utils.data.DataLoader`, but is expected in any
              calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
    """
    def __init__(self, data_source: Optional[Sized]) -> None:
        pass
    def __iter__(self) -> Iterator[T_co]:
        raise NotImplementedError

特别地__len()__ 方法虽不是必要的,但是当 DataLoader 需要计算 length 的时候必须定义,这点在源码中也有注释加以体现。


同样,PyTorch 也在此基础上提供了其他类型的 Sampler 子类:


· torch.utils.data.SequentialSampler:顺序采样样本,始终按照同一个顺序。


· torch.utils.data.RandomSampler:可指定有无放回地,进行随机采样样本元素。


· torch.utils.data.SubsetRandomSampler:无放回地按照给定的索引列表采样样本元素。


· torch.utils.data.WeightedRandomSampler:按照给定的概率来采样样本。样本元素来自 [0,…,len(weights)-1] ,给定概率(权重)。


· torch.utils.data.BatchSampler:在一个 batch 中封装一个其他的采样器, 返回一个 batch 大小的 index 索引。


· torch.utils.data.DistributedSample:将数据加载限制为数据集子集的采样器。与 torch.nn.parallel.DistributedDataParallel 结合使用。在这种情况下,每个进程都可以将 DistributedSampler 实例作为 DataLoader 采样器传递。


4. DataLoader


torch.utils.data.DataLoader 是 PyTorch 数据加载的核心,负责加载数据,同时支持 Map-style 和 Iterable-style Dataset,支持单进程/多进程,还可以通过参数设置如 sampler, batch size, pin memory 等自定义数据加载顺序以及控制数据批处理功能。其接口定义如下:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

对于每个参数的含义,下面通过一个表格进行直观地介绍:

640.png

从参数定义中,我们可以看到 DataLoader 主要支持以下几个功能:


· 支持加载 map-style 和 iterable-style 的 dataset,主要涉及到的参数是 dataset。


· 自定义数据加载顺序,主要涉及到的参数有 shuffle,sampler,batch_sampler,collate_fn。


· 自动把数据整理成batch序列,主要涉及到的参数有 batch_size,batch_sampler,collate_fn,drop_last。


· 单进程和多进程的数据加载,主要涉及到的参数有 num_workers,worker_init_fn。


· 自动进行锁页内存读取 (memory pinning),主要涉及到的参数 pin_memory。


· 支持数据预加载,主要涉及的参数 prefetch_factor。


3.1 批处理

3.1.1 自动批处理(默认)


DataLoader 支持通过参数 batch_size, drop_last, batch_sampler,自动地把取出的数据整理(collate)成批次样本(batch),其中 batch_size 和 drop_last 参数用于指定 DataLoader 如何获取 dataset 的 key。特别地,对于 map-style 类型的 dataset,用户可以选择指定 batch_sample 参数,一次就生成一个 keys list。


在使用 sampler 产生的 indices 获取采样到的数据时,DataLoader 使用 collate_fn 参数将样本列表整理成 batch。抽象整个过程,其表示方式大致如下:

# For Map-style
for indices in batch_sampler:
    yield collate_fn([dataset[i] for i in indices])
# For Iterable-style
dataset_iter = iter(dataset)
for indices in batch_sampler:
    yield collate_fn([next(dataset_iter) for _ in indices])


3.1.2 关闭自动批处理


当我们想用 dataset 代码手动处理 batch,或仅加载单个 sample data 时,可将 batch_size 和 batch_sampler 设为 None, 将关闭自动批处理。此时,由 Dataset 产生的 sample 将会直接被 collate_fn 处理。抽象整个过程,其表示方式大致如下:

# For Map-style
for index in sampler:
    yield collate_fn(dataset[index])
# For Iterable-style
for data in iter(dataset):
    yield collate_fn(data)


3.1.3 collate_fn


当关闭自动批处理 (automatic batching) 时,collate_fn 作用于单个数据样本,只是在 PyTorch 张量中转换 NumPy 数组。


而当开启自动批处理 (automatic batching) 时,collate_fn 作用于数据样本列表,将输入样本整理为一个 batch,一般做下面 3 件事情:


· 添加新的批次维度(一般是第一维)。


· 它会自动将 NumPy 数组和 Python 数值转换为 PyTorch 张量。


· 它保留数据结构,例如,如果每个样本都是 dict,则输出具有相同键集但批处理过的张量作为值的字典(或 list,当数据类型不能转换的时候)。这在 list,tuples,namedtuples 同样适用。


自定义 collate_fn 可用于自定义排序规则,例如,将顺序数据填充到批处理的最大长度,添加对自定义数据类型的支持等。


5. 三者关系



通过以上解析的三者工作内容,不难可以推出其内在关系:


1)设置 Dataset,将数据 data source 包装成 Dataset 类,暴露出提取接口。


2)设置 Sampler,决定采样方式。我们虽然能从 Dataset 中提取元素了,但还是需要设置 Sampler 告诉程序提取 Dataset 的策略。


3)将设置好的 Dataset 和 Sampler 传入 DataLoader,同时可以设置 shuffle,batch_size 等参数。使用 DataLoader 对象可以方便快捷地在数据集上遍历。


至此我们就可以了解到了 Dataset,Sampler,Dataloader 三个类的基本定义以及对应实现功能,同时也介绍了批处理对应参数组件。总结来说,我们需要记得的是三点,即 Dataloader 负责总的调度,命令 Sampler 定义遍历索引的方式,然后用索引去 Dataset 中提取元素。于是就实现了对给定数据集的遍历。


文章来源:【OpenMMLab

2022-04-08 18:05

目录
相关文章
|
13天前
|
监控 安全 开发工具
鸿蒙HarmonyOS应用开发 | HarmonyOS Next-从应用开发到上架全流程解析
HarmonyOS Next是华为推出的最新版本鸿蒙操作系统,强调多设备协同和分布式技术,提供丰富的开发工具和API接口。本文详细解析了从应用开发到上架的全流程,包括环境搭建、应用设计与开发、多设备适配、测试调试、应用上架及推广等环节,并介绍了鸿蒙原生应用开发者激励计划,帮助开发者更好地融入鸿蒙生态。通过DevEco Studio集成开发环境和华为提供的多种支持工具,开发者可以轻松创建并发布高质量的鸿蒙应用,享受技术和市场推广的双重支持。
171 11
|
1月前
|
机器学习/深度学习 人工智能 PyTorch
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
本文探讨了Transformer模型中变长输入序列的优化策略,旨在解决深度学习中常见的计算效率问题。文章首先介绍了批处理变长输入的技术挑战,特别是填充方法导致的资源浪费。随后,提出了多种优化技术,包括动态填充、PyTorch NestedTensors、FlashAttention2和XFormers的memory_efficient_attention。这些技术通过减少冗余计算、优化内存管理和改进计算模式,显著提升了模型的性能。实验结果显示,使用FlashAttention2和无填充策略的组合可以将步骤时间减少至323毫秒,相比未优化版本提升了约2.5倍。
49 3
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
|
16天前
|
域名解析 弹性计算 安全
阿里云服务器租用、注册域名、备案及域名解析完整流程参考(图文教程)
对于很多初次建站的用户来说,选购云服务器和注册应及备案和域名解析步骤必须了解的,目前轻量云服务器2核2G68元一年,2核4G4M服务器298元一年,域名注册方面,阿里云推出域名1元购买活动,新用户注册com和cn域名2年首年仅需0元,xyz和top等域名首年仅需1元。对于建站的用户来说,购买完云服务器并注册好域名之后,下一步还需要操作备案和域名绑定。本文为大家展示阿里云服务器的购买流程,域名注册、绑定以及备案的完整流程,全文以图文教程形式为大家展示具体细节及注意事项,以供新手用户参考。
|
2月前
|
缓存 监控 Java
Java线程池提交任务流程底层源码与源码解析
【11月更文挑战第30天】嘿,各位技术爱好者们,今天咱们来聊聊Java线程池提交任务的底层源码与源码解析。作为一个资深的Java开发者,我相信你一定对线程池并不陌生。线程池作为并发编程中的一大利器,其重要性不言而喻。今天,我将以对话的方式,带你一步步深入线程池的奥秘,从概述到功能点,再到背景和业务点,最后到底层原理和示例,让你对线程池有一个全新的认识。
57 12
|
29天前
|
PyTorch Shell API
Ascend Extension for PyTorch的源码解析
本文介绍了Ascend对PyTorch代码的适配过程,包括源码下载、编译步骤及常见问题,详细解析了torch-npu编译后的文件结构和三种实现昇腾NPU算子调用的方式:通过torch的register方式、定义算子方式和API重定向映射方式。这对于开发者理解和使用Ascend平台上的PyTorch具有重要指导意义。
|
3月前
|
JavaScript 前端开发 开发者
Vue执行流程及渲染解析
【10月更文挑战第2天】
115 58
|
3月前
|
机器学习/深度学习 算法 PyTorch
Pytorch-RMSprop算法解析
关注B站【肆十二】,观看更多实战教学视频。本期介绍深度学习中的RMSprop优化算法,通过调整每个参数的学习率来优化模型训练。示例代码使用PyTorch实现,详细解析了RMSprop的参数及其作用。适合初学者了解和实践。
87 1
|
3月前
|
JavaScript 前端开发 UED
Vue执行流程及渲染解析
【10月更文挑战第5天】
|
3月前
|
存储 搜索推荐 数据库
运用LangChain赋能企业规章制度制定:深入解析Retrieval-Augmented Generation(RAG)技术如何革新内部管理文件起草流程,实现高效合规与个性化定制的完美结合——实战指南与代码示例全面呈现
【10月更文挑战第3天】构建公司规章制度时,需融合业务实际与管理理论,制定合规且促发展的规则体系。尤其在数字化转型背景下,利用LangChain框架中的RAG技术,可提升规章制定效率与质量。通过Chroma向量数据库存储规章制度文本,并使用OpenAI Embeddings处理文本向量化,将现有文档转换后插入数据库。基于此,构建RAG生成器,根据输入问题检索信息并生成规章制度草案,加快更新速度并确保内容准确,灵活应对法律与业务变化,提高管理效率。此方法结合了先进的人工智能技术,展现了未来规章制度制定的新方向。
51 3
|
3月前
|
存储 缓存 边缘计算
揭秘直播带货背后的黑科技:播放流程全解析!
大家好,我是小米,今天聊聊社区直播带货的技术细节。我们将探讨直播播放流程中的关键技术,包括 HTTP DASH 协议、POP(Point of Presence)缓存和一致性哈希算法等。通过这些技术,直播流能根据网络状况动态调整清晰度,保证流畅体验。POP 和 DC 的多层次缓存设计减少了延迟,提升了观看效果。无论是技术人员还是直播运营者,都能从中受益。希望通过本文,你能更好地理解直播背后的技术原理。
59 3

热门文章

最新文章

推荐镜像

更多