【菜菜的CV进阶之路-Pytorch基础-数据处理】自定义数据集加载及预处理

简介: 【菜菜的CV进阶之路-Pytorch基础-数据处理】自定义数据集加载及预处理

前提:

本文的记录前提是---有一个完整、已调通的pytorch网络项目,因为暂时比赛要用,完整项目等过一段时间再打包发到github上...

比如:加载的pytorch自带cifar数据集:

1. # train、test图像预处理和增强
2. transform_train = transforms.Compose(
3.     [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(),
4.      transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ])
5. 
6. transform_test = transforms.Compose(
7.     [transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ])
8. 
9. #加载train、test数据集
10. trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
11. trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
12. 
13. testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
14. testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

数据预处理torchvision.transforms这一部分主要是进行数据的中心化(torchvision.transforms.CenterCrop)、随机剪切(torchvision.transforms.RandomCrop)、正则化、图片变为Tensor、tensor变为图片等。如果不懂请查看官方文档:TORCHVISION.TRANSFORMS

构建Dataset子类

那么如果我想将cifar数据集换成我自己的数据集怎么办呢?答案是:如果想要使用自己的数据,则必须自己构建一个torch.utils.data.Dataset的子类去读取数据。例如:

1. from __future__ import print_function
2. import torch.utils.data as data
3. import torch
4. 
5. class MyDataset(data.Dataset):
6.     def __init__(self, images, labels):#这一部分用于读取训练、测试数据
7.         self.images = images
8.         self.labels = labels
9. 
10.     def __getitem__(self, index):#这一部分将读取的数据输出,返回的是tensor格式
11.         img, target = self.images[index], self.labels[index]
12.         return img, target
13. 
14.     def __len__(self):
15.         return len(self.images)
16. 
17. dataset = MyDataset(images, labels)

查看torchvision.datasets.CIFAR10的源码也可以看到cifar10也是继承了Dataset这个类

在定义torch.utils.data.Dataset的子类时,必须重载的两个函数是__len__和__getitem__。__len__返回数据集的大小,__getitem__实现数据集的下标索引,返回对应的图像和标记(不一定非得返回图像和标记,返回元组的长度可以是任意长,这由网络需要的数据决定)。

在创建DataLoader时会判断__getitem__返回值的数据类型,然后用不同的if/else分支把数据转换成tensor,所以,_getitem_返回值的数据类型可选择范围很多,一种可以选择的数据类型是:图像为numpy.array,标记为int数据类型。

实例:

比如这里我需要读入我自己的32*32大小的图像数据,则代码为:

1. # 图像预处理和增强
2. transform_train = transforms.Compose(
3.     [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(),
4.      transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ])
5. 
6. transform_test = transforms.Compose(
7.     [transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ])
8. 
9. #加载train、test数据集
10. trainset = MyDataset(imgdir='./Train',imgpath='./Train.txt', transform=transform_train)
11. trainloader = torch.utils.data.DataLoader(trainset, batch_size=8, shuffle=True, num_workers=2)
12. 
13. testset = MyDataset(imgdir='./Test',imgpath='./Test.txt', train=False, transform=transform_test)
14. testloader = torch.utils.data.DataLoader(testset, batch_size=8, shuffle=False, num_workers=2)

这里将参数说明一下:MyDataset中imgdir为图片的存放位置,imgpath中包含每张图片的name以及图片的label,其它同cifar

torch.utils.data.DataLoader()函数,合成数据并且提供迭代访问。主要由两部分组成:

1. - dataset(Dataset)。输入加载的数据,就是上面的MyDataset的实现。 
2. - batch_size, shuffle, sampler, batch_sampler, num_worker, collate_fn, pin_memory, drop_last, timeout等参数,介绍几个比较常用的,这些在官方网站都有:
3. 
4.     - batch-size。样本每个batch的大小,默认为1。
5.     - shuffle。是否打乱数据,默认为False。
6.     - num_workers。数据分为几个线程处理默认为0。
7.     - sampler。定义一个方法来绘制样本数据,如果定义该方法,则不能使用shuffle。默认为False

之后就是在MyDataset里读取和输出自定义数据集: 简而言之就是在继承了Dataset类的Mydataset里,__init__函数中读取图像数据,DataLoader再通过__getitem__中获取输出

1. import os
2. import cv2
3. from PIL import Image
4. import numpy as np
5. from torch.utils.data import Dataset
6. 
7. class MyDataset(Dataset):
8.     def __init__(self, imgdir, imgpath,train=True,
9.                  transform=None, target_transform=None):
10.         self.root = os.path.expanduser(imgdir)
11.         self.transform = transform
12.         self.target_transform = target_transform
13.         self.train = train  # training set or test set
14. 
15.         # now load the picked numpy arrays
16.         if self.train:
17.             self.train_data = []#images
18.             self.train_labels = []#labels
19.             #read the images and labels in file Train.txt or Test.txt
20.             with open(imgpath,"r") as imgpath:
21.                 for line in imgpath:
22.                     line=line.split(' ')
23.                     image=Image.open(line[0])
24.                     image = np.array(image)
25.                     self.train_data.append(image)#将读取的图片放入train_data list里
26.                     self.train_labels.append(int(line[1]))#将读取图片的对应label放入train_labels里
27.                 imgpath.close()
28.             self.train_data = np.array(self.train_data)
29.             self.train_data = self.train_data.reshape((1000,3, 32, 32))
30.             self.train_data = self.train_data.transpose((0, 2, 3, 1))  # convert to HWC
31.         else:
32.             self.test_data = []#images
33.             self.test_labels = []#labels
34.             #read the images and labels in file Train.txt or Test.txt
35.             with open(imgpath,"r") as imgpath:
36.                 for line in imgpath:
37.                     line = line.split(' ')
38.                     image=Image.open(line[0])
39.                     image = np.array(image)
40.                     self.test_data.append(image)#将读取的图片放入train_data list里
41.                     self.test_labels.append(int(line[1]))#将读取图片的对应label放入train_labels里
42.                 imgpath.close()
43.             self.test_data = np.array(self.test_data)
44.             self.test_data = self.test_data.reshape((200, 3,32, 32))
45.             self.test_data = self.test_data.transpose((0, 2, 3, 1))  # convert to HWC
46. 
47.     def __getitem__(self, index):
48.         """
49.         Args:
50.             index (int): Index
51. 
52.         Returns:
53.             tuple: (image, target) where target is index of the target class.
54.         """
55.         if self.train:
56.             img, target = self.train_data[index], self.train_labels[index]
57.         else:
58.             img, target = self.test_data[index], self.test_labels[index]
59. 
60.         # doing this so that it is consistent with all other datasets
61.         # to return a PIL Image
62.         img = Image.fromarray(img)
63. 
64.         if self.transform is not None:
65.             img = self.transform(img)
66. 
67.         if self.target_transform is not None:
68.             target = self.target_transform(target)
69. 
70.         return img,target
71. 
72.     def __len__(self):
73.         if self.train:
74.             return len(self.train_data)
75.         else:
76.             return len(self.test_data)

参考:

https://blog.csdn.net/GYGuo95/article/details/78821520

https://blog.csdn.net/victoriaw/article/details/72356453

https://pytorch.org/docs/stable/torchvision/transforms.html


AIEarth是一个由众多领域内专家博主共同打造的学术平台,旨在建设一个拥抱智慧未来的学术殿堂!【平台地址:https://devpress.csdn.net/aiearth】 很高兴认识你!加入我们共同进步!

目录
相关文章
|
3月前
|
机器学习/深度学习 存储 PyTorch
PyTorch自定义学习率调度器实现指南
本文将详细介绍如何通过扩展PyTorch的 ``` LRScheduler ``` 类来实现一个具有预热阶段的余弦衰减调度器。我们将分五个关键步骤来完成这个过程。
128 2
|
6月前
|
机器学习/深度学习 人工智能 PyTorch
|
3月前
|
并行计算 PyTorch 算法框架/工具
基于CUDA12.1+CUDNN8.9+PYTORCH2.3.1,实现自定义数据集训练
文章介绍了如何在CUDA 12.1、CUDNN 8.9和PyTorch 2.3.1环境下实现自定义数据集的训练,包括环境配置、预览结果和核心步骤,以及遇到问题的解决方法和参考链接。
145 4
基于CUDA12.1+CUDNN8.9+PYTORCH2.3.1,实现自定义数据集训练
|
6月前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】38. Pytorch实战案例:梯度下降、随机梯度下降、小批量随机梯度下降3种优化算法对比【含数据集与源码】
【从零开始学习深度学习】38. Pytorch实战案例:梯度下降、随机梯度下降、小批量随机梯度下降3种优化算法对比【含数据集与源码】
|
4月前
|
机器学习/深度学习 PyTorch 数据处理
PyTorch数据处理:torch.utils.data模块的7个核心函数详解
在机器学习和深度学习项目中,数据处理是至关重要的一环。PyTorch作为一个强大的深度学习框架,提供了多种灵活且高效的数据处理工具
36 1
|
6月前
|
数据采集 PyTorch 数据处理
PyTorch的数据处理
PyTorch中,`Dataset`封装自定义数据集,`DataLoader`负责批量加载和多线程读取。例如,定义一个简单的`Dataset`类,包含数据和标签,然后使用`DataLoader`指定批大小和工作线程数。数据预处理包括导入如Excel的数据,图像数据集可通过`torchvision.datasets`加载。示例展示了如何从Excel文件创建`Dataset`,并用`DataLoader`读取。
|
6月前
|
机器学习/深度学习 人工智能 PyTorch
人工智能平台PAI产品使用合集之Alink是否加载预训练好的pytorch模型
阿里云人工智能平台PAI是一个功能强大、易于使用的AI开发平台,旨在降低AI开发门槛,加速创新,助力企业和开发者高效构建、部署和管理人工智能应用。其中包含了一系列相互协同的产品与服务,共同构成一个完整的人工智能开发与应用生态系统。以下是对PAI产品使用合集的概述,涵盖数据处理、模型开发、训练加速、模型部署及管理等多个环节。
|
6月前
|
机器学习/深度学习 自然语言处理 PyTorch
【从零开始学习深度学习】34. Pytorch-RNN项目实战:RNN创作歌词案例--使用周杰伦专辑歌词训练模型并创作歌曲【含数据集与源码】
【从零开始学习深度学习】34. Pytorch-RNN项目实战:RNN创作歌词案例--使用周杰伦专辑歌词训练模型并创作歌曲【含数据集与源码】
|
6月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】18. Pytorch中自定义层的几种方法:nn.Module、ParameterList和ParameterDict
【从零开始学习深度学习】18. Pytorch中自定义层的几种方法:nn.Module、ParameterList和ParameterDict
|
6月前
|
机器学习/深度学习 资源调度 PyTorch
【从零开始学习深度学习】15. Pytorch实战Kaggle比赛:房价预测案例【含数据集与源码】
【从零开始学习深度学习】15. Pytorch实战Kaggle比赛:房价预测案例【含数据集与源码】