前提:
本文的记录前提是---有一个完整、已调通的pytorch网络项目,因为暂时比赛要用,完整项目等过一段时间再打包发到github上...
比如:加载的pytorch自带cifar数据集:
1. # train、test图像预处理和增强 2. transform_train = transforms.Compose( 3. [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), 4. transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) 5. 6. transform_test = transforms.Compose( 7. [transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) 8. 9. #加载train、test数据集 10. trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) 11. trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) 12. 13. testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) 14. testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
数据预处理torchvision.transforms这一部分主要是进行数据的中心化(torchvision.transforms.CenterCrop)、随机剪切(torchvision.transforms.RandomCrop)、正则化、图片变为Tensor、tensor变为图片等。如果不懂请查看官方文档:TORCHVISION.TRANSFORMS
构建Dataset子类
那么如果我想将cifar数据集换成我自己的数据集怎么办呢?答案是:如果想要使用自己的数据,则必须自己构建一个torch.utils.data.Dataset的子类去读取数据。例如:
1. from __future__ import print_function 2. import torch.utils.data as data 3. import torch 4. 5. class MyDataset(data.Dataset): 6. def __init__(self, images, labels):#这一部分用于读取训练、测试数据 7. self.images = images 8. self.labels = labels 9. 10. def __getitem__(self, index):#这一部分将读取的数据输出,返回的是tensor格式 11. img, target = self.images[index], self.labels[index] 12. return img, target 13. 14. def __len__(self): 15. return len(self.images) 16. 17. dataset = MyDataset(images, labels)
查看torchvision.datasets.CIFAR10的源码也可以看到cifar10也是继承了Dataset这个类
在定义torch.utils.data.Dataset的子类时,必须重载的两个函数是__len__和__getitem__。__len__返回数据集的大小,__getitem__实现数据集的下标索引,返回对应的图像和标记(不一定非得返回图像和标记,返回元组的长度可以是任意长,这由网络需要的数据决定)。
在创建DataLoader时会判断__getitem__返回值的数据类型,然后用不同的if/else分支把数据转换成tensor,所以,_getitem_返回值的数据类型可选择范围很多,一种可以选择的数据类型是:图像为numpy.array,标记为int数据类型。
实例:
比如这里我需要读入我自己的32*32大小的图像数据,则代码为:
1. # 图像预处理和增强 2. transform_train = transforms.Compose( 3. [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), 4. transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) 5. 6. transform_test = transforms.Compose( 7. [transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) 8. 9. #加载train、test数据集 10. trainset = MyDataset(imgdir='./Train',imgpath='./Train.txt', transform=transform_train) 11. trainloader = torch.utils.data.DataLoader(trainset, batch_size=8, shuffle=True, num_workers=2) 12. 13. testset = MyDataset(imgdir='./Test',imgpath='./Test.txt', train=False, transform=transform_test) 14. testloader = torch.utils.data.DataLoader(testset, batch_size=8, shuffle=False, num_workers=2)
这里将参数说明一下:MyDataset中imgdir为图片的存放位置,imgpath中包含每张图片的name以及图片的label,其它同cifar
torch.utils.data.DataLoader()
函数,合成数据并且提供迭代访问。主要由两部分组成:
1. - dataset(Dataset)。输入加载的数据,就是上面的MyDataset的实现。 2. - batch_size, shuffle, sampler, batch_sampler, num_worker, collate_fn, pin_memory, drop_last, timeout等参数,介绍几个比较常用的,这些在官方网站都有: 3. 4. - batch-size。样本每个batch的大小,默认为1。 5. - shuffle。是否打乱数据,默认为False。 6. - num_workers。数据分为几个线程处理默认为0。 7. - sampler。定义一个方法来绘制样本数据,如果定义该方法,则不能使用shuffle。默认为False
之后就是在MyDataset里读取和输出自定义数据集: 简而言之就是在继承了Dataset类的Mydataset里,__init__函数中读取图像数据,DataLoader再通过
__getitem__中获取输出
1. import os 2. import cv2 3. from PIL import Image 4. import numpy as np 5. from torch.utils.data import Dataset 6. 7. class MyDataset(Dataset): 8. def __init__(self, imgdir, imgpath,train=True, 9. transform=None, target_transform=None): 10. self.root = os.path.expanduser(imgdir) 11. self.transform = transform 12. self.target_transform = target_transform 13. self.train = train # training set or test set 14. 15. # now load the picked numpy arrays 16. if self.train: 17. self.train_data = []#images 18. self.train_labels = []#labels 19. #read the images and labels in file Train.txt or Test.txt 20. with open(imgpath,"r") as imgpath: 21. for line in imgpath: 22. line=line.split(' ') 23. image=Image.open(line[0]) 24. image = np.array(image) 25. self.train_data.append(image)#将读取的图片放入train_data list里 26. self.train_labels.append(int(line[1]))#将读取图片的对应label放入train_labels里 27. imgpath.close() 28. self.train_data = np.array(self.train_data) 29. self.train_data = self.train_data.reshape((1000,3, 32, 32)) 30. self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC 31. else: 32. self.test_data = []#images 33. self.test_labels = []#labels 34. #read the images and labels in file Train.txt or Test.txt 35. with open(imgpath,"r") as imgpath: 36. for line in imgpath: 37. line = line.split(' ') 38. image=Image.open(line[0]) 39. image = np.array(image) 40. self.test_data.append(image)#将读取的图片放入train_data list里 41. self.test_labels.append(int(line[1]))#将读取图片的对应label放入train_labels里 42. imgpath.close() 43. self.test_data = np.array(self.test_data) 44. self.test_data = self.test_data.reshape((200, 3,32, 32)) 45. self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC 46. 47. def __getitem__(self, index): 48. """ 49. Args: 50. index (int): Index 51. 52. Returns: 53. tuple: (image, target) where target is index of the target class. 54. """ 55. if self.train: 56. img, target = self.train_data[index], self.train_labels[index] 57. else: 58. img, target = self.test_data[index], self.test_labels[index] 59. 60. # doing this so that it is consistent with all other datasets 61. # to return a PIL Image 62. img = Image.fromarray(img) 63. 64. if self.transform is not None: 65. img = self.transform(img) 66. 67. if self.target_transform is not None: 68. target = self.target_transform(target) 69. 70. return img,target 71. 72. def __len__(self): 73. if self.train: 74. return len(self.train_data) 75. else: 76. return len(self.test_data)
参考:
https://blog.csdn.net/GYGuo95/article/details/78821520
https://blog.csdn.net/victoriaw/article/details/72356453
https://pytorch.org/docs/stable/torchvision/transforms.html
AIEarth是一个由众多领域内专家博主共同打造的学术平台,旨在建设一个拥抱智慧未来的学术殿堂!【平台地址:https://devpress.csdn.net/aiearth】 很高兴认识你!加入我们共同进步!