在使用torch.utils.data.DataLoader
构造数据集并进行迭代时,可以通过调用iter()
函数将DataLoader
对象转换为一个可迭代对象,然后再通过调用next()
函数取出下一个batch的数据。
例如,假设我们已经定义好了一个MyDataset
类来表示自己的数据集,并且使用DataLoader
来对其进行批量处理:
import torch.utils.data as data class MyDataset(data.Dataset): def __init__(self, ...): # 构造函数,初始化数据集 ... def __getitem__(self, index): # 根据索引获取样本和标签 ... def __len__(self): # 获取数据集大小 ... dataset = MyDataset(...) dataloader = data.DataLoader(dataset, batch_size=32, shuffle=True)
现在,我们可以通过调用iter()
函数将dataloader
对象转换为一个可迭代对象:
iterator = iter(dataloader)
然后,我们就可以通过调用next()
函数来逐个获取数据集中的下一个batch数据:
batch_data = next(iterator)
这里的batch_data
就是一个包含32个样本及其对应标签的张量,可以直接传入模型进行训练或者预测。
下面给出一个案例
import torch from torch.utils.data import DataLoader, TensorDataset x_data = torch.randn(10, 3) y_data = torch.randn(10, 1) # for i, j in zip(x_data ,y_data): # print(i, j) # import random # 创建TensorDataset对象 dataset = TensorDataset(x_data, y_data) # 创建DataLoader对象,并指定batch_size和是否要进行打乱 dataloader = DataLoader(dataset, batch_size=2, shuffle=False) # 遍历每个小批量数据 for batch_x, batch_y in dataloader: # 在这里执行训练或评估操作 print(batch_x, batch_y) iterator = iter(dataloader)
输出
tensor([[-0.8643, -1.6477, 0.8695], [-1.4600, 0.2523, -0.4641]]) tensor([[-0.3900], [ 0.3526]]) tensor([[ 1.2054, -0.9444, -1.0735], [-1.4717, 0.8908, 0.5036]]) tensor([[ 0.8096], [-1.0215]]) tensor([[ 1.4094, 0.1649, 0.2448], [-0.6039, -1.4968, 0.1234]]) tensor([[ 0.0931], [-0.3150]]) tensor([[-0.8754, -0.1743, 0.7225], [-2.0970, 0.8257, 0.7893]]) tensor([[-0.8036], [ 1.4351]]) tensor([[-0.4395, -0.5905, -1.2884], [ 1.4488, 0.2629, -0.1280]]) tensor([[-0.9477], [ 0.1115]])
while True: try: batch_data = next(iterator) print(batch_data) except: break
输出
[tensor([[-0.8643, -1.6477, 0.8695], [-1.4600, 0.2523, -0.4641]]), tensor([[-0.3900], [ 0.3526]])] [tensor([[ 1.2054, -0.9444, -1.0735], [-1.4717, 0.8908, 0.5036]]), tensor([[ 0.8096], [-1.0215]])] [tensor([[ 1.4094, 0.1649, 0.2448], [-0.6039, -1.4968, 0.1234]]), tensor([[ 0.0931], [-0.3150]])] [tensor([[-0.8754, -0.1743, 0.7225], [-2.0970, 0.8257, 0.7893]]), tensor([[-0.8036], [ 1.4351]])] [tensor([[-0.4395, -0.5905, -1.2884], [ 1.4488, 0.2629, -0.1280]]), tensor([[-0.9477], [ 0.1115]])]