深度学习实践篇 第六章:webdataset

简介: 简要介绍webdataset的使用。

参考教程:
https://github.com/pytorch/pytorch/issues/38419
https://zhuanlan.zhihu.com/p/412772439
https://webdataset.github.io/webdataset/gettingstarted/


背景

训练数据通常是以个体的方式存储的,就像我们在第一章下载并处理成png格式后的cifar10数据,它以'xxx.png'的文件形式存放在一个一个独立的空间中。
随着数据集变得越来越大,这样的存放形式就不是那么高效和便捷。在进行模型训练时,也会因为数据的IO瓶颈拖慢训练的速度。
在使用Dataset中的数据时,我们的getitem(self, idx)函数会根据数据的index检索数据。在训练时,我们一般都会使用shuffle = True来完成数据的随机读取,这样索引的index也是无效的,当图片数据直接存放在系统上时,对文件的访问需要花费大量的代价。
这个问题可以使用sequential storage formats and sharding来解决。就像tensorflow中使用的TFRecord格式,它将训练集/测试集打包在一起使用,文件里存储的就是序列化的tf.Example。Pytorch是没有这种专属的数据存储格式的。

WebDataset

WebDataset提供了一种序列化存储大规模数据的方法,它将数据保存在tar包中,但是在使用时不需要对tar包进行解压。这种形式提供了高效的I/O,并且不管是在本地还是云端数据上都表现很不错。

webdataset的生成

webdataset是一个tar文件,所以你直接使用tar命令就可以进行文件的生成。

tar --sort=name -cf dataset.tar dataset/

我们也可以使用python调用webdataset的包,来进行文件的写入操作。
以下面的代码为例,下方的代码想要将现有的MNIST数据存放到'mnist.tar'文件中,因此它按照顺序将数据一个一个多写入了文件里。

dataset = torchvision.datasets.MNIST(root="./temp", download=True) # 获得MNIST数据
sink = wds.TarWriter("mnist.tar") # 使用TarWriter,准备将数据写入mnist.tar
for index, (input, output) in enumerate(dataset):
    if index%1000==0:
        print(f"{index:6d}", end="\r", flush=True, file=sys.stderr) # 每写入1000个数据,输出一些状态
    sink.write({
   
   
        "__key__": "sample%06d" % index, # 当前的数据的index
        "input.pyd": input, # 数据的input
        "output.pyd": output, # 数据的target
    })
sink.close() # 关闭当前文件。

这里的sink_write写入了是一个dict,其中'key'这一项决定了你想保存的数据的前缀名,’input.pyd'是你的input的数据的后缀,它同时也决定了你的数据存放的格式。
比如说这里使用的'pyd',就是我们之前说过的pickle格式,它可以保证数据的完整性,以不压缩的形式存储数据,缺点是不能被其它的语言读取。
在你明确知道数据的类型的情况下,你也可以使用别的格式来存放数据,比如说对于图片,你可以使用‘ppm','png','jpg'等格式,对于图片的标签,已知数据标签是整数的形式时,可以使用'cls'格式。

webdataset的加载

对于一个存入tar的webdataset的数据,你可以通过它的url对它进行读取,这个url可以是云端地址,也可以是本地路径。

import webdataset as wds
dataset = wds.WebDataset(url)

我们在讲数据存入tar时,writer根据我们定义的数据格式对数据进行了encode,所以我们直接读取到的数据是还没有decode的数据。
在教程中给了这样一个例子。
image.png

直接获取到的数据格式是bytes的格式。
你可以数据进行一些处理,webdataset提供一种链式的数据处理方法,比如上面的数据,你就可以使用下面的方法处理。

dataset = (
    wds.WebDataset(url)
    .shuffle(100)
    .decode("rgb")
    .to_tuple("jpg;png", "json")
)

这里的decode传入的'rgb'属于headler,webdataset提供了一些自带的imageheadler。帮助使用者进行数据类型转换。imagespecs = { "l8": ("numpy", "uint8", "l"), "rgb8": ("numpy", "uint8", "rgb"), "rgba8": ("numpy", "uint8", "rgba"), "l": ("numpy", "float", "l"), "rgb": ("numpy", "float", "rgb"), "rgba": ("numpy", "float", "rgba"), "torchl8": ("torch", "uint8", "l"), "torchrgb8": ("torch", "uint8", "rgb"), "torchrgba8": ("torch", "uint8", "rgba"), "torchl": ("torch", "float", "l"), "torchrgb": ("torch", "float", "rgb"), "torch": ("torch", "float", "rgb"), "torchrgba": ("torch", "float", "rgba"), "pill": ("pil", None, "l"), "pil": ("pil", None, "rgb"), "pilrgb": ("pil", None, "rgb"), "pilrgba": ("pil", None, "rgba"), }
webdataset提供了多种数据的decode方式的示例,你也可以自定义decode的方法。具体的源码可以查看https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py

decoders = {
   
   
    "txt": lambda data: data.decode("utf-8"),
    "text": lambda data: data.decode("utf-8"),
    "transcript": lambda data: data.decode("utf-8"),
    "cls": lambda data: int(data),
    "cls2": lambda data: int(data),
    "index": lambda data: int(data),
    "inx": lambda data: int(data),
    "id": lambda data: int(data),
    "json": lambda data: json.loads(data),
    "jsn": lambda data: json.loads(data),
    "pyd": lambda data: pickle.loads(data),
    "pickle": lambda data: pickle.loads(data),
    "pth": lambda data: torch_loads(data),
    "ten": tenbin_loads,
    "tb": tenbin_loads,
    "mp": msgpack_loads,
    "msg": msgpack_loads,
    "npy": npy_loads,
    "npz": lambda data: np.load(io.BytesIO(data)),
    "cbor": cbor_loads,
}

如果是想要自己定义decode的方法,可以使用以下类似的方法。以下的方法中定义了my_decoder方法,这方法会判断dataset中sample的key是否为jpg,如果不是则忽略,是的话才会返回结果。要注意这里直接获得的数据类型都是bytes,你可以使用类似于imageio.imread(io.BytesIO(value))处理数据,将它转为图片。

def my_decoder(key, value):
        if not key.endswith(".jpg"):
            return None
        assert isinstance(value, bytes)
        return value

dataset = wds.WebDataset(url).shuffle(1000).decode(my_decoder)

示例代码

最后给出一个简单的webdataset多进程存储的方法,这里使用的dataset中返回sample是dict形式,最后以pickle的形式存放到指定数量的tar中。

import multiprocessing as mp
import webdataset as wds
import pickle
import os

def write_samples(dataset, tar_index, sample_index,save_dir):
    for t_idx, s_idx in zip(tar_index, sample_index):
        fname = os.path.join(save_dir,str(t_idx)+'.tar')
        stream = wds.TarWriter(fname)
        for idx in s_idx:
            data = dataset[idx]
            sample = {
   
   }
            sample['__key__'] = "sample%06d" % idx
            for key, value in data.items():
                sample[key +'.pyd'] = value
            stream.write(sample)
        stream.close()

def dataset2tar(dataset, save_dir,num_tars, num_workers):
    num_len = len(dataset)
    data_index = [i for i in range(num_len)]
    samples = [data_index[i::num_tars] for i in range(num_tars)]
    tar_index = list(range(num_tars))
    jobs = []
    for i in range(num_workers):
        job = mp.Process(target = write_samples,args=(dataset,tar_index[i::num_workers],samples[i::num_workers],save_dir))
        job.start()
        jobs.append(job)

    for job in jobs:
        job.join()

def pyd_decoder(key, data):
    if not key.endswith(".pyd"):
        return None
    result = pickle.loads(data)
    return result
相关文章
|
1月前
|
机器学习/深度学习 传感器 数据采集
深度学习在故障检测中的应用:从理论到实践
深度学习在故障检测中的应用:从理论到实践
147 6
|
2月前
|
机器学习/深度学习 人工智能 TensorFlow
人工智能浪潮下的自我修养:从Python编程入门到深度学习实践
【10月更文挑战第39天】本文旨在为初学者提供一条清晰的道路,从Python基础语法的掌握到深度学习领域的探索。我们将通过简明扼要的语言和实际代码示例,引导读者逐步构建起对人工智能技术的理解和应用能力。文章不仅涵盖Python编程的基础,还将深入探讨深度学习的核心概念、工具和实战技巧,帮助读者在AI的浪潮中找到自己的位置。
|
1月前
|
机器学习/深度学习 人工智能 自然语言处理
揭秘人工智能:深度学习的奥秘与实践
在本文中,我们将深入浅出地探索深度学习的神秘面纱。从基础概念到实际应用,你将获得一份简明扼要的指南,助你理解并运用这一前沿技术。我们避开复杂的数学公式和冗长的论述,以直观的方式呈现深度学习的核心原理和应用实例。无论你是技术新手还是有经验的开发者,这篇文章都将为你打开一扇通往人工智能新世界的大门。
|
1月前
|
机器学习/深度学习 算法 TensorFlow
深度学习中的自编码器:从理论到实践
在这篇文章中,我们将深入探讨深度学习的一个重要分支——自编码器。自编码器是一种无监督学习算法,它可以学习数据的有效表示。我们将首先介绍自编码器的基本概念和工作原理,然后通过一个简单的Python代码示例来展示如何实现一个基本的自编码器。最后,我们将讨论自编码器的一些变体,如稀疏自编码器和降噪自编码器,以及它们在实际应用中的优势。
|
1月前
|
机器学习/深度学习 人工智能 自然语言处理
揭秘AI:深度学习的奥秘与实践
本文将深入浅出地探讨人工智能中的一个重要分支——深度学习。我们将从基础概念出发,逐步揭示深度学习的原理和工作机制。通过生动的比喻和实际代码示例,本文旨在帮助初学者理解并应用深度学习技术,开启AI之旅。
|
1月前
|
机器学习/深度学习 人工智能 自然语言处理
深入浅出深度学习:从理论到实践的探索之旅
在人工智能的璀璨星空中,深度学习如同一颗耀眼的新星,以其强大的数据处理能力引领着技术革新的浪潮。本文将带您走进深度学习的核心概念,揭示其背后的数学原理,并通过实际案例展示如何应用深度学习模型解决现实世界的问题。无论您是初学者还是有一定基础的开发者,这篇文章都将为您提供宝贵的知识和启发。
60 5
|
2月前
|
机器学习/深度学习 人工智能 自然语言处理
深度学习中的卷积神经网络(CNN): 从理论到实践
本文将深入浅出地介绍卷积神经网络(CNN)的工作原理,并带领读者通过一个简单的图像分类项目,实现从理论到代码的转变。我们将探索CNN如何识别和处理图像数据,并通过实例展示如何训练一个有效的CNN模型。无论你是深度学习领域的新手还是希望扩展你的技术栈,这篇文章都将为你提供宝贵的知识和技能。
383 7
|
2月前
|
机器学习/深度学习 自然语言处理 语音技术
深入探索深度学习中的兼容性函数:从原理到实践
深入探索深度学习中的兼容性函数:从原理到实践
42 3
|
2月前
|
机器学习/深度学习 自然语言处理 网络架构
深度学习中的正则化技术:从理论到实践
在深度学习的海洋中,正则化技术如同灯塔指引着模型训练的方向。本文将深入探讨正则化的核心概念、常见类型及其在防止过拟合中的应用。通过实例分析,我们将展示如何在实践中运用这些技术以提升模型的泛化能力。
|
3月前
|
机器学习/深度学习 调度 计算机视觉
深度学习中的学习率调度:循环学习率、SGDR、1cycle 等方法介绍及实践策略研究
本文探讨了多种学习率调度策略在神经网络训练中的应用,强调了选择合适学习率的重要性。文章介绍了阶梯式衰减、余弦退火、循环学习率等策略,并分析了它们在不同实验设置下的表现。研究表明,循环学习率和SGDR等策略在提高模型性能和加快训练速度方面表现出色,而REX调度则在不同预算条件下表现稳定。这些策略为深度学习实践者提供了实用的指导。
83 2
深度学习中的学习率调度:循环学习率、SGDR、1cycle 等方法介绍及实践策略研究