【单点知识】基于实例讲解PyTorch中的ImageFolder类

简介: 【单点知识】基于实例讲解PyTorch中的ImageFolder类

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

torchvision.datasets.ImageFolderPyTorch 中用于加载图像分类数据集的一个实用类。它特别适用于图像分类任务(可以说是图像分类任务离不开ImageFolder),因为它能够自动将文件夹结构映射到类别标签上


本文将基于实例详细介绍ImageFolder类。


1. ImageFolder功能

  1. ImageFolder 自动遍历指定目录下的所有子文件夹,并将每个子文件夹视为一个不同的类别或标签。(最重要的功能)
  2. 它允许用户轻松地加载和迭代训练/验证集中的图像数据,同时自动提供对应的类别标签
  3. 可以方便地与 PyTorch 数据加载器(DataLoader)结合使用,实现批量化、并行化数据加载,这对于深度学习模型训练非常关键。


2 基本使用方法及参数解析

2.1 基本调用方式
from torchvision.datasets import ImageFolder
from torchvision import transforms
import PIL

root = ...
transforms = transforms.Compose([...])
target_transform = transforms.Compose([...])

folder = ImageFolder(root=root, transform=transforms,
                     target_transform= target_transform, loader=PIL.Image.open)


2.2 构造参数说明
  • root(必需):字符串类型,表示数据集所在的根目录。
  • transform(可选,几乎是必须):一个 transforms 对象,用于对读取的图像进行一系列预处理操作,如调整大小、转为 tensor、标准化等。
  • target_transform(可选):对目标标签进行转换的函数或变换对象。
  • loader(可选,从没用过):用于加载图像文件的函数。
2.3 属性
  • classes:一个包含所有类别名称的列表,按照文件夹名字排序。
  • class_to_idx:一个字典,key为类别名称,value为对应的整数索引。
  • imgs:一个元组,形为(image path, class_index)
2.4 方法
  • __getitem__(index):通过索引获取单个数据项,返回一个元组 (image, target),其中 image 是经过 transform 处理后的图像 tensor,target 是对应的类别索引。


利用 ImageFolderDataLoader 结合的方式,可以高效地准备和访问数据,从而大大简化了图像分类任务中数据预处理和加载的工作流程。


3. PyTorch实例说明

3.1 实例数据集

使用之前文章用过的hymenoptera数据集的简化版,其文件结构如下:

hymenoptera/hymenoptera_dataset/
├── train  #训练组
│   ├── ants  #4张图像
│   └── bees  #5张图像
└── val   #验证组
    ├── ants  #6张图像
    └── bees  #7张图像
3.2 实例说明
  • ImageFolder的属性实例说明:
from torchvision.datasets import ImageFolder

image_path1 = '.\hymenoptera'
image_path2 = '.\hymenoptera\hymenoptera_data'
image_path3 = '.\hymenoptera\hymenoptera_data\\train'

folder1 = ImageFolder(root=image_path1)
print(folder1.classes, '|',folder1.class_to_idx,'|', folder1.imgs)
folder2 = ImageFolder(root=image_path2)
print(folder2.classes, '|',folder2.class_to_idx,'|', folder2.imgs)
folder3 = ImageFolder(root=image_path3)
print(folder3.classes, '|',folder3.class_to_idx,'|', folder3.imgs)

输出为:

['hymenoptera_data'] | {'hymenoptera_data': 0} | [('.\\hymenoptera\\hymenoptera_data\\train\\ants\\0013035.jpg', 0), ...)]
['train', 'val'] | {'train': 0, 'val': 1} | [('.\\hymenoptera\\hymenoptera_data\\train\\ants\\0013035.jpg', 0), ...)]
['ants', 'bees'] | {'ants': 0, 'bees': 1} | [('.\\hymenoptera\\hymenoptera_data\\train\\ants\\0013035.jpg', 0), ...)]


这里可以看出,虽然是同一个数据集,但是root选择的文件夹层级不同会有不同的classes和自动编号的classes_id。但是无论选择哪个文件夹层级,imgs都能自动遍历所选文件路径下的所有图像。(从下面的方法说明也能看出)


  • ImageFolder的方法实例说明:
from torchvision.datasets import ImageFolder

image_path1 = '.\hymenoptera'
image_path2 = '.\hymenoptera\hymenoptera_data'
image_path3 = '.\hymenoptera\hymenoptera_data\\train'

folder1 = ImageFolder(root=image_path1)
folder2 = ImageFolder(root=image_path2)
folder3 = ImageFolder(root=image_path3)

print(len(folder1),len(folder2),len(folder3)) #输出:22 22 9

for img,id in folder3:
    print(img,'|',id)
'''
输出:
<PIL.Image.Image image mode=RGB size=768x512 at 0x20AA7E5D550> | 0
<PIL.Image.Image image mode=RGB size=500x375 at 0x20AA7E5D7D0> | 0
<PIL.Image.Image image mode=RGB size=500x369 at 0x20AA7E5ECD0> | 0
<PIL.Image.Image image mode=RGB size=500x181 at 0x20AA7E5D850> | 0
<PIL.Image.Image image mode=RGB size=500x450 at 0x20AA7E5D7D0> | 1
<PIL.Image.Image image mode=RGB size=500x412 at 0x20AA7E5ECD0> | 1
<PIL.Image.Image image mode=RGB size=500x334 at 0x20AA7E5D850> | 1
<PIL.Image.Image image mode=RGB size=500x400 at 0x20AA7E5D7D0> | 1
<PIL.Image.Image image mode=RGB size=500x173 at 0x20AA7E5ECD0> | 1
'''

print(folder3[0])   #输出:(<PIL.Image.Image image mode=RGB size=768x512 at 0x20AA7E5D850>, 0)

4. 总结

torchvision.datasets.ImageFolder 是 PyTorch 中用于加载图像数据集的一个重要类,它特别适用于组织结构清晰的图片分类任务。ImageFolder 假设数据集按照以下结构存储:

  • 根目录(root)
  • 类别1
  • 图像1.jpg
  • 图像2.jpg
  • 类别2
  • 图像1.jpg
  • 图像3.png

每个子文件夹代表一个类别,并且该文件夹的下的图像都会被自动编号。ImageFolder 会遍历这些子文件夹中的所有图像文件,并将它们与其对应的类别标签配对。


使用 ImageFolder 时,其返回的对象是一个可迭代的数据加载器,每次迭代都会产出一个 (image, label) 元组,其中 image 可以是 PIL Image 或经过 transform 转换后的张量形式的图像,而 label 则是从对应文件夹名称映射得到的整数类别标签。


通过结合 PyTorch 的 DataLoader,可以方便地实现图像数据的批量化加载,非常适合于训练深度学习模型。


相关文章
|
8月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【单点知识】基于实例详解PyTorch中的DataLoader类
【单点知识】基于实例详解PyTorch中的DataLoader类
701 2
|
8月前
|
机器学习/深度学习 算法 PyTorch
【PyTorch实战演练】自调整学习率实例应用(附代码)
【PyTorch实战演练】自调整学习率实例应用(附代码)
261 0
|
8月前
|
PyTorch 算法框架/工具
Pytorch中最大池化层Maxpool的作用说明及实例使用(附代码)
Pytorch中最大池化层Maxpool的作用说明及实例使用(附代码)
782 0
|
8月前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch基础之网络模块torch.nn中函数和模板类的使用详解(附源码)
PyTorch基础之网络模块torch.nn中函数和模板类的使用详解(附源码)
712 0
|
8月前
|
机器学习/深度学习 PyTorch 算法框架/工具
基于Pytorch通过实例详细剖析CNN
基于Pytorch通过实例详细剖析CNN
88 1
基于Pytorch通过实例详细剖析CNN
|
8月前
|
机器学习/深度学习 算法 大数据
基于PyTorch对凸函数采用SGD算法优化实例(附源码)
基于PyTorch对凸函数采用SGD算法优化实例(附源码)
120 3
|
8月前
|
机器学习/深度学习 算法 PyTorch
基于Pytorch用GAN生成手写数字实例(附代码)
基于Pytorch用GAN生成手写数字实例(附代码)
208 0
|
8月前
|
机器学习/深度学习 算法 PyTorch
基于Pytorch的机器学习Regression问题实例(附源码)
基于Pytorch的机器学习Regression问题实例(附源码)
100 1
|
8月前
|
机器学习/深度学习 自然语言处理 算法
PyTorch实例:简单线性回归的训练和反向传播解析
PyTorch实例:简单线性回归的训练和反向传播解析
PyTorch实例:简单线性回归的训练和反向传播解析
|
8月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【单点知识】基于实例讲解PyTorch中的transforms类
【单点知识】基于实例讲解PyTorch中的transforms类
110 0