【菜菜的CV进阶之路-Pytorch基础-数据处理】自定义数据集加载及预处理

简介: 【菜菜的CV进阶之路-Pytorch基础-数据处理】自定义数据集加载及预处理

前提:

本文的记录前提是---有一个完整、已调通的pytorch网络项目,因为暂时比赛要用,完整项目等过一段时间再打包发到github上...

比如:加载的pytorch自带cifar数据集:

1. # train、test图像预处理和增强
2. transform_train = transforms.Compose(
3.     [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(),
4.      transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ])
5. 
6. transform_test = transforms.Compose(
7.     [transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ])
8. 
9. #加载train、test数据集
10. trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
11. trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
12. 
13. testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
14. testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

数据预处理torchvision.transforms这一部分主要是进行数据的中心化(torchvision.transforms.CenterCrop)、随机剪切(torchvision.transforms.RandomCrop)、正则化、图片变为Tensor、tensor变为图片等。如果不懂请查看官方文档:TORCHVISION.TRANSFORMS

构建Dataset子类

那么如果我想将cifar数据集换成我自己的数据集怎么办呢?答案是:如果想要使用自己的数据,则必须自己构建一个torch.utils.data.Dataset的子类去读取数据。例如:

1. from __future__ import print_function
2. import torch.utils.data as data
3. import torch
4. 
5. class MyDataset(data.Dataset):
6.     def __init__(self, images, labels):#这一部分用于读取训练、测试数据
7.         self.images = images
8.         self.labels = labels
9. 
10.     def __getitem__(self, index):#这一部分将读取的数据输出,返回的是tensor格式
11.         img, target = self.images[index], self.labels[index]
12.         return img, target
13. 
14.     def __len__(self):
15.         return len(self.images)
16. 
17. dataset = MyDataset(images, labels)

查看torchvision.datasets.CIFAR10的源码也可以看到cifar10也是继承了Dataset这个类

在定义torch.utils.data.Dataset的子类时,必须重载的两个函数是__len__和__getitem__。__len__返回数据集的大小,__getitem__实现数据集的下标索引,返回对应的图像和标记(不一定非得返回图像和标记,返回元组的长度可以是任意长,这由网络需要的数据决定)。

在创建DataLoader时会判断__getitem__返回值的数据类型,然后用不同的if/else分支把数据转换成tensor,所以,_getitem_返回值的数据类型可选择范围很多,一种可以选择的数据类型是:图像为numpy.array,标记为int数据类型。

实例:

比如这里我需要读入我自己的32*32大小的图像数据,则代码为:

1. # 图像预处理和增强
2. transform_train = transforms.Compose(
3.     [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(),
4.      transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ])
5. 
6. transform_test = transforms.Compose(
7.     [transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ])
8. 
9. #加载train、test数据集
10. trainset = MyDataset(imgdir='./Train',imgpath='./Train.txt', transform=transform_train)
11. trainloader = torch.utils.data.DataLoader(trainset, batch_size=8, shuffle=True, num_workers=2)
12. 
13. testset = MyDataset(imgdir='./Test',imgpath='./Test.txt', train=False, transform=transform_test)
14. testloader = torch.utils.data.DataLoader(testset, batch_size=8, shuffle=False, num_workers=2)

这里将参数说明一下:MyDataset中imgdir为图片的存放位置,imgpath中包含每张图片的name以及图片的label,其它同cifar

torch.utils.data.DataLoader()函数,合成数据并且提供迭代访问。主要由两部分组成:

1. - dataset(Dataset)。输入加载的数据,就是上面的MyDataset的实现。 
2. - batch_size, shuffle, sampler, batch_sampler, num_worker, collate_fn, pin_memory, drop_last, timeout等参数,介绍几个比较常用的,这些在官方网站都有:
3. 
4.     - batch-size。样本每个batch的大小,默认为1。
5.     - shuffle。是否打乱数据,默认为False。
6.     - num_workers。数据分为几个线程处理默认为0。
7.     - sampler。定义一个方法来绘制样本数据,如果定义该方法,则不能使用shuffle。默认为False

之后就是在MyDataset里读取和输出自定义数据集: 简而言之就是在继承了Dataset类的Mydataset里,__init__函数中读取图像数据,DataLoader再通过__getitem__中获取输出

1. import os
2. import cv2
3. from PIL import Image
4. import numpy as np
5. from torch.utils.data import Dataset
6. 
7. class MyDataset(Dataset):
8.     def __init__(self, imgdir, imgpath,train=True,
9.                  transform=None, target_transform=None):
10.         self.root = os.path.expanduser(imgdir)
11.         self.transform = transform
12.         self.target_transform = target_transform
13.         self.train = train  # training set or test set
14. 
15.         # now load the picked numpy arrays
16.         if self.train:
17.             self.train_data = []#images
18.             self.train_labels = []#labels
19.             #read the images and labels in file Train.txt or Test.txt
20.             with open(imgpath,"r") as imgpath:
21.                 for line in imgpath:
22.                     line=line.split(' ')
23.                     image=Image.open(line[0])
24.                     image = np.array(image)
25.                     self.train_data.append(image)#将读取的图片放入train_data list里
26.                     self.train_labels.append(int(line[1]))#将读取图片的对应label放入train_labels里
27.                 imgpath.close()
28.             self.train_data = np.array(self.train_data)
29.             self.train_data = self.train_data.reshape((1000,3, 32, 32))
30.             self.train_data = self.train_data.transpose((0, 2, 3, 1))  # convert to HWC
31.         else:
32.             self.test_data = []#images
33.             self.test_labels = []#labels
34.             #read the images and labels in file Train.txt or Test.txt
35.             with open(imgpath,"r") as imgpath:
36.                 for line in imgpath:
37.                     line = line.split(' ')
38.                     image=Image.open(line[0])
39.                     image = np.array(image)
40.                     self.test_data.append(image)#将读取的图片放入train_data list里
41.                     self.test_labels.append(int(line[1]))#将读取图片的对应label放入train_labels里
42.                 imgpath.close()
43.             self.test_data = np.array(self.test_data)
44.             self.test_data = self.test_data.reshape((200, 3,32, 32))
45.             self.test_data = self.test_data.transpose((0, 2, 3, 1))  # convert to HWC
46. 
47.     def __getitem__(self, index):
48.         """
49.         Args:
50.             index (int): Index
51. 
52.         Returns:
53.             tuple: (image, target) where target is index of the target class.
54.         """
55.         if self.train:
56.             img, target = self.train_data[index], self.train_labels[index]
57.         else:
58.             img, target = self.test_data[index], self.test_labels[index]
59. 
60.         # doing this so that it is consistent with all other datasets
61.         # to return a PIL Image
62.         img = Image.fromarray(img)
63. 
64.         if self.transform is not None:
65.             img = self.transform(img)
66. 
67.         if self.target_transform is not None:
68.             target = self.target_transform(target)
69. 
70.         return img,target
71. 
72.     def __len__(self):
73.         if self.train:
74.             return len(self.train_data)
75.         else:
76.             return len(self.test_data)

参考:

https://blog.csdn.net/GYGuo95/article/details/78821520

https://blog.csdn.net/victoriaw/article/details/72356453

https://pytorch.org/docs/stable/torchvision/transforms.html


AIEarth是一个由众多领域内专家博主共同打造的学术平台,旨在建设一个拥抱智慧未来的学术殿堂!【平台地址:https://devpress.csdn.net/aiearth】 很高兴认识你!加入我们共同进步!

目录
相关文章
|
3月前
|
机器学习/深度学习 编解码 PyTorch
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【PyTorch实战演练】AlexNet网络模型构建并使用Cifar10数据集进行批量训练(附代码)
【PyTorch实战演练】AlexNet网络模型构建并使用Cifar10数据集进行批量训练(附代码)
84 0
|
5月前
|
数据可视化 PyTorch 算法框架/工具
使用PyTorch搭建VGG模型进行图像风格迁移实战(附源码和数据集)
使用PyTorch搭建VGG模型进行图像风格迁移实战(附源码和数据集)
152 1
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【PyTorch实战演练】使用Cifar10数据集训练LeNet5网络并实现图像分类(附代码)
【PyTorch实战演练】使用Cifar10数据集训练LeNet5网络并实现图像分类(附代码)
92 0
|
25天前
|
数据采集 机器学习/深度学习 PyTorch
PyTorch中的数据加载与预处理
【4月更文挑战第17天】了解PyTorch中的数据加载与预处理至关重要。通过`Dataset`和`DataLoader`,我们可以自定义数据集、实现批处理、数据混洗及多线程加载。`transforms`模块用于数据预处理,如图像转Tensor和归一化。本文展示了CIFAR10数据集的加载和预处理示例,强调了这些工具在深度学习项目中的重要性。
|
1月前
|
机器学习/深度学习 数据可视化 PyTorch
利用PyTorch实现基于MNIST数据集的手写数字识别
利用PyTorch实现基于MNIST数据集的手写数字识别
25 2
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具
通过实例学习Pytorch加载权重.load_state_dict()与保存权重.save()
通过实例学习Pytorch加载权重.load_state_dict()与保存权重.save()
19 0
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具
使用PyTorch加载数据集:简单指南
使用PyTorch加载数据集:简单指南
使用PyTorch加载数据集:简单指南
|
3月前
|
机器学习/深度学习 算法 PyTorch
pytorch实现手写数字识别 | MNIST数据集(全连接神经网络)
pytorch实现手写数字识别 | MNIST数据集(全连接神经网络)
|
5月前
|
机器学习/深度学习 自然语言处理 PyTorch
PyTorch搭建RNN联合嵌入模型(LSTM GRU)实现视觉问答(VQA)实战(超详细 附数据集和源码)
PyTorch搭建RNN联合嵌入模型(LSTM GRU)实现视觉问答(VQA)实战(超详细 附数据集和源码)
86 1