4.3 图像分类ResNet实战:眼疾识别

简介: 这篇文章介绍了使用ResNet网络进行眼疾识别的实战流程,涵盖了计算机视觉任务研发的全流程,包括数据处理、数据预处理、数据读取器的定义,以及如何利用iChallenge-PM数据集进行模型训练和评估。

4.3 图像分类ResNet实战:眼疾识别

基本的计算机视觉任务研发全流程如图1所示:

图1:基本的计算机视觉任务研发全流程

其中,基本的计算机视觉任务研发全流程包含模型训练、模型预测和模型部署三大步骤。每个步骤又包含单独的流程:

  • 数据准备:根据网络接收的数据格式,完成相应的预处理和跑批量数据读取器操作,保证模型正常读取;
  • 模型构建:设计卷积网络结构;
  • 特征提取:使用构建的模型提取数据的特征信息;
  • 损失函数:通过损失函数衡量模型的预测值和真实值的不一致程度,通常损失函数越小,模型性能越好;
  • 模型评估:在模型训练中或训练结束后,对模型进行评估测试,观察准确率;
  • 模型预测:使用训练好的模型进行测试,也需要准备数据和模型特征提取,最后对结果进行解析。

4.3.1 数据处理

4.3.1.1 数据集介绍

如今近视已经成为困扰人们健康的一项全球性负担,在近视人群中,有超过35%的人患有重度近视。近视会拉长眼睛的光轴,也可能引起视网膜或者络网膜的病变。随着近视度数的不断加深,高度近视有可能引发病理性病变,这将会导致以下几种症状:视网膜或者络网膜发生退化、视盘区域萎缩、漆裂样纹损害、Fuchs斑等。因此,及早发现近视患者眼睛的病变并采取治疗,显得非常重要。

iChallenge-PM是百度大脑和中山大学中山眼科中心联合举办的iChallenge比赛中,提供的关于病理性近视(Pathologic Myopia,PM)的医疗类数据集,包含1200个受试者的眼底视网膜图片,训练、验证和测试数据集各400张。iChallenge-PM分为2类:病理性近视和非病理性近视(包含高度近视和正常眼睛),数据集命名规则如下:

  • 病理性近视(PM):文件名以P开头
  • 非病理性近视(non-PM):
  • 高度近视(high myopia):文件名以H开头
  • 正常眼睛(normal):文件名以N开头

有监督图像分类任务的数据包含图片和对应的类别标签(如猫、狗等),我们将病理性患者的图片作为正样本,标签为1; 非病理性患者的图片作为负样本,标签为0。iChallenge-PM数据集的示意如图2所示。

图2:iChallenge-PM数据集示例

4.3.1.2 数据集下载

AIStudio项目包含了iChallenge-PM数据集,位于/home/aistudio/data/data19065目录,包括如下三个文件:

  • training.zip:训练中的图片和标签
  • validation.zip:验证集的图片
  • valid_gt.zip:验证集的标签

通过如下命令解压数据集(执行一次即可),解压缩后存放在/home/aistudio/work/palm目录下:


注意

valid_gt.zip文件解压缩之后,需要将“/home/aistudio/work/palm/PALM-Validation-GT/”目录下的“PM_Label_and_Fovea_Location.xlsx”文件转存成.csv格式,本节代码示例中已经提前转成文件labels.csv。


# 如果已经解压过,不需要运行此段代码,否则由于文件已经存在,解压时会报错
!unzip -o -q -d /home/aistudio/work/palm /home/aistudio/data/data19065/training.zip
%cd /home/aistudio/work/palm/PALM-Training400/
!unzip -o -q PALM-Training400.zip
!unzip -o -q -d /home/aistudio/work/palm /home/aistudio/data/data19065/validation.zip
!unzip -o -q -d /home/aistudio/work/palm /home/aistudio/data/data19065/valid_gt.zip
#返回家目录,生成模型文件位于/home/aistudio/
%cd /home/aistudio/

/home/aistudio/work/palm/PALM-Training400

/home/aistudio

从数据集中选取两张图片,并将图片显示出来。代码如下所示:

import os
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from PIL import Image
DATADIR = '/home/aistudio/work/palm/PALM-Training400/PALM-Training400'
# 文件名以N开头的是正常眼底图片,以P开头的是病变眼底图片
file1 = 'N0012.jpg'
file2 = 'P0095.jpg'
# 读取图片
img1 = Image.open(os.path.join(DATADIR, file1))
img1 = np.array(img1)
img2 = Image.open(os.path.join(DATADIR, file2))
img2 = np.array(img2)
# 画出读取的图片
plt.figure(figsize=(16, 8))
f = plt.subplot(121)
f.set_title('Normal', fontsize=20)
plt.imshow(img1)
f = plt.subplot(122)
f.set_title('PM', fontsize=20)
plt.imshow(img2)
plt.show()

In [ ]

# 查看图片形状
img1.shape, img2.shape

((2056, 2124, 3), (2056, 2124, 3))

4.3.1.3 数据预处理

图像分类网络对输入图片的格式、大小有一定的要求,数据灌入模型前,需要对数据进行预处理操作,使图片满足网络训练以及预测的需要。本实验主要应用了如下方法:

  • 调整图片大小:将每张图缩放到 224 × 224大小 ,统一大小训练速度更快
  • 归一化:将像素值调整到[-1, 1] 之间 ,效果更好

代码如下所示:

import cv2
import numpy as np
# 对读入的图像数据进行预处理
def transform_img(img):
    # 将图片尺寸缩放道 224x224
    img = cv2.resize(img, (224, 224))
    # 读入的图像数据格式是[H, W, C]
    # 使用转置操作将其变成[C, H, W]
    img = np.transpose(img, (2,0,1))
    img = img.astype('float32')
    # 将数据范围调整到[-1.0, 1.0]之间
    img = img / 255.
    img = img * 2.0 - 1.0
    return img

4.1.1.4 定义数据读取器

上面的代码仅展示了读取一张图片和预处理的方法,但在真实场景的模型训练与评估过程中,通常会使用批量数据读取和预处理的方式。

In [3]

import cv2
import random
import numpy as np
import os
# 定义训练集数据读取器
def data_loader(datadir, batch_size=10, mode = 'train'):
    # 将datadir目录下的文件列出来,每条文件都要读入
    filenames = os.listdir(datadir)
    def reader():
        if mode == 'train':
            # 训练时随机打乱数据顺序
            random.shuffle(filenames)
        batch_imgs = []
        batch_labels = []
        for name in filenames:
            filepath = os.path.join(datadir, name)
            img = cv2.imread(filepath)
            img = transform_img(img)
            if name[0] == 'H' or name[0] == 'N':
                # H开头的文件名表示高度近似,N开头的文件名表示正常视力
                # 高度近视和正常视力的样本,都不是病理性的,属于负样本,标签为0
                label = 0
            elif name[0] == 'P':
                # P开头的是病理性近视,属于正样本,标签为1
                label = 1
            else:
                raise('Not excepted file name')
            # 每读取一个样本的数据,就将其放入数据列表中
            batch_imgs.append(img)
            batch_labels.append(label)
            if len(batch_imgs) == batch_size:
                # 当数据列表的长度等于batch_size的时候,
                # 把这些数据当作一个mini-batch,并作为数据生成器的一个输出
                imgs_array = np.array(batch_imgs).astype('float32')
                labels_array = np.array(batch_labels).reshape(-1, 1)
                yield imgs_array, labels_array
                batch_imgs = []
                batch_labels = []
        if len(batch_imgs) > 0:
            # 剩余样本数目不足一个batch_size的数据,一起打包成一个mini-batch
            imgs_array = np.array(batch_imgs).astype('float32')
            labels_array = np.array(batch_labels).reshape(-1, 1)
            yield imgs_array, labels_array
    return reader
# 定义验证集数据读取器
def valid_data_loader(datadir, csvfile, batch_size=10, mode='valid'):
    # 训练集读取时通过文件名来确定样本标签,验证集则通过csvfile来读取每个图片对应的标签
    # 请查看解压后的验证集标签数据,观察csvfile文件里面所包含的内容
    # csvfile文件所包含的内容格式如下,每一行代表一个样本,
    # 其中第一列是图片id,第二列是文件名,第三列是图片标签,
    # 第四列和第五列是Fovea的坐标,与分类任务无关
    # ID,imgName,Label,Fovea_X,Fovea_Y
    # 1,V0001.jpg,0,1157.74,1019.87
    # 2,V0002.jpg,1,1285.82,1080.47
    # 打开包含验证集标签的csvfile,并读入其中的内容
    filelists = open(csvfile).readlines()
    def reader():
        batch_imgs = []
        batch_labels = []
        for line in filelists[1:]:
            line = line.strip().split(',')
            name = line[1]
            label = int(line[2])
            # 根据图片文件名加载图片,并对图像数据作预处理
            filepath = os.path.join(datadir, name)
            img = cv2.imread(filepath)
            img = transform_img(img)
            # 每读取一个样本的数据,就将其放入数据列表中
            batch_imgs.append(img)
            batch_labels.append(label)
            if len(batch_imgs) == batch_size:
                # 当数据列表的长度等于batch_size的时候,
                # 把这些数据当作一个mini-batch,并作为数据生成器的一个输出
                imgs_array = np.array(batch_imgs).astype('float32')
                labels_array = np.array(batch_labels).reshape(-1, 1)
                yield imgs_array, labels_array
                batch_imgs = []
                batch_labels = []
        if len(batch_imgs) > 0:
            # 剩余样本数目不足一个batch_size的数据,一起打包成一个mini-batch
            imgs_array = np.array(batch_imgs).astype('float32')
            labels_array = np.array(batch_labels).reshape(-1, 1)
            yield imgs_array, labels_array
    return reader

In [6]

import paddle
paddle.seed(100)
# 查看数据形状
DATADIR = '/home/aistudio/work/palm/PALM-Training400/PALM-Training400'
train_loader = data_loader(DATADIR, 
                           batch_size=10, mode='train')
data_reader = train_loader()
data = next(data_reader)
data[0].shape, data[1].shape
eval_loader = data_loader(DATADIR, 
                           batch_size=10, mode='eval')
data_reader = eval_loader()
data = next(data_reader)
data[0].shape, data[1].shape

((10, 3, 224, 224), (10, 1))

相关文章
|
1月前
|
机器学习/深度学习 数据采集 人工智能
深度学习的魔法:用神经网络识别手写数字
本文将引导读者了解如何使用深度学习技术,特别是卷积神经网络(CNN)来识别手写数字。我们将从基础理论出发,逐步深入到实际操作,包括数据的预处理、模型的构建和训练,以及结果的评估。通过本文,读者不仅能掌握使用深度学习进行图像识别的技能,还能理解其背后的原理。让我们一同揭开深度学习的神秘面纱,探索其在图像处理领域的无限可能。
|
1月前
|
机器学习/深度学习 传感器 监控
深度学习之动作识别与分类
基于深度学习的动作识别与分类是指通过深度学习模型从视频或传感器数据中自动识别和分类人类动作的过程。这项技术广泛应用于视频监控、安全监控、体育分析、医疗康复、虚拟现实(VR)和增强现实(AR)等领域。
64 1
|
3月前
|
API 异构计算
4.3.2 图像分类ResNet实战:眼疾识别——模型构建
这篇文章介绍了如何使用飞桨框架中的ResNet50模型进行眼疾识别的实战,通过5个epoch的训练,在验证集上达到了约96%的准确率,并提供了模型构建、训练、评估和预测的详细代码实现。
|
5月前
|
机器学习/深度学习 存储 自动驾驶
基于YOLOv8深度学习的45种交通标志智能检测与识别系统【python源码+Pyqt5界面+数据集+训练代码】目标检测、深度学习实战
基于YOLOv8深度学习的45种交通标志智能检测与识别系统【python源码+Pyqt5界面+数据集+训练代码】目标检测、深度学习实战
|
6月前
|
机器学习/深度学习 并行计算 算法
【计算机视觉+CNN】keras+ResNet残差网络实现图像识别分类实战(附源码和数据集 超详细)
【计算机视觉+CNN】keras+ResNet残差网络实现图像识别分类实战(附源码和数据集 超详细)
182 0
|
机器学习/深度学习 数据采集 缓存
【深度学习实践(七)】卷积神经网络(Xception)之动物识别
【深度学习实践(七)】卷积神经网络(Xception)之动物识别
|
机器学习/深度学习 算法 数据库
图像识别2:图像多分类实验
图像识别2:图像多分类实验
72 0
|
机器学习/深度学习 算法 PyTorch
【深度学习】实验16 使用CNN完成MNIST手写体识别(PyTorch)
【深度学习】实验16 使用CNN完成MNIST手写体识别(PyTorch)
162 0
|
机器学习/深度学习 自然语言处理 算法
【深度学习】实验15 使用CNN完成MNIST手写体识别(Keras)
【深度学习】实验15 使用CNN完成MNIST手写体识别(Keras)
117 0
|
机器学习/深度学习 人工智能 文字识别
深度学习应用篇-计算机视觉-OCR光学字符识别[7]:OCR综述、常用CRNN识别方法、DBNet、CTPN检测方法等、评估指标、应用场景
深度学习应用篇-计算机视觉-OCR光学字符识别[7]:OCR综述、常用CRNN识别方法、DBNet、CTPN检测方法等、评估指标、应用场景
深度学习应用篇-计算机视觉-OCR光学字符识别[7]:OCR综述、常用CRNN识别方法、DBNet、CTPN检测方法等、评估指标、应用场景