深度学习实践篇 第六章:webdataset

简介: 简要介绍webdataset的使用。

参考教程:
https://github.com/pytorch/pytorch/issues/38419
https://zhuanlan.zhihu.com/p/412772439
https://webdataset.github.io/webdataset/gettingstarted/


背景

训练数据通常是以个体的方式存储的,就像我们在第一章下载并处理成png格式后的cifar10数据,它以'xxx.png'的文件形式存放在一个一个独立的空间中。
随着数据集变得越来越大,这样的存放形式就不是那么高效和便捷。在进行模型训练时,也会因为数据的IO瓶颈拖慢训练的速度。
在使用Dataset中的数据时,我们的getitem(self, idx)函数会根据数据的index检索数据。在训练时,我们一般都会使用shuffle = True来完成数据的随机读取,这样索引的index也是无效的,当图片数据直接存放在系统上时,对文件的访问需要花费大量的代价。
这个问题可以使用sequential storage formats and sharding来解决。就像tensorflow中使用的TFRecord格式,它将训练集/测试集打包在一起使用,文件里存储的就是序列化的tf.Example。Pytorch是没有这种专属的数据存储格式的。

WebDataset

WebDataset提供了一种序列化存储大规模数据的方法,它将数据保存在tar包中,但是在使用时不需要对tar包进行解压。这种形式提供了高效的I/O,并且不管是在本地还是云端数据上都表现很不错。

webdataset的生成

webdataset是一个tar文件,所以你直接使用tar命令就可以进行文件的生成。

tar --sort=name -cf dataset.tar dataset/

我们也可以使用python调用webdataset的包,来进行文件的写入操作。
以下面的代码为例,下方的代码想要将现有的MNIST数据存放到'mnist.tar'文件中,因此它按照顺序将数据一个一个多写入了文件里。

dataset = torchvision.datasets.MNIST(root="./temp", download=True) # 获得MNIST数据
sink = wds.TarWriter("mnist.tar") # 使用TarWriter,准备将数据写入mnist.tar
for index, (input, output) in enumerate(dataset):
    if index%1000==0:
        print(f"{index:6d}", end="\r", flush=True, file=sys.stderr) # 每写入1000个数据,输出一些状态
    sink.write({
   
   
        "__key__": "sample%06d" % index, # 当前的数据的index
        "input.pyd": input, # 数据的input
        "output.pyd": output, # 数据的target
    })
sink.close() # 关闭当前文件。

这里的sink_write写入了是一个dict,其中'key'这一项决定了你想保存的数据的前缀名,’input.pyd'是你的input的数据的后缀,它同时也决定了你的数据存放的格式。
比如说这里使用的'pyd',就是我们之前说过的pickle格式,它可以保证数据的完整性,以不压缩的形式存储数据,缺点是不能被其它的语言读取。
在你明确知道数据的类型的情况下,你也可以使用别的格式来存放数据,比如说对于图片,你可以使用‘ppm','png','jpg'等格式,对于图片的标签,已知数据标签是整数的形式时,可以使用'cls'格式。

webdataset的加载

对于一个存入tar的webdataset的数据,你可以通过它的url对它进行读取,这个url可以是云端地址,也可以是本地路径。

import webdataset as wds
dataset = wds.WebDataset(url)

我们在讲数据存入tar时,writer根据我们定义的数据格式对数据进行了encode,所以我们直接读取到的数据是还没有decode的数据。
在教程中给了这样一个例子。
image.png

直接获取到的数据格式是bytes的格式。
你可以数据进行一些处理,webdataset提供一种链式的数据处理方法,比如上面的数据,你就可以使用下面的方法处理。

dataset = (
    wds.WebDataset(url)
    .shuffle(100)
    .decode("rgb")
    .to_tuple("jpg;png", "json")
)

这里的decode传入的'rgb'属于headler,webdataset提供了一些自带的imageheadler。帮助使用者进行数据类型转换。imagespecs = { "l8": ("numpy", "uint8", "l"), "rgb8": ("numpy", "uint8", "rgb"), "rgba8": ("numpy", "uint8", "rgba"), "l": ("numpy", "float", "l"), "rgb": ("numpy", "float", "rgb"), "rgba": ("numpy", "float", "rgba"), "torchl8": ("torch", "uint8", "l"), "torchrgb8": ("torch", "uint8", "rgb"), "torchrgba8": ("torch", "uint8", "rgba"), "torchl": ("torch", "float", "l"), "torchrgb": ("torch", "float", "rgb"), "torch": ("torch", "float", "rgb"), "torchrgba": ("torch", "float", "rgba"), "pill": ("pil", None, "l"), "pil": ("pil", None, "rgb"), "pilrgb": ("pil", None, "rgb"), "pilrgba": ("pil", None, "rgba"), }
webdataset提供了多种数据的decode方式的示例,你也可以自定义decode的方法。具体的源码可以查看https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py

decoders = {
   
   
    "txt": lambda data: data.decode("utf-8"),
    "text": lambda data: data.decode("utf-8"),
    "transcript": lambda data: data.decode("utf-8"),
    "cls": lambda data: int(data),
    "cls2": lambda data: int(data),
    "index": lambda data: int(data),
    "inx": lambda data: int(data),
    "id": lambda data: int(data),
    "json": lambda data: json.loads(data),
    "jsn": lambda data: json.loads(data),
    "pyd": lambda data: pickle.loads(data),
    "pickle": lambda data: pickle.loads(data),
    "pth": lambda data: torch_loads(data),
    "ten": tenbin_loads,
    "tb": tenbin_loads,
    "mp": msgpack_loads,
    "msg": msgpack_loads,
    "npy": npy_loads,
    "npz": lambda data: np.load(io.BytesIO(data)),
    "cbor": cbor_loads,
}

如果是想要自己定义decode的方法,可以使用以下类似的方法。以下的方法中定义了my_decoder方法,这方法会判断dataset中sample的key是否为jpg,如果不是则忽略,是的话才会返回结果。要注意这里直接获得的数据类型都是bytes,你可以使用类似于imageio.imread(io.BytesIO(value))处理数据,将它转为图片。

def my_decoder(key, value):
        if not key.endswith(".jpg"):
            return None
        assert isinstance(value, bytes)
        return value

dataset = wds.WebDataset(url).shuffle(1000).decode(my_decoder)

示例代码

最后给出一个简单的webdataset多进程存储的方法,这里使用的dataset中返回sample是dict形式,最后以pickle的形式存放到指定数量的tar中。

import multiprocessing as mp
import webdataset as wds
import pickle
import os

def write_samples(dataset, tar_index, sample_index,save_dir):
    for t_idx, s_idx in zip(tar_index, sample_index):
        fname = os.path.join(save_dir,str(t_idx)+'.tar')
        stream = wds.TarWriter(fname)
        for idx in s_idx:
            data = dataset[idx]
            sample = {
   
   }
            sample['__key__'] = "sample%06d" % idx
            for key, value in data.items():
                sample[key +'.pyd'] = value
            stream.write(sample)
        stream.close()

def dataset2tar(dataset, save_dir,num_tars, num_workers):
    num_len = len(dataset)
    data_index = [i for i in range(num_len)]
    samples = [data_index[i::num_tars] for i in range(num_tars)]
    tar_index = list(range(num_tars))
    jobs = []
    for i in range(num_workers):
        job = mp.Process(target = write_samples,args=(dataset,tar_index[i::num_workers],samples[i::num_workers],save_dir))
        job.start()
        jobs.append(job)

    for job in jobs:
        job.join()

def pyd_decoder(key, data):
    if not key.endswith(".pyd"):
        return None
    result = pickle.loads(data)
    return result
相关文章
|
3天前
|
机器学习/深度学习 算法
揭秘深度学习中的对抗性网络:理论与实践
【5月更文挑战第18天】 在深度学习领域的众多突破中,对抗性网络(GANs)以其独特的机制和强大的生成能力受到广泛关注。不同于传统的监督学习方法,GANs通过同时训练生成器与判别器两个模型,实现了无监督学习下的高效数据生成。本文将深入探讨对抗性网络的核心原理,解析其数学模型,并通过案例分析展示GANs在图像合成、风格迁移及增强学习等领域的应用。此外,我们还将讨论当前GANs面临的挑战以及未来的发展方向,为读者提供一个全面而深入的视角以理解这一颠覆性技术。
|
4天前
|
机器学习/深度学习 人工智能 算法
【AI】从零构建深度学习框架实践
【5月更文挑战第16天】 本文介绍了从零构建一个轻量级的深度学习框架tinynn,旨在帮助读者理解深度学习的基本组件和框架设计。构建过程包括设计框架架构、实现基本功能、模型定义、反向传播算法、训练和推理过程以及性能优化。文章详细阐述了网络层、张量、损失函数、优化器等组件的抽象和实现,并给出了一个基于MNIST数据集的分类示例,与TensorFlow进行了简单对比。tinynn的源代码可在GitHub上找到,目前支持多种层、损失函数和优化器,适用于学习和实验新算法。
59 2
|
6天前
|
机器学习/深度学习 人工智能 自然语言处理
深度理解深度学习:从理论到实践的探索
【5月更文挑战第3天】 在人工智能的浪潮中,深度学习以其卓越的性能和广泛的应用成为了研究的热点。本文将深入探讨深度学习的核心理论,解析其背后的数学原理,并通过实际案例分析如何将这些理论应用于解决现实世界的问题。我们将从神经网络的基础结构出发,逐步过渡到复杂的模型架构,同时讨论优化算法和正则化技巧。通过本文,读者将对深度学习有一个全面而深刻的认识,并能够在实践中更加得心应手地应用这些技术。
|
6天前
|
机器学习/深度学习 人工智能 缓存
安卓应用性能优化实践探索深度学习在图像识别中的应用进展
【4月更文挑战第30天】随着智能手机的普及,移动应用已成为用户日常生活的重要组成部分。对于安卓开发者而言,确保应用流畅、高效地运行在多样化的硬件上是一大挑战。本文将探讨针对安卓平台进行应用性能优化的策略和技巧,包括内存管理、多线程处理、UI渲染效率提升以及电池使用优化,旨在帮助开发者构建更加健壮、响应迅速的安卓应用。 【4月更文挑战第30天】 随着人工智能技术的迅猛发展,深度学习已成为推动计算机视觉领域革新的核心动力。本篇文章将深入分析深度学习技术在图像识别任务中的最新应用进展,并探讨其面临的挑战与未来发展趋势。通过梳理卷积神经网络(CNN)的优化策略、转移学习的实践应用以及增强学习与生成对
|
6天前
|
机器学习/深度学习 搜索推荐 算法
推荐系统算法的研究与实践:协同过滤、基于内容的推荐和深度学习推荐模型
推荐系统算法的研究与实践:协同过滤、基于内容的推荐和深度学习推荐模型
277 1
|
6天前
|
机器学习/深度学习 人工智能 自然语言处理
从零开始学习深度学习:入门指南与实践建议
本文将引导读者进入深度学习领域的大门,从基础概念到实际应用,为初学者提供全面的学习指南和实践建议。通过系统化的学习路径规划和案例实践,帮助读者快速掌握深度学习的核心知识和技能,迈出在人工智能领域的第一步。
|
6天前
|
机器学习/深度学习 Python
有没有一些开源的深度学习项目可以帮助我实践所学的知识?
【2月更文挑战第14天】【2月更文挑战第40篇】有没有一些开源的深度学习项目可以帮助我实践所学的知识?
|
6天前
|
机器学习/深度学习 人工智能 算法
【深度学习】因果推断与机器学习的高级实践 | 数学建模
【深度学习】因果推断与机器学习的高级实践 | 数学建模
|
9月前
|
机器学习/深度学习 PyTorch 算法框架/工具
深度学习实践篇 第五章:模型保存与加载
简要介绍pytorch中模型的保存与加载。
110 0
|
6天前
|
机器学习/深度学习 人工智能 算法
基于AidLux的工业视觉少样本缺陷检测实战应用---深度学习分割模型UNET的实践部署
  工业视觉在生产和制造中扮演着关键角色,而缺陷检测则是确保产品质量和生产效率的重要环节。工业视觉的前景与发展在于其在生产制造领域的关键作用,尤其是在少样本缺陷检测方面,借助AidLux技术和深度学习分割模型UNET的实践应用,深度学习分割模型UNET的实践部署变得至关重要。
72 1