Diff-Instruct:指导任意生成模型训练的通用框架,无需额外训练数据即可提升生成质量

本文涉及的产品
视觉智能开放平台,分割抠图1万点
视觉智能开放平台,视频通用资源包5000点
视觉智能开放平台,图像通用资源包5000点
简介: Diff-Instruct 是一种从预训练扩散模型中迁移知识的通用框架,通过最小化积分Kullback-Leibler散度,指导其他生成模型的训练,提升生成性能。

❤️ 如果你也关注 AI 的发展现状,且对 AI 应用开发非常感兴趣,我会每日跟你分享最新的 AI 资讯和开源应用,也会不定期分享自己的想法和开源实例,欢迎关注我哦!

🥦 微信公众号|搜一搜:蚝油菜花 🥦

原文链接:https://mp.weixin.qq.com/s/faeBUXbDsc-ZhIxmdTWOcw


🚀 快速阅读

  1. 功能:Diff-Instruct 能从预训练扩散模型中提取知识,指导其他生成模型的训练。
  2. 原理:基于积分Kullback-Leibler散度,通过计算扩散过程中的KL散度积分来比较分布。
  3. 应用:适用于预训练扩散模型的蒸馏和改进现有GAN模型,提升生成性能。

正文(附运行示例)

Diff-Instruct 是什么

公众号: 蚝油菜花 - diff_instruct

Diff-Instruct 是一种先进的知识转移方法,专门用于从预训练的扩散模型中提取知识,并指导其他生成模型的训练。它基于一种新的散度度量——积分Kullback-Leibler (IKL) 散度,通过计算沿扩散过程的KL散度积分来比较分布。这种方法能够在不需要额外数据的情况下,通过最小化IKL散度,实现对任意生成模型的训练指导。

Diff-Instruct 的通用性和有效性在学术界受到广泛关注。它不仅可以显著提升生成模型的性能,还能在多种应用场景中发挥作用,如预训练扩散模型的蒸馏和改进现有的GAN模型。

Diff-Instruct 的主要功能

  • 知识转移:Diff-Instruct 能够从预训练的扩散模型中提取知识,并将其转移到其他生成模型中,无需额外数据。
  • 指导生成模型训练:作为一个通用框架,Diff-Instruct 可以指导任意生成模型的训练,只要生成的样本对模型参数是可微分的。
  • 最小化新型散度:Diff-Instruct 通过最小化积分Kullback-Leibler (IKL) 散度来实现知识转移,这种散度专为扩散模型设计,具有更高的鲁棒性。
  • 提升生成模型性能:Diff-Instruct 在多个实验中展示了其有效性,能够显著提升生成模型的性能,特别是在单步扩散模型和GAN模型的改进上。

Diff-Instruct 的技术原理

  • 通用框架:Diff-Instruct 提出了一个通用框架,可以指导任意生成模型的训练,只要生成的样本对模型参数是可微分的。
  • 积分Kullback-Leibler (IKL) 散度:Diff-Instruct 基于IKL散度,通过计算沿扩散过程的KL散度积分来比较分布,这种散度在比较具有不对齐支持的分布时更具鲁棒性。
  • 数据自由学习:Diff-Instruct 支持使用预训练的扩散模型作为教师来指导各种生成模型,无需额外数据。
  • 灵活性:Diff-Instruct 为生成器提供了非常高的灵活性,生成器可以是基于卷积神经网络(CNN)或基于Transformer的图像生成器,如StyleGAN,或者是从预训练扩散模型适应的基于UNet的生成器。

如何运行 Diff-Instruct

首先,克隆 Diff-Instruct 的 GitHub 仓库并设置 conda 环境:

git clone https://github.com/pkulwj1994/diff_instruct.git
cd diff_instruct

source activate
conda create -n di_v100 python=3.8
conda activate di_v100
pip install torch==1.12.1 torchvision==0.13.1 tqdm click psutil scipy

接下来,准备数据集并运行蒸馏过程。例如,对于 CIFAR-10 数据集的无条件生成,可以使用以下命令:

CUDA_VISIBLE_DEVICES=0 torchrun --standalone --nproc_per_node=1 --master_port=25678 di_train.py --outdir=/logs/di/ci10-uncond --data=/data/datasets/cifar10-32x32.zip --arch=ddpmpp --batch 128 --edm_model cifar10-uncond --cond=0 --metrics fid50k_full --tick 10 --snap 50 --lr 0.00001 --glr 0.00001 --init_sigma 1.0 --fp16=0 --lr_warmup_kimg -1 --ls 1.0 --sgls 1.0

在实验中,FID 值将自动计算并在每个“snap”轮次中显示。

资源


❤️ 如果你也关注 AI 的发展现状,且对 AI 应用开发非常感兴趣,我会每日跟你分享最新的 AI 资讯和开源应用,也会不定期分享自己的想法和开源实例,欢迎关注我哦!

🥦 微信公众号|搜一搜:蚝油菜花 🥦

相关文章
|
3月前
|
数据采集 自动驾驶 Java
PAI-TurboX:面向自动驾驶的训练推理加速框架
PAI-TurboX 为自动驾驶场景中的复杂数据预处理、离线大规模模型训练和实时智能驾驶推理,提供了全方位的加速解决方案。PAI-Notebook Gallery 提供PAI-TurboX 一键启动的 Notebook 最佳实践
|
5月前
|
机器学习/深度学习 人工智能 JSON
【解决方案】DistilQwen2.5-R1蒸馏小模型在PAI-ModelGallery的训练、评测、压缩及部署实践
阿里云的人工智能平台 PAI,作为一站式的机器学习和深度学习平台,对DistilQwen2.5-R1模型系列提供了全面的技术支持。无论是开发者还是企业客户,都可以通过 PAI-ModelGallery 轻松实现 Qwen2.5 系列模型的训练、评测、压缩和快速部署。本文详细介绍在 PAI 平台使用 DistilQwen2.5-R1 蒸馏模型的全链路最佳实践。
|
4月前
|
人工智能 JSON 算法
【解决方案】DistilQwen2.5-DS3-0324蒸馏小模型在PAI-ModelGallery的训练、评测、压缩及部署实践
DistilQwen 系列是阿里云人工智能平台 PAI 推出的蒸馏语言模型系列,包括 DistilQwen2、DistilQwen2.5、DistilQwen2.5-R1 等。本文详细介绍DistilQwen2.5-DS3-0324蒸馏小模型在PAI-ModelGallery的训练、评测、压缩及部署实践。
|
机器学习/深度学习 人工智能 算法
Post-Training on PAI (3):PAI-ChatLearn,PAI 自研高性能强化学习框架
人工智能平台 PAI 推出了高性能一体化强化学习框架 PAI-Chatlearn,从框架层面解决强化学习在计算性能和易用性方面的挑战。
|
3月前
|
机器学习/深度学习 人工智能 分布式计算
Post-Training on PAI (1):一文览尽开源强化学习框架在PAI平台的应用
Post-Training(即模型后训练)作为大模型落地的重要一环,能显著优化模型性能,适配特定领域需求。相比于 Pre-Training(即模型预训练),Post-Training 阶段对计算资源和数据资源需求更小,更易迭代,因此备受推崇。近期,我们将体系化地分享基于阿里云人工智能平台 PAI 在强化学习、模型蒸馏、数据预处理、SFT等方向的技术实践,旨在清晰地展现 PAI 在 Post-Training 各个环节的产品能力和使用方法,欢迎大家随时交流探讨。
|
4月前
|
机器学习/深度学习 人工智能 算法
PaperCoder:一种利用大型语言模型自动生成机器学习论文代码的框架
PaperCoder是一种基于多智能体LLM框架的工具,可自动将机器学习研究论文转化为代码库。它通过规划、分析和生成三个阶段,系统性地实现从论文到代码的转化,解决当前研究中代码缺失导致的可复现性问题。实验表明,PaperCoder在自动生成高质量代码方面显著优于基线方法,并获得专家高度认可。这一工具降低了验证研究成果的门槛,推动科研透明与高效。
339 19
PaperCoder:一种利用大型语言模型自动生成机器学习论文代码的框架
|
4月前
|
机器学习/深度学习 人工智能 自然语言处理
阿里云人工智能平台 PAI 开源 EasyDistill 框架助力大语言模型轻松瘦身
本文介绍了阿里云人工智能平台 PAI 推出的开源工具包 EasyDistill。随着大语言模型的复杂性和规模增长,它们面临计算需求和训练成本的障碍。知识蒸馏旨在不显著降低性能的前提下,将大模型转化为更小、更高效的版本以降低训练和推理成本。EasyDistill 框架简化了知识蒸馏过程,其具备多种功能模块,包括数据合成、基础和进阶蒸馏训练。通过数据合成,丰富训练集的多样性;基础和进阶蒸馏训练则涵盖黑盒和白盒知识转移策略、强化学习及偏好优化,从而提升小模型的性能。
|
5月前
|
机器学习/深度学习 算法 数据挖掘
PyTabKit:比sklearn更强大的表格数据机器学习框架
PyTabKit是一个专为表格数据设计的新兴机器学习框架,集成了RealMLP等先进深度学习技术与优化的GBDT超参数配置。相比传统Scikit-Learn,PyTabKit通过元级调优的默认参数设置,在无需复杂超参调整的情况下,显著提升中大型数据集的性能表现。其简化API设计、高效训练速度和多模型集成能力,使其成为企业决策与竞赛建模的理想工具。
169 12
PyTabKit:比sklearn更强大的表格数据机器学习框架
|
6月前
|
人工智能 自然语言处理 算法
MT-MegatronLM:国产训练框架逆袭!三合一并行+FP8黑科技,大模型训练效率暴涨200%
MT-MegatronLM 是摩尔线程推出的面向全功能 GPU 的开源混合并行训练框架,支持多种模型架构和高效混合并行训练,显著提升 GPU 集群的算力利用率。
424 18
|
6月前
|
机器学习/深度学习 人工智能 Java
Java机器学习实战:基于DJL框架的手写数字识别全解析
在人工智能蓬勃发展的今天,Python凭借丰富的生态库(如TensorFlow、PyTorch)成为AI开发的首选语言。但Java作为企业级应用的基石,其在生产环境部署、性能优化和工程化方面的优势不容忽视。DJL(Deep Java Library)的出现完美填补了Java在深度学习领域的空白,它提供了一套统一的API,允许开发者无缝对接主流深度学习框架,将AI模型高效部署到Java生态中。本文将通过手写数字识别的完整流程,深入解析DJL框架的核心机制与应用实践。
349 3

热门文章

最新文章