PyTorch 零基础入门 GAN 模型之 cGAN

简介: 最早提出 cGAN 的是论文 《Conditional Generative Adversarial Nets》,为了达到条件生成的目的,我们在输入给生成器网络 G 的噪声 z 上 concat 一个标签向量 y, 告诉生成网络生成标签所指定的数据。对于输入给判别器 D 的数据,也 concat 这样的一个标签,告诉判别网络判断输入是否为真实的该类别数据。


背景介绍



在之前的文章中,我们介绍了 GAN 的原理以及如何评价训练好的模型。可能有小伙伴看到,怎么生成的都是单一类别的图片呢,像 CIFAR10 和 ImageNet,都包含了多种类别的图片,如果我想训练一个能够生成多种类别图片的生成对抗网络该怎么做呢?


那么称为  conditional GANs (cGAN) 的模型就可以派上用场了。


cGAN 的发展



如何引入类别信息


最早提出 cGAN 的是论文 《Conditional Generative Adversarial Nets》,为了达到条件生成的目的,我们在输入给生成器网络 G 的噪声 z 上 concat 一个标签向量 y, 告诉生成网络生成标签所指定的数据。对于输入给判别器 D 的数据,也 concat 这样的一个标签,告诉判别网络判断输入是否为真实的该类别数据。

640.png

那么,cGAN 的目标函数可以表述成如下形式:

640.png

cGAN 采用 MLP 作为网络结构,一维的输入可以方便地和标签向量或者标签嵌入 concat,但是对于图像生成任务主流的 CNN 模型,无法直接采用这种引入方式,特别是对于判别器网络。


Projection GAN 通过推导发现,假设image.png其中  image.png为判别器网络,  image.png为激活函数, image.png为网络模型,在上面的目标函数下,最优的image.png image.png

根据上式,输入图像先经过网络 image.png 抽取一维特征,然后分两路,一路通过网络  image.png输出关于图像真假的判别结果,一路与经过编码的类别标签做点乘得到关于类别的判别结果,之后将两路结果相加就可以得到最终的判别结果。模型结构如下:

640.png


如何稳定训练过程


GAN 的训练过程不稳定,而 cGAN 中学习多种类别的数据,稳定训练过程则更具有挑战性。在之前的研究如 WGAN,WGANGP 中,学者们发现对网络施加 Lipschitz 约束能有效稳定 GAN 的训练过程。


相比于之前在损失函数中增加正则项的做法,SNGAN 提出了谱归一化 (spectral normalization) 来构造网络模型,使得无论网络参数是什么,都能满足 Lipschitz 约束。


如何提升生成质量


在多类别数据如 ImageNet 的训练过程中,人们发现网络更加擅长生成局部的细节纹理,如狗的毛发;而对于几何特征和整体结构,生成效果往往不尽如人意。如下图 SNGAN 生成结果为例,狗的身体结构存在很多错误和不完整的表达。

640.png


SAGAN 认为这是由于卷积模型难以捕捉到距离较远的特征,因此引入注意力机制,设计了如下的注意力模块。

640.png

这里卷积特征图经过三个 1x1 卷积 f(x), g(x), h(x),将 f(x) 的输出转置,并和 g(x) 的输出相乘,再经过 softmax 归一化得到一个 attention map ,将得到的 attention map 和 h(x) 逐像素点相乘,再经过卷积 v(x) 得到自适应注意力的特征图。


SAGAN 在生成器和判别器中引入了这个注意力模块,使得生成器可以建模图像跨区域的依赖关系,判别器可以对全局图像结构施加几何约束。


集大成者:BigGAN


在采用上面提到的投影判别,谱归一化,自注意模块的基础上,《LARGE SCALE GAN TRAINING FOR HIGH FIDELITY NATURAL IMAGE SYNTHESIS》(即 BigGAN )通过增大 batch size,提升模型宽度,显著提升了图像生成的结果。(见下图,我们在上篇文章中提到了 FID 和 IS 两个评价生成模型的 metrics,可以看到,通过提升 batch size 和 通道数,这两个 metric 有显著提升。

640.png

为了让 noise 能够直接影响不同分辨率上的特征,BigGAN 提出了 skip-z ,将 noise 分割后,分别传入 G 网络不同尺度的 layer。为了节省时间和内存,BigGAN 只对 class 做一次 embedding,将编码结果传给每个条件批归一化。以下为一个 BigGAN 的 G 网络结构图。

640.png

其中每一个 ResBlock 结构如下。分组的噪声和类别嵌入 concat 后,通过线性层得到每个 BatchNorm 的 gain 和 bias 参数,通过这种方式引入类别信息。

640.png


使用 MMGeneration 上手 BigGAN


在我们的文章 PyTorch 零基础入门 GAN 模型之基础篇 中,我们介绍了如何安装 MMGen 和训练模型。在此基础上,我们可以上手以 BigGAN 为代表的条件生成模型。


我们可以先看看 BigGAN 生成的图片长啥样,通过运行如下代码,我们可以从预训练好的  BigGAN 中 sample 类别随机的图片。

python demo/conditional_demo.py configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py https://download.openmmlab.com/mmgen/biggan/biggan_imagenet1k_128x128_b32x8_best_fid_iter_1232000_20211111_122548-5315b13d.pth

640.png

当然,我们也可以用 --label 来指定采样的类别,--samples-per-classes 来指定每类采样的数量。


比如运行下面代码:

python demo/conditional_demo.py configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py https://download.openmmlab.com/mmgen/biggan/biggan_imagenet1k_128x128_b32x8_best_fid_iter_1232000_20211111_122548-5315b13d.pth --label 151 285 292 --samples-per-classes 5

可以得到 狗(151),猫(285),老虎(292)各五张图片。

640.png

我们还可以通过运行下面的代码看看 BigGAN 分别从噪声空间和标签空间插值的结果。

python apps/conditional_interpolate.py configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py https://download.openmmlab.com/mmgen/biggan/biggan_imagenet1k_128x128_b32x8_best_fid_iter_1232000_20211111_122548-5315b13d.pth --samples-path work_dirs/demos/ --show-mode group --fix-z # 固定噪声
python apps/conditional_interpolate.py configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py https://download.openmmlab.com/mmgen/biggan/biggan_imagenet1k_128x128_b32x8_best_fid_iter_1232000_20211111_122548-5315b13d.pth --samples-path work_dirs/demos/ --show-mode group --fix-y # 固定标签

结果分别如下:

640.png

640.png

现在我们可以看看怎么训练 BigGAN,首先我们需要下载 ImageNet,然后放到 ./data 文件夹下。

BigGAN 的训练有几个关键点设置,我们以 configs/_base_/models/biggan/biggan_128x128.py 和configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py 为例进行说明。


首先是 spectral normalization 的实现方式,我们提供了两种实现方式,一种是 PyTorch 官方提供的实现,另一种是 BigGAN 作者 ajbrock 提供的实现。我们可以设置 sn_style 为 torch 或者 ajbrock 来选择,如果不设置,默认为 ajbrock。

# 使用 BigGAN 作者提供的 SN 实现
model = dict(
    type='BasiccGAN',
    generator=dict(xxx, sn_style='ajbrock'),
    discriminator=dict(xxx, sn_style='ajbrock'),
    gan_loss=dict(type='GANLoss', gan_type='hinge'))

在生成模型中,有时希望用一个更强的 D 来引导 G 的更新。常用的一种设置为训练若干步判别器,再训练一步生成器 ,这里 BigGAN 通过设置 train_cfg 中的 disc_steps 和 gen_steps 来实现。

train_cfg = dict(
    disc_steps=8, gen_steps=1, batch_accumulation_steps=8, use_ema=True)

这个 config 表示训练 8 步判别器,再训练一步生成器。


在上面代码中,字段 batch_accumulation_steps 涉及到梯度累积操作,因为显存限制,我们很难直接在 batchsize 为 2048 的数据上做 forward 和 backward,因此可以将多个小批量上的梯度平均化,只在 batch_accumulation_steps 次累积后进行优化。假设我们在 8 卡上训练,每张卡 batchsize 为 32,batch_accumulation_steps = 8,这样可以逼近 8*8*32 = 2048 的 batchsize。


为了稳定 GAN 的训练过程,我们往往要使用一种叫指数移动平均的技巧。通过设置 generator 网络的备份 generator_ema 。在每个 train iter 后,将更新的模型参数和历史参数加权平均后作为generator_ema 的参数,这样 generator_ema 的参数更新会比 generator 更加平滑,作为训练结束后的 inference 模型,其生成结果更好。


为了使用 ema,首先需要在上文 train_cfg 中将 use_ema 设置为 True,同时,需要在 config 中添加一个 ExponentialMovingAverageHook。

custom_hooks = [
    xxx,
    dict(
        type='ExponentialMovingAverageHook',
        module_keys=('generator_ema', ),
        interval=8,
        start_iter=160000,
        interp_cfg=dict(momentum=0.9999, momentum_nontrainable=0.9999),
        priority='VERY_HIGH')
]

这里 interval 为 generator_ema 参数更新频率,start_inter 为 ema 开始 iteration,在此之前网络参数照搬 generator。interp_cfg 中 momentum 为 parameters 更新的权重,momentum_nontrainable 为 buffers 更新的权重。


对于模型训练的其他细节和具体实现,大家可以参考 MMGeneration 中的代码(当然我们也有可能再出一期详解~)


上面的设置,我们已经在 config 中为大家写好了,只需要运行如下代码,就可以开始训练自己的 BigGAN 模型了!

bash tools/dist_train.sh configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py 8 --work-dir ./work_dirs/biggan

训练过程中,可以查看 work_dirs/biggan/training_samples 下的不同阶段模型生成图片。前四行为 generator_ema 生成图片,后四行为 generator 同输入下生成的图片。

640.png

其实,上面的条件采样,条件插值,模型训练对于 MMGeneration 中已支持的 SNGAN,SAGAN 也是同样适用的,欢迎大家随时使用并提出意见~

文章来源:【OpenMMLab

 2022-03-30 18:10

目录
相关文章
|
2月前
|
机器学习/深度学习 数据采集 人工智能
PyTorch学习实战:AI从数学基础到模型优化全流程精解
本文系统讲解人工智能、机器学习与深度学习的层级关系,涵盖PyTorch环境配置、张量操作、数据预处理、神经网络基础及模型训练全流程,结合数学原理与代码实践,深入浅出地介绍激活函数、反向传播等核心概念,助力快速入门深度学习。
183 1
|
6月前
|
机器学习/深度学习 PyTorch API
PyTorch量化感知训练技术:模型压缩与高精度边缘部署实践
本文深入探讨神经网络模型量化技术,重点讲解训练后量化(PTQ)与量化感知训练(QAT)两种主流方法。PTQ通过校准数据集确定量化参数,快速实现模型压缩,但精度损失较大;QAT在训练中引入伪量化操作,使模型适应低精度环境,显著提升量化后性能。文章结合PyTorch实现细节,介绍Eager模式、FX图模式及PyTorch 2导出量化等工具,并分享大语言模型Int4/Int8混合精度实践。最后总结量化最佳策略,包括逐通道量化、混合精度设置及目标硬件适配,助力高效部署深度学习模型。
956 21
PyTorch量化感知训练技术:模型压缩与高精度边缘部署实践
|
1月前
|
边缘计算 人工智能 PyTorch
130_知识蒸馏技术:温度参数与损失函数设计 - 教师-学生模型的优化策略与PyTorch实现
随着大型语言模型(LLM)的规模不断增长,部署这些模型面临着巨大的计算和资源挑战。以DeepSeek-R1为例,其671B参数的规模即使经过INT4量化后,仍需要至少6张高端GPU才能运行,这对于大多数中小型企业和研究机构来说成本过高。知识蒸馏作为一种有效的模型压缩技术,通过将大型教师模型的知识迁移到小型学生模型中,在显著降低模型复杂度的同时保留核心性能,成为解决这一问题的关键技术之一。
|
2月前
|
机器学习/深度学习 存储 PyTorch
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
Neural ODE将神经网络与微分方程结合,用连续思维建模数据演化,突破传统离散层的限制,实现自适应深度与高效连续学习。
145 3
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
|
3月前
|
PyTorch 算法框架/工具 异构计算
PyTorch 2.0性能优化实战:4种常见代码错误严重拖慢模型
我们将深入探讨图中断(graph breaks)和多图问题对性能的负面影响,并分析PyTorch模型开发中应当避免的常见错误模式。
245 9
|
5月前
|
机器学习/深度学习 存储 PyTorch
PyTorch + MLFlow 实战:从零构建可追踪的深度学习模型训练系统
本文通过使用 Kaggle 数据集训练情感分析模型的实例,详细演示了如何将 PyTorch 与 MLFlow 进行深度集成,实现完整的实验跟踪、模型记录和结果可复现性管理。文章将系统性地介绍训练代码的核心组件,展示指标和工件的记录方法,并提供 MLFlow UI 的详细界面截图。
252 2
PyTorch + MLFlow 实战:从零构建可追踪的深度学习模型训练系统
|
4月前
|
机器学习/深度学习 数据可视化 PyTorch
Flow Matching生成模型:从理论基础到Pytorch代码实现
本文将系统阐述Flow Matching的完整实现过程,包括数学理论推导、模型架构设计、训练流程构建以及速度场学习等关键组件。通过本文的学习,读者将掌握Flow Matching的核心原理,获得一个完整的PyTorch实现,并对生成模型在噪声调度和分数函数之外的发展方向有更深入的理解。
1803 0
Flow Matching生成模型:从理论基础到Pytorch代码实现
|
5月前
|
机器学习/深度学习 PyTorch 算法框架/工具
提升模型泛化能力:PyTorch的L1、L2、ElasticNet正则化技术深度解析与代码实现
本文将深入探讨L1、L2和ElasticNet正则化技术,重点关注其在PyTorch框架中的具体实现。关于这些技术的理论基础,建议读者参考相关理论文献以获得更深入的理解。
174 4
提升模型泛化能力:PyTorch的L1、L2、ElasticNet正则化技术深度解析与代码实现
|
6月前
|
机器学习/深度学习 PyTorch 编译器
深入解析torch.compile:提升PyTorch模型性能、高效解决常见问题
PyTorch 2.0推出的`torch.compile`功能为深度学习模型带来了显著的性能优化能力。本文从实用角度出发,详细介绍了`torch.compile`的核心技巧与应用场景,涵盖模型复杂度评估、可编译组件分析、系统化调试策略及性能优化高级技巧等内容。通过解决图断裂、重编译频繁等问题,并结合分布式训练和NCCL通信优化,开发者可以有效提升日常开发效率与模型性能。文章为PyTorch用户提供了全面的指导,助力充分挖掘`torch.compile`的潜力。
737 17
|
6月前
|
机器学习/深度学习 搜索推荐 PyTorch
基于昇腾用PyTorch实现CTR模型DIN(Deep interest Netwok)网络
本文详细讲解了如何在昇腾平台上使用PyTorch训练推荐系统中的经典模型DIN(Deep Interest Network)。主要内容包括:DIN网络的创新点与架构剖析、Activation Unit和Attention模块的实现、Amazon-book数据集的介绍与预处理、模型训练过程定义及性能评估。通过实战演示,利用Amazon-book数据集训练DIN模型,最终评估其点击率预测性能。文中还提供了代码示例,帮助读者更好地理解每个步骤的实现细节。

热门文章

最新文章

推荐镜像

更多