pytorch中自定义数据集加载对象重写Dataset

简介: pytorch中自定义数据集加载对象重写Dataset

在pytorch中,数据加载可以通过自动逸的数据集对象来实现,数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Dataset,并实现相应的方法。

下面针对给定任务进行重写Dataset类:

我们所有的图片都是在一个文件下,每个图像的标签含在一个csv文件中,所以不能利用Pytorch中的ImageFolder进行加载,所以需要自己重写DataSet类,实现读写数据。

重写DataSet类,需要重写3个方法:

  • __init__:该方法主要就是一些参数初始化工作,定义一些路径或者变量什么的
  • __getitem__:该方法是加载数据用的,用于读取每一条数据,他会有一个参数idx,就是对应的索引,从0开始,由于我们的图片是从001.jpg到280.jpg,所以可以利用这个索引依次读取文件夹中的所有图片,然后从标签csv中读取它对应的行拿到对应的标签,然后返回即可
  • __len__:返回整个数据集的大小
# 加载数据集,自己重写DataSet类
class dataset(Dataset):
    # image_dir为数据目录,label_file,为标签文件
    def __init__(self, image_dir, label_file, transform=None):
        self.image_dir = image_dir # 图像文件所在路径
        self.label_file = pd.read_csv(label_file) # 图像对应的标签文件
        self.transform = transform # 数据转换操作
    # 加载每一项数据
    def __getitem__(self, idx):
        # 每个图片,其中idx为数据索引
        img_name = os.path.join(self.image_dir, '%.3d.jpg' % (idx + 1)) # 加载每一张照片
        image = Image.open(img_name)
        # 对应标签
        labels = (self.label_file[['cream', 'fruits', 'sprinkle_toppings']] == 'yes').astype(int).values[idx, :]
        if self.transform:
            image = self.transform(image)
        # 返回一张照片,一个标签
        return image, labels
    # 数据集大小
    def __len__(self):
        return (len(self.label_file))

如果上面任务能够明白,其实Dataset类不局限于这么写,它可以实现多种数据读取方法,只需要把读取数据以及数据处理逻辑写在__getitem__方法中即可,然后将处理好后的数据以及标签返回即可。


目录
相关文章
|
9月前
|
机器学习/深度学习 监控 算法
利用PyTorch处理个人数据集
如此看来,整个处理个人数据集的过程就像进行一场球赛。你设立球场,安排队员,由教练训练,最后你可以看到他们的表现。不断地学习,不断地调整,你的模型也会越来越厉害。 当然,这个过程看似简单,但在实际操作时可能会奇怪各种问题。需要你在实践中不断摸索,不断学习。可是不要怕,只要你热爱,不怕困难,你一定能驯服你的数据,让他们为你所用!
170 35
|
机器学习/深度学习 存储 PyTorch
PyTorch自定义学习率调度器实现指南
本文将详细介绍如何通过扩展PyTorch的 ``` LRScheduler ``` 类来实现一个具有预热阶段的余弦衰减调度器。我们将分五个关键步骤来完成这个过程。
788 2
|
机器学习/深度学习 人工智能 PyTorch
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】38. Pytorch实战案例:梯度下降、随机梯度下降、小批量随机梯度下降3种优化算法对比【含数据集与源码】
【从零开始学习深度学习】38. Pytorch实战案例:梯度下降、随机梯度下降、小批量随机梯度下降3种优化算法对比【含数据集与源码】
|
并行计算 PyTorch 算法框架/工具
基于CUDA12.1+CUDNN8.9+PYTORCH2.3.1,实现自定义数据集训练
文章介绍了如何在CUDA 12.1、CUDNN 8.9和PyTorch 2.3.1环境下实现自定义数据集的训练,包括环境配置、预览结果和核心步骤,以及遇到问题的解决方法和参考链接。
955 4
基于CUDA12.1+CUDNN8.9+PYTORCH2.3.1,实现自定义数据集训练
|
机器学习/深度学习 资源调度 PyTorch
【从零开始学习深度学习】15. Pytorch实战Kaggle比赛:房价预测案例【含数据集与源码】
【从零开始学习深度学习】15. Pytorch实战Kaggle比赛:房价预测案例【含数据集与源码】
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】45. Pytorch迁移学习微调方法实战:使用微调技术进行2分类图片热狗识别模型训练【含源码与数据集】
【从零开始学习深度学习】45. Pytorch迁移学习微调方法实战:使用微调技术进行2分类图片热狗识别模型训练【含源码与数据集】
|
11月前
|
机器学习/深度学习 JavaScript PyTorch
9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体
生成对抗网络(GAN)的训练效果高度依赖于损失函数的选择。本文介绍了经典GAN损失函数理论,并用PyTorch实现多种变体,包括原始GAN、LS-GAN、WGAN及WGAN-GP等。通过分析其原理与优劣,如LS-GAN提升训练稳定性、WGAN-GP改善图像质量,展示了不同场景下损失函数的设计思路。代码实现覆盖生成器与判别器的核心逻辑,为实际应用提供了重要参考。未来可探索组合优化与自适应设计以提升性能。
964 7
9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体
|
5月前
|
机器学习/深度学习 存储 PyTorch
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
Neural ODE将神经网络与微分方程结合,用连续思维建模数据演化,突破传统离散层的限制,实现自适应深度与高效连续学习。
380 3
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节

推荐镜像

更多