自定义数据集分为导入和打包两个过程。导入有三种方式,重载Dataset,构建迭代器,ImageFolder函数。打包利用DataLoader(数据集打包为一个个batch)。
✨1 导入
🎈1.1 重载Dataset
利用pytorch官方提供的自定义数据集的接口。
导入类:
from torch.utils.data import DataSet
构造类的基本形式:
class MyDataSet(DataSet): def __init__(self, [param1, param2, ...]): # 1.图像或序列数据路径 # 2.label路径或内容 # 3.数据增强操作初始化... def __len__(self): # 返回数据数量 def __getitem__(self, index): # 1.根据index获取单个图像和真实标签 # 2.对单个图像和标签进行数据增强操作
🎈1.2 图像通道问题
目前常用的打开图像库为opencv
和Image
。其中opencv打开后是BGR形式,Image.open打开后是RGB,而模型训练时要求的顺序是RGB。
因此,如果使用opencv打开图像,后面需要将BGR格式转化为RGB:
image = cv2.imread(filepath) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) h, w, c = image.shape
同时,也附上Image图像转化的代码:
img = Image.open(file).convert('RGB') w,h = img.size
🌭1.3 ImageFolder
ImageFolder导入数据集时支持数据预处理,但是对储存数据的文件夹格式要求比较严格,格式如下:
要求训练和测试放在不同的文件夹,并且其中每一类的数据放在不同类的文件夹下。
类别文件夹的命名可识别,利用ImageFolder导入数据后,会把其中类别文件夹的名字作为标签映射为数字,属性class_to_idx
可以提取到。
举例:
from torchvision import datasets data = datasets.ImageFolder(root, transform) # root: 训练集/测试集路径,如图示 root = train # transform: 预处理方法
✨2 打包
DataLoader类构造
from torch.utils.data import DataLoader DataLoader( dataset, batch_size, shuffle=False, sampler=None, batch_sample=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None )
参数(这里重点总结几个)
参数 描述
dataset 传入的数据集
batch_size batch的尺寸
shuffle 在每个epoch开始的时候,对数据进行重新排序
num_workers 默认为0,这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程
collate_fn 打包的规则,一个函数,下面会详细解释
pin_memory 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中
🍕2.1 构造num_workers
这个参数决定了有几个进程来处理data loading,一般边在8,batch_size,cpu数量之间来获取:
num_workers = min(8, batch_size if batch_size >1 else 0, os.cup_count())
在windows中该参数无法使用!!!
🎆2.2 collate_fn
打包规则,最后返回的必须是打包的训练/测试图片,打包的训练/测试真实标签。
如果这个参数自定义,下面给一个实例:
def collate_fn(batch): # 解释下面这行代码做的事情(假设batch_size=2),即将每个batch中的图片和真实标签分开打包。 # 原batch的结构如下:batch(batch1, batch2), batch1/batch2(image, target) # 结果转化为:images(image1, image2), targets(target1, target2) images, targets = list(zip(*batch)) # cat_list将batch中的图片统一大小,这里是自定义的,可以按照自己的需求替换。 batched_imgs = cat_list(images, fill_value=0) batched_targets = cat_list(targets, fill_value=255) # 返回: 打包的训练/测试图片,打包的训练/测试真实标签 return batched_imgs, batched_targets def cat_list(images, fill_value=0): max_size = tuple(max(s) for s in zip(*[img.shape for img in images])) batch_shape = (len(images),) + max_size batched_imgs = images[0].new(*batch_shape).fill_(fill_value) for img, pad_img in zip(images, batched_imgs): pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img) return batched_imgs
✨3 摄像头数据集的搭建
🍔3.1 完整代码
参考YOLOv5的实现,采用迭代器的方式创建。
class LoadStream: def __init__(self, idx: int): super(LoadStream, self).__init__() cap = cv2.VideoCapture(idx) assert cap.isOpened(), "摄像头{}打开失败".format(idx) self.cap = cap self.w = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) self.h = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) _, self.img = self.cap.read() self.fps = max(cap.get(cv2.CAP_PROP_FPS) % 100, 0) or 30.0 def __iter__(self): # 迭代器必须实现的函数 self.count = 1 return self def __next__(self): self.count += 1 _, self.img = self.cap.read() return self.img
🎃3.2 __init__
def __init__(self, idx: int): super(LoadStream, self).__init__() cap = cv2.VideoCapture(idx) assert cap.isOpened(), "摄像头{}打开失败".format(idx) self.cap = cap self.w = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) self.h = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) _, self.img = self.cap.read() self.fps = max(cap.get(cv2.CAP_PROP_FPS) % 100, 0) or 30.0
- 1. ==cap = cv2.VideoCapture(idx)==创建一个摄像头资源,其中idx代表摄像头编号,一般电脑自带的摄像头默认为0
- 2. self.w和self.h记录打开图像的宽和高,在目标检测等任务中可能用到。
- 3. _, self.img = self.cap.read()初始化图像数据
4. self.fps
是视频的帧率
🎄3.3 __iter__
def __iter__(self): # 迭代器必须实现的函数 self.count = 1 return self
__iter__
是迭代器必须实现的参数,初始化计数器self.count
,并返回类的声明self
🍔3.4 __next__
def __next__(self): self.count += 1 _, self.img = self.cap.read() return self.img
这里是该迭代器的核心代码,每次迭代都会执行一次__next__
:
1. self.count += 1
,计数器加12. _, self.img = self.cap.read()
获取下一帧的图像
如果需要图像增强等,可以在return self.img
之前执行。