PyTorch 零基础入门 GAN 模型之 cGAN

简介: 最早提出 cGAN 的是论文 《Conditional Generative Adversarial Nets》,为了达到条件生成的目的,我们在输入给生成器网络 G 的噪声 z 上 concat 一个标签向量 y, 告诉生成网络生成标签所指定的数据。对于输入给判别器 D 的数据,也 concat 这样的一个标签,告诉判别网络判断输入是否为真实的该类别数据。


背景介绍



在之前的文章中,我们介绍了 GAN 的原理以及如何评价训练好的模型。可能有小伙伴看到,怎么生成的都是单一类别的图片呢,像 CIFAR10 和 ImageNet,都包含了多种类别的图片,如果我想训练一个能够生成多种类别图片的生成对抗网络该怎么做呢?


那么称为  conditional GANs (cGAN) 的模型就可以派上用场了。


cGAN 的发展



如何引入类别信息


最早提出 cGAN 的是论文 《Conditional Generative Adversarial Nets》,为了达到条件生成的目的,我们在输入给生成器网络 G 的噪声 z 上 concat 一个标签向量 y, 告诉生成网络生成标签所指定的数据。对于输入给判别器 D 的数据,也 concat 这样的一个标签,告诉判别网络判断输入是否为真实的该类别数据。

640.png

那么,cGAN 的目标函数可以表述成如下形式:

640.png

cGAN 采用 MLP 作为网络结构,一维的输入可以方便地和标签向量或者标签嵌入 concat,但是对于图像生成任务主流的 CNN 模型,无法直接采用这种引入方式,特别是对于判别器网络。


Projection GAN 通过推导发现,假设image.png其中  image.png为判别器网络,  image.png为激活函数, image.png为网络模型,在上面的目标函数下,最优的image.png image.png

根据上式,输入图像先经过网络 image.png 抽取一维特征,然后分两路,一路通过网络  image.png输出关于图像真假的判别结果,一路与经过编码的类别标签做点乘得到关于类别的判别结果,之后将两路结果相加就可以得到最终的判别结果。模型结构如下:

640.png


如何稳定训练过程


GAN 的训练过程不稳定,而 cGAN 中学习多种类别的数据,稳定训练过程则更具有挑战性。在之前的研究如 WGAN,WGANGP 中,学者们发现对网络施加 Lipschitz 约束能有效稳定 GAN 的训练过程。


相比于之前在损失函数中增加正则项的做法,SNGAN 提出了谱归一化 (spectral normalization) 来构造网络模型,使得无论网络参数是什么,都能满足 Lipschitz 约束。


如何提升生成质量


在多类别数据如 ImageNet 的训练过程中,人们发现网络更加擅长生成局部的细节纹理,如狗的毛发;而对于几何特征和整体结构,生成效果往往不尽如人意。如下图 SNGAN 生成结果为例,狗的身体结构存在很多错误和不完整的表达。

640.png


SAGAN 认为这是由于卷积模型难以捕捉到距离较远的特征,因此引入注意力机制,设计了如下的注意力模块。

640.png

这里卷积特征图经过三个 1x1 卷积 f(x), g(x), h(x),将 f(x) 的输出转置,并和 g(x) 的输出相乘,再经过 softmax 归一化得到一个 attention map ,将得到的 attention map 和 h(x) 逐像素点相乘,再经过卷积 v(x) 得到自适应注意力的特征图。


SAGAN 在生成器和判别器中引入了这个注意力模块,使得生成器可以建模图像跨区域的依赖关系,判别器可以对全局图像结构施加几何约束。


集大成者:BigGAN


在采用上面提到的投影判别,谱归一化,自注意模块的基础上,《LARGE SCALE GAN TRAINING FOR HIGH FIDELITY NATURAL IMAGE SYNTHESIS》(即 BigGAN )通过增大 batch size,提升模型宽度,显著提升了图像生成的结果。(见下图,我们在上篇文章中提到了 FID 和 IS 两个评价生成模型的 metrics,可以看到,通过提升 batch size 和 通道数,这两个 metric 有显著提升。

640.png

为了让 noise 能够直接影响不同分辨率上的特征,BigGAN 提出了 skip-z ,将 noise 分割后,分别传入 G 网络不同尺度的 layer。为了节省时间和内存,BigGAN 只对 class 做一次 embedding,将编码结果传给每个条件批归一化。以下为一个 BigGAN 的 G 网络结构图。

640.png

其中每一个 ResBlock 结构如下。分组的噪声和类别嵌入 concat 后,通过线性层得到每个 BatchNorm 的 gain 和 bias 参数,通过这种方式引入类别信息。

640.png


使用 MMGeneration 上手 BigGAN


在我们的文章 PyTorch 零基础入门 GAN 模型之基础篇 中,我们介绍了如何安装 MMGen 和训练模型。在此基础上,我们可以上手以 BigGAN 为代表的条件生成模型。


我们可以先看看 BigGAN 生成的图片长啥样,通过运行如下代码,我们可以从预训练好的  BigGAN 中 sample 类别随机的图片。

python demo/conditional_demo.py configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py https://download.openmmlab.com/mmgen/biggan/biggan_imagenet1k_128x128_b32x8_best_fid_iter_1232000_20211111_122548-5315b13d.pth

640.png

当然,我们也可以用 --label 来指定采样的类别,--samples-per-classes 来指定每类采样的数量。


比如运行下面代码:

python demo/conditional_demo.py configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py https://download.openmmlab.com/mmgen/biggan/biggan_imagenet1k_128x128_b32x8_best_fid_iter_1232000_20211111_122548-5315b13d.pth --label 151 285 292 --samples-per-classes 5

可以得到 狗(151),猫(285),老虎(292)各五张图片。

640.png

我们还可以通过运行下面的代码看看 BigGAN 分别从噪声空间和标签空间插值的结果。

python apps/conditional_interpolate.py configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py https://download.openmmlab.com/mmgen/biggan/biggan_imagenet1k_128x128_b32x8_best_fid_iter_1232000_20211111_122548-5315b13d.pth --samples-path work_dirs/demos/ --show-mode group --fix-z # 固定噪声
python apps/conditional_interpolate.py configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py https://download.openmmlab.com/mmgen/biggan/biggan_imagenet1k_128x128_b32x8_best_fid_iter_1232000_20211111_122548-5315b13d.pth --samples-path work_dirs/demos/ --show-mode group --fix-y # 固定标签

结果分别如下:

640.png

640.png

现在我们可以看看怎么训练 BigGAN,首先我们需要下载 ImageNet,然后放到 ./data 文件夹下。

BigGAN 的训练有几个关键点设置,我们以 configs/_base_/models/biggan/biggan_128x128.py 和configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py 为例进行说明。


首先是 spectral normalization 的实现方式,我们提供了两种实现方式,一种是 PyTorch 官方提供的实现,另一种是 BigGAN 作者 ajbrock 提供的实现。我们可以设置 sn_style 为 torch 或者 ajbrock 来选择,如果不设置,默认为 ajbrock。

# 使用 BigGAN 作者提供的 SN 实现
model = dict(
    type='BasiccGAN',
    generator=dict(xxx, sn_style='ajbrock'),
    discriminator=dict(xxx, sn_style='ajbrock'),
    gan_loss=dict(type='GANLoss', gan_type='hinge'))

在生成模型中,有时希望用一个更强的 D 来引导 G 的更新。常用的一种设置为训练若干步判别器,再训练一步生成器 ,这里 BigGAN 通过设置 train_cfg 中的 disc_steps 和 gen_steps 来实现。

train_cfg = dict(
    disc_steps=8, gen_steps=1, batch_accumulation_steps=8, use_ema=True)

这个 config 表示训练 8 步判别器,再训练一步生成器。


在上面代码中,字段 batch_accumulation_steps 涉及到梯度累积操作,因为显存限制,我们很难直接在 batchsize 为 2048 的数据上做 forward 和 backward,因此可以将多个小批量上的梯度平均化,只在 batch_accumulation_steps 次累积后进行优化。假设我们在 8 卡上训练,每张卡 batchsize 为 32,batch_accumulation_steps = 8,这样可以逼近 8*8*32 = 2048 的 batchsize。


为了稳定 GAN 的训练过程,我们往往要使用一种叫指数移动平均的技巧。通过设置 generator 网络的备份 generator_ema 。在每个 train iter 后,将更新的模型参数和历史参数加权平均后作为generator_ema 的参数,这样 generator_ema 的参数更新会比 generator 更加平滑,作为训练结束后的 inference 模型,其生成结果更好。


为了使用 ema,首先需要在上文 train_cfg 中将 use_ema 设置为 True,同时,需要在 config 中添加一个 ExponentialMovingAverageHook。

custom_hooks = [
    xxx,
    dict(
        type='ExponentialMovingAverageHook',
        module_keys=('generator_ema', ),
        interval=8,
        start_iter=160000,
        interp_cfg=dict(momentum=0.9999, momentum_nontrainable=0.9999),
        priority='VERY_HIGH')
]

这里 interval 为 generator_ema 参数更新频率,start_inter 为 ema 开始 iteration,在此之前网络参数照搬 generator。interp_cfg 中 momentum 为 parameters 更新的权重,momentum_nontrainable 为 buffers 更新的权重。


对于模型训练的其他细节和具体实现,大家可以参考 MMGeneration 中的代码(当然我们也有可能再出一期详解~)


上面的设置,我们已经在 config 中为大家写好了,只需要运行如下代码,就可以开始训练自己的 BigGAN 模型了!

bash tools/dist_train.sh configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py 8 --work-dir ./work_dirs/biggan

训练过程中,可以查看 work_dirs/biggan/training_samples 下的不同阶段模型生成图片。前四行为 generator_ema 生成图片,后四行为 generator 同输入下生成的图片。

640.png

其实,上面的条件采样,条件插值,模型训练对于 MMGeneration 中已支持的 SNGAN,SAGAN 也是同样适用的,欢迎大家随时使用并提出意见~

文章来源:【OpenMMLab

 2022-03-30 18:10

目录
相关文章
|
2天前
|
机器学习/深度学习 自然语言处理 PyTorch
【PyTorch实战演练】基于AlexNet的预训练模型介绍
【PyTorch实战演练】基于AlexNet的预训练模型介绍
94 0
|
2天前
|
机器学习/深度学习 关系型数据库 MySQL
大模型中常用的注意力机制GQA详解以及Pytorch代码实现
GQA是一种结合MQA和MHA优点的注意力机制,旨在保持MQA的速度并提供MHA的精度。它将查询头分成组,每组共享键和值。通过Pytorch和einops库,可以简洁实现这一概念。GQA在保持高效性的同时接近MHA的性能,是高负载系统优化的有力工具。相关论文和非官方Pytorch实现可进一步探究。
124 4
|
2天前
|
PyTorch 算法框架/工具 异构计算
pytorch 模型保存与加载
pytorch 模型保存与加载
5 0
|
2天前
|
PyTorch 算法框架/工具 Python
【pytorch框架】对模型知识的基本了解
【pytorch框架】对模型知识的基本了解
|
2天前
|
机器学习/深度学习 算法 PyTorch
PyTorch模型优化与调优:正则化、批归一化等技巧
【4月更文挑战第18天】本文探讨了PyTorch中提升模型性能的优化技巧,包括正则化(L1/L2正则化、Dropout)、批归一化、学习率调整策略和模型架构优化。正则化防止过拟合,Dropout提高泛化能力;批归一化加速训练并提升性能;学习率调整策略动态优化训练效果;模型架构优化涉及网络结构和参数的调整。这些方法有助于实现更高效的深度学习模型。
|
2天前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch与迁移学习:利用预训练模型提升性能
【4月更文挑战第18天】PyTorch支持迁移学习,助力提升深度学习性能。预训练模型(如ResNet、VGG)在大规模数据集(如ImageNet)训练后,可在新任务中加速训练,提高准确率。通过选择模型、加载预训练权重、修改结构和微调,可适应不同任务需求。迁移学习节省资源,但也需考虑源任务与目标任务的相似度及超参数选择。实践案例显示,预训练模型能有效提升小数据集上的图像分类任务性能。未来,迁移学习将继续在深度学习领域发挥重要作用。
|
2天前
|
机器学习/深度学习 PyTorch 调度
PyTorch进阶:模型保存与加载,以及断点续训技巧
【4月更文挑战第17天】本文介绍了PyTorch中模型的保存与加载,以及断点续训技巧。使用`torch.save`和`torch.load`可保存和加载模型权重和状态字典。保存模型时,可选择仅保存轻量级的状态字典或整个模型对象。加载时,需确保模型结构与保存时一致。断点续训需保存训练状态,包括epoch、batch index、optimizer和scheduler状态。中断后,加载这些状态以恢复训练,节省时间和资源。
|
2天前
|
机器学习/深度学习 数据采集 PyTorch
构建你的第一个PyTorch神经网络模型
【4月更文挑战第17天】本文介绍了如何使用PyTorch构建和训练第一个神经网络模型。首先,准备数据集,如MNIST。接着,自定义神经网络模型`SimpleNet`,包含两个全连接层和ReLU激活函数。然后,定义交叉熵损失函数和SGD优化器。训练模型涉及多次迭代,计算损失、反向传播和参数更新。最后,测试模型性能,计算测试集上的准确率。这是一个基础的深度学习入门示例,为进一步探索复杂项目打下基础。
|
2天前
|
机器学习/深度学习 PyTorch 算法框架/工具
Python中用PyTorch机器学习神经网络分类预测银行客户流失模型
Python中用PyTorch机器学习神经网络分类预测银行客户流失模型
|
2天前
|
PyTorch 算法框架/工具 Python
Pytorch构建网络模型时super(__class__, self).__init__()的作用
Pytorch构建网络模型时super(__class__, self).__init__()的作用
12 0