pytorch-TensorFlow-加载直接的数据集

简介: 笔记

在我刚刚学习深度学习的时候,就只会用现有的数据集。当我想训练直接的模型的时候,却不知道该怎样弄,但时是花了两天在网上寻找教程,可是都不太适合新手学习,所以今天我就来总结一下pytorch里面加载自己的数据集的方法。


方法一:利用torch.utils.data.TensorDataset,也是我认为最简单的方法

from torch.utils.data import TensorDataset,DataLoader
x =       #对应你加载的数据   是tensor类型
y =      #对应你数据的标签  是tensor类型
data_set = TensorDataset(x,y)   
data_loader = DataLoader(data_set,batch_size = 100,shuffle = True,drop_last = True) #生成可迭代的迭代器
#训练时
for index (img,label) in enumerate(data_loaer):
  #训练代码

方法二:重写Dataset类

from torch.utils.data import Dataset,DataLoader
class dataset(Dataset):  #继承Dataset类
  def __init__(self,root,transform=None):  #root为你数据的路径  下面以图片为例
  imgs = os.listdir(root)    #返回root里面所用文件的名称
  self.img = [os.path.join(root,img) for img in imgs]  #生成每张照片的路径
  self.transform = transform  #对图片的一些处理
  def __getitem__(self,index):
    img_path = self.img(index)
    label = 0 if 'woman' in img_path.split('/')[-1] else 1 #自己根据文件名来设置标签,例如这里是文件名包不包含woman。
    data =  Image.open(img_path) #读取图片,在PIL库中
    data = self.transform(data)   #对图片进行装换,一定要转换成tensor
    return data,label    #返回图片和其标签
  def __len__(self):
    return len(self.img)   #返回数据的大小
data = dataset(root)
data_loader = DataLoader(data,batch_size=100,shuffle=True,drop_last= True)
              #batch_size每一次迭代数据的大小       shuffle对图片是否打散    drop_last对最后的数据如果不满足batch_size的大小就舍弃

方法三:利用ImageFolder对图片进行读取

ImageFolder是对整个文件夹进行对取,每个文件夹的内容,会自动被归为一类。

from torchvision.datasets import ImageFolder
data = ImageFolder(root)  #root的根目录放保存每一类的文件夹
data_loader = DataLoader(data,batch_size=100,shuffle=True,drop_last= True)

基于这三种简单易行的方法,你可以很方便的根据你数据的存放的形式进行构造自己的数据集。

下面介绍一下tf 2.0的构造自己的数据集,先看一下代码吧,也不难。

import tensorflow as tf
x =     #数据。numpy类型 
y =   #标签。numpy类型
# x,y tensor类型的我还没尝试过你可以去转换成tensor类型的试试。
data = tf.data.Dataset.slices((x,y))
data_load = data.repeat().shuffle(5000).batch(128).prefetch(1)
for step , (x,y) in enumerate(data_load.take(2000),1):

这种方法与pytorch的第一种方法类似。

prefetch(x) ,表示预先准备下x次迭代的数据,提高效率。

batch , 表示每一次迭代的数据大小,比如图片的话就是每一次迭代128张图片的数据。

shuffle(x) ,表示打乱数据的次序,x表示打乱的次数,每一次迭代算一次。

take(x) ,表示训练的epoch数,在pytorch里面需要在外面在嵌套一次for循环来设置训练的epochs数。


Thank for your reading !!!

公众号:FPGA之旅

目录
相关文章
|
1月前
|
数据采集 TensorFlow 算法框架/工具
【大作业-03】手把手教你用tensorflow2.3训练自己的分类数据集
本教程详细介绍了如何使用TensorFlow 2.3训练自定义图像分类数据集,涵盖数据集收集、整理、划分及模型训练与测试全过程。提供完整代码示例及图形界面应用开发指导,适合初学者快速上手。[教程链接](https://www.bilibili.com/video/BV1rX4y1A7N8/),配套视频更易理解。
39 0
【大作业-03】手把手教你用tensorflow2.3训练自己的分类数据集
|
5月前
|
机器学习/深度学习 人工智能 PyTorch
|
2月前
|
并行计算 PyTorch 算法框架/工具
基于CUDA12.1+CUDNN8.9+PYTORCH2.3.1,实现自定义数据集训练
文章介绍了如何在CUDA 12.1、CUDNN 8.9和PyTorch 2.3.1环境下实现自定义数据集的训练,包括环境配置、预览结果和核心步骤,以及遇到问题的解决方法和参考链接。
113 4
基于CUDA12.1+CUDNN8.9+PYTORCH2.3.1,实现自定义数据集训练
|
5月前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】38. Pytorch实战案例:梯度下降、随机梯度下降、小批量随机梯度下降3种优化算法对比【含数据集与源码】
【从零开始学习深度学习】38. Pytorch实战案例:梯度下降、随机梯度下降、小批量随机梯度下降3种优化算法对比【含数据集与源码】
|
3月前
|
UED 存储 数据管理
深度解析 Uno Platform 离线状态处理技巧:从网络检测到本地存储同步,全方位提升跨平台应用在无网环境下的用户体验与数据管理策略
【8月更文挑战第31天】处理离线状态下的用户体验是现代应用开发的关键。本文通过在线笔记应用案例,介绍如何使用 Uno Platform 优雅地应对离线状态。首先,利用 `NetworkInformation` 类检测网络状态;其次,使用 SQLite 实现离线存储;然后,在网络恢复时同步数据;最后,通过 UI 反馈提升用户体验。
88 0
|
3月前
|
机器学习/深度学习 TensorFlow 数据处理
分布式训练在TensorFlow中的全面应用指南:掌握多机多卡配置与实践技巧,让大规模数据集训练变得轻而易举,大幅提升模型训练效率与性能
【8月更文挑战第31天】本文详细介绍了如何在Tensorflow中实现多机多卡的分布式训练,涵盖环境配置、模型定义、数据处理及训练执行等关键环节。通过具体示例代码,展示了使用`MultiWorkerMirroredStrategy`进行分布式训练的过程,帮助读者更好地应对大规模数据集与复杂模型带来的挑战,提升训练效率。
80 0
|
3月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
【Tensorflow+Keras】keras实现条件生成对抗网络DCGAN--以Minis和fashion_mnist数据集为例
如何使用TensorFlow和Keras实现条件生成对抗网络(CGAN)并以MNIST和Fashion MNIST数据集为例进行演示。
49 3
|
5月前
|
机器学习/深度学习 自然语言处理 PyTorch
【从零开始学习深度学习】34. Pytorch-RNN项目实战:RNN创作歌词案例--使用周杰伦专辑歌词训练模型并创作歌曲【含数据集与源码】
【从零开始学习深度学习】34. Pytorch-RNN项目实战:RNN创作歌词案例--使用周杰伦专辑歌词训练模型并创作歌曲【含数据集与源码】
|
5月前
|
机器学习/深度学习 资源调度 PyTorch
【从零开始学习深度学习】15. Pytorch实战Kaggle比赛:房价预测案例【含数据集与源码】
【从零开始学习深度学习】15. Pytorch实战Kaggle比赛:房价预测案例【含数据集与源码】
|
5月前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】45. Pytorch迁移学习微调方法实战:使用微调技术进行2分类图片热狗识别模型训练【含源码与数据集】
【从零开始学习深度学习】45. Pytorch迁移学习微调方法实战:使用微调技术进行2分类图片热狗识别模型训练【含源码与数据集】

热门文章

最新文章