SANA-Sprint:基于连续时间一致性蒸馏的单步扩散模型,0.1秒即可生成图像

在线体验各类最新模型,更有模型 免费Token 额度领取!
立即体验
简介: Nvidia 提出的 SANA-Sprint 是一种混合蒸馏框架,结合连续时间一致性模型(sCM)与潜在对抗扩散蒸馏(LADD),实现快速高质量文本到图像生成。它支持 1-4 步推理,单步生成 FID 7.59、GenEval 0.74,H100 GPU 上 0.1 秒生成 1024×1024 图像,比 FLUX-Schnell 快 10 倍。通过无训练一致性变换和稳定训练技术,SANA-Sprint 克服传统方法局限,推动实时生成应用。

扩散模型已成为现代文本到图像 (T2I) 生成技术的核心,能够生成高质量图像,但其迭代式推理过程导致生成速度缓慢。多数模型通常需要 20–50 个去噪步骤,这严重制约了其在实时应用中的部署。

现有的蒸馏技术旨在加速扩散模型的采样过程,然而,这些方法往往会引入稳定性问题,在极低步数下出现质量下降,并可能导致显著的内存需求

Nvidia 提出的 SANA-Sprint 是一种混合蒸馏框架,它整合了连续时间一致性模型 (sCM)潜在对抗扩散蒸馏 (LADD),旨在实现以下目标:

  • 无步训练,并支持灵活的 1–4 步推理
  • 卓越的速度与质量平衡,单步推理即可达到 FID 7.59GenEval 0.74 的指标。
  • 在 H100 GPU 上实现 0.1 秒生成 1024×1024 图像,速度比 FLUX-Schnell 快 10 倍,同时保持更高图像质量。

本文将深入探讨 SANA-Sprint 实现上述性能的技术原理。

传统蒸馏方法在超低步数推理中的局限性

扩散模型依赖于随机微分方程 (SDE) 或常微分方程 (ODE) 进行图像生成,该过程通常需要多个步骤。尽管存在多种步数缩减技术,但每种方法都存在其固有的局限性:

  • 基于 GAN 的蒸馏方法(例如,LADD) 可以加速推理过程,但容易遭受模式崩溃泛化能力不足的问题。
  • 一致性模型 (CM) 能够实现快速采样,但在超低步数 (少于 4 步) 的情况下,由于轨迹截断误差,语义对齐性能会显著下降
  • 变分分数蒸馏 (VSD) 需要额外训练辅助扩散模型,这会显著增加 GPU 内存占用和计算开销

SANA-Sprint 通过整合 sCM 和 LADD 到统一框架中,克服了上述挑战,从而在确保快速推理的同时,实现了高图像质量

基于无训练一致性变换的预训练模型重用

扩散模型通常采用流匹配基于分数的学习方法进行训练,而一致性模型 (CM) 则基于 TrigFlow 参数化。为了实现无需重新训练的快速蒸馏,SANA-Sprint 引入了一种数学变换,可以将预训练的流匹配模型转化为 TrigFlow 模型

该变换确保了以下关键特性:

  • 时域映射的无缝衔接:实现了从 流匹配模型的 [0,1] 区间TrigFlow 模型的 [0, π/2] 区间 的平滑转换。
  • 信噪比 (SNR) 的一致性:在模型适配过程中,保持了信噪比的稳定,确保图像保真度。
  • 模型输出的正确参数化:保证了转换后模型输出的速度场与 TrigFlow 框架的公式保持一致。

通过上述变换,预训练模型可以直接应用于 SANA-Sprint 框架,无需额外的重新训练,从而显著提升了效率。

解决大规模一致性模型训练不稳定性问题

将一致性模型扩展到更高分辨率和更大模型规模时,常常会面临训练不稳定性的挑战,这主要是由于梯度爆炸现象引起的。SANA-Sprint 通过以下两项关键技术来稳定训练过程:

密集时间嵌入以抑制梯度爆炸

  • 传统扩散模型通常使用乘法因子(例如,1000 * t来缩放时间嵌入,这种方法会放大时间导数梯度,容易导致训练崩溃。
  • SANA-Sprint 采用归一化时间嵌入方法,确保时间步长表示的均匀分布,从而有效提升训练稳定性和样本质量
  • 这种方法使得模型能够更快收敛,并生成更清晰锐利的图像

QK 归一化实现稳定的自注意力和交叉注意力机制

  • 随着模型规模的扩大 (参数量从 0.6B 增至 1.6B),梯度范数变得不稳定 ( >¹⁰³),导致训练失败。
  • SANA-Sprint 在注意力层的 Query 和 Key (QK) 组件中引入 RMS 归一化,在不改变模型架构的前提下,有效稳定了梯度。
  • 仅需 5,000 次微调迭代,即可显著降低训练不稳定性,从而为大规模扩散模型的稳定蒸馏奠定基础。

结合一致性模型与对抗监督

传统一致性模型主要依赖局部轨迹学习,这导致其收敛速度较慢,并且在单步生成中容易丢失细节信息。SANA-Sprint 通过引入 基于 GAN 的对抗监督机制 (LADD) (Latent Adversarial Diffusion Distillation),对一致性模型进行了增强:

  • 使用冻结的教师模型提取高层潜在空间表征,以强制模型学习数据分布的一致性。
  • 引入多头判别器学习特征层面的差异,避免了像素空间直接比对可能导致的问题。
  • 采用 铰链损失函数,提升了训练稳定性和生成样本的真实感

该技术显著提升了单步图像生成质量,有效保留了传统一致性模型难以捕捉的高频细节

评估与结果

SANA-Sprint 在速度和质量方面均达到了新的技术水平。相较于 FLUX-Schnell,SANA-Sprint 的推理速度提升了 10 倍,同时能够生成更高质量的图像。在单步推理下,SANA-Sprint 取得了 7.59 的 FID 值和 0.74 的 GenEval 值,性能超越了需要多步推理的模型。即使在 RTX 4090 等消费级 GPU 上,SANA-Sprint 也能在 0.31 秒内生成 1024×1024 像素的图像,使得高质量 AI 图像生成技术更加普及。在 H100 GPU 上,文本到图像生成仅需 0.1 秒,ControlNet 任务耗时 0.25 秒,实现了近乎实时的视觉反馈。

总结

与需要 20 步以上的传统扩散模型不同,SANA-Sprint 仅需 1-4 步即可生成高质量图像,且无需额外的训练过程。单步推理速度极快,非常适合实时应用场景。两步生成能够在保证速度 (低于 0.25 秒) 的前提下,有效提升图像细节。四步生成则在质量和效率之间实现了最佳平衡

该论文在数学原理上具有一定的复杂性,但其技术方案堪称杰出非常值得深入阅读和研究。SANA-Sprint 的工作有望推动 Flow Matching DiT 模型的下游优化,进而实现更快、更低成本的图像生成。

蒸馏推理技术的进步,使得高质量图像生成技术更加普惠化。

https://avoid.overfit.cn/post/c9690cdfa56046e7833462825ef93352

作者:Pietro Bolcato

目录
相关文章
|
机器学习/深度学习 调度
详解 Diffusion (扩散) 模型
详解 Diffusion (扩散) 模型
2130 0
|
监控 Java API
Spring Cloud 2021.0.1 实践 Resilience4J
Spring Cloud CircuitBreaker 提供了跨不同断路器实现的抽象。它提供了在您的应用程序中使用的一致 API,让您(开发人员)选择最适合您的应用程序需求的断路器实现。
2117 0
Spring Cloud 2021.0.1 实践 Resilience4J
|
12月前
|
机器学习/深度学习 达摩院 PyTorch
GitHub 1.3k 一款能“填色回忆”的神器:DDColor 让老照片鲜活又逼真
DDColor 是阿里达摩院推出的图像自动着色模型,采用双解码器架构与 Colorfulness Loss 技术,实现黑白图到高保真彩色图的智能转换。支持 GPU/CPU 推理,兼容历史照片、动画、游戏截图等多场景,具备高效、真实、多样、易用等特点,广泛适用于影像修复、艺术创作等领域。
1474 24
|
机器学习/深度学习 人工智能 算法
Scikit-learn:Python机器学习的瑞士军刀
想要快速入门机器学习但被复杂算法吓退?本文详解Scikit-learn如何让您无需深厚数学背景也能构建强大AI模型。从数据预处理到模型评估,从垃圾邮件过滤到信用风险评估,通过实用案例和直观图表,带您掌握这把Python机器学习的'瑞士军刀'。无论您是AI新手还是经验丰富的数据科学家,都能从中获取将理论转化为实际应用的关键技巧。了解Scikit-learn与大语言模型的最新集成方式,抢先掌握机器学习的未来发展方向!
1400 12
Scikit-learn:Python机器学习的瑞士军刀
|
存储 安全 数据安全/隐私保护
解锁Python安全新姿势!AES加密:让你的数据穿上防弹衣,无惧黑客窥探?
【8月更文挑战第1天】在数字化时代,确保数据安全至关重要。AES(高级加密标准)作为一种强大的对称密钥加密算法,能有效保护数据免遭非法获取。AES支持128/192/256位密钥,通过多轮复杂的加密过程提高安全性。在Python中,利用`pycryptodome`库可轻松实现AES加密:生成密钥、定义IV,使用CBC模式进行加密与解密。需要注意的是,要妥善管理密钥并确保每次加密使用不同的IV。掌握AES加密技术,为数据安全提供坚实保障。
822 2
|
测试技术
字节Seed开源统一多模态理解和生成模型 BAGEL!
近期,字节跳动Seed推出了 BAGEL—— 一个开源的多模态理解和生成础模型,具有70亿个激活参数(总共140亿个),并在大规模交错多模态数据上进行训练。
1140 3
|
机器学习/深度学习 人工智能 编解码
R1-Onevision:开源多模态推理之王!复杂视觉难题一键解析,超越GPT-4V
R1-Onevision 是一款开源的多模态视觉推理模型,基于 Qwen2.5-VL 微调,专注于复杂视觉推理任务。它通过整合视觉和文本数据,能够在数学、科学、深度图像理解和逻辑推理等领域表现出色,并在多项基准测试中超越了 Qwen2.5-VL-7B 和 GPT-4V 等模型。
704 0
R1-Onevision:开源多模态推理之王!复杂视觉难题一键解析,超越GPT-4V
|
机器学习/深度学习 人工智能
Diff-Instruct:指导任意生成模型训练的通用框架,无需额外训练数据即可提升生成质量
Diff-Instruct 是一种从预训练扩散模型中迁移知识的通用框架,通过最小化积分Kullback-Leibler散度,指导其他生成模型的训练,提升生成性能。
453 11
Diff-Instruct:指导任意生成模型训练的通用框架,无需额外训练数据即可提升生成质量
|
数据采集 数据可视化 数据挖掘
基于Python的App流量大数据分析与可视化方案
基于Python的App流量大数据分析与可视化方案