Pytorch 的 torch.utils.data.DataLoader 参数详解

简介: Pytorch 的 torch.utils.data.DataLoader 参数详解

DataLoader是PyTorch中的一种数据类型,它定义了如何读取数据方式。


1、dataset:(数据类型 dataset)


输入的数据类型。看名字感觉就像是数据库,C#里面也有dataset类,理论上应该还有下一级的datatable。这应当是原始数据的输入。PyTorch内也有这种数据结构。这里先不管,估计和C#的类似,这里只需要知道是输入数据类型是dataset就可以了。


2、batch_size:(数据类型 int)


每次输入数据的行数,默认为1。PyTorch训练模型时调用数据不是一行一行进行的(这样太没效率),而是一捆一捆来的。这里就是定义每次喂给神经网络多少行数据,如果设置成1,那就是一行一行进行(个人偏好,PyTorch默认设置是1)。


3、shuffle:(数据类型 bool)


洗牌。默认设置为False。在每次迭代训练时是否将数据洗牌,默认设置是False。将输入数据的顺序打乱,是为了使数据更有独立性,但如果数据是有序列特征的,就不要设置成True了。


4、collate_fn:(数据类型 callable,没见过的类型)


将一小段数据合并成数据列表,默认设置是False。如果设置成True,系统会在返回前会将张量数据(Tensors)复制到CUDA内存中。(不太明白作用是什么,就暂时默认False)


5、batch_sampler:(数据类型 Sampler)


批量采样,默认设置为None。但每次返回的是一批数据的索引(注意:不是数据)。其和batch_size、shuffle 、sampler and drop_last参数是不兼容的。我想,应该是每次输入网络的数据是随机采样模式,这样能使数据更具有独立性质。所以,它和一捆一捆按顺序输入,数据洗牌,数据采样,等模式是不兼容的。


6、sampler:(数据类型 Sampler)


采样,默认设置为None。根据定义的策略从数据集中采样输入。如果定义采样规则,则洗牌(shuffle)设置必须为False。


7、num_workers:(数据类型 Int)


工作者数量,默认是0。使用多少个子进程来导入数据。设置为0,就是使用主进程来导入数据。注意:这个数字必须是大于等于0的,负数估计会出错。


8、pin_memory:(数据类型 bool)


内存寄存,默认为False。在数据返回前,是否将数据复制到CUDA内存中。


9、drop_last:(数据类型 bool)


丢弃最后数据,默认为False。设置了 batch_size 的数目后,最后一批数据未必是设置的数目,有可能会小些。这时你是否需要丢弃这批数据。


10、timeout:(数据类型 numeric)


超时,默认为0。是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。 所以,数值必须大于等于0。


11、worker_init_fn(数据类型 callable,没见过的类型)


注:

batch_size表示一次训练喂给模型的数据量,不能超GPU的显存,还有如果过大,会导致训练次数下降,导致loss和准确率下降。

num_workers表示CPU加载数据的线程数,但是不是越大越好,每次线程创建会耗费大量时间。

参考连接:https://blog.csdn.net/qq_32998593/article/details/92849585


使用方法:

1.先加载数据

2.dataloader读取数据

3.for循环从loader里面读取数据

train_data = torchvision.datasets.CIFAR10(root="../data", train=True, 
                                              transform=torchvision.transforms.ToTensor(),
                                              download=True)
train_dataloader = DataLoader(train_data, batch_size=8, num_workers=1, pin_memory=True)
for data in train_dataloader:

参考连接:https://blog.csdn.net/qq_36653505/article/details/84728855

相关文章
|
6天前
|
PyTorch 算法框架/工具
ImportError: cannot import name ‘_DataLoaderIter‘ from ‘torch.utils.data.dataloader‘
ImportError: cannot import name ‘_DataLoaderIter‘ from ‘torch.utils.data.dataloader‘
9 2
|
6天前
|
存储 PyTorch 算法框架/工具
torch.Storage()是什么?和torch.Tensor()有什么区别?
torch.Storage()是什么?和torch.Tensor()有什么区别?
9 1
|
6天前
|
PyTorch 算法框架/工具
实战pytorch中utils.data.TensorDataset和utils.data.DataLoader工具
本文主要说明pytorch框架中utils.data.TensorDataset和utils.data.DataLoader两个工具类。
96 0
|
PyTorch 算法框架/工具 索引
详细介绍torch中的from torch.utils.data.sampler相关知识
PyTorch中的torch.utils.data.sampler模块提供了一些用于数据采样的类和函数,这些类和函数可以用于控制如何从数据集中选择样本。下面是一些常用的Sampler类和函数的介绍: Sampler基类: Sampler是一个抽象类,它定义了一个__iter__方法,返回一个迭代器,用于生成数据集中的样本索引。 RandomSampler: 随机采样器,它会随机从数据集中选择样本。可以设置随机数种子,以确保每次采样结果相同。 SequentialSampler: 顺序采样器,它会按照数据集中的顺序,依次选择样本。 SubsetRandomSampler: 子集随机采样器
513 0
|
存储 测试技术
测试模型时,为什么要with torch.no_grad(),为什么要model.eval(),如何使用with torch.no_grad(),model.eval(),同时使用还是只用其中之一
在测试模型时,我们通常使用with torch.no_grad()和model.eval()这两个方法来确保模型在评估过程中的正确性和效率。
620 0
|
机器学习/深度学习 数据可视化 PyTorch
visdom的用法,详细的介绍torch相关案例
Visdom是一款用于创建交互式可视化的Python库,通常在深度学习中用于监视训练进度和可视化结果。在PyTorch中,可以使用Visdom轻松地创建图形和可视化数据。在这个例子中,我们首先使用 torchvision.utils.make_grid 方法加载了一批随机生成的图像,并使用 visdom.image 方法将其可视化。然后,我们创建了一些随机的二维数据,并使用 visdom.scatter 方法将其可视化为散点图。 需要注意的是,Visdom的Web界面支持实时更新,这意味着在训练模型或执行其他任务时,可以使用Visdom实时监视进度和结果。
194 0
|
PyTorch 算法框架/工具
pytorch中torch.where()使用方法
pytorch中torch.where()使用方法
604 0
|
网络虚拟化
在torch_geometric.datasets中使用Planetoid手动导入Core数据集及发生相关错误解决方案
在torch_geometric.datasets中使用Planetoid手动导入Core数据集及发生相关错误解决方案
621 0
在torch_geometric.datasets中使用Planetoid手动导入Core数据集及发生相关错误解决方案