从零搭建Pytorch模型教程(一)数据读取 ​

本文涉及的产品
云原生大数据计算服务MaxCompute,500CU*H 100GB 3个月
简介: 本文介绍了classdataset的几个要点,由哪些部分组成,每个部分需要完成哪些事情,如何进行数据增强,如何实现自己设计的数据增强。然后,介绍了分布式训练的数据加载方式,数据读取的整个流程,当面对超大数据集时,内存不足的改进思路。欢迎关注公众号CV技术指南,专注于计算机视觉的技术总结、最新技术跟踪、最新论文解读、各种技术教程、CV招聘信息发布等。关注公众号可邀请加入免费版的知识星球和技术交流群。

(零) 概述


浮躁是人性的一个典型的弱点,很多人总擅长看别人分享的现成代码解读的文章,看起来学会了好多东西,实际上仍然不具备自己从零搭建一个pipeline的能力。


在公众号(CV技术指南)的交流群里(群内交流氛围不错,有需要的请关注公众号加群),常有不少人问到一些问题,根据这些问题明显能看出是对pipeline不了解,却已经在搞项目或论文了,很难想象如果基本的pipeline都不懂,如何分析代码问题所在?如何分析结果不正常的可能原因?遇到问题如何改?


Pytorch在这几年逐渐成为了学术上的主流框架,其具有简单易懂的特点。网上有很多pytorch的教程,如果是一个已经懂的人去看这些教程,确实pipeline的要素都写到了,感觉这教程挺不错的。但实际上更多地像是写给自己看的一个笔记,记录了pipeline要写哪些东西,却没有介绍要怎么写,为什么这么写,刚入门的小白看的时候容易云里雾里。

鉴于此,本教程尝试对于pytorch搭建一个完整pipeline写一个比较明确且易懂的说明。

本教程将介绍以下内容:


  1. 准备数据,自定义classdataset,分布式训练的数据加载方式,加载超大数据集的改进思路。


  1. 搭建模型与模型初始化。


  1. 编写训练过程,包括加载预训练模型、设置优化器、设置损失函数等。


  1. 可视化并保存训练过程。


  1. 编写推理函数。

(一)数据读取


classdataset的定义

先来看一个完整的classdataset


import torch.utils.data as data
import torchvision.transforms as transforms
class MyDataset(data.Dataset):
   def __init__(self,data_folder):
       self.data_folder = data_folder
       self.filenames = []
       self.labels = []
       per_classes = os.listdir(data_folder)
       for per_class in per_classes:
           per_class_paths = os.path.join(data_folder, per_class)
           label = torch.tensor(int(per_class))
           per_datas = os.listdir(per_class_paths)
           for per_data in per_datas:
               self.filenames.append(os.path.join(per_class_paths, per_data))
               self.labels.append(label)
   def __getitem__(self, index):
       image = Image.open(self.filenames[index])
       label = self.labels[index]
       data = self.proprecess(image)
       return data, label
   def __len__(self):
       return len(self.filenames)
   def proprecess(self,data):
       transform_train_list = [
           transforms.Resize((self.opt.h, self.opt.w), interpolation=3),
           transforms.Pad(self.opt.pad, padding_mode='edge'),
           transforms.RandomCrop((self.opt.h, self.opt.w)),
           transforms.RandomHorizontalFlip(),
           transforms.ToTensor(),
           transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
      ]
       return transforms.Compose(transform_train_list)

classdataset的几个要点:


  1. classdataset类继承torch.utils.data.dataset。


  1. classdataset的作用是将任意格式的数据,通过读取、预处理或数据增强后以tensor的形式输出。其中任意格式的数据可能是以文件夹名作为类别的形式、或以txt文件存储图片地址的形式、或视频、或十几帧图像作为一份样本的形式。而输出则指的是经过处理后的一个batch的tensor格式数据和对应标签。


  1. classdataset主要有三个函数要完成:__init__函数、__getitem__ 函数和__len__函数。

 

__init__函数


init函数主要是完成两个静态变量的赋值。一个是用于存储所有数据路径的变量,变量的每个元素即为一份训练样本,(注:如果一份样本是十几帧图像,则变量每个元素存储的是这十几帧图像的路径),可以命名为self.filenames。一个是用于存储与数据路径变量一一对应的标签变量,可以命名为self.labels。


假如数据集的格式如下:


#这里的0,1指的是类别0,1
/data_path/0/image0.jpg
/data_path/0/image1.jpg
/data_path/0/image2.jpg
/data_path/0/image3.jpg
......
/data_path/1/image0.jpg
/data_path/1/image1.jpg
/data_path/1/image2.jpg
/data_path/1/image3.jpg

可通过per_classes = os.listdir(data_path) 获得所有类别的文件夹,在此处per_classes的每个元素即为对应的数据标签,通过for遍历per_classes即可获得每个类的标签,将其转换成int的tensor形式即可。在for下获得每个类下每张图片的路径,通过self.join获得每份样本的路径,通过append添加到self.filenames中。

 

__getitem__ 函数


getitem 函数主要是根据索引返回对应的数据。这个索引是在训练前通过dataloader切片获得的,这里先不管。它的参数默认是index,即每次传回在init函数中获得的所有样本中索引对应的数据和标签。因此,可通过下面两行代码找到对应的数据和标签。

欢迎关注公众号CV技术指南,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读

image = Image.open(self.filenames[index]))
label = self.labels[index]

获得数据后,进行数据预处理。数据预处理主要通过 torchvision.transforms 来完成,这里面已经包含了常用的预处理、数据增强方式


上面这里介绍了最常用的几种,主要就是resize,随机裁剪,翻转,归一化等。

最后通过transforms.Compose(transform_train_list)来执行。

 

除了这些已经有的数据增强方式外,在《数据增强方法总结》中还介绍了十几种特殊的数据增强方式,像这种自己设计了一种新的数据增强方式,该如何添加进去呢

下面以随机擦除作为例子。


class RandomErasing(object):
   """ Randomly selects a rectangle region in an image and erases its pixels.
      'Random Erasing Data Augmentation' by Zhong et al.
      See https://arxiv.org/pdf/1708.04896.pdf
  Args:
        probability: The probability that the Random Erasing operation will be performed.
        sl: Minimum proportion of erased area against input image.
        sh: Maximum proportion of erased area against input image.
        r1: Minimum aspect ratio of erased area.
        mean: Erasing value.
  """
   def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=[0.4914, 0.4822, 0.4465]):
       self.probability = probability
       self.mean = mean
       self.sl = sl
       self.sh = sh
       self.r1 = r1
   def __call__(self, img):
       if random.uniform(0, 1) > self.probability:
           return img
       for attempt in range(100):
           area = img.size()[1] * img.size()[2]
           target_area = random.uniform(self.sl, self.sh) * area
           aspect_ratio = random.uniform(self.r1, 1 / self.r1)
           h = int(round(math.sqrt(target_area * aspect_ratio)))
           w = int(round(math.sqrt(target_area / aspect_ratio)))
           if w < img.size()[2] and h < img.size()[1]:
               x1 = random.randint(0, img.size()[1] - h)
               y1 = random.randint(0, img.size()[2] - w)
               if img.size()[0] == 3:
                   img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
                   img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]
                   img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]
               else:
                   img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
               return img
       return img

如上所示,自己写一个类RandomErasing,继承object,在call函数里完成你的操作。在transform_train_list里添加上RandomErasing的定义即可。


transform_train_list = [
          transforms.Resize((self.opt.h, self.opt.w), interpolation=3),
          transforms.Pad(self.opt.pad, padding_mode='edge'),
          transforms.RandomCrop((self.opt.h, self.opt.w)),
          transforms.RandomHorizontalFlip(),
          transforms.ToTensor(),
          transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
          RandomErasing(probability=self.opt.erasing_p, mean=[0.0, 0.0, 0.0])
          #添加到这里
      ]

__len__函数


len函数主要就是返回数据长度,即样本的总数量。前面介绍了self.filenames的每个元素即为每份样本的路径,因此,self.filename的长度就是样本的数量。通过return len(self.filenames)即可返回数据长度。

 

验证classdataset

train_dataset = My_Dataset(data_folder=data_folder)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=False)
print('there are total %s batches for train' % (len(train_loader)))
for i,(data,label) in enumerate(train_loader):
    print(data.size(),label.size())

分布式训练的数据加载方式


前面介绍的是单卡的数据加载,实际上分布式也是这样,但为了高速高效读取,每张卡上也会保存所有数据的信息,即self.filenames和self.labels的信息。只是在DistributedSampler 中会给每张卡分配互不交叉的索引,然后由torch.utils.data.DataLoader来加载。

dataset = My_Dataset(data_folder=data_folder)
sampler = DistributedSampler(dataset) if is_distributed else None
loader = DataLoader(dataset, shuffle=(sampler is None), sampler=sampler)

数据读取的完整流程


结合上面这段代码,在这里,我们介绍以下读取数据的整个流程。


  1. 首先定义一个classdataset,在初始化函数里获得所有数据的信息。


  1. classdataset中实现getitem函数,通过索引来获取对应的数据,然后对数据进行预处理和数据增强。


  1. 在模型训练前,初始化classdataset,通过Dataloader来加载数据,其加载方式是通过Dataloader中分配的索引,调用getitem函数来获取。


关于索引的分配,在单卡上,可通过设置shuffle=True来随机生成索引顺序;在多机多卡的分布式训练上,shuffle操作通过DistributedSampler来完成,因此shuffle与sampler只能有一个,另一个必须为None。

 

超大数据集的加载思路


问题所在


再回顾一下上面这个流程,前面提到所有数据信息在classdataset初始化部分都会保存在变量中,因此当面对超大数据集时,会出现内存不足的情况。

思路


将切片获取索引的步骤放到classdataset初始化的位置,此时每张卡都是保存不同的数据子集。通过这种方式,可以将内存用量减少到原来的world_size倍(world_size指卡的数量)。


参考代码


class RankDataset(Dataset):
   '''
  实际流程
  获取rank和world_size 信息 -> 获取dataset长度 -> 根据dataset长度产生随机indices ->
  给不同的rank 分配indices -> 根据这些indices产生metas
  '''
   def __init__(self, meta_file, world_size, rank, seed):
       super(RankDataset, self).__init__()
       random.seed(seed)
       np.random.seed(seed)
       self.world_size = world_size
       self.rank = rank
       self.metas = self.parse(meta_file)
   def parse(self, meta_file):
       dataset_size = self.get_dataset_size(meta_file)                                     # 获取metafile的行数
       local_rank_index = self.get_local_index(dataset_size, self.rank, self.world_size)   # 根据world size和rank,获取当前epoch,当前rank需要训练的index。
       self.metas = self.read_file(meta_file, local_rank_index)
   def __getitem__(self, idx):
       return self.metas[idx]
   def __len__(self):
       return len(self.metas)
##train
for epoch_num in range(epoch_num):
   dataset = RankDataset("/path/to/meta", world_size, rank, seed=epoch_num)
   sampler = RandomSampler(datset)
   dataloader = DataLoader(
               dataset=dataset,
               batch_size=32,
               shuffle=False,
               num_workers=4,
               sampler=sampler)

但这种思路比较明显的问题时,为了让每张卡上在每个epoch都加载不同的训练子集,因此需要在每个epoch重新build dataloader。


总结


本篇文章介绍了数据读取的完整流程,如何自定义classdataset,如何进行数据增强,自己设计的数据增强如何写,分布式训练是如何加载数据的,超大数据集的数据加载改进思路。

相信读完本文的读者对数据读取有了比较清晰的认识,下一篇将介绍搭建模型与模型初始化。


欢迎关注公众号CV技术指南,专注于计算机视觉的技术总结、最新技术跟踪、最新论文解读、各种技术教程、CV招聘信息发布等。关注公众号可邀请加入免费版的知识星球和技术交流群。

相关实践学习
基于MaxCompute的热门话题分析
Apsara Clouder大数据专项技能认证配套课程:基于MaxCompute的热门话题分析
相关文章
|
18天前
|
机器学习/深度学习 数据采集 人工智能
PyTorch学习实战:AI从数学基础到模型优化全流程精解
本文系统讲解人工智能、机器学习与深度学习的层级关系,涵盖PyTorch环境配置、张量操作、数据预处理、神经网络基础及模型训练全流程,结合数学原理与代码实践,深入浅出地介绍激活函数、反向传播等核心概念,助力快速入门深度学习。
72 1
|
18天前
|
机器学习/深度学习 存储 PyTorch
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
Neural ODE将神经网络与微分方程结合,用连续思维建模数据演化,突破传统离散层的限制,实现自适应深度与高效连续学习。
55 3
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
|
2月前
|
PyTorch 算法框架/工具 异构计算
PyTorch 2.0性能优化实战:4种常见代码错误严重拖慢模型
我们将深入探讨图中断(graph breaks)和多图问题对性能的负面影响,并分析PyTorch模型开发中应当避免的常见错误模式。
146 9
|
4月前
|
机器学习/深度学习 存储 PyTorch
PyTorch + MLFlow 实战:从零构建可追踪的深度学习模型训练系统
本文通过使用 Kaggle 数据集训练情感分析模型的实例,详细演示了如何将 PyTorch 与 MLFlow 进行深度集成,实现完整的实验跟踪、模型记录和结果可复现性管理。文章将系统性地介绍训练代码的核心组件,展示指标和工件的记录方法,并提供 MLFlow UI 的详细界面截图。
154 2
PyTorch + MLFlow 实战:从零构建可追踪的深度学习模型训练系统
|
3月前
|
机器学习/深度学习 数据可视化 PyTorch
Flow Matching生成模型:从理论基础到Pytorch代码实现
本文将系统阐述Flow Matching的完整实现过程,包括数学理论推导、模型架构设计、训练流程构建以及速度场学习等关键组件。通过本文的学习,读者将掌握Flow Matching的核心原理,获得一个完整的PyTorch实现,并对生成模型在噪声调度和分数函数之外的发展方向有更深入的理解。
1129 0
Flow Matching生成模型:从理论基础到Pytorch代码实现
|
4月前
|
机器学习/深度学习 PyTorch 算法框架/工具
提升模型泛化能力:PyTorch的L1、L2、ElasticNet正则化技术深度解析与代码实现
本文将深入探讨L1、L2和ElasticNet正则化技术,重点关注其在PyTorch框架中的具体实现。关于这些技术的理论基础,建议读者参考相关理论文献以获得更深入的理解。
108 4
提升模型泛化能力:PyTorch的L1、L2、ElasticNet正则化技术深度解析与代码实现
|
5月前
|
机器学习/深度学习 搜索推荐 PyTorch
基于昇腾用PyTorch实现CTR模型DIN(Deep interest Netwok)网络
本文详细讲解了如何在昇腾平台上使用PyTorch训练推荐系统中的经典模型DIN(Deep Interest Network)。主要内容包括:DIN网络的创新点与架构剖析、Activation Unit和Attention模块的实现、Amazon-book数据集的介绍与预处理、模型训练过程定义及性能评估。通过实战演示,利用Amazon-book数据集训练DIN模型,最终评估其点击率预测性能。文中还提供了代码示例,帮助读者更好地理解每个步骤的实现细节。
|
5月前
|
机器学习/深度学习 PyTorch API
PyTorch量化感知训练技术:模型压缩与高精度边缘部署实践
本文深入探讨神经网络模型量化技术,重点讲解训练后量化(PTQ)与量化感知训练(QAT)两种主流方法。PTQ通过校准数据集确定量化参数,快速实现模型压缩,但精度损失较大;QAT在训练中引入伪量化操作,使模型适应低精度环境,显著提升量化后性能。文章结合PyTorch实现细节,介绍Eager模式、FX图模式及PyTorch 2导出量化等工具,并分享大语言模型Int4/Int8混合精度实践。最后总结量化最佳策略,包括逐通道量化、混合精度设置及目标硬件适配,助力高效部署深度学习模型。
683 21
PyTorch量化感知训练技术:模型压缩与高精度边缘部署实践
|
7月前
|
机器学习/深度学习 JavaScript PyTorch
9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体
生成对抗网络(GAN)的训练效果高度依赖于损失函数的选择。本文介绍了经典GAN损失函数理论,并用PyTorch实现多种变体,包括原始GAN、LS-GAN、WGAN及WGAN-GP等。通过分析其原理与优劣,如LS-GAN提升训练稳定性、WGAN-GP改善图像质量,展示了不同场景下损失函数的设计思路。代码实现覆盖生成器与判别器的核心逻辑,为实际应用提供了重要参考。未来可探索组合优化与自适应设计以提升性能。
470 7
9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体
|
5月前
|
机器学习/深度学习 PyTorch 编译器
深入解析torch.compile:提升PyTorch模型性能、高效解决常见问题
PyTorch 2.0推出的`torch.compile`功能为深度学习模型带来了显著的性能优化能力。本文从实用角度出发,详细介绍了`torch.compile`的核心技巧与应用场景,涵盖模型复杂度评估、可编译组件分析、系统化调试策略及性能优化高级技巧等内容。通过解决图断裂、重编译频繁等问题,并结合分布式训练和NCCL通信优化,开发者可以有效提升日常开发效率与模型性能。文章为PyTorch用户提供了全面的指导,助力充分挖掘`torch.compile`的潜力。
496 17

热门文章

最新文章

推荐镜像

更多