torch在构造数据集时,如何将dataloader和next结合,使得next输出的是dataloader已经构造好的数据集

简介: 在使用torch.utils.data.DataLoader构造数据集并进行迭代时,可以通过调用iter()函数将DataLoader对象转换为一个可迭代对象,然后再通过调用next()函数取出下一个batch的数据。例如,假设我们已经定义好了一个MyDataset类来表示自己的数据集,并且使用DataLoader来对其进行批量处理:

在使用torch.utils.data.DataLoader构造数据集并进行迭代时,可以通过调用iter()函数将DataLoader对象转换为一个可迭代对象,然后再通过调用next()函数取出下一个batch的数据。

例如,假设我们已经定义好了一个MyDataset类来表示自己的数据集,并且使用DataLoader来对其进行批量处理:

import torch.utils.data as data
class MyDataset(data.Dataset):
    def __init__(self, ...):
        # 构造函数,初始化数据集
        ...
    def __getitem__(self, index):
        # 根据索引获取样本和标签
        ...
    def __len__(self):
        # 获取数据集大小
        ...
dataset = MyDataset(...)
dataloader = data.DataLoader(dataset, batch_size=32, shuffle=True)

现在,我们可以通过调用iter()函数将dataloader对象转换为一个可迭代对象:

iterator = iter(dataloader)

然后,我们就可以通过调用next()函数来逐个获取数据集中的下一个batch数据:

batch_data = next(iterator)

这里的batch_data就是一个包含32个样本及其对应标签的张量,可以直接传入模型进行训练或者预测。


下面给出一个案例

import torch
from torch.utils.data import DataLoader, TensorDataset
x_data = torch.randn(10, 3)
y_data = torch.randn(10, 1)
# for i, j in zip(x_data ,y_data):
#     print(i, j)
# import random
# 创建TensorDataset对象
dataset = TensorDataset(x_data, y_data)
# 创建DataLoader对象,并指定batch_size和是否要进行打乱
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)
# 遍历每个小批量数据
for batch_x, batch_y in dataloader:
    # 在这里执行训练或评估操作
    print(batch_x, batch_y)
iterator = iter(dataloader)

输出

tensor([[-0.8643, -1.6477,  0.8695],
        [-1.4600,  0.2523, -0.4641]]) tensor([[-0.3900],
        [ 0.3526]])
tensor([[ 1.2054, -0.9444, -1.0735],
        [-1.4717,  0.8908,  0.5036]]) tensor([[ 0.8096],
        [-1.0215]])
tensor([[ 1.4094,  0.1649,  0.2448],
        [-0.6039, -1.4968,  0.1234]]) tensor([[ 0.0931],
        [-0.3150]])
tensor([[-0.8754, -0.1743,  0.7225],
        [-2.0970,  0.8257,  0.7893]]) tensor([[-0.8036],
        [ 1.4351]])
tensor([[-0.4395, -0.5905, -1.2884],
        [ 1.4488,  0.2629, -0.1280]]) tensor([[-0.9477],
        [ 0.1115]])
while True:
    try:
        batch_data = next(iterator)
        print(batch_data) 
    except:
        break

输出

[tensor([[-0.8643, -1.6477,  0.8695],
        [-1.4600,  0.2523, -0.4641]]), tensor([[-0.3900],
        [ 0.3526]])]
[tensor([[ 1.2054, -0.9444, -1.0735],
        [-1.4717,  0.8908,  0.5036]]), tensor([[ 0.8096],
        [-1.0215]])]
[tensor([[ 1.4094,  0.1649,  0.2448],
        [-0.6039, -1.4968,  0.1234]]), tensor([[ 0.0931],
        [-0.3150]])]
[tensor([[-0.8754, -0.1743,  0.7225],
        [-2.0970,  0.8257,  0.7893]]), tensor([[-0.8036],
        [ 1.4351]])]
[tensor([[-0.4395, -0.5905, -1.2884],
        [ 1.4488,  0.2629, -0.1280]]), tensor([[-0.9477],
        [ 0.1115]])]
相关文章
|
6月前
|
存储 PyTorch 算法框架/工具
PyTorch 中的 Tensor:属性、数据生成和基本操作
PyTorch 中的 Tensor:属性、数据生成和基本操作
206 0
|
缓存 PyTorch 数据处理
基于Pytorch的PyTorch Geometric(PYG)库构造个人数据集
基于Pytorch的PyTorch Geometric(PYG)库构造个人数据集
981 0
基于Pytorch的PyTorch Geometric(PYG)库构造个人数据集
|
1月前
|
计算机视觉
数据集学习笔记(三):COCO创建dataloader用于训练
如何使用COCO数据集创建dataloader进行训练,包括安装环境、加载数据集代码、定义数据转换、创建数据集对象以及创建dataloader。
39 5
|
4月前
|
机器学习/深度学习 存储 算法
查询模型的方法knn_model.pkl
【7月更文挑战第28天】
45 3
|
3月前
|
API 算法框架/工具
【Tensorflow+keras】使用keras API保存模型权重、plot画loss损失函数、保存训练loss值
使用keras API保存模型权重、plot画loss损失函数、保存训练loss值
31 0
|
机器学习/深度学习 Linux PyTorch
Dataset and DataLoader 加载数据集
Dataset and DataLoader 加载数据集
144 0
|
机器学习/深度学习 PyTorch 算法框架/工具
Pytorch使用专题 | 2 :Pytorch中数据读取-Dataset、Dataloader 、TensorDataset 和 Sampler 的使用
介绍Pytorch中数据读取-Dataset、Dataloader 、TensorDataset 和 Sampler 的使用
|
PyTorch 算法框架/工具
【PyTorch】自定义数据集处理/dataset/DataLoader等
【PyTorch】自定义数据集处理/dataset/DataLoader等
181 0
|
机器学习/深度学习 存储 PyTorch
怎么调用pytorch中mnist数据集
怎么调用pytorch中mnist数据集
218 0
随机抽样方法——DataFrame.sample()
随机抽样方法——DataFrame.sample()