❤️ 如果你也关注 AI 的发展现状,且对 AI 应用开发非常感兴趣,我会每日跟你分享最新的 AI 资讯和开源应用,也会不定期分享自己的想法和开源实例,欢迎关注我哦!
🥦 微信公众号|搜一搜:蚝油菜花 🥦
🚀 快速阅读
高分辨率图像生成:HART能够直接生成1024×1024像素的高分辨率图像。
图像质量提升:基于混合Tokenizer技术,HART在图像重建和生成质量上超越传统自回归模型。
计算效率优化:在保持高图像质量的同时,显著提高计算效率,降低训练成本和推理延迟。
正文(附运行示例)
HART 是什么
HART(Hybrid Autoregressive Transformer)是麻省理工学院研究团队推出的自回归视觉生成模型。该模型能够直接生成1024×1024像素的高分辨率图像,其生成质量可与扩散模型相媲美。HART的核心技术在于其混合Tokenizer,这种技术将自动编码器的连续潜在表示分解为离散token和连续token。离散token负责捕捉图像的主要结构,而连续token则专注于细节。
HART的轻量级残差扩散模块仅用3700万参数,大幅提升了计算效率。在MJHQ-30K数据集上,HART将重构FID从2.11降至0.30,生成FID从7.85降至5.38,提升了31%。此外,HART在吞吐量上比现有扩散模型提高了4.5-7.7倍,MAC降低6.9-13.4倍。
HART 的主要功能
- 高分辨率图像生成:直接生成1024×1024像素的高分辨率图像,满足高质量视觉内容的需求。
- 图像质量提升:基于混合Tokenizer技术,HART在图像重建和生成质量上超越传统的自回归模型,与扩散模型相媲美。
- 计算效率优化:在保持高图像质量的同时,显著提高计算效率,降低训练成本和推理延迟。
- 自回归建模:基于自回归方法,逐步生成图像,支持对生成过程进行更精细的控制。
HART 的技术原理
- 混合Tokenizer:HART的核心是混合Tokenizer,将自动编码器的连续潜在表示分解为离散token和连续token。离散token负责捕捉图像的主要结构,连续token专注于细节。
- 离散自回归模型:离散部分由一个可扩展分辨率的离散自回归模型建模,支持模型在不同分辨率下生成图像。
- 轻量级残差扩散模块:连续部分由一个轻量级的残差扩散模块学习,该模块只有3700万个参数,有助于提高模型的效率。
- 效率与性能平衡:HART在FID和CLIP分数上优于现有的扩散模型,在吞吐量上提高了4.5-7.7倍,MAC降低6.9-13.4倍,实现效率与性能的良好平衡。
- 自回归生成:HART基于自回归方法,逐步生成图像,每一步都基于前一步的输出,支持模型在生成过程中逐步细化图像细节。
如何运行 HART
环境设置
首先,克隆HART的GitHub仓库并设置环境:
git clone https://github.com/mit-han-lab/hart
cd hart
conda create -n hart python=3.10
conda activate hart
conda install -c nvidia cuda-toolkit -y
pip install -e .
cd hart/kernels && python setup.py install
下载模型和Tokenizer
下载Qwen2-VL-1.5B-Instruct模型和HART tokenizer及模型:
git clone https://huggingface.co/mit-han-lab/Qwen2-VL-1.5B-Instruct
git clone https://huggingface.co/mit-han-lab/hart-0.7b-1024px
运行Gradio Demo
使用以下命令启动Gradio demo:
python app.py --model_path /path/to/model \
--text_model_path /path/to/Qwen2 \
--shield_model_path /path/to/ShieldGemma2B
命令行推理
- 使用单个提示生成图像:
python sample.py --model_path /path/to/model \
--text_model_path /path/to/Qwen2 \
--prompt "YOUR_PROMPT" \
--sample_folder_dir /path/to/save_dir \
--shield_model_path /path/to/ShieldGemma2B
- 使用多个提示生成图像:
python sample.py --model_path /path/to/model \
--text_model_path /path/to/Qwen2 \
--prompt_list [Prompt1, Prompt2, ..., PromptN] \
--sample_folder_dir /path/to/save_dir \
--shield_model_path /path/to/ShieldGemma2B
延迟基准测试
使用以下命令进行延迟基准测试:
python latency_profile.py --model_path /path/to/model \
--text_model_path /path/to/Qwen2
资源
- 项目官网:https://hanlab.mit.edu/projects/hart
- GitHub 仓库:https://github.com/mit-han-lab/hart
- arXiv 技术论文:https://arxiv.org/pdf/2410.10812
- 在线体验Demo:https://hart.mit.edu/
❤️ 如果你也关注 AI 的发展现状,且对 AI 应用开发非常感兴趣,我会每日跟你分享最新的 AI 资讯和开源应用,也会不定期分享自己的想法和开源实例,欢迎关注我哦!
🥦 微信公众号|搜一搜:蚝油菜花 🥦