torch在构造数据集时,如何将dataloader和next结合,使得next输出的是dataloader已经构造好的数据集

简介: 在使用torch.utils.data.DataLoader构造数据集并进行迭代时,可以通过调用iter()函数将DataLoader对象转换为一个可迭代对象,然后再通过调用next()函数取出下一个batch的数据。例如,假设我们已经定义好了一个MyDataset类来表示自己的数据集,并且使用DataLoader来对其进行批量处理:

在使用torch.utils.data.DataLoader构造数据集并进行迭代时,可以通过调用iter()函数将DataLoader对象转换为一个可迭代对象,然后再通过调用next()函数取出下一个batch的数据。

例如,假设我们已经定义好了一个MyDataset类来表示自己的数据集,并且使用DataLoader来对其进行批量处理:

import torch.utils.data as data
class MyDataset(data.Dataset):
    def __init__(self, ...):
        # 构造函数,初始化数据集
        ...
    def __getitem__(self, index):
        # 根据索引获取样本和标签
        ...
    def __len__(self):
        # 获取数据集大小
        ...
dataset = MyDataset(...)
dataloader = data.DataLoader(dataset, batch_size=32, shuffle=True)

现在,我们可以通过调用iter()函数将dataloader对象转换为一个可迭代对象:

iterator = iter(dataloader)

然后,我们就可以通过调用next()函数来逐个获取数据集中的下一个batch数据:

batch_data = next(iterator)

这里的batch_data就是一个包含32个样本及其对应标签的张量,可以直接传入模型进行训练或者预测。


下面给出一个案例

import torch
from torch.utils.data import DataLoader, TensorDataset
x_data = torch.randn(10, 3)
y_data = torch.randn(10, 1)
# for i, j in zip(x_data ,y_data):
#     print(i, j)
# import random
# 创建TensorDataset对象
dataset = TensorDataset(x_data, y_data)
# 创建DataLoader对象,并指定batch_size和是否要进行打乱
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)
# 遍历每个小批量数据
for batch_x, batch_y in dataloader:
    # 在这里执行训练或评估操作
    print(batch_x, batch_y)
iterator = iter(dataloader)

输出

tensor([[-0.8643, -1.6477,  0.8695],
        [-1.4600,  0.2523, -0.4641]]) tensor([[-0.3900],
        [ 0.3526]])
tensor([[ 1.2054, -0.9444, -1.0735],
        [-1.4717,  0.8908,  0.5036]]) tensor([[ 0.8096],
        [-1.0215]])
tensor([[ 1.4094,  0.1649,  0.2448],
        [-0.6039, -1.4968,  0.1234]]) tensor([[ 0.0931],
        [-0.3150]])
tensor([[-0.8754, -0.1743,  0.7225],
        [-2.0970,  0.8257,  0.7893]]) tensor([[-0.8036],
        [ 1.4351]])
tensor([[-0.4395, -0.5905, -1.2884],
        [ 1.4488,  0.2629, -0.1280]]) tensor([[-0.9477],
        [ 0.1115]])
while True:
    try:
        batch_data = next(iterator)
        print(batch_data) 
    except:
        break

输出

[tensor([[-0.8643, -1.6477,  0.8695],
        [-1.4600,  0.2523, -0.4641]]), tensor([[-0.3900],
        [ 0.3526]])]
[tensor([[ 1.2054, -0.9444, -1.0735],
        [-1.4717,  0.8908,  0.5036]]), tensor([[ 0.8096],
        [-1.0215]])]
[tensor([[ 1.4094,  0.1649,  0.2448],
        [-0.6039, -1.4968,  0.1234]]), tensor([[ 0.0931],
        [-0.3150]])]
[tensor([[-0.8754, -0.1743,  0.7225],
        [-2.0970,  0.8257,  0.7893]]), tensor([[-0.8036],
        [ 1.4351]])]
[tensor([[-0.4395, -0.5905, -1.2884],
        [ 1.4488,  0.2629, -0.1280]]), tensor([[-0.9477],
        [ 0.1115]])]
相关文章
|
缓存 PyTorch 数据处理
基于Pytorch的PyTorch Geometric(PYG)库构造个人数据集
基于Pytorch的PyTorch Geometric(PYG)库构造个人数据集
1053 0
基于Pytorch的PyTorch Geometric(PYG)库构造个人数据集
|
2月前
|
计算机视觉
数据集学习笔记(三):COCO创建dataloader用于训练
如何使用COCO数据集创建dataloader进行训练,包括安装环境、加载数据集代码、定义数据转换、创建数据集对象以及创建dataloader。
50 5
|
5月前
|
机器学习/深度学习 存储 算法
查询模型的方法knn_model.pkl
【7月更文挑战第28天】
64 3
|
4月前
|
API 算法框架/工具
【Tensorflow+keras】使用keras API保存模型权重、plot画loss损失函数、保存训练loss值
使用keras API保存模型权重、plot画loss损失函数、保存训练loss值
37 0
|
机器学习/深度学习 Linux PyTorch
Dataset and DataLoader 加载数据集
Dataset and DataLoader 加载数据集
155 0
|
PyTorch 算法框架/工具
【PyTorch】自定义数据集处理/dataset/DataLoader等
【PyTorch】自定义数据集处理/dataset/DataLoader等
187 0
|
机器学习/深度学习 存储 PyTorch
怎么调用pytorch中mnist数据集
怎么调用pytorch中mnist数据集
231 0
|
PyTorch 算法框架/工具
如何将x_data和y_data利用torch转换成小批量数据,并要求打乱数据,以及将数据标准化或者归一化,如何处理?
以上代码中,在定义预处理操作transform时,只在Normalize函数的第一个参数中传入x_data的均值和标准差,而在第二个参数中传入空元组,表示不对y_data进行标准化。 接着,将标准化后的x_data和原始的y_data转换为张量格式,并将它们合并为一个TensorDataset对象。最后,定义dataloader对象,设置batch_size和shuffle参数,并使用上述数据集对象作为输入数据。
326 0
|
机器学习/深度学习 固态存储 数据处理
【目标检测之数据集预处理】继承Dataset定义自己的数据集【附代码】(上)
在深度学习训练中,除了设计有效的卷积神经网络框架外,更重要的是数据的处理。在训练之前需要对训练数据进行预处理。比如在目标检测网络训练中,首先需要划分训练集和测试集,然后对标签、边界框等进行处理后才能送入网络进行训练,本文章以VOC数据集格式为例,对数据集进行预处理后送入目标检测网络进行训练。【附代码】
356 0
【目标检测之数据集预处理】继承Dataset定义自己的数据集【附代码】(上)
|
数据采集 并行计算 PyTorch
【目标检测之数据集加载】利用DataLoader加载已预处理后的数据集【附代码】
在前一篇文章中,已经通过继承Dataset预处理自己的数据集 ,接下来就是使用pytorch提供的DataLoader函数加载数据集。
646 0
【目标检测之数据集加载】利用DataLoader加载已预处理后的数据集【附代码】