LISA微调技术解析:比LoRA更低的显存更快的速度

本文涉及的产品
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
交互式建模 PAI-DSW,每月250计算时 3个月
模型训练 PAI-DLC,5000CU*H 3个月
简介: LISA是Layerwise Importance Sampling for Memory-Efficient Large Language Model Fine-Tuning的简写,由UIUC联合LMFlow团队于近期提出的一项LLM微调技术,可实现把全参训练的显存使用降低到之前的三分之一左右,而使用的技术方法却是非常简单。

背景介绍

image.png

论文地址:

https://arxiv.org/abs/2403.17919

LISA是Layerwise Importance Sampling for Memory-Efficient Large Language Model Fine-Tuning的简写,由UIUC联合LMFlow团队于近期提出的一项LLM微调技术,可实现把全参训练的显存使用降低到之前的三分之一左右,而使用的技术方法却是非常简单。例如,全参训练一个7b模型大约需要80G显存(相当于一张完整的A100显卡),但使用LISA训练后却可以使显存降低到30G左右,这使得使用40G A100显卡甚至是24G A10或者RTX 3090成为可能,且它的显存占用更低、训练速度更快。

技术背景

如果阅读者尚不熟悉深度学习基本原理,请参考魔搭社区提供的教程:

https://github.com/modelscope/modelscope-classroom/blob/main/LLM-tutorial/A.%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E5%85%A5%E9%97%A8%E4%BB%8B%E7%BB%8D.md

技术解析

LISA使用的技术原理相对简单。作者首先对LoRA训练和全参训练每个layer不同step时的L2范数的平均和进行了对比,结果如下:

8e421979-cf36-4a74-b319-a4498519b332[1].png

作者训练了GPT2和LLaMA-2-7B两个模型,发现它们自身不同layers的parameters的LoRA训练和全参训练的L2范数不同,可以间接说明LoRA训练中由于低秩矩阵的存在,因此其参数更新的重点和全参数更新重点完全不同。可以看出,在权重更新时,除底层和顶层外其它层的L2范数都较小,因此作者假设可以在全参数训练时通过冻结大部分层的参数来模拟LoRA更新的行为,使其最后的参数迭代范数达到类似的效果。

完整的算法迭代可以用下图表示:

image.png

实验

在官方实验中,作者对比了LISA和LoRA训练以及全参数的显存占用:

image.png

可以看到LISA的显存占用要小于LoRA。在训练速度上面:

image.png

官方实验结果,LISA的Forward和Backward时间要显著短于LoRA训练。在训练方面,作者进行不同尺寸的微调和大规模微调,均证明了LISA的效果要强于LoRA:

image.png

image.png

如何调节LISA的超参数呢?LISA的超参数包含两个值:

  • LISA采样的有效层数γ
  • LISA的更新频率K

消融实验对这两个值的对比如下:

image.png

可以看到LISA的性能在γ=8,采样频率K=5的时候达到最好。作者也证明,LISA对于不同的随机种子的鲁棒性很强,在此不列举表格。

魔搭社区实战评测

为了验证LISA在实际测试中的效果,我们对LISA进行了一定的实验。我们使用了魔搭社区提供的SWIFT框架(https://github.com/modelscope/swift),该框架支持LISA训练方式,且支持LoRA等通用训练方式。我们可以设置LISA的两个值:

  • lisa_activated_layers 上文的γ
  • lisa_step_interval 上文的K

我们使用如下命令进行训练:

# pip install ms-swift -U
sft.py \
 --model_type qwen-7b-chat \
 --dataset ms-agent \
 --train_dataset_mix_ratio 2.0 \
 --batch_size 1 \
 --max_length 2048 \
 --use_loss_scale True \
 --gradient_accumulation_steps 16 \
 --learning_rate 5e-05 \
 --use_flash_attn True \
 --eval_steps 2000 \
 --save_steps 2000 \
 --train_dataset_sample -1 \
 --val_dataset_sample 5000 \
 --num_train_epochs 2 \
 --check_dataset_strategy none \
 --gradient_checkpointing True \
 --weight_decay 0.01 \
 --warmup_ratio 0.03 \
 --save_total_limit 2 \
 --logging_steps 10 \
 --sft_type full \
 --lisa_activated_layers 2 \
 --lisa_step_interval 20

同时,我们将--lisa_activated_layers置为0,进行全参数训练,并且使用r=8进行了LoRA训练,得到的效果如下:

image.png

从我们的实验中可以看到下面的结论:

  1. 在显存占用中,全参数几乎是其他轻量训练方式显存占用的一倍,但是在loss中也是最低的,这说明全参数在模型训练的基础指标中仍然是最优的;
  2. LISA的显存使用比r=8(这是个常用值)的显存占用要低,其中lisa_activated_layers越低显存越低;
  3. 训练速度上LISA的训练速度也比LoRA要快一些,并且该指标也受到lisa_activated_layers的影响;
  4. 在评估指标上,LoRA更为优秀,然而评估指标受到数据集的强烈影响,由于训练主要内容是Agent数据集,因此说明LoRA在防止灾难性遗忘上具有一定的优势

image.png

LISA lisa_activated_layers=2 训练的loss

image.png

LoRA r=8 训练的loss

可以观察到LISA的训练loss较LoRA曲线更为抖动一些,猜测可能是LISA随机挑选layer进行反向传播的随机性造成的。

结论

可以看到LISA作为2024年的新晋tuner,使用一个非常简单的方式做到了部分数据集的SOTA,同时显存使用和训练速度也是很优秀的,且没有额外的使用条件。然而LISA仍然存在着一些可以分析讨论的问题,比如:是否可以通过参数范数或者参数矩阵特征值判断哪些layers应该被反向传播?是否可以在更细粒度上(qkv/mlp/layernorm)层面上控制反向传播?

如果有做过实验的同学欢迎留言讨论。

相关文章
|
23天前
|
数据库 索引
深入探索数据库索引技术:回表与索引下推解析
【10月更文挑战第15天】在数据库查询优化的领域中,回表和索引下推是两个核心概念,它们对于提高查询性能至关重要。本文将详细解释这两个术语,并探讨它们在数据库操作中的作用和影响。
43 3
|
7天前
|
网络协议 网络安全 网络虚拟化
本文介绍了十个重要的网络技术术语,包括IP地址、子网掩码、域名系统(DNS)、防火墙、虚拟专用网络(VPN)、路由器、交换机、超文本传输协议(HTTP)、传输控制协议/网际协议(TCP/IP)和云计算
本文介绍了十个重要的网络技术术语,包括IP地址、子网掩码、域名系统(DNS)、防火墙、虚拟专用网络(VPN)、路由器、交换机、超文本传输协议(HTTP)、传输控制协议/网际协议(TCP/IP)和云计算。通过这些术语的详细解释,帮助读者更好地理解和应用网络技术,应对数字化时代的挑战和机遇。
36 3
|
7天前
|
存储 网络协议 安全
30 道初级网络工程师面试题,涵盖 OSI 模型、TCP/IP 协议栈、IP 地址、子网掩码、VLAN、STP、DHCP、DNS、防火墙、NAT、VPN 等基础知识和技术,帮助小白们充分准备面试,顺利踏入职场
本文精选了 30 道初级网络工程师面试题,涵盖 OSI 模型、TCP/IP 协议栈、IP 地址、子网掩码、VLAN、STP、DHCP、DNS、防火墙、NAT、VPN 等基础知识和技术,帮助小白们充分准备面试,顺利踏入职场。
21 2
|
10天前
|
监控 关系型数据库 MySQL
MySQL自增ID耗尽应对策略:技术解决方案全解析
在数据库管理中,MySQL的自增ID(AUTO_INCREMENT)属性为表中的每一行提供了一个唯一的标识符。然而,当自增ID达到其最大值时,如何处理这一情况成为了数据库管理员和开发者必须面对的问题。本文将探讨MySQL自增ID耗尽的原因、影响以及有效的应对策略。
35 3
|
17天前
|
机器学习/深度学习 人工智能 自然语言处理
思通数科AI平台在尽职调查中的技术解析与应用
思通数科AI多模态能力平台结合OCR、NLP和深度学习技术,为IPO尽职调查、融资等重要交易环节提供智能化解决方案。平台自动识别、提取并分类海量文档,实现高效数据核验与合规性检查,显著提升审查速度和精准度,同时保障敏感信息管理和数据安全。
71 11
|
12天前
|
Kubernetes Cloud Native 云计算
云原生技术深度解析:重塑企业IT架构的未来####
本文深入探讨了云原生技术的核心理念、关键技术组件及其对企业IT架构转型的深远影响。通过剖析Kubernetes、微服务、容器化等核心技术,本文揭示了云原生如何提升应用的灵活性、可扩展性和可维护性,助力企业在数字化转型中保持领先地位。 ####
|
13天前
|
自然语言处理 并行计算 数据可视化
免费开源法律文档比对工具:技术解析与应用
这款免费开源的法律文档比对工具,利用先进的文本分析和自然语言处理技术,实现高效、精准的文档比对。核心功能包括文本差异检测、多格式支持、语义分析、批量处理及用户友好的可视化界面,广泛适用于法律行业的各类场景。
|
17天前
|
机器学习/深度学习 人工智能 自然语言处理
医疗行业的语音识别技术解析:AI多模态能力平台的应用与架构
AI多模态能力平台通过语音识别技术,实现实时转录医患对话,自动生成结构化数据,提高医疗效率。平台具备强大的环境降噪、语音分离及自然语言处理能力,支持与医院系统无缝集成,广泛应用于门诊记录、多学科会诊和急诊场景,显著提升工作效率和数据准确性。
|
20天前
|
监控 Cloud Native 持续交付
云原生技术深度解析:重塑现代应用开发与部署范式####
本文深入探讨了云原生技术的核心概念、关键技术组件及其在现代软件开发中的重要性。通过剖析容器化、微服务架构、持续集成/持续部署(CI/CD)等关键技术,本文旨在揭示云原生技术如何促进应用的敏捷性、可扩展性和高可用性,进而推动企业数字化转型进程。不同于传统摘要仅概述内容要点,本部分将融入具体案例分析,直观展示云原生技术在实际应用中的显著成效与挑战应对策略,为读者提供更加丰富、立体的理解视角。 ####
|
20天前
|
算法 Java 数据库连接
Java连接池技术,从基础概念出发,解析了连接池的工作原理及其重要性
本文详细介绍了Java连接池技术,从基础概念出发,解析了连接池的工作原理及其重要性。连接池通过复用数据库连接,显著提升了应用的性能和稳定性。文章还展示了使用HikariCP连接池的示例代码,帮助读者更好地理解和应用这一技术。
32 1

热门文章

最新文章

推荐镜像

更多