对于新手入门,我们最常见的就是猫狗分类数据集,但是对于已经在本地的图像文件,我们一改如何加载进来呢?
这里pytorch中给出了ImageFolder函数,它可以将指定路径下的所有图像读取进来,对于ImageFolder使用方法,我们需要将所有图像按照文件夹保存,例如所有猫的图像放入到cat文件夹中,所有狗的图像放入到dog文件夹中,该函数就会自动识别类别,将图像所在的目录名称作为label。
torchvision.datasets.ImageFolder()
参数列表:
- root:图像文件读取路径
- transform:对图像数据采取的数据增强策略
- target_transform:对label进行转换
- loader:指定加载图像的函数
- is_valid_file:获取图像路径,检查文件的有效性
下面给出代码示例,例如我们图像有两个类别:ants和bees,我们需要将对应的图像放入到相应的文件夹中即可。
data_path = r'./data/train' dataset = torchvision.datasets.ImageFolder(data_path)
返回对应图像
返回对应label
如果需要对读取的图像进行处理,只需要将transform操作作为参数传入即可。