深度学习实践篇 第六章:webdataset

简介: 简要介绍webdataset的使用。

参考教程:
https://github.com/pytorch/pytorch/issues/38419
https://zhuanlan.zhihu.com/p/412772439
https://webdataset.github.io/webdataset/gettingstarted/


背景

训练数据通常是以个体的方式存储的,就像我们在第一章下载并处理成png格式后的cifar10数据,它以'xxx.png'的文件形式存放在一个一个独立的空间中。
随着数据集变得越来越大,这样的存放形式就不是那么高效和便捷。在进行模型训练时,也会因为数据的IO瓶颈拖慢训练的速度。
在使用Dataset中的数据时,我们的getitem(self, idx)函数会根据数据的index检索数据。在训练时,我们一般都会使用shuffle = True来完成数据的随机读取,这样索引的index也是无效的,当图片数据直接存放在系统上时,对文件的访问需要花费大量的代价。
这个问题可以使用sequential storage formats and sharding来解决。就像tensorflow中使用的TFRecord格式,它将训练集/测试集打包在一起使用,文件里存储的就是序列化的tf.Example。Pytorch是没有这种专属的数据存储格式的。

WebDataset

WebDataset提供了一种序列化存储大规模数据的方法,它将数据保存在tar包中,但是在使用时不需要对tar包进行解压。这种形式提供了高效的I/O,并且不管是在本地还是云端数据上都表现很不错。

webdataset的生成

webdataset是一个tar文件,所以你直接使用tar命令就可以进行文件的生成。

tar --sort=name -cf dataset.tar dataset/

我们也可以使用python调用webdataset的包,来进行文件的写入操作。
以下面的代码为例,下方的代码想要将现有的MNIST数据存放到'mnist.tar'文件中,因此它按照顺序将数据一个一个多写入了文件里。

dataset = torchvision.datasets.MNIST(root="./temp", download=True) # 获得MNIST数据
sink = wds.TarWriter("mnist.tar") # 使用TarWriter,准备将数据写入mnist.tar
for index, (input, output) in enumerate(dataset):
    if index%1000==0:
        print(f"{index:6d}", end="\r", flush=True, file=sys.stderr) # 每写入1000个数据,输出一些状态
    sink.write({
   
   
        "__key__": "sample%06d" % index, # 当前的数据的index
        "input.pyd": input, # 数据的input
        "output.pyd": output, # 数据的target
    })
sink.close() # 关闭当前文件。

这里的sink_write写入了是一个dict,其中'key'这一项决定了你想保存的数据的前缀名,’input.pyd'是你的input的数据的后缀,它同时也决定了你的数据存放的格式。
比如说这里使用的'pyd',就是我们之前说过的pickle格式,它可以保证数据的完整性,以不压缩的形式存储数据,缺点是不能被其它的语言读取。
在你明确知道数据的类型的情况下,你也可以使用别的格式来存放数据,比如说对于图片,你可以使用‘ppm','png','jpg'等格式,对于图片的标签,已知数据标签是整数的形式时,可以使用'cls'格式。

webdataset的加载

对于一个存入tar的webdataset的数据,你可以通过它的url对它进行读取,这个url可以是云端地址,也可以是本地路径。

import webdataset as wds
dataset = wds.WebDataset(url)

我们在讲数据存入tar时,writer根据我们定义的数据格式对数据进行了encode,所以我们直接读取到的数据是还没有decode的数据。
在教程中给了这样一个例子。
image.png

直接获取到的数据格式是bytes的格式。
你可以数据进行一些处理,webdataset提供一种链式的数据处理方法,比如上面的数据,你就可以使用下面的方法处理。

dataset = (
    wds.WebDataset(url)
    .shuffle(100)
    .decode("rgb")
    .to_tuple("jpg;png", "json")
)

这里的decode传入的'rgb'属于headler,webdataset提供了一些自带的imageheadler。帮助使用者进行数据类型转换。imagespecs = { "l8": ("numpy", "uint8", "l"), "rgb8": ("numpy", "uint8", "rgb"), "rgba8": ("numpy", "uint8", "rgba"), "l": ("numpy", "float", "l"), "rgb": ("numpy", "float", "rgb"), "rgba": ("numpy", "float", "rgba"), "torchl8": ("torch", "uint8", "l"), "torchrgb8": ("torch", "uint8", "rgb"), "torchrgba8": ("torch", "uint8", "rgba"), "torchl": ("torch", "float", "l"), "torchrgb": ("torch", "float", "rgb"), "torch": ("torch", "float", "rgb"), "torchrgba": ("torch", "float", "rgba"), "pill": ("pil", None, "l"), "pil": ("pil", None, "rgb"), "pilrgb": ("pil", None, "rgb"), "pilrgba": ("pil", None, "rgba"), }
webdataset提供了多种数据的decode方式的示例,你也可以自定义decode的方法。具体的源码可以查看https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py

decoders = {
   
   
    "txt": lambda data: data.decode("utf-8"),
    "text": lambda data: data.decode("utf-8"),
    "transcript": lambda data: data.decode("utf-8"),
    "cls": lambda data: int(data),
    "cls2": lambda data: int(data),
    "index": lambda data: int(data),
    "inx": lambda data: int(data),
    "id": lambda data: int(data),
    "json": lambda data: json.loads(data),
    "jsn": lambda data: json.loads(data),
    "pyd": lambda data: pickle.loads(data),
    "pickle": lambda data: pickle.loads(data),
    "pth": lambda data: torch_loads(data),
    "ten": tenbin_loads,
    "tb": tenbin_loads,
    "mp": msgpack_loads,
    "msg": msgpack_loads,
    "npy": npy_loads,
    "npz": lambda data: np.load(io.BytesIO(data)),
    "cbor": cbor_loads,
}

如果是想要自己定义decode的方法,可以使用以下类似的方法。以下的方法中定义了my_decoder方法,这方法会判断dataset中sample的key是否为jpg,如果不是则忽略,是的话才会返回结果。要注意这里直接获得的数据类型都是bytes,你可以使用类似于imageio.imread(io.BytesIO(value))处理数据,将它转为图片。

def my_decoder(key, value):
        if not key.endswith(".jpg"):
            return None
        assert isinstance(value, bytes)
        return value

dataset = wds.WebDataset(url).shuffle(1000).decode(my_decoder)

示例代码

最后给出一个简单的webdataset多进程存储的方法,这里使用的dataset中返回sample是dict形式,最后以pickle的形式存放到指定数量的tar中。

import multiprocessing as mp
import webdataset as wds
import pickle
import os

def write_samples(dataset, tar_index, sample_index,save_dir):
    for t_idx, s_idx in zip(tar_index, sample_index):
        fname = os.path.join(save_dir,str(t_idx)+'.tar')
        stream = wds.TarWriter(fname)
        for idx in s_idx:
            data = dataset[idx]
            sample = {
   
   }
            sample['__key__'] = "sample%06d" % idx
            for key, value in data.items():
                sample[key +'.pyd'] = value
            stream.write(sample)
        stream.close()

def dataset2tar(dataset, save_dir,num_tars, num_workers):
    num_len = len(dataset)
    data_index = [i for i in range(num_len)]
    samples = [data_index[i::num_tars] for i in range(num_tars)]
    tar_index = list(range(num_tars))
    jobs = []
    for i in range(num_workers):
        job = mp.Process(target = write_samples,args=(dataset,tar_index[i::num_workers],samples[i::num_workers],save_dir))
        job.start()
        jobs.append(job)

    for job in jobs:
        job.join()

def pyd_decoder(key, data):
    if not key.endswith(".pyd"):
        return None
    result = pickle.loads(data)
    return result
相关文章
|
1月前
|
机器学习/深度学习 人工智能 自然语言处理
深度学习中的迁移学习:从理论到实践
科技进步不断推动人工智能的发展,其中深度学习已成为最炙手可热的领域。然而,训练深度学习模型通常需要大量的数据和计算资源,这对于许多实际应用来说是一个显著的障碍。迁移学习作为一种有效的方法,通过利用已有模型在新任务上的再训练,大大减少了数据和计算资源的需求。本文将详细探讨迁移学习的理论基础、各种实现方法以及其在实际应用中的优势和挑战。
|
2月前
|
机器学习/深度学习 人工智能 自然语言处理
探索深度学习的奥秘:从理论到实践
【5月更文挑战第31天】本文将深入探讨深度学习的理论基础和实践应用,揭示其在解决复杂问题中的强大能力。我们将从深度学习的基本概念开始,然后讨论其在不同领域的应用,最后分享一些实践经验和技巧。
|
5天前
|
机器学习/深度学习 人工智能 自然语言处理
揭秘深度学习:从理论到实践的技术之旅
【7月更文挑战第10天】本文将深入探索深度学习的奥秘,从其理论基础讲起,穿越关键技术和算法的发展,直至应用案例的实现。我们将一窥深度学习如何变革数据处理、图像识别、自然语言处理等领域,并讨论当前面临的挑战与未来发展趋势。
|
9天前
|
机器学习/深度学习 搜索推荐 算法
深度学习在推荐系统中的应用:技术解析与实践
【7月更文挑战第6天】深度学习在推荐系统中的应用为推荐算法的发展带来了新的机遇和挑战。通过深入理解深度学习的技术原理和应用场景,并结合具体的实践案例,我们可以更好地构建高效、准确的推荐系统,为用户提供更加个性化的推荐服务。
|
1月前
|
机器学习/深度学习 API TensorFlow
Keras深度学习框架入门与实践
**Keras**是Python的高级神经网络API,支持TensorFlow、Theano和CNTK后端。因其用户友好、模块化和可扩展性受到深度学习开发者欢迎。本文概述了Keras的基础,包括**模型构建**(Sequential和Functional API)、**编译与训练**(选择优化器、损失函数和评估指标)以及**评估与预测**。还提供了一个**代码示例**,展示如何使用Keras构建和训练简单的卷积神经网络(CNN)进行MNIST手写数字分类。最后,强调Keras简化了复杂神经网络的构建和训练过程。【6月更文挑战第7天】
25 7
|
1月前
|
机器学习/深度学习 人工智能 自然语言处理
探索深度学习:从理论到实践
【6月更文挑战第4天】本文深入探讨了深度学习的理论基础和实践应用,包括其发展历程、主要模型、以及在图像识别、自然语言处理等领域的应用。文章不仅提供了对深度学习的全面理解,还通过实例展示了如何将理论知识转化为实际的技术解决方案。
|
2月前
|
机器学习/深度学习 传感器 自动驾驶
基于深度学习的图像识别技术在自动驾驶系统中的应用构建高效云原生应用:云平台的选择与实践
【5月更文挑战第31天】 随着人工智能技术的飞速发展,深度学习已经成为推动计算机视觉进步的关键力量。特别是在图像识别领域,通过模仿人脑处理信息的方式,深度学习模型能够从大量数据中学习并识别复杂的图像模式。本文将探讨深度学习技术在自动驾驶系统中图像识别方面的应用,重点分析卷积神经网络(CNN)的结构与优化策略,以及如何通过这些技术提高自动驾驶车辆的环境感知能力。此外,文章还将讨论目前所面临的挑战和未来的研究方向。
|
2月前
|
机器学习/深度学习 算法 大数据
基于深度学习的图像识别技术:原理与实践
基于深度学习的图像识别技术:原理与实践
46 4
|
2月前
|
机器学习/深度学习 传感器 自动驾驶
基于深度学习的图像识别技术在自动驾驶系统中的应用深入理解操作系统内存管理:原理与实践
【5月更文挑战第28天】 随着人工智能技术的飞速发展,图像识别作为其重要分支之一,在多个领域展现出了广泛的应用潜力。尤其是在自动驾驶系统中,基于深度学习的图像识别技术已成为实现车辆环境感知和决策的关键。本文将深入探讨深度学习算法在自动驾驶图像识别中的作用,分析其面临的挑战以及未来的发展趋势,并以此为基础,展望该技术对自动驾驶安全性和效率的影响。
|
2月前
|
机器学习/深度学习 数据采集 算法
利用深度学习优化图像识别准确性的策略与实践
【5月更文挑战第26天】 在计算机视觉领域,图像识别的准确性直接影响着算法的实用性和效率。本文针对当前深度学习在图像识别中的应用进行探讨,提出了一系列优化策略,旨在提升模型的识别精度。文中首先概述了深度学习在图像识别中的基础框架,随后深入分析了数据预处理、网络结构设计、损失函数定制以及训练技巧等方面的优化方法。通过实验验证,这些策略能显著提高模型在复杂环境下的表现能力。