分布匹配蒸馏:扩散模型的单步生成优化方法研究

简介: 扩散模型在生成高质量图像方面表现出色,但其迭代去噪过程计算开销大。分布匹配蒸馏(DMD)通过将多步扩散简化为单步生成器,结合分布匹配损失和对抗生成网络损失,实现高效映射噪声图像到真实图像,显著提升生成速度。DMD利用预训练模型作为教师网络,提供高精度中间表征,通过蒸馏机制优化单步生成器的输出,从而实现快速、高质量的图像生成。该方法为图像生成应用提供了新的技术路径。

扩散模型在生成高质量图像领域具有显著优势,但其迭代去噪过程导致计算开销较大。分布匹配蒸馏(Distribution Matching Distillation,DMD)通过将多步扩散过程精简为单步生成器来解决这一问题。该方法结合分布匹配损失函数和对抗生成网络损失,实现从噪声图像到真实图像的高效映射,为快速图像生成应用提供了新的技术路径。

分布匹配机制

与传统扩散模型不同,单步生成器并不直接学习完整的数据分布,而是通过强制对齐的方式逼近目标分布。这种方法摒弃了逐步近似的过程,直接建立噪声样本到目标分布的映射关系。

在此过程中,蒸馏机制起到关键作用。预训练模型作为教师网络,提供目标分布的高精度中间表征。

DMD 技术实现流程

阶段 0:系统初始化

  1. 单步生成器基于预训练扩散 unet 进行初始化,时间步设定为 T-1
  2. real_unet 作为固定权重的教师网络,表征真实数据分布
  3. fake_unet 用于对生成器的数据分布进行建模

阶段 1:噪声到图像的生成

生成器接收随机噪声图作为输入,通过单步去噪操作生成图像 x,此时生成的图像 x 符合生成器的概率密度分布 p_fake

阶段 2:高斯噪声注入

对生成图像 x 施加高斯噪声,获得噪声图像 xt,在 0.2T0.98T 范围内均匀采样时间步 t(避开极端噪声状态),噪声注入操作促进 p_fakep_real 分布的重叠,为后续分布比较创造条件

阶段 3:双重网络处理

  1. real_unet 生成 pred_real_image,作为清晰图像的参考近似
  2. fake_unet 生成 pred_fake_image,反映当前时间步的生成器分布特征

通过对比 pred_real_imagepred_fake_image 量化真实分布与生成分布的差异

阶段 4:损失计算

计算 x 与 x — grad 之间的均方误差(MSE)作为损失度量。其中 x — grad 表示经过梯度校正的输出,用于减小与真实数据分布的偏差。

阶段 5:假分布更新机制

fake_unet 通过 x 和 pred_fake_image 之间的扩散损失进行参数更新。这一过程使 fake unet 能够追踪生成器分布的动态变化。与传统 unet 使用 xt-1_pred 和 xt-1_gt 计算损失不同,这里采用 xt-1_pred 和 x 之间的损失,使 fake UNet 能够将生成器输出的噪声版本(xt)还原为当前生成器输出 x。

核心问题解析

问题 1: 为何 fake_unet 采用 xt-1_pred 和 x0 之间的散度作为损失度量,而非采用 xt-1_pred 和 xt-1_gt 的比较?

选择 xt-1_pred 和 x 之间的散度是基于 fake_unet 的核心功能考虑。其目标是将生成器输出的噪声版本(xt)映射回生成器的当前输出(x)。这种设计确保了 fake_unet 能够准确捕获生成器的动态分布特征,从而提供有效的梯度信息来优化生成器输出。

问题 2:*fake_unet 的必要性何在?是否可以直接利用预训练的 real_unet* 输出与生成器输出计算 KL 散度?

生成器的设计目标是实现单步完全去噪,而预训练的 real_unet 在相同时间步内仅能实现部分去噪。这种本质差异导致 real_unet 输出无法提供有效的 KL 散度用于生成器训练。相比之下,fake_unet 通过持续学习生成器的动态分布,能够准确approximation当前生成器输出的特征。通过比较 real_unetfake_unet 的输出,可以获得用于优化生成器概率分布的有效梯度方向,从而提升单步图像合成的质量。# 分布匹配损失机制

训练过程中,通过 KL 散度定量评估生成器分布与真实分布之间的差异。

其中 Preal 代表真实数据的概率密度函数,Pfake 表示生成器 Gθ 产生的假分布概率密度函数。

对于高维数据集,直接计算概率密度在计算复杂度上存在显著挑战。例如,对于 32×32 像素的灰度图像,其维度空间为 256¹⁰²⁴,直接计算在实际应用中不可行。

因此,采用分数函数对真实分布和生成分布进行特征表征。

这种方法使得 KL 散度的计算成为可能:Sreal 引导 x 向 Preal 的模态靠近,而 −Sfake 则促使其远离真实分布。

其中 Sreal(x) 为真实数据分布的分数函数,Sfake(x) 为生成数据分布的分数函数,∇θ Gθ(z) 表示生成器输出 x 对参数的梯度。

Sreal(x)−Sfake(x) 表征了真实分数与生成分数的差异。对于生成样本 x,由于其 Sreal 接近零,需要引入扰动以支持扩散模型从 xt 进行去噪。

Sfake 和 Sreal 的定义参考自论文 "Song et al. — Score-based generative modeling through stochastic differential equations"

最终损失函数

技术原理剖析

在时间步 t−1,利用 real_unetfake_unet 的输出构建梯度,引导生成器的当前输出 x 向 real_unet 在 t=0 时刻的输出收敛。随后计算生成器原始输出与梯度校正后输出的均方误差(MSE)。这一校正机制确保 x 能够逐步对齐真实数据分布。

损失函数的代码实现

该图展示了不同时间步的损失函数变化,详细说明了多步生成器对单步生成器的训练过程。注意: 图中未详细展示 weighting_factor 相关细节,并对底层分布作出了特定假设。

核心思想在于利用 xfake 和 xreal 之间的差异产生的梯度,将生成器输出引导至 real_unet 在 t=0 时刻的目标输出。随着训练进行,生成器输出逐步向真实分布靠近,同时带动 fake_unet 输出的优化。最终,校正后的图像 ∥x−grad∥ 收敛至真实分布。

总结

本文深入探讨了分布匹配蒸馏(DMD)的技术原理和实现机制,着重阐述了其在图像生成领域的应用价值。欢迎学术界同仁就相关技术细节提供建议和讨论,以促进该领域的持续发展。

https://avoid.overfit.cn/post/c8b74a7d05944be5908b583559294a24

作者:Om Rastogi

目录
相关文章
|
机器学习/深度学习 人工智能
手动实现一个扩散模型DDPM(下)
手动实现一个扩散模型DDPM(下)
1215 2
|
8月前
|
存储 机器学习/深度学习 缓存
vLLM 核心技术 PagedAttention 原理详解
本文系统梳理了 vLLM 核心技术 PagedAttention 的设计理念与实现机制。文章从 KV Cache 在推理中的关键作用与内存管理挑战切入,介绍了 vLLM 在请求调度、分布式执行及 GPU kernel 优化等方面的核心改进。PagedAttention 通过分页机制与动态映射,有效提升了显存利用率,使 vLLM 在保持低延迟的同时显著提升了吞吐能力。
4577 20
vLLM 核心技术 PagedAttention 原理详解
使用云起实验室安装Stable Diffusion报错问题的解决
因为huggingface目前国内已无法访问,按照原有的手册安装时就会报错,本文给出解决办法,以顺利完成安装和使用
3614 0
|
计算机视觉 iOS开发 MacOS
Alfred Clipboard History 回车自动粘贴失效
Alfred Clipboard History 回车自动粘贴失效
2035 0
Alfred Clipboard History 回车自动粘贴失效
|
JSON JavaScript 前端开发
|
5月前
|
物联网 开发者
LoRA 模型的全新玩法——AutoLoRA 带你体验 LoRA 检索与融合的魔法
LoRA 模型的全新玩法——AutoLoRA 带你体验 LoRA 检索与融合的魔法
341 0
|
机器学习/深度学习 人工智能 自然语言处理
人工智能的发展现状如何?
【10月更文挑战第16天】人工智能的发展现状如何?
|
6月前
|
机器学习/深度学习 达摩院 PyTorch
GitHub 1.3k 一款能“填色回忆”的神器:DDColor 让老照片鲜活又逼真
DDColor 是阿里达摩院推出的图像自动着色模型,采用双解码器架构与 Colorfulness Loss 技术,实现黑白图到高保真彩色图的智能转换。支持 GPU/CPU 推理,兼容历史照片、动画、游戏截图等多场景,具备高效、真实、多样、易用等特点,广泛适用于影像修复、艺术创作等领域。
755 24
|
11月前
|
Linux Shell
问题记录:解决Linux登录故障,/etc/passwd配置受损该怎么操作
修复/etc/passwd文件是解决Linux登录故障的重要步骤。通过进入单用户模式、挂载文件系统、恢复或手动修复/etc/passwd文件,可以有效解决该问题。保持定期备份系统配置文件是预防此类问题的最佳实践。
359 13