Pytorch基本使用—自定义数据集

简介: Pytorch基本使用—自定义数据集

自定义数据集分为导入和打包两个过程。导入有三种方式,重载Dataset,构建迭代器,ImageFolder函数。打包利用DataLoader(数据集打包为一个个batch)。

✨1 导入

🎈1.1 重载Dataset

利用pytorch官方提供的自定义数据集的接口。

导入类

from torch.utils.data import DataSet

构造类的基本形式

class MyDataSet(DataSet):
    def __init__(self, [param1, param2, ...]):
        # 1.图像或序列数据路径
        # 2.label路径或内容
        # 3.数据增强操作初始化...
    def __len__(self):
        # 返回数据数量
    def __getitem__(self, index):
        # 1.根据index获取单个图像和真实标签
        # 2.对单个图像和标签进行数据增强操作

🎈1.2 图像通道问题

目前常用的打开图像库为opencvImage。其中opencv打开后是BGR形式,Image.open打开后是RGB,而模型训练时要求的顺序是RGB

因此,如果使用opencv打开图像,后面需要将BGR格式转化为RGB:

image = cv2.imread(filepath)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
h, w, c = image.shape

同时,也附上Image图像转化的代码:

img = Image.open(file).convert('RGB')
w,h = img.size

🌭1.3 ImageFolder

ImageFolder导入数据集时支持数据预处理,但是对储存数据的文件夹格式要求比较严格,格式如下:


dbaed7f9026b4647910d272454225d43.png

要求训练和测试放在不同的文件夹,并且其中每一类的数据放在不同类的文件夹下

类别文件夹的命名可识别,利用ImageFolder导入数据后,会把其中类别文件夹的名字作为标签映射为数字,属性class_to_idx可以提取到

举例:

from torchvision import datasets
data = datasets.ImageFolder(root, transform)
# root: 训练集/测试集路径,如图示 root = train
# transform: 预处理方法

✨2 打包

DataLoader类构造

from torch.utils.data import DataLoader
DataLoader(
    dataset, 
    batch_size, 
    shuffle=False, 
    sampler=None, 
    batch_sample=None, 
    num_workers=0, 
    collate_fn=None, 
    pin_memory=False, 
    drop_last=False, 
    timeout=0,
    worker_init_fn=None
)

参数(这里重点总结几个)

参数 描述

dataset 传入的数据集

batch_size batch的尺寸

shuffle 在每个epoch开始的时候,对数据进行重新排序

num_workers 默认为0,这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程

collate_fn 打包的规则,一个函数,下面会详细解释

pin_memory 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中

🍕2.1 构造num_workers

这个参数决定了有几个进程来处理data loading,一般边在8,batch_size,cpu数量之间来获取:

num_workers = min(8, batch_size if batch_size >1 else 0, os.cup_count())

在windows中该参数无法使用!!!

🎆2.2 collate_fn

打包规则,最后返回的必须是打包的训练/测试图片,打包的训练/测试真实标签

如果这个参数自定义,下面给一个实例

def collate_fn(batch):
    # 解释下面这行代码做的事情(假设batch_size=2),即将每个batch中的图片和真实标签分开打包。
    # 原batch的结构如下:batch(batch1, batch2), batch1/batch2(image, target)
    # 结果转化为:images(image1, image2), targets(target1, target2)
    images, targets = list(zip(*batch))
    # cat_list将batch中的图片统一大小,这里是自定义的,可以按照自己的需求替换。
    batched_imgs = cat_list(images, fill_value=0)
    batched_targets = cat_list(targets, fill_value=255)
    # 返回: 打包的训练/测试图片,打包的训练/测试真实标签
    return batched_imgs, batched_targets
def cat_list(images, fill_value=0):
    max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
    batch_shape = (len(images),) + max_size
    batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
    for img, pad_img in zip(images, batched_imgs):
        pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img)
    return batched_imgs

✨3 摄像头数据集的搭建

🍔3.1 完整代码

参考YOLOv5的实现,采用迭代器的方式创建。

class LoadStream:
    def __init__(self, idx: int):
        super(LoadStream, self).__init__()
        cap = cv2.VideoCapture(idx)
        assert cap.isOpened(), "摄像头{}打开失败".format(idx)
        self.cap = cap
        self.w = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        self.h = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        _, self.img = self.cap.read()
        self.fps = max(cap.get(cv2.CAP_PROP_FPS) % 100, 0) or 30.0
    def __iter__(self):  # 迭代器必须实现的函数
        self.count = 1
        return self
    def __next__(self):
        self.count += 1
        _, self.img = self.cap.read()
        return self.img

🎃3.2 __init__

    def __init__(self, idx: int):
        super(LoadStream, self).__init__()
        cap = cv2.VideoCapture(idx)
        assert cap.isOpened(), "摄像头{}打开失败".format(idx)
        self.cap = cap
        self.w = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        self.h = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        _, self.img = self.cap.read()
        self.fps = max(cap.get(cv2.CAP_PROP_FPS) % 100, 0) or 30.0
  1. 1. ==cap = cv2.VideoCapture(idx)==创建一个摄像头资源,其中idx代表摄像头编号,一般电脑自带的摄像头默认为0
  2. 2. self.w和self.h记录打开图像的宽和高,在目标检测等任务中可能用到。
  3. 3. _, self.img = self.cap.read()初始化图像数据
  4. 4. self.fps是视频的帧率

🎄3.3 __iter__

    def __iter__(self):  # 迭代器必须实现的函数
        self.count = 1
        return self

__iter__是迭代器必须实现的参数,初始化计数器self.count,并返回类的声明self

🍔3.4 __next__

    def __next__(self):
        self.count += 1
        _, self.img = self.cap.read()
        return self.img

这里是该迭代器的核心代码,每次迭代都会执行一次__next__

  1. 1. self.count += 1,计数器加1
  2. 2. _, self.img = self.cap.read()获取下一帧的图像

如果需要图像增强等,可以在return self.img之前执行。

相关文章
|
2月前
|
机器学习/深度学习 存储 PyTorch
PyTorch自定义学习率调度器实现指南
本文将详细介绍如何通过扩展PyTorch的 ``` LRScheduler ``` 类来实现一个具有预热阶段的余弦衰减调度器。我们将分五个关键步骤来完成这个过程。
63 2
|
5月前
|
机器学习/深度学习 人工智能 PyTorch
|
2月前
|
并行计算 PyTorch 算法框架/工具
基于CUDA12.1+CUDNN8.9+PYTORCH2.3.1,实现自定义数据集训练
文章介绍了如何在CUDA 12.1、CUDNN 8.9和PyTorch 2.3.1环境下实现自定义数据集的训练,包括环境配置、预览结果和核心步骤,以及遇到问题的解决方法和参考链接。
基于CUDA12.1+CUDNN8.9+PYTORCH2.3.1,实现自定义数据集训练
|
5月前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】38. Pytorch实战案例:梯度下降、随机梯度下降、小批量随机梯度下降3种优化算法对比【含数据集与源码】
【从零开始学习深度学习】38. Pytorch实战案例:梯度下降、随机梯度下降、小批量随机梯度下降3种优化算法对比【含数据集与源码】
|
5月前
|
机器学习/深度学习 自然语言处理 PyTorch
【从零开始学习深度学习】34. Pytorch-RNN项目实战:RNN创作歌词案例--使用周杰伦专辑歌词训练模型并创作歌曲【含数据集与源码】
【从零开始学习深度学习】34. Pytorch-RNN项目实战:RNN创作歌词案例--使用周杰伦专辑歌词训练模型并创作歌曲【含数据集与源码】
|
5月前
|
机器学习/深度学习 资源调度 PyTorch
【从零开始学习深度学习】15. Pytorch实战Kaggle比赛:房价预测案例【含数据集与源码】
【从零开始学习深度学习】15. Pytorch实战Kaggle比赛:房价预测案例【含数据集与源码】
|
5月前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】45. Pytorch迁移学习微调方法实战:使用微调技术进行2分类图片热狗识别模型训练【含源码与数据集】
【从零开始学习深度学习】45. Pytorch迁移学习微调方法实战:使用微调技术进行2分类图片热狗识别模型训练【含源码与数据集】
|
19天前
|
算法 PyTorch 算法框架/工具
Pytorch学习笔记(九):Pytorch模型的FLOPs、模型参数量等信息输出(torchstat、thop、ptflops、torchsummary)
本文介绍了如何使用torchstat、thop、ptflops和torchsummary等工具来计算Pytorch模型的FLOPs、模型参数量等信息。
82 2
|
21天前
|
机器学习/深度学习 自然语言处理 监控
利用 PyTorch Lightning 搭建一个文本分类模型
利用 PyTorch Lightning 搭建一个文本分类模型
43 8
利用 PyTorch Lightning 搭建一个文本分类模型
|
23天前
|
机器学习/深度学习 自然语言处理 数据建模
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力
本文深入探讨了Transformer模型中的三种关键注意力机制:自注意力、交叉注意力和因果自注意力,这些机制是GPT-4、Llama等大型语言模型的核心。文章不仅讲解了理论概念,还通过Python和PyTorch从零开始实现这些机制,帮助读者深入理解其内部工作原理。自注意力机制通过整合上下文信息增强了输入嵌入,多头注意力则通过多个并行的注意力头捕捉不同类型的依赖关系。交叉注意力则允许模型在两个不同输入序列间传递信息,适用于机器翻译和图像描述等任务。因果自注意力确保模型在生成文本时仅考虑先前的上下文,适用于解码器风格的模型。通过本文的详细解析和代码实现,读者可以全面掌握这些机制的应用潜力。
39 3
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力