LISA微调技术解析:比LoRA更低的显存更快的速度

本文涉及的产品
交互式建模 PAI-DSW,每月250计算时 3个月
模型训练 PAI-DLC,100CU*H 3个月
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
简介: LISA是Layerwise Importance Sampling for Memory-Efficient Large Language Model Fine-Tuning的简写,由UIUC联合LMFlow团队于近期提出的一项LLM微调技术,可实现把全参训练的显存使用降低到之前的三分之一左右,而使用的技术方法却是非常简单。

背景介绍

image.png

论文地址:

https://arxiv.org/abs/2403.17919

LISA是Layerwise Importance Sampling for Memory-Efficient Large Language Model Fine-Tuning的简写,由UIUC联合LMFlow团队于近期提出的一项LLM微调技术,可实现把全参训练的显存使用降低到之前的三分之一左右,而使用的技术方法却是非常简单。例如,全参训练一个7b模型大约需要80G显存(相当于一张完整的A100显卡),但使用LISA训练后却可以使显存降低到30G左右,这使得使用40G A100显卡甚至是24G A10或者RTX 3090成为可能,且它的显存占用更低、训练速度更快。

技术背景

如果阅读者尚不熟悉深度学习基本原理,请参考魔搭社区提供的教程:

https://github.com/modelscope/modelscope-classroom/blob/main/LLM-tutorial/A.%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E5%85%A5%E9%97%A8%E4%BB%8B%E7%BB%8D.md

技术解析

LISA使用的技术原理相对简单。作者首先对LoRA训练和全参训练每个layer不同step时的L2范数的平均和进行了对比,结果如下:

8e421979-cf36-4a74-b319-a4498519b332[1].png

作者训练了GPT2和LLaMA-2-7B两个模型,发现它们自身不同layers的parameters的LoRA训练和全参训练的L2范数不同,可以间接说明LoRA训练中由于低秩矩阵的存在,因此其参数更新的重点和全参数更新重点完全不同。可以看出,在权重更新时,除底层和顶层外其它层的L2范数都较小,因此作者假设可以在全参数训练时通过冻结大部分层的参数来模拟LoRA更新的行为,使其最后的参数迭代范数达到类似的效果。

完整的算法迭代可以用下图表示:

image.png

实验

在官方实验中,作者对比了LISA和LoRA训练以及全参数的显存占用:

image.png

可以看到LISA的显存占用要小于LoRA。在训练速度上面:

image.png

官方实验结果,LISA的Forward和Backward时间要显著短于LoRA训练。在训练方面,作者进行不同尺寸的微调和大规模微调,均证明了LISA的效果要强于LoRA:

image.png

image.png

如何调节LISA的超参数呢?LISA的超参数包含两个值:

  • LISA采样的有效层数γ
  • LISA的更新频率K

消融实验对这两个值的对比如下:

image.png

可以看到LISA的性能在γ=8,采样频率K=5的时候达到最好。作者也证明,LISA对于不同的随机种子的鲁棒性很强,在此不列举表格。

魔搭社区实战评测

为了验证LISA在实际测试中的效果,我们对LISA进行了一定的实验。我们使用了魔搭社区提供的SWIFT框架(https://github.com/modelscope/swift),该框架支持LISA训练方式,且支持LoRA等通用训练方式。我们可以设置LISA的两个值:

  • lisa_activated_layers 上文的γ
  • lisa_step_interval 上文的K

我们使用如下命令进行训练:

# pip install ms-swift -U
sft.py \
 --model_type qwen-7b-chat \
 --dataset ms-agent \
 --train_dataset_mix_ratio 2.0 \
 --batch_size 1 \
 --max_length 2048 \
 --use_loss_scale True \
 --gradient_accumulation_steps 16 \
 --learning_rate 5e-05 \
 --use_flash_attn True \
 --eval_steps 2000 \
 --save_steps 2000 \
 --train_dataset_sample -1 \
 --val_dataset_sample 5000 \
 --num_train_epochs 2 \
 --check_dataset_strategy none \
 --gradient_checkpointing True \
 --weight_decay 0.01 \
 --warmup_ratio 0.03 \
 --save_total_limit 2 \
 --logging_steps 10 \
 --sft_type full \
 --lisa_activated_layers 2 \
 --lisa_step_interval 20

同时,我们将--lisa_activated_layers置为0,进行全参数训练,并且使用r=8进行了LoRA训练,得到的效果如下:

image.png

从我们的实验中可以看到下面的结论:

  1. 在显存占用中,全参数几乎是其他轻量训练方式显存占用的一倍,但是在loss中也是最低的,这说明全参数在模型训练的基础指标中仍然是最优的;
  2. LISA的显存使用比r=8(这是个常用值)的显存占用要低,其中lisa_activated_layers越低显存越低;
  3. 训练速度上LISA的训练速度也比LoRA要快一些,并且该指标也受到lisa_activated_layers的影响;
  4. 在评估指标上,LoRA更为优秀,然而评估指标受到数据集的强烈影响,由于训练主要内容是Agent数据集,因此说明LoRA在防止灾难性遗忘上具有一定的优势

image.png

LISA lisa_activated_layers=2 训练的loss

image.png

LoRA r=8 训练的loss

可以观察到LISA的训练loss较LoRA曲线更为抖动一些,猜测可能是LISA随机挑选layer进行反向传播的随机性造成的。

结论

可以看到LISA作为2024年的新晋tuner,使用一个非常简单的方式做到了部分数据集的SOTA,同时显存使用和训练速度也是很优秀的,且没有额外的使用条件。然而LISA仍然存在着一些可以分析讨论的问题,比如:是否可以通过参数范数或者参数矩阵特征值判断哪些layers应该被反向传播?是否可以在更细粒度上(qkv/mlp/layernorm)层面上控制反向传播?

如果有做过实验的同学欢迎留言讨论。

相关文章
|
18天前
|
机器学习/深度学习 人工智能 自然语言处理
AI技术深度解析:从基础到应用的全面介绍
人工智能(AI)技术的迅猛发展,正在深刻改变着我们的生活和工作方式。从自然语言处理(NLP)到机器学习,从神经网络到大型语言模型(LLM),AI技术的每一次进步都带来了前所未有的机遇和挑战。本文将从背景、历史、业务场景、Python代码示例、流程图以及如何上手等多个方面,对AI技术中的关键组件进行深度解析,为读者呈现一个全面而深入的AI技术世界。
88 10
|
1天前
|
自然语言处理 文字识别 数据处理
多模态文件信息抽取:技术解析与实践评测!
在大数据和人工智能时代,企业和开发者面临的挑战是如何高效处理多模态数据(文本、图像、音频、视频)以快速提取有价值信息。传统方法效率低下,难以满足现代需求。本文将深度评测阿里云的多模态文件信息抽取解决方案,涵盖部署、应用、功能与性能,揭示其在复杂数据处理中的潜力。通过自然语言处理(NLP)、计算机视觉(CV)、语音识别(ASR)等技术,该方案助力企业挖掘多模态数据的价值,提升数据利用效率。
12 4
多模态文件信息抽取:技术解析与实践评测!
|
4天前
|
域名解析 负载均衡 安全
DNS技术标准趋势和安全研究
本文探讨了互联网域名基础设施的结构性安全风险,由清华大学段教授团队多年研究总结。文章指出,DNS系统的安全性不仅受代码实现影响,更源于其设计、实现、运营及治理中的固有缺陷。主要风险包括协议设计缺陷(如明文传输)、生态演进隐患(如单点故障增加)和薄弱的信任关系(如威胁情报被操纵)。团队通过多项研究揭示了这些深层次问题,并呼吁构建更加可信的DNS基础设施,以保障全球互联网的安全稳定运行。
|
4天前
|
缓存 网络协议 安全
融合DNS技术产品和生态
本文介绍了阿里云在互联网基础资源领域的最新进展和解决方案,重点围绕共筑韧性寻址、赋能新质生产展开。随着应用规模的增长,基础服务的韧性变得尤为重要。阿里云作为互联网资源的践行者,致力于推动互联网基础资源技术研究和自主创新,打造更韧性的寻址基础服务。文章还详细介绍了浙江省IPv6创新实验室的成立背景与工作进展,以及阿里云在IPv6规模化部署、DNS产品能力升级等方面的成果。此外,阿里云通过端云融合场景下的企业级DNS服务,帮助企业构建稳定安全的DNS系统,确保企业在数字世界中的稳定运行。最后,文章强调了全链路极致高可用的企业DNS解决方案,为全球互联网基础资源的创新提供了中国标准和数字化解决方案。
|
4天前
|
缓存 边缘计算 网络协议
深入解析CDN技术:加速互联网内容分发的幕后英雄
内容分发网络(CDN)是现代互联网架构的重要组成部分,通过全球分布的服务器节点,加速网站、应用和多媒体内容的传递。它不仅提升了访问速度和用户体验,还减轻了源站服务器的负担。CDN的核心技术包括缓存机制、动态加速、流媒体加速和安全防护,广泛应用于静态资源、动态内容、视频直播及大文件下载等场景,具有低延迟、高带宽、稳定性强等优势,有效降低成本并保障安全。
25 3
|
25天前
|
机器学习/深度学习 人工智能 自然语言处理
秒级响应 + 99.9%准确率:法律行业文本比对技术解析
本工具基于先进AI技术,采用自然语言处理和语义匹配算法,支持PDF、Word等格式,实现法律文本的智能化比对。具备高精度语义匹配、多格式兼容、高性能架构及智能化标注与可视化等特点,有效解决文本复杂性和法规更新难题,提升法律行业工作效率。
|
22天前
|
数据采集 存储 JavaScript
网页爬虫技术全解析:从基础到实战
在信息爆炸的时代,网页爬虫作为数据采集的重要工具,已成为数据科学家、研究人员和开发者不可或缺的技术。本文全面解析网页爬虫的基础概念、工作原理、技术栈与工具,以及实战案例,探讨其合法性与道德问题,分享爬虫设计与实现的详细步骤,介绍优化与维护的方法,应对反爬虫机制、动态内容加载等挑战,旨在帮助读者深入理解并合理运用网页爬虫技术。
|
2月前
|
监控 算法 物联网
院内导航怎么实现?解析信息化医院的智慧导航技术
智慧医院院内导航系统通过高精度电子地图、室内定位技术和路径规划算法,提升了医疗服务质量和患者就医体验。本文深入解析了院内导航技术的实现原理、应用案例及未来趋势,助力医院管理者和技术人员优化服务。文章最后可面查看详细医院院内导航解决方案
100 2
院内导航怎么实现?解析信息化医院的智慧导航技术
|
28天前
|
机器学习/深度学习 自然语言处理 监控
智能客服系统集成技术解析和价值点梳理
在 2024 年的智能客服系统领域,合力亿捷等服务商凭借其卓越的技术实力引领潮流,它们均积极应用最新的大模型技术,推动智能客服的进步。
70 7
|
1月前
|
负载均衡 网络协议 算法
Docker容器环境中服务发现与负载均衡的技术与方法,涵盖环境变量、DNS、集中式服务发现系统等方式
本文探讨了Docker容器环境中服务发现与负载均衡的技术与方法,涵盖环境变量、DNS、集中式服务发现系统等方式,以及软件负载均衡器、云服务负载均衡、容器编排工具等实现手段,强调两者结合的重要性及面临挑战的应对措施。
76 3

热门文章

最新文章

推荐镜像

更多