pytorch中自定义数据集加载对象重写Dataset

简介: pytorch中自定义数据集加载对象重写Dataset

在pytorch中,数据加载可以通过自动逸的数据集对象来实现,数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Dataset,并实现相应的方法。

下面针对给定任务进行重写Dataset类:

我们所有的图片都是在一个文件下,每个图像的标签含在一个csv文件中,所以不能利用Pytorch中的ImageFolder进行加载,所以需要自己重写DataSet类,实现读写数据。

重写DataSet类,需要重写3个方法:

  • __init__:该方法主要就是一些参数初始化工作,定义一些路径或者变量什么的
  • __getitem__:该方法是加载数据用的,用于读取每一条数据,他会有一个参数idx,就是对应的索引,从0开始,由于我们的图片是从001.jpg到280.jpg,所以可以利用这个索引依次读取文件夹中的所有图片,然后从标签csv中读取它对应的行拿到对应的标签,然后返回即可
  • __len__:返回整个数据集的大小
# 加载数据集,自己重写DataSet类
class dataset(Dataset):
    # image_dir为数据目录,label_file,为标签文件
    def __init__(self, image_dir, label_file, transform=None):
        self.image_dir = image_dir # 图像文件所在路径
        self.label_file = pd.read_csv(label_file) # 图像对应的标签文件
        self.transform = transform # 数据转换操作
    # 加载每一项数据
    def __getitem__(self, idx):
        # 每个图片,其中idx为数据索引
        img_name = os.path.join(self.image_dir, '%.3d.jpg' % (idx + 1)) # 加载每一张照片
        image = Image.open(img_name)
        # 对应标签
        labels = (self.label_file[['cream', 'fruits', 'sprinkle_toppings']] == 'yes').astype(int).values[idx, :]
        if self.transform:
            image = self.transform(image)
        # 返回一张照片,一个标签
        return image, labels
    # 数据集大小
    def __len__(self):
        return (len(self.label_file))

如果上面任务能够明白,其实Dataset类不局限于这么写,它可以实现多种数据读取方法,只需要把读取数据以及数据处理逻辑写在__getitem__方法中即可,然后将处理好后的数据以及标签返回即可。


目录
相关文章
|
2月前
|
机器学习/深度学习 存储 PyTorch
PyTorch自定义学习率调度器实现指南
本文将详细介绍如何通过扩展PyTorch的 ``` LRScheduler ``` 类来实现一个具有预热阶段的余弦衰减调度器。我们将分五个关键步骤来完成这个过程。
62 2
|
5月前
|
机器学习/深度学习 人工智能 PyTorch
|
2月前
|
并行计算 PyTorch 算法框架/工具
基于CUDA12.1+CUDNN8.9+PYTORCH2.3.1,实现自定义数据集训练
文章介绍了如何在CUDA 12.1、CUDNN 8.9和PyTorch 2.3.1环境下实现自定义数据集的训练,包括环境配置、预览结果和核心步骤,以及遇到问题的解决方法和参考链接。
基于CUDA12.1+CUDNN8.9+PYTORCH2.3.1,实现自定义数据集训练
|
5月前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】38. Pytorch实战案例:梯度下降、随机梯度下降、小批量随机梯度下降3种优化算法对比【含数据集与源码】
【从零开始学习深度学习】38. Pytorch实战案例:梯度下降、随机梯度下降、小批量随机梯度下降3种优化算法对比【含数据集与源码】
|
5月前
|
机器学习/深度学习 自然语言处理 PyTorch
【从零开始学习深度学习】34. Pytorch-RNN项目实战:RNN创作歌词案例--使用周杰伦专辑歌词训练模型并创作歌曲【含数据集与源码】
【从零开始学习深度学习】34. Pytorch-RNN项目实战:RNN创作歌词案例--使用周杰伦专辑歌词训练模型并创作歌曲【含数据集与源码】
|
5月前
|
机器学习/深度学习 资源调度 PyTorch
【从零开始学习深度学习】15. Pytorch实战Kaggle比赛:房价预测案例【含数据集与源码】
【从零开始学习深度学习】15. Pytorch实战Kaggle比赛:房价预测案例【含数据集与源码】
|
5月前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】45. Pytorch迁移学习微调方法实战:使用微调技术进行2分类图片热狗识别模型训练【含源码与数据集】
【从零开始学习深度学习】45. Pytorch迁移学习微调方法实战:使用微调技术进行2分类图片热狗识别模型训练【含源码与数据集】
|
18天前
|
算法 PyTorch 算法框架/工具
Pytorch学习笔记(九):Pytorch模型的FLOPs、模型参数量等信息输出(torchstat、thop、ptflops、torchsummary)
本文介绍了如何使用torchstat、thop、ptflops和torchsummary等工具来计算Pytorch模型的FLOPs、模型参数量等信息。
78 2
|
20天前
|
机器学习/深度学习 自然语言处理 监控
利用 PyTorch Lightning 搭建一个文本分类模型
利用 PyTorch Lightning 搭建一个文本分类模型
43 8
利用 PyTorch Lightning 搭建一个文本分类模型
|
22天前
|
机器学习/深度学习 自然语言处理 数据建模
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力
本文深入探讨了Transformer模型中的三种关键注意力机制:自注意力、交叉注意力和因果自注意力,这些机制是GPT-4、Llama等大型语言模型的核心。文章不仅讲解了理论概念,还通过Python和PyTorch从零开始实现这些机制,帮助读者深入理解其内部工作原理。自注意力机制通过整合上下文信息增强了输入嵌入,多头注意力则通过多个并行的注意力头捕捉不同类型的依赖关系。交叉注意力则允许模型在两个不同输入序列间传递信息,适用于机器翻译和图像描述等任务。因果自注意力确保模型在生成文本时仅考虑先前的上下文,适用于解码器风格的模型。通过本文的详细解析和代码实现,读者可以全面掌握这些机制的应用潜力。
37 3
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力