【项目实战】通过LLaMaFactory+Qwen2-VL-2B微调一个多模态医疗大模型

本文涉及的产品
NLP 自学习平台,3个模型定制额度 1个月
NLP自然语言处理_高级版,每接口累计50万次
视觉智能开放平台,图像资源包5000点
简介: 本文介绍了一个基于多模态大模型的医疗图像诊断项目。项目旨在通过训练一个医疗领域的多模态大模型,提高医生处理医学图像的效率,辅助诊断和治疗。作者以家中老人的脑部CT为例,展示了如何利用MedTrinity-25M数据集训练模型,经过数据准备、环境搭建、模型训练及微调、最终验证等步骤,成功使模型能够识别CT图像并给出具体的诊断意见,与专业医生的诊断结果高度吻合。

前言

随着多模态大模型的发展,其不仅限于文字处理,更能够在图像、视频、音频方面进行识别与理解。医疗领域中,医生们往往需要对各种医学图像进行处理,以辅助诊断和治疗。如果将多模态大模型与图像诊断相结合,那么这会极大地提升诊断效率。

项目目标

训练一个医疗多模态大模型,用于图像诊断。

刚好家里老爷子近期略感头疼,去医院做了脑部CT,诊断患有垂体瘤,我将尝试使用多模态大模型进行进一步诊断。

实现过程

1. 数据集准备

为了训练模型,需要准备大量的医学图像数据。通过搜索我们找到以下训练数据:

数据名称:MedTrinity-25M
数据地址https://github.com/UCSC-VLAA/MedTrinity-25M
数据简介:MedTrinity-25M数据集是一个用于医学图像分析和计算机视觉研究的大型数据集。
数据来源:该数据集由加州大学圣克鲁兹分校(UCSC)提供,旨在促进医学图像处理和分析的研究。
数据量:MedTrinity-25M包含约2500万条医学图像数据,涵盖多种医学成像技术,如CT、MRI和超声等。
数据内容
该数据集有两份,分别是 25Mdemo25Mfull

25Mdemo (约162,000条)数据集内容如下:

25Mfull (约24,800,000条)数据集内容如下:

2. 数据下载

2.1 安装Hugging Face的Datasets库

pip install datasets

2.2 下载数据集

from datasets import load_dataset

# 加载数据集
ds = load_dataset("UCSC-VLAA/MedTrinity-25M", "25M_demo", cache_dir="cache")

执行结果:

说明:

  • 以上方法是使用HuggingFace的Datasets库下载数据集,下载的路径为当前脚本所在路径下的cache文件夹。
  • 使用HuggingFace下载需要能够访问https://huggingface.co/ 并且在网站上申请数据集读取权限才可以。
  • 如果没有权限访问HuggingFace,可以关注"一起AI技术"公众号后,回复 “MedTrinity”获取百度网盘下载地址。

2.3 预览数据集

# 查看训练集的前1个样本
print(ds['train'][:1])

运行结果:

{
   
    'image': [<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512 at 0x15DD6D06530>], 
    'id': ['8031efe0-1b5c-11ef-8929-000066532cad'], 
    'caption': ['The image is a non-contrasted computed tomography (CT) scan of the brain, showing the cerebral structures without any medical devices present. The region of interest, located centrally and in the middle of the image, exhibits an area of altered density, which is indicative of a brain hemorrhage. This area is distinct from the surrounding brain tissue, suggesting a possible hematoma or bleeding within the brain parenchyma. The location and characteristics of this abnormality may suggest a relationship with the surrounding brain tissue, potentially causing a mass effect or contributing to increased intracranial pressure.'
    ]
}

使用如下命令对数据集的图片进行可视化查看:

# 可视化image内容
from PIL import Image
import matplotlib.pyplot as plt

image = ds['train'][0]['image']  # 获取第一张图像

plt.imshow(image)
plt.axis('off')  # 不显示坐标轴
plt.show()

运行结果:

3. 数据预处理

由于后续我们要通过LLama Factory进行多模态大模型微调,所以我们需要对上述的数据集进行预处理以符合LLama Factory的要求。

3.1 LLama Factory数据格式

查看LLama Factory的多模态数据格式要求如下:

[
  {
   
    "messages": [
      {
   
        "content": "<image>他们是谁?",
        "role": "user"
      },
      {
   
        "content": "他们是拜仁慕尼黑的凯恩和格雷茨卡。",
        "role": "assistant"
      },
      {
   
        "content": "他们在做什么?",
        "role": "user"
      },
      {
   
        "content": "他们在足球场上庆祝。",
        "role": "assistant"
      }
    ],
    "images": [
      "mllm_demo_data/1.jpg"
    ]
  }
]

3.2 实现数据格式转换脚本

from datasets import load_dataset
import os
import json
from PIL import Image

def save_images_and_json(ds, output_dir="mllm_data"):
    """
    将数据集中的图像和对应的 JSON 信息保存到指定目录。

    参数:
    ds: 数据集对象,包含图像和标题。
    output_dir: 输出目录,默认为 "mllm_data"。
    """
    # 创建输出目录
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # 创建一个列表来存储所有的消息和图像信息
    all_data = []

    # 遍历数据集中的每个项目
    for item in ds:
        img_path = f"{output_dir}/{item['id']}.jpg"  # 图像保存路径
        image = item["image"]  # 假设这里是一个 PIL 图像对象

        # 将图像对象保存为文件
        image.save(img_path)  # 使用 PIL 的 save 方法

        # 添加消息和图像信息到列表中
        all_data.append(
            {
   
                "messages": [
                    {
   
                        "content": "<image>图片中的诊断结果是怎样?",
                        "role": "user",
                    },
                    {
   
                        "content": item["caption"],  # 从数据集中获取的标题
                        "role": "assistant",
                    },
                ],
                "images": [img_path],  # 图像文件路径
            }
        )

    # 创建 JSON 文件
    json_file_path = f"{output_dir}/mllm_data.json"
    with open(json_file_path, "w", encoding='utf-8') as f:
        json.dump(all_data, f, ensure_ascii=False)  # 确保中文字符正常显示

if __name__ == "__main__":
    # 加载数据集
    ds = load_dataset("UCSC-VLAA/MedTrinity-25M", "25M_demo", cache_dir="cache")

    # 保存数据集中的图像和 JSON 信息
    save_images_and_json(ds['train'])

运行结果:
image.png

4. 模型下载

本次微调,我们使用阿里最新发布的多模态大模型:Qwen2-VL-2B-Instruct 作为底座模型。
模型说明地址https://modelscope.cn/models/Qwen/Qwen2-VL-2B-Instruct

使用如下命令下载模型

git lfs install
# 下载模型
git clone https://www.modelscope.cn/Qwen/Qwen2-VL-2B-Instruct.git

5. 环境准备

5.1 机器环境

硬件:

  • 显卡:4080 Super
  • 显存:16GB

软件:

  • 系统:Ubuntu 20.04 LTS
  • python:3.10
  • pytorch:2.1.2 + cuda12.1

5.2 准备虚拟环境

# 创建python3.10版本虚拟环境
conda create --name train_env python=3.10

# 激活环境
conda activate train_env

# 安装依赖包
pip install streamlit torch torchvision

# 安装Qwen2建议的transformers版本
pip install git+https://github.com/huggingface/transformers

6. 准备训练框架

下载并安装LLamaFactory框架的具体步骤,请见【课程总结】day24(上):大模型三阶段训练方法(LLaMa Factory)准备训练框架 部分内容,本章不再赘述。

6.1 修改LLaMaFactory源码以适配transformer

由于Qwen2-VL使用的transformer的版本为4.47.0.dev0,LLamaFactory还不支持,所以需要修改LLaMaFactory的代码,具体方法如下:

第一步:在 llamafactory 源码中,找到 check_dependencies() 函数,这个函数位于 src/llamafactory/extras/misc.py 文件的第 82 行。

第二步:修改 check_dependencies() 函数并保存

# 原始代码
require_version("transformers>=4.41.2,<=4.45.2", "To fix: pip install transformers>=4.41.2,<=4.45.2")
# 修改后代码
require_version("transformers>=4.41.2,<=4.47.0", "To fix: pip install transformers>=4.41.2,<=4.47.0")

第三步:重新启动LLaMaFactory服务

llamafactory-cli webui

这个过程可能会提示 ImportError: accelerate>=0.34.0 is required for a normal functioning of this module, but found accelerate==0.32.0.
如遇到上述问题,可以重新安装accelerate,如下:

# 卸载旧的 accelerate
pip uninstall accelerate

# 安装新的 accelerate
pip install accelerate==0.34.0

7. 测试当前模型

第一步:启动LLaMa Factory后,访问http://0.0.0.0:7860

第二步:在web页面配置模型路径为 4.步骤 下载的模型路径,并点击加载模型

第三步:上传一张CT图片并输入问题:“请使用中文描述下这个图像并给出你的诊断结果”

由上图可以看到,模型能够识别到这是一个CT图像,显示了大概的位置以及相应的器官,但是并不能给出是否存在诊断结果。

8. 模型训练

8.1 数据准备

第一步:将 3.2步骤 生成的mllm_data文件拷贝到LLaMaFactory的data目录下

第二步:将 4.步骤 下载的底座模型Qwen2-VL 拷贝到LLaMaFactory的model目录下

第三步:修改 LLaMaFactory data目录下的dataset_info.json,增加自定义数据集:

  "mllm_med": {
   
    "file_name": "mllm_data/mllm_data.json",
    "formatting": "sharegpt",
    "columns": {
   
      "messages": "messages",
      "images": "images"
    },
    "tags": {
   
      "role_tag": "role",
      "content_tag": "content",
      "user_tag": "user",
      "assistant_tag": "assistant"
    }
  },

8.2 配置训练参数

访问LLaMaFactory的web页面,配置微调的训练参数:

  • Model name: Qwen2-VL-2B-Instruct
  • Model path: models/Qwen2-VL-2B-Instruct
  • Finetuning method: lora
  • Stage : Supervised Fine-Tuning
  • Dataset: mllm_med
  • Output dir: saves/Qwen2-VL/lora/Qwen2-VL-sft-demo1

配置参数中最好将 save_steps 设置大一点,否则训练过程会生成非常多的训练日志,导致硬盘空间不足而训练终止。

点击Preview Command预览命令行无误后,点击Run按钮开始训练。
训练参数

llamafactory-cli train \
    --do_train True \
    --model_name_or_path models/Qwen2-VL-2B-Instruct \
    --preprocessing_num_workers 16 \
    --finetuning_type lora \
    --template qwen2_vl \
    --flash_attn auto \
    --dataset_dir data \
    --dataset mllm_med \
    --cutoff_len 1024 \
    --learning_rate 5e-05 \
    --num_train_epochs 3.0 \
    --max_samples 100000 \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 8 \
    --lr_scheduler_type cosine \
    --max_grad_norm 1.0 \
    --logging_steps 5 \
    --save_steps 3000 \
    --warmup_steps 0 \
    --optim adamw_torch \
    --packing False \
    --report_to none \
    --output_dir saves/Qwen2-VL-2B/full/Qwen2-VL-sft-demo1 \
    --bf16 True \
    --plot_loss True \
    --ddp_timeout 180000000 \
    --include_num_input_tokens_seen True \
    --lora_rank 8 \
    --lora_alpha 16 \
    --lora_dropout 0 \
    --lora_target all

训练过程

训练的过程中,可以通过 watch -n 1 nvidia-smi 实时查看GPU显存的消耗情况。

经过35小时的训练,模型训练完成,损失函数如下:

损失函数一般降低至1.2左右,太低会导致模型过拟合。

8.3 合并导出模型

接下来,我们将 Lora补丁原始模型 合并导出:

  1. 切换到 Expert 标签下
  2. Model path: 选择Qwen2-VL的基座模型,即:models/Qwen2-VL-2B-Instruct
  3. Checkpoint path: 选择lora微调的输出路径,即 saves/Qwen2-VL/lora/Qwen2-VL-sft-demo1
  4. Export path:设置一个新的路径,例如:Qwen2-VL-sft-final
  5. 点击 开始导出 按钮


导出完毕后,会在LLaMaFactory的根目录下生成一个 Qwen2-VL-sft-final 的文件夹。

9. 模型验证

9.1 模型效果对比

第一步:在LLaMa Factory中卸载之前的模型

第二步:在LLaMa Factory中加载导出的模型,并配置模型路径为 Qwen2-VL-sft-final

第三步:加载模型并上传之前的CT图片提问同样的问题


可以看到,经过微调后的模型,可以给出具体区域存在的可能异常问题。

9.2 实际诊断

接下来,我将使用微调后的模型,为家里老爷子的CT片做诊断,看看模型给出的诊断与大夫的异同点。

我总计测试了CT片上的52张局部结果,其中具有代表性的为上述三张,可以看到模型还是比较准确地诊断出:脑部有垂体瘤,可能会影响到眼部。这与大夫给出的诊断和后续检查方案一致。

不足之处

训练集:

  • 多模态:本次训练只是采用了MedTrinity-25Mdemo数据集,如果使用MedTrinity-25Mfull数据集,效果应该会更好。
  • 中英文:本次训练集中使用的MedTrinity-25Mdemo数据集,只包含了英文数据,如果将英文标注翻译为中文,提供中英文双文数据集,相信效果会更好。
  • 对话数据集:本次训练只是使用了多模态数据集,如果增加中文对话(如:中文医疗对话数据-Chinese-medical-dialogue),相信效果会更好。

前端页面:

  • 前端页面:本次实践曾使用streamlit构建前端页面,以便图片上传和问题提出,但是在加载微调后的模型时,会出现:ValueError: No chat template is set for this processor 问题,所以转而使用LLaMaFactory的web页面进行展示。
  • 多个图片推理:在Qwen2-VL的官方指导文档中,提供了 Multi image inference 方法,本次未进行尝试,相信将多个图片交给大模型进行推理,效果会更好。

内容小结

  • Qwen2-VL-2B作为多模态大模型,具备有非常强的多模态处理能力,除了能够识别图片内容,还可以进行相关的推理。
  • 我们可以通过 LLaMaFactory 对模型进行微调,使得其具备医疗方面的处理能力。
  • 微调数据集采用开源的MedTrinity-25M数据集,该数据集有两个版本:25Mdemo和25Mfull。
  • 训练前需要对数据集进行预处理,使得其适配LLaMaFactory的微调格式。
  • 经过微调后的多模态大模型,不但可以详细地描述图片中的内容,还可以给出可能的诊断结果。
相关文章
|
8月前
|
文字识别 前端开发
CodeFuse-VLM 开源,支持多模态多任务预训练/微调
随着huggingface开源社区的不断更新,会有更多的vision encoder 和 LLM 底座发布,这些vision encoder 和 LLM底座都有各自的强项,例如 code-llama 适合生成代码类任务,但是不适合生成中文类的任务,因此用户常常需要根据vision encoder和LLM的特长来搭建自己的多模态大语言模型。针对多模态大语言模型种类繁多的落地场景,我们搭建了CodeFuse-VLM 框架,支持多种视觉模型和语言大模型,使得MFT-VLM可以适应不同种类的任务。
763 0
|
8月前
|
自然语言处理 物联网 Swift
零一万物开源Yi-VL多模态大模型,魔搭社区推理&微调最佳实践来啦!
近期,零一万物Yi系列模型家族发布了其多模态大模型系列,Yi Vision Language(Yi-VL)多模态语言大模型正式面向全球开源。
|
4月前
|
机器学习/深度学习 人工智能 分布式计算
使用PAI+LLaMA Factory 微调 Qwen2-VL 模型,搭建文旅领域知识问答机器人
本次教程介绍了如何使用 PAI ×LLaMA Factory 框架,基于全参方法微调 Qwen2-VL 模型,使其能够进行文旅领域知识问答,同时通过人工测试验证了微调的效果。
使用PAI+LLaMA Factory 微调 Qwen2-VL 模型,搭建文旅领域知识问答机器人
|
22天前
|
编解码 机器人 测试技术
技术实践 | 使用 PAI+LLaMA Factory 微调 Qwen2-VL 模型快速搭建专业领域知识问答机器人
Qwen2-VL是一款具备高级图像和视频理解能力的多模态模型,支持多种语言,适用于多模态应用开发。通过PAI和LLaMA Factory框架,用户可以轻松微调Qwen2-VL模型,快速构建文旅领域的知识问答机器人。本教程详细介绍了从模型部署、微调到对话测试的全过程,帮助开发者高效实现定制化多模态应用。
|
4月前
|
搜索推荐 算法
模型小,还高效!港大最新推荐系统EasyRec:零样本文本推荐能力超越OpenAI、Bert
【9月更文挑战第21天】香港大学研究者开发了一种名为EasyRec的新推荐系统,利用语言模型的强大文本理解和生成能力,解决了传统推荐算法在零样本学习场景中的局限。EasyRec通过文本-行为对齐框架,结合对比学习和协同语言模型调优,提升了推荐准确性。实验表明,EasyRec在多个真实世界数据集上的表现优于现有模型,但其性能依赖高质量文本数据且计算复杂度较高。论文详见:http://arxiv.org/abs/2408.08821
106 7
|
3月前
|
机器学习/深度学习 人工智能 自然语言处理
【AI大模型】BERT模型:揭秘LLM主要类别架构(上)
【AI大模型】BERT模型:揭秘LLM主要类别架构(上)
|
5月前
|
机器学习/深度学习 数据采集 人工智能
【机器学习】QLoRA:基于PEFT亲手量化微调Qwen2大模型
【机器学习】QLoRA:基于PEFT亲手量化微调Qwen2大模型
383 0
【机器学习】QLoRA:基于PEFT亲手量化微调Qwen2大模型
|
5月前
|
机器学习/深度学习 人工智能 自然语言处理
【AI大模型】Transformers大模型库(八):大模型微调之LoraConfig
【AI大模型】Transformers大模型库(八):大模型微调之LoraConfig
89 0
|
5月前
|
机器学习/深度学习 存储 人工智能
【机器学习】Qwen1.5-14B-Chat大模型训练与推理实战
【机器学习】Qwen1.5-14B-Chat大模型训练与推理实战
549 0
|
5月前
|
机器学习/深度学习 数据采集 人工智能
【AI大模型】Transformers大模型库(十一):Trainer训练类
【AI大模型】Transformers大模型库(十一):Trainer训练类
126 0