谢赛宁新作:表征学习有多重要?一个操作刷新SOTA,DiT训练速度暴涨18倍

简介: 谢赛宁团队提出REPA方法,通过将扩散模型中的噪声输入隐藏状态与外部预训练视觉编码器的干净图像表征对齐,显著提升扩散模型的训练效率和生成质量,为扩散模型在表征学习上的应用开辟新路径。

在人工智能领域,生成模型(generative models)一直是研究的热点。近年来,扩散模型(diffusion models)因其在生成高质量图像和音频方面的卓越表现,备受关注。然而,尽管扩散模型在生成任务上取得了显著的成果,但其在表征学习(representation learning)方面的能力却相对较弱,这在一定程度上限制了其在更广泛领域的应用。

最近,谢赛宁团队在表征学习与扩散模型的结合上取得了突破性进展。他们提出了一种名为REPresentation Alignment(REPA)的简单而有效的正则化方法,通过将扩散模型中的噪声输入隐藏状态与外部预训练视觉编码器获得的干净图像表征进行对齐,显著提高了扩散模型的训练效率和生成质量。

表征学习是机器学习中的一个重要概念,它指的是从原始数据中提取有用的特征或表示,以便于后续的分类、回归或生成任务。在扩散模型中,表征学习的质量直接影响到模型的生成能力和泛化性能。然而,由于扩散模型的训练过程主要关注于去噪(denoising),而不是显式地学习表征,导致其在表征学习方面的表现相对较弱。

谢赛宁团队的研究表明,通过引入外部高质量的视觉表征,可以有效地弥补扩散模型在表征学习方面的不足。他们提出的REPA方法,通过在扩散模型的训练过程中,将噪声输入隐藏状态的投影与外部预训练视觉编码器获得的干净图像表征进行对齐,使得模型能够更好地学习到有用的表征。

REPA方法的核心思想是,通过在扩散模型的训练过程中引入一个额外的正则化项,来约束模型的表征学习过程。具体来说,REPA方法包括以下几个步骤:

  1. 获取外部视觉表征:首先,使用一个预训练的视觉编码器(如VGG、ResNet等),从干净的图像中提取高质量的视觉表征。
  2. 计算噪声输入隐藏状态的投影:在扩散模型的训练过程中,对于每个噪声输入,计算其在模型中的隐藏状态,并进行投影操作。
  3. 对齐投影与外部视觉表征:将噪声输入隐藏状态的投影与外部视觉表征进行对齐,通过最小化两者之间的差异,来约束模型的表征学习过程。
  4. 联合优化:将REPA正则化项与扩散模型的原始损失函数进行联合优化,以实现对模型表征学习和生成能力的共同提升。

谢赛宁团队在多个流行的扩散模型和基于流的变换器(如DiTs和SiTs)上进行了实验,结果表明REPA方法能够显著提高模型的训练效率和生成质量。

在训练效率方面,REPA方法能够将SiT模型的训练速度提高超过17.5倍。例如,使用REPA方法,可以在不到400K步的训练时间内,达到与未使用REPA方法的SiT-XL模型在7M步训练时间内相同的性能水平(不使用无分类器指导)。

在生成质量方面,REPA方法在使用无分类器指导的情况下,能够达到当前最先进的FID=1.42的生成质量。这表明REPA方法不仅能够提高模型的训练效率,还能够显著提升模型的生成能力。

REPA方法的提出,为扩散模型的表征学习提供了一种新的思路和方法。其简单而有效的正则化策略,不仅能够提高模型的训练效率,还能够显著提升模型的生成质量。这对于推动扩散模型在更广泛领域的应用具有重要意义。

然而,REPA方法也存在一些潜在的问题和挑战。首先,引入外部视觉表征可能会增加模型的计算复杂度和存储需求,这对于大规模模型的训练和部署可能是一个限制因素。其次,REPA方法的效果可能依赖于外部视觉编码器的质量和与扩散模型的匹配程度,这需要在实际应用中进行仔细的选择和调整。

此外,REPA方法主要关注于图像生成任务,对于其他类型的生成任务(如音频、视频等)的适用性还有待进一步研究。同时,REPA方法在处理复杂场景和多模态数据时的表现也需要进一步验证。

论文地址:https://arxiv.org/abs/2410.06940

目录
相关文章
|
3月前
|
物联网
StableDiffusion-04 (炼丹篇) 15分钟 部署服务并进行LoRA微调全过程详细记录 不到20张百变小樱Sakura微调 3090(24GB) 学不会你打我!(一)
StableDiffusion-04 (炼丹篇) 15分钟 部署服务并进行LoRA微调全过程详细记录 不到20张百变小樱Sakura微调 3090(24GB) 学不会你打我!(一)
39 0
|
3月前
|
物联网
StableDiffusion-04 (炼丹篇) 15分钟 部署服务并进行LoRA微调全过程详细记录 不到20张百变小樱Sakura微调 3090(24GB) 学不会你打我!(二)
StableDiffusion-04 (炼丹篇) 15分钟 部署服务并进行LoRA微调全过程详细记录 不到20张百变小樱Sakura微调 3090(24GB) 学不会你打我!(二)
41 0
|
3月前
|
物联网
StableDiffusion-03 (准备篇)15分钟 部署服务并进行LoRA微调全过程详细记录 不到20张百变小樱Sakura微调 3090(24GB) 学不会你打我!(二)
StableDiffusion-03 (准备篇)15分钟 部署服务并进行LoRA微调全过程详细记录 不到20张百变小樱Sakura微调 3090(24GB) 学不会你打我!(二)
40 1
|
3月前
|
并行计算 Ubuntu 物联网
StableDiffusion-03 (准备篇)15分钟 部署服务并进行LoRA微调全过程详细记录 不到20张百变小樱Sakura微调 3090(24GB) 学不会你打我!(一)
StableDiffusion-03 (准备篇)15分钟 部署服务并进行LoRA微调全过程详细记录 不到20张百变小樱Sakura微调 3090(24GB) 学不会你打我!(一)
39 0
|
7月前
|
机器学习/深度学习 自然语言处理 算法
用神经架构搜索给LLM瘦身,模型变小,准确度有时反而更高
【6月更文挑战第20天】研究人员运用神经架构搜索(NAS)压缩LLM,如LLaMA2-7B,找到小而精准的子网,降低内存与计算成本,保持甚至提升性能。实验显示在多个任务上,模型大小减半,速度加快,精度不变或提升。NAS虽需大量计算资源,但结合量化技术,能有效优化大型语言模型。[论文链接](https://arxiv.org/pdf/2405.18377)**
70 3
|
7月前
|
算法 测试技术 异构计算
【SAM模型超级进化】MobileSAM轻量化的分割一切大模型出现,模型缩小60倍,速度提高40倍,效果不减
【SAM模型超级进化】MobileSAM轻量化的分割一切大模型出现,模型缩小60倍,速度提高40倍,效果不减
|
8月前
|
人工智能 安全 测试技术
Infection-2.5登场,训练计算量仅40%、性能直逼GPT-4!
【2月更文挑战第18天】Infection-2.5登场,训练计算量仅40%、性能直逼GPT-4!
77 3
Infection-2.5登场,训练计算量仅40%、性能直逼GPT-4!
|
机器学习/深度学习 人工智能 自然语言处理
超越Transformer,清华、字节大幅刷新并行文本生成SoTA性能|ICML 2022
超越Transformer,清华、字节大幅刷新并行文本生成SoTA性能|ICML 2022
172 0
超越Transformer,清华、字节大幅刷新并行文本生成SoTA性能|ICML 2022
|
机器学习/深度学习 存储 人工智能
ICLR 2023 Spotlight|节省95%训练开销,清华黄隆波团队提出强化学习专用稀疏训练框架RLx2
ICLR 2023 Spotlight|节省95%训练开销,清华黄隆波团队提出强化学习专用稀疏训练框架RLx2
204 0
|
机器学习/深度学习 存储 自然语言处理
微软提出MiniViT | 把DeiT压缩9倍,性能依旧超越ResNet等卷积网络(一)
微软提出MiniViT | 把DeiT压缩9倍,性能依旧超越ResNet等卷积网络(一)
277 0

热门文章

最新文章