❤️ 如果你也关注 AI 的发展现状,且对 AI 应用开发非常感兴趣,我会每日跟你分享最新的 AI 资讯和开源应用,也会不定期分享自己的想法和开源实例,欢迎关注我哦!
🥦 微信公众号|搜一搜:蚝油菜花 🥦
🚀 快速阅读
- 量化压缩:将扩散模型的权重和激活值量化到4位,减少模型大小和内存占用。
- 加速推理:通过量化减少计算复杂度,提高模型在GPU上的推理速度。
- 低秩分支:引入低秩分支处理量化中的异常值,减少量化误差,提升图像质量。
正文(附运行示例)
SVDQuant 是什么
SVDQuant是由MIT研究团队推出的后训练量化技术,专门针对扩散模型进行优化。该技术通过将模型的权重和激活值量化至4位,显著减少了内存占用,并加速了推理过程。SVDQuant引入了一个高精度的低秩分支,用于吸收量化过程中的异常值,从而在保持图像质量的同时,实现了在16GB 4090 GPU上3.5倍的显存优化和8.7倍的延迟减少。
SVDQuant支持DiT和UNet架构,并能无缝集成现成的低秩适配器(LoRAs),无需重新量化。这为在资源受限的设备上部署大型扩散模型提供了有效的解决方案。
SVDQuant 的主要功能
- 量化压缩:将扩散模型的权重和激活值量化到4位,减少模型大小,降低内存占用。
- 加速推理:量化减少计算复杂度,提高模型在GPU上的推理速度。
- 低秩分支吸收异常值:引入低秩分支处理量化中的异常值,减少量化误差。
- 内核融合:设计推理引擎Nunchaku,基于内核融合减少内存访问,进一步提升推理效率。
- 支持多种架构:兼容DiT和UNet架构的扩散模型。
- LoRA集成:无缝集成低秩适配器(LoRAs),无需重新量化。
SVDQuant 的技术原理
- 量化处理:对模型的权重和激活值进行4位量化,对保持模型性能构成挑战。
- 异常值处理:用平滑技术将激活值中的异常值转移到权重上,基于SVD分解权重,将权重分解为低秩分量和残差。
- 低秩分支:引入16位精度的低秩分支处理权重中的异常值,将残差量化到4位,降低量化难度。
- Eckart-Young-Mirsky定理:移除权重中的主导奇异值,大幅减小权重的幅度和异常值。
- 推理引擎Nunchaku:设计推理引擎,基于融合低秩分支和低比特分支的内核,减少内存访问和内核调用次数,降低延迟。
如何运行 SVDQuant
安装依赖
首先,创建并激活一个conda环境,然后安装所需的依赖包:
conda create -n nunchaku python=3.11
conda activate nunchaku
pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121
pip install diffusers ninja wheel transformers accelerate sentencepiece protobuf
pip install huggingface_hub peft opencv-python einops gradio spaces GPUtil
安装 nunchaku
包
确保你已经安装了gcc/g++>=11
。如果没有,可以通过Conda安装:
conda install -c conda-forge gxx=11 gcc=11
然后从源码构建并安装nunchaku
包:
git clone https://github.com/mit-han-lab/nunchaku.git
cd nunchaku
git submodule init
git submodule update
pip install -e .
使用示例
在example.py
中,提供了一个运行INT4 FLUX.1-schnell模型的最小脚本:
import torch
from diffusers import FluxPipeline
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=4, guidance_scale=0).images[0]
image.save("example.png")
资源
- 项目官网:https://hanlab.mit.edu/projects/svdquant
- GitHub 仓库:https://github.com/mit-han-lab/nunchaku
- arXiv 技术论文:https://arxiv.org/pdf/2411.05007
- 在线体验Demo:https://svdquant.mit.edu/
❤️ 如果你也关注 AI 的发展现状,且对 AI 应用开发非常感兴趣,我会每日跟你分享最新的 AI 资讯和开源应用,也会不定期分享自己的想法和开源实例,欢迎关注我哦!
🥦 微信公众号|搜一搜:蚝油菜花 🥦