Pytorch 的 torch.utils.data.DataLoader 参数详解

简介: Pytorch 的 torch.utils.data.DataLoader 参数详解

DataLoader是PyTorch中的一种数据类型,它定义了如何读取数据方式。


1、dataset:(数据类型 dataset)


输入的数据类型。看名字感觉就像是数据库,C#里面也有dataset类,理论上应该还有下一级的datatable。这应当是原始数据的输入。PyTorch内也有这种数据结构。这里先不管,估计和C#的类似,这里只需要知道是输入数据类型是dataset就可以了。


2、batch_size:(数据类型 int)


每次输入数据的行数,默认为1。PyTorch训练模型时调用数据不是一行一行进行的(这样太没效率),而是一捆一捆来的。这里就是定义每次喂给神经网络多少行数据,如果设置成1,那就是一行一行进行(个人偏好,PyTorch默认设置是1)。


3、shuffle:(数据类型 bool)


洗牌。默认设置为False。在每次迭代训练时是否将数据洗牌,默认设置是False。将输入数据的顺序打乱,是为了使数据更有独立性,但如果数据是有序列特征的,就不要设置成True了。


4、collate_fn:(数据类型 callable,没见过的类型)


将一小段数据合并成数据列表,默认设置是False。如果设置成True,系统会在返回前会将张量数据(Tensors)复制到CUDA内存中。(不太明白作用是什么,就暂时默认False)


5、batch_sampler:(数据类型 Sampler)


批量采样,默认设置为None。但每次返回的是一批数据的索引(注意:不是数据)。其和batch_size、shuffle 、sampler and drop_last参数是不兼容的。我想,应该是每次输入网络的数据是随机采样模式,这样能使数据更具有独立性质。所以,它和一捆一捆按顺序输入,数据洗牌,数据采样,等模式是不兼容的。


6、sampler:(数据类型 Sampler)


采样,默认设置为None。根据定义的策略从数据集中采样输入。如果定义采样规则,则洗牌(shuffle)设置必须为False。


7、num_workers:(数据类型 Int)


工作者数量,默认是0。使用多少个子进程来导入数据。设置为0,就是使用主进程来导入数据。注意:这个数字必须是大于等于0的,负数估计会出错。


8、pin_memory:(数据类型 bool)


内存寄存,默认为False。在数据返回前,是否将数据复制到CUDA内存中。


9、drop_last:(数据类型 bool)


丢弃最后数据,默认为False。设置了 batch_size 的数目后,最后一批数据未必是设置的数目,有可能会小些。这时你是否需要丢弃这批数据。


10、timeout:(数据类型 numeric)


超时,默认为0。是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。 所以,数值必须大于等于0。


11、worker_init_fn(数据类型 callable,没见过的类型)


注:

batch_size表示一次训练喂给模型的数据量,不能超GPU的显存,还有如果过大,会导致训练次数下降,导致loss和准确率下降。

num_workers表示CPU加载数据的线程数,但是不是越大越好,每次线程创建会耗费大量时间。

参考连接:https://blog.csdn.net/qq_32998593/article/details/92849585


使用方法:

1.先加载数据

2.dataloader读取数据

3.for循环从loader里面读取数据

train_data = torchvision.datasets.CIFAR10(root="../data", train=True, 
                                              transform=torchvision.transforms.ToTensor(),
                                              download=True)
train_dataloader = DataLoader(train_data, batch_size=8, num_workers=1, pin_memory=True)
for data in train_dataloader:

参考连接:https://blog.csdn.net/qq_36653505/article/details/84728855

相关文章
|
5月前
|
数据采集 存储 缓存
【Python-Tensorflow】tf.data.Dataset的解析与使用
本文详细介绍了TensorFlow中`tf.data.Dataset`类的使用,包括创建数据集的方法(如`from_generator()`、`from_tensor_slices()`、`from_tensors()`)、数据集函数(如`apply()`、`as_numpy_iterator()`、`batch()`、`cache()`等),以及如何通过这些函数进行高效的数据预处理和操作。
99 7
|
5月前
|
TensorFlow 算法框架/工具
【Tensorflow+Keras】tf.keras.backend.image_data_format()的解析与举例使用
介绍了TensorFlow和Keras中tf.keras.backend.image_data_format()函数的用法。
58 5
|
5月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
【Tensorflow+keras】解决cuDNN launch failure : input shape ([32,2,8,8]) [[{{node sequential_1/batch_nor
在使用TensorFlow 2.0和Keras训练生成对抗网络(GAN)时,遇到了“cuDNN launch failure”错误,特别是在调用self.generator.predict方法时出现,输入形状为([32,2,8,8])。此问题可能源于输入数据形状与模型期望的形状不匹配或cuDNN版本不兼容。解决方案包括设置GPU内存增长、检查模型定义和输入数据形状、以及确保TensorFlow和cuDNN版本兼容。
58 1
|
5月前
|
TensorFlow 算法框架/工具 Python
【Tensorflow 2】解决'Tensor' object has no attribute 'numpy'
解决'Tensor' object has no attribute 'numpy'
98 3
|
5月前
|
TensorFlow API 算法框架/工具
【Tensorflow+keras】解决使用model.load_weights时报错 ‘str‘ object has no attribute ‘decode‘
python 3.6,Tensorflow 2.0,在使用Tensorflow 的keras API,加载权重模型时,报错’str’ object has no attribute ‘decode’
71 0
|
8月前
|
PyTorch 算法框架/工具
ImportError: cannot import name ‘_DataLoaderIter‘ from ‘torch.utils.data.dataloader‘
ImportError: cannot import name ‘_DataLoaderIter‘ from ‘torch.utils.data.dataloader‘
112 2
|
API 数据格式
TensorFlow2._:model.summary() Output Shape为multiple解决方法
TensorFlow2._:model.summary() Output Shape为multiple解决方法
292 0
TensorFlow2._:model.summary() Output Shape为multiple解决方法
|
8月前
|
PyTorch 算法框架/工具
实战pytorch中utils.data.TensorDataset和utils.data.DataLoader工具
本文主要说明pytorch框架中utils.data.TensorDataset和utils.data.DataLoader两个工具类。
199 0
|
网络虚拟化
在torch_geometric.datasets中使用Planetoid手动导入Core数据集及发生相关错误解决方案
在torch_geometric.datasets中使用Planetoid手动导入Core数据集及发生相关错误解决方案
815 1
在torch_geometric.datasets中使用Planetoid手动导入Core数据集及发生相关错误解决方案
|
TensorFlow 算法框架/工具
成功解决AttributeError: module 'tensorflow.python.keras' has no attribute 'Model'
成功解决AttributeError: module 'tensorflow.python.keras' has no attribute 'Model'