4.3.2 图像分类ResNet实战:眼疾识别——模型构建

简介: 这篇文章介绍了如何使用飞桨框架中的ResNet50模型进行眼疾识别的实战,通过5个epoch的训练,在验证集上达到了约96%的准确率,并提供了模型构建、训练、评估和预测的详细代码实现。

4.3.2 模型构建

上一节定义好已经了解了ResNet模型结构,本节直接使用飞桨高层API中的Resnet50进行图像分类实验。

In [7]

from paddle.vision.models import resnet50
model = resnet50()

W0714 20:32:55.131150 102 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.1

W0714 20:32:55.136173 102 device_context.cc:465] device: 0, cuDNN Version: 7.6.

4.3.3 损失函数

飞桨高层API中都为大家提供了实现好交叉熵损失函数,代码如下所示。

In [8]

import paddle.nn.functional as F
loss_fn = F.cross_entropy

4.3.4 模型训练

使用交叉熵损失函数,并用SGD作为优化器来训练ResNet网络。

In [9]

# -*- coding: utf-8 -*-
# LeNet 识别眼疾图片
import os
import random
import paddle
import numpy as np
class Runner(object):
    def __init__(self, model, optimizer, loss_fn):
        self.model = model
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        
        # 记录全局最优指标
        self.best_acc = 0
    
    # 定义训练过程
    def train_pm(self, train_datadir, val_datadir, **kwargs):
        print('start training ... ')
        self.model.train()
        
        num_epochs = kwargs.get('num_epochs', 0)
        csv_file = kwargs.get('csv_file', 0)
        save_path = kwargs.get("save_path", "/home/aistudio/output/")
        # 定义数据读取器,训练数据读取器
        train_loader = data_loader(train_datadir, batch_size=10, mode='train')
        
        for epoch in range(num_epochs):
            for batch_id, data in enumerate(train_loader()):
                x_data, y_data = data
                img = paddle.to_tensor(x_data)
                label = paddle.to_tensor(y_data)
                # 运行模型前向计算,得到预测值
                logits = model(img) 
                avg_loss = self.loss_fn(logits, label)
                
                if batch_id % 20 == 0:
                    print("epoch: {}, batch_id: {}, loss is: {:.4f}".format(epoch, batch_id, float(avg_loss.numpy())))
                # 反向传播,更新权重,清除梯度
                avg_loss.backward()
                self.optimizer.step()
                self.optimizer.clear_grad()
            
            acc = self.evaluate_pm(val_datadir, csv_file)
            self.model.train()
            if acc > self.best_acc:
                self.save_model(save_path)
                self.best_acc = acc
    # 模型评估阶段,使用'paddle.no_grad()'控制不计算和存储梯度
    @paddle.no_grad()
    def evaluate_pm(self, val_datadir, csv_file):
        self.model.eval()
        accuracies = []
        losses = []
        # 验证数据读取器
        valid_loader = valid_data_loader(val_datadir, csv_file)
        for batch_id, data in enumerate(valid_loader()):
            x_data, y_data = data
            img = paddle.to_tensor(x_data)
            label = paddle.to_tensor(y_data)
            # 运行模型前向计算,得到预测值
            logits = self.model(img)
            # 多分类,使用softmax计算预测概率
            pred = F.softmax(logits)
            loss = self.loss_fn(pred, label)
            acc = paddle.metric.accuracy(pred, label)
            accuracies.append(acc.numpy())
            losses.append(loss.numpy())
        print("[validation] accuracy/loss: {:.4f}/{:.4f}".format(np.mean(accuracies), np.mean(losses)))
        return np.mean(accuracies)    
    
    # 模型评估阶段,使用'paddle.no_grad()'控制不计算和存储梯度
    @paddle.no_grad()
    def predict_pm(self, x, **kwargs):
        # 将模型设置为评估模式
        self.model.eval()
        # 运行模型前向计算,得到预测值
        logits = self.model(x)
        return logits
    
    def save_model(self, save_path):
        paddle.save(self.model.state_dict(), save_path + 'palm.pdparams')
        paddle.save(self.optimizer.state_dict(), save_path + 'palm.pdopt')
    
    def load_model(self, model_path):
        model_state_dict = paddle.load(model_path)
        self.model.set_state_dict(model_state_dict)

实例化Runner类,并传入训练配置,代码实现如下:

In [12]

# 开启0号GPU训练
use_gpu = True
paddle.device.set_device('gpu:0') if use_gpu else paddle.device.set_device('cpu')
# 定义优化器
# opt = paddle.optimizer.Momentum(learning_rate=0.001, momentum=0.9, parameters=model.parameters(), weight_decay=0.001)
opt = paddle.optimizer.SGD(learning_rate=0.001, parameters=model.parameters())
runner = Runner(model, opt, loss_fn)

用Runner在训练集上训练5个epoch,并保存准确率最高的模型作为最佳模型。

In [13]

import os
# 数据集路径
DATADIR = '/home/aistudio/work/palm/PALM-Training400/PALM-Training400'
DATADIR2 = '/home/aistudio/work/palm/PALM-Validation400'
CSVFILE = '/home/aistudio/labels.csv'
# 设置迭代轮数
EPOCH_NUM = 5
# 模型保存路径
PATH='/home/aistudio/output/'
if not os.path.exists(PATH):
    os.makedirs(PATH)
# 启动训练过程
runner.train_pm(DATADIR, DATADIR2, 
                num_epochs=EPOCH_NUM, csv_file=CSVFILE, save_path=PATH)

start training ...

epoch: 0, batch_id: 0, loss is: 0.3287

epoch: 0, batch_id: 20, loss is: 0.0716

[validation] accuracy/loss: 0.9625/5.9818

/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/framework/io.py:729: UserWarning: The input state dict is empty, no need to save.

warnings.warn("The input state dict is empty, no need to save.")

epoch: 1, batch_id: 0, loss is: 0.1286

epoch: 1, batch_id: 20, loss is: 0.4718

[validation] accuracy/loss: 0.9650/5.9824

epoch: 2, batch_id: 0, loss is: 0.0892

epoch: 2, batch_id: 20, loss is: 0.0313

[validation] accuracy/loss: 0.9625/5.9801

epoch: 3, batch_id: 0, loss is: 0.1362

epoch: 3, batch_id: 20, loss is: 0.0569

[validation] accuracy/loss: 0.9625/5.9746

epoch: 4, batch_id: 0, loss is: 0.1036

epoch: 4, batch_id: 20, loss is: 0.0873

[validation] accuracy/loss: 0.9575/5.9856

通过运行结果可以发现,使用ResNet在眼疾筛查数据集iChallenge-PM上,经过5个epoch的训练,在验证集上的准确率可以达到96%左右。

4.3.5 模型评估

使用测试数据对在训练过程中保存的最佳模型进行评价,观察模型在评估集上的准确率。代码实现如下:

In [ ]

# 加载最优模型
runner.load_model('/home/aistudio/output/palm.pdparams')
# 模型评价
score = runner.evaluate_pm(DATADIR2, CSVFILE)

[validation] accuracy/loss: 0.9725/5.9591

4.3.6 模型预测

同样地,也可以使用保存好的模型,对测试集中的某一个数据进行模型预测,观察模型效果。代码实现如下:

In [18]

import cv2
from PIL import Image
import matplotlib.pyplot as plt
import paddle
import paddle.nn.functional as F
%matplotlib inline
# 加载最优模型
runner.load_model('/home/aistudio/output/palm.pdparams')
# 获取测试集中第一条数据
DATADIRv2 = '/home/aistudio/work/palm/PALM-Validation400'
filelists = open('/home/aistudio/labels.csv').readlines()
# 可以通过修改filelists列表的数字获取其他测试图片,可取值1-400
line = filelists[1].strip().split(',')
name, label = line[1], int(line[2])
# 读取测试图片
img = cv2.imread(os.path.join(DATADIRv2, name))
# 测试图片预处理
trans_img = transform_img(img)
unsqueeze_img = paddle.unsqueeze(paddle.to_tensor(trans_img), axis=0)
# 模型预测
logits = runner.predict_pm(unsqueeze_img)
result=F.softmax(logits)
pred_class = paddle.argmax(result).numpy()
# 输出真实类别与预测类别
print("The true category is {} and the predicted category is {}".format(label, pred_class))
# 图片可视化
show_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.imshow(show_img)
plt.show()

The true category is 0 and the predicted category is [0]

小结

在这一节里,我们通过ResNet模型实现眼疾识别,在验证集上的预测精度在95%左右,通过这个案例熟悉了基础的视觉任务构建流程。如果读者有兴趣的话,可以进一步调整学习率和训练轮数等超参数,观察是否能够得到更高的精度。

作业

本节通过调用飞桨高层API Resnet50模型from paddle.vision.models import resnet50实现了眼疾识别。更换其他模型,看是否能得到更高的精度,

相关实践学习
部署Stable Diffusion玩转AI绘画(GPU云服务器)
本实验通过在ECS上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。
相关文章
|
6月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【PyTorch实战演练】AlexNet网络模型构建并使用Cifar10数据集进行批量训练(附代码)
【PyTorch实战演练】AlexNet网络模型构建并使用Cifar10数据集进行批量训练(附代码)
465 0
|
6月前
|
存储 编解码 安全
带三维重建和还原的PACS源码 医学影像PACS系统源码
带三维重建和还原的PACS源码 医学影像PACS系统源码 PACS及影像存取与传输系统”( Picture Archiving and Communication System),为以实现医学影像数字化存储、诊断为核心任务,从医学影像设备(如CT、CR、DR、MR、DSA、RF等)获取影像,集中存储、综合管理医学影像及病人相关信息,建立数字化工作流程。系统可实现检查预约、病人信息登记、计算机阅片、电子报告书写、胶片打印、数据备份等一系列满足影像科室日常工作的功能,并且由于影像数字化存储,用户可利用影像处理与测量技术辅助诊断、方便快捷地查找资料或利用网络将资料传输至临床科室,还可与医院HIS、L
93 0
|
6月前
|
存储 数据采集 固态存储
带三维重建和还原功能的医学影像管理系统(pacs)源码
带三维重建和还原功能的医学影像管理系统(pacs)源码
108 0
|
6月前
|
存储 数据采集 编解码
【PACS】医学影像管理系统源码带三维重建后处理技术
【PACS】医学影像管理系统源码带三维重建后处理技术
121 0
|
6月前
|
C++
【C++医学影像PACS】CT检查中的三维重建是什么检查?
【C++医学影像PACS】CT检查中的三维重建是什么检查?
171 0
|
6月前
|
存储 数据可视化 vr&ar
突破传统 重新定义:3D医学影像PACS系统源码(包含RIS放射信息) 实现三维重建与还原
突破传统,重新定义PACS/RIS服务,洞察用户需求,关注应用场景,新一代PACS/RIS系统,系统顶层设计采用集中+分布式架构,满足医院影像全流程业务运行,同时各模块均可独立部署,满足医院未来影像信息化扩展新需求、感受新时代影像服务便捷性、易用性!系统基于平台化设计,与第三方服务自然接入无压力,从功能多样化到调阅速度快;覆盖(放射、超声、内镜、病理、核医学、心血管、临床科室等,是以影像采集、传输、存储、诊断、报告书写和科室管理)为核心应用的模块化PACS/RIS系统,实现了全院级影像信息的合理共享与应用。
122 0
突破传统 重新定义:3D医学影像PACS系统源码(包含RIS放射信息) 实现三维重建与还原
|
存储 数据库 数据安全/隐私保护
基于C++开发,支持三维重建,多平面重建技术的医学影像PACS系统源码
支持非DICOM标准的影像设备的图像采集和处理。 3)支持各种扫描仪、数码相机等影像输入设备。 4)支持各大主流厂商的CT、MR、DSA、ECT、US、数字胃肠、内镜等影像设备; 5)支持所有的DICOM相机,支持各大厂家的激光相机。 6)系统完全支持HL7接口和ICD—10编码,可与HIS系统无缝连接。 7)提供全院级、科室级工作站以及远程会诊工作站,三维重建,多平面重建。
166 0
基于C++开发,支持三维重建,多平面重建技术的医学影像PACS系统源码
|
6月前
|
数据采集 存储 数据可视化
医院影像PACS系统三维重建技术(获取数据、预处理、重建)
开放式体系结构,完全符合DICOM3.0标准,提供HL7标准接口,可实现与提供相应标准接口的HIS系统以及其他医学信息系统间的数据通信。
235 3
|
6月前
|
存储 编解码 监控
【C++】医学影像PACS三维重建后处理系统源码
系统完全符合国际标准的DICOM3.0标准
81 2
|
6月前
|
存储
医院PACS系统全套源码 强大的三维重建功能
对非DICOM影像,如超声、病理、心电图等进行了集成,做到了可以同时处理DICOM标准图像和非DICOM图像。
57 1

热门文章

最新文章