【项目实战】通过LLaMaFactory+Qwen2-VL-2B微调一个多模态医疗大模型

本文涉及的产品
NLP自然语言处理_高级版,每接口累计50万次
视觉智能开放平台,视频资源包5000点
视觉智能开放平台,图像资源包5000点
简介: 本文介绍了一个基于多模态大模型的医疗图像诊断项目。项目旨在通过训练一个医疗领域的多模态大模型,提高医生处理医学图像的效率,辅助诊断和治疗。作者以家中老人的脑部CT为例,展示了如何利用MedTrinity-25M数据集训练模型,经过数据准备、环境搭建、模型训练及微调、最终验证等步骤,成功使模型能够识别CT图像并给出具体的诊断意见,与专业医生的诊断结果高度吻合。

前言

随着多模态大模型的发展,其不仅限于文字处理,更能够在图像、视频、音频方面进行识别与理解。医疗领域中,医生们往往需要对各种医学图像进行处理,以辅助诊断和治疗。如果将多模态大模型与图像诊断相结合,那么这会极大地提升诊断效率。

项目目标

训练一个医疗多模态大模型,用于图像诊断。

刚好家里老爷子近期略感头疼,去医院做了脑部CT,诊断患有垂体瘤,我将尝试使用多模态大模型进行进一步诊断。

实现过程

1. 数据集准备

为了训练模型,需要准备大量的医学图像数据。通过搜索我们找到以下训练数据:

数据名称:MedTrinity-25M
数据地址https://github.com/UCSC-VLAA/MedTrinity-25M
数据简介:MedTrinity-25M数据集是一个用于医学图像分析和计算机视觉研究的大型数据集。
数据来源:该数据集由加州大学圣克鲁兹分校(UCSC)提供,旨在促进医学图像处理和分析的研究。
数据量:MedTrinity-25M包含约2500万条医学图像数据,涵盖多种医学成像技术,如CT、MRI和超声等。
数据内容
该数据集有两份,分别是 25Mdemo25Mfull

25Mdemo (约162,000条)数据集内容如下:

25Mfull (约24,800,000条)数据集内容如下:

2. 数据下载

2.1 安装Hugging Face的Datasets库

pip install datasets

2.2 下载数据集

from datasets import load_dataset

# 加载数据集
ds = load_dataset("UCSC-VLAA/MedTrinity-25M", "25M_demo", cache_dir="cache")

执行结果:

说明:

  • 以上方法是使用HuggingFace的Datasets库下载数据集,下载的路径为当前脚本所在路径下的cache文件夹。
  • 使用HuggingFace下载需要能够访问https://huggingface.co/ 并且在网站上申请数据集读取权限才可以。
  • 如果没有权限访问HuggingFace,可以关注"一起AI技术"公众号后,回复 “MedTrinity”获取百度网盘下载地址。

2.3 预览数据集

# 查看训练集的前1个样本
print(ds['train'][:1])

运行结果:

{
   
    'image': [<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512 at 0x15DD6D06530>], 
    'id': ['8031efe0-1b5c-11ef-8929-000066532cad'], 
    'caption': ['The image is a non-contrasted computed tomography (CT) scan of the brain, showing the cerebral structures without any medical devices present. The region of interest, located centrally and in the middle of the image, exhibits an area of altered density, which is indicative of a brain hemorrhage. This area is distinct from the surrounding brain tissue, suggesting a possible hematoma or bleeding within the brain parenchyma. The location and characteristics of this abnormality may suggest a relationship with the surrounding brain tissue, potentially causing a mass effect or contributing to increased intracranial pressure.'
    ]
}

使用如下命令对数据集的图片进行可视化查看:

# 可视化image内容
from PIL import Image
import matplotlib.pyplot as plt

image = ds['train'][0]['image']  # 获取第一张图像

plt.imshow(image)
plt.axis('off')  # 不显示坐标轴
plt.show()

运行结果:

3. 数据预处理

由于后续我们要通过LLama Factory进行多模态大模型微调,所以我们需要对上述的数据集进行预处理以符合LLama Factory的要求。

3.1 LLama Factory数据格式

查看LLama Factory的多模态数据格式要求如下:

[
  {
   
    "messages": [
      {
   
        "content": "<image>他们是谁?",
        "role": "user"
      },
      {
   
        "content": "他们是拜仁慕尼黑的凯恩和格雷茨卡。",
        "role": "assistant"
      },
      {
   
        "content": "他们在做什么?",
        "role": "user"
      },
      {
   
        "content": "他们在足球场上庆祝。",
        "role": "assistant"
      }
    ],
    "images": [
      "mllm_demo_data/1.jpg"
    ]
  }
]

3.2 实现数据格式转换脚本

from datasets import load_dataset
import os
import json
from PIL import Image

def save_images_and_json(ds, output_dir="mllm_data"):
    """
    将数据集中的图像和对应的 JSON 信息保存到指定目录。

    参数:
    ds: 数据集对象,包含图像和标题。
    output_dir: 输出目录,默认为 "mllm_data"。
    """
    # 创建输出目录
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # 创建一个列表来存储所有的消息和图像信息
    all_data = []

    # 遍历数据集中的每个项目
    for item in ds:
        img_path = f"{output_dir}/{item['id']}.jpg"  # 图像保存路径
        image = item["image"]  # 假设这里是一个 PIL 图像对象

        # 将图像对象保存为文件
        image.save(img_path)  # 使用 PIL 的 save 方法

        # 添加消息和图像信息到列表中
        all_data.append(
            {
   
                "messages": [
                    {
   
                        "content": "<image>图片中的诊断结果是怎样?",
                        "role": "user",
                    },
                    {
   
                        "content": item["caption"],  # 从数据集中获取的标题
                        "role": "assistant",
                    },
                ],
                "images": [img_path],  # 图像文件路径
            }
        )

    # 创建 JSON 文件
    json_file_path = f"{output_dir}/mllm_data.json"
    with open(json_file_path, "w", encoding='utf-8') as f:
        json.dump(all_data, f, ensure_ascii=False)  # 确保中文字符正常显示

if __name__ == "__main__":
    # 加载数据集
    ds = load_dataset("UCSC-VLAA/MedTrinity-25M", "25M_demo", cache_dir="cache")

    # 保存数据集中的图像和 JSON 信息
    save_images_and_json(ds['train'])

运行结果:
image.png

4. 模型下载

本次微调,我们使用阿里最新发布的多模态大模型:Qwen2-VL-2B-Instruct 作为底座模型。
模型说明地址https://modelscope.cn/models/Qwen/Qwen2-VL-2B-Instruct

使用如下命令下载模型

git lfs install
# 下载模型
git clone https://www.modelscope.cn/Qwen/Qwen2-VL-2B-Instruct.git

5. 环境准备

5.1 机器环境

硬件:

  • 显卡:4080 Super
  • 显存:16GB

软件:

  • 系统:Ubuntu 20.04 LTS
  • python:3.10
  • pytorch:2.1.2 + cuda12.1

5.2 准备虚拟环境

# 创建python3.10版本虚拟环境
conda create --name train_env python=3.10

# 激活环境
conda activate train_env

# 安装依赖包
pip install streamlit torch torchvision

# 安装Qwen2建议的transformers版本
pip install git+https://github.com/huggingface/transformers

6. 准备训练框架

下载并安装LLamaFactory框架的具体步骤,请见【课程总结】day24(上):大模型三阶段训练方法(LLaMa Factory)准备训练框架 部分内容,本章不再赘述。

6.1 修改LLaMaFactory源码以适配transformer

由于Qwen2-VL使用的transformer的版本为4.47.0.dev0,LLamaFactory还不支持,所以需要修改LLaMaFactory的代码,具体方法如下:

第一步:在 llamafactory 源码中,找到 check_dependencies() 函数,这个函数位于 src/llamafactory/extras/misc.py 文件的第 82 行。

第二步:修改 check_dependencies() 函数并保存

# 原始代码
require_version("transformers>=4.41.2,<=4.45.2", "To fix: pip install transformers>=4.41.2,<=4.45.2")
# 修改后代码
require_version("transformers>=4.41.2,<=4.47.0", "To fix: pip install transformers>=4.41.2,<=4.47.0")

第三步:重新启动LLaMaFactory服务

llamafactory-cli webui

这个过程可能会提示 ImportError: accelerate>=0.34.0 is required for a normal functioning of this module, but found accelerate==0.32.0.
如遇到上述问题,可以重新安装accelerate,如下:

# 卸载旧的 accelerate
pip uninstall accelerate

# 安装新的 accelerate
pip install accelerate==0.34.0

7. 测试当前模型

第一步:启动LLaMa Factory后,访问http://0.0.0.0:7860

第二步:在web页面配置模型路径为 4.步骤 下载的模型路径,并点击加载模型

第三步:上传一张CT图片并输入问题:“请使用中文描述下这个图像并给出你的诊断结果”

由上图可以看到,模型能够识别到这是一个CT图像,显示了大概的位置以及相应的器官,但是并不能给出是否存在诊断结果。

8. 模型训练

8.1 数据准备

第一步:将 3.2步骤 生成的mllm_data文件拷贝到LLaMaFactory的data目录下

第二步:将 4.步骤 下载的底座模型Qwen2-VL 拷贝到LLaMaFactory的model目录下

第三步:修改 LLaMaFactory data目录下的dataset_info.json,增加自定义数据集:

  "mllm_med": {
   
    "file_name": "mllm_data/mllm_data.json",
    "formatting": "sharegpt",
    "columns": {
   
      "messages": "messages",
      "images": "images"
    },
    "tags": {
   
      "role_tag": "role",
      "content_tag": "content",
      "user_tag": "user",
      "assistant_tag": "assistant"
    }
  },

8.2 配置训练参数

访问LLaMaFactory的web页面,配置微调的训练参数:

  • Model name: Qwen2-VL-2B-Instruct
  • Model path: models/Qwen2-VL-2B-Instruct
  • Finetuning method: lora
  • Stage : Supervised Fine-Tuning
  • Dataset: mllm_med
  • Output dir: saves/Qwen2-VL/lora/Qwen2-VL-sft-demo1

配置参数中最好将 save_steps 设置大一点,否则训练过程会生成非常多的训练日志,导致硬盘空间不足而训练终止。

点击Preview Command预览命令行无误后,点击Run按钮开始训练。
训练参数

llamafactory-cli train \
    --do_train True \
    --model_name_or_path models/Qwen2-VL-2B-Instruct \
    --preprocessing_num_workers 16 \
    --finetuning_type lora \
    --template qwen2_vl \
    --flash_attn auto \
    --dataset_dir data \
    --dataset mllm_med \
    --cutoff_len 1024 \
    --learning_rate 5e-05 \
    --num_train_epochs 3.0 \
    --max_samples 100000 \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 8 \
    --lr_scheduler_type cosine \
    --max_grad_norm 1.0 \
    --logging_steps 5 \
    --save_steps 3000 \
    --warmup_steps 0 \
    --optim adamw_torch \
    --packing False \
    --report_to none \
    --output_dir saves/Qwen2-VL-2B/full/Qwen2-VL-sft-demo1 \
    --bf16 True \
    --plot_loss True \
    --ddp_timeout 180000000 \
    --include_num_input_tokens_seen True \
    --lora_rank 8 \
    --lora_alpha 16 \
    --lora_dropout 0 \
    --lora_target all

训练过程

训练的过程中,可以通过 watch -n 1 nvidia-smi 实时查看GPU显存的消耗情况。

经过35小时的训练,模型训练完成,损失函数如下:

损失函数一般降低至1.2左右,太低会导致模型过拟合。

8.3 合并导出模型

接下来,我们将 Lora补丁原始模型 合并导出:

  1. 切换到 Expert 标签下
  2. Model path: 选择Qwen2-VL的基座模型,即:models/Qwen2-VL-2B-Instruct
  3. Checkpoint path: 选择lora微调的输出路径,即 saves/Qwen2-VL/lora/Qwen2-VL-sft-demo1
  4. Export path:设置一个新的路径,例如:Qwen2-VL-sft-final
  5. 点击 开始导出 按钮


导出完毕后,会在LLaMaFactory的根目录下生成一个 Qwen2-VL-sft-final 的文件夹。

9. 模型验证

9.1 模型效果对比

第一步:在LLaMa Factory中卸载之前的模型

第二步:在LLaMa Factory中加载导出的模型,并配置模型路径为 Qwen2-VL-sft-final

第三步:加载模型并上传之前的CT图片提问同样的问题


可以看到,经过微调后的模型,可以给出具体区域存在的可能异常问题。

9.2 实际诊断

接下来,我将使用微调后的模型,为家里老爷子的CT片做诊断,看看模型给出的诊断与大夫的异同点。

我总计测试了CT片上的52张局部结果,其中具有代表性的为上述三张,可以看到模型还是比较准确地诊断出:脑部有垂体瘤,可能会影响到眼部。这与大夫给出的诊断和后续检查方案一致。

不足之处

训练集:

  • 多模态:本次训练只是采用了MedTrinity-25Mdemo数据集,如果使用MedTrinity-25Mfull数据集,效果应该会更好。
  • 中英文:本次训练集中使用的MedTrinity-25Mdemo数据集,只包含了英文数据,如果将英文标注翻译为中文,提供中英文双文数据集,相信效果会更好。
  • 对话数据集:本次训练只是使用了多模态数据集,如果增加中文对话(如:中文医疗对话数据-Chinese-medical-dialogue),相信效果会更好。

前端页面:

  • 前端页面:本次实践曾使用streamlit构建前端页面,以便图片上传和问题提出,但是在加载微调后的模型时,会出现:ValueError: No chat template is set for this processor 问题,所以转而使用LLaMaFactory的web页面进行展示。
  • 多个图片推理:在Qwen2-VL的官方指导文档中,提供了 Multi image inference 方法,本次未进行尝试,相信将多个图片交给大模型进行推理,效果会更好。

内容小结

  • Qwen2-VL-2B作为多模态大模型,具备有非常强的多模态处理能力,除了能够识别图片内容,还可以进行相关的推理。
  • 我们可以通过 LLaMaFactory 对模型进行微调,使得其具备医疗方面的处理能力。
  • 微调数据集采用开源的MedTrinity-25M数据集,该数据集有两个版本:25Mdemo和25Mfull。
  • 训练前需要对数据集进行预处理,使得其适配LLaMaFactory的微调格式。
  • 经过微调后的多模态大模型,不但可以详细地描述图片中的内容,还可以给出可能的诊断结果。
相关文章
|
4天前
|
人工智能 自动驾驶 大数据
预告 | 阿里云邀您参加2024中国生成式AI大会上海站,马上报名
大会以“智能跃进 创造无限”为主题,设置主会场峰会、分会场研讨会及展览区,聚焦大模型、AI Infra等热点议题。阿里云智算集群产品解决方案负责人丛培岩将出席并发表《高性能智算集群设计思考与实践》主题演讲。观众报名现已开放。
|
21天前
|
存储 人工智能 弹性计算
阿里云弹性计算_加速计算专场精华概览 | 2024云栖大会回顾
2024年9月19-21日,2024云栖大会在杭州云栖小镇举行,阿里云智能集团资深技术专家、异构计算产品技术负责人王超等多位产品、技术专家,共同带来了题为《AI Infra的前沿技术与应用实践》的专场session。本次专场重点介绍了阿里云AI Infra 产品架构与技术能力,及用户如何使用阿里云灵骏产品进行AI大模型开发、训练和应用。围绕当下大模型训练和推理的技术难点,专家们分享了如何在阿里云上实现稳定、高效、经济的大模型训练,并通过多个客户案例展示了云上大模型训练的显著优势。
|
24天前
|
存储 人工智能 调度
阿里云吴结生:高性能计算持续创新,响应数据+AI时代的多元化负载需求
在数字化转型的大潮中,每家公司都在积极探索如何利用数据驱动业务增长,而AI技术的快速发展更是加速了这一进程。
|
16天前
|
并行计算 前端开发 物联网
全网首发!真·从0到1!万字长文带你入门Qwen2.5-Coder——介绍、体验、本地部署及简单微调
2024年11月12日,阿里云通义大模型团队正式开源通义千问代码模型全系列,包括6款Qwen2.5-Coder模型,每个规模包含Base和Instruct两个版本。其中32B尺寸的旗舰代码模型在多项基准评测中取得开源最佳成绩,成为全球最强开源代码模型,多项关键能力超越GPT-4o。Qwen2.5-Coder具备强大、多样和实用等优点,通过持续训练,结合源代码、文本代码混合数据及合成数据,显著提升了代码生成、推理和修复等核心任务的性能。此外,该模型还支持多种编程语言,并在人类偏好对齐方面表现出色。本文为周周的奇妙编程原创,阿里云社区首发,未经同意不得转载。
11572 11
|
9天前
|
人工智能 自然语言处理 前端开发
100个降噪蓝牙耳机免费领,用通义灵码从 0 开始打造一个完整APP
打开手机,录制下你完成的代码效果,发布到你的社交媒体,前 100 个@玺哥超Carry、@通义灵码的粉丝,可以免费获得一个降噪蓝牙耳机。
4053 13
|
16天前
|
人工智能 自然语言处理 前端开发
用通义灵码,从 0 开始打造一个完整APP,无需编程经验就可以完成
通义灵码携手科技博主@玺哥超carry 打造全网第一个完整的、面向普通人的自然语言编程教程。完全使用 AI,再配合简单易懂的方法,只要你会打字,就能真正做出一个完整的应用。本教程完全免费,而且为大家准备了 100 个降噪蓝牙耳机,送给前 100 个完成的粉丝。获奖的方式非常简单,只要你跟着教程完成第一课的内容就能获得。
6783 10
|
28天前
|
缓存 监控 Linux
Python 实时获取Linux服务器信息
Python 实时获取Linux服务器信息
|
14天前
|
人工智能 自然语言处理 前端开发
什么?!通义千问也可以在线开发应用了?!
阿里巴巴推出的通义千问,是一个超大规模语言模型,旨在高效处理信息和生成创意内容。它不仅能在创意文案、办公助理、学习助手等领域提供丰富交互体验,还支持定制化解决方案。近日,通义千问推出代码模式,基于Qwen2.5-Coder模型,用户即使不懂编程也能用自然语言生成应用,如个人简历、2048小游戏等。该模式通过预置模板和灵活的自定义选项,极大简化了应用开发过程,助力用户快速实现创意。
|
3天前
|
机器学习/深度学习 人工智能 安全
通义千问开源的QwQ模型,一个会思考的AI,百炼邀您第一时间体验
Qwen团队推出新成员QwQ-32B-Preview,专注于增强AI推理能力。通过深入探索和试验,该模型在数学和编程领域展现了卓越的理解力,但仍在学习和完善中。目前,QwQ-32B-Preview已上线阿里云百炼平台,提供免费体验。
|
10天前
|
人工智能 C++ iOS开发
ollama + qwen2.5-coder + VS Code + Continue 实现本地AI 辅助写代码
本文介绍在Apple M4 MacOS环境下搭建Ollama和qwen2.5-coder模型的过程。首先通过官网或Brew安装Ollama,然后下载qwen2.5-coder模型,可通过终端命令`ollama run qwen2.5-coder`启动模型进行测试。最后,在VS Code中安装Continue插件,并配置qwen2.5-coder模型用于代码开发辅助。
729 5