采样提速256倍,蒸馏扩散模型生成图像质量媲美教师模型,只需4步

本文涉及的产品
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
交互式建模 PAI-DSW,每月250计算时 3个月
模型训练 PAI-DLC,100CU*H 3个月
简介: 采样提速256倍,蒸馏扩散模型生成图像质量媲美教师模型,只需4步

斯坦福大学联合谷歌大脑使用「两步蒸馏方法」提升无分类器指导的采样效率,在生成样本质量和采样速度上都有非常亮眼的表现。

去噪扩散概率模型(DDPM)在图像生成、音频合成、分子生成和似然估计领域都已经实现了 SOTA 性能。同时无分类器(classifier-free)指导进一步提升了扩散模型的样本质量,并已被广泛应用在包括 GLIDE、DALL·E 2 和 Imagen 在内的大规模扩散模型框架中。


然而,无分类器指导的一大关键局限是它的采样效率低下,需要对两个扩散模型评估数百次才能生成一个样本。这一局限阻碍了无分类指导模型在真实世界设置中的应用。尽管已经针对扩散模型提出了蒸馏方法,但目前这些方法不适用无分类器指导扩散模型。


为了解决这一问题,近日斯坦福大学和谷歌大脑的研究者在论文《On Distillation of Guided Diffusion Models》中提出使用两步蒸馏(two-step distillation)方法来提升无分类器指导的采样效率。


在第一步中,他们引入单一学生模型来匹配两个教师扩散模型的组合输出;在第二步中,他们利用提出的方法逐渐地将从第一步学得的模型蒸馏为更少步骤的模型。


利用提出的方法,单个蒸馏模型能够处理各种不同的指导强度,从而高效地对样本质量和多样性进行权衡。此外为了从他们的模型中采样,研究者考虑了文献中已有的确定性采样器,并进一步提出了随机采样过程。



论文地址:https://arxiv.org/pdf/2210.03142.pdf


研究者在 ImageNet 64x64 和 CIFAR-10 上进行了实验,结果表明提出的蒸馏模型只需 4 步就能生成在视觉上与教师模型媲美的样本,并且在更广泛的指导强度上只需 8 到 16 步就能实现与教师模型媲美的 FID/IS 分数,具体如下图 1 所示。



此外,在 ImageNet 64x64 上的其他实验结果也表明了,研究者提出的框架在风格迁移应用中也表现良好。


方法介绍


接下来本文讨论了蒸馏无分类器指导扩散模型的方法( distilling a classifier-free guided diffusion model)。给定一个训练好的指导模型,即教师模型之后本文分两步完成。


第一步引入一个连续时间学生模型,该模型具有可学习参数η_1,以匹配教师模型在任意时间步 t∈[0,1] 处的输出。给定一个优化范围 [w_min, w_max],对学生模型进行优化:



其中,为了合并指导权重 w,本文引入了一个 w - 条件模型,其中 w 作为学生模型的输入。为了更好地捕捉特征,本文还对 w 应用傅里叶嵌入。此外,由于初始化在模型性能中起着关键作用,因此本文初始化学生模型的参数与教师模型相同。


在第二步中,本文将离散时间步(discrete time-step)考虑在内,并逐步将第一步中的蒸馏模型转化为步数较短的学生模型,其可学习参数为η_2,每次采样步数减半。设 N 为采样步数,给定 w ~ U[w_min, w_max] 和 t∈{1,…, N},然后根据 Salimans & Ho 等人提出的方法训练学生模型。在将教师模型中的 2N 步蒸馏为学生模型中的 N 步之后,之后使用 N 步学生模型作为新的教师模型,这个过程不断重复,直到将教师模型蒸馏为 N/2 步学生模型。

N 步可确定性和随机采样:一旦模型训练完成,给定一个指定的 w ∈ [w_min, w_max],然后使用 DDIM 更新规则执行采样。


实际上,本文也可以执行 N 步随机采样,使用两倍于原始步长的确定性采样步骤,然后使用原始步长向后执行一个随机步骤 。对于,当 t > 1/N 时,本文使用以下更新规则



实验


实验评估了蒸馏方法的性能,本文主要关注模型在 ImageNet 64x64 和 CIFAR-10 上的结果。他们探索了指导权重的不同范围,并观察到所有范围都具有可比性,因此实验采用 [w_min, w_max] = [0, 4]。图 2 和表 1 报告了在 ImageNet 64x64 上所有方法的性能。




本文还进行了如下实验。具体来说,为了在两个域 A 和 B 之间执行风格迁移,本文使用在域 A 上训练的扩散模型对来自域 A 的图像进行编码,然后使用在域 B 上训练的扩散模型进行解码。由于编码过程可以理解为反向 DDIM 采样过程,本文在无分类器指导下对编码器和解码器进行蒸馏,并与下图 3 中的 DDIM 编码器和解码器进行比较。



本文还探讨了如何修改指导强度 w 以影响性能,如下图 4 所示。


相关文章
|
移动开发 JavaScript
vue 中使用vue-introjs做引导动画
vue 中使用vue-introjs做引导动画
|
7月前
|
机器学习/深度学习 人工智能 vr&ar
LHM:单图生成3D动画人!阿里开源建模核弹,高斯点云重构服装纹理
阿里巴巴通义实验室开源的LHM模型,能够从单张图像快速重建高质量可动画化的3D人体模型,支持实时渲染和姿态控制,适用于AR/VR、游戏开发等多种场景。
1500 0
LHM:单图生成3D动画人!阿里开源建模核弹,高斯点云重构服装纹理
|
10月前
|
人工智能 编解码 虚拟化
See3D:智源研究院开源的无标注视频学习 3D 生成模型
See3D 是智源研究院推出的无标注视频学习 3D 生成模型,能够从大规模无标注的互联网视频中学习 3D 先验,实现从视频中生成 3D 内容。See3D 采用视觉条件技术,支持从文本、单视图和稀疏视图到 3D 的生成,并能进行 3D 编辑与高斯渲染。
286 13
See3D:智源研究院开源的无标注视频学习 3D 生成模型
|
12月前
|
存储 索引 Python
Python的列表替换
Python的列表替换
297 0
|
vr&ar 图形学 UED
优化图形渲染与物理模拟:减少Draw Calls与利用LOD技术提升性能
【7月更文第10天】在现代游戏开发和实时渲染应用中,性能优化是至关重要的环节,它直接关系到用户体验的流畅度和真实感。本文将深入探讨两种关键技术手段——减少Draw Calls和使用Level of Detail (LOD) 技术,来提升图形渲染与物理模拟的效率。
687 2
|
开发工具 vr&ar 图形学
PicoVR Unity SDK⭐️四、基础传送方式详解
PicoVR Unity SDK⭐️四、基础传送方式详解
|
JavaScript Java 测试技术
基于SpringBoot+Vue+uniapp的合庆镇停车场车位预约系统的详细设计和实现(源码+lw+部署文档+讲解等)
基于SpringBoot+Vue+uniapp的合庆镇停车场车位预约系统的详细设计和实现(源码+lw+部署文档+讲解等)
151 0
|
数据采集 供应链 安全
利用大数据优化业务流程:策略与实践
【5月更文挑战第11天】本文探讨了利用大数据优化业务流程的策略与实践,包括明确业务目标、构建大数据平台、数据采集整合、分析挖掘及流程优化。通过实例展示了电商和制造企业如何利用大数据改进库存管理和生产流程,提高效率与客户满意度。随着大数据技术进步,其在业务流程优化中的应用将更加广泛和深入,企业需积极采纳以适应市场和客户需求。
|
安全 NoSQL Redis
服务器又被攻击了,我这样做...
近期遭遇阿里云服务器频繁报警,经分析发现是由于测试服务器所有端口对公网开放,导致自动化程序对其扫描。黑客可能利用类似Redis的未授权访问漏洞进行攻击。为避免此类问题,建议:1. 不开放不必要的端口;2. 避免以root权限运行服务;3. 设置服务器IP白名单;4. 定期更换密码。保持良好安全习惯可保障服务器安全。
3687 3
服务器又被攻击了,我这样做...
|
监控 Docker 容器
Docker从入门到精通:Docker log 命令学习
了解 Docker 日志管理对容器监控至关重要。`docker logs` 命令用于查看和管理容器日志,例如,`docker logs <container_name>` 显示容器日志,`-f` 或 `--follow` 实时跟踪日志,`--tail` 显示指定行数,`--timestamps` 添加时间戳,`--since` 按日期筛选。Docker 支持多种日志驱动,如 `syslog`,可通过 `--log-driver` 配置。有效管理日志能提升应用程序的稳定性和可维护性。

热门文章

最新文章