背景介绍
在之前的文章中,我们介绍了 GAN 的原理以及如何评价训练好的模型。可能有小伙伴看到,怎么生成的都是单一类别的图片呢,像 CIFAR10 和 ImageNet,都包含了多种类别的图片,如果我想训练一个能够生成多种类别图片的生成对抗网络该怎么做呢?
那么称为 conditional GANs (cGAN) 的模型就可以派上用场了。
cGAN 的发展
如何引入类别信息
最早提出 cGAN 的是论文 《Conditional Generative Adversarial Nets》,为了达到条件生成的目的,我们在输入给生成器网络 G 的噪声 z 上 concat 一个标签向量 y, 告诉生成网络生成标签所指定的数据。对于输入给判别器 D 的数据,也 concat 这样的一个标签,告诉判别网络判断输入是否为真实的该类别数据。
那么,cGAN 的目标函数可以表述成如下形式:
cGAN 采用 MLP 作为网络结构,一维的输入可以方便地和标签向量或者标签嵌入 concat,但是对于图像生成任务主流的 CNN 模型,无法直接采用这种引入方式,特别是对于判别器网络。
Projection GAN 通过推导发现,假设,其中 为判别器网络, 为激活函数, 为网络模型,在上面的目标函数下,最优的 为
根据上式,输入图像先经过网络 抽取一维特征,然后分两路,一路通过网络 输出关于图像真假的判别结果,一路与经过编码的类别标签做点乘得到关于类别的判别结果,之后将两路结果相加就可以得到最终的判别结果。模型结构如下:
如何稳定训练过程
GAN 的训练过程不稳定,而 cGAN 中学习多种类别的数据,稳定训练过程则更具有挑战性。在之前的研究如 WGAN,WGANGP 中,学者们发现对网络施加 Lipschitz 约束能有效稳定 GAN 的训练过程。
相比于之前在损失函数中增加正则项的做法,SNGAN 提出了谱归一化 (spectral normalization) 来构造网络模型,使得无论网络参数是什么,都能满足 Lipschitz 约束。
如何提升生成质量
在多类别数据如 ImageNet 的训练过程中,人们发现网络更加擅长生成局部的细节纹理,如狗的毛发;而对于几何特征和整体结构,生成效果往往不尽如人意。如下图 SNGAN 生成结果为例,狗的身体结构存在很多错误和不完整的表达。
SAGAN 认为这是由于卷积模型难以捕捉到距离较远的特征,因此引入注意力机制,设计了如下的注意力模块。
这里卷积特征图经过三个 1x1 卷积 f(x), g(x), h(x),将 f(x) 的输出转置,并和 g(x) 的输出相乘,再经过 softmax 归一化得到一个 attention map ,将得到的 attention map 和 h(x) 逐像素点相乘,再经过卷积 v(x) 得到自适应注意力的特征图。
SAGAN 在生成器和判别器中引入了这个注意力模块,使得生成器可以建模图像跨区域的依赖关系,判别器可以对全局图像结构施加几何约束。
集大成者:BigGAN
在采用上面提到的投影判别,谱归一化,自注意模块的基础上,《LARGE SCALE GAN TRAINING FOR HIGH FIDELITY NATURAL IMAGE SYNTHESIS》(即 BigGAN )通过增大 batch size,提升模型宽度,显著提升了图像生成的结果。(见下图,我们在上篇文章中提到了 FID 和 IS 两个评价生成模型的 metrics,可以看到,通过提升 batch size 和 通道数,这两个 metric 有显著提升。)
为了让 noise 能够直接影响不同分辨率上的特征,BigGAN 提出了 skip-z ,将 noise 分割后,分别传入 G 网络不同尺度的 layer。为了节省时间和内存,BigGAN 只对 class 做一次 embedding,将编码结果传给每个条件批归一化。以下为一个 BigGAN 的 G 网络结构图。
其中每一个 ResBlock 结构如下。分组的噪声和类别嵌入 concat 后,通过线性层得到每个 BatchNorm 的 gain 和 bias 参数,通过这种方式引入类别信息。
使用 MMGeneration 上手 BigGAN
在我们的文章 PyTorch 零基础入门 GAN 模型之基础篇 中,我们介绍了如何安装 MMGen 和训练模型。在此基础上,我们可以上手以 BigGAN 为代表的条件生成模型。
我们可以先看看 BigGAN 生成的图片长啥样,通过运行如下代码,我们可以从预训练好的 BigGAN 中 sample 类别随机的图片。
python demo/conditional_demo.py configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py https://download.openmmlab.com/mmgen/biggan/biggan_imagenet1k_128x128_b32x8_best_fid_iter_1232000_20211111_122548-5315b13d.pth
当然,我们也可以用 --label 来指定采样的类别,--samples-per-classes 来指定每类采样的数量。
比如运行下面代码:
python demo/conditional_demo.py configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py https://download.openmmlab.com/mmgen/biggan/biggan_imagenet1k_128x128_b32x8_best_fid_iter_1232000_20211111_122548-5315b13d.pth --label 151 285 292 --samples-per-classes 5
可以得到 狗(151),猫(285),老虎(292)各五张图片。
我们还可以通过运行下面的代码看看 BigGAN 分别从噪声空间和标签空间插值的结果。
python apps/conditional_interpolate.py configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py https://download.openmmlab.com/mmgen/biggan/biggan_imagenet1k_128x128_b32x8_best_fid_iter_1232000_20211111_122548-5315b13d.pth --samples-path work_dirs/demos/ --show-mode group --fix-z # 固定噪声 python apps/conditional_interpolate.py configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py https://download.openmmlab.com/mmgen/biggan/biggan_imagenet1k_128x128_b32x8_best_fid_iter_1232000_20211111_122548-5315b13d.pth --samples-path work_dirs/demos/ --show-mode group --fix-y # 固定标签
结果分别如下:
现在我们可以看看怎么训练 BigGAN,首先我们需要下载 ImageNet,然后放到 ./data 文件夹下。
BigGAN 的训练有几个关键点设置,我们以 configs/_base_/models/biggan/biggan_128x128.py 和configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py 为例进行说明。
首先是 spectral normalization 的实现方式,我们提供了两种实现方式,一种是 PyTorch 官方提供的实现,另一种是 BigGAN 作者 ajbrock 提供的实现。我们可以设置 sn_style 为 torch 或者 ajbrock 来选择,如果不设置,默认为 ajbrock。
# 使用 BigGAN 作者提供的 SN 实现 model = dict( type='BasiccGAN', generator=dict(xxx, sn_style='ajbrock'), discriminator=dict(xxx, sn_style='ajbrock'), gan_loss=dict(type='GANLoss', gan_type='hinge'))
在生成模型中,有时希望用一个更强的 D 来引导 G 的更新。常用的一种设置为训练若干步判别器,再训练一步生成器 ,这里 BigGAN 通过设置 train_cfg 中的 disc_steps 和 gen_steps 来实现。
train_cfg = dict( disc_steps=8, gen_steps=1, batch_accumulation_steps=8, use_ema=True)
这个 config 表示训练 8 步判别器,再训练一步生成器。
在上面代码中,字段 batch_accumulation_steps 涉及到梯度累积操作,因为显存限制,我们很难直接在 batchsize 为 2048 的数据上做 forward 和 backward,因此可以将多个小批量上的梯度平均化,只在 batch_accumulation_steps 次累积后进行优化。假设我们在 8 卡上训练,每张卡 batchsize 为 32,batch_accumulation_steps = 8,这样可以逼近 8*8*32 = 2048 的 batchsize。
为了稳定 GAN 的训练过程,我们往往要使用一种叫指数移动平均的技巧。通过设置 generator 网络的备份 generator_ema 。在每个 train iter 后,将更新的模型参数和历史参数加权平均后作为generator_ema 的参数,这样 generator_ema 的参数更新会比 generator 更加平滑,作为训练结束后的 inference 模型,其生成结果更好。
为了使用 ema,首先需要在上文 train_cfg 中将 use_ema 设置为 True,同时,需要在 config 中添加一个 ExponentialMovingAverageHook。
custom_hooks = [ xxx, dict( type='ExponentialMovingAverageHook', module_keys=('generator_ema', ), interval=8, start_iter=160000, interp_cfg=dict(momentum=0.9999, momentum_nontrainable=0.9999), priority='VERY_HIGH') ]
这里 interval 为 generator_ema 参数更新频率,start_inter 为 ema 开始 iteration,在此之前网络参数照搬 generator。interp_cfg 中 momentum 为 parameters 更新的权重,momentum_nontrainable 为 buffers 更新的权重。
对于模型训练的其他细节和具体实现,大家可以参考 MMGeneration 中的代码(当然我们也有可能再出一期详解~)
上面的设置,我们已经在 config 中为大家写好了,只需要运行如下代码,就可以开始训练自己的 BigGAN 模型了!
bash tools/dist_train.sh configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py 8 --work-dir ./work_dirs/biggan
训练过程中,可以查看 work_dirs/biggan/training_samples 下的不同阶段模型生成图片。前四行为 generator_ema 生成图片,后四行为 generator 同输入下生成的图片。
其实,上面的条件采样,条件插值,模型训练对于 MMGeneration 中已支持的 SNGAN,SAGAN 也是同样适用的,欢迎大家随时使用并提出意见~
文章来源:【OpenMMLab】
2022-03-30 18:10