PyTorch 深度学习实战 | 基于 ResNet 的花卉图片分类

简介: 本期将提供一个利用深度学习进行花卉图片分类的案例,并使用迁移学习的方法解决训练数据较少的问题。图片分类是根据图像的语义信息对不同的图片进行区分,是计算机视觉中的基本问题,也是图像检测、图像分割、物体跟踪等高阶视觉任务的基础。在深度学习领域,图片分类的任务一般基于卷积神经网络来完成,如常见的卷积神经网络有 VGG、GoogleNet、ResNet 等。而在图像分类领域,数据标记是最基础和烦琐的工作。有时由于条件限制,往往得不到很多经过标记的、用于训练的图片,其中一个解决办法就是对已经预训练好的模型进行迁移学习。本文是以 ResNet 为基础,对花卉图片进行迁移学习,从而完成对花卉图片的分类任

“工欲善其事,必先利其器”。如果直接使用 Python 完成模型的构建、导出等工作,势必会耗费相当多的时间,而且大部分工作都是深度学习中共同拥有的部分,即重复工作。所以本案例为了快速实现效果,就直接使用将这些共有部分整理成框架的 TensorFlow 和 Keras 来完成开发工作。TensorFlow 是 Google 公司开源的基于数据流图的科学计算库,适合用于机器学习、深度学习等人工智能领域。Keras 是一个用 Python 编写的高级神经网络 API,它能够以 TensorFlow、CNTK 或 Theano 作为后端运行。Keras 的开发重点是支持快速的实验,所以,本案例中,大部分与模型有关的工作都是基于 Keras API 来完成的。而现在版本的 TensorFlow 已经将 Keras 集成了进来,所以只需要安装 TensorFlow 即可。注意,由于本案例采用的 ResNet 网络较深,所以模型训练需要消耗的资源较多,需要 GPU 来加速训练过程。

1、环境安装

安装 TensorFlow 的 GPU 版本是相对比较繁杂的事情,需要找对应的驱动,安装合适版本的 CUDA 和 cuDNN。而一种比较方便的办法就是使用 Anaconda 来进行 tensorflow-gpu 的安装。具体的安装过程可以参考本书的附录 A.2 部分。其他需要安装的依赖包的名称及版本号如下:

其他依赖包可以在 Anaconda 界面上进行选择安装,也可以将其添加到 requirements.txt 文件,然后使用 conda install -yes -file requirements.txt 命令进行安装。另外,Conda 可以创建不同的环境来支持不同的开发要求。例如,有些工程需要 TensorFlow 1.15.0 环境来进行开发,而另外一些工程需要 TensorFlow 2.1.0 来进行开发,替换整个工作环境或者重新安装 TensorFlow 都不是很好的选择。所以,本案例使用 Conda 创建虚拟环境来解决。

2、数据集简介

在进行模型构建和训练之前,需要进行数据收集。为了简化收集工作,本案例采用已标记好的花卉数据集 Oxford 102 Flowers。数据集可以从 VGG 官方网站上进行下载。单击如图 1 所示的 Downloads 区域的 1、4 和 5 对应的超链接就可以下载所需要的文件。

image.png


■ 图 1 Oxford 102 Flowers 数据集下载网站

该数据集由牛津大学工程科学系于 2008 年发布,是一个英国本土常见花卉的图片数据集,包含 102 个类别,每类包含 40 ~ 258 张图片。在基于深度学习的图像分类任务中,这样较为少量的图片还是比较有挑战性的。Oxford 102 Flowers 的分类细节和部分类别的图片及对应的数量如图 2 所示。

image.png


■ 图 2 Oxford 102 Flowers 的分类细节和部分类别的图片及对应的数量

除了图片文件(dataset images),数据集中还包含图片分割标记文件(image segmentations)、分类标记文件(the image iabels)和数据集划分文件(the data splits)。由于本案例中不涉及图片分割,所以使用的是图片、分类标记和数据集划分文件。

3、数据集的下载与处理

Python urllib 库提供了 urlretrieve()函数可以直接将远程数据下载到本地。可以使用 urlretrieve()函数下载所需文件;然后把压缩的图片文件进行解压,并解析分类标记文件和数据集划分文件;再根据数据集划分文件并分成训练集、验证集和测试集;最后,向不同类别的数据集中按图片所标识的花的种类分类存放图片文件。代码及详细注释如代码清单 1 所示。

代码清单 1

import os
from urllib.request import urlretrieve
import tarfile
from scipy. io import loadmat2
from shutil import copyfile
import glob
import numpy as np

"""
函数说明:按照分类(labels)复制未分组的图片到指定的位置10
Parameters:
    data path - 数据存放目录
    labels - 数据对应的标签,需要按标签放到不同的目录
"""

def copy_data_files(data path, labels) :
if not os. path, exists( data path) :
  os.mkdir(data path)
  
  # 创建分类目录
for i in range(0,102) :
os.mkdir(os.path.join( data path, str(i)))

for label in labels:
src path = str(label[0])
dst path = os.path. join(data path, label[1], src path. split(os. sep)[ - 1])
copyfile(src path, dst path)

if_name_ _== '_main_':
  # 检查本地数据集目录是否存在,若不存在,则需创建 
  data set path = "./data'
  if not os. path. exists( data set path) :
    os.mkdir(data set path)
    
#下载 102 Category Elower 数据集并解压 
flowers archive file = "102flowers.tgz'
flowers_url frefix = "https://www,robots.ox.ac.uk/~vgg/data/flowers/102/'
flowers archive path = os.path, join(data set path, flowers archive file)
if not os path.exists(flowers archive path) :
print("正在下载图片文件...")
urlretrieve(flowers url frefix + flowers archive file, flowers archive path)
print("图片文件下载完成.")
print("正在解压图片文件...")
tarfile. open(flowers archive path)..extractall(path = data set_path)
print("图片文件解压完成,")

# 下载标识文件,标识不同文件的类别
flowers labels file = "imagelabels.mat'
flowers labels path = os.path. join(data set path, flowers labels file)
   if not os.path.exists(flowers labels path) :
    print("正在下载标识文件...")
urlretrieve(flowers url frefix + flowers labels file, flowers labels path)
print("标识文件下载完成")
flower_labels = loadmat(flowers_labels_path)['labels'][0] - 1

#下载数据集分类文件,包含训练集、验证集和测试集
sets splits file = "setid.mat"
sets splits_path = os.path. join(data set path, sets splits file)
if not os.path,exists( sets splits path) :
print("正在下载数据集分类文件...")
urlretrieve(flowers url frefix + sets splits file, sets splits path)
print("数据集分类文件下载完成")
sets_splits = loadmat( sets splits path)

# 由于数据集分类文件中测试集数量比训练集多,所以进行了对调
train set = sets splits['tstid'][0] - 1
valid set = sets splits[ 'valid'][0] - 1
test_set = sets splits['trnid'][0] - 1

# 获取图片文件名并找到图片对应的分类标识
image files = sorted(glob.glob(os.path. join(data set path, 'jpg', ' x .jpg')))
image labels = np.array([i for i in zip(image files, flower labels)])

# 将训练集、验证集和测试集分别放在不同的目录下
print("正在进行训练集的复制...")
copy_data files(os.path. join(data set path, 'train'), image labels[train set, :]
  print("已完成训练集的复制,开始复制验证集...")
copy_data files(os.path. join(data_set_path, 'valid'), image labels[valid set, :]
  print("已完成验证集的复制,开始复制测试集...")
copy_data files(os.path, join(data set_path, 'test'), image labels[test set, :] 
  print("已完成测试集的复制,所有的图片下载和预处理工作已完成.")

下载的图片数据有 330MB 左右。国外的网站有时候下载比较慢,可以用下载工具下载,或者使用参考书前言中提供的二维码进行下载。

需要说明的是,分类标记文件 imagelabels.mat 和数据集划分文件 setid.mat 是 MATLAB 的数据存储的标准格式,可以用 MATLAB 程序打开进行查看。本案例中使用 scipy 库的 loadmat()函数对 .mat 文件进行读取。图片分类后的目录结构如图 3 所示。

image.png


■ 图 3 图片分类后的目录结构

目录
相关文章
|
3月前
|
机器学习/深度学习 人工智能 PyTorch
PyTorch深度学习 ? 带你从入门到精通!!!
🌟 蒋星熠Jaxonic,深度学习探索者。三年深耕PyTorch,从基础到部署,分享模型构建、GPU加速、TorchScript优化及PyTorch 2.0新特性,助力AI开发者高效进阶。
PyTorch深度学习 ? 带你从入门到精通!!!
|
3月前
|
机器学习/深度学习 PyTorch TensorFlow
TensorFlow与PyTorch深度对比分析:从基础原理到实战选择的完整指南
蒋星熠Jaxonic,深度学习探索者。本文深度对比TensorFlow与PyTorch架构、性能、生态及应用场景,剖析技术选型关键,助力开发者在二进制星河中驾驭AI未来。
742 13
|
4月前
|
机器学习/深度学习 存储 PyTorch
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
Neural ODE将神经网络与微分方程结合,用连续思维建模数据演化,突破传统离散层的限制,实现自适应深度与高效连续学习。
278 3
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
|
3月前
|
机器学习/深度学习 数据采集 人工智能
深度学习实战指南:从神经网络基础到模型优化的完整攻略
🌟 蒋星熠Jaxonic,AI探索者。深耕深度学习,从神经网络到Transformer,用代码践行智能革命。分享实战经验,助你构建CV、NLP模型,共赴二进制星辰大海。
|
5月前
|
PyTorch 算法框架/工具 异构计算
PyTorch 2.0性能优化实战:4种常见代码错误严重拖慢模型
我们将深入探讨图中断(graph breaks)和多图问题对性能的负面影响,并分析PyTorch模型开发中应当避免的常见错误模式。
349 9
|
7月前
|
机器学习/深度学习 存储 PyTorch
PyTorch + MLFlow 实战:从零构建可追踪的深度学习模型训练系统
本文通过使用 Kaggle 数据集训练情感分析模型的实例,详细演示了如何将 PyTorch 与 MLFlow 进行深度集成,实现完整的实验跟踪、模型记录和结果可复现性管理。文章将系统性地介绍训练代码的核心组件,展示指标和工件的记录方法,并提供 MLFlow UI 的详细界面截图。
330 2
PyTorch + MLFlow 实战:从零构建可追踪的深度学习模型训练系统
|
11月前
|
机器学习/深度学习 数据可视化 算法
PyTorch生态系统中的连续深度学习:使用Torchdyn实现连续时间神经网络
神经常微分方程(Neural ODEs)是深度学习领域的创新模型,将神经网络的离散变换扩展为连续时间动力系统。本文基于Torchdyn库介绍Neural ODE的实现与训练方法,涵盖数据集构建、模型构建、基于PyTorch Lightning的训练及实验结果可视化等内容。Torchdyn支持多种数值求解算法和高级特性,适用于生成模型、时间序列分析等领域。
572 77
PyTorch生态系统中的连续深度学习:使用Torchdyn实现连续时间神经网络
|
10月前
|
机器学习/深度学习 自然语言处理 算法
PyTorch PINN实战:用深度学习求解微分方程
物理信息神经网络(PINN)是一种将深度学习与物理定律结合的创新方法,特别适用于微分方程求解。传统神经网络依赖大规模标记数据,而PINN通过将微分方程约束嵌入损失函数,显著提高数据效率。它能在流体动力学、量子力学等领域实现高效建模,弥补了传统数值方法在高维复杂问题上的不足。尽管计算成本较高且对超参数敏感,PINN仍展现出强大的泛化能力和鲁棒性,为科学计算提供了新路径。文章详细介绍了PINN的工作原理、技术优势及局限性,并通过Python代码演示了其在微分方程求解中的应用,验证了其与解析解的高度一致性。
3719 5
PyTorch PINN实战:用深度学习求解微分方程
|
12月前
|
机器学习/深度学习 算法 PyTorch
昇腾910-PyTorch 实现 ResNet50图像分类
本实验基于PyTorch,在昇腾平台上使用ResNet50对CIFAR10数据集进行图像分类训练。内容涵盖ResNet50的网络架构、残差模块分析及训练代码详解。通过端到端的实战讲解,帮助读者理解如何在深度学习中应用ResNet50模型,并实现高效的图像分类任务。实验包括数据预处理、模型搭建、训练与测试等环节,旨在提升模型的准确率和训练效率。
609 54
|
11月前
|
机器学习/深度学习 PyTorch TensorFlow
深度学习工具和框架详细指南:PyTorch、TensorFlow、Keras
在深度学习的世界中,PyTorch、TensorFlow和Keras是最受欢迎的工具和框架,它们为研究者和开发者提供了强大且易于使用的接口。在本文中,我们将深入探索这三个框架,涵盖如何用它们实现经典深度学习模型,并通过代码实例详细讲解这些工具的使用方法。
1004 0

热门文章

最新文章

推荐镜像

更多